Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 52 additions & 1 deletion src/accel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

use std::fmt;
use std::str::FromStr;
use std::sync::atomic::{AtomicI32, AtomicU8, Ordering};
use std::sync::atomic::{AtomicBool, AtomicI32, AtomicU8, Ordering};

use serde::{Deserialize, Serialize};

Expand Down Expand Up @@ -45,6 +45,39 @@ pub fn get_ort_accelerator() -> OrtAccelerator {
OrtAccelerator::from_u8(ORT_ACCELERATOR.load(Ordering::Relaxed))
}

static ORT_INTRA_THREADS: AtomicU8 = AtomicU8::new(0);

/// Set the global ORT intra-op thread count.
///
/// Applies to all ORT sessions (encoder and decoder) unless overridden by
/// an explicit thread count parameter. 0 means use the ORT default (all
/// available cores). Call before loading models; already-loaded sessions
/// are not affected.
pub fn set_ort_intra_threads(count: u8) {
ORT_INTRA_THREADS.store(count, Ordering::Relaxed);
}

/// Get the current ORT intra-op thread count (0 = ORT default).
pub fn get_ort_intra_threads() -> u8 {
ORT_INTRA_THREADS.load(Ordering::Relaxed)
}

static DECODER_GPU: AtomicBool = AtomicBool::new(false);

/// Enable GPU execution providers for decoder sessions.
///
/// By default, decoder sessions use CPU-only with arena allocator because
/// sequential execution makes GPU kernel launch overhead net-negative for
/// per-token latency at batch size 1. Set to `true` for GPU benchmarking.
pub fn set_decoder_gpu(enable: bool) {
DECODER_GPU.store(enable, Ordering::Relaxed);
}

/// Get whether decoder sessions should use GPU execution providers.
pub fn get_decoder_gpu() -> bool {
DECODER_GPU.load(Ordering::Relaxed)
}

impl OrtAccelerator {
/// Return the list of ORT accelerators that are compiled-in for the current build.
///
Expand Down Expand Up @@ -262,6 +295,8 @@ mod tests {
fn drop(&mut self) {
set_ort_accelerator(OrtAccelerator::Auto);
set_whisper_accelerator(WhisperAccelerator::Auto);
set_ort_intra_threads(0);
set_decoder_gpu(false);
set_whisper_gpu_device(GPU_DEVICE_AUTO);
}
}
Expand Down Expand Up @@ -340,6 +375,22 @@ mod tests {
assert_eq!(OrtAccelerator::from_u8(255), OrtAccelerator::Auto);
}

// -- ORT thread count tests --

#[test]
fn ort_intra_threads_default_is_zero() {
let _g = AccelGuard::new();
set_ort_intra_threads(0);
assert_eq!(get_ort_intra_threads(), 0);
}

#[test]
fn ort_intra_threads_roundtrip() {
let _g = AccelGuard::new();
set_ort_intra_threads(6);
assert_eq!(get_ort_intra_threads(), 6);
}

// -- Whisper tests --

#[test]
Expand Down
3 changes: 2 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ pub mod accel;
pub mod audio;
pub mod error;
pub use accel::{
get_ort_accelerator, get_whisper_accelerator, get_whisper_gpu_device, set_ort_accelerator,
get_decoder_gpu, get_ort_accelerator, get_ort_intra_threads, get_whisper_accelerator,
get_whisper_gpu_device, set_decoder_gpu, set_ort_accelerator, set_ort_intra_threads,
set_whisper_accelerator, set_whisper_gpu_device, OrtAccelerator, WhisperAccelerator,
GPU_DEVICE_AUTO,
};
Expand Down
82 changes: 78 additions & 4 deletions src/onnx/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,19 @@ fn directml_active() -> bool {
get_ort_accelerator() == OrtAccelerator::DirectMl && cfg!(feature = "ort-directml")
}

/// Resolve intra-op thread count: explicit param > global setting > ORT default.
fn resolve_intra_threads(explicit: usize) -> Option<usize> {
if explicit > 0 {
return Some(explicit);
}
let global = crate::accel::get_ort_intra_threads();
if global > 0 {
Some(global as usize)
} else {
None
}
}

/// Internal session builder with full control over threading and EP selection.
fn build_session(
path: &Path,
Expand All @@ -75,10 +88,9 @@ fn build_session(
let mut builder =
Session::builder()?.with_optimization_level(GraphOptimizationLevel::Level3)?;

if let Some(n) = intra_threads {
if n > 0 {
builder = builder.with_intra_threads(n)?;
}
let threads = resolve_intra_threads(intra_threads.unwrap_or(0));
if let Some(n) = threads {
builder = builder.with_intra_threads(n)?;
}

// DirectML requires parallel_execution(false) and memory_pattern(false)
Expand Down Expand Up @@ -203,3 +215,65 @@ pub fn read_metadata_float_vec(
None => Ok(None),
}
}

/// Session for autoregressive decoders: sequential execution, configurable
/// intra-op threads.
///
/// By default forces CPU-only with arena allocator — sequential execution makes
/// GPU kernel launch overhead net-negative for per-token latency at batch size 1.
///
/// Call [`crate::set_decoder_gpu(true)`] to use the global accelerator preference
/// for decoder sessions (for GPU benchmarking).
///
/// Thread count resolution:
/// 1. `num_threads` if > 0
/// 2. Global `set_ort_intra_threads` if > 0
/// 3. ORT default (all cores)
pub fn create_decoder_session(path: &Path, num_threads: usize) -> Result<Session, ort::Error> {
let use_gpu = crate::accel::get_decoder_gpu();

let threads = resolve_intra_threads(num_threads);

let mut builder = Session::builder()?
.with_optimization_level(GraphOptimizationLevel::Level3)?
.with_parallel_execution(false)?;
if let Some(n) = threads {
builder = builder.with_intra_threads(n)?;
}

let session = if use_gpu {
log::info!("Decoder session using global accelerator (set_decoder_gpu=true)");
builder
.with_execution_providers(execution_providers())?
.commit_from_file(path)?
} else {
if get_ort_accelerator() != OrtAccelerator::CpuOnly {
log::info!(
"Decoder session uses CPU-only (sequential execution); \
call set_decoder_gpu(true) to override"
);
}
builder
.with_execution_providers([CPUExecutionProvider::default()
.with_arena_allocator(true)
.build()])?
.commit_from_file(path)?
};

for input in session.inputs() {
log::info!(
"Model input: name={}, type={:?}",
input.name(),
input.dtype()
);
}
for output in session.outputs() {
log::info!(
"Model output: name={}, type={:?}",
output.name(),
output.dtype()
);
}

Ok(session)
}