Skip to content

Commit 2314e44

Browse files
committed
feat: make wasi-nn immutable
Signed-off-by: Harald Hoyer <harald@profian.com>
1 parent 4475ebf commit 2314e44

2 files changed

Lines changed: 19 additions & 18 deletions

File tree

crates/wasi-nn/src/ctx.rs

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,15 @@ use crate::r#impl::UsageError;
66
use crate::witx::types::{Graph, GraphEncoding, GraphExecutionContext};
77
use std::collections::HashMap;
88
use std::hash::Hash;
9+
use std::sync::RwLock;
910
use thiserror::Error;
1011
use wiggle::GuestError;
1112

1213
/// Capture the state necessary for calling into the backend ML libraries.
1314
pub struct WasiNnCtx {
14-
pub(crate) backends: HashMap<u8, Box<dyn Backend>>,
15-
pub(crate) graphs: Table<Graph, Box<dyn BackendGraph>>,
16-
pub(crate) executions: Table<GraphExecutionContext, Box<dyn BackendExecutionContext>>,
15+
pub(crate) backends: RwLock<HashMap<u8, Box<dyn Backend>>>,
16+
pub(crate) graphs: RwLock<Table<Graph, Box<dyn BackendGraph>>>,
17+
pub(crate) executions: RwLock<Table<GraphExecutionContext, Box<dyn BackendExecutionContext>>>,
1718
}
1819

1920
impl WasiNnCtx {
@@ -27,9 +28,9 @@ impl WasiNnCtx {
2728
Box::new(OpenvinoBackend::default()) as Box<dyn Backend>,
2829
);
2930
Ok(Self {
30-
backends,
31-
graphs: Table::default(),
32-
executions: Table::default(),
31+
backends: RwLock::new(backends),
32+
graphs: RwLock::new(Table::default()),
33+
executions: RwLock::new(Table::default()),
3334
})
3435
}
3536
}

crates/wasi-nn/src/impl.rs

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,61 +26,61 @@ pub enum UsageError {
2626

2727
impl<'a> WasiEphemeralNn for WasiNnCtx {
2828
fn load<'b>(
29-
&mut self,
29+
&self,
3030
builders: &GraphBuilderArray<'_>,
3131
encoding: GraphEncoding,
3232
target: ExecutionTarget,
3333
) -> Result<Graph> {
3434
let encoding_id: u8 = encoding.into();
35-
let graph = if let Some(backend) = self.backends.get_mut(&encoding_id) {
35+
let graph = if let Some(backend) = self.backends.write().unwrap().get_mut(&encoding_id) {
3636
backend.load(builders, target)?
3737
} else {
3838
return Err(UsageError::InvalidEncoding(encoding).into());
3939
};
40-
let graph_id = self.graphs.insert(graph);
40+
let graph_id = self.graphs.write().unwrap().insert(graph);
4141
Ok(graph_id)
4242
}
4343

44-
fn init_execution_context(&mut self, graph_id: Graph) -> Result<GraphExecutionContext> {
45-
let exec_context = if let Some(graph) = self.graphs.get_mut(graph_id) {
44+
fn init_execution_context(&self, graph_id: Graph) -> Result<GraphExecutionContext> {
45+
let exec_context = if let Some(graph) = self.graphs.write().unwrap().get_mut(graph_id) {
4646
graph.init_execution_context()?
4747
} else {
4848
return Err(UsageError::InvalidGraphHandle.into());
4949
};
5050

51-
let exec_context_id = self.executions.insert(exec_context);
51+
let exec_context_id = self.executions.write().unwrap().insert(exec_context);
5252
Ok(exec_context_id)
5353
}
5454

5555
fn set_input<'b>(
56-
&mut self,
56+
&self,
5757
exec_context_id: GraphExecutionContext,
5858
index: u32,
5959
tensor: &Tensor<'b>,
6060
) -> Result<()> {
61-
if let Some(exec_context) = self.executions.get_mut(exec_context_id) {
61+
if let Some(exec_context) = self.executions.write().unwrap().get_mut(exec_context_id) {
6262
Ok(exec_context.set_input(index, tensor)?)
6363
} else {
6464
Err(UsageError::InvalidGraphHandle.into())
6565
}
6666
}
6767

68-
fn compute(&mut self, exec_context_id: GraphExecutionContext) -> Result<()> {
69-
if let Some(exec_context) = self.executions.get_mut(exec_context_id) {
68+
fn compute(&self, exec_context_id: GraphExecutionContext) -> Result<()> {
69+
if let Some(exec_context) = self.executions.write().unwrap().get_mut(exec_context_id) {
7070
Ok(exec_context.compute()?)
7171
} else {
7272
Err(UsageError::InvalidExecutionContextHandle.into())
7373
}
7474
}
7575

7676
fn get_output<'b>(
77-
&mut self,
77+
&self,
7878
exec_context_id: GraphExecutionContext,
7979
index: u32,
8080
out_buffer: &GuestPtr<'_, u8>,
8181
out_buffer_max_size: u32,
8282
) -> Result<u32> {
83-
if let Some(exec_context) = self.executions.get_mut(exec_context_id) {
83+
if let Some(exec_context) = self.executions.write().unwrap().get_mut(exec_context_id) {
8484
let mut destination = out_buffer
8585
.as_array(out_buffer_max_size)
8686
.as_slice_mut()?

0 commit comments

Comments
 (0)