diff --git a/Cargo.lock b/Cargo.lock index d389a718..1462112c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4623,12 +4623,12 @@ dependencies = [ "tokio", "tokio-stream", "toml", + "tonic", "tracing", "tracing-appender", "tracing-subscriber", "uuid", "weaver-database", - "weaver-embedding", "weaver-trace", ] @@ -4720,6 +4720,7 @@ dependencies = [ "tonic-build", "tower", "tracing", + "weaver-core", "weaver-inference", ] diff --git a/crates/weaver-core/Cargo.toml b/crates/weaver-core/Cargo.toml index 06b4e310..02dd2b55 100644 --- a/crates/weaver-core/Cargo.toml +++ b/crates/weaver-core/Cargo.toml @@ -23,8 +23,23 @@ tracing = { workspace = true } tracing-subscriber = { workspace = true } tracing-appender = { workspace = true } uuid = { workspace = true } +# `tonic` is a workspace-level transport dep used here only for the +# `From for EmbeddingError` and +# `From for EmbeddingError` impls in `embedder.rs`. +# Orphan rules require these impls to live in the trait's crate +# (weaver-core); they're how the gRPC backend in `weaver-embedding` +# (and elsewhere) propagates transport-layer errors via `?` into the +# abstract `EmbeddingError::Transport` variant. Removed in PR-3.A +# along with the gRPC client itself. +tonic = { workspace = true } weaver-database = { workspace = true } -weaver-embedding = { workspace = true } +# `weaver-embedding` was a direct dep until PR-0.5.B. Its only use was +# constructing `EmbeddingClient` directly inside `tools/memory/codebase_search.rs`. +# That call site now goes through `ExecutionContext.embedder` +# (`Arc`), which the daemon +# populates at boot. Dropping the dep breaks the cycle that would form +# once `weaver-embedding` adds its own `weaver-core` dep for the +# relocated trait. # Production dep as of Phase 1a Stage 1.2 — `engine::query` calls # `weaver_trace::current_span_id()` to capture the TOOL span's id at # tool-execute time so the resulting `MemoryRetrieval` block can carry diff --git a/crates/weaver-core/src/embedder.rs b/crates/weaver-core/src/embedder.rs new file mode 100644 index 00000000..5221c984 --- /dev/null +++ b/crates/weaver-core/src/embedder.rs @@ -0,0 +1,393 @@ +//! The `Embedder` trait — backend-agnostic interface for embedding services. +//! +//! Per `docs/specs/weaver-spu-Spec.md` §4.1, this trait + its associated types +//! live in `weaver-core` so trait-only consumers (e.g., `weaver-database`, +//! `weaver-core::surfacing`) don't pull a backend implementation crate into +//! their dep graph. Backend implementations live where their concrete code +//! lives: +//! +//! - **Pre-Phase-1 cutover**: gRPC client to the Python `weaver-embedder.service` +//! in `weaver-embedding::grpc_client::EmbeddingClient`. +//! - **Post-Phase-1 cutover**: in-process candle backend +//! `weaver-spu::encoder::client::EmbedderClient`. +//! +//! Consumers hold `Arc` and operate against either impl +//! transparently. +//! +//! # Why a trait, not just two structs +//! +//! Phase 1 ships the Rust backend behind a daemon-config flag +//! (`backend = "python" | "rust"` in `server.toml`). At daemon-boot time +//! the orchestrator picks one impl and hands it to the runtime as +//! `Arc`. The cutover is a one-line constructor change at +//! the wiring point; consumer code is unchanged. +//! +//! Once Phase 3 cleanup removes the Python service, the trait can either +//! stay as defense-in-depth (cheap; one unused indirection) or get +//! inlined against the single remaining impl. That's a future call. +//! +//! # Backend-agnostic types +//! +//! The trait deliberately surfaces backend-agnostic types: +//! +//! - [`EmbedResult`] — batch-embed result. +//! - [`LateChunkedResult`] / [`LateChunkResult`] — late-chunked embed +//! result. +//! - [`EmbedderInfo`] — identity / readiness, replaces the proto +//! `InfoResponse` at the trait boundary. +//! - [`EmbeddingError`] — abstract over transport-layer errors via a +//! `Box` payload, so this crate doesn't need a +//! tonic dep just to express gRPC failures. Backend impls map their +//! concrete transport errors into the abstract `Transport` variant via +//! `From` impls (see `weaver-embedding::grpc_client`). + +use async_trait::async_trait; + +/// Backend-agnostic identity / readiness information. +/// +/// Returned by [`Embedder::info`]. Replaces the proto `InfoResponse` +/// at the trait boundary so non-gRPC backends don't have to construct +/// proto types. +#[derive(Debug, Clone)] +pub struct EmbedderInfo { + /// Logical model identifier (e.g. `"jinaai/jina-embeddings-v4"`). + pub model_name: String, + /// `true` once the backend is ready to serve embed requests. + pub model_loaded: bool, + /// Output embedding dimension. + pub dimension: u32, + /// Maximum sequence length the backend accepts in a single embed + /// call (in tokens). + pub max_seq_length: u32, + /// Backend-supplied identifier for the loaded weights — typically + /// the HF snapshot SHA. Empty string if the backend doesn't surface + /// this. Used by the cohort-pin guard for identity-drift detection. + pub weight_revision: String, +} + +/// Result of an embedding operation. +#[derive(Debug, Clone)] +pub struct EmbedResult { + /// Embedding vectors, one per input text. + pub embeddings: Vec>, + /// Model identifier used. + pub model: String, + /// Embedding dimension. + pub dimension: u32, + /// Wall-clock time in milliseconds. + pub duration_ms: u64, +} + +/// Result of a late-chunked embedding operation — one [`LateChunkResult`] +/// per chunk produced by the service, in document order. +#[derive(Debug, Clone)] +pub struct LateChunkedResult { + pub chunks: Vec, + pub model: String, + pub dimension: u32, + pub duration_ms: u64, + pub context_window_used: u32, +} + +/// A single late-chunked span: its text, its position in the source +/// document, and its context-aware embedding. +#[derive(Debug, Clone)] +pub struct LateChunkResult { + pub text: String, + pub embedding: Vec, + pub start_char: u32, + pub end_char: u32, + pub start_token: u32, + pub end_token: u32, + pub chunk_index: u32, + pub total_chunks: u32, +} + +/// Error type for embedding-trait operations. +/// +/// Variants are deliberately abstract over transport so this crate +/// doesn't pick up a tonic dep just to express gRPC errors. Backend +/// implementations (e.g., the gRPC client in `weaver-embedding`) map +/// their concrete transport errors into [`Self::Transport`] via `From` +/// impls and the message preserves the upstream error's `Display`. +#[derive(Debug, thiserror::Error)] +pub enum EmbeddingError { + /// Transport-layer error (connection, status, RPC). Concrete error + /// preserved for chaining; the Display preserves the upstream + /// message. + #[error("transport error: {0}")] + Transport(Box), + + /// Backend returned a malformed response (wrong dimension, wrong + /// embedding count, missing fields, etc.). + #[error("invalid response: {0}")] + InvalidResponse(String), + + /// Backend reports as unavailable (service up but model not loaded + /// or in a terminal error state). + #[error("embedder not available: {0}")] + NotAvailable(String), +} + +// gRPC-transport `From` impls. These live here (not in +// `weaver-embedding`) because Rust's orphan rule requires foreign-on- +// foreign impls to be in the trait's crate. The cost is one tonic dep +// in `weaver-core`; the alternative (explicit `.map_err` at every `?` +// site in the gRPC client) is a larger maintenance burden than the +// dep. Removed in PR-3.A along with the gRPC client itself, when the +// last consumer (weaver-embedding) goes away. +impl From for EmbeddingError { + fn from(e: tonic::transport::Error) -> Self { + EmbeddingError::Transport(Box::new(e)) + } +} + +impl From for EmbeddingError { + fn from(e: tonic::Status) -> Self { + EmbeddingError::Transport(Box::new(e)) + } +} + +/// Backend-agnostic embedding service. +/// +/// Implementations: +/// - `weaver_embedding::grpc_client::EmbeddingClient` — Python service over +/// Unix socket / TCP gRPC (pre-Phase-1 cutover). +/// - `weaver_spu::encoder::client::EmbedderClient` — Rust in-process +/// candle backend (post-Phase-1 cutover). +/// +/// `Send + Sync` so consumers can hold `Arc` across task +/// boundaries. Object-safe via `#[async_trait]`. +#[async_trait] +pub trait Embedder: Send + Sync { + /// Embed a batch of texts. Each text in `texts` produces one + /// embedding vector in the result's `embeddings`. `task` is one of: + /// + /// - `"retrieval.query"` — `retrieval` adapter, `"Query: "` prefix + /// - `"retrieval.passage"` — `retrieval` adapter, `"Passage: "` prefix + /// - `"text-matching"` — `text-matching` adapter, `"Query: "` prefix + /// on both legs (symmetric) + /// - `"code"` — `code` adapter, no prefix + /// + /// Per `docs/specs/weaver-spu-Spec.md` §3 ground truth: 3 LoRA + /// adapters (`retrieval`, `text-matching`, `code`) with `query` / + /// `passage` as prompt-prefix variants of the single `retrieval` + /// adapter, dispatched at forward time on `task_label`. + /// + /// `batch_size = None` lets the backend pick. + async fn embed( + &self, + texts: &[String], + task: &str, + batch_size: Option, + ) -> Result; + + /// Convenience wrapper: embed a single text and return the vector + /// directly. Default implementation wraps [`Self::embed`] and + /// enforces an exact 1:1 response — anything other than exactly one + /// embedding is treated as a backend protocol violation, not + /// silently truncated. + async fn embed_one(&self, text: &str, task: &str) -> Result, EmbeddingError> { + let result = self.embed(&[text.to_string()], task, None).await?; + match result.embeddings.len() { + 1 => Ok(result.embeddings.into_iter().next().unwrap()), + n => Err(EmbeddingError::InvalidResponse(format!( + "expected exactly 1 embedding for embed_one, got {n}" + ))), + } + } + + /// Embed a long document with late chunking — the backend encodes + /// the whole document once, then derives per-chunk vectors from the + /// contextual token states. Returns one chunk per element in + /// document order. + async fn embed_late_chunked( + &self, + text: &str, + task: &str, + ) -> Result; + + /// Backend identity + readiness. The returned [`EmbedderInfo`] is + /// the basis for the cohort-pin guard's drift detection at daemon + /// boot. + async fn info(&self) -> Result; + + /// Liveness check. Default implementation returns `true` if + /// [`Self::info`] succeeds and reports `model_loaded`. Backends can + /// override for cheaper probes. + async fn health_check(&self) -> bool { + match self.info().await { + Ok(info) => info.model_loaded, + Err(_) => false, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::Arc; + + /// Static fixture impl. Confirms the trait is object-safe (compiles + /// at all if you can construct `Arc` from a concrete + /// type), and exercises the default `embed_one` and `health_check` + /// implementations against a controllable backend. + struct FixtureEmbedder { + loaded: bool, + dimension: u32, + } + + #[async_trait] + impl Embedder for FixtureEmbedder { + async fn embed( + &self, + texts: &[String], + _task: &str, + _batch_size: Option, + ) -> Result { + Ok(EmbedResult { + embeddings: texts + .iter() + .map(|_| vec![0.0_f32; self.dimension as usize]) + .collect(), + model: "fixture".into(), + dimension: self.dimension, + duration_ms: 0, + }) + } + + async fn embed_late_chunked( + &self, + text: &str, + _task: &str, + ) -> Result { + Ok(LateChunkedResult { + chunks: vec![LateChunkResult { + text: text.into(), + embedding: vec![0.0_f32; self.dimension as usize], + start_char: 0, + end_char: text.len() as u32, + start_token: 0, + end_token: 0, + chunk_index: 0, + total_chunks: 1, + }], + model: "fixture".into(), + dimension: self.dimension, + duration_ms: 0, + context_window_used: 0, + }) + } + + async fn info(&self) -> Result { + Ok(EmbedderInfo { + model_name: "fixture".into(), + model_loaded: self.loaded, + dimension: self.dimension, + max_seq_length: 32_768, + weight_revision: "fixture-rev".into(), + }) + } + } + + #[tokio::test] + async fn trait_is_object_safe() { + // The compile is the test — if the trait is object-unsafe this + // expression fails to compile. (Object-safety is what lets the + // surfacing engine hold `Arc` at runtime per + // `docs/specs/weaver-spu-Spec.md` §4.1.) + let _: Arc = Arc::new(FixtureEmbedder { + loaded: true, + dimension: 8, + }); + } + + #[tokio::test] + async fn embed_one_default_wraps_embed() { + let e = FixtureEmbedder { + loaded: true, + dimension: 4, + }; + let v = e.embed_one("hello", "retrieval.query").await.unwrap(); + assert_eq!(v.len(), 4); + assert!(v.iter().all(|&x| x == 0.0)); + } + + /// A backend that returns the wrong number of embeddings from + /// `embed` (server-side bug or protocol-version skew) must produce + /// a clear `InvalidResponse` error through `embed_one`'s default + /// impl, not panic and not silently truncate. + struct CountedBackend(Vec>); + + #[async_trait] + impl Embedder for CountedBackend { + async fn embed( + &self, + _: &[String], + _: &str, + _: Option, + ) -> Result { + Ok(EmbedResult { + embeddings: self.0.clone(), + model: "counted".into(), + dimension: self.0.first().map_or(0, |v| v.len() as u32), + duration_ms: 0, + }) + } + async fn embed_late_chunked( + &self, + _: &str, + _: &str, + ) -> Result { + unreachable!() + } + async fn info(&self) -> Result { + unreachable!() + } + } + + #[tokio::test] + async fn embed_one_errors_when_backend_returns_no_embeddings() { + let backend = CountedBackend(vec![]); + let err = backend.embed_one("x", "y").await.unwrap_err(); + match err { + EmbeddingError::InvalidResponse(msg) => { + assert!( + msg.contains("expected exactly 1") && msg.contains("got 0"), + "msg = {msg:?}" + ); + } + other => panic!("expected InvalidResponse, got {other:?}"), + } + } + + #[tokio::test] + async fn embed_one_errors_when_backend_returns_multiple_embeddings() { + let backend = CountedBackend(vec![vec![1.0_f32; 4], vec![2.0_f32; 4]]); + let err = backend.embed_one("x", "y").await.unwrap_err(); + match err { + EmbeddingError::InvalidResponse(msg) => { + assert!( + msg.contains("expected exactly 1") && msg.contains("got 2"), + "msg = {msg:?}" + ); + } + other => panic!("expected InvalidResponse, got {other:?}"), + } + } + + #[tokio::test] + async fn health_check_default_reflects_model_loaded() { + let loaded = FixtureEmbedder { + loaded: true, + dimension: 4, + }; + assert!(loaded.health_check().await); + + let unloaded = FixtureEmbedder { + loaded: false, + dimension: 4, + }; + assert!(!unloaded.health_check().await); + } +} diff --git a/crates/weaver-core/src/engine/runtime.rs b/crates/weaver-core/src/engine/runtime.rs index 4c1944b4..97074def 100644 --- a/crates/weaver-core/src/engine/runtime.rs +++ b/crates/weaver-core/src/engine/runtime.rs @@ -214,6 +214,7 @@ impl ContextNapCallback for HadesNapCallback { db_pool: Some(self.pool.clone()), codebase_pool: None, memory_bypass: false, + embedder: None, }; let ts = chrono::Utc::now().timestamp_millis(); @@ -611,6 +612,30 @@ unsafe impl Send for AgentHandle {} unsafe impl Sync for AgentHandle {} impl AgentHandle { + /// Inject an `Embedder` backend into the agent's `ExecutionContext`. + /// + /// `AgentRuntime::spawn_*` constructs every agent with + /// `exec_ctx.embedder = None` because the runtime layer has no + /// reference to the daemon's embedder backend at spawn time. The + /// daemon (`weaver-interface::serve` per + /// `docs/specs/weaver-spu-Spec.md` PR-1.H) calls this method + /// post-spawn to wire the backend before the agent processes any + /// turn. Tools that require embeddings (currently + /// `CodebaseSearch`) check `ctx.embedder.is_some()` and return + /// `ToolError::ExecutionFailed` if it isn't populated, so calling + /// this is required before any embed-using tool gets invoked. + /// + /// Idempotent: calling twice replaces the prior backend without + /// re-spawning the agent. Pre-Phase-1 this is the gRPC client to + /// `weaver-embedder.service`; post-Phase-1 it's the in-process + /// candle backend in `weaver-spu`. + pub fn set_embedder( + &mut self, + embedder: std::sync::Arc, + ) { + self.exec_ctx.embedder = Some(embedder); + } + /// Send a user message and run the agentic loop. /// /// Returns a receiver that streams `AgentEvent`s **in real time** as @@ -876,11 +901,8 @@ impl AgentRuntime { // 2b. Pre-flight: ensure the model is loaded on the inference server // Bounded by the agent's configured timeout to avoid hanging on stalled sockets. let preflight_timeout = std::time::Duration::from_secs(timeout_secs); - match tokio::time::timeout( - preflight_timeout, - ensure_model_available(&provider, &state), - ) - .await + match tokio::time::timeout(preflight_timeout, ensure_model_available(&provider, &state)) + .await { Ok(result) => result?, Err(_) => { @@ -952,6 +974,12 @@ impl AgentRuntime { db_pool, codebase_pool, memory_bypass: state.memory_bypass.unwrap_or(false), + // Daemon (`weaver-interface::serve`) populates `embedder` at + // boot when the agent runtime is constructed; the engine + // runtime here is the agent-internal runtime that doesn't + // own the embedder backend, so it stays None until the + // calling code injects one (Phase 1 wiring lands in PR-1.H). + embedder: None, }; // 4. Build the scoped ToolRegistry @@ -1169,7 +1197,10 @@ async fn ensure_model_available( // 3. Model not loaded — can we auto-load it? let is_gguf = state.decoder_backend() == Some(BackendSpec::Gguf); - let path = state.decoder_path().map(str::trim).filter(|p| !p.is_empty()); + let path = state + .decoder_path() + .map(str::trim) + .filter(|p| !p.is_empty()); if !is_gguf || path.is_none() { return Err(RuntimeError::ModelLoadFailed(format!( diff --git a/crates/weaver-core/src/lib.rs b/crates/weaver-core/src/lib.rs index dd063c96..ee6868df 100644 --- a/crates/weaver-core/src/lib.rs +++ b/crates/weaver-core/src/lib.rs @@ -3,6 +3,7 @@ pub mod blocks; pub mod channel; pub mod config; pub mod default_prompt; +pub mod embedder; pub mod engine; pub mod error; pub mod gpu; diff --git a/crates/weaver-core/src/tool.rs b/crates/weaver-core/src/tool.rs index e42db6dd..66a954bf 100644 --- a/crates/weaver-core/src/tool.rs +++ b/crates/weaver-core/src/tool.rs @@ -185,7 +185,7 @@ pub enum PermissionResult { } /// Context available during tool execution. -#[derive(Debug, Clone)] +#[derive(Clone)] pub struct ExecutionContext { pub working_directory: std::path::PathBuf, /// Codebase/project-graph config for code-navigation tools. None if no @@ -207,4 +207,43 @@ pub struct ExecutionContext { /// calls with synthetic empty results so the Memory↔SPU contract carries /// no information; no ArangoDB traffic is emitted. Defaults to `false`. pub memory_bypass: bool, + /// Embedder backend handle for tools that need to embed query text + /// (currently `CodebaseSearch`). Populated by `weaver-interface::serve` + /// at daemon-boot time (production) or constructed per-test (test + /// fixtures). `None` when no embedder is configured — tools that + /// require embeddings return an error in that case rather than + /// constructing a transport themselves. + /// + /// Per `docs/specs/weaver-spu-Spec.md` PR-0.5.B: the trait was + /// relocated from `weaver-embedding::embedder` to + /// `weaver-core::embedder`; the concrete impl is supplied by the + /// daemon (currently `weaver-embedding::grpc_client::EmbeddingClient`, + /// post-Phase-1 `weaver-spu::encoder::client::EmbedderClient`). + pub embedder: Option>, +} + +// Manual `Debug` impl: `dyn Embedder` is not `Debug`-bound (derives +// would require all fields to be `Debug`). The trait surface in +// `weaver-core::embedder` doesn't require `Debug` because no consumer +// pretty-prints embedder backends — they call methods. The other +// fields print transparently; the embedder slot prints as a presence +// flag. +impl std::fmt::Debug for ExecutionContext { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ExecutionContext") + .field("working_directory", &self.working_directory) + .field("hades", &self.hades) + .field("db_pool", &self.db_pool) + .field("codebase_pool", &self.codebase_pool) + .field("memory_bypass", &self.memory_bypass) + .field( + "embedder", + &if self.embedder.is_some() { + "Some(Arc)" + } else { + "None" + }, + ) + .finish() + } } diff --git a/crates/weaver-core/src/tools/memory/codebase_search.rs b/crates/weaver-core/src/tools/memory/codebase_search.rs index 9dc7ef00..03d12dfa 100644 --- a/crates/weaver-core/src/tools/memory/codebase_search.rs +++ b/crates/weaver-core/src/tools/memory/codebase_search.rs @@ -69,7 +69,6 @@ impl Tool for CodebaseSearchTool { async fn execute(&self, input: Value, ctx: &ExecutionContext) -> Result { use weaver_database::db::collections::CODEBASE; - use weaver_embedding::grpc_client::EmbeddingClient; let query = input["query"].as_str().unwrap(); let limit = input["limit"] @@ -80,11 +79,20 @@ impl Tool for CodebaseSearchTool { let pool = require_codebase_pool(ctx)?; - // Embed the query via the project embedder service. - let embedder = EmbeddingClient::connect_default().await.map_err(|e| { - ToolError::ExecutionFailed(format!( - "Cannot connect to embedder service. Is weaver-embedder running? ({e})" - )) + // Embed the query via the configured embedder backend. The + // daemon plumbs an `Arc` through `ExecutionContext` + // at boot (per `docs/specs/weaver-spu-Spec.md` PR-0.5.B); + // pre-Phase-1 this is the gRPC client to weaver-embedder.service, + // post-Phase-1 it's the in-process candle backend in weaver-spu. + // No daemon-injected embedder = no codebase_search; the tool + // refuses rather than try to construct a transport itself. + let embedder = ctx.embedder.as_ref().ok_or_else(|| { + ToolError::ExecutionFailed( + "CodebaseSearch requires an embedder backend, but ExecutionContext has none. \ + Daemon should populate ExecutionContext.embedder at boot \ + (weaver-interface::serve)." + .to_string(), + ) })?; let query_vec = embedder .embed_one(query, "retrieval.query") diff --git a/crates/weaver-core/tests/engine_tests.rs b/crates/weaver-core/tests/engine_tests.rs index 1f530df6..c0ea65a2 100644 --- a/crates/weaver-core/tests/engine_tests.rs +++ b/crates/weaver-core/tests/engine_tests.rs @@ -73,6 +73,7 @@ fn test_ctx() -> ExecutionContext { db_pool: None, codebase_pool: None, memory_bypass: false, + embedder: None, } } @@ -913,6 +914,7 @@ async fn budget_manager_fires_red_nap_and_scrubs_yellow_hint() { db_pool: None, codebase_pool: None, memory_bypass: false, + embedder: None, }; let perm_ctx = PermissionContext { mode: PermissionMode::AllowAll, diff --git a/crates/weaver-core/tests/multi_agent_tests.rs b/crates/weaver-core/tests/multi_agent_tests.rs index d01850d0..9d9750ed 100644 --- a/crates/weaver-core/tests/multi_agent_tests.rs +++ b/crates/weaver-core/tests/multi_agent_tests.rs @@ -71,6 +71,7 @@ fn test_ctx() -> ExecutionContext { db_pool: None, codebase_pool: None, memory_bypass: false, + embedder: None, } } @@ -182,6 +183,7 @@ async fn two_agents_concurrent_tool_calls() { db_pool: None, codebase_pool: None, memory_bypass: false, + embedder: None, }; let ctx_b = test_ctx(); let registry_a = ToolRegistry::with_builtins(&ctx_a); diff --git a/crates/weaver-core/tests/tool_tests.rs b/crates/weaver-core/tests/tool_tests.rs index 28596a1a..99178845 100644 --- a/crates/weaver-core/tests/tool_tests.rs +++ b/crates/weaver-core/tests/tool_tests.rs @@ -32,6 +32,7 @@ fn ctx(dir: &TempDir) -> ExecutionContext { db_pool: None, codebase_pool: None, memory_bypass: false, + embedder: None, } } diff --git a/crates/weaver-demo/src/hanoi/mod.rs b/crates/weaver-demo/src/hanoi/mod.rs index 2011fd3a..96e95756 100644 --- a/crates/weaver-demo/src/hanoi/mod.rs +++ b/crates/weaver-demo/src/hanoi/mod.rs @@ -556,6 +556,7 @@ mod tests { db_pool: None, codebase_pool: None, memory_bypass: false, + embedder: None, } } diff --git a/crates/weaver-demo/src/herobench/tools.rs b/crates/weaver-demo/src/herobench/tools.rs index 62840b47..aff810a8 100644 --- a/crates/weaver-demo/src/herobench/tools.rs +++ b/crates/weaver-demo/src/herobench/tools.rs @@ -2294,6 +2294,7 @@ mod tests { db_pool: None, // no HADES codebase_pool: None, memory_bypass: false, + embedder: None, }; let input = json!({ "task_name": "test", @@ -2323,6 +2324,7 @@ mod tests { db_pool: Some(pool), codebase_pool: None, memory_bypass: false, + embedder: None, }; // Cycle: a depends on b, b depends on a let input = json!({ diff --git a/crates/weaver-embedding/Cargo.toml b/crates/weaver-embedding/Cargo.toml index 70d33d0a..ca200ea3 100644 --- a/crates/weaver-embedding/Cargo.toml +++ b/crates/weaver-embedding/Cargo.toml @@ -41,6 +41,14 @@ candle = ["dep:candle-core", "dep:candle-nn", "dep:tokenizers"] candle-cuda = ["candle", "candle-core/cuda", "candle-nn/cuda"] [dependencies] +# `weaver-core` carries the `Embedder` trait + associated types +# post-PR-0.5.B. `grpc_client::EmbeddingClient`'s `Embedder` impl +# is in this crate (allowed by orphan rules: trait in weaver-core, +# concrete type in this crate). `embedder.rs` re-exports the trait +# during the transition window so existing consumers compile +# unchanged. +weaver-core = { workspace = true } + # `tracing` and `anyhow` are unconditional — `late_chunk` is a # pure-stdlib algorithm but we want module-level instrumentation # hooks and ergonomic error returns available for any future diff --git a/crates/weaver-embedding/src/embedder.rs b/crates/weaver-embedding/src/embedder.rs index 31194f4b..926e0f0d 100644 --- a/crates/weaver-embedding/src/embedder.rs +++ b/crates/weaver-embedding/src/embedder.rs @@ -1,326 +1,30 @@ -//! The `Embedder` trait — backend-agnostic interface for embedding services. -//! -//! Per `embedder-oxidization-Spec.md` §3 / §5.2, both the existing Python -//! gRPC backend ([`crate::grpc_client::EmbeddingClient`]) and the future -//! Rust in-process backend (Phase 1's `EmbedderClient`) implement this -//! trait. Consumers (the surfacing engine, Notepad, Pen, sleep stage A, -//! preseed materialization) take `Arc` and stay backend- -//! agnostic. -//! -//! # Why a trait, not just two structs -//! -//! Phase 1 ships the Rust backend behind a daemon-config flag -//! (`backend = "python" | "rust"` in `server.toml` per -//! `embedder-oxidization-Spec.md` §6.3). At daemon-boot time the -//! orchestrator picks one impl and hands it to the runtime as -//! `Arc`. The cutover is a one-line constructor change -//! at the wiring point; consumer code is unchanged. -//! -//! Once Phase 2 (`embedder-oxidization-Spec.md` §10 Phase 2 cleanup) -//! removes the Python service, the trait can either stay as -//! defense-in-depth (cheap; one unused indirection) or get inlined -//! against the single remaining impl. That's a future call. -//! -//! # Backend-agnostic types -//! -//! The trait deliberately does NOT return [`crate::proto::embedding`] -//! types directly. Doing so would force every backend to construct -//! proto messages even when the wire format isn't involved (the Rust -//! in-process backend has no gRPC; a hypothetical -//! [`candle`](https://github.com/huggingface/candle)-based backend -//! has no llama.cpp; etc). Instead, the trait surfaces: -//! -//! - [`crate::grpc_client::EmbedResult`] — already a backend-agnostic -//! struct (lives in `grpc_client.rs` for now; could move here if -//! the grpc_client impl path gets retired). -//! - [`crate::grpc_client::LateChunkedResult`] / -//! [`crate::grpc_client::LateChunkResult`] — same shape. -//! - [`EmbedderInfo`] (defined here) — replaces the proto -//! `InfoResponse` for the trait surface. -//! - [`crate::grpc_client::EmbeddingError`] — already abstract enough. -//! -//! When the Phase 0 → Phase 1 transition completes and the Python -//! backend retires, a follow-up PR may relocate the result/error -//! types out of `grpc_client.rs` into a neutral `types.rs` module. -//! Not in scope here. - -use async_trait::async_trait; - -use crate::grpc_client::{EmbedResult, EmbeddingError, LateChunkedResult}; - -/// Backend-agnostic identity / readiness information. -/// -/// Returned by [`Embedder::info`]. Replaces the proto `InfoResponse` -/// at the trait boundary so non-gRPC backends don't have to -/// construct proto types. -#[derive(Debug, Clone)] -pub struct EmbedderInfo { - /// Logical model identifier (e.g. `"jinaai/jina-embeddings-v4"`). - pub model_name: String, - /// `true` once the backend is ready to serve embed requests. - pub model_loaded: bool, - /// Output embedding dimension. - pub dimension: u32, - /// Maximum sequence length the backend accepts in a single - /// embed call (in tokens). - pub max_seq_length: u32, - /// Backend-supplied identifier for the loaded weights — typically - /// the HF snapshot SHA. Empty string if the backend doesn't - /// surface this. Used by `crate::pin` for cohort-identity - /// drift detection. - pub weight_revision: String, -} - -/// Backend-agnostic embedding service. -/// -/// Implementations live in: -/// - [`crate::grpc_client::EmbeddingClient`] — Python service over -/// Unix socket / TCP gRPC. -/// - (Phase 1) `crate::embedder_client::EmbedderClient` — Rust -/// in-process backend wrapping `weaver-inference`'s -/// `LlamaContext` for native Jina V4 GGUF inference. -/// -/// `Send + Sync` so consumers can hold `Arc` across -/// task boundaries. Object-safe via `#[async_trait]`. -#[async_trait] -pub trait Embedder: Send + Sync { - /// Embed a batch of texts. Each text in `texts` produces one - /// embedding vector in the result's `embeddings`. `task` is one - /// of Jina V4's task hints - /// (`"retrieval.query"`, `"retrieval.passage"`, `"classification"`, - /// `"clustering"`, `"separation"`); the backend selects the - /// matching LoRA adapter at inference time. - /// - /// `batch_size = None` lets the backend pick. - async fn embed( - &self, - texts: &[String], - task: &str, - batch_size: Option, - ) -> Result; - - /// Convenience wrapper: embed a single text and return the - /// vector directly. Default implementation wraps [`Self::embed`] - /// and enforces an exact 1:1 response — anything other than - /// exactly one embedding is treated as a backend protocol - /// violation, not silently truncated. The strict check matches - /// what `EmbeddingClient::embed` already enforces for its - /// inherent surface (`expected N, got M`); applying it here - /// gives consumers using `Arc` the same guarantee. - async fn embed_one(&self, text: &str, task: &str) -> Result, EmbeddingError> { - let result = self.embed(&[text.to_string()], task, None).await?; - match result.embeddings.len() { - 1 => Ok(result.embeddings.into_iter().next().unwrap()), - n => Err(EmbeddingError::InvalidResponse(format!( - "expected exactly 1 embedding for embed_one, got {n}" - ))), - } - } - - /// Embed a long document with late chunking — the backend - /// encodes the whole document once, then derives per-chunk - /// vectors from the contextual token states. Returns one - /// chunk per element in document order. See `late_chunk` for - /// the math; the backend implements the chunk-derivation - /// step natively (gRPC backend delegates to the Persephone - /// service; Rust backend will run the document encode locally - /// and apply [`crate::late_chunk::late_chunk_embeddings`]). - async fn embed_late_chunked( - &self, - text: &str, - task: &str, - ) -> Result; - - /// Backend identity + readiness. The returned [`EmbedderInfo`] - /// is the basis for [`crate::pin`]'s drift detection at daemon - /// boot. - async fn info(&self) -> Result; - - /// Liveness check. Default implementation returns `true` if - /// [`Self::info`] succeeds and reports `model_loaded`. Backends - /// can override for cheaper probes. - async fn health_check(&self) -> bool { - match self.info().await { - Ok(info) => info.model_loaded, - Err(_) => false, - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::grpc_client::{EmbedResult, EmbeddingError, LateChunkResult, LateChunkedResult}; - use std::sync::Arc; - - /// Static fixture impl. Confirms the trait is object-safe (compiles - /// at all if you can construct `Arc` from a concrete - /// type), and exercises the default `embed_one` and `health_check` - /// implementations against a controllable backend. - struct FixtureEmbedder { - loaded: bool, - dimension: u32, - } - - #[async_trait] - impl Embedder for FixtureEmbedder { - async fn embed( - &self, - texts: &[String], - _task: &str, - _batch_size: Option, - ) -> Result { - Ok(EmbedResult { - embeddings: texts - .iter() - .map(|_| vec![0.0_f32; self.dimension as usize]) - .collect(), - model: "fixture".into(), - dimension: self.dimension, - duration_ms: 0, - }) - } - - async fn embed_late_chunked( - &self, - text: &str, - _task: &str, - ) -> Result { - Ok(LateChunkedResult { - chunks: vec![LateChunkResult { - text: text.into(), - embedding: vec![0.0_f32; self.dimension as usize], - start_char: 0, - end_char: text.len() as u32, - start_token: 0, - end_token: 0, - chunk_index: 0, - total_chunks: 1, - }], - model: "fixture".into(), - dimension: self.dimension, - duration_ms: 0, - context_window_used: 0, - }) - } - - async fn info(&self) -> Result { - Ok(EmbedderInfo { - model_name: "fixture".into(), - model_loaded: self.loaded, - dimension: self.dimension, - max_seq_length: 32_768, - weight_revision: "fixture-rev".into(), - }) - } - } - - #[tokio::test] - async fn trait_is_object_safe() { - // The compile is the test — if the trait is object-unsafe this - // expression fails to compile. (Object-safety is what lets the - // surfacing engine hold `Arc` at runtime per - // `embedder-oxidization-Spec.md` §3.) - let _: Arc = Arc::new(FixtureEmbedder { - loaded: true, - dimension: 8, - }); - } - - #[tokio::test] - async fn embed_one_default_wraps_embed() { - let e = FixtureEmbedder { - loaded: true, - dimension: 4, - }; - let v = e.embed_one("hello", "retrieval.query").await.unwrap(); - assert_eq!(v.len(), 4); - assert!(v.iter().all(|&x| x == 0.0)); - } - - /// A backend that returns the wrong number of embeddings from - /// `embed` (server-side bug or protocol-version skew) must - /// produce a clear `InvalidResponse` error through `embed_one`'s - /// default impl, not panic and not silently truncate. - /// - /// Tested with both an empty result (legacy edge case) and a - /// multi-result (new edge case per the strict 1:1 check) — the - /// CR finding called out that `next()`-style "take the first" - /// semantics silently swallow multi-result backend bugs. - struct CountedBackend(Vec>); - #[async_trait] - impl Embedder for CountedBackend { - async fn embed( - &self, - _: &[String], - _: &str, - _: Option, - ) -> Result { - Ok(EmbedResult { - embeddings: self.0.clone(), - model: "counted".into(), - dimension: self.0.first().map_or(0, |v| v.len() as u32), - duration_ms: 0, - }) - } - async fn embed_late_chunked( - &self, - _: &str, - _: &str, - ) -> Result { - unreachable!() - } - async fn info(&self) -> Result { - unreachable!() - } - } - - #[tokio::test] - async fn embed_one_errors_when_backend_returns_no_embeddings() { - let backend = CountedBackend(vec![]); - let err = backend.embed_one("x", "y").await.unwrap_err(); - match err { - EmbeddingError::InvalidResponse(msg) => { - assert!( - msg.contains("expected exactly 1") && msg.contains("got 0"), - "msg = {msg:?}" - ); - } - other => panic!("expected InvalidResponse, got {other:?}"), - } - } - - #[tokio::test] - async fn embed_one_errors_when_backend_returns_multiple_embeddings() { - // The strict-length check protects against silent truncation - // when a buggy backend returns N>1 embeddings for a 1-text - // input. Pre-fix the default impl would silently take the - // first; post-fix it errors loudly. - let backend = CountedBackend(vec![vec![1.0_f32; 4], vec![2.0_f32; 4]]); - let err = backend.embed_one("x", "y").await.unwrap_err(); - match err { - EmbeddingError::InvalidResponse(msg) => { - assert!( - msg.contains("expected exactly 1") && msg.contains("got 2"), - "msg = {msg:?}" - ); - } - other => panic!("expected InvalidResponse, got {other:?}"), - } - } - - #[tokio::test] - async fn health_check_default_reflects_model_loaded() { - let loaded = FixtureEmbedder { - loaded: true, - dimension: 4, - }; - assert!(loaded.health_check().await); - - let unloaded = FixtureEmbedder { - loaded: false, - dimension: 4, - }; - assert!(!unloaded.health_check().await); - } -} +//! Deprecated re-export — the `Embedder` trait + types relocated to +//! `weaver-core::embedder` per `docs/specs/weaver-spu-Spec.md` PR-0.5.B. +//! +//! New code should import from `weaver_core::embedder::*` directly. +//! This module remains during the transition window so existing +//! consumers compile unchanged. Removed in PR-0.5.E along with the +//! whole `weaver-embedding` crate (folded into `weaver-spu`). + +// `Embedder` is a trait — Rust has no stable trait-alias mechanism, so +// the only available form is `pub use` with a `#[deprecated]` attribute. +// Caveat: `#[deprecated]` on `pub use` does not reliably emit a warning +// at consumer call sites (see grpc_client.rs's note on this). Consumers +// migrating to `weaver_core::embedder::Embedder` should be informed via +// the spec + commit history rather than relying solely on rustc to +// flag it. Removed in PR-0.5.E. +#[deprecated( + since = "0.1.0", + note = "moved to weaver_core::embedder::Embedder; this re-export goes away in PR-0.5.E" +)] +pub use weaver_core::embedder::Embedder; + +// Type aliases (not `pub use`) so deprecation warnings fire reliably at +// consumer call sites. Same underlying type — any consumer that has +// already migrated to `weaver_core::embedder::EmbedderInfo` keeps +// working unchanged. Removed in PR-0.5.E. +#[deprecated( + since = "0.1.0", + note = "moved to weaver_core::embedder::EmbedderInfo; this alias goes away in PR-0.5.E" +)] +pub type EmbedderInfo = weaver_core::embedder::EmbedderInfo; diff --git a/crates/weaver-embedding/src/grpc_client.rs b/crates/weaver-embedding/src/grpc_client.rs index 0d6163a1..0d2cc722 100644 --- a/crates/weaver-embedding/src/grpc_client.rs +++ b/crates/weaver-embedding/src/grpc_client.rs @@ -4,6 +4,17 @@ //! or TCP endpoint. Wraps the generated tonic client with connection //! management, health checking, and ergonomic Rust types. +// Module-scoped `allow(deprecated)` — this file defines the deprecated +// type aliases (`EmbedResult`, `EmbeddingError`, `LateChunkResult`, +// `LateChunkedResult`, `EmbedderInfo`) per PR-0.5.B and uses them +// throughout its own implementation. The deprecations are for *external* +// consumers (so they migrate to `weaver_core::embedder::*` before +// PR-0.5.E); internal self-uses are benign and would otherwise fail +// `-D warnings`. Narrower than a crate-level allow per the reviewer's +// "narrower scoping" guidance — the allow is bounded to this single +// retiring module. +#![allow(deprecated)] + use std::path::{Path, PathBuf}; use std::time::Duration; @@ -64,26 +75,55 @@ impl Default for EmbeddingClientConfig { } } -/// Error type for embedding client operations. -#[derive(Debug, thiserror::Error)] -pub enum EmbeddingError { - /// gRPC transport/connection error. - #[error("connection error: {0}")] - Connection(#[from] tonic::transport::Error), - - /// gRPC status error from the service. - #[error("service error: {0}")] - Status(#[from] tonic::Status), - - /// Invalid response from the service. - #[error("invalid response: {0}")] - InvalidResponse(String), - - /// Embedder reports as unavailable (service up but model not loaded or - /// in a terminal error state). - #[error("embedder not available: {0}")] - NotAvailable(String), -} +// `EmbeddingError`, `EmbedResult`, `LateChunkedResult`, `LateChunkResult`, +// and `EmbedderInfo` relocated to `weaver-core::embedder` per +// `docs/specs/weaver-spu-Spec.md` PR-0.5.B. Backwards-compat shims here +// during the migration window. Removed in PR-0.5.E. +// +// **Why type aliases (not `pub use`)**: `#[deprecated]` on a `pub use` +// re-export does not reliably emit a warning at consumer call sites +// in stable/nightly rustc — the attribute attaches to the re-export +// itself, not to the consumer-side import. Deprecated type aliases +// emit warnings consistently when consumers reference the alias name, +// giving downstream callers a clear migration path before PR-0.5.E +// removal. Identical underlying types, so any consumer that has +// already migrated to `weaver_core::embedder::*` keeps working +// unchanged. +// +// The trait surface in weaver-core abstracts over transport via +// `EmbeddingError::Transport(Box)`. The +// `From` impls live in weaver-core (orphan rule) and let +// `?` propagate raw tonic errors through the call sites below. + +#[deprecated( + since = "0.1.0", + note = "moved to weaver_core::embedder::EmbedResult; this alias goes away in PR-0.5.E" +)] +pub type EmbedResult = weaver_core::embedder::EmbedResult; + +#[deprecated( + since = "0.1.0", + note = "moved to weaver_core::embedder::EmbedderInfo; this alias goes away in PR-0.5.E" +)] +pub type EmbedderInfo = weaver_core::embedder::EmbedderInfo; + +#[deprecated( + since = "0.1.0", + note = "moved to weaver_core::embedder::EmbeddingError; this alias goes away in PR-0.5.E" +)] +pub type EmbeddingError = weaver_core::embedder::EmbeddingError; + +#[deprecated( + since = "0.1.0", + note = "moved to weaver_core::embedder::LateChunkResult; this alias goes away in PR-0.5.E" +)] +pub type LateChunkResult = weaver_core::embedder::LateChunkResult; + +#[deprecated( + since = "0.1.0", + note = "moved to weaver_core::embedder::LateChunkedResult; this alias goes away in PR-0.5.E" +)] +pub type LateChunkedResult = weaver_core::embedder::LateChunkedResult; /// Client for the Persephone embedding service. /// @@ -303,26 +343,39 @@ impl EmbeddingClient { /// don't also have to know the tonic error taxonomy. #[instrument(skip(self))] pub async fn ensure_ready(&self) -> Result { + // Transport-class errors (connection failure, gRPC status with + // an "unavailable / deadline / failed-precondition" code) all + // map to the abstract `Transport` variant in + // `weaver-core::embedder`. We downcast to the concrete tonic + // types here to preserve the prior code-class discrimination — + // gRPC `Status` codes UNAVAILABLE / DEADLINE_EXCEEDED / + // FAILED_PRECONDITION are mapped to `NotAvailable`; raw transport + // errors (TCP / Unix socket / HTTP) likewise. Other transport + // errors propagate unchanged. let info = match self.info().await { Ok(info) => info, - Err(EmbeddingError::Status(s)) - if matches!( - s.code(), - tonic::Code::Unavailable - | tonic::Code::DeadlineExceeded - | tonic::Code::FailedPrecondition, - ) => - { - return Err(EmbeddingError::NotAvailable(format!( - "embedder Info failed: {} ({})", - s.code(), - s.message() - ))); - } - Err(EmbeddingError::Connection(e)) => { - return Err(EmbeddingError::NotAvailable(format!( - "embedder transport error: {e}" - ))); + Err(EmbeddingError::Transport(e)) => { + if let Some(status) = e.downcast_ref::() { + if matches!( + status.code(), + tonic::Code::Unavailable + | tonic::Code::DeadlineExceeded + | tonic::Code::FailedPrecondition, + ) { + return Err(EmbeddingError::NotAvailable(format!( + "embedder Info failed: {} ({})", + status.code(), + status.message() + ))); + } + return Err(EmbeddingError::Transport(e)); + } + if e.downcast_ref::().is_some() { + return Err(EmbeddingError::NotAvailable(format!( + "embedder transport error: {e}" + ))); + } + return Err(EmbeddingError::Transport(e)); } Err(e) => return Err(e), }; @@ -424,43 +477,10 @@ impl EmbeddingClient { } } -/// Result of an embedding operation. -#[derive(Debug, Clone)] -pub struct EmbedResult { - /// Embedding vectors, one per input text. - pub embeddings: Vec>, - /// Model identifier used. - pub model: String, - /// Embedding dimension. - pub dimension: u32, - /// Wall-clock time in milliseconds. - pub duration_ms: u64, -} - -/// Result of a late-chunked embedding operation — one `LateChunkResult` -/// per chunk produced by the service, in document order. -#[derive(Debug, Clone)] -pub struct LateChunkedResult { - pub chunks: Vec, - pub model: String, - pub dimension: u32, - pub duration_ms: u64, - pub context_window_used: u32, -} - -/// A single late-chunked span: its text, its position in the source -/// document, and its context-aware embedding. -#[derive(Debug, Clone)] -pub struct LateChunkResult { - pub text: String, - pub embedding: Vec, - pub start_char: u32, - pub end_char: u32, - pub start_token: u32, - pub end_token: u32, - pub chunk_index: u32, - pub total_chunks: u32, -} +// `EmbedResult`, `LateChunkedResult`, `LateChunkResult` definitions +// relocated to `weaver-core::embedder` per PR-0.5.B; re-exported above +// from this module's existing public surface for backward compat +// during the migration window. /// Summary returned by a `graceful_shutdown` call. #[derive(Debug, Clone)] diff --git a/crates/weaver-embedding/src/lib.rs b/crates/weaver-embedding/src/lib.rs index c1ff5e0b..1127a13c 100644 --- a/crates/weaver-embedding/src/lib.rs +++ b/crates/weaver-embedding/src/lib.rs @@ -68,6 +68,19 @@ pub mod grpc_client; /// backend-agnostic. pub mod embedder; +// `Embedder` and `EmbedderInfo` re-exports — the trait + types relocated +// to `weaver-core::embedder` per `docs/specs/weaver-spu-Spec.md` PR-0.5.B. +// New code should import from `weaver_core::embedder::*`. These re-exports +// keep existing consumers compiling during the migration; removed in +// PR-0.5.E with the rest of this crate. +// +// `#[allow(deprecated)]` is required here because `embedder::Embedder` +// and `embedder::EmbedderInfo` are themselves marked deprecated in the +// `embedder` module — re-exporting them at the crate root would trigger +// `-D warnings` without this. Narrower allow than a crate-level one; +// scoped to just this re-export line per the reviewer's "narrower +// allow" guidance. +#[allow(deprecated)] pub use embedder::{Embedder, EmbedderInfo}; // Native llama.cpp GGUF embedding backend. Behind the `gguf` diff --git a/crates/weaver-embedding/tests/grpc_client.rs b/crates/weaver-embedding/tests/grpc_client.rs index 88fcd9be..969e900d 100644 --- a/crates/weaver-embedding/tests/grpc_client.rs +++ b/crates/weaver-embedding/tests/grpc_client.rs @@ -7,6 +7,15 @@ //! backed by a mock `EmbeddingService` and exercise the full client RPC //! surface (Embed, EmbedLateChunked, Info, ensure_ready, GracefulShutdown). +// Module-scoped `allow(deprecated)` — these tests reference the +// `weaver_embedding::grpc_client::{EmbedResult, EmbeddingError, ...}` +// type aliases that were marked `#[deprecated]` in PR-0.5.B (they now +// alias `weaver_core::embedder::*`). The tests stay on the legacy path +// because they're testing the legacy gRPC client; the whole crate +// retires in PR-0.5.E. Scoped to this single test file rather than +// crate-wide. +#![allow(deprecated)] + use std::path::PathBuf; use weaver_embedding::grpc_client::{EmbedResult, EmbeddingClientConfig, EmbeddingEndpoint}; @@ -115,6 +124,14 @@ mod mock_server { /// the server-inconsistency case (e.g. `Some(0)` paired with /// non-empty chunks — the Rust client must reject this). pub dimension_override: Option, + /// When `Some(status)`, `info()` returns the supplied tonic + /// `Status` instead of a successful `InfoResponse`. Used to + /// exercise `EmbeddingClient::ensure_ready`'s downcast path + /// (PR-0.5.B abstracted EmbeddingError over transport via + /// `Box`; ensure_ready downcasts back to specific + /// transport-error types to discriminate UNAVAILABLE-class + /// failures from other status codes). Default `None`. + pub info_error: Option, } #[tonic::async_trait] @@ -201,6 +218,9 @@ mod mock_server { &self, _request: Request, ) -> Result, Status> { + if let Some(err) = self.info_error.clone() { + return Err(err); + } Ok(Response::new(InfoResponse { model_name: "mock".into(), dimension: 4, @@ -309,6 +329,7 @@ async fn test_info_roundtrip() { model_loaded: true, shutdown_count: Arc::new(AtomicU32::new(0)), dimension_override: None, + info_error: None, }) .await; @@ -334,6 +355,7 @@ async fn test_embed_roundtrip() { model_loaded: true, shutdown_count: Arc::new(AtomicU32::new(0)), dimension_override: None, + info_error: None, }) .await; @@ -363,6 +385,7 @@ async fn test_embed_late_chunked_roundtrip() { model_loaded: true, shutdown_count: Arc::new(AtomicU32::new(0)), dimension_override: None, + info_error: None, }) .await; @@ -414,6 +437,7 @@ async fn test_embed_late_chunked_rejects_zero_dimension_with_values() { model_loaded: true, shutdown_count: Arc::new(AtomicU32::new(0)), dimension_override: Some(0), + info_error: None, }) .await; @@ -439,6 +463,7 @@ async fn test_ensure_ready_ok_when_loaded() { model_loaded: true, shutdown_count: Arc::new(AtomicU32::new(0)), dimension_override: None, + info_error: None, }) .await; @@ -462,6 +487,7 @@ async fn test_ensure_ready_rejects_unloaded_model() { model_loaded: false, shutdown_count: Arc::new(AtomicU32::new(0)), dimension_override: None, + info_error: None, }) .await; @@ -489,6 +515,7 @@ async fn test_graceful_shutdown_roundtrip() { model_loaded: true, shutdown_count: counter.clone(), dimension_override: None, + info_error: None, }) .await; @@ -503,3 +530,162 @@ async fn test_graceful_shutdown_roundtrip() { srv.stop().await; } + +// ----------------------------------------------------------------------------- +// `ensure_ready` Transport-downcast regression tests (PR-0.5.B). +// +// PR-0.5.B abstracted `EmbeddingError` over transport via +// `Transport(Box)`. `EmbeddingClient::ensure_ready` +// downcasts the boxed error back to concrete tonic types +// (`tonic::Status`, `tonic::transport::Error`) so that +// UNAVAILABLE/DEADLINE_EXCEEDED/FAILED_PRECONDITION codes and raw +// transport failures both surface as `EmbeddingError::NotAvailable` (the +// shape callers pattern-match on). These tests pin that downcast + +// remap behavior so refactors of the abstract `Transport` payload +// can't silently break it. +// ----------------------------------------------------------------------------- + +#[tokio::test] +async fn test_ensure_ready_maps_unavailable_status_to_not_available() { + use mock_server::{MockService, spawn}; + use std::sync::{Arc, atomic::AtomicU32}; + use tonic::Status; + use weaver_embedding::grpc_client::EmbeddingError; + + let srv = spawn(MockService { + model_loaded: true, + shutdown_count: Arc::new(AtomicU32::new(0)), + dimension_override: None, + info_error: Some(Status::unavailable("backend warming up")), + }) + .await; + + let client = connect_tcp_client(srv.tcp_url.clone()).await; + let err = client.ensure_ready().await.unwrap_err(); + match err { + EmbeddingError::NotAvailable(msg) => { + assert!( + msg.contains("Unavailable") || msg.contains("warming up"), + "msg = {msg:?}" + ); + } + other => panic!("expected NotAvailable, got {other:?}"), + } + + srv.stop().await; +} + +#[tokio::test] +async fn test_ensure_ready_maps_deadline_exceeded_status_to_not_available() { + use mock_server::{MockService, spawn}; + use std::sync::{Arc, atomic::AtomicU32}; + use tonic::Status; + use weaver_embedding::grpc_client::EmbeddingError; + + let srv = spawn(MockService { + model_loaded: true, + shutdown_count: Arc::new(AtomicU32::new(0)), + dimension_override: None, + info_error: Some(Status::deadline_exceeded("timed out")), + }) + .await; + + let client = connect_tcp_client(srv.tcp_url.clone()).await; + let err = client.ensure_ready().await.unwrap_err(); + assert!( + matches!(err, EmbeddingError::NotAvailable(_)), + "expected NotAvailable for DEADLINE_EXCEEDED, got {err:?}" + ); + + srv.stop().await; +} + +#[tokio::test] +async fn test_ensure_ready_maps_failed_precondition_status_to_not_available() { + use mock_server::{MockService, spawn}; + use std::sync::{Arc, atomic::AtomicU32}; + use tonic::Status; + use weaver_embedding::grpc_client::EmbeddingError; + + let srv = spawn(MockService { + model_loaded: true, + shutdown_count: Arc::new(AtomicU32::new(0)), + dimension_override: None, + info_error: Some(Status::failed_precondition("model not loaded yet")), + }) + .await; + + let client = connect_tcp_client(srv.tcp_url.clone()).await; + let err = client.ensure_ready().await.unwrap_err(); + assert!( + matches!(err, EmbeddingError::NotAvailable(_)), + "expected NotAvailable for FAILED_PRECONDITION, got {err:?}" + ); + + srv.stop().await; +} + +#[tokio::test] +async fn test_ensure_ready_propagates_other_status_codes_as_transport() { + // Negative case: a non-UNAVAILABLE-class status code (e.g. INTERNAL) + // should propagate as `Transport`, not get re-mapped to NotAvailable. + // This pins the discrimination logic — callers shouldn't see + // arbitrary backend errors masked as "service is briefly down." + use mock_server::{MockService, spawn}; + use std::sync::{Arc, atomic::AtomicU32}; + use tonic::Status; + use weaver_embedding::grpc_client::EmbeddingError; + + let srv = spawn(MockService { + model_loaded: true, + shutdown_count: Arc::new(AtomicU32::new(0)), + dimension_override: None, + info_error: Some(Status::internal("oops")), + }) + .await; + + let client = connect_tcp_client(srv.tcp_url.clone()).await; + let err = client.ensure_ready().await.unwrap_err(); + assert!( + matches!(err, EmbeddingError::Transport(_)), + "expected Transport for INTERNAL, got {err:?}" + ); + + srv.stop().await; +} + +#[tokio::test] +async fn test_ensure_ready_maps_dropped_server_to_not_available() { + // Raw `tonic::transport::Error` path: spin up the server, connect, + // stop the server, then call `ensure_ready`. The follow-up RPC + // fails at the transport layer (server gone), exercising the + // downcast-to-`tonic::transport::Error` arm in + // `ensure_ready`'s match. + use mock_server::{MockService, spawn}; + use std::sync::{Arc, atomic::AtomicU32}; + use weaver_embedding::grpc_client::EmbeddingError; + + let srv = spawn(MockService { + model_loaded: true, + shutdown_count: Arc::new(AtomicU32::new(0)), + dimension_override: None, + info_error: None, + }) + .await; + + let tcp_url = srv.tcp_url.clone(); + let client = connect_tcp_client(tcp_url).await; + // First call succeeds — confirms client + server are healthy. + client.info().await.expect("info() succeeds while server is up"); + + // Drop the server, then issue ensure_ready. Whether the next call + // surfaces as a tonic::transport::Error or an UNAVAILABLE Status + // depends on tonic's internal state machine, but BOTH paths must + // map to NotAvailable per ensure_ready's contract. + srv.stop().await; + let err = client.ensure_ready().await.unwrap_err(); + assert!( + matches!(err, EmbeddingError::NotAvailable(_)), + "expected NotAvailable after server drop, got {err:?}" + ); +} diff --git a/crates/weaver-interface/src/main.rs b/crates/weaver-interface/src/main.rs index a3040b8a..a47690bb 100644 --- a/crates/weaver-interface/src/main.rs +++ b/crates/weaver-interface/src/main.rs @@ -454,6 +454,7 @@ async fn run_repl( db_pool: None, codebase_pool: None, memory_bypass: false, + embedder: None, }; let registry = ToolRegistry::with_builtins(&tool_ctx);