diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 4598eba6c..139cc5c26 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -31,8 +31,10 @@ jobs: - name: Use mock TranscriptionManager (CI only) working-directory: src-tauri run: | - # Swap to mock adapter - avoids compiling whisper/Vulkan + # Swap to mock adapters - avoids compiling whisper/Vulkan cp src/managers/transcription_mock.rs src/managers/transcription.rs + cp src/managers/streaming_mock.rs src/managers/streaming.rs + cp src/audio_toolkit/vad/adapter_mock.rs src/audio_toolkit/vad/adapter.rs sed -i '/^transcribe-rs/d' Cargo.toml - name: Run Rust tests diff --git a/src-tauri/Cargo.lock b/src-tauri/Cargo.lock index 1765bf9f3..468078355 100644 --- a/src-tauri/Cargo.lock +++ b/src-tauri/Cargo.lock @@ -1696,16 +1696,6 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "af9673d8203fcb076b19dfd17e38b3d4ae9f44959416ea532ce72415a6020365" -[[package]] -name = "eyre" -version = "0.6.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7cd915d99f24784cdc19fd37ef22b97e3ff0ae756c7e492e9fbfe897d61e2aec" -dependencies = [ - "indenter", - "once_cell", -] - [[package]] name = "fallible-iterator" version = "0.3.0" @@ -2467,7 +2457,6 @@ dependencies = [ "tempfile", "tokio", "transcribe-rs", - "vad-rs", "windows 0.61.3", "winreg 0.55.0", ] @@ -2864,12 +2853,6 @@ dependencies = [ "tiff", ] -[[package]] -name = "indenter" -version = "0.3.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "964de6e86d545b246d84badc0fef527924ace5134f30641c203ef52ba83f58d5" - [[package]] name = "indexmap" version = "1.9.3" @@ -4980,12 +4963,6 @@ dependencies = [ "windows-sys 0.52.0", ] -[[package]] -name = "ringbuffer" -version = "0.15.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3df6368f71f205ff9c33c076d170dd56ebf68e8161c733c0caa07a7a5509ed53" - [[package]] name = "rkyv" version = "0.7.46" @@ -6992,9 +6969,9 @@ dependencies = [ [[package]] name = "transcribe-rs" -version = "0.3.3" +version = "0.3.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8af54b24283d1548883a79f258c75c0e02938f5d66073b84b99dcaf00cb06f7" +checksum = "e38caeddd0be4528f3ed21e57eb8baba38b5a4d8b821b97d49894d0adbb0420e" dependencies = [ "base64 0.22.1", "derive_builder", @@ -7246,17 +7223,6 @@ dependencies = [ "wasm-bindgen", ] -[[package]] -name = "vad-rs" -version = "0.1.6" -source = "git+https://github.com/cjpais/vad-rs#2a412ed858695b9251f3f5a1a20d95b59fa7c498" -dependencies = [ - "eyre", - "ndarray", - "ort", - "ringbuffer", -] - [[package]] name = "value-bag" version = "1.12.0" diff --git a/src-tauri/Cargo.toml b/src-tauri/Cargo.toml index e6afa5e7a..8bdb19dc3 100644 --- a/src-tauri/Cargo.toml +++ b/src-tauri/Cargo.toml @@ -55,7 +55,6 @@ hound = "3.5.1" log = "0.4.25" env_filter = "0.1.0" tokio = "1.43.0" -vad-rs = { git = "https://github.com/cjpais/vad-rs", default-features = false } enigo = "0.6.1" rodio = { git = "https://github.com/cjpais/rodio.git" } reqwest = { version = "0.12", features = ["json", "stream"] } @@ -69,7 +68,7 @@ rusqlite = { version = "0.37", features = ["bundled"] } tar = "0.4.44" flate2 = "1.0" sha2 = "0.10" -transcribe-rs = { version = "0.3.3", features = ["whisper-cpp", "onnx"] } +transcribe-rs = { version = "0.3.4", features = ["whisper-cpp", "onnx", "vad-silero"] } handy-keys = "0.2.4" ferrous-opencc = "0.2.3" clap = { version = "4", features = ["derive"] } @@ -88,7 +87,7 @@ tauri-plugin-single-instance = "2.3.2" tauri-plugin-updater = "2.10.0" [target.'cfg(windows)'.dependencies] -transcribe-rs = { version = "0.3.3", features = ["whisper-vulkan", "ort-directml"] } +transcribe-rs = { version = "0.3.4", features = ["whisper-vulkan", "ort-directml", "vad-silero"] } windows = { version = "0.61.3", features = [ "Win32_Media_Audio_Endpoints", "Win32_System_Com_StructuredStorage", @@ -100,12 +99,12 @@ winreg = "0.55" [target.'cfg(target_os = "macos")'.dependencies] tauri-nspanel = { git = "https://github.com/ahkohd/tauri-nspanel", branch = "v2.1" } -transcribe-rs = { version = "0.3.3", features = ["whisper-metal"] } +transcribe-rs = { version = "0.3.4", features = ["whisper-metal", "vad-silero"] } [target.'cfg(target_os = "linux")'.dependencies] gtk-layer-shell = { version = "0.8", features = ["v0_6"] } gtk = "0.18" -transcribe-rs = { version = "0.3.3", features = ["whisper-vulkan"] } +transcribe-rs = { version = "0.3.4", features = ["whisper-vulkan", "vad-silero"] } [patch.crates-io] tauri-runtime = { git = "https://github.com/cjpais/tauri.git", branch = "handy-2.10.2" } diff --git a/src-tauri/src/actions.rs b/src-tauri/src/actions.rs index 4a738dfde..6b2cef76e 100644 --- a/src-tauri/src/actions.rs +++ b/src-tauri/src/actions.rs @@ -4,8 +4,11 @@ use crate::audio_feedback::{play_feedback_sound, play_feedback_sound_blocking, S use crate::audio_toolkit::{is_microphone_access_denied, is_no_input_device_error}; use crate::managers::audio::AudioRecordingManager; use crate::managers::history::HistoryManager; +use crate::managers::streaming::StreamingSession; use crate::managers::transcription::TranscriptionManager; -use crate::settings::{get_settings, AppSettings, APPLE_INTELLIGENCE_PROVIDER_ID}; +use crate::settings::{ + get_settings, AppSettings, TranscriptionMode, APPLE_INTELLIGENCE_PROVIDER_ID, +}; use crate::shortcut; use crate::tray::{change_tray_icon, TrayIconState}; use crate::utils::{ @@ -16,7 +19,7 @@ use ferrous_opencc::{config::BuiltinConfig, OpenCC}; use log::{debug, error, warn}; use once_cell::sync::Lazy; use std::collections::HashMap; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; use std::time::Instant; use tauri::Manager; use tauri::{AppHandle, Emitter}; @@ -38,6 +41,9 @@ impl Drop for FinishGuard { } } +/// Active streaming session (if any). Keeps state isolated per recording. +static STREAMING_SESSION: Lazy>> = Lazy::new(|| Mutex::new(None)); + // Shortcut Action Trait pub trait ShortcutAction: Send + Sync { fn start(&self, app: &AppHandle, binding_id: &str, shortcut_str: &str); @@ -361,6 +367,86 @@ pub(crate) async fn process_transcription_output( } } +async fn save_wav_and_history( + hm: &Arc, + audio: &[f32], + transcription: &str, + post_process: bool, + post_processed_text: Option, + post_process_prompt: Option, +) { + if audio.is_empty() { + return; + } + let file_name = format!("handy-{}.wav", chrono::Utc::now().timestamp()); + let wav_path = hm.recordings_dir().join(&file_name); + let samples = audio.to_vec(); + let sample_count = samples.len(); + let wav_path_for_verify = wav_path.clone(); + + let wav_saved = match tauri::async_runtime::spawn_blocking(move || { + crate::audio_toolkit::save_wav_file(&wav_path, &samples) + }) + .await + { + Ok(Ok(())) => crate::audio_toolkit::verify_wav_file(&wav_path_for_verify, sample_count) + .map(|_| true) + .unwrap_or_else(|e| { + error!("WAV verification failed: {}", e); + false + }), + Ok(Err(e)) => { + error!("Failed to save WAV: {}", e); + false + } + Err(e) => { + error!("WAV save panicked: {}", e); + false + } + }; + + if wav_saved { + if let Err(err) = hm.save_entry( + file_name, + transcription.to_string(), + post_process, + post_processed_text, + post_process_prompt, + ) { + error!("Failed to save history entry: {}", err); + } + } +} + +fn start_streaming_session( + app: &AppHandle, + mode: TranscriptionMode, + tm: Arc, + chunk_rx: std::sync::mpsc::Receiver>, + settings: &AppSettings, +) { + let vad_path = app + .path() + .resolve( + "resources/models/silero_vad_v4.onnx", + tauri::path::BaseDirectory::Resource, + ) + .unwrap_or_default() + .to_string_lossy() + .to_string(); + + let session = StreamingSession::start( + mode, + tm, + app.clone(), + chunk_rx, + vad_path, + settings.realtime_chunk_duration_secs, + ); + *STREAMING_SESSION.lock().unwrap() = Some(session); + debug!("Streaming session started (mode: {:?})", mode); +} + impl ShortcutAction for TranscribeAction { fn start(&self, app: &AppHandle, binding_id: &str, _shortcut_str: &str) { let start_time = Instant::now(); @@ -381,39 +467,78 @@ impl ShortcutAction for TranscribeAction { let is_always_on = settings.always_on_microphone; debug!("Microphone mode - always_on: {}", is_always_on); + let transcription_mode = settings.transcription_mode; + let is_streaming = !matches!(transcription_mode, TranscriptionMode::Standard); + let mut recording_error: Option = None; + + // Start recording — streaming modes use start_streaming, standard uses start + if is_streaming { + // Clear any previous session + *STREAMING_SESSION.lock().unwrap() = None; + } + if is_always_on { // Always-on mode: Play audio feedback immediately, then apply mute after sound finishes debug!("Always-on mode: Playing audio feedback immediately"); let rm_clone = Arc::clone(&rm); let app_clone = app.clone(); - // The blocking helper exits immediately if audio feedback is disabled, - // so we can always reuse this thread to ensure mute happens right after playback. std::thread::spawn(move || { play_feedback_sound_blocking(&app_clone, SoundType::Start); rm_clone.apply_mute(); }); - if let Err(e) = rm.try_start_recording(&binding_id) { + if is_streaming { + match rm.try_start_streaming_recording(&binding_id) { + Ok(chunk_rx) => { + start_streaming_session( + app, + transcription_mode, + Arc::clone(&app.state::>()), + chunk_rx, + &settings, + ); + } + Err(e) => { + debug!("Streaming recording failed: {}", e); + recording_error = Some(e); + } + } + } else if let Err(e) = rm.try_start_recording(&binding_id) { debug!("Recording failed: {}", e); recording_error = Some(e); } } else { // On-demand mode: Start recording first, then play audio feedback, then apply mute - // This allows the microphone to be activated before playing the sound debug!("On-demand mode: Starting recording first, then audio feedback"); let recording_start_time = Instant::now(); - match rm.try_start_recording(&binding_id) { + + let started = if is_streaming { + match rm.try_start_streaming_recording(&binding_id) { + Ok(chunk_rx) => { + start_streaming_session( + app, + transcription_mode, + Arc::clone(&app.state::>()), + chunk_rx, + &settings, + ); + Ok(()) + } + Err(e) => Err(e), + } + } else { + rm.try_start_recording(&binding_id) + }; + + match started { Ok(()) => { debug!("Recording started in {:?}", recording_start_time.elapsed()); - // Small delay to ensure microphone stream is active let app_clone = app.clone(); let rm_clone = Arc::clone(&rm); std::thread::spawn(move || { std::thread::sleep(std::time::Duration::from_millis(100)); debug!("Handling delayed audio feedback/mute sequence"); - // Helper handles disabled audio feedback by returning early, so we reuse it - // to keep mute sequencing consistent in every mode. play_feedback_sound_blocking(&app_clone, SoundType::Start); rm_clone.apply_mute(); }); @@ -458,7 +583,6 @@ impl ShortcutAction for TranscribeAction { } fn stop(&self, app: &AppHandle, binding_id: &str, _shortcut_str: &str) { - // Unregister the cancel shortcut when transcription stops shortcut::unregister_cancel_shortcut(app); let stop_time = Instant::now(); @@ -469,16 +593,16 @@ impl ShortcutAction for TranscribeAction { let tm = Arc::clone(&app.state::>()); let hm = Arc::clone(&app.state::>()); + let settings = get_settings(app); + let transcription_mode = settings.transcription_mode; + change_tray_icon(app, TrayIconState::Transcribing); show_transcribing_overlay(app); - // Unmute before playing audio feedback so the stop sound is audible rm.remove_mute(); - - // Play audio feedback for recording stop play_feedback_sound(app, SoundType::Stop); - let binding_id = binding_id.to_string(); // Clone binding_id for the async task + let binding_id = binding_id.to_string(); let post_process = self.post_process; tauri::async_runtime::spawn(async move { @@ -488,19 +612,100 @@ impl ShortcutAction for TranscribeAction { binding_id ); + // Stop recording — this drops chunk_tx in streaming mode, + // causing the streaming worker's channel to close. let stop_recording_time = Instant::now(); - if let Some(samples) = rm.stop_recording(&binding_id) { - debug!( - "Recording stopped and samples retrieved in {:?}, sample count: {}", - stop_recording_time.elapsed(), - samples.len() - ); + let samples = rm.stop_recording(&binding_id); + debug!("Recording stopped in {:?}", stop_recording_time.elapsed()); + + if !matches!(transcription_mode, TranscriptionMode::Standard) { + // ── Streaming modes ── + let session = STREAMING_SESSION.lock().unwrap().take(); + if let Some(session) = session { + let result = session.finish(); + let transcription = result.combined_text; + debug!( + "Streaming finished: '{}' ({} audio samples)", + transcription, + result.audio.len() + ); + + if transcription.is_empty() { + utils::hide_recording_overlay(&ah); + change_tray_icon(&ah, TrayIconState::Idle); + return; + } + + match transcription_mode { + TranscriptionMode::Stream | TranscriptionMode::Realtime => { + // Text was already pasted live. Just save history + clean up. + save_wav_and_history( + &hm, + &result.audio, + &transcription, + false, + None, + None, + ) + .await; + utils::hide_recording_overlay(&ah); + change_tray_icon(&ah, TrayIconState::Idle); + } + TranscriptionMode::BatchStream => { + // Run post-processing on combined text, then paste once. + if post_process { + show_processing_overlay(&ah); + } + let processed = + process_transcription_output(&ah, &transcription, post_process) + .await; + save_wav_and_history( + &hm, + &result.audio, + &transcription, + post_process, + processed.post_processed_text.clone(), + processed.post_process_prompt.clone(), + ) + .await; - if samples.is_empty() { - debug!("Recording produced no audio samples; skipping persistence"); + if processed.final_text.is_empty() { + utils::hide_recording_overlay(&ah); + change_tray_icon(&ah, TrayIconState::Idle); + } else { + let ah_clone = ah.clone(); + let final_text = processed.final_text; + ah.run_on_main_thread(move || { + if let Err(e) = utils::paste(final_text, ah_clone.clone()) { + error!("Failed to paste transcription: {}", e); + } + utils::hide_recording_overlay(&ah_clone); + change_tray_icon(&ah_clone, TrayIconState::Idle); + }) + .unwrap_or_else(|e| { + error!("Failed to run paste on main thread: {:?}", e); + utils::hide_recording_overlay(&ah); + change_tray_icon(&ah, TrayIconState::Idle); + }); + } + } + TranscriptionMode::Standard => unreachable!(), + } + } else { + debug!("No streaming session found"); utils::hide_recording_overlay(&ah); change_tray_icon(&ah, TrayIconState::Idle); - } else { + } + } else { + // ── Standard mode (unchanged) ── + if let Some(samples) = samples { + if samples.is_empty() { + debug!("Recording produced no audio samples"); + utils::hide_recording_overlay(&ah); + change_tray_icon(&ah, TrayIconState::Idle); + return; + } + // Save WAV concurrently with transcription let sample_count = samples.len(); let file_name = format!("handy-{}.wav", chrono::Utc::now().timestamp()); @@ -511,30 +716,25 @@ impl ShortcutAction for TranscribeAction { crate::audio_toolkit::save_wav_file(&wav_path, &samples_for_wav) }); - // Transcribe concurrently with WAV save let transcription_time = Instant::now(); let transcription_result = tm.transcribe(samples); - // Await WAV save and verify let wav_saved = match wav_handle.await { - Ok(Ok(())) => { - match crate::audio_toolkit::verify_wav_file( - &wav_path_for_verify, - sample_count, - ) { - Ok(()) => true, - Err(e) => { - error!("WAV verification failed: {}", e); - false - } - } - } + Ok(Ok(())) => crate::audio_toolkit::verify_wav_file( + &wav_path_for_verify, + sample_count, + ) + .map(|_| true) + .unwrap_or_else(|e| { + error!("WAV verification failed: {}", e); + false + }), Ok(Err(e)) => { - error!("Failed to save WAV file: {}", e); + error!("Failed to save WAV: {}", e); false } Err(e) => { - error!("WAV save task panicked: {}", e); + error!("WAV save panicked: {}", e); false } }; @@ -554,7 +754,6 @@ impl ShortcutAction for TranscribeAction { process_transcription_output(&ah, &transcription, post_process) .await; - // Save to history if WAV was saved if wav_saved { if let Err(err) = hm.save_entry( file_name, @@ -572,15 +771,10 @@ impl ShortcutAction for TranscribeAction { change_tray_icon(&ah, TrayIconState::Idle); } else { let ah_clone = ah.clone(); - let paste_time = Instant::now(); let final_text = processed.final_text; ah.run_on_main_thread(move || { - match utils::paste(final_text, ah_clone.clone()) { - Ok(()) => debug!( - "Text pasted successfully in {:?}", - paste_time.elapsed() - ), - Err(e) => error!("Failed to paste transcription: {}", e), + if let Err(e) = utils::paste(final_text, ah_clone.clone()) { + error!("Failed to paste transcription: {}", e); } utils::hide_recording_overlay(&ah_clone); change_tray_icon(&ah_clone, TrayIconState::Idle); @@ -593,8 +787,7 @@ impl ShortcutAction for TranscribeAction { } } Err(err) => { - debug!("Global Shortcut Transcription error: {}", err); - // Save entry with empty text so user can retry + debug!("Transcription error: {}", err); if wav_saved { if let Err(save_err) = hm.save_entry( file_name, @@ -610,11 +803,11 @@ impl ShortcutAction for TranscribeAction { change_tray_icon(&ah, TrayIconState::Idle); } } + } else { + debug!("No samples retrieved from recording stop"); + utils::hide_recording_overlay(&ah); + change_tray_icon(&ah, TrayIconState::Idle); } - } else { - debug!("No samples retrieved from recording stop"); - utils::hide_recording_overlay(&ah); - change_tray_icon(&ah, TrayIconState::Idle); } }); diff --git a/src-tauri/src/audio_toolkit/audio/recorder.rs b/src-tauri/src/audio_toolkit/audio/recorder.rs index ef94a9836..58a13b6e9 100644 --- a/src-tauri/src/audio_toolkit/audio/recorder.rs +++ b/src-tauri/src/audio_toolkit/audio/recorder.rs @@ -21,6 +21,7 @@ use crate::audio_toolkit::{ enum Cmd { Start, + StartStreaming(mpsc::Sender>), Stop(mpsc::Sender>), Shutdown, } @@ -202,6 +203,14 @@ impl AudioRecorder { Ok(()) } + pub fn start_streaming(&self) -> Result>, Box> { + let (tx, rx) = mpsc::channel(); + if let Some(cmd_tx) = &self.cmd_tx { + cmd_tx.send(Cmd::StartStreaming(tx))?; + } + Ok(rx) + } + pub fn stop(&self) -> Result, Box> { let (resp_tx, resp_rx) = mpsc::channel(); if let Some(tx) = &self.cmd_tx { @@ -408,6 +417,7 @@ fn run_consumer( let mut processed_samples = Vec::::new(); let mut recording = false; + let mut chunk_tx: Option>> = None; // ---------- spectrum visualisation setup ---------------------------- // const BUCKETS: usize = 16; @@ -461,7 +471,20 @@ fn run_consumer( // ---------- existing pipeline ------------------------------------ // frame_resampler.push(&raw, &mut |frame: &[f32]| { - handle_frame(frame, recording, &vad, &mut processed_samples) + if !recording { + return; + } + + if let Some(tx) = &chunk_tx { + // Streaming mode: send raw resampled frames to the streaming session. + // VAD + chunking is handled by transcribe-rs in the session. + if tx.send(frame.to_vec()).is_err() { + log::warn!("Streaming channel closed"); + } + } else { + // Standard mode: use Handy's VAD filter + handle_frame(frame, recording, &vad, &mut processed_samples) + } }); // non-blocking check for a command @@ -470,6 +493,17 @@ fn run_consumer( Cmd::Start => { stop_flag.store(false, Ordering::Relaxed); processed_samples.clear(); + chunk_tx = None; + recording = true; + visualizer.reset(); + if let Some(v) = &vad { + v.lock().unwrap().reset(); + } + } + Cmd::StartStreaming(tx) => { + stop_flag.store(false, Ordering::Relaxed); + processed_samples.clear(); + chunk_tx = Some(tx); recording = true; visualizer.reset(); if let Some(v) = &vad { @@ -480,6 +514,10 @@ fn run_consumer( recording = false; stop_flag.store(true, Ordering::Relaxed); + // Drop chunk_tx so the streaming session's worker thread + // exits its chunk_rx.iter() loop. + chunk_tx = None; + // Drain all remaining audio until the producer confirms end-of-stream. // The cpal callback sees the stop flag, sends EndOfStream, and goes // silent — guaranteeing every captured sample is in the channel diff --git a/src-tauri/src/audio_toolkit/bin/cli.rs b/src-tauri/src/audio_toolkit/bin/cli.rs index a99e5a64f..87598d2cc 100644 --- a/src-tauri/src/audio_toolkit/bin/cli.rs +++ b/src-tauri/src/audio_toolkit/bin/cli.rs @@ -176,7 +176,7 @@ fn main() -> Result<(), Box> { print_help(); let silero = SileroVad::new("./resources/models/silero_vad_v4.onnx", 0.5)?; - let smoothed_vad = SmoothedVad::new(Box::new(silero), 15, 15); + let smoothed_vad = SmoothedVad::new(Box::new(silero), 15, 15, 2); let recorder = AudioRecorder::new()?.with_vad(Box::new(smoothed_vad)); let mut state = RecorderState::new(recorder); diff --git a/src-tauri/src/audio_toolkit/vad/adapter.rs b/src-tauri/src/audio_toolkit/vad/adapter.rs new file mode 100644 index 000000000..8651a5018 --- /dev/null +++ b/src-tauri/src/audio_toolkit/vad/adapter.rs @@ -0,0 +1,123 @@ +//! Adapts transcribe-rs VAD types to Handy's VoiceActivityDetector trait. + +use anyhow::Result; +use std::path::Path; +use transcribe_rs::vad::Vad; + +use super::{VadFrame, VoiceActivityDetector}; + +/// Wrapper around `transcribe_rs::vad::SileroVad`. +pub struct SileroVad(transcribe_rs::vad::SileroVad); + +impl SileroVad { + pub fn new>(model_path: P, threshold: f32) -> Result { + let inner = transcribe_rs::vad::SileroVad::new(model_path, threshold) + .map_err(|e| anyhow::anyhow!("Failed to create SileroVad: {e}"))?; + Ok(Self(inner)) + } +} + +// transcribe-rs Vad is Send but not Sync. SileroVad is only accessed +// behind &mut (single-threaded in the recorder worker), so Sync is safe. +unsafe impl Sync for SileroVad {} + +impl VoiceActivityDetector for SileroVad { + fn push_frame<'a>(&'a mut self, frame: &'a [f32]) -> Result> { + let speech = self + .0 + .is_speech(frame) + .map_err(|e| anyhow::anyhow!("{e}"))?; + if speech { + Ok(VadFrame::Speech(frame)) + } else { + Ok(VadFrame::Noise) + } + } + + fn reset(&mut self) { + self.0.reset(); + } +} + +/// Wrapper around `transcribe_rs::vad::SmoothedVad` that implements +/// Handy's `VoiceActivityDetector` trait with prefill-aware output. +pub struct SmoothedVad { + inner: transcribe_rs::vad::SmoothedVad, + temp_out: Vec, +} + +// Same reasoning as SileroVad — only accessed behind &mut in recorder worker. +unsafe impl Sync for SmoothedVad {} + +impl SmoothedVad { + pub fn new( + inner_vad: Box, + prefill_frames: usize, + hangover_frames: usize, + onset_frames: usize, + ) -> Self { + let adapted = VadAdapter(inner_vad); + let inner = transcribe_rs::vad::SmoothedVad::new( + Box::new(adapted), + prefill_frames, + hangover_frames, + onset_frames, + ); + Self { + inner, + temp_out: Vec::new(), + } + } +} + +impl VoiceActivityDetector for SmoothedVad { + fn push_frame<'a>(&'a mut self, frame: &'a [f32]) -> Result> { + let speech = self + .inner + .is_speech(frame) + .map_err(|e| anyhow::anyhow!("{e}"))?; + + if speech { + let prefill = self.inner.drain_prefill(); + if !prefill.is_empty() { + self.temp_out.clear(); + self.temp_out.extend_from_slice(&prefill); + self.temp_out.extend_from_slice(frame); + Ok(VadFrame::Speech(&self.temp_out)) + } else { + Ok(VadFrame::Speech(frame)) + } + } else { + Ok(VadFrame::Noise) + } + } + + fn reset(&mut self) { + self.inner.reset(); + self.temp_out.clear(); + } +} + +/// Adapts a Handy `VoiceActivityDetector` to the transcribe-rs `Vad` trait. +struct VadAdapter(Box); + +unsafe impl Send for VadAdapter {} + +impl Vad for VadAdapter { + fn frame_size(&self) -> usize { + 480 // 30ms at 16kHz + } + + fn is_speech( + &mut self, + frame: &[f32], + ) -> std::result::Result { + self.0 + .is_voice(frame) + .map_err(|e| transcribe_rs::TranscribeError::Inference(e.to_string())) + } + + fn reset(&mut self) { + self.0.reset(); + } +} diff --git a/src-tauri/src/audio_toolkit/vad/adapter_mock.rs b/src-tauri/src/audio_toolkit/vad/adapter_mock.rs new file mode 100644 index 000000000..5458e3422 --- /dev/null +++ b/src-tauri/src/audio_toolkit/vad/adapter_mock.rs @@ -0,0 +1,46 @@ +// CI-only mock VAD adapter - avoids transcribe-rs dependency. +// This file is copied over adapter.rs during CI tests. + +use anyhow::Result; +use std::path::Path; + +use super::{VadFrame, VoiceActivityDetector}; + +/// Mock SileroVad that always reports noise. +pub struct SileroVad; + +impl SileroVad { + pub fn new>(_model_path: P, _threshold: f32) -> Result { + Ok(Self) + } +} + +impl VoiceActivityDetector for SileroVad { + fn push_frame<'a>(&'a mut self, _frame: &'a [f32]) -> Result> { + Ok(VadFrame::Noise) + } + + fn reset(&mut self) {} +} + +/// Mock SmoothedVad that always reports noise. +pub struct SmoothedVad; + +impl SmoothedVad { + pub fn new( + _inner_vad: Box, + _prefill_frames: usize, + _hangover_frames: usize, + _onset_frames: usize, + ) -> Self { + Self + } +} + +impl VoiceActivityDetector for SmoothedVad { + fn push_frame<'a>(&'a mut self, _frame: &'a [f32]) -> Result> { + Ok(VadFrame::Noise) + } + + fn reset(&mut self) {} +} diff --git a/src-tauri/src/audio_toolkit/vad/mod.rs b/src-tauri/src/audio_toolkit/vad/mod.rs index 7283e0dc1..fa043366f 100644 --- a/src-tauri/src/audio_toolkit/vad/mod.rs +++ b/src-tauri/src/audio_toolkit/vad/mod.rs @@ -25,8 +25,6 @@ pub trait VoiceActivityDetector: Send + Sync { fn reset(&mut self) {} } -mod silero; -mod smoothed; +mod adapter; -pub use silero::SileroVad; -pub use smoothed::SmoothedVad; +pub use adapter::{SileroVad, SmoothedVad}; diff --git a/src-tauri/src/audio_toolkit/vad/silero.rs b/src-tauri/src/audio_toolkit/vad/silero.rs deleted file mode 100644 index 2deec7ecc..000000000 --- a/src-tauri/src/audio_toolkit/vad/silero.rs +++ /dev/null @@ -1,52 +0,0 @@ -use anyhow::Result; -use std::path::Path; - -use vad_rs::Vad; - -use super::{VadFrame, VoiceActivityDetector}; -use crate::audio_toolkit::constants; - -const SILERO_FRAME_MS: u32 = 30; -const SILERO_FRAME_SAMPLES: usize = - (constants::WHISPER_SAMPLE_RATE * SILERO_FRAME_MS / 1000) as usize; - -pub struct SileroVad { - engine: Vad, - threshold: f32, -} - -impl SileroVad { - pub fn new>(model_path: P, threshold: f32) -> Result { - if !(0.0..=1.0).contains(&threshold) { - anyhow::bail!("threshold must be between 0.0 and 1.0"); - } - - Ok(Self { - engine: Vad::new(&model_path, constants::WHISPER_SAMPLE_RATE as usize) - .map_err(|e| anyhow::anyhow!("Failed to create VAD: {e}"))?, - threshold, - }) - } -} - -impl VoiceActivityDetector for SileroVad { - fn push_frame<'a>(&'a mut self, frame: &'a [f32]) -> Result> { - if frame.len() != SILERO_FRAME_SAMPLES { - anyhow::bail!( - "expected {SILERO_FRAME_SAMPLES} samples, got {}", - frame.len() - ); - } - - let result = self - .engine - .compute(frame) - .map_err(|e| anyhow::anyhow!("Silero VAD error: {e}"))?; - - if result.prob > self.threshold { - Ok(VadFrame::Speech(frame)) - } else { - Ok(VadFrame::Noise) - } - } -} diff --git a/src-tauri/src/audio_toolkit/vad/smoothed.rs b/src-tauri/src/audio_toolkit/vad/smoothed.rs deleted file mode 100644 index c0e461627..000000000 --- a/src-tauri/src/audio_toolkit/vad/smoothed.rs +++ /dev/null @@ -1,105 +0,0 @@ -use super::{VadFrame, VoiceActivityDetector}; -use anyhow::Result; -use std::collections::VecDeque; - -pub struct SmoothedVad { - inner_vad: Box, - prefill_frames: usize, - hangover_frames: usize, - onset_frames: usize, - - frame_buffer: VecDeque>, - hangover_counter: usize, - onset_counter: usize, - in_speech: bool, - - temp_out: Vec, -} - -impl SmoothedVad { - pub fn new( - inner_vad: Box, - prefill_frames: usize, - hangover_frames: usize, - onset_frames: usize, - ) -> Self { - Self { - inner_vad, - prefill_frames, - hangover_frames, - onset_frames, - frame_buffer: VecDeque::new(), - hangover_counter: 0, - onset_counter: 0, - in_speech: false, - temp_out: Vec::new(), - } - } -} - -impl VoiceActivityDetector for SmoothedVad { - fn push_frame<'a>(&'a mut self, frame: &'a [f32]) -> Result> { - // 1. Buffer every incoming frame for possible pre-roll - self.frame_buffer.push_back(frame.to_vec()); - while self.frame_buffer.len() > self.prefill_frames + 1 { - self.frame_buffer.pop_front(); - } - - // 2. Delegate to the wrapped boolean VAD - let is_voice = self.inner_vad.is_voice(frame)?; - - match (self.in_speech, is_voice) { - // Potential start of speech - need to accumulate onset frames - (false, true) => { - self.onset_counter += 1; - if self.onset_counter >= self.onset_frames { - // We have enough consecutive voice frames to trigger speech - self.in_speech = true; - self.hangover_counter = self.hangover_frames; - self.onset_counter = 0; // Reset for next time - - // Collect prefill + current frame - self.temp_out.clear(); - for buf in &self.frame_buffer { - self.temp_out.extend(buf); - } - Ok(VadFrame::Speech(&self.temp_out)) - } else { - // Not enough frames yet, still silence - Ok(VadFrame::Noise) - } - } - - // Ongoing Speech - (true, true) => { - self.hangover_counter = self.hangover_frames; - Ok(VadFrame::Speech(frame)) - } - - // End of Speech or interruption during onset phase - (true, false) => { - if self.hangover_counter > 0 { - self.hangover_counter -= 1; - Ok(VadFrame::Speech(frame)) - } else { - self.in_speech = false; - Ok(VadFrame::Noise) - } - } - - // Silence or broken onset sequence - (false, false) => { - self.onset_counter = 0; // Reset onset counter on silence - Ok(VadFrame::Noise) - } - } - } - - fn reset(&mut self) { - self.frame_buffer.clear(); - self.hangover_counter = 0; - self.onset_counter = 0; - self.in_speech = false; - self.temp_out.clear(); - } -} diff --git a/src-tauri/src/commands/history.rs b/src-tauri/src/commands/history.rs index 3792f6f1d..f03c41637 100644 --- a/src-tauri/src/commands/history.rs +++ b/src-tauri/src/commands/history.rs @@ -1,9 +1,12 @@ use crate::actions::process_transcription_output; use crate::managers::{ history::{HistoryManager, PaginatedHistory}, + streaming::transcribe_chunked, transcription::TranscriptionManager, }; +use log::debug; use std::sync::Arc; +use tauri::Manager; use tauri::{AppHandle, State}; #[tauri::command] @@ -84,10 +87,37 @@ pub async fn retry_history_entry_transcription( transcription_manager.initiate_model_load(); let tm = Arc::clone(&transcription_manager); - let transcription = tauri::async_runtime::spawn_blocking(move || tm.transcribe(samples)) + let duration_secs = samples.len() as f32 / 16000.0; + + let transcription = if duration_secs <= 30.0 { + // Short audio: single-shot transcription + tauri::async_runtime::spawn_blocking(move || tm.transcribe(samples)) + .await + .map_err(|e| format!("Transcription task panicked: {}", e))? + .map_err(|e| e.to_string())? + } else { + // Long audio: chunked batch transcription for better performance + debug!( + "Retry using chunked transcription for {:.1}s audio", + duration_secs + ); + let vad_model_path = app + .path() + .resolve( + "resources/models/silero_vad_v4.onnx", + tauri::path::BaseDirectory::Resource, + ) + .map_err(|e| format!("Failed to resolve VAD model path: {}", e))? + .to_string_lossy() + .to_string(); + + tauri::async_runtime::spawn_blocking(move || { + transcribe_chunked(&tm, &samples, &vad_model_path) + }) .await .map_err(|e| format!("Transcription task panicked: {}", e))? - .map_err(|e| e.to_string())?; + .map_err(|e| e.to_string())? + }; if transcription.is_empty() { return Err("Recording contains no speech".to_string()); diff --git a/src-tauri/src/lib.rs b/src-tauri/src/lib.rs index 74472b7d6..2cc3f1c35 100644 --- a/src-tauri/src/lib.rs +++ b/src-tauri/src/lib.rs @@ -347,6 +347,8 @@ pub fn run(cli_args: CliArgs) { shortcut::change_auto_submit_key_setting, shortcut::change_post_process_enabled_setting, shortcut::change_experimental_enabled_setting, + shortcut::change_transcription_mode_setting, + shortcut::change_realtime_chunk_duration_setting, shortcut::change_post_process_base_url_setting, shortcut::change_post_process_api_key_setting, shortcut::change_post_process_model_setting, diff --git a/src-tauri/src/managers/audio.rs b/src-tauri/src/managers/audio.rs index b3378b720..5097fc0ce 100644 --- a/src-tauri/src/managers/audio.rs +++ b/src-tauri/src/managers/audio.rs @@ -4,7 +4,7 @@ use crate::settings::{get_settings, AppSettings}; use crate::utils; use log::{debug, error, info}; use std::sync::atomic::{AtomicU64, Ordering}; -use std::sync::{Arc, Mutex}; +use std::sync::{mpsc, Arc, Mutex}; use std::time::{Duration, Instant}; use tauri::Manager; @@ -403,6 +403,42 @@ impl AudioRecordingManager { } } + pub fn try_start_streaming_recording( + &self, + binding_id: &str, + ) -> Result>, String> { + let mut state = self.state.lock().unwrap(); + + if let RecordingState::Idle = *state { + // Ensure microphone is open in on-demand mode + if matches!(*self.mode.lock().unwrap(), MicrophoneMode::OnDemand) { + self.close_generation.fetch_add(1, Ordering::SeqCst); + if let Err(e) = self.start_microphone_stream() { + let msg = format!("{e}"); + error!("Failed to open microphone stream: {msg}"); + return Err(msg); + } + } + + if let Some(rec) = self.recorder.lock().unwrap().as_ref() { + match rec.start_streaming() { + Ok(rx) => { + *self.is_recording.lock().unwrap() = true; + *state = RecordingState::Recording { + binding_id: binding_id.to_string(), + }; + debug!("Streaming recording started for binding {binding_id}"); + return Ok(rx); + } + Err(e) => return Err(format!("Failed to start streaming: {e}")), + } + } + Err("Recorder not available".to_string()) + } else { + Err("Already recording".to_string()) + } + } + pub fn update_selected_device(&self) -> Result<(), anyhow::Error> { // If currently open, restart the microphone stream to use the new device if *self.is_open.lock().unwrap() { diff --git a/src-tauri/src/managers/mod.rs b/src-tauri/src/managers/mod.rs index 1239dc26b..f10f0ab15 100644 --- a/src-tauri/src/managers/mod.rs +++ b/src-tauri/src/managers/mod.rs @@ -1,4 +1,5 @@ pub mod audio; pub mod history; pub mod model; +pub mod streaming; pub mod transcription; diff --git a/src-tauri/src/managers/streaming.rs b/src-tauri/src/managers/streaming.rs new file mode 100644 index 000000000..98590ff3e --- /dev/null +++ b/src-tauri/src/managers/streaming.rs @@ -0,0 +1,264 @@ +use std::sync::{mpsc, Arc}; +use std::thread; + +use log::{debug, error}; +use tauri::AppHandle; +use transcribe_rs::transcriber::{ + EnergyAdaptiveChunked, EnergyAdaptiveConfig, Transcriber, VadChunked, VadChunkedConfig, +}; +use transcribe_rs::vad::SmoothedVad; +use transcribe_rs::TranscribeOptions; + +use crate::managers::transcription::TranscriptionManager; +use crate::settings::TranscriptionMode; +use crate::utils; + +pub struct StreamingFinishResult { + pub audio: Vec, + pub combined_text: String, +} + +pub struct StreamingSession { + #[allow(dead_code)] + mode: TranscriptionMode, + tm: Arc, + #[allow(dead_code)] + app: AppHandle, + audio_buf: Arc>>, + text_buf: Arc>>, + worker_handle: Option>, + active: bool, +} + +impl StreamingSession { + pub fn start( + mode: TranscriptionMode, + tm: Arc, + app: AppHandle, + chunk_rx: mpsc::Receiver>, + vad_model_path: String, + realtime_chunk_duration_secs: f32, + ) -> Self { + let audio_buf = Arc::new(std::sync::Mutex::new(Vec::new())); + let text_buf = Arc::new(std::sync::Mutex::new(Vec::new())); + + let audio_buf_worker = Arc::clone(&audio_buf); + let text_buf_worker = Arc::clone(&text_buf); + let tm_worker = Arc::clone(&tm); + let app_worker = app.clone(); + let pastes_live = mode == TranscriptionMode::Stream || mode == TranscriptionMode::Realtime; + + tm.begin_streaming(); + + let worker_handle = thread::spawn(move || { + let transcriber_result = + create_transcriber(mode, &vad_model_path, realtime_chunk_duration_secs); + + let mut transcriber = match transcriber_result { + Ok(t) => t, + Err(e) => { + error!("Failed to create transcriber: {}", e); + return; + } + }; + + // Track text already pasted by feed() so we can extract the + // remainder from finish() for live-paste modes. + let mut feed_texts: Vec = Vec::new(); + + for chunk in chunk_rx.iter() { + // Accumulate raw audio for history + audio_buf_worker.lock().unwrap().extend_from_slice(&chunk); + + // Feed to the transcriber via engine access + let feed_result = tm_worker.with_engine(|model| { + let results = transcriber + .feed(model, &chunk) + .map_err(|e| anyhow::anyhow!("Transcriber feed error: {}", e))?; + Ok(results) + }); + + match feed_result { + Ok(results) => { + for result in results { + let text = result.text.trim().to_string(); + if !text.is_empty() { + debug!("Streaming chunk transcribed: '{}'", text); + feed_texts.push(text.clone()); + + if pastes_live { + // Strip trailing punctuation and add a space so + // the next chunk flows naturally after this one. + let paste_text = + format!("{} ", strip_trailing_punctuation(&text)); + paste_on_main_thread(&app_worker, paste_text); + } + } + } + } + Err(e) => error!("Streaming transcription error: {}", e), + } + } + + // Channel closed — finish() transcribes the remainder and returns + // ALL chunks merged (feed results + remainder). + let finish_result = tm_worker.with_engine(|model| { + let result = transcriber + .finish(model) + .map_err(|e| anyhow::anyhow!("Transcriber finish error: {}", e))?; + Ok(result) + }); + + match finish_result { + Ok(result) => { + let full_text = result.text.trim().to_string(); + if !full_text.is_empty() { + debug!("Streaming session complete: '{}'", full_text); + // Use finish()'s merged result as the authoritative combined text + *text_buf_worker.lock().unwrap() = vec![full_text.clone()]; + + // For live-paste modes: extract the remainder that wasn't + // already pasted by feed() and paste it now + if pastes_live { + let already_pasted = feed_texts.join(" "); + let remainder = if full_text.len() > already_pasted.len() { + full_text[already_pasted.len()..].trim() + } else { + "" + }; + if !remainder.is_empty() { + debug!("Pasting remainder: '{}'", remainder); + paste_on_main_thread(&app_worker, remainder.to_string()); + } + } + } + } + Err(e) => error!("Streaming finish error: {}", e), + } + + debug!("Streaming consumer thread exiting"); + }); + + Self { + mode, + tm, + app, + audio_buf, + text_buf, + worker_handle: Some(worker_handle), + active: true, + } + } + + /// Wait for the worker thread to complete and return the combined result. + pub fn finish(mut self) -> StreamingFinishResult { + // Wait for the worker to finish (it exits when chunk_rx is dropped) + if let Some(handle) = self.worker_handle.take() { + let _ = handle.join(); + } + + let audio = std::mem::take(&mut *self.audio_buf.lock().unwrap()); + let combined_text = self.text_buf.lock().unwrap().join(" "); + + self.active = false; + self.tm.end_streaming(); + + StreamingFinishResult { + audio, + combined_text, + } + } +} + +impl Drop for StreamingSession { + fn drop(&mut self) { + if self.active { + self.tm.end_streaming(); + } + } +} + +fn create_transcriber( + mode: TranscriptionMode, + vad_model_path: &str, + realtime_chunk_duration_secs: f32, +) -> Result, anyhow::Error> { + let options = TranscribeOptions::default(); + + match mode { + TranscriptionMode::Realtime => { + let config = EnergyAdaptiveConfig { + target_chunk_secs: realtime_chunk_duration_secs, + search_window_secs: 1.0, + padding_secs: 0.0, + min_chunk_secs: 1.0, + frame_size: 480, + merge_separator: " ".into(), + }; + Ok(Box::new(EnergyAdaptiveChunked::new(config, options))) + } + TranscriptionMode::Stream | TranscriptionMode::BatchStream => { + let silero = transcribe_rs::vad::SileroVad::new(vad_model_path, 0.3) + .map_err(|e| anyhow::anyhow!("Failed to create SileroVad: {}", e))?; + let vad = SmoothedVad::new(Box::new(silero), 15, 15, 2); + let config = VadChunkedConfig { + min_chunk_secs: 5.0, + max_chunk_secs: 30.0, + padding_secs: 0.0, + smart_split_search_secs: Some(3.0), + merge_separator: " ".into(), + }; + Ok(Box::new(VadChunked::new(Box::new(vad), config, options))) + } + TranscriptionMode::Standard => { + unreachable!("Standard mode should not create a streaming session") + } + } +} + +/// Strip trailing sentence-ending punctuation from a chunk so it flows +/// naturally into the next chunk when pasted live. +fn strip_trailing_punctuation(text: &str) -> &str { + text.trim_end_matches(|c: char| matches!(c, '.' | ',' | ';' | '!' | '?' | '…')) + .trim_end() +} + +/// Batch-transcribe long audio using VAD-based chunking. +/// +/// This is used by the history retry command when audio exceeds 30 s. +pub fn transcribe_chunked( + tm: &TranscriptionManager, + samples: &[f32], + vad_model_path: &str, +) -> Result { + let silero = transcribe_rs::vad::SileroVad::new(vad_model_path, 0.3) + .map_err(|e| anyhow::anyhow!("Failed to create SileroVad: {}", e))?; + let vad = SmoothedVad::new(Box::new(silero), 15, 15, 2); + let config = VadChunkedConfig { + min_chunk_secs: 10.0, + max_chunk_secs: 30.0, + padding_secs: 0.0, + smart_split_search_secs: Some(3.0), + merge_separator: " ".into(), + }; + let mut transcriber = VadChunked::new(Box::new(vad), config, TranscribeOptions::default()); + + tm.with_engine(|model| { + transcriber + .feed(model, samples) + .map_err(|e| anyhow::anyhow!("Transcriber feed error: {}", e))?; + let result = transcriber + .finish(model) + .map_err(|e| anyhow::anyhow!("Transcriber finish error: {}", e))?; + Ok(result.text.trim().to_string()) + }) +} + +fn paste_on_main_thread(app: &AppHandle, text: String) { + let app_clone = app.clone(); + let _ = app.run_on_main_thread(move || { + if let Err(e) = utils::paste(text, app_clone.clone()) { + error!("Failed to paste streamed chunk: {}", e); + } + }); +} diff --git a/src-tauri/src/managers/streaming_mock.rs b/src-tauri/src/managers/streaming_mock.rs new file mode 100644 index 000000000..2f6ac0466 --- /dev/null +++ b/src-tauri/src/managers/streaming_mock.rs @@ -0,0 +1,45 @@ +// CI-only mock StreamingSession - avoids transcribe-rs dependency. +// This file is copied over streaming.rs during CI tests. + +use std::sync::mpsc; + +use crate::managers::transcription::TranscriptionManager; +use crate::settings::TranscriptionMode; +use std::sync::Arc; +use tauri::AppHandle; + +pub struct StreamingFinishResult { + pub audio: Vec, + pub combined_text: String, +} + +pub struct StreamingSession; + +impl StreamingSession { + pub fn start( + _mode: TranscriptionMode, + _tm: Arc, + _app: AppHandle, + _chunk_rx: mpsc::Receiver>, + _vad_model_path: String, + _realtime_chunk_duration_secs: f32, + ) -> Self { + Self + } + + pub fn finish(self) -> StreamingFinishResult { + StreamingFinishResult { + audio: Vec::new(), + combined_text: String::new(), + } + } +} + +/// No-op in CI mock. +pub fn transcribe_chunked( + _tm: &TranscriptionManager, + _samples: &[f32], + _vad_model_path: &str, +) -> Result { + Ok(String::new()) +} diff --git a/src-tauri/src/managers/transcription.rs b/src-tauri/src/managers/transcription.rs index 32dd42c76..7a07a0908 100644 --- a/src-tauri/src/managers/transcription.rs +++ b/src-tauri/src/managers/transcription.rs @@ -9,7 +9,7 @@ use log::{debug, error, info, warn}; use serde::Serialize; use specta::Type; use std::panic::{catch_unwind, AssertUnwindSafe}; -use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; +use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering}; use std::sync::{Arc, Condvar, Mutex, MutexGuard, OnceLock}; use std::thread; use std::time::{Duration, SystemTime}; @@ -71,6 +71,7 @@ pub struct TranscriptionManager { watcher_handle: Arc>>>, is_loading: Arc>, loading_condvar: Arc, + streaming_sessions: Arc, } impl TranscriptionManager { @@ -85,6 +86,7 @@ impl TranscriptionManager { watcher_handle: Arc::new(Mutex::new(None)), is_loading: Arc::new(Mutex::new(false)), loading_condvar: Arc::new(Condvar::new()), + streaming_sessions: Arc::new(AtomicUsize::new(0)), }; // Start the idle watcher @@ -112,6 +114,12 @@ impl TranscriptionManager { continue; } + // Skip unloading while a streaming session is active + if manager_cloned.streaming_sessions.load(Ordering::Relaxed) > 0 { + manager_cloned.touch_activity(); + continue; + } + // While recording, keep the idle timer fresh so the // model is never unloaded mid-session. let is_recording = app_handle_cloned @@ -235,8 +243,13 @@ impl TranscriptionManager { self.last_activity.store(Self::now_ms(), Ordering::Relaxed); } - /// Unloads the model immediately if the setting is enabled and the model is loaded + /// Unloads the model immediately if the setting is enabled and the model is loaded. + /// Skips unloading when a streaming session is active. pub fn maybe_unload_immediately(&self, context: &str) { + if self.streaming_sessions.load(Ordering::Relaxed) > 0 { + debug!("Skipping immediate unload during active streaming session"); + return; + } let settings = get_settings(&self.app_handle); if settings.model_unload_timeout == ModelUnloadTimeout::Immediately && self.is_model_loaded() @@ -248,6 +261,106 @@ impl TranscriptionManager { } } + /// Mark a streaming session as active. While active, the model will not be + /// auto-unloaded between transcription calls. + pub fn begin_streaming(&self) { + self.streaming_sessions.fetch_add(1, Ordering::Relaxed); + debug!( + "Streaming session started (active: {})", + self.streaming_sessions.load(Ordering::Relaxed) + ); + } + + /// Mark a streaming session as finished. If no sessions remain, normal + /// unload behaviour resumes. + pub fn end_streaming(&self) { + let prev = self.streaming_sessions.fetch_sub(1, Ordering::Relaxed); + debug!( + "Streaming session ended (active: {})", + prev.saturating_sub(1) + ); + } + + /// Execute a closure with exclusive access to the loaded speech model. + /// + /// The engine is taken out of the mutex before calling `f` and put back + /// after, so the mutex is NOT held during the (potentially slow) model call. + /// Panics in `f` are caught and the engine is discarded (unloaded). + pub fn with_engine(&self, f: F) -> Result + where + F: FnOnce(&mut dyn SpeechModel) -> Result, + { + // Wait for any in-progress model loading to complete + { + let mut is_loading = self.is_loading.lock().unwrap(); + while *is_loading { + is_loading = self.loading_condvar.wait(is_loading).unwrap(); + } + } + + let mut engine_guard = self.lock_engine(); + let mut engine = engine_guard + .take() + .ok_or_else(|| anyhow::anyhow!("No model loaded. Please check your model settings."))?; + drop(engine_guard); + + let mut f = Some(f); + let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { + let f = f.take().unwrap(); + let model_ref: &mut dyn SpeechModel = match &mut engine { + LoadedEngine::Whisper(e) => e, + LoadedEngine::Parakeet(e) => e, + LoadedEngine::Moonshine(e) => e, + LoadedEngine::MoonshineStreaming(e) => e, + LoadedEngine::SenseVoice(e) => e, + LoadedEngine::GigaAM(e) => e, + LoadedEngine::Canary(e) => e, + }; + f(model_ref) + })); + + match result { + Ok(inner) => { + let mut engine_guard = self.lock_engine(); + *engine_guard = Some(engine); + self.touch_activity(); + inner + } + Err(panic_payload) => { + let panic_msg = if let Some(s) = panic_payload.downcast_ref::<&str>() { + s.to_string() + } else if let Some(s) = panic_payload.downcast_ref::() { + s.clone() + } else { + "unknown panic".to_string() + }; + error!("Engine panicked in with_engine: {}", panic_msg); + + { + let mut current_model = self + .current_model_id + .lock() + .unwrap_or_else(|e| e.into_inner()); + *current_model = None; + } + let _ = self.app_handle.emit( + "model-state-changed", + ModelStateEvent { + event_type: "unloaded".to_string(), + model_id: None, + model_name: None, + error: Some(format!("Engine panicked: {}", panic_msg)), + }, + ); + + Err(anyhow::anyhow!( + "Engine panicked: {}. Model has been unloaded.", + panic_msg + )) + } + } + } + pub fn load_model(&self, model_id: &str) -> Result<()> { let load_start = std::time::Instant::now(); debug!("Starting to load model: {}", model_id); diff --git a/src-tauri/src/settings.rs b/src-tauri/src/settings.rs index 6bf657b4f..c5d5166af 100644 --- a/src-tauri/src/settings.rs +++ b/src-tauri/src/settings.rs @@ -304,6 +304,21 @@ impl Default for OrtAcceleratorSetting { } } +#[derive(Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq, Type)] +#[serde(rename_all = "snake_case")] +pub enum TranscriptionMode { + Standard, + Realtime, + Stream, + BatchStream, +} + +impl Default for TranscriptionMode { + fn default() -> Self { + TranscriptionMode::Standard + } +} + /* still handy for composing the initial JSON in the store ------------- */ #[derive(Serialize, Deserialize, Debug, Clone, Type)] pub struct AppSettings { @@ -401,6 +416,10 @@ pub struct AppSettings { pub whisper_gpu_device: i32, #[serde(default)] pub extra_recording_buffer_ms: u64, + #[serde(default)] + pub transcription_mode: TranscriptionMode, + #[serde(default = "default_realtime_chunk_duration_secs")] + pub realtime_chunk_duration_secs: f32, } fn default_model() -> String { @@ -611,6 +630,10 @@ fn default_whisper_gpu_device() -> i32 { -1 // auto } +fn default_realtime_chunk_duration_secs() -> f32 { + 3.0 +} + fn default_typing_tool() -> TypingTool { TypingTool::Auto } @@ -775,6 +798,8 @@ pub fn get_default_settings() -> AppSettings { ort_accelerator: OrtAcceleratorSetting::default(), whisper_gpu_device: default_whisper_gpu_device(), extra_recording_buffer_ms: 0, + transcription_mode: TranscriptionMode::default(), + realtime_chunk_duration_secs: default_realtime_chunk_duration_secs(), } } diff --git a/src-tauri/src/shortcut/mod.rs b/src-tauri/src/shortcut/mod.rs index 6d179f175..3e05f8fff 100644 --- a/src-tauri/src/shortcut/mod.rs +++ b/src-tauri/src/shortcut/mod.rs @@ -824,6 +824,27 @@ pub fn change_experimental_enabled_setting(app: AppHandle, enabled: bool) -> Res Ok(()) } +#[tauri::command] +#[specta::specta] +pub fn change_transcription_mode_setting( + app: AppHandle, + mode: settings::TranscriptionMode, +) -> Result<(), String> { + let mut settings = settings::get_settings(&app); + settings.transcription_mode = mode; + settings::write_settings(&app, settings); + Ok(()) +} + +#[tauri::command] +#[specta::specta] +pub fn change_realtime_chunk_duration_setting(app: AppHandle, duration: f32) -> Result<(), String> { + let mut settings = settings::get_settings(&app); + settings.realtime_chunk_duration_secs = duration; + settings::write_settings(&app, settings); + Ok(()) +} + #[tauri::command] #[specta::specta] pub fn change_post_process_base_url_setting( diff --git a/src/bindings.ts b/src/bindings.ts index 14d98b380..0b24da0f8 100644 --- a/src/bindings.ts +++ b/src/bindings.ts @@ -192,6 +192,22 @@ async changeExperimentalEnabledSetting(enabled: boolean) : Promise> { + try { + return { status: "ok", data: await TAURI_INVOKE("change_transcription_mode_setting", { mode }) }; +} catch (e) { + if(e instanceof Error) throw e; + else return { status: "error", error: e as any }; +} +}, +async changeRealtimeChunkDurationSetting(duration: number) : Promise> { + try { + return { status: "ok", data: await TAURI_INVOKE("change_realtime_chunk_duration_setting", { duration }) }; +} catch (e) { + if(e instanceof Error) throw e; + else return { status: "error", error: e as any }; +} +}, async changePostProcessBaseUrlSetting(providerId: string, baseUrl: string) : Promise> { try { return { status: "ok", data: await TAURI_INVOKE("change_post_process_base_url_setting", { providerId, baseUrl }) }; @@ -827,7 +843,7 @@ historyUpdatePayload: "history-update-payload" /** user-defined types **/ -export type AppSettings = { bindings: Partial<{ [key in string]: ShortcutBinding }>; push_to_talk: boolean; audio_feedback: boolean; audio_feedback_volume?: number; sound_theme?: SoundTheme; start_hidden?: boolean; autostart_enabled?: boolean; update_checks_enabled?: boolean; selected_model?: string; always_on_microphone?: boolean; selected_microphone?: string | null; clamshell_microphone?: string | null; selected_output_device?: string | null; translate_to_english?: boolean; selected_language?: string; overlay_position?: OverlayPosition; debug_mode?: boolean; log_level?: LogLevel; custom_words?: string[]; model_unload_timeout?: ModelUnloadTimeout; word_correction_threshold?: number; history_limit?: number; recording_retention_period?: RecordingRetentionPeriod; paste_method?: PasteMethod; clipboard_handling?: ClipboardHandling; auto_submit?: boolean; auto_submit_key?: AutoSubmitKey; post_process_enabled?: boolean; post_process_provider_id?: string; post_process_providers?: PostProcessProvider[]; post_process_api_keys?: Partial<{ [key in string]: string }>; post_process_models?: Partial<{ [key in string]: string }>; post_process_prompts?: LLMPrompt[]; post_process_selected_prompt_id?: string | null; mute_while_recording?: boolean; append_trailing_space?: boolean; app_language?: string; experimental_enabled?: boolean; lazy_stream_close?: boolean; keyboard_implementation?: KeyboardImplementation; show_tray_icon?: boolean; paste_delay_ms?: number; typing_tool?: TypingTool; external_script_path: string | null; custom_filler_words?: string[] | null; whisper_accelerator?: WhisperAcceleratorSetting; ort_accelerator?: OrtAcceleratorSetting; whisper_gpu_device?: number; extra_recording_buffer_ms?: number } +export type AppSettings = { bindings: Partial<{ [key in string]: ShortcutBinding }>; push_to_talk: boolean; audio_feedback: boolean; audio_feedback_volume?: number; sound_theme?: SoundTheme; start_hidden?: boolean; autostart_enabled?: boolean; update_checks_enabled?: boolean; selected_model?: string; always_on_microphone?: boolean; selected_microphone?: string | null; clamshell_microphone?: string | null; selected_output_device?: string | null; translate_to_english?: boolean; selected_language?: string; overlay_position?: OverlayPosition; debug_mode?: boolean; log_level?: LogLevel; custom_words?: string[]; model_unload_timeout?: ModelUnloadTimeout; word_correction_threshold?: number; history_limit?: number; recording_retention_period?: RecordingRetentionPeriod; paste_method?: PasteMethod; clipboard_handling?: ClipboardHandling; auto_submit?: boolean; auto_submit_key?: AutoSubmitKey; post_process_enabled?: boolean; post_process_provider_id?: string; post_process_providers?: PostProcessProvider[]; post_process_api_keys?: Partial<{ [key in string]: string }>; post_process_models?: Partial<{ [key in string]: string }>; post_process_prompts?: LLMPrompt[]; post_process_selected_prompt_id?: string | null; mute_while_recording?: boolean; append_trailing_space?: boolean; app_language?: string; experimental_enabled?: boolean; lazy_stream_close?: boolean; keyboard_implementation?: KeyboardImplementation; show_tray_icon?: boolean; paste_delay_ms?: number; typing_tool?: TypingTool; external_script_path: string | null; custom_filler_words?: string[] | null; whisper_accelerator?: WhisperAcceleratorSetting; ort_accelerator?: OrtAcceleratorSetting; whisper_gpu_device?: number; extra_recording_buffer_ms?: number; transcription_mode?: TranscriptionMode; realtime_chunk_duration_secs?: number } export type AudioDevice = { index: string; name: string; is_default: boolean } export type AutoSubmitKey = "enter" | "ctrl_enter" | "cmd_enter" export type AvailableAccelerators = { whisper: string[]; ort: string[]; gpu_devices: GpuDeviceOption[] } @@ -861,6 +877,7 @@ export type PostProcessProvider = { id: string; label: string; base_url: string; export type RecordingRetentionPeriod = "never" | "preserve_limit" | "days_3" | "weeks_2" | "months_3" export type ShortcutBinding = { id: string; name: string; description: string; default_binding: string; current_binding: string } export type SoundTheme = "marimba" | "pop" | "custom" +export type TranscriptionMode = "standard" | "realtime" | "stream" | "batch_stream" export type TypingTool = "auto" | "wtype" | "kwtype" | "dotool" | "ydotool" | "xdotool" export type WhisperAcceleratorSetting = "auto" | "cpu" | "gpu" export type WindowsMicrophonePermissionStatus = { supported: boolean; overall_access: PermissionAccess; device_access: PermissionAccess; app_access: PermissionAccess; desktop_app_access: PermissionAccess } diff --git a/src/components/settings/RealtimeChunkDuration.tsx b/src/components/settings/RealtimeChunkDuration.tsx new file mode 100644 index 000000000..59510b935 --- /dev/null +++ b/src/components/settings/RealtimeChunkDuration.tsx @@ -0,0 +1,35 @@ +import React from "react"; +import { useTranslation } from "react-i18next"; +import { Slider } from "../ui/Slider"; +import { useSettings } from "../../hooks/useSettings"; + +interface RealtimeChunkDurationProps { + descriptionMode?: "inline" | "tooltip"; + grouped?: boolean; +} + +export const RealtimeChunkDuration: React.FC = + React.memo(({ descriptionMode = "tooltip", grouped = false }) => { + const { t } = useTranslation(); + const { getSetting, updateSetting } = useSettings(); + const duration = getSetting("realtime_chunk_duration_secs") ?? 3.0; + + return ( + + updateSetting("realtime_chunk_duration_secs", value) + } + min={1} + max={10} + step={0.5} + label={t("settings.experimental.realtimeChunkDuration.title")} + description={t( + "settings.experimental.realtimeChunkDuration.description", + )} + descriptionMode={descriptionMode} + grouped={grouped} + formatValue={(value) => `${value}s`} + /> + ); + }); diff --git a/src/components/settings/TranscriptionModeSetting.tsx b/src/components/settings/TranscriptionModeSetting.tsx new file mode 100644 index 000000000..33f4651eb --- /dev/null +++ b/src/components/settings/TranscriptionModeSetting.tsx @@ -0,0 +1,58 @@ +import React from "react"; +import { useTranslation } from "react-i18next"; +import { Dropdown } from "../ui/Dropdown"; +import { SettingContainer } from "../ui/SettingContainer"; +import { useSettings } from "../../hooks/useSettings"; +import type { TranscriptionMode } from "@/bindings"; + +interface TranscriptionModeSettingProps { + descriptionMode?: "inline" | "tooltip"; + grouped?: boolean; +} + +export const TranscriptionModeSetting: React.FC = + React.memo(({ descriptionMode = "tooltip", grouped = false }) => { + const { t } = useTranslation(); + const { getSetting, updateSetting, isUpdating } = useSettings(); + + const selectedMode = (getSetting("transcription_mode") || + "standard") as TranscriptionMode; + + const options = [ + { + value: "standard", + label: t("settings.experimental.transcriptionMode.options.standard"), + }, + { + value: "realtime", + label: t("settings.experimental.transcriptionMode.options.realtime"), + }, + { + value: "stream", + label: t("settings.experimental.transcriptionMode.options.stream"), + }, + { + value: "batch_stream", + label: t("settings.experimental.transcriptionMode.options.batchStream"), + }, + ]; + + return ( + + + updateSetting("transcription_mode", value as TranscriptionMode) + } + disabled={isUpdating("transcription_mode")} + /> + + ); + }); diff --git a/src/components/settings/advanced/AdvancedSettings.tsx b/src/components/settings/advanced/AdvancedSettings.tsx index 733f97db6..14cacf269 100644 --- a/src/components/settings/advanced/AdvancedSettings.tsx +++ b/src/components/settings/advanced/AdvancedSettings.tsx @@ -20,11 +20,14 @@ import { useSettings } from "../../../hooks/useSettings"; import { KeyboardImplementationSelector } from "../debug/KeyboardImplementationSelector"; import { AccelerationSelector } from "../AccelerationSelector"; import { LazyStreamClose } from "../LazyStreamClose"; +import { TranscriptionModeSetting } from "../TranscriptionModeSetting"; +import { RealtimeChunkDuration } from "../RealtimeChunkDuration"; export const AdvancedSettings: React.FC = () => { const { t } = useTranslation(); const { getSetting } = useSettings(); const experimentalEnabled = getSetting("experimental_enabled") || false; + const transcriptionMode = getSetting("transcription_mode") || "standard"; return (
@@ -59,6 +62,10 @@ export const AdvancedSettings: React.FC = () => { {experimentalEnabled && ( + + {transcriptionMode === "realtime" && ( + + )} commands.changeAppLanguageSetting(value as string), experimental_enabled: (value) => commands.changeExperimentalEnabledSetting(value as boolean), + transcription_mode: (value) => + commands.changeTranscriptionModeSetting(value as any), + realtime_chunk_duration_secs: (value) => + commands.changeRealtimeChunkDurationSetting(value as number), lazy_stream_close: (value) => commands.changeLazyStreamCloseSetting(value as boolean), show_tray_icon: (value) =>