diff --git a/src/accel.rs b/src/accel.rs index 3d6b2a0..ffab123 100644 --- a/src/accel.rs +++ b/src/accel.rs @@ -6,7 +6,7 @@ use std::fmt; use std::str::FromStr; -use std::sync::atomic::{AtomicI32, AtomicU8, Ordering}; +use std::sync::atomic::{AtomicBool, AtomicI32, AtomicU8, Ordering}; use serde::{Deserialize, Serialize}; @@ -72,6 +72,39 @@ pub fn get_ort_accelerator() -> OrtAccelerator { OrtAccelerator::from_u8(ORT_ACCELERATOR.load(Ordering::Relaxed)) } +static ORT_INTRA_THREADS: AtomicU8 = AtomicU8::new(0); + +/// Set the global ORT intra-op thread count. +/// +/// Applies to all ORT sessions (encoder and decoder) unless overridden by +/// an explicit thread count parameter. 0 means use the ORT default (all +/// available cores). Call before loading models; already-loaded sessions +/// are not affected. +pub fn set_ort_intra_threads(count: u8) { + ORT_INTRA_THREADS.store(count, Ordering::Relaxed); +} + +/// Get the current ORT intra-op thread count (0 = ORT default). +pub fn get_ort_intra_threads() -> u8 { + ORT_INTRA_THREADS.load(Ordering::Relaxed) +} + +static DECODER_GPU: AtomicBool = AtomicBool::new(false); + +/// Enable GPU execution providers for decoder sessions. +/// +/// By default, decoder sessions use CPU-only with arena allocator because +/// sequential execution makes GPU kernel launch overhead net-negative for +/// per-token latency at batch size 1. Set to `true` for GPU benchmarking. +pub fn set_decoder_gpu(enable: bool) { + DECODER_GPU.store(enable, Ordering::Relaxed); +} + +/// Get whether decoder sessions should use GPU execution providers. +pub fn get_decoder_gpu() -> bool { + DECODER_GPU.load(Ordering::Relaxed) +} + impl OrtAccelerator { /// Return the list of ORT accelerators that are compiled-in for the current build. /// @@ -314,6 +347,8 @@ mod tests { set_ort_accelerator(OrtAccelerator::Auto); set_whisper_accelerator(WhisperAccelerator::Auto); set_whisper_gpu_device(GPU_DEVICE_AUTO); + set_ort_intra_threads(0); + set_decoder_gpu(false); } } @@ -422,6 +457,22 @@ mod tests { assert_eq!(OrtAccelerator::from_u8(255), OrtAccelerator::Auto); } + // -- ORT thread count tests -- + + #[test] + fn ort_intra_threads_default_is_zero() { + let _g = AccelGuard::new(); + set_ort_intra_threads(0); + assert_eq!(get_ort_intra_threads(), 0); + } + + #[test] + fn ort_intra_threads_roundtrip() { + let _g = AccelGuard::new(); + set_ort_intra_threads(6); + assert_eq!(get_ort_intra_threads(), 6); + } + // -- Whisper tests -- #[test] diff --git a/src/lib.rs b/src/lib.rs index deb229e..59a2721 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -90,7 +90,8 @@ pub mod accel; pub mod audio; pub mod error; pub use accel::{ - get_ort_accelerator, get_whisper_accelerator, get_whisper_gpu_device, set_ort_accelerator, + get_decoder_gpu, get_ort_accelerator, get_ort_intra_threads, get_whisper_accelerator, + get_whisper_gpu_device, set_decoder_gpu, set_ort_accelerator, set_ort_intra_threads, set_whisper_accelerator, set_whisper_gpu_device, OrtAccelerator, WhisperAccelerator, GPU_DEVICE_AUTO, }; diff --git a/src/onnx/session.rs b/src/onnx/session.rs index 0e43ce9..400110e 100644 --- a/src/onnx/session.rs +++ b/src/onnx/session.rs @@ -144,6 +144,19 @@ fn is_xnnpack_active() -> bool { pref == OrtAccelerator::Xnnpack && cfg!(feature = "ort-xnnpack") } +/// Resolve intra-op thread count: explicit param > global setting > ORT default. +fn resolve_intra_threads(explicit: usize) -> Option { + if explicit > 0 { + return Some(explicit); + } + let global = crate::accel::get_ort_intra_threads(); + if global > 0 { + Some(global as usize) + } else { + None + } +} + /// Internal session builder with full control over threading and EP selection. fn build_session( path: &Path, @@ -158,8 +171,9 @@ fn build_session( // force a single intra-op thread when XNNPACK is the active EP. builder = builder.with_intra_op_spinning(false)?; builder = builder.with_intra_threads(1)?; - } else if let Some(n) = intra_threads { - if n > 0 { + } else { + let threads = resolve_intra_threads(intra_threads.unwrap_or(0)); + if let Some(n) = threads { builder = builder.with_intra_threads(n)?; } } @@ -287,3 +301,65 @@ pub fn read_metadata_float_vec( None => Ok(None), } } + +/// Session for autoregressive decoders: sequential execution, configurable +/// intra-op threads. +/// +/// By default forces CPU-only with arena allocator — sequential execution makes +/// GPU kernel launch overhead net-negative for per-token latency at batch size 1. +/// +/// Call [`crate::set_decoder_gpu(true)`] to use the global accelerator preference +/// for decoder sessions (for GPU benchmarking). +/// +/// Thread count resolution: +/// 1. `num_threads` if > 0 +/// 2. Global `set_ort_intra_threads` if > 0 +/// 3. ORT default (all cores) +pub fn create_decoder_session(path: &Path, num_threads: usize) -> Result { + let use_gpu = crate::accel::get_decoder_gpu(); + + let threads = resolve_intra_threads(num_threads); + + let mut builder = Session::builder()? + .with_optimization_level(GraphOptimizationLevel::Level3)? + .with_parallel_execution(false)?; + if let Some(n) = threads { + builder = builder.with_intra_threads(n)?; + } + + let session = if use_gpu { + log::info!("Decoder session using global accelerator (set_decoder_gpu=true)"); + builder + .with_execution_providers(execution_providers())? + .commit_from_file(path)? + } else { + if get_ort_accelerator() != OrtAccelerator::CpuOnly { + log::info!( + "Decoder session uses CPU-only (sequential execution); \ + call set_decoder_gpu(true) to override" + ); + } + builder + .with_execution_providers([CPUExecutionProvider::default() + .with_arena_allocator(true) + .build()])? + .commit_from_file(path)? + }; + + for input in session.inputs() { + log::info!( + "Model input: name={}, type={:?}", + input.name(), + input.dtype() + ); + } + for output in session.outputs() { + log::info!( + "Model output: name={}, type={:?}", + output.name(), + output.dtype() + ); + } + + Ok(session) +}