Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src-tauri/src/managers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ pub mod audio;
pub mod history;
pub mod model;
pub mod transcription;
mod transcription_recovery;
105 changes: 99 additions & 6 deletions src-tauri/src/managers/transcription.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
use super::transcription_recovery::{
transcribe_with_chunk_retry, ChunkRetryPolicy, ChunkedTranscriptionResult,
ChunkedTranscriptionSegment,
};
use crate::audio_toolkit::{apply_custom_words, filter_transcription_output};
use crate::managers::audio::AudioRecordingManager;
use crate::managers::model::{EngineType, ModelManager};
Expand All @@ -24,7 +28,7 @@ use transcribe_rs::{
Quantization,
},
whisper_cpp::{WhisperEngine, WhisperInferenceParams},
SpeechModel, TranscribeOptions,
SpeechModel, TranscribeError, TranscribeOptions,
};

#[derive(Clone, Debug, Serialize)]
Expand All @@ -45,6 +49,65 @@ enum LoadedEngine {
Canary(CanaryModel),
}

const PARAKEET_RETRYABLE_BROADCAST_MARKERS: [&str; 2] = [
"broadcastiterator::init",
"attempting to broadcast an axis by a dimension other than 1",
];
const PARAKEET_CHUNK_RETRY_POLICY: ChunkRetryPolicy<TranscribeError> = ChunkRetryPolicy {
label: "Parakeet",
sample_rate_hz: crate::audio_toolkit::constants::WHISPER_SAMPLE_RATE as usize,
max_split_depth: 9,
min_chunk_samples: crate::audio_toolkit::constants::WHISPER_SAMPLE_RATE as usize * 30,
split_padding_samples: (crate::audio_toolkit::constants::WHISPER_SAMPLE_RATE as usize * 3) / 4,
max_merge_word_overlap: 12,
should_retry: should_retry_parakeet_chunk_error,
};

fn should_retry_parakeet_chunk_error(error: &TranscribeError) -> bool {
let TranscribeError::Inference(message) = error else {
return false;
};

let normalized = message.to_lowercase();
PARAKEET_RETRYABLE_BROADCAST_MARKERS
.iter()
.all(|marker| normalized.contains(marker))
}

fn to_chunked_transcription_result(
result: transcribe_rs::TranscriptionResult,
) -> ChunkedTranscriptionResult {
ChunkedTranscriptionResult {
text: result.text,
segments: result.segments.map(|segments| {
segments
.into_iter()
.map(|segment| ChunkedTranscriptionSegment {
start: segment.start,
end: segment.end,
text: segment.text,
})
.collect()
}),
}
}

fn to_transcribe_result(result: ChunkedTranscriptionResult) -> transcribe_rs::TranscriptionResult {
transcribe_rs::TranscriptionResult {
text: result.text,
segments: result.segments.map(|segments| {
segments
.into_iter()
.map(|segment| transcribe_rs::TranscriptionSegment {
start: segment.start,
end: segment.end,
text: segment.text,
})
.collect()
}),
}
}

/// RAII guard that clears the `is_loading` flag and notifies waiters on drop.
/// Ensures the loading flag is always reset, even on early returns or panics.
pub struct LoadingGuard {
Expand Down Expand Up @@ -550,11 +613,19 @@ impl TranscriptionManager {
timestamp_granularity: Some(TimestampGranularity::Segment),
..Default::default()
};
parakeet_engine
.transcribe_with(&audio, &params)
.map_err(|e| {
anyhow::anyhow!("Parakeet transcription failed: {}", e)
})
let mut transcribe_chunk = |chunk: &[f32]| {
parakeet_engine
.transcribe_with(chunk, &params)
.map(to_chunked_transcription_result)
};

transcribe_with_chunk_retry(
&audio,
&PARAKEET_CHUNK_RETRY_POLICY,
&mut transcribe_chunk,
)
.map(to_transcribe_result)
.map_err(|e| anyhow::anyhow!("Parakeet transcription failed: {}", e))
}
LoadedEngine::Moonshine(moonshine_engine) => moonshine_engine
.transcribe(&audio, &TranscribeOptions::default())
Expand Down Expand Up @@ -704,6 +775,28 @@ impl TranscriptionManager {
}
}

#[cfg(test)]
mod tests {
use super::should_retry_parakeet_chunk_error;
use transcribe_rs::TranscribeError;

#[test]
fn should_retry_parakeet_chunk_error_matches_broadcast_failure() {
let error = TranscribeError::Inference("Non-zero status code returned while running Add node. Status Message: BroadcastIterator::Init axis == 1 || axis == largest was false. Attempting to broadcast an axis by a dimension other than 1. 9545 by 14545".to_string());

assert!(should_retry_parakeet_chunk_error(&error));
}

#[test]
fn should_retry_parakeet_chunk_error_ignores_generic_ort_errors() {
let error = TranscribeError::Inference(
"ORT error: failed to allocate memory for inference".to_string(),
);

assert!(!should_retry_parakeet_chunk_error(&error));
}
}

/// Apply the user's accelerator preferences to the transcribe-rs global atomics.
/// Called on startup and whenever the user changes the setting.
pub fn apply_accelerator_settings(app: &tauri::AppHandle) {
Expand Down
Loading
Loading