diff --git a/Cargo.lock b/Cargo.lock index c6eea2cc..9b9bab8b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1119,12 +1119,6 @@ dependencies = [ "glob", ] -[[package]] -name = "fixedbitset" -version = "0.5.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d674e81391d1e1ab681a28d99df07927c6d4aa5b027d7da16ba32d1d21ecd99" - [[package]] name = "flate2" version = "1.1.9" @@ -2465,12 +2459,6 @@ dependencies = [ "syn", ] -[[package]] -name = "multimap" -version = "0.10.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d87ecb2933e8aeadb3e3a02b828fed80a7528047e68b4f424523a0981a3a084" - [[package]] name = "nalgebra" version = "0.33.3" @@ -2746,16 +2734,6 @@ version = "2.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9b4f627cb1b25917193a259e49bdad08f671f8d9708acfd5fe0a8c1455d87220" -[[package]] -name = "petgraph" -version = "0.7.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3672b37090dbd86368a4145bc067582552b29c27377cad4e0a306c97f9bd7772" -dependencies = [ - "fixedbitset", - "indexmap", -] - [[package]] name = "phf" version = "0.11.3" @@ -2882,49 +2860,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2796faa41db3ec313a31f7624d9286acf277b52de526150b7e69f3debf891ee5" dependencies = [ "bytes", - "prost-derive", -] - -[[package]] -name = "prost-build" -version = "0.13.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be769465445e8c1474e9c5dac2018218498557af32d9ed057325ec9a41ae81bf" -dependencies = [ - "heck", - "itertools 0.11.0", - "log", - "multimap", - "once_cell", - "petgraph", - "prettyplease", - "prost", - "prost-types", - "regex", - "syn", - "tempfile", -] - -[[package]] -name = "prost-derive" -version = "0.13.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a56d757972c98b346a9b766e3f02746cde6dd1cd1d1d563472929fdd74bec4d" -dependencies = [ - "anyhow", - "itertools 0.11.0", - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "prost-types" -version = "0.13.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "52c2c1bf36ddb1a1c396b3601a3cec27c2462e45f07c386894ec3ccf5332bd16" -dependencies = [ - "prost", ] [[package]] @@ -4180,20 +4115,6 @@ dependencies = [ "tracing", ] -[[package]] -name = "tonic-build" -version = "0.13.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eac6f67be712d12f0b41328db3137e0d0757645d8904b4cb7d51cd9c2279e847" -dependencies = [ - "prettyplease", - "proc-macro2", - "prost-build", - "prost-types", - "quote", - "syn", -] - [[package]] name = "tower" version = "0.5.3" @@ -4814,8 +4735,6 @@ dependencies = [ "hyperlocal", "moka", "proc-macro2", - "prost", - "prost-types", "rand 0.9.2", "regex", "rustpython-parser", @@ -4829,8 +4748,6 @@ dependencies = [ "tokio", "tokio-stream", "tokio-util", - "tonic", - "tonic-build", "tower", "tracing", "urlencoding", @@ -4915,12 +4832,9 @@ dependencies = [ "encoding_rs", "half", "hf-hub", - "hyper-util", "llama-cpp-2", "memmap2", "parking_lot", - "prost", - "prost-types", "reqwest", "rustc-hash 1.1.0", "safetensors 0.7.0", @@ -4932,11 +4846,7 @@ dependencies = [ "tiktoken-rs", "tokenizers 0.21.4", "tokio", - "tokio-stream", "toml", - "tonic", - "tonic-build", - "tower", "tracing", "tracing-subscriber", "weaver-core", diff --git a/crates/weaver-core/src/embedder.rs b/crates/weaver-core/src/embedder.rs index a7a768ef..24bb669e 100644 --- a/crates/weaver-core/src/embedder.rs +++ b/crates/weaver-core/src/embedder.rs @@ -129,25 +129,6 @@ pub enum EmbeddingError { NotAvailable(String), } -// gRPC-transport `From` impls. These live here (not in -// `weaver_spu::encoder::grpc_client_legacy`) because Rust's orphan rule requires foreign-on- -// foreign impls to be in the trait's crate. The cost is one tonic dep -// in `weaver-core`; the alternative (explicit `.map_err` at every `?` -// site in the gRPC client) is a larger maintenance burden than the -// dep. Removed in PR-3.A along with the gRPC client itself, when the -// last consumer (`weaver_spu::encoder::grpc_client_legacy`) goes away in PR-3.A. -impl From for EmbeddingError { - fn from(e: tonic::transport::Error) -> Self { - EmbeddingError::Transport(Box::new(e)) - } -} - -impl From for EmbeddingError { - fn from(e: tonic::Status) -> Self { - EmbeddingError::Transport(Box::new(e)) - } -} - /// Backend-agnostic embedding service. /// /// Implementations: diff --git a/crates/weaver-database/Cargo.toml b/crates/weaver-database/Cargo.toml index 25cc19eb..d676bdb8 100644 --- a/crates/weaver-database/Cargo.toml +++ b/crates/weaver-database/Cargo.toml @@ -29,11 +29,6 @@ tracing.workspace = true anyhow.workspace = true thiserror.workspace = true -# gRPC / Protobuf -tonic.workspace = true -prost.workspace = true -prost-types.workspace = true - # Encoding base64.workspace = true urlencoding.workspace = true @@ -56,9 +51,6 @@ syn.workspace = true sha2.workspace = true proc-macro2.workspace = true -[build-dependencies] -tonic-build.workspace = true - [dev-dependencies] tempfile.workspace = true tokio-stream.workspace = true diff --git a/crates/weaver-database/build.rs b/crates/weaver-database/build.rs deleted file mode 100644 index 20566c6f..00000000 --- a/crates/weaver-database/build.rs +++ /dev/null @@ -1,29 +0,0 @@ -use std::path::PathBuf; - -fn main() -> Result<(), Box> { - let manifest_dir = std::env::var("CARGO_MANIFEST_DIR").expect("CARGO_MANIFEST_DIR not set"); - let proto_root = PathBuf::from(manifest_dir) - .join("../../proto") - .canonicalize() - .expect("proto/ directory not found — expected at workspace root"); - - let protos: Vec = [ - "persephone/common/common.proto", - "persephone/embedding/embedding.proto", - "persephone/extraction/extraction.proto", - "persephone/training/training.proto", - ] - .iter() - .map(|p| proto_root.join(p)) - .collect(); - - tonic_build::configure() - .build_server(true) - .build_client(true) - .compile_protos(&protos, &[&proto_root])?; - - // Re-run if any proto file changes - println!("cargo:rerun-if-changed={}", proto_root.display()); - - Ok(()) -} diff --git a/crates/weaver-database/src/lib.rs b/crates/weaver-database/src/lib.rs index 65debd5d..c89f6030 100644 --- a/crates/weaver-database/src/lib.rs +++ b/crates/weaver-database/src/lib.rs @@ -4,15 +4,5 @@ pub mod code; pub mod config; pub mod db; pub mod graph; -// `persephone` (the embedding gRPC client) moved to -// `weaver-embedding::grpc_client` per -// `embedder-oxidization-Spec.md` §5.1 (issue #166 / sprint -// Block A.5). External consumers import from -// `weaver_spu::encoder::grpc_client_legacy` and `weaver_spu::proto::embedding`. -// `weaver-database`'s `proto` module continues to generate -// embedding-service types for backward compat with internal -// proto-validation tests; a future cleanup PR removes the -// duplication once all consumers migrate. -pub mod proto; pub use config::HadesConfig; diff --git a/crates/weaver-database/src/proto.rs b/crates/weaver-database/src/proto.rs deleted file mode 100644 index 4046f6af..00000000 --- a/crates/weaver-database/src/proto.rs +++ /dev/null @@ -1,25 +0,0 @@ -//! Persephone gRPC/protobuf definitions. -//! -//! Generated from `.proto` files in the `proto/persephone/` directory. -//! Provides both client stubs and server traits for the embedding, -//! extraction, and training services. - -/// Common types shared across Persephone services. -pub mod common { - tonic::include_proto!("persephone.common"); -} - -/// Embedding service — vector embedding generation. -pub mod embedding { - tonic::include_proto!("persephone.embedding"); -} - -/// Extraction service — document content extraction. -pub mod extraction { - tonic::include_proto!("persephone.extraction"); -} - -/// Training service — RGCN link prediction on knowledge graphs. -pub mod training { - tonic::include_proto!("persephone.training"); -} diff --git a/crates/weaver-database/tests/proto_types.rs b/crates/weaver-database/tests/proto_types.rs deleted file mode 100644 index a7ab0228..00000000 --- a/crates/weaver-database/tests/proto_types.rs +++ /dev/null @@ -1,171 +0,0 @@ -//! Verify generated protobuf types compile and are accessible. - -use weaver_database::proto::common::{ChunkMetadata, ChunkingStrategy}; -use weaver_database::proto::embedding::{EmbedRequest, EmbedResponse, Embedding, InfoResponse}; -use weaver_database::proto::extraction::{ - ExtractRequest, ExtractResponse, ExtractorInfo, SourceType, Table, -}; -use weaver_database::proto::training::{ - CheckpointRequest, EvaluateRequest, GetEmbeddingsRequest, InitModelRequest, - LoadCheckpointRequest, LoadGraphRequest, ModelConfig, OptimizerConfig, TrainStepRequest, -}; - -#[test] -fn test_common_types() { - let meta = ChunkMetadata { - chunk_index: 0, - total_chunks: 5, - start_char: 0, - end_char: 500, - }; - assert_eq!(meta.chunk_index, 0); - assert_eq!(meta.total_chunks, 5); - - // Enum values - assert_eq!(ChunkingStrategy::Late as i32, 0); - assert_eq!(ChunkingStrategy::Semantic as i32, 1); - assert_eq!(ChunkingStrategy::Sliding as i32, 2); - assert_eq!(ChunkingStrategy::Token as i32, 3); -} - -#[test] -fn test_embedding_types() { - let req = EmbedRequest { - texts: vec!["hello world".to_string()], - task: "retrieval.passage".to_string(), - batch_size: 32, - }; - assert_eq!(req.texts.len(), 1); - assert_eq!(req.task, "retrieval.passage"); - - let embedding = Embedding { - values: vec![0.1, 0.2, 0.3], - }; - assert_eq!(embedding.values.len(), 3); - - let resp = EmbedResponse { - embeddings: vec![embedding], - model: "jinaai/jina-embeddings-v4".to_string(), - dimension: 2048, - duration_ms: 42, - }; - assert_eq!(resp.embeddings.len(), 1); - assert_eq!(resp.dimension, 2048); - - let info = InfoResponse { - model_name: "test".to_string(), - dimension: 2048, - max_seq_length: 32768, - supported_tasks: vec!["retrieval.passage".to_string()], - device: "cuda:0".to_string(), - model_loaded: true, - uptime_seconds: 123.4, - weight_revision: String::new(), - }; - assert!(info.model_loaded); - assert_eq!(info.model_name, "test"); -} - -#[test] -fn test_extraction_types() { - let req = ExtractRequest { - file_path: "/tmp/paper.pdf".to_string(), - content: vec![], - source_type: SourceType::Pdf.into(), - extract_tables: true, - extract_equations: true, - extract_images: false, - use_ocr: false, - }; - assert_eq!(req.source_type, SourceType::Pdf as i32); - assert!(req.extract_tables); - - let table = Table { - content: "col1 | col2".to_string(), - caption: "Table 1".to_string(), - index: 0, - }; - - let resp = ExtractResponse { - full_text: "Some extracted text".to_string(), - tables: vec![table], - equations: vec![], - images: vec![], - metadata: [("pages".to_string(), "10".to_string())].into(), - source_type: SourceType::Pdf.into(), - }; - assert_eq!(resp.tables.len(), 1); - assert_eq!(resp.metadata.get("pages").unwrap(), "10"); - - let info = ExtractorInfo { - supported_extensions: vec![".pdf".to_string(), ".tex".to_string()], - supported_types: vec![SourceType::Pdf.into(), SourceType::Latex.into()], - features: vec!["ocr".to_string(), "tables".to_string()], - gpu_available: true, - }; - assert_eq!(info.supported_extensions.len(), 2); -} - -#[test] -fn test_training_types() { - let model_config = ModelConfig { - num_relations: 22, - num_collection_types: 62, - hidden_dim: 256, - embed_dim: 128, - num_bases: 21, - dropout: 0.2, - }; - assert_eq!(model_config.num_relations, 22); - assert_eq!(model_config.embed_dim, 128); - - let optimizer = OptimizerConfig { - learning_rate: 0.01, - weight_decay: 5e-4, - }; - assert!(optimizer.learning_rate > 0.0); - - let init_req = InitModelRequest { - model: Some(model_config), - optimizer: Some(optimizer), - device: "cuda:2".to_string(), - }; - assert!(init_req.model.is_some()); - assert_eq!(init_req.device, "cuda:2"); - - let load_req = LoadGraphRequest { - safetensors_path: "/tmp/graph.safetensors".to_string(), - }; - assert!(!load_req.safetensors_path.is_empty()); - - let step_req = TrainStepRequest { - train_edge_indices: vec![0, 1, 2, 3], - neg_src: vec![5, 6], - neg_dst: vec![7, 8], - }; - assert_eq!(step_req.train_edge_indices.len(), 4); - assert_eq!(step_req.neg_src.len(), step_req.neg_dst.len()); - - let eval_req = EvaluateRequest { - edge_indices: vec![10, 11], - neg_src: vec![0], - neg_dst: vec![1], - }; - assert_eq!(eval_req.edge_indices.len(), 2); - - let emb_req = GetEmbeddingsRequest { - output_path: "/tmp/embeddings.safetensors".to_string(), - }; - assert!(!emb_req.output_path.is_empty()); - - let ckpt_req = CheckpointRequest { - path: "/tmp/model.pt".to_string(), - }; - assert!(!ckpt_req.path.is_empty()); - - let load_ckpt = LoadCheckpointRequest { - path: "/tmp/model.pt".to_string(), - device: "cuda:0".to_string(), - }; - assert_eq!(load_ckpt.device, "cuda:0"); -} diff --git a/crates/weaver-demo/src/bin/embedder_latency.rs b/crates/weaver-demo/src/bin/embedder_latency.rs deleted file mode 100644 index e1852f57..00000000 --- a/crates/weaver-demo/src/bin/embedder_latency.rs +++ /dev/null @@ -1,381 +0,0 @@ -// Pre-existing hidden-lifetime pattern at line 345 — not introduced -// by PR-0.5.E (the binary uses gRPC types whose generic lifetimes -// rustc began flagging more strictly). Out-of-scope cleanup for the -// no-op consolidation; allowed at file scope. -#![allow(elided_lifetimes_in_paths)] - -//! Embedder latency characterization for Option B (SuperEgo / ambient -//! surfacing) feasibility. -//! -//! The "latency is the enemy of agency" thesis says the harness's -//! internal-channel speed must beat the agent's external environment -//! by ≥1 OOM. The user's calibration: external = network = 1-10 ms; -//! the harness's transport floor target is the 100 ns range. We are -//! 4-5 OOM faster on transport-only IPC; what we measure here is -//! whether the embedder, which is part of the inner loop for -//! ambient surfacing, blows that budget. -//! -//! Two axes that matter: -//! -//! - **Input size sweep** — real `output.value` payloads from -//! benchmark traces are small (median 2 KB, max ~5 KB). We sweep -//! ~50 B → 32 KB to capture the realistic range plus the upper -//! tail. -//! - **Concurrency sweep** — Option B fires the embedder on every -//! tool output, and a twin-GPU bench has both agents firing into -//! the same embedder. Measuring 1 vs 2 vs 4 simultaneous requests -//! shows whether the current GPU2 (RTX 2000 75W) placement -//! collapses under contention. -//! -//! Output: a markdown table per (size × concurrency) cell with -//! P50 / P95 / P99 / max round-trip in nanoseconds, plus the -//! transport delta (round-trip minus server-reported compute) so -//! the IPC overhead is separable from the GPU cost. -//! -//! Per-call data is captured in Unix nanoseconds from -//! `SystemTime::UNIX_EPOCH` at request emit and response receive, -//! so future runs can be cross-correlated with packet captures or -//! CUDA-side instrumentation if/when we add per-stage profiling on -//! the Python embedder. - -use std::env; -use std::path::PathBuf; -use std::sync::Arc; -use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH}; - -use tokio::task::JoinSet; - -use weaver_spu::encoder::grpc_client_legacy::{EmbeddingClient, EmbeddingClientConfig, EmbeddingEndpoint}; - -/// Single measured call. -struct Sample { - /// Wall-clock at send (Unix nanoseconds). Captured for future - /// cross-process correlation; not consumed by the aggregate - /// stats emitted in this run. - #[allow(dead_code)] - t_send_unix_ns: u128, - /// Wall-clock at receive (Unix nanoseconds). Same rationale as - /// `t_send_unix_ns`. - #[allow(dead_code)] - t_recv_unix_ns: u128, - /// Round-trip from `Instant::now()` deltas — monotonic, immune to - /// wall-clock jumps. This is the load-bearing measurement. - roundtrip_ns: u128, - /// Server-reported compute time. Service still uses ms; converted - /// to ns at the boundary so all subsequent math is in one unit. - server_compute_ns: u128, - /// Round-trip minus server compute = IPC + serialization + - /// scheduler overhead. Negative values are clamped to 0 (the - /// server's `duration_ms` is integer-rounded so 0 ms server with - /// any positive client roundtrip still yields a sensible delta). - transport_overhead_ns: u128, -} - -/// Aggregated stats for a (size × concurrency) cell. -struct CellStats { - n: usize, - rt_p50_ns: u128, - rt_p95_ns: u128, - rt_p99_ns: u128, - rt_max_ns: u128, - rt_mean_ns: u128, - server_p50_ns: u128, - server_p95_ns: u128, - transport_p50_ns: u128, - transport_p95_ns: u128, -} - -impl CellStats { - fn from_samples(samples: &[Sample]) -> Self { - let mut rt: Vec = samples.iter().map(|s| s.roundtrip_ns).collect(); - let mut server: Vec = samples.iter().map(|s| s.server_compute_ns).collect(); - let mut transport: Vec = samples.iter().map(|s| s.transport_overhead_ns).collect(); - rt.sort_unstable(); - server.sort_unstable(); - transport.sort_unstable(); - - let pct = |sorted: &[u128], p: f64| -> u128 { - if sorted.is_empty() { - return 0; - } - // Nearest-rank percentile — adequate at n=100. We avoid - // interpolation so the reported number is an actual - // observed value, not a synthetic midpoint. - let idx = ((p * (sorted.len() as f64)).ceil() as usize).saturating_sub(1); - sorted[idx.min(sorted.len() - 1)] - }; - - let mean = if rt.is_empty() { - 0 - } else { - rt.iter().sum::() / (rt.len() as u128) - }; - - Self { - n: samples.len(), - rt_p50_ns: pct(&rt, 0.50), - rt_p95_ns: pct(&rt, 0.95), - rt_p99_ns: pct(&rt, 0.99), - rt_max_ns: *rt.last().unwrap_or(&0), - rt_mean_ns: mean, - server_p50_ns: pct(&server, 0.50), - server_p95_ns: pct(&server, 0.95), - transport_p50_ns: pct(&transport, 0.50), - transport_p95_ns: pct(&transport, 0.95), - } - } -} - -/// Format a nanosecond count with the most readable unit. -fn fmt_ns(ns: u128) -> String { - if ns < 1_000 { - format!("{ns} ns") - } else if ns < 1_000_000 { - format!("{:.2} µs", (ns as f64) / 1_000.0) - } else if ns < 1_000_000_000 { - format!("{:.2} ms", (ns as f64) / 1_000_000.0) - } else { - format!("{:.3} s", (ns as f64) / 1_000_000_000.0) - } -} - -async fn one_call(client: &EmbeddingClient, text: &str, task: &str) -> anyhow::Result { - let t_send_unix_ns = SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap_or_default() - .as_nanos(); - let t0 = Instant::now(); - - // `embed_one` calls `embed` with a single-element batch; the - // result carries the server-reported `duration_ms`. - let result = client.embed(&[text.to_string()], task, None).await?; - - let elapsed = t0.elapsed(); - let t_recv_unix_ns = SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap_or_default() - .as_nanos(); - - let roundtrip_ns = elapsed.as_nanos(); - let server_compute_ns = (result.duration_ms as u128) * 1_000_000; - let transport_overhead_ns = roundtrip_ns.saturating_sub(server_compute_ns); - - Ok(Sample { - t_send_unix_ns, - t_recv_unix_ns, - roundtrip_ns, - server_compute_ns, - transport_overhead_ns, - }) -} - -/// Run a (size × concurrency) cell. -async fn run_cell( - client: Arc, - text: Arc, - task: &str, - iters: usize, - concurrency: usize, -) -> anyhow::Result> { - let mut all_samples: Vec = Vec::with_capacity(iters); - let task_str = task.to_string(); - - if concurrency == 1 { - for _ in 0..iters { - let s = one_call(&client, &text, &task_str).await?; - all_samples.push(s); - } - } else { - // Split iters across `concurrency` parallel tasks. Each task - // does its own sequential calls — the `concurrency` parameter - // models "how many simultaneous in-flight requests is the - // embedder receiving," which is what GPU2 contention sees - // when both agents fire ambient surfacing at once. - let per_worker = iters.div_ceil(concurrency); - let mut set = JoinSet::new(); - for _ in 0..concurrency { - let c = client.clone(); - let t = text.clone(); - let task_owned = task_str.clone(); - set.spawn(async move { - let mut local: Vec = Vec::with_capacity(per_worker); - for _ in 0..per_worker { - let s = one_call(&c, &t, &task_owned).await?; - local.push(s); - } - Ok::, anyhow::Error>(local) - }); - } - while let Some(res) = set.join_next().await { - let local = res??; - all_samples.extend(local); - } - // Trim any overshoot from the rounding in `per_worker`. - all_samples.truncate(iters); - } - - Ok(all_samples) -} - -#[tokio::main] -async fn main() -> anyhow::Result<()> { - // Allow override of the default socket via env, useful for - // pointing at a moved embedder later (e.g. GPU0 co-resident - // experiment) without recompiling. - let socket = env::var("WEAVER_EMBEDDER_SOCKET") - .map(PathBuf::from) - .unwrap_or_else(|_| PathBuf::from("/run/weaver/embedder.sock")); - - println!("# Embedder latency characterization"); - println!(); - println!("**Date:** {}", chrono::Utc::now().to_rfc3339()); - println!("**Socket:** `{}`", socket.display()); - println!("**Task:** `retrieval.query`"); - println!(); - println!("Per-call timing uses `Instant::now()` for monotonic round-trip and"); - println!("`SystemTime` (Unix nanoseconds) for cross-process correlation."); - println!("Server-reported compute time is in milliseconds; transport overhead"); - println!("= round-trip − server compute."); - println!(); - - let cfg = EmbeddingClientConfig { - endpoint: EmbeddingEndpoint::Unix(socket), - timeout: Duration::from_secs(30), - connect_timeout: Duration::from_secs(5), - }; - - println!("## Connecting…"); - let client = EmbeddingClient::connect(cfg).await?; - let client = Arc::new(client); - println!("Connected."); - println!(); - - // Input size sweep. Sizes chosen from real trace data - // (median 2 KB, max ~5 KB) plus an upper-tail synthetic. - let sizes: &[(usize, &str)] = &[ - (50, "small_50B"), - (500, "small_500B"), - (2_000, "median_2KB"), - (5_000, "max_observed_5KB"), - (16_000, "upper_tail_16KB"), - ]; - - let task = "retrieval.query"; - let warmup_iters = 5; - let measured_iters = 100; - let concurrencies = [1usize, 2, 4]; - - println!("## Warmup"); - println!(); - println!("{warmup_iters} calls at the median size to warm the GPU and bring weights into HBM."); - println!(); - let warm_text = Arc::new("a".repeat(2_000)); - for i in 0..warmup_iters { - let s = one_call(&client, &warm_text, task).await?; - println!( - "warmup {} → roundtrip {}, server {} ms", - i + 1, - fmt_ns(s.roundtrip_ns), - s.server_compute_ns / 1_000_000, - ); - } - println!(); - - // Run each (size × concurrency) cell. - println!("## Results"); - println!(); - println!( - "| Size | Concurrency | n | RT P50 | RT P95 | RT P99 | RT max | RT mean | Server P50 | Server P95 | Transport P50 | Transport P95 |" - ); - println!("|---|---|---|---|---|---|---|---|---|---|---|---|"); - - let mut all_cells: Vec<(String, usize, CellStats)> = Vec::new(); - - for (size_bytes, size_label) in sizes { - // Synthetic payload of the target size. We don't care about - // semantics — Jina V4's compute cost depends on tokenized - // length, and 'a' chars tokenize predictably without - // exercising any vocab edge cases. - let text = Arc::new("a".repeat(*size_bytes)); - - for concurrency in concurrencies { - print!( - "{} (B={}) × c={} → running…", - size_label, size_bytes, concurrency - ); - let samples = run_cell( - client.clone(), - text.clone(), - task, - measured_iters, - concurrency, - ) - .await?; - let stats = CellStats::from_samples(&samples); - println!(" done."); - - println!( - "| {} (`{} B`) | {} | {} | {} | {} | {} | {} | {} | {} | {} | {} | {} |", - size_label, - size_bytes, - concurrency, - stats.n, - fmt_ns(stats.rt_p50_ns), - fmt_ns(stats.rt_p95_ns), - fmt_ns(stats.rt_p99_ns), - fmt_ns(stats.rt_max_ns), - fmt_ns(stats.rt_mean_ns), - fmt_ns(stats.server_p50_ns), - fmt_ns(stats.server_p95_ns), - fmt_ns(stats.transport_p50_ns), - fmt_ns(stats.transport_p95_ns), - ); - - all_cells.push((format!("{size_label} ({size_bytes} B)"), concurrency, stats)); - } - } - - println!(); - println!("## OOM-test reading"); - println!(); - println!("**Reference points** (per `latency is the enemy of agency`):"); - println!("- 100 ns target = harness IPC floor (decisively faster than network)."); - println!("- 1-10 ms = industry-standard \"good\" network latency (the bar to beat by ≥1 OOM)."); - println!("- ≥1 ms = same-OOM-as-network territory; harness loses its claimed advantage."); - println!(); - let single_cells: Vec<&CellStats> = all_cells - .iter() - .filter(|(_, c, _)| *c == 1) - .map(|(_, _, s)| s) - .collect(); - if let Some(median_cell) = single_cells.get(2) { - let rt_p95 = median_cell.rt_p95_ns; - let in_ms = (rt_p95 as f64) / 1_000_000.0; - let oom_vs_network = if rt_p95 > 0 { - (10_000_000_f64 / rt_p95 as f64).log10() - } else { - f64::INFINITY - }; - println!( - "Median-size (~2 KB), c=1, P95 round-trip: **{}** ({:.2} ms).", - fmt_ns(rt_p95), - in_ms - ); - println!("vs 10 ms network: {:.2} OOM faster.", oom_vs_network); - println!( - "vs 100 ns transport floor: {:.2} OOM slower.", - ((rt_p95 as f64) / 100.0).log10() - ); - println!(); - println!( - "**Reading:** the embedder is in the millisecond range, comparable to network latency. \ - It is part of the harness but its compute cost puts it in the same OOM as the external \ - environment it is supposed to beat. Ambient surfacing in Option B would have to either \ - (a) fire rarely enough that millisecond cost amortizes, or (b) move the embedder into a \ - topology where the compute cost is hidden behind cheaper transport (e.g. co-resident on \ - the agent's GPU with intra-VRAM transfer)." - ); - } - - Ok(()) -} diff --git a/crates/weaver-demo/src/herobench/benchmark.rs b/crates/weaver-demo/src/herobench/benchmark.rs index 23fc0c42..5a7aef1a 100644 --- a/crates/weaver-demo/src/herobench/benchmark.rs +++ b/crates/weaver-demo/src/herobench/benchmark.rs @@ -850,150 +850,106 @@ fn generate_task_ids(difficulty: u32, count: usize) -> Vec<(String, String)> { .collect() } -/// Construct an embedder for the PR-10 dedup gate. Tries in-process -/// (`EmbedderClient::from_snapshot`) first when the `embedder-rust` -/// feature is enabled and `WEAVER_SPU_JINA_V4_SNAPSHOT` is set; falls -/// back to the legacy gRPC `EmbeddingClient` against the Python -/// `weaver-embedder.service`. Returns `None` if both paths fail — -/// dedup degrades to unconditional writes in that case. +/// Construct an embedder for the PR-10 dedup gate. Loads the +/// in-process candle-backed `EmbedderClient` when the +/// `embedder-rust` feature is enabled and `WEAVER_SPU_JINA_V4_SNAPSHOT` +/// is set; otherwise returns `None` so dedup degrades to unconditional +/// writes. +/// +/// The legacy gRPC fallback against the Python +/// `weaver-embedder.service` retired in PR-1.J alongside the service +/// itself. async fn try_construct_embedder() -> Option> { #[cfg(feature = "embedder-rust")] { - if let Ok(snapshot) = std::env::var("WEAVER_SPU_JINA_V4_SNAPSHOT") - && !snapshot.trim().is_empty() - { - // GPU ordinal default: 0. Operators wanting a different - // card override via `WEAVER_SPU_CUDA_DEVICE` (matches the - // `jina_embed` smoke binary's env contract). - // - // Strict parse: a malformed value (typo) is an error, not a - // silent fall-through to 0 — otherwise a typo would publish - // dedup work on the wrong GPU without anyone noticing. Only - // *absent* env var defaults to 0. - let gpu_ordinal: usize = match std::env::var("WEAVER_SPU_CUDA_DEVICE") { - Ok(s) => match s.parse() { - Ok(n) => n, - Err(e) => { - tracing::warn!( - "WEAVER_SPU_CUDA_DEVICE = {s:?} is not a valid usize ({e}); \ - falling back to legacy gRPC embedder. \ - Set the env var to a CUDA ordinal (e.g., 0 or 1) or unset it." - ); - return try_construct_embedder_grpc().await; - } - }, - Err(std::env::VarError::NotPresent) => 0, - Err(e) => { - tracing::warn!( - "WEAVER_SPU_CUDA_DEVICE not valid Unicode ({e}); \ - falling back to legacy gRPC embedder." - ); - return try_construct_embedder_grpc().await; - } - }; - tracing::info!( - snapshot = %snapshot, - gpu = gpu_ordinal, - "constructing in-process EmbedderClient for dedup gate" - ); - // Heavy load is sync + GPU-bound; spawn_blocking keeps the - // tokio runtime healthy across the seconds-long mmap + - // weights faulting in. - let snapshot_path = std::path::PathBuf::from(&snapshot); - let result = tokio::task::spawn_blocking(move || { - let device = candle_core::Device::new_cuda(gpu_ordinal) - .map_err(|e| anyhow::anyhow!("CUDA cuda:{gpu_ordinal} unavailable: {e}"))?; - weaver_spu::encoder::client::EmbedderClient::from_snapshot( - &snapshot_path, - candle_core::DType::BF16, - &device, - ) - .map_err(|e| anyhow::anyhow!("EmbedderClient::from_snapshot: {e}")) - }) - .await; - match result { - Ok(Ok(client)) => { - tracing::info!( - "in-process EmbedderClient ready; dedup gate live" - ); - let arc: Arc = Arc::new(client); - return Some(arc); - } - Ok(Err(e)) => { - tracing::warn!( - "in-process EmbedderClient construction failed ({e}); \ - falling back to legacy gRPC embedder" - ); - } + let snapshot = match std::env::var("WEAVER_SPU_JINA_V4_SNAPSHOT") { + Ok(s) if !s.trim().is_empty() => s, + _ => { + tracing::warn!( + "WEAVER_SPU_JINA_V4_SNAPSHOT unset; \ + hypothesis dedup degraded to unconditional writes" + ); + return None; + } + }; + // GPU ordinal default: 0. Operators wanting a different + // card override via `WEAVER_SPU_CUDA_DEVICE` (matches the + // `jina_embed` smoke binary's env contract). + // + // Strict parse: a malformed value (typo) is an error, not a + // silent fall-through to 0 — otherwise a typo would publish + // dedup work on the wrong GPU without anyone noticing. Only + // *absent* env var defaults to 0. + let gpu_ordinal: usize = match std::env::var("WEAVER_SPU_CUDA_DEVICE") { + Ok(s) => match s.parse() { + Ok(n) => n, Err(e) => { tracing::warn!( - "in-process EmbedderClient task panicked ({e}); \ - falling back to legacy gRPC embedder" + "WEAVER_SPU_CUDA_DEVICE = {s:?} is not a valid usize ({e}); \ + hypothesis dedup degraded to unconditional writes. \ + Set the env var to a CUDA ordinal (e.g., 0 or 1) or unset it." ); + return None; } + }, + Err(std::env::VarError::NotPresent) => 0, + Err(e) => { + tracing::warn!( + "WEAVER_SPU_CUDA_DEVICE not valid Unicode ({e}); \ + hypothesis dedup degraded to unconditional writes." + ); + return None; + } + }; + tracing::info!( + snapshot = %snapshot, + gpu = gpu_ordinal, + "constructing in-process EmbedderClient for dedup gate" + ); + // Heavy load is sync + GPU-bound; spawn_blocking keeps the + // tokio runtime healthy across the seconds-long mmap + + // weights faulting in. + let snapshot_path = std::path::PathBuf::from(&snapshot); + let result = tokio::task::spawn_blocking(move || { + let device = candle_core::Device::new_cuda(gpu_ordinal) + .map_err(|e| anyhow::anyhow!("CUDA cuda:{gpu_ordinal} unavailable: {e}"))?; + weaver_spu::encoder::client::EmbedderClient::from_snapshot( + &snapshot_path, + candle_core::DType::BF16, + &device, + ) + .map_err(|e| anyhow::anyhow!("EmbedderClient::from_snapshot: {e}")) + }) + .await; + match result { + Ok(Ok(client)) => { + tracing::info!("in-process EmbedderClient ready; dedup gate live"); + let arc: Arc = Arc::new(client); + return Some(arc); + } + Ok(Err(e)) => { + tracing::warn!( + "in-process EmbedderClient construction failed ({e}); \ + hypothesis dedup degraded to unconditional writes" + ); + return None; + } + Err(e) => { + tracing::warn!( + "in-process EmbedderClient task panicked ({e}); \ + hypothesis dedup degraded to unconditional writes" + ); + return None; } } } - - try_construct_embedder_grpc().await -} - -/// Bounded-timeout legacy gRPC fallback. Both the connect handshake -/// and the readiness probe are wrapped in -/// [`tokio::time::timeout`] so a wedged Python service or a stalled -/// network can't hang benchmark startup indefinitely. On timeout or -/// error → log warn → return `None` (dedup degrades to unconditional -/// writes). -async fn try_construct_embedder_grpc() -> Option> { - use std::time::Duration; - /// Bounded enough to absorb a cold-start Python embedder's socket - /// bind + first-RPC dispatch (~few seconds in practice), short - /// enough that a truly wedged service doesn't stall the whole - /// benchmark start. - const PROBE_TIMEOUT: Duration = Duration::from_secs(10); - - let client = match tokio::time::timeout( - PROBE_TIMEOUT, - weaver_spu::encoder::grpc_client_legacy::EmbeddingClient::connect_default(), - ) - .await + #[cfg(not(feature = "embedder-rust"))] { - Ok(Ok(c)) => c, - Ok(Err(e)) => { - tracing::warn!( - "embedder not reachable ({e}); \ - hypothesis dedup degraded to unconditional writes" - ); - return None; - } - Err(_) => { - tracing::warn!( - "embedder connect timed out after {PROBE_TIMEOUT:?}; \ - hypothesis dedup degraded to unconditional writes" - ); - return None; - } - }; - - match tokio::time::timeout(PROBE_TIMEOUT, client.ensure_ready()).await { - Ok(Ok(_)) => { - let arc: Arc = Arc::new(client); - Some(arc) - } - Ok(Err(e)) => { - tracing::warn!( - "embedder reached but not ready ({e}); \ - hypothesis dedup degraded to unconditional writes" - ); - None - } - Err(_) => { - tracing::warn!( - "embedder ensure_ready timed out after {PROBE_TIMEOUT:?}; \ - hypothesis dedup degraded to unconditional writes" - ); - None - } + tracing::warn!( + "weaver-demo built without `embedder-rust` feature; \ + hypothesis dedup degraded to unconditional writes" + ); + None } } diff --git a/crates/weaver-demo/tests/herobench_integration.rs b/crates/weaver-demo/tests/herobench_integration.rs index bb929bd9..6ed804a9 100644 --- a/crates/weaver-demo/tests/herobench_integration.rs +++ b/crates/weaver-demo/tests/herobench_integration.rs @@ -2294,24 +2294,59 @@ async fn integration_graph_stats() { // --ignored integration_dedup // --------------------------------------------------------------------------- -async fn try_connect_embedder() -> Option { - let client = match weaver_spu::encoder::grpc_client_legacy::EmbeddingClient::connect_default().await { - Ok(c) => c, +/// Construct an in-process candle-backed `EmbedderClient` for the +/// dedup integration tests. Requires `WEAVER_SPU_JINA_V4_SNAPSHOT` +/// to be set and the `embedder-rust` feature enabled. The legacy +/// gRPC connector retired in PR-1.J alongside the Python service. +/// +/// Returns: +/// - `Ok(None)` when `WEAVER_SPU_JINA_V4_SNAPSHOT` is unset — the +/// test then skips cleanly. This is the **only** path that's not +/// a failure; everything else is propagated as `Err` so the +/// test fails loudly instead of silently passing on infrastructure +/// issues (busted snapshot, GPU unavailable, OOM, `WEAVER_SPU_CUDA_DEVICE` +/// typo, etc.). +#[cfg(feature = "embedder-rust")] +async fn try_connect_embedder() +-> anyhow::Result> { + let snapshot = match std::env::var("WEAVER_SPU_JINA_V4_SNAPSHOT") { + Ok(s) if !s.trim().is_empty() => s, + _ => { + eprintln!("skipping: WEAVER_SPU_JINA_V4_SNAPSHOT unset"); + return Ok(None); + } + }; + let gpu_ordinal: usize = match std::env::var("WEAVER_SPU_CUDA_DEVICE") { + Ok(s) => s + .parse() + .map_err(|e| anyhow::anyhow!("WEAVER_SPU_CUDA_DEVICE={s:?} not a valid usize: {e}"))?, + Err(std::env::VarError::NotPresent) => 0, Err(e) => { - eprintln!("skipping: embedder not reachable ({e})"); - return None; + return Err(anyhow::anyhow!( + "WEAVER_SPU_CUDA_DEVICE not valid Unicode: {e}" + )); } }; - if let Err(e) = client.ensure_ready().await { - eprintln!("skipping: embedder not ready ({e})"); - return None; - } - Some(client) + let snapshot_path = std::path::PathBuf::from(&snapshot); + let client = tokio::task::spawn_blocking(move || { + let device = candle_core::Device::new_cuda(gpu_ordinal) + .map_err(|e| anyhow::anyhow!("cuda:{gpu_ordinal} unavailable: {e}"))?; + weaver_spu::encoder::client::EmbedderClient::from_snapshot( + &snapshot_path, + candle_core::DType::BF16, + &device, + ) + .map_err(|e| anyhow::anyhow!("EmbedderClient::from_snapshot: {e}")) + }) + .await + .map_err(|e| anyhow::anyhow!("in-process embedder spawn_blocking task panicked: {e}"))??; + Ok(Some(client)) } /// Clean up all attempt_hypothesis nodes + refines edges for a given task /// name. Ensures the test starts from an empty state regardless of prior /// runs. +#[cfg(feature = "embedder-rust")] async fn cleanup_hypothesis_task(pool: &weaver_database::db::ArangoPool, task_name: &str) { // Delete matching refines edges first (foreign keys). Filter on `_from` // (the canonical relationship field) rather than `_key` — edge-key @@ -2355,6 +2390,7 @@ async fn cleanup_hypothesis_task(pool: &weaver_database::db::ArangoPool, task_na /// Count `attempt_hypothesis` nodes for a task and the number of `refines` /// edges whose source is one of those nodes. +#[cfg(feature = "embedder-rust")] async fn count_task_nodes_and_refines( pool: &weaver_database::db::ArangoPool, task_name: &str, @@ -2409,11 +2445,15 @@ async fn count_task_nodes_and_refines( (nodes, edges) } +#[cfg(feature = "embedder-rust")] #[tokio::test] -#[ignore = "requires live Arango + Persephone embedder"] +#[ignore = "requires live Arango + Jina V4 snapshot (set WEAVER_SPU_JINA_V4_SNAPSHOT)"] async fn integration_dedup_exact_duplicate_is_skipped() { let Some(pool) = hades_pool() else { return }; - let Some(embedder) = try_connect_embedder().await else { + let Some(embedder) = try_connect_embedder() + .await + .expect("embedder construction should not fail (set WEAVER_SPU_JINA_V4_SNAPSHOT to skip cleanly)") + else { return; }; belief::ensure_collections(&pool) @@ -2451,11 +2491,15 @@ async fn integration_dedup_exact_duplicate_is_skipped() { cleanup_hypothesis_task(&pool, task).await; } +#[cfg(feature = "embedder-rust")] #[tokio::test] -#[ignore = "requires live Arango + Persephone embedder"] +#[ignore = "requires live Arango + Jina V4 snapshot (set WEAVER_SPU_JINA_V4_SNAPSHOT)"] async fn integration_dedup_near_duplicate_writes_refines_edge() { let Some(pool) = hades_pool() else { return }; - let Some(embedder) = try_connect_embedder().await else { + let Some(embedder) = try_connect_embedder() + .await + .expect("embedder construction should not fail (set WEAVER_SPU_JINA_V4_SNAPSHOT to skip cleanly)") + else { return; }; belief::ensure_collections(&pool) @@ -2511,11 +2555,15 @@ async fn integration_dedup_near_duplicate_writes_refines_edge() { cleanup_hypothesis_task(&pool, task).await; } +#[cfg(feature = "embedder-rust")] #[tokio::test] -#[ignore = "requires live Arango + Persephone embedder"] +#[ignore = "requires live Arango + Jina V4 snapshot (set WEAVER_SPU_JINA_V4_SNAPSHOT)"] async fn integration_dedup_novel_writes_no_refines() { let Some(pool) = hades_pool() else { return }; - let Some(embedder) = try_connect_embedder().await else { + let Some(embedder) = try_connect_embedder() + .await + .expect("embedder construction should not fail (set WEAVER_SPU_JINA_V4_SNAPSHOT to skip cleanly)") + else { return; }; belief::ensure_collections(&pool) diff --git a/crates/weaver-demo/tests/similarity_calibration.rs b/crates/weaver-demo/tests/similarity_calibration.rs deleted file mode 100644 index 1ccb5632..00000000 --- a/crates/weaver-demo/tests/similarity_calibration.rs +++ /dev/null @@ -1,102 +0,0 @@ -//! Live-embedder calibration for the creation-time dedup classifier. -//! -//! Per `docs/specs/memory-creation-dedup.md` acceptance #2, the PR-09 -//! thresholds (0.95 / 0.85) must cleanly separate 10 hypothesis pairs known -//! to be duplicates from 10 pairs known to be novel when embedded by the -//! production embedder (Jina V4 FP16 via Persephone at -//! `/run/weaver/embedder.sock`). -//! -//! This test embeds both corpora and asserts: -//! - every `DUPLICATE_PAIRS` entry classifies as `Duplicate` (score ≥ 0.95), -//! - every `NOVEL_PAIRS` entry classifies as `Novel` (score < 0.85). -//! -//! It is marked `#[ignore]` so unit-test runs on machines without the -//! embedder stay green. Run manually with: -//! -//! ```text -//! cargo test -p weaver-demo --test similarity_calibration -- --ignored --nocapture -//! ``` -//! -//! The `--nocapture` prints every observed score so any threshold delta can -//! be copied into the PR description. - -use weaver_demo::herobench::similarity::{ - Classification, DUPLICATE_PAIRS, DUPLICATE_THRESHOLD, MemoryCandidate, - NEAR_DUPLICATE_THRESHOLD, NOVEL_PAIRS, classify, -}; -use weaver_spu::encoder::grpc_client_legacy::EmbeddingClient; - -/// Embed one hypothesis string into a candidate with no key — callers -/// wrap the result with a key appropriate to the assertion being made. -async fn embed_as_candidate(client: &EmbeddingClient, text: &str, key: &str) -> MemoryCandidate { - let vec = client - .embed_one(text, "retrieval.query") - .await - .expect("embed_one failed; is the Persephone embedder up?"); - MemoryCandidate { - key: key.into(), - text: text.into(), - embedding: vec, - } -} - -#[tokio::test] -#[ignore = "requires live Persephone embedder at /run/weaver/embedder.sock"] -async fn calibrate_duplicate_and_novel_pairs() { - let client = EmbeddingClient::connect_default() - .await - .expect("connect to embedder"); - client - .ensure_ready() - .await - .expect("embedder not ready (model not loaded?)"); - - // --- Duplicates: every pair must score >= 0.95. ------------------------ - println!("\n=== DUPLICATE_PAIRS (threshold >= {DUPLICATE_THRESHOLD}) ==="); - let mut dup_failures: Vec = Vec::new(); - for (i, (a, b)) in DUPLICATE_PAIRS.iter().enumerate() { - let cand_b = embed_as_candidate(&client, b, &format!("dup_{i}_b")).await; - let query = embed_as_candidate(&client, a, &format!("dup_{i}_a")).await; - match classify(&query.embedding, std::slice::from_ref(&cand_b)) { - Classification::Duplicate { score, .. } => { - println!(" [{i:>2}] OK score={score:.4}"); - } - Classification::NearDuplicate { score, .. } => { - println!(" [{i:>2}] NEAR score={score:.4} (under threshold)"); - dup_failures.push(format!("pair {i}: NearDuplicate at {score:.4}")); - } - Classification::Novel { best_match } => { - let score = best_match.map(|(_, s)| s).unwrap_or(0.0); - println!(" [{i:>2}] NOVEL score={score:.4} (under threshold)"); - dup_failures.push(format!("pair {i}: Novel at {score:.4}")); - } - } - } - - // --- Novels: every pair must score < 0.85. ---------------------------- - println!("\n=== NOVEL_PAIRS (threshold < {NEAR_DUPLICATE_THRESHOLD}) ==="); - let mut novel_failures: Vec = Vec::new(); - for (i, (a, b)) in NOVEL_PAIRS.iter().enumerate() { - let cand_b = embed_as_candidate(&client, b, &format!("novel_{i}_b")).await; - let query = embed_as_candidate(&client, a, &format!("novel_{i}_a")).await; - match classify(&query.embedding, std::slice::from_ref(&cand_b)) { - Classification::Novel { best_match } => { - let score = best_match.map(|(_, s)| s).unwrap_or(0.0); - println!(" [{i:>2}] OK score={score:.4}"); - } - Classification::NearDuplicate { score, .. } => { - println!(" [{i:>2}] NEAR score={score:.4} (above threshold)"); - novel_failures.push(format!("pair {i}: NearDuplicate at {score:.4}")); - } - Classification::Duplicate { score, .. } => { - println!(" [{i:>2}] DUP score={score:.4} (above threshold)"); - novel_failures.push(format!("pair {i}: Duplicate at {score:.4}")); - } - } - } - - assert!( - dup_failures.is_empty() && novel_failures.is_empty(), - "calibration failed\nduplicate misses: {dup_failures:#?}\nnovel misses: {novel_failures:#?}" - ); -} diff --git a/crates/weaver-interface/src/serve.rs b/crates/weaver-interface/src/serve.rs index 9f2ec929..cd9c5bef 100644 --- a/crates/weaver-interface/src/serve.rs +++ b/crates/weaver-interface/src/serve.rs @@ -6,27 +6,12 @@ //! on demand via `weaver model load` or agent pre-flight. use std::path::{Path, PathBuf}; -use std::time::Duration; use anyhow::{Result, bail}; use clap::Args; use weaver_core::embedder::EmbedderInfo; -use weaver_spu::decoder::multi_model::EmbedderBackend; -use weaver_spu::encoder::grpc_client_legacy::{ - DEFAULT_EMBEDDER_SOCKET, EmbeddingClient, EmbeddingClientConfig, EmbeddingEndpoint, -}; use weaver_spu::core::pin::{self, DEFAULT_PIN_PATH, EmbedderPin, VerifyOutcome}; -const EMBEDDER_PROBE_TIMEOUT: Duration = Duration::from_secs(3); -/// Total wall-clock budget for establishing an embedder probe. Under systemd, -/// weaver-embedder.service is Type=simple and declared active the moment -/// Python's `exec` returns, but the gRPC server may take a few seconds more -/// to bind its socket and the Jina V4 weights may take longer still to -/// finish cold-loading. We retry connect+ensure_ready up to this deadline -/// before giving up, so ordering via `After=weaver-embedder.service` doesn't -/// race against that startup delay on every boot. -const EMBEDDER_PROBE_DEADLINE: Duration = Duration::from_secs(30); -const EMBEDDER_PROBE_BACKOFF: Duration = Duration::from_millis(500); /// Start the inference server #[derive(Args, Debug)] @@ -66,9 +51,9 @@ pub async fn handle_serve(args: ServeArgs) -> Result<()> { } // Embedder probe + pin verify runs BEFORE model loading so the daemon - // fails fast on a missing snapshot, busted CUDA init, or a wedged - // Python service rather than burning seconds on `build_server_state()` - // first. (PR #296 review.) + // fails fast on a missing snapshot, busted CUDA init, or a Rust + // EmbedderClient construction failure rather than burning seconds on + // `build_server_state()` first. (PR #296 review.) // // Probe+verify is gated on pin presence: // probe Some + pin any → verify (first-boot writes, match OK, mismatch aborts) @@ -79,26 +64,7 @@ pub async fn handle_serve(args: ServeArgs) -> Result<()> { // staying down until the embedder is reachable. Operators see a clear // refusal instead of silent drift. let pin_path = Path::new(DEFAULT_PIN_PATH); - let backend = config - .embedder - .as_ref() - .map(|e| e.backend) - .unwrap_or_default(); - let probe_result = match backend { - EmbedderBackend::Python => { - // Treat an empty string in [embedder].socket as unset — otherwise a - // blank override silently falls through and the probe hits the empty - // path. - let embedder_socket = config - .embedder - .as_ref() - .map(|e| e.socket.trim()) - .filter(|s| !s.is_empty()) - .unwrap_or(DEFAULT_EMBEDDER_SOCKET); - probe_embedder_python(embedder_socket).await - } - EmbedderBackend::Rust => probe_embedder_rust(config.embedder.as_ref()).await?, - }; + let probe_result = probe_embedder_rust(config.embedder.as_ref()).await?; match probe_result { Some(info) => verify_embedder_pin(pin_path, &info)?, None => { @@ -109,9 +75,13 @@ pub async fn handle_serve(args: ServeArgs) -> Result<()> { // Distinguish explicitly. match std::fs::metadata(pin_path) { Ok(_) => bail!( - "Refuse to start: embedder probe failed and a pin exists at {}. \ - The cohort identity cannot be verified without the embedder. \ - Bring weaver-embedder.service up first, then retry.", + "Refuse to start: in-process embedder probe failed and a pin exists at {}. \ + The cohort identity cannot be verified without a working embedder. \ + Check that [embedder].snapshot points at a valid Jina V4 HF snapshot \ + directory, that the GPU ordinal is available, and that the daemon is \ + built with --features embedder-rust. If the pin is intentionally stale \ + (e.g. a corpus re-embed is planned), run \ + `weaver harness embedder reset --force` to clear it before restart.", pin_path.display() ), Err(e) if e.kind() == std::io::ErrorKind::NotFound => { @@ -171,88 +141,21 @@ pub async fn handle_serve(args: ServeArgs) -> Result<()> { .await } -/// Probe the Python embedder socket with bounded retry and log its readiness state. -/// -/// Systemd ordering (`After=weaver-embedder.service` on weaver-infer) puts -/// the embedder first, but `Type=simple` considers a service "active" the -/// moment its exec returns — it does NOT wait for the Python gRPC server to -/// bind its socket or for Jina V4 to finish loading weights. So at the time -/// `weaver serve` runs this probe, the embedder may be "active" per systemd -/// but not yet accepting connections. We retry for `EMBEDDER_PROBE_DEADLINE` -/// with short backoffs to absorb that startup delay instead of failing on -/// the first attempt. -/// -/// Once we're past startup, the refuse-to-start-when-pin-exists check in -/// `handle_serve` catches a truly-wedged embedder: probe returns None → -/// refuse, rather than silently running without the pin verified. -async fn probe_embedder_python(socket: &str) -> Option { - let client_config = EmbeddingClientConfig { - endpoint: EmbeddingEndpoint::Unix(PathBuf::from(socket)), - connect_timeout: EMBEDDER_PROBE_TIMEOUT, - timeout: EMBEDDER_PROBE_TIMEOUT, - }; - - let started = std::time::Instant::now(); - let mut last_err: Option = None; - let mut attempts: u32 = 0; - while started.elapsed() < EMBEDDER_PROBE_DEADLINE { - attempts += 1; - match EmbeddingClient::connect(client_config.clone()).await { - Ok(client) => match client.ensure_ready().await { - Ok(info) => { - println!( - "Embedder ready at {} after {} probe attempt(s): {} (dim={}, max_seq={})", - socket, attempts, info.model_name, info.dimension, info.max_seq_length, - ); - // Convert proto InfoResponse → backend-agnostic - // EmbedderInfo so the pin-verifier downstream is - // backend-shape-independent. The Python service's - // `info()` already returns the proto type via the - // gRPC client's inherent `info` method (not the - // trait one) — this conversion mirrors what the - // Embedder trait impl does internally on - // `EmbeddingClient`. - return Some(EmbedderInfo { - model_name: info.model_name, - model_loaded: info.model_loaded, - dimension: info.dimension, - max_seq_length: info.max_seq_length, - weight_revision: info.weight_revision, - }); - } - Err(e) => last_err = Some(format!("ensure_ready: {e}")), - }, - Err(e) => last_err = Some(format!("connect: {e}")), - } - tokio::time::sleep(EMBEDDER_PROBE_BACKOFF).await; - } - - eprintln!( - "Embedder probe gave up after {:?} ({} attempt(s)) at {}: {} \ - (weaver-embedder.service may be wedged or still cold-loading weights)", - started.elapsed(), - attempts, - socket, - last_err.as_deref().unwrap_or("no attempts ran"), - ); - None -} - /// Construct the in-process Rust embedder backend /// ([`weaver_spu::encoder::client::EmbedderClient`]) and return its /// identity for the cohort-pin verifier. /// -/// Unlike the Python path, there's no socket-readiness retry loop — -/// the embedder either constructs successfully (model loaded, ready -/// to serve) or it doesn't. If construction fails (snapshot path -/// missing, GPU unavailable, weights corrupted), this is a hard boot -/// failure rather than a transient that we'd retry through; callers -/// see the error directly. +/// There's no socket-readiness retry loop — the embedder either +/// constructs successfully (model loaded, ready to serve) or it +/// doesn't. If construction fails (snapshot path missing, GPU +/// unavailable, weights corrupted), this is a hard boot failure +/// rather than a transient that we'd retry through; callers see the +/// error directly. /// /// **Feature gating**: requires `embedder-rust` (which transitively -/// enables `weaver-spu/flash-attn`). When the feature is OFF, this -/// function rejects `backend = "rust"` with a configuration error -/// rather than silently falling through to the Python path. +/// enables `weaver-spu/flash-attn`). When the feature is OFF the +/// daemon has no embedder backend at all; this function bails with a +/// rebuild instruction. async fn probe_embedder_rust( embedder_config: Option<&weaver_spu::decoder::multi_model::EmbedderConfig>, ) -> Result> { @@ -260,10 +163,10 @@ async fn probe_embedder_rust( { let _ = embedder_config; bail!( - "[embedder].backend = \"rust\" requires the daemon to be built with \ - the `embedder-rust` feature; rebuild with \ + "Daemon was built without the `embedder-rust` feature; the in-process \ + embedder is the only supported backend post-PR-1.J. Rebuild with \ `cargo build -p weaver-interface --features inference,embedder-rust` \ - (or omit `embedder-rust` and use `backend = \"python\"`)" + and retry." ); } @@ -274,15 +177,14 @@ async fn probe_embedder_rust( let embedder_config = embedder_config.ok_or_else(|| { anyhow::anyhow!( - "[embedder].backend = \"rust\" but no [embedder] block in server.toml; \ - add the block with `snapshot = \"/path/to/jina-v4/snapshot//\"`" + "no [embedder] block in server.toml; add it with \ + `snapshot = \"/path/to/jina-v4/snapshot//\"`" ) })?; let snapshot = embedder_config.snapshot.as_ref().ok_or_else(|| { anyhow::anyhow!( - "[embedder].backend = \"rust\" requires [embedder].snapshot to point at \ - the Jina V4 HF snapshot directory (e.g. \ - `snapshot = \"/opt/weaver/huggingface/.../snapshots//\"`)" + "[embedder].snapshot must point at the Jina V4 HF snapshot directory \ + (e.g. `snapshot = \"/opt/weaver/huggingface/.../snapshots//\"`)" ) })?; // GPU ordinal from the (now-deprecated-but-still-load-bearing) [embedder].gpu diff --git a/crates/weaver-spu/Cargo.toml b/crates/weaver-spu/Cargo.toml index 62e0e668..a3914304 100644 --- a/crates/weaver-spu/Cargo.toml +++ b/crates/weaver-spu/Cargo.toml @@ -63,17 +63,9 @@ llama-cpp-legacy = ["gguf"] # / GPU primitives the legacy decoder backend depends on. weaver-core = { workspace = true } -# Encoder-side gRPC stack (folded in from weaver-embedding in -# PR-0.5.D). Used by `encoder::grpc_client_legacy` to talk to the -# Python `weaver-embedder.service` during the migration window. -# Always compiled — `grpc_client_legacy` is the **production -# embedder backend** today, not a feature-gated path. Retires in -# PR-3.A alongside the gRPC client. -tonic = { workspace = true } -prost = { workspace = true } -prost-types = { workspace = true } -hyper-util = { workspace = true } -tower = { workspace = true } +# `async-trait` is consumed by the `weaver_core::embedder::Embedder` +# trait we implement; remains required even after the gRPC stack +# retired in PR-1.J. async-trait = { workspace = true } serde_json = { workspace = true } chrono = { workspace = true } @@ -141,20 +133,13 @@ toml = { workspace = true } tracing-subscriber = { workspace = true } [build-dependencies] -# `weaver-inference` had a `cc` build-dep for compiling its custom -# CUDA kernels via build.rs. Folds in unchanged via PR-0.5.C. +# `cc` for compiling the custom CUDA kernels via build.rs (folded in +# from weaver-inference in PR-0.5.C). The `tonic-build` build-dep +# that lived alongside it for the Persephone proto retired in PR-1.J. cc = "1" -# `weaver-embedding` had a `tonic-build` build-dep for compiling the -# Persephone embedding-service proto. Folded in via PR-0.5.D; retires -# alongside the gRPC client in PR-3.A. -tonic-build = { workspace = true } [dev-dependencies] tempfile = { workspace = true } -# `tokio-stream` for the legacy grpc_client integration tests folded -# in from weaver-embedding in PR-0.5.D (the test spins up an -# in-process tonic server backed by a TcpListener-derived stream). -tokio-stream = { workspace = true } # `memmap2` for the layerwise-bisection diagnostic test (`tests/ # jina_v4_layerwise_bisection.rs`). The captured-state fixture is # ~2.4 GiB and reading it through std::fs::read materializes the diff --git a/crates/weaver-spu/build.rs b/crates/weaver-spu/build.rs index da55046c..b2053a8a 100644 --- a/crates/weaver-spu/build.rs +++ b/crates/weaver-spu/build.rs @@ -1,31 +1,16 @@ -//! `weaver-spu` build script — combines: +//! `weaver-spu` build script — compiles the legacy CUDA-kernel +//! sources (folded in from `weaver-inference` in PR-0.5.C). The +//! step is gated by the `cuda` feature, since the kernels only get +//! linked when the cudarc-backed decoder path is in scope. //! -//! 1. The legacy CUDA-kernel compile step (folded in from -//! `weaver-inference` in PR-0.5.C; gated by `cuda` feature, since -//! the kernels only get linked when the cudarc-backed decoder -//! path is in scope). -//! 2. The Persephone embedding-service proto compile step (folded -//! in from `weaver-embedding` in PR-0.5.D; **always compiled** -//! because `encoder::grpc_client_legacy` — the production -//! embedder backend during the migration window — depends on -//! the generated proto types unconditionally). -//! -//! The CUDA step is skipped when `cuda` is not set. The proto step -//! always runs. Default `cargo build` (no features) does the proto -//! compile only. - -use std::path::PathBuf; +//! The Persephone proto compile step that lived here through the +//! migration window retired in PR-1.J alongside the Python +//! embedder service. fn main() -> Result<(), Box> { if std::env::var("CARGO_FEATURE_CUDA").is_ok() { compile_cuda_kernels(); } - // Persephone proto is always compiled — `encoder::grpc_client_legacy` - // is the production embedder backend during the migration window - // (talks to the Python `weaver-embedder.service`), not a - // feature-gated path. Retires alongside the gRPC client in - // PR-3.A. - compile_persephone_proto()?; Ok(()) } @@ -62,35 +47,3 @@ fn compile_cuda_kernels() { println!("cargo:rustc-link-lib=dylib=cudart"); println!("cargo:rustc-link-lib=dylib=cublas"); } - -fn compile_persephone_proto() -> Result<(), Box> { - let manifest_dir = std::env::var("CARGO_MANIFEST_DIR").expect("CARGO_MANIFEST_DIR not set"); - let proto_root = PathBuf::from(manifest_dir) - .join("../../proto") - .canonicalize() - .expect("proto/ directory not found — expected at workspace root"); - - // Embedding-only. `weaver-database` keeps generating common / - // extraction / training for its own consumers; this proto step - // retires alongside the gRPC client in PR-3.A. - let protos: Vec = ["persephone/embedding/embedding.proto"] - .iter() - .map(|p| proto_root.join(p)) - .collect(); - - tonic_build::configure() - .build_server(true) - .build_client(true) - .compile_protos(&protos, &[&proto_root])?; - - // Per-proto rerun-if-changed so editing any individual `.proto` - // forces a rebuild — `proto_root` alone catches new files added - // to the directory but not edits to existing ones reliably - // across cargo versions. - for proto in &protos { - println!("cargo:rerun-if-changed={}", proto.display()); - } - println!("cargo:rerun-if-changed={}", proto_root.display()); - - Ok(()) -} diff --git a/crates/weaver-spu/src/decoder/multi_model.rs b/crates/weaver-spu/src/decoder/multi_model.rs index b01a0d01..2abc53f4 100644 --- a/crates/weaver-spu/src/decoder/multi_model.rs +++ b/crates/weaver-spu/src/decoder/multi_model.rs @@ -39,7 +39,7 @@ use serde::{Deserialize, Serialize}; use tracing::warn; use crate::decoder::family::{ModelMode, Pooling, family_by_id}; -use crate::core::gpu_orchestrator::{GpuConfig, GpuPolicy}; +use crate::core::gpu_orchestrator::GpuConfig; use crate::decoder::model_profile::{SamplingDefaults, ServerConfig}; /// Top-level multi-model configuration. @@ -104,62 +104,32 @@ pub struct GpuPairConfig { pub decoder_gpu: usize, } -/// Backend discriminator for the embedder. -/// -/// Selects which implementation of [`weaver_core::embedder::Embedder`] -/// the daemon constructs at boot. Per `embedder-oxidization-Spec.md` §3: -/// -/// - **Python** (default during the migration window): connect to the -/// external `weaver-embedder.service` over Unix-socket gRPC. The -/// daemon's only embedder use today is the boot-time cohort-pin probe. -/// - **Rust** (Phase 1 cutover): construct an in-process `EmbedderClient` -/// that owns a `JinaV4Embedder` via candle, no RPC hop. Requires the -/// `weaver-spu/flash-attn` feature on the daemon binary; without it, -/// selecting `Rust` here is a hard fail at boot rather than a silent -/// fallback to the Python path. -/// -/// Default stays `Python` until Phase 3 retires the gRPC path and the -/// default flips. See `docs/specs/embedder-oxidization-Spec.md` for the -/// cutover plan. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)] -#[serde(rename_all = "lowercase")] -pub enum EmbedderBackend { - /// External Python `weaver-embedder.service` over Unix-socket gRPC. - #[default] - Python, - /// In-process candle-backed embedder (`weaver_spu::encoder::client:: - /// EmbedderClient`). - Rust, -} - /// Embedder service placement + operational defaults. /// /// Harness-scoped singleton: `weaver-infer` up ⇒ embedder up on the /// declared GPU. The embedder is mandatory infrastructure (same tier as -/// ArangoDB) — agents must not run without it. Fields here are the -/// single source of truth; the Python service reads the same TOML. +/// ArangoDB) — agents must not run without it. +/// +/// Post-PR-1.J the only embedder backend is the in-process candle +/// `EmbedderClient`; the Python `weaver-embedder.service` and its gRPC +/// client retired together. The legacy `socket`/`model_name`/ +/// `batch_size`/`use_fp16`/`idle_timeout_seconds` fields are kept on +/// the struct for back-compat with existing `server.toml` files but +/// are not read by anything that ships today. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct EmbedderConfig { - /// Which backend implementation to use. See [`EmbedderBackend`] for - /// the migration-window semantics. Defaults to `python` so existing - /// deployments keep their current path until the operator explicitly - /// flips to `rust` and provides a `snapshot` path. - #[serde(default)] - pub backend: EmbedderBackend, - /// Filesystem path to the Jina V4 HuggingFace snapshot directory /// (e.g. `/opt/weaver/huggingface/hub/.../snapshots//`). Read - /// by the in-process Rust backend at boot to construct the - /// `EmbedderClient` and run the cohort-pin probe. - /// - /// Only required when `backend = "rust"`. Ignored under - /// `backend = "python"` (the Python service has its own snapshot - /// resolution via `model_name`). Required-but-missing under `rust` - /// is a hard boot failure. + /// at boot to construct the in-process `EmbedderClient` and run + /// the cohort-pin probe. Required when `[embedder]` is present; + /// missing is a hard boot failure. #[serde(default)] pub snapshot: Option, - /// Unix socket path the Python embedder service binds. + /// **Deprecated post-PR-1.J.** Was the Unix socket path the Python + /// `weaver-embedder.service` bound to. Retained on the struct so + /// existing `server.toml` files don't error on parse; ignored at + /// runtime. #[serde(default = "default_embedder_socket")] pub socket: String, @@ -585,69 +555,26 @@ impl MultiModelConfig { } } - // Validate [embedder] — its GPU must be declared in [[gpus]] with - // policy = "locked" so the orchestrator reserves the card for the - // harness lifetime. Without this, a generation model could bind - // the same GPU and trigger an eviction race. - // - // The `[embedder].gpu` field is deprecated per - // `agent-spu-schema-Spec.md` §8. Surface a runtime warning - // when the Python embedder service path consumes it so - // operators see the deprecation in their logs and start - // migrating to per-agent placement. + // Validate [embedder] — `snapshot` is the only required field + // post-PR-1.J. The Python-service-era validations + // (`gpu`/`socket`/`model_name`/`batch_size`) retired alongside + // the Python service itself; those fields are kept on the + // struct for back-compat with existing `server.toml` files but + // are no longer load-bearing. if let Some(emb) = &self.embedder { - // Reject `backend = "rust"` without a snapshot at parse time so - // operators get a clear error from `weaver serve` startup - // instead of an opaque failure deep in `probe_embedder_rust()` - // after `build_server_state()` has already begun loading models. - // Empty-string `snapshot = ""` deserializes to - // `Some(PathBuf::new())`, which would slip past a bare - // `is_none()` check; treat empty paths as missing too. - // (PR #296 review.) let snapshot_missing = match emb.snapshot.as_ref() { None => true, + // `snapshot = ""` deserializes to Some(PathBuf::new()); + // treat that as missing too. Some(p) => p.as_os_str().is_empty(), }; - if emb.backend == EmbedderBackend::Rust && snapshot_missing { + if snapshot_missing { anyhow::bail!( - "[embedder] backend = \"rust\" requires [embedder].snapshot to point \ - at the Jina V4 HF snapshot directory (e.g. \ + "[embedder].snapshot is required — point it at the Jina V4 HF \ + snapshot directory (e.g. \ `snapshot = \"/opt/weaver/huggingface/.../snapshots//\"`)" ); } - tracing::warn!( - gpu = emb.gpu, - "[embedder].gpu is deprecated post-B′.3b; encoder placement \ - is now per-agent via spu.encoder.gpu. The Python embedder \ - service still reads this field during the migration \ - window; the Rust embedder backend ignores it." - ); - match self.gpus.iter().find(|g| g.id == emb.gpu) { - None => anyhow::bail!( - "[embedder] gpu = {} has no matching [[gpus]] entry — \ - add `[[gpus]] id = {} policy = \"locked\"`", - emb.gpu, - emb.gpu, - ), - Some(g) if g.policy != GpuPolicy::Locked => anyhow::bail!( - "[embedder] gpu = {} must have policy = \"locked\" \ - (got {:?}); the embedder reserves the card for the \ - harness lifetime", - emb.gpu, - g.policy, - ), - _ => {} - } - - if emb.socket.trim().is_empty() { - anyhow::bail!("[embedder] socket must not be empty"); - } - if emb.model_name.trim().is_empty() { - anyhow::bail!("[embedder] model_name must not be empty"); - } - if emb.batch_size == 0 { - anyhow::bail!("[embedder] batch_size must be >= 1"); - } } // Cross-entry validation: reject overlapping GPU assignments among @@ -1767,68 +1694,15 @@ gpu = 0 assert!(config.embedder.is_none()); } - #[test] - fn test_validate_embedder_missing_gpu_entry_rejected() { - // [embedder] gpu = 2 without a [[gpus]] id = 2 must fail — the - // orchestrator has no policy for that card. - let toml = r#" -[embedder] -gpu = 2 -"#; - let err = MultiModelConfig::parse(toml).unwrap_err(); - assert!( - err.to_string().contains("no matching [[gpus]] entry"), - "{err}" - ); - } - - #[test] - fn test_validate_embedder_gpu_not_locked_rejected() { - // [embedder] gpu points at a [[gpus]] entry with the wrong policy. - let toml = r#" -[[gpus]] -id = 2 -policy = "free" - -[embedder] -gpu = 2 -"#; - let err = MultiModelConfig::parse(toml).unwrap_err(); - assert!(err.to_string().contains("policy = \"locked\""), "{err}"); - } - - #[test] - fn test_validate_embedder_blank_socket_rejected() { - let toml = r#" -[[gpus]] -id = 2 -policy = "locked" - -[embedder] -gpu = 2 -socket = "" -"#; - let err = MultiModelConfig::parse(toml).unwrap_err(); - assert!( - err.to_string().contains("socket must not be empty"), - "{err}" - ); - } - - #[test] - fn test_validate_embedder_zero_batch_size_rejected() { - let toml = r#" -[[gpus]] -id = 2 -policy = "locked" - -[embedder] -gpu = 2 -batch_size = 0 -"#; - let err = MultiModelConfig::parse(toml).unwrap_err(); - assert!(err.to_string().contains("batch_size must be >= 1"), "{err}"); - } + // Python-service-era validation tests + // (`test_validate_embedder_missing_gpu_entry_rejected`, + // `test_validate_embedder_gpu_not_locked_rejected`, + // `test_validate_embedder_blank_socket_rejected`, + // `test_validate_embedder_zero_batch_size_rejected`) + // retired in PR-1.J alongside the validations they covered. + // The single remaining required field is `[embedder].snapshot`, + // exercised by `test_validate_embedder_requires_snapshot` / + // `_with_snapshot_ok` / `_rejects_empty_snapshot` below. #[test] fn jina_embedder_fixture_registers_all_three_specialists() { @@ -1885,44 +1759,48 @@ batch_size = 0 // ----- backend selection / snapshot validation (PR #296) ----- #[test] - fn test_validate_rust_backend_requires_snapshot() { + fn test_validate_embedder_requires_snapshot() { + // [embedder] block without `snapshot` errors at parse — there's + // only one embedder backend post-PR-1.J (in-process candle), + // and it needs the snapshot directory to construct the + // `EmbedderClient`. let toml = r#" [[gpus]] id = 2 policy = "locked" [embedder] -backend = "rust" gpu = 2 "#; let err = MultiModelConfig::parse(toml).unwrap_err(); let msg = err.to_string(); assert!( - msg.contains("requires [embedder].snapshot"), + msg.contains("[embedder].snapshot is required"), "expected snapshot-required error, got: {msg}" ); } #[test] - fn test_validate_rust_backend_with_snapshot_ok() { + fn test_validate_embedder_with_snapshot_ok() { let toml = r#" [[gpus]] id = 2 policy = "locked" [embedder] -backend = "rust" gpu = 2 snapshot = "/opt/weaver/huggingface/.../snapshots/abc123/" "#; let config = MultiModelConfig::parse(toml).expect("parse should succeed"); - let emb = config.embedder.as_ref().expect("[embedder] should be present"); - assert_eq!(emb.backend, EmbedderBackend::Rust); + let emb = config + .embedder + .as_ref() + .expect("[embedder] should be present"); assert!(emb.snapshot.is_some(), "snapshot should round-trip"); } #[test] - fn test_validate_rust_backend_rejects_empty_snapshot() { + fn test_validate_embedder_rejects_empty_snapshot() { // `snapshot = ""` deserializes to Some(PathBuf::new()) — the // validation must treat empty paths as missing, not as "set". let toml = r#" @@ -1931,33 +1809,14 @@ id = 2 policy = "locked" [embedder] -backend = "rust" gpu = 2 snapshot = "" "#; let err = MultiModelConfig::parse(toml).unwrap_err(); let msg = err.to_string(); assert!( - msg.contains("requires [embedder].snapshot"), + msg.contains("[embedder].snapshot is required"), "expected snapshot-required error, got: {msg}" ); } - - #[test] - fn test_validate_python_backend_without_snapshot_ok() { - // The default backend is python; no snapshot needed. Confirms - // the new check doesn't accidentally reject existing configs. - let toml = r#" -[[gpus]] -id = 2 -policy = "locked" - -[embedder] -gpu = 2 -"#; - let config = MultiModelConfig::parse(toml).expect("parse should succeed"); - let emb = config.embedder.as_ref().expect("[embedder] should be present"); - assert_eq!(emb.backend, EmbedderBackend::Python); - assert!(emb.snapshot.is_none()); - } } diff --git a/crates/weaver-spu/src/encoder/grpc_client_legacy.rs b/crates/weaver-spu/src/encoder/grpc_client_legacy.rs deleted file mode 100644 index 0ccb6723..00000000 --- a/crates/weaver-spu/src/encoder/grpc_client_legacy.rs +++ /dev/null @@ -1,572 +0,0 @@ -//! Persephone Embedding Client — gRPC client for vector embedding generation. -//! -//! Connects to the Persephone embedding service over a Unix domain socket -//! or TCP endpoint. Wraps the generated tonic client with connection -//! management, health checking, and ergonomic Rust types. - -// Module-scoped `allow(deprecated)` — this file defines the deprecated -// type aliases (`EmbedResult`, `EmbeddingError`, `LateChunkResult`, -// `LateChunkedResult`, `EmbedderInfo`) per PR-0.5.B and uses them -// throughout its own implementation. The deprecations are for *external* -// consumers (so they migrate to `weaver_core::embedder::*` before -// PR-0.5.E); internal self-uses are benign and would otherwise fail -// `-D warnings`. Narrower than a crate-level allow per the reviewer's -// "narrower scoping" guidance — the allow is bounded to this single -// retiring module. -#![allow(deprecated)] - -use std::path::{Path, PathBuf}; -use std::time::Duration; - -use crate::proto::embedding::embedding_service_client::EmbeddingServiceClient; -use crate::proto::embedding::{ - EmbedLateChunkedRequest, EmbedLateChunkedResponse, EmbedRequest, EmbedResponse, - GracefulShutdownRequest, GracefulShutdownResponse, InfoRequest, InfoResponse, -}; -use hyper_util::rt::TokioIo; -use tonic::transport::{Channel, Endpoint, Uri}; -use tower::service_fn; -use tracing::{debug, info, instrument, warn}; - -/// Default Unix-socket path for the Persephone embedder service. -/// -/// Matches `/run/weaver/embedder.sock` — the path systemd's -/// `weaver-embedder.service` binds. Defined here (rather than -/// imported from `weaver-database::config`) because the -/// post-`embedder-oxidization-Spec.md` ontology has -/// `weaver-embedding` upstream of `weaver-database`. Both crates -/// keep their own constant; they must stay in sync — a future PR -/// may consolidate by having `weaver-database` re-export this -/// one. -pub const DEFAULT_EMBEDDER_SOCKET: &str = "/run/weaver/embedder.sock"; - -/// Default timeout for embedding requests. -const DEFAULT_TIMEOUT: Duration = Duration::from_secs(300); // 5 min for large batches -/// Default connection timeout. -const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(10); - -/// Configuration for the embedding client. -#[derive(Debug, Clone)] -pub struct EmbeddingClientConfig { - /// Unix socket path or TCP address for the embedding service. - pub endpoint: EmbeddingEndpoint, - /// Request timeout. - pub timeout: Duration, - /// Connection timeout. - pub connect_timeout: Duration, -} - -/// Endpoint for the embedding service. -#[derive(Debug, Clone)] -pub enum EmbeddingEndpoint { - /// Unix domain socket path. - Unix(PathBuf), - /// TCP address (e.g. "http://localhost:50051"). - Tcp(String), -} - -impl Default for EmbeddingClientConfig { - fn default() -> Self { - Self { - endpoint: EmbeddingEndpoint::Unix(PathBuf::from(DEFAULT_EMBEDDER_SOCKET)), - timeout: DEFAULT_TIMEOUT, - connect_timeout: DEFAULT_CONNECT_TIMEOUT, - } - } -} - -// `EmbeddingError`, `EmbedResult`, `LateChunkedResult`, `LateChunkResult`, -// and `EmbedderInfo` relocated to `weaver-core::embedder` per -// `docs/specs/weaver-spu-Spec.md` PR-0.5.B. Backwards-compat shims here -// during the migration window. Removed in PR-0.5.E. -// -// **Why type aliases (not `pub use`)**: `#[deprecated]` on a `pub use` -// re-export does not reliably emit a warning at consumer call sites -// in stable/nightly rustc — the attribute attaches to the re-export -// itself, not to the consumer-side import. Deprecated type aliases -// emit warnings consistently when consumers reference the alias name, -// giving downstream callers a clear migration path before PR-0.5.E -// removal. Identical underlying types, so any consumer that has -// already migrated to `weaver_core::embedder::*` keeps working -// unchanged. -// -// The trait surface in weaver-core abstracts over transport via -// `EmbeddingError::Transport(Box)`. The -// `From` impls live in weaver-core (orphan rule) and let -// `?` propagate raw tonic errors through the call sites below. - -#[deprecated( - since = "0.1.0", - note = "moved to weaver_core::embedder::EmbedResult; this alias goes away in PR-0.5.E" -)] -pub type EmbedResult = weaver_core::embedder::EmbedResult; - -#[deprecated( - since = "0.1.0", - note = "moved to weaver_core::embedder::EmbedderInfo; this alias goes away in PR-0.5.E" -)] -pub type EmbedderInfo = weaver_core::embedder::EmbedderInfo; - -#[deprecated( - since = "0.1.0", - note = "moved to weaver_core::embedder::EmbeddingError; this alias goes away in PR-0.5.E" -)] -pub type EmbeddingError = weaver_core::embedder::EmbeddingError; - -#[deprecated( - since = "0.1.0", - note = "moved to weaver_core::embedder::LateChunkResult; this alias goes away in PR-0.5.E" -)] -pub type LateChunkResult = weaver_core::embedder::LateChunkResult; - -#[deprecated( - since = "0.1.0", - note = "moved to weaver_core::embedder::LateChunkedResult; this alias goes away in PR-0.5.E" -)] -pub type LateChunkedResult = weaver_core::embedder::LateChunkedResult; - -/// Client for the Persephone embedding service. -/// -/// Provides ergonomic methods for embedding text and querying -/// provider capabilities. -#[derive(Clone)] -pub struct EmbeddingClient { - inner: EmbeddingServiceClient, - config: EmbeddingClientConfig, -} - -impl EmbeddingClient { - /// Connect to the embedding service. - #[instrument(skip_all)] - pub async fn connect(config: EmbeddingClientConfig) -> Result { - let channel = match &config.endpoint { - EmbeddingEndpoint::Unix(path) => { - debug!(socket = %path.display(), "connecting to embedding service via UDS"); - Self::connect_unix(path, &config).await? - } - EmbeddingEndpoint::Tcp(addr) => { - debug!(addr, "connecting to embedding service via TCP"); - Self::connect_tcp(addr, &config).await? - } - }; - - let inner = EmbeddingServiceClient::new(channel); - info!("connected to embedding service"); - - Ok(Self { inner, config }) - } - - /// Connect to the embedding service with default configuration. - pub async fn connect_default() -> Result { - Self::connect(EmbeddingClientConfig::default()).await - } - - /// Connect to an embedding service at the given Unix socket path. - pub async fn connect_unix_at(path: impl Into) -> Result { - let config = EmbeddingClientConfig { - endpoint: EmbeddingEndpoint::Unix(path.into()), - ..Default::default() - }; - Self::connect(config).await - } - - /// Embed a batch of texts into vectors. - /// - /// Returns a vector of embedding vectors, one per input text. - #[instrument(skip(self, texts), fields(count = texts.len()))] - pub async fn embed( - &self, - texts: &[String], - task: &str, - batch_size: Option, - ) -> Result { - let request = EmbedRequest { - texts: texts.to_vec(), - task: task.to_string(), - batch_size: batch_size.unwrap_or(0), - }; - - let response: EmbedResponse = self.inner.clone().embed(request).await?.into_inner(); - - let embeddings: Vec> = response.embeddings.into_iter().map(|e| e.values).collect(); - - if embeddings.len() != texts.len() { - return Err(EmbeddingError::InvalidResponse(format!( - "expected {} embeddings, got {}", - texts.len(), - embeddings.len() - ))); - } - - let expected_dim = response.dimension as usize; - for (i, emb) in embeddings.iter().enumerate() { - if emb.len() != expected_dim { - return Err(EmbeddingError::InvalidResponse(format!( - "embedding[{}] has dimension {}, expected {}", - i, - emb.len(), - expected_dim - ))); - } - } - - debug!( - count = embeddings.len(), - dimension = response.dimension, - duration_ms = response.duration_ms, - model = %response.model, - "embedding complete" - ); - - Ok(EmbedResult { - embeddings, - model: response.model, - dimension: response.dimension, - duration_ms: response.duration_ms, - }) - } - - /// Embed a single text string. - pub async fn embed_one(&self, text: &str, task: &str) -> Result, EmbeddingError> { - let result = self.embed(&[text.to_string()], task, None).await?; - Ok(result.embeddings.into_iter().next().unwrap()) - } - - /// Embed a full document with late chunking. - /// - /// The service encodes the whole document once, then derives per-chunk - /// embeddings from the contextual token states — each chunk's vector - /// reflects its surrounding context. Returns one `LateChunkResult` per - /// chunk, in document order. - #[instrument(skip(self, text), fields(text_len = text.len()))] - pub async fn embed_late_chunked( - &self, - text: &str, - task: &str, - ) -> Result { - let request = EmbedLateChunkedRequest { - text: text.to_string(), - task: task.to_string(), - }; - let response: EmbedLateChunkedResponse = self - .inner - .clone() - .embed_late_chunked(request) - .await? - .into_inner(); - - let expected_dim = response.dimension as usize; - let mut chunks = Vec::with_capacity(response.chunks.len()); - for (i, c) in response.chunks.into_iter().enumerate() { - // Reject missing embeddings outright. `unwrap_or_default()` used - // to mask absent/empty embeddings as a zero-length vector, which - // silently passed the dimension check when expected_dim==0 and - // otherwise failed with a misleading "dimension 0" error. Callers - // need to know the service returned no vector at all. - let embedding = c.embedding.ok_or_else(|| { - EmbeddingError::InvalidResponse(format!( - "late-chunk[{i}] response missing embedding field" - )) - })?; - let emb = embedding.values; - if emb.is_empty() { - return Err(EmbeddingError::InvalidResponse(format!( - "late-chunk[{i}] returned empty embedding vector" - ))); - } - if expected_dim == 0 { - return Err(EmbeddingError::InvalidResponse(format!( - "late-chunk[{i}] has embedding values but response.dimension is 0" - ))); - } - if emb.len() != expected_dim { - return Err(EmbeddingError::InvalidResponse(format!( - "late-chunk[{}] has dimension {}, expected {}", - i, - emb.len(), - expected_dim - ))); - } - chunks.push(LateChunkResult { - text: c.text, - embedding: emb, - start_char: c.start_char, - end_char: c.end_char, - start_token: c.start_token, - end_token: c.end_token, - chunk_index: c.chunk_index, - total_chunks: c.total_chunks, - }); - } - - debug!( - count = chunks.len(), - dimension = response.dimension, - context_window_used = response.context_window_used, - duration_ms = response.duration_ms, - "late-chunked embedding complete" - ); - - Ok(LateChunkedResult { - chunks, - model: response.model, - dimension: response.dimension, - duration_ms: response.duration_ms, - context_window_used: response.context_window_used, - }) - } - - /// Query the embedding provider's capabilities and model info. - #[instrument(skip(self))] - pub async fn info(&self) -> Result { - let response = self.inner.clone().info(InfoRequest {}).await?.into_inner(); - - debug!( - model = %response.model_name, - dimension = response.dimension, - device = %response.device, - loaded = response.model_loaded, - uptime_s = response.uptime_seconds, - "provider info retrieved" - ); - - Ok(response) - } - - /// Check that the service is up and the model is loaded. - /// - /// Returns `NotAvailable` for any pre-flight failure — both the - /// "service reachable but model_loaded=false" case and the wire-level - /// UNAVAILABLE/deadline/transport failures that mean the same thing - /// operationally (can't embed right now). Callers pattern-match on - /// `NotAvailable` and refuse to continue; funnelling here means they - /// don't also have to know the tonic error taxonomy. - #[instrument(skip(self))] - pub async fn ensure_ready(&self) -> Result { - // Transport-class errors (connection failure, gRPC status with - // an "unavailable / deadline / failed-precondition" code) all - // map to the abstract `Transport` variant in - // `weaver-core::embedder`. We downcast to the concrete tonic - // types here to preserve the prior code-class discrimination — - // gRPC `Status` codes UNAVAILABLE / DEADLINE_EXCEEDED / - // FAILED_PRECONDITION are mapped to `NotAvailable`; raw transport - // errors (TCP / Unix socket / HTTP) likewise. Other transport - // errors propagate unchanged. - let info = match self.info().await { - Ok(info) => info, - Err(EmbeddingError::Transport(e)) => { - if let Some(status) = e.downcast_ref::() { - if matches!( - status.code(), - tonic::Code::Unavailable - | tonic::Code::DeadlineExceeded - | tonic::Code::FailedPrecondition, - ) { - return Err(EmbeddingError::NotAvailable(format!( - "embedder Info failed: {} ({})", - status.code(), - status.message() - ))); - } - return Err(EmbeddingError::Transport(e)); - } - if e.downcast_ref::().is_some() { - return Err(EmbeddingError::NotAvailable(format!( - "embedder transport error: {e}" - ))); - } - return Err(EmbeddingError::Transport(e)); - } - Err(e) => return Err(e), - }; - if !info.model_loaded { - return Err(EmbeddingError::NotAvailable(format!( - "service reachable but model_loaded=false (model={})", - info.model_name - ))); - } - Ok(info) - } - - /// Request a graceful shutdown of the embedder service. - /// - /// The service drains in-flight requests (up to `drain_timeout`) and then - /// exits. Returns the drain summary. Harness-side lifecycle management - /// uses this on teardown so the model unloads cleanly instead of being - /// SIGKILLed out of a half-finished batch. - #[instrument(skip(self))] - pub async fn graceful_shutdown( - &self, - drain_timeout: Duration, - ) -> Result { - let request = GracefulShutdownRequest { - drain_timeout_ms: drain_timeout.as_millis().min(u32::MAX as u128) as u32, - }; - let response: GracefulShutdownResponse = self - .inner - .clone() - .graceful_shutdown(request) - .await? - .into_inner(); - Ok(ShutdownSummary { - drained: response.drained, - message: response.message, - }) - } - - /// Check if the service is reachable by calling Info. - /// - /// Uses a short timeout so a stalled service returns `false` quickly - /// rather than blocking for the full request timeout. - pub async fn health_check(&self) -> bool { - match tokio::time::timeout(Duration::from_secs(5), self.info()).await { - Ok(Ok(_)) => true, - Ok(Err(e)) => { - warn!(error = %e, "embedding service health check failed"); - false - } - Err(_) => { - warn!("embedding service health check timed out"); - false - } - } - } - - /// Get the configured endpoint. - pub fn endpoint(&self) -> &EmbeddingEndpoint { - &self.config.endpoint - } - - // ----------------------------------------------------------------------- - // Connection helpers - // ----------------------------------------------------------------------- - - async fn connect_unix( - path: &Path, - config: &EmbeddingClientConfig, - ) -> Result { - let path = path.to_path_buf(); - - // tonic requires a URI even for UDS; the authority is ignored. - let channel = Endpoint::from_static("http://[::]:50051") - .timeout(config.timeout) - .connect_timeout(config.connect_timeout) - .connect_with_connector(service_fn(move |_: Uri| { - let path = path.clone(); - async move { - let stream = tokio::net::UnixStream::connect(path).await?; - Ok::<_, std::io::Error>(TokioIo::new(stream)) - } - })) - .await?; - - Ok(channel) - } - - async fn connect_tcp( - addr: &str, - config: &EmbeddingClientConfig, - ) -> Result { - let channel = Endpoint::from_shared(addr.to_string())? - .timeout(config.timeout) - .connect_timeout(config.connect_timeout) - .connect() - .await?; - - Ok(channel) - } -} - -// `EmbedResult`, `LateChunkedResult`, `LateChunkResult` definitions -// relocated to `weaver-core::embedder` per PR-0.5.B; re-exported above -// from this module's existing public surface for backward compat -// during the migration window. - -/// Summary returned by a `graceful_shutdown` call. -#[derive(Debug, Clone)] -pub struct ShutdownSummary { - /// True if the service drained in-flight requests cleanly; - /// false if the drain timeout expired. - pub drained: bool, - /// Human-readable drain summary from the service. - pub message: String, -} - -impl std::fmt::Debug for EmbeddingClient { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("EmbeddingClient") - .field("endpoint", &self.config.endpoint) - .finish() - } -} - -// ===================================================================== -// `Embedder` trait impl -// ===================================================================== -// -// The gRPC backend presented as the backend-agnostic [`Embedder`] trait -// per `embedder-oxidization-Spec.md` §3 / §5.2. Consumer code that takes -// `Arc` (the surfacing engine, Notepad, Pen, sleep stage A, -// preseed) operates against this impl today and against the future Rust -// in-process impl post-Phase 1 cutover, with no consumer-side changes. -// -// The `info()` impl converts the proto `InfoResponse` to the -// backend-agnostic [`weaver_core::embedder::EmbedderInfo`] so the trait -// surface stays clean of proto types. - -#[async_trait::async_trait] -impl weaver_core::embedder::Embedder for EmbeddingClient { - async fn embed( - &self, - texts: &[String], - task: &str, - batch_size: Option, - ) -> Result { - // Delegate to the inherent method — the trait method's - // signature matches exactly so this is a thin forward. - EmbeddingClient::embed(self, texts, task, batch_size).await - } - - async fn embed_late_chunked( - &self, - text: &str, - task: &str, - ) -> Result { - EmbeddingClient::embed_late_chunked(self, text, task).await - } - - async fn info(&self) -> Result { - let resp = EmbeddingClient::info(self).await?; - Ok(weaver_core::embedder::EmbedderInfo { - model_name: resp.model_name, - model_loaded: resp.model_loaded, - dimension: resp.dimension, - max_seq_length: resp.max_seq_length, - weight_revision: resp.weight_revision, - }) - } - - // `embed_one` uses the trait's default impl. The inherent - // `EmbeddingClient::embed_one` exists for direct callers but - // the trait surface goes through `embed` for consistency. - - /// Override the trait's default `health_check` to bound the - /// `info()` call's latency at 5 seconds *while preserving the - /// trait-level readiness semantics* (must report - /// `model_loaded == true`). The trait's default impl runs `info()` - /// with no upper bound, so a stuck embedder hangs callers for the - /// full 300-second request timeout. The inherent - /// `EmbeddingClient::health_check` has the timeout but treats any - /// successful `info()` response as healthy — including - /// `model_loaded == false`, which is wrong for a readiness check - /// that consumers using `Arc` rely on. - /// - /// This implementation combines both: 5-second timeout AND - /// `model_loaded == true`. - async fn health_check(&self) -> bool { - match tokio::time::timeout(Duration::from_secs(5), EmbeddingClient::info(self)).await { - Ok(Ok(info)) => info.model_loaded, - Ok(Err(_)) | Err(_) => false, - } - } -} diff --git a/crates/weaver-spu/src/encoder/mod.rs b/crates/weaver-spu/src/encoder/mod.rs index 3f15a09d..3ad9616e 100644 --- a/crates/weaver-spu/src/encoder/mod.rs +++ b/crates/weaver-spu/src/encoder/mod.rs @@ -69,12 +69,11 @@ pub mod jina_v4; #[cfg(feature = "candle")] pub mod client; -// `grpc_client_legacy` is NOT feature-gated: it's the production -// embedder backend during the migration window (talks to the Python -// `weaver-embedder.service` over Unix socket gRPC). Always-available -// because the daemon needs an embedder regardless of features. -// Retires in PR-3.A. -pub mod grpc_client_legacy; +// `grpc_client_legacy` was the migration-window embedder backend +// talking to the Python `weaver-embedder.service` over Unix socket +// gRPC. Retired in PR-1.J alongside the Python service itself; the +// in-process `client::EmbedderClient` (PR-1.F) is the only embedder +// backend now. // `gguf_backend` is feature-gated because it pulls llama-cpp-2 (a // substantial transitive build cost). Used only by experimental diff --git a/crates/weaver-spu/src/lib.rs b/crates/weaver-spu/src/lib.rs index 8f99bb84..df0fe482 100644 --- a/crates/weaver-spu/src/lib.rs +++ b/crates/weaver-spu/src/lib.rs @@ -54,17 +54,3 @@ pub mod decoder; pub mod encoder; pub mod models; -/// Persephone gRPC/protobuf definitions — embedding service. -/// -/// Generated by `build.rs` from the workspace -/// `proto/persephone/embedding/embedding.proto`. Folded in from -/// `weaver-embedding::proto` via PR-0.5.D. Always available (the -/// gRPC client to the Python embedder is the production backend -/// during the migration window, not a feature). Removed in Phase 3 -/// PR-3.A alongside the gRPC client itself. -pub mod proto { - /// Embedding service — vector embedding generation. - pub mod embedding { - tonic::include_proto!("persephone.embedding"); - } -} diff --git a/crates/weaver-spu/tests/grpc_client_legacy.rs b/crates/weaver-spu/tests/grpc_client_legacy.rs deleted file mode 100644 index 7d596597..00000000 --- a/crates/weaver-spu/tests/grpc_client_legacy.rs +++ /dev/null @@ -1,691 +0,0 @@ -//! Tests for the Persephone embedding client. -//! -//! Connection tests require a running Persephone embedding service. -//! Type and config tests run without a service. -//! -//! Integration tests at the bottom spin up an in-process tonic server -//! backed by a mock `EmbeddingService` and exercise the full client RPC -//! surface (Embed, EmbedLateChunked, Info, ensure_ready, GracefulShutdown). - -// Module-scoped `allow(deprecated)` — these tests reference the -// `weaver_spu::encoder::grpc_client_legacy::{EmbedResult, EmbeddingError, ...}` -// type aliases that were marked `#[deprecated]` in PR-0.5.B (they now -// alias `weaver_core::embedder::*`). The tests stay on the legacy path -// because they're testing the legacy gRPC client; the whole crate -// retires in PR-0.5.E. Scoped to this single test file rather than -// crate-wide. -#![allow(deprecated)] - -use std::path::PathBuf; - -use weaver_spu::encoder::grpc_client_legacy::{EmbedResult, EmbeddingClientConfig, EmbeddingEndpoint}; - -#[test] -fn test_config_defaults() { - let config = EmbeddingClientConfig::default(); - match &config.endpoint { - EmbeddingEndpoint::Unix(path) => { - assert_eq!(path, &PathBuf::from("/run/weaver/embedder.sock")); - } - EmbeddingEndpoint::Tcp(_) => panic!("expected Unix endpoint by default"), - } - assert_eq!(config.timeout.as_secs(), 300); - assert_eq!(config.connect_timeout.as_secs(), 10); -} - -#[test] -fn test_embed_result_struct() { - let result = EmbedResult { - embeddings: vec![vec![0.1, 0.2, 0.3]], - model: "test-model".to_string(), - dimension: 3, - duration_ms: 42, - }; - assert_eq!(result.embeddings.len(), 1); - assert_eq!(result.embeddings[0].len(), 3); - assert_eq!(result.dimension, 3); -} - -#[test] -fn test_endpoint_variants() { - let unix = EmbeddingEndpoint::Unix(PathBuf::from("/tmp/test.sock")); - let tcp = EmbeddingEndpoint::Tcp("http://localhost:50051".to_string()); - - // Just verify they construct without panic - assert!(matches!(unix, EmbeddingEndpoint::Unix(_))); - assert!(matches!(tcp, EmbeddingEndpoint::Tcp(_))); -} - -#[tokio::test] -async fn test_connect_nonexistent_socket_fails() { - use weaver_spu::encoder::grpc_client_legacy::EmbeddingClient; - - let dir = tempfile::tempdir().expect("failed to create tempdir"); - let socket_path = dir.path().join("nonexistent.sock"); - - let config = EmbeddingClientConfig { - endpoint: EmbeddingEndpoint::Unix(socket_path), - connect_timeout: std::time::Duration::from_millis(500), - ..Default::default() - }; - - let result = EmbeddingClient::connect(config).await; - assert!( - result.is_err(), - "expected connection error for nonexistent socket" - ); -} - -// ----------------------------------------------------------------------------- -// Integration tests — mock service over a real TCP listener. -// ----------------------------------------------------------------------------- - -/// Sentinel HF snapshot SHA used by the mock Info handler so that -/// `test_info_roundtrip` can assert the `weight_revision` field survives -/// the tonic serialize/deserialize cycle. A value of `String::new()` would -/// silently pass even if the wire format dropped the field. -pub const MOCK_WEIGHT_REVISION: &str = "0000000000000000000000000000000000000000"; - -mod mock_server { - use std::sync::{ - Arc, - atomic::{AtomicU32, Ordering}, - }; - - use super::MOCK_WEIGHT_REVISION; - - use tokio::net::TcpListener; - use tokio::sync::oneshot; - use tokio::task::JoinHandle; - use tokio_stream::wrappers::TcpListenerStream; - use tonic::{Request, Response, Status, transport::Server}; - - use weaver_spu::proto::embedding::embedding_service_server::{ - EmbeddingService, EmbeddingServiceServer, - }; - use weaver_spu::proto::embedding::{ - EmbedLateChunkedRequest, EmbedLateChunkedResponse, EmbedRequest, EmbedResponse, Embedding, - GracefulShutdownRequest, GracefulShutdownResponse, InfoRequest, InfoResponse, LateChunk, - }; - - #[allow(dead_code)] - pub struct MockServer { - pub tcp_url: String, - pub shutdown: Option>, - pub handle: JoinHandle<()>, - pub shutdown_count: Arc, - } - - pub struct MockService { - pub model_loaded: bool, - pub shutdown_count: Arc, - /// When `Some(d)`, the mock reports `response.dimension = d` even - /// though per-chunk embeddings are still 4-wide. Lets a test force - /// the server-inconsistency case (e.g. `Some(0)` paired with - /// non-empty chunks — the Rust client must reject this). - pub dimension_override: Option, - /// When `Some(status)`, `info()` returns the supplied tonic - /// `Status` instead of a successful `InfoResponse`. Used to - /// exercise `EmbeddingClient::ensure_ready`'s downcast path - /// (PR-0.5.B abstracted EmbeddingError over transport via - /// `Box`; ensure_ready downcasts back to specific - /// transport-error types to discriminate UNAVAILABLE-class - /// failures from other status codes). Default `None`. - pub info_error: Option, - } - - #[tonic::async_trait] - impl EmbeddingService for MockService { - async fn embed( - &self, - request: Request, - ) -> Result, Status> { - let req = request.into_inner(); - if req.texts.is_empty() { - return Err(Status::invalid_argument("no texts")); - } - let embeddings = req - .texts - .iter() - .map(|t| Embedding { - values: vec![t.len() as f32; 4], - }) - .collect::>(); - Ok(Response::new(EmbedResponse { - embeddings, - model: "mock".into(), - dimension: 4, - duration_ms: 1, - })) - } - - async fn embed_late_chunked( - &self, - request: Request, - ) -> Result, Status> { - let req = request.into_inner(); - if req.text.is_empty() { - return Err(Status::invalid_argument("empty text")); - } - // Split on a char boundary — byte-halving would panic on any - // multibyte codepoint that straddles the midpoint, and the - // start_char/end_char fields are defined as character counts, - // not byte offsets. - let total_chars = req.text.chars().count() as u32; - let half_chars = total_chars / 2; - let split_byte = req - .text - .char_indices() - .nth(half_chars as usize) - .map(|(i, _)| i) - .unwrap_or(req.text.len()); - let chunks = vec![ - LateChunk { - text: req.text[..split_byte].to_string(), - embedding: Some(Embedding { - values: vec![1.0, 0.0, 0.0, 0.0], - }), - start_char: 0, - end_char: half_chars, - start_token: 0, - end_token: half_chars / 4, - chunk_index: 0, - total_chunks: 2, - }, - LateChunk { - text: req.text[split_byte..].to_string(), - embedding: Some(Embedding { - values: vec![0.0, 1.0, 0.0, 0.0], - }), - start_char: half_chars, - end_char: total_chars, - start_token: half_chars / 4, - end_token: total_chars / 4, - chunk_index: 1, - total_chunks: 2, - }, - ]; - Ok(Response::new(EmbedLateChunkedResponse { - chunks, - model: "mock".into(), - dimension: self.dimension_override.unwrap_or(4), - duration_ms: 2, - context_window_used: total_chars / 4, - })) - } - - async fn info( - &self, - _request: Request, - ) -> Result, Status> { - if let Some(err) = self.info_error.clone() { - return Err(err); - } - Ok(Response::new(InfoResponse { - model_name: "mock".into(), - dimension: 4, - max_seq_length: 1024, - supported_tasks: vec!["retrieval.passage".into()], - device: "cpu".into(), - model_loaded: self.model_loaded, - uptime_seconds: 1.0, - weight_revision: MOCK_WEIGHT_REVISION.to_string(), - })) - } - - async fn graceful_shutdown( - &self, - request: Request, - ) -> Result, Status> { - let _ = request.into_inner(); - self.shutdown_count.fetch_add(1, Ordering::SeqCst); - Ok(Response::new(GracefulShutdownResponse { - drained: true, - message: "mock drained".into(), - })) - } - } - - /// Spawn a mock gRPC server on a loopback TCP port. Returns a handle - /// whose `.stop().await` shuts the server down. TCP is used instead of - /// UDS to avoid implementing tonic's `Connected` trait for UnixStream — - /// the wire-level gRPC contract is the same either way. - pub async fn spawn(service: MockService) -> MockServer { - let (tx, rx) = oneshot::channel(); - let shutdown_count = service.shutdown_count.clone(); - let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind tcp"); - let addr = listener.local_addr().expect("addr"); - let incoming = TcpListenerStream::new(listener); - - let svc = EmbeddingServiceServer::new(service); - let server = Server::builder().add_service(svc); - - let handle = tokio::spawn(async move { - server - .serve_with_incoming_shutdown(incoming, async { - let _ = rx.await; - }) - .await - .expect("server exited cleanly"); - }); - - // Poll the bound addr until the server is actually accepting. A fixed - // sleep here was the textbook flaky-CI pattern — 50ms worked on this - // laptop but nothing guarantees it holds under a loaded runner. - let deadline = tokio::time::Instant::now() + std::time::Duration::from_secs(2); - loop { - if tokio::net::TcpStream::connect(addr).await.is_ok() { - break; - } - if tokio::time::Instant::now() >= deadline { - panic!("mock gRPC server did not accept connections on {addr} within 2s"); - } - tokio::time::sleep(std::time::Duration::from_millis(10)).await; - } - - MockServer { - tcp_url: format!("http://{addr}"), - shutdown: Some(tx), - handle, - shutdown_count, - } - } - - impl MockServer { - pub async fn stop(mut self) { - if let Some(tx) = self.shutdown.take() { - let _ = tx.send(()); - } - // Surface server-task panics. The spawn closure does - // `.expect("server exited cleanly")`; if that fires, we want the - // test to fail loudly — silently swallowing JoinError used to - // let mock-side panics masquerade as passing tests. - self.handle.await.expect("mock server task panicked"); - } - } -} - -/// Shared helper: build a client against the mock server's TCP URL. Every -/// integration test needs the same `EmbeddingClientConfig { endpoint: Tcp(…), -/// ..Default::default() }` dance, so centralize it here. -async fn connect_tcp_client(url: String) -> weaver_spu::encoder::grpc_client_legacy::EmbeddingClient { - use weaver_spu::encoder::grpc_client_legacy::{ - EmbeddingClient, EmbeddingClientConfig, EmbeddingEndpoint, - }; - EmbeddingClient::connect(EmbeddingClientConfig { - endpoint: EmbeddingEndpoint::Tcp(url), - ..Default::default() - }) - .await - .expect("connect to mock TCP endpoint") -} - -#[tokio::test] -async fn test_info_roundtrip() { - use mock_server::{MockService, spawn}; - use std::sync::{Arc, atomic::AtomicU32}; - - let srv = spawn(MockService { - model_loaded: true, - shutdown_count: Arc::new(AtomicU32::new(0)), - dimension_override: None, - info_error: None, - }) - .await; - - let client = connect_tcp_client(srv.tcp_url.clone()).await; - let info = client.info().await.unwrap(); - assert_eq!(info.model_name, "mock"); - assert!(info.model_loaded); - assert_eq!(info.dimension, 4); - // Positive-path roundtrip check for the SG5-C cohort-identity field. - // If wire-format generation drops the new field, the mock's non-empty - // sentinel will surface as an empty string on the client side. - assert_eq!(info.weight_revision, MOCK_WEIGHT_REVISION); - - srv.stop().await; -} - -#[tokio::test] -async fn test_embed_roundtrip() { - use mock_server::{MockService, spawn}; - use std::sync::{Arc, atomic::AtomicU32}; - - let srv = spawn(MockService { - model_loaded: true, - shutdown_count: Arc::new(AtomicU32::new(0)), - dimension_override: None, - info_error: None, - }) - .await; - - let client = connect_tcp_client(srv.tcp_url.clone()).await; - let result = client - .embed( - &["alpha".into(), "bravo-bravo".into()], - "retrieval.passage", - None, - ) - .await - .unwrap(); - assert_eq!(result.embeddings.len(), 2); - assert_eq!(result.dimension, 4); - assert_eq!(result.embeddings[0][0], 5.0); - assert_eq!(result.embeddings[1][0], 11.0); - - srv.stop().await; -} - -#[tokio::test] -async fn test_embed_late_chunked_roundtrip() { - use mock_server::{MockService, spawn}; - use std::sync::{Arc, atomic::AtomicU32}; - - let srv = spawn(MockService { - model_loaded: true, - shutdown_count: Arc::new(AtomicU32::new(0)), - dimension_override: None, - info_error: None, - }) - .await; - - let client = connect_tcp_client(srv.tcp_url.clone()).await; - // 43 chars — the mock splits at half_chars = 21, so spans/tokens are - // deterministic. Assert the full span contract so the proto field - // mapping stays honest in both directions. - let text = "the quick brown fox jumps over the lazy dog"; - assert_eq!(text.chars().count(), 43); - let result = client - .embed_late_chunked(text, "retrieval.passage") - .await - .unwrap(); - assert_eq!(result.chunks.len(), 2); - assert_eq!(result.dimension, 4); - assert_eq!(result.context_window_used, 10); // 43 / 4 - - let c0 = &result.chunks[0]; - assert_eq!(c0.chunk_index, 0); - assert_eq!(c0.total_chunks, 2); - assert_eq!(c0.start_char, 0); - assert_eq!(c0.end_char, 21); - assert_eq!(c0.start_token, 0); - assert_eq!(c0.end_token, 5); // 21 / 4 - assert_eq!(c0.embedding.len(), 4); - - let c1 = &result.chunks[1]; - assert_eq!(c1.chunk_index, 1); - assert_eq!(c1.total_chunks, 2); - assert_eq!(c1.start_char, 21); - assert_eq!(c1.end_char, 43); - assert_eq!(c1.start_token, 5); - assert_eq!(c1.end_token, 10); // 43 / 4 - assert_eq!(c1.embedding.len(), 4); - - srv.stop().await; -} - -/// Server reports `dimension = 0` but still ships non-empty per-chunk vectors. -/// The client must reject this as an inconsistent response rather than pass -/// "dimension 0" metadata downstream alongside populated embeddings. -#[tokio::test] -async fn test_embed_late_chunked_rejects_zero_dimension_with_values() { - use mock_server::{MockService, spawn}; - use std::sync::{Arc, atomic::AtomicU32}; - use weaver_spu::encoder::grpc_client_legacy::EmbeddingError; - - let srv = spawn(MockService { - model_loaded: true, - shutdown_count: Arc::new(AtomicU32::new(0)), - dimension_override: Some(0), - info_error: None, - }) - .await; - - let client = connect_tcp_client(srv.tcp_url.clone()).await; - let err = client - .embed_late_chunked("the quick brown fox jumps", "retrieval.passage") - .await - .expect_err("dimension=0 with non-empty chunks should be rejected"); - assert!( - matches!(err, EmbeddingError::InvalidResponse(ref m) if m.contains("dimension is 0")), - "expected InvalidResponse with dimension=0 diagnostic, got: {err:?}" - ); - - srv.stop().await; -} - -#[tokio::test] -async fn test_ensure_ready_ok_when_loaded() { - use mock_server::{MockService, spawn}; - use std::sync::{Arc, atomic::AtomicU32}; - - let srv = spawn(MockService { - model_loaded: true, - shutdown_count: Arc::new(AtomicU32::new(0)), - dimension_override: None, - info_error: None, - }) - .await; - - let client = connect_tcp_client(srv.tcp_url.clone()).await; - let info = client - .ensure_ready() - .await - .expect("service should be ready"); - assert!(info.model_loaded); - - srv.stop().await; -} - -#[tokio::test] -async fn test_ensure_ready_rejects_unloaded_model() { - use mock_server::{MockService, spawn}; - use std::sync::{Arc, atomic::AtomicU32}; - use weaver_spu::encoder::grpc_client_legacy::EmbeddingError; - - let srv = spawn(MockService { - model_loaded: false, - shutdown_count: Arc::new(AtomicU32::new(0)), - dimension_override: None, - info_error: None, - }) - .await; - - let client = connect_tcp_client(srv.tcp_url.clone()).await; - let err = client - .ensure_ready() - .await - .expect_err("unloaded model should error"); - assert!(matches!(err, EmbeddingError::NotAvailable(_))); - - srv.stop().await; -} - -#[tokio::test] -async fn test_graceful_shutdown_roundtrip() { - use mock_server::{MockService, spawn}; - use std::sync::{ - Arc, - atomic::{AtomicU32, Ordering}, - }; - use std::time::Duration; - - let counter = Arc::new(AtomicU32::new(0)); - let srv = spawn(MockService { - model_loaded: true, - shutdown_count: counter.clone(), - dimension_override: None, - info_error: None, - }) - .await; - - let client = connect_tcp_client(srv.tcp_url.clone()).await; - let summary = client - .graceful_shutdown(Duration::from_millis(250)) - .await - .unwrap(); - assert!(summary.drained); - assert!(summary.message.contains("drained")); - assert_eq!(counter.load(Ordering::SeqCst), 1); - - srv.stop().await; -} - -// ----------------------------------------------------------------------------- -// `ensure_ready` Transport-downcast regression tests (PR-0.5.B). -// -// PR-0.5.B abstracted `EmbeddingError` over transport via -// `Transport(Box)`. `EmbeddingClient::ensure_ready` -// downcasts the boxed error back to concrete tonic types -// (`tonic::Status`, `tonic::transport::Error`) so that -// UNAVAILABLE/DEADLINE_EXCEEDED/FAILED_PRECONDITION codes and raw -// transport failures both surface as `EmbeddingError::NotAvailable` (the -// shape callers pattern-match on). These tests pin that downcast + -// remap behavior so refactors of the abstract `Transport` payload -// can't silently break it. -// ----------------------------------------------------------------------------- - -#[tokio::test] -async fn test_ensure_ready_maps_unavailable_status_to_not_available() { - use mock_server::{MockService, spawn}; - use std::sync::{Arc, atomic::AtomicU32}; - use tonic::Status; - use weaver_spu::encoder::grpc_client_legacy::EmbeddingError; - - let srv = spawn(MockService { - model_loaded: true, - shutdown_count: Arc::new(AtomicU32::new(0)), - dimension_override: None, - info_error: Some(Status::unavailable("backend warming up")), - }) - .await; - - let client = connect_tcp_client(srv.tcp_url.clone()).await; - let err = client.ensure_ready().await.unwrap_err(); - match err { - EmbeddingError::NotAvailable(msg) => { - assert!( - msg.contains("Unavailable") || msg.contains("warming up"), - "msg = {msg:?}" - ); - } - other => panic!("expected NotAvailable, got {other:?}"), - } - - srv.stop().await; -} - -#[tokio::test] -async fn test_ensure_ready_maps_deadline_exceeded_status_to_not_available() { - use mock_server::{MockService, spawn}; - use std::sync::{Arc, atomic::AtomicU32}; - use tonic::Status; - use weaver_spu::encoder::grpc_client_legacy::EmbeddingError; - - let srv = spawn(MockService { - model_loaded: true, - shutdown_count: Arc::new(AtomicU32::new(0)), - dimension_override: None, - info_error: Some(Status::deadline_exceeded("timed out")), - }) - .await; - - let client = connect_tcp_client(srv.tcp_url.clone()).await; - let err = client.ensure_ready().await.unwrap_err(); - assert!( - matches!(err, EmbeddingError::NotAvailable(_)), - "expected NotAvailable for DEADLINE_EXCEEDED, got {err:?}" - ); - - srv.stop().await; -} - -#[tokio::test] -async fn test_ensure_ready_maps_failed_precondition_status_to_not_available() { - use mock_server::{MockService, spawn}; - use std::sync::{Arc, atomic::AtomicU32}; - use tonic::Status; - use weaver_spu::encoder::grpc_client_legacy::EmbeddingError; - - let srv = spawn(MockService { - model_loaded: true, - shutdown_count: Arc::new(AtomicU32::new(0)), - dimension_override: None, - info_error: Some(Status::failed_precondition("model not loaded yet")), - }) - .await; - - let client = connect_tcp_client(srv.tcp_url.clone()).await; - let err = client.ensure_ready().await.unwrap_err(); - assert!( - matches!(err, EmbeddingError::NotAvailable(_)), - "expected NotAvailable for FAILED_PRECONDITION, got {err:?}" - ); - - srv.stop().await; -} - -#[tokio::test] -async fn test_ensure_ready_propagates_other_status_codes_as_transport() { - // Negative case: a non-UNAVAILABLE-class status code (e.g. INTERNAL) - // should propagate as `Transport`, not get re-mapped to NotAvailable. - // This pins the discrimination logic — callers shouldn't see - // arbitrary backend errors masked as "service is briefly down." - use mock_server::{MockService, spawn}; - use std::sync::{Arc, atomic::AtomicU32}; - use tonic::Status; - use weaver_spu::encoder::grpc_client_legacy::EmbeddingError; - - let srv = spawn(MockService { - model_loaded: true, - shutdown_count: Arc::new(AtomicU32::new(0)), - dimension_override: None, - info_error: Some(Status::internal("oops")), - }) - .await; - - let client = connect_tcp_client(srv.tcp_url.clone()).await; - let err = client.ensure_ready().await.unwrap_err(); - assert!( - matches!(err, EmbeddingError::Transport(_)), - "expected Transport for INTERNAL, got {err:?}" - ); - - srv.stop().await; -} - -#[tokio::test] -async fn test_ensure_ready_maps_dropped_server_to_not_available() { - // Raw `tonic::transport::Error` path: spin up the server, connect, - // stop the server, then call `ensure_ready`. The follow-up RPC - // fails at the transport layer (server gone), exercising the - // downcast-to-`tonic::transport::Error` arm in - // `ensure_ready`'s match. - use mock_server::{MockService, spawn}; - use std::sync::{Arc, atomic::AtomicU32}; - use weaver_spu::encoder::grpc_client_legacy::EmbeddingError; - - let srv = spawn(MockService { - model_loaded: true, - shutdown_count: Arc::new(AtomicU32::new(0)), - dimension_override: None, - info_error: None, - }) - .await; - - let tcp_url = srv.tcp_url.clone(); - let client = connect_tcp_client(tcp_url).await; - // First call succeeds — confirms client + server are healthy. - client.info().await.expect("info() succeeds while server is up"); - - // Drop the server, then issue ensure_ready. Whether the next call - // surfaces as a tonic::transport::Error or an UNAVAILABLE Status - // depends on tonic's internal state machine, but BOTH paths must - // map to NotAvailable per ensure_ready's contract. - srv.stop().await; - let err = client.ensure_ready().await.unwrap_err(); - assert!( - matches!(err, EmbeddingError::NotAvailable(_)), - "expected NotAvailable after server drop, got {err:?}" - ); -} diff --git a/proto/persephone/common/common.proto b/proto/persephone/common/common.proto deleted file mode 100644 index 099c2da5..00000000 --- a/proto/persephone/common/common.proto +++ /dev/null @@ -1,36 +0,0 @@ -// Common types shared across Persephone services. - -syntax = "proto3"; - -package persephone.common; - -// Text chunking strategy. -enum ChunkingStrategy { - // Late chunking: encode full document, pool token embeddings per chunk. - // Preserves full document context in every chunk embedding. - CHUNKING_STRATEGY_LATE = 0; - - // Semantic chunking: split on paragraph/sentence boundaries. - CHUNKING_STRATEGY_SEMANTIC = 1; - - // Sliding window: fixed token window with overlap. - CHUNKING_STRATEGY_SLIDING = 2; - - // Token-based: split at token boundaries with configurable size. - CHUNKING_STRATEGY_TOKEN = 3; -} - -// Metadata for a single text chunk within a document. -message ChunkMetadata { - // Zero-based index of this chunk within the document. - uint32 chunk_index = 1; - - // Total number of chunks in the document. - uint32 total_chunks = 2; - - // Start character offset in the original document text. - uint64 start_char = 3; - - // End character offset (exclusive) in the original document text. - uint64 end_char = 4; -} diff --git a/proto/persephone/embedding/embedding.proto b/proto/persephone/embedding/embedding.proto deleted file mode 100644 index 290a9f90..00000000 --- a/proto/persephone/embedding/embedding.proto +++ /dev/null @@ -1,165 +0,0 @@ -// Persephone Embedding Service — vector embedding generation. -// -// Wraps the Jina V4 FP16 embedder behind a typed gRPC contract. The service -// runs on a Unix socket (/run/weaver/embedder.sock) and manages GPU memory -// lifecycle internally. -// -// The embedder is harness-scoped mandatory infrastructure (same tier as -// ArangoDB CE) — one embedder, pinned to Jina V4 FP16, not swappable. The -// harness owns lifecycle via systemd BindsTo=. - -syntax = "proto3"; - -package persephone.embedding; - -// Embedding service for generating vector representations of text. -service EmbeddingService { - // Embed a batch of texts into vectors. - rpc Embed(EmbedRequest) returns (EmbedResponse); - - // Embed a document with late chunking — returns per-chunk embeddings whose - // contextual representation reflects the whole document, not just the chunk. - rpc EmbedLateChunked(EmbedLateChunkedRequest) returns (EmbedLateChunkedResponse); - - // Query provider capabilities and model metadata. - rpc Info(InfoRequest) returns (InfoResponse); - - // Gracefully drain in-flight requests and shut down the service. - rpc GracefulShutdown(GracefulShutdownRequest) returns (GracefulShutdownResponse); -} - -// Request to embed one or more text strings. -message EmbedRequest { - // Text strings to embed. - repeated string texts = 1; - - // Embedding task label — controls model behavior. - // Common values: "retrieval.passage", "retrieval.query", "code", "text-matching" - string task = 2; - - // Optional batch size hint for the provider. - // The provider may ignore this and use its own default. - uint32 batch_size = 3; -} - -// Response containing computed embeddings. -message EmbedResponse { - // One Embedding per input text, in the same order. - repeated Embedding embeddings = 1; - - // Model identifier used to generate the embeddings. - string model = 2; - - // Embedding vector dimension. - uint32 dimension = 3; - - // Wall-clock time for the embedding computation in milliseconds. - uint64 duration_ms = 4; -} - -// Request for late-chunked embedding of a full document. -message EmbedLateChunkedRequest { - // Full document text. Late chunking encodes the whole document once, - // then derives per-chunk embeddings from the contextual token states. - string text = 1; - - // Embedding task label (see EmbedRequest.task). - string task = 2; -} - -// Response containing per-chunk embeddings produced by late chunking. -message EmbedLateChunkedResponse { - // One chunk per span, in document order. - repeated LateChunk chunks = 1; - - // Model identifier used to generate the embeddings. - string model = 2; - - // Embedding vector dimension. - uint32 dimension = 3; - - // Wall-clock time in milliseconds. - uint64 duration_ms = 4; - - // Number of tokens used in the full-document context window. - uint32 context_window_used = 5; -} - -// A single late-chunked span: the chunk text, its span in the source document, -// and its context-aware embedding. -message LateChunk { - // Chunk text (substring of the source document). - string text = 1; - - // Context-aware embedding for this chunk. - Embedding embedding = 2; - - // Character-level span within the source document. - uint32 start_char = 3; - uint32 end_char = 4; - - // Token-level span within the source document. - uint32 start_token = 5; - uint32 end_token = 6; - - // Position of this chunk among the document's chunks. - uint32 chunk_index = 7; - uint32 total_chunks = 8; -} - -// A single embedding vector. -message Embedding { - // The embedding values (length == dimension from EmbedResponse). - repeated float values = 1; -} - -// Empty request for the Info RPC. -message InfoRequest {} - -// Provider capabilities and model metadata. -message InfoResponse { - // Model identifier (e.g. "jinaai/jina-embeddings-v4"). - string model_name = 1; - - // Output embedding dimension (e.g. 2048). - uint32 dimension = 2; - - // Maximum input sequence length in tokens. - uint32 max_seq_length = 3; - - // Supported task labels. - repeated string supported_tasks = 4; - - // Device the model is loaded on (e.g. "cuda:0", "cpu"). - string device = 5; - - // Whether the model is currently loaded in memory. - bool model_loaded = 6; - - // Service uptime in seconds. - double uptime_seconds = 7; - - // HuggingFace snapshot revision SHA for the loaded weights — read from - // /models----/refs/main at load time. Empty if - // unknown (offline load, non-HF source, corrupt cache). The harness - // refuse-to-start check at /opt/weaver/state/embedder.pin.json compares - // this against the pinned value — a mismatch means the weights changed - // under a previously-pinned model_name and the corpus must be re-embedded. - string weight_revision = 8; -} - -// Request to gracefully shut down the service. -message GracefulShutdownRequest { - // Maximum time to wait for in-flight requests to drain before - // forcing shutdown. 0 = drain without timeout. - uint32 drain_timeout_ms = 1; -} - -// Response from a graceful shutdown request. -message GracefulShutdownResponse { - // True if the service drained cleanly; false if the drain timeout expired. - bool drained = 1; - - // Human-readable summary (e.g. "drained 3 in-flight requests in 412ms"). - string message = 2; -} diff --git a/proto/persephone/extraction/extraction.proto b/proto/persephone/extraction/extraction.proto deleted file mode 100644 index 119980fc..00000000 --- a/proto/persephone/extraction/extraction.proto +++ /dev/null @@ -1,143 +0,0 @@ -// Persephone Extraction Service — document content extraction. -// -// Wraps document extraction backends (Docling VLM, LaTeX parser, -// TreeSitter AST) behind a typed gRPC contract. The service -// handles source type routing internally. - -syntax = "proto3"; - -package persephone.extraction; - -// Extraction service for converting documents into structured text. -service ExtractionService { - // Extract structured content from a document. - rpc Extract(ExtractRequest) returns (ExtractResponse); - - // Query extractor capabilities (supported formats, features). - rpc Capabilities(CapabilitiesRequest) returns (ExtractorInfo); -} - -// Source document type. -enum SourceType { - // Auto-detect from file extension. - SOURCE_TYPE_UNKNOWN = 0; - - // PDF document (routes to Docling VLM). - SOURCE_TYPE_PDF = 1; - - // LaTeX source archive (.tar.gz). - SOURCE_TYPE_LATEX = 2; - - // Source code file (routes to TreeSitter AST). - SOURCE_TYPE_CODE = 3; - - // Markdown document. - SOURCE_TYPE_MARKDOWN = 4; - - // Plain text. - SOURCE_TYPE_TEXT = 5; -} - -// Request to extract content from a document. -message ExtractRequest { - // Path to the file on the local filesystem. - // The extraction service must have read access to this path. - string file_path = 1; - - // Optional raw content bytes (alternative to file_path). - // When set, file_path is used only for type detection. - bytes content = 2; - - // Source type hint. If UNKNOWN, auto-detect from file extension. - SourceType source_type = 3; - - // Whether to extract tables from the document. - bool extract_tables = 4; - - // Whether to extract equations from the document. - bool extract_equations = 5; - - // Whether to extract images/figures. - bool extract_images = 6; - - // Whether to use OCR for scanned content. - bool use_ocr = 7; -} - -// Extracted content from a document. -message ExtractResponse { - // Full extracted text content. - string full_text = 1; - - // Extracted tables. - repeated Table tables = 2; - - // Extracted equations. - repeated Equation equations = 3; - - // Extracted image references. - repeated ImageRef images = 4; - - // Additional metadata from the extraction process. - map metadata = 5; - - // Source type that was used for extraction. - SourceType source_type = 6; -} - -// A table extracted from the document. -message Table { - // Table content as text (may be markdown or CSV formatted). - string content = 1; - - // Table caption or title, if available. - string caption = 2; - - // Zero-based index of this table in the document. - uint32 index = 3; -} - -// An equation extracted from the document. -message Equation { - // Equation content (LaTeX notation). - string latex = 1; - - // Plain text representation, if available. - string text = 2; - - // Zero-based index of this equation in the document. - uint32 index = 3; - - // Whether this is an inline or display equation. - bool is_inline = 4; -} - -// A reference to an extracted image/figure. -message ImageRef { - // Path or identifier for the extracted image. - string path = 1; - - // Figure caption, if available. - string caption = 2; - - // Zero-based index of this image in the document. - uint32 index = 3; -} - -// Empty request for the Capabilities RPC. -message CapabilitiesRequest {} - -// Extractor capabilities and supported formats. -message ExtractorInfo { - // Supported file extensions (e.g. [".pdf", ".tex", ".py"]). - repeated string supported_extensions = 1; - - // Supported source types. - repeated SourceType supported_types = 2; - - // Available features. - repeated string features = 3; - - // Whether GPU is available for extraction. - bool gpu_available = 4; -} diff --git a/proto/persephone/training/training.proto b/proto/persephone/training/training.proto deleted file mode 100644 index 2ca77637..00000000 --- a/proto/persephone/training/training.proto +++ /dev/null @@ -1,245 +0,0 @@ -// Persephone Training Service — RGCN link prediction on knowledge graphs. -// -// The Rust orchestrator drives the training loop: it loads the graph from -// ArangoDB, computes edge splits and negative samples via safetensors IPC, -// and issues per-step RPCs to the Python provider which owns the GPU, model -// parameters, and optimizer state. -// -// Lifecycle: -// 1. InitModel — provider creates encoder + predictor on GPU -// 2. LoadGraph — provider loads safetensors file into GPU memory -// 3. TrainStep* — repeated forward+backward+optimizer steps (per epoch) -// 4. Evaluate — score on val/test edges (no gradients) -// 5. GetEmbeddings — full forward pass, return structural embeddings -// 6. Checkpoint / LoadCheckpoint — save/restore model state - -syntax = "proto3"; - -package persephone.training; - -// --------------------------------------------------------------------------- -// Service -// --------------------------------------------------------------------------- - -// Training service for RGCN link prediction. -service TrainingService { - // Initialize the model with the given architecture config. - // Must be called before LoadGraph or TrainStep. - rpc InitModel(InitModelRequest) returns (InitModelResponse); - - // Load a safetensors graph file into GPU memory. - // Must be called after InitModel and before TrainStep. - rpc LoadGraph(LoadGraphRequest) returns (LoadGraphResponse); - - // Execute one training step: forward → loss → backward → optimizer.step(). - // Returns the training loss for the step. - rpc TrainStep(TrainStepRequest) returns (TrainStepResponse); - - // Evaluate on a set of edges (val or test) without gradient computation. - // Returns loss, accuracy, and AUC. - rpc Evaluate(EvaluateRequest) returns (EvaluateResponse); - - // Run a full forward pass and return structural embeddings for all nodes. - // No gradients. Used after training completes to export embeddings. - rpc GetEmbeddings(GetEmbeddingsRequest) returns (GetEmbeddingsResponse); - - // Save the current model state (encoder + predictor) to a checkpoint file. - rpc Checkpoint(CheckpointRequest) returns (CheckpointResponse); - - // Load a previously saved checkpoint, restoring model weights. - rpc LoadCheckpoint(LoadCheckpointRequest) returns (LoadCheckpointResponse); -} - -// --------------------------------------------------------------------------- -// InitModel -// --------------------------------------------------------------------------- - -// Model architecture configuration. -message ModelConfig { - // Number of edge relation types (default: 22). - uint32 num_relations = 1; - - // Number of distinct vertex collection types. - uint32 num_collection_types = 2; - - // Hidden layer dimension (default: 256). - uint32 hidden_dim = 3; - - // Output structural embedding dimension (default: 128). - uint32 embed_dim = 4; - - // Number of basis matrices for RGCN decomposition (default: 21). - uint32 num_bases = 5; - - // Dropout rate (default: 0.2). - float dropout = 6; -} - -// Optimizer configuration. -message OptimizerConfig { - // Learning rate (default: 0.01). - float learning_rate = 1; - - // L2 regularization weight (default: 5e-4). - float weight_decay = 2; -} - -message InitModelRequest { - // Model architecture. - ModelConfig model = 1; - - // Optimizer settings. - OptimizerConfig optimizer = 2; - - // Target device (e.g. "cuda:0", "cuda:2", "cpu"). - string device = 3; -} - -message InitModelResponse { - // Total number of trainable parameters. - uint64 num_parameters = 1; - - // Device the model was placed on. - string device = 2; -} - -// --------------------------------------------------------------------------- -// LoadGraph -// --------------------------------------------------------------------------- - -message LoadGraphRequest { - // Absolute path to the safetensors file produced by hades-prefetch. - string safetensors_path = 1; -} - -message LoadGraphResponse { - // Number of nodes loaded. - uint64 num_nodes = 1; - - // Number of edges loaded. - uint64 num_edges = 2; - - // Feature dimension. - uint32 feature_dim = 3; - - // GPU memory used in bytes (approximate). - uint64 gpu_memory_bytes = 4; -} - -// --------------------------------------------------------------------------- -// TrainStep -// --------------------------------------------------------------------------- - -// Training step input. -// -// neg_src and neg_dst are parallel arrays of equal length: each negative -// sample is the pair (neg_src[i], neg_dst[i]). They are sampled globally -// (not per positive edge) and are independent of train_edge_indices. -message TrainStepRequest { - // Edge indices for the positive (training) edges in this step. - // Indices into the edge arrays loaded via LoadGraph. - repeated uint32 train_edge_indices = 1; - - // Negative sample source node indices (parallel with neg_dst). - repeated uint32 neg_src = 2; - - // Negative sample destination node indices (parallel with neg_src). - repeated uint32 neg_dst = 3; -} - -message TrainStepResponse { - // Binary cross-entropy loss for this step. - float loss = 1; - - // Training accuracy for this step (fraction of correct predictions). - float accuracy = 2; -} - -// --------------------------------------------------------------------------- -// Evaluate -// --------------------------------------------------------------------------- - -message EvaluateRequest { - // Edge indices for the evaluation split (val or test). - repeated uint32 edge_indices = 1; - - // Negative sample source node indices for evaluation. - repeated uint32 neg_src = 2; - - // Negative sample destination node indices for evaluation. - repeated uint32 neg_dst = 3; -} - -message EvaluateResponse { - // Binary cross-entropy loss. - float loss = 1; - - // Classification accuracy (threshold 0.5). - float accuracy = 2; - - // Area under the ROC curve. - float auc = 3; -} - -// --------------------------------------------------------------------------- -// GetEmbeddings -// --------------------------------------------------------------------------- - -message GetEmbeddingsRequest { - // Optional: path to write embeddings as a safetensors file. - // If empty, embeddings are returned inline in the response. - string output_path = 1; -} - -message GetEmbeddingsResponse { - // Number of nodes embedded. - uint64 num_nodes = 1; - - // Embedding dimension. - uint32 embed_dim = 2; - - // Structural embeddings [N * embed_dim] flattened, F32. - // Only populated if output_path was empty in the request. - // For large graphs, prefer file output to avoid gRPC message size limits. - bytes embeddings = 3; - - // Path where embeddings were written (if output_path was set). - string output_path = 4; -} - -// --------------------------------------------------------------------------- -// Checkpoint -// --------------------------------------------------------------------------- - -message CheckpointRequest { - // Path to save the checkpoint file (.pt). - string path = 1; -} - -message CheckpointResponse { - // Path where the checkpoint was saved. - string path = 1; - - // Size of the checkpoint file in bytes. - uint64 size_bytes = 2; -} - -message LoadCheckpointRequest { - // Path to the checkpoint file to load. - string path = 1; - - // Target device to load onto (e.g. "cuda:0"). - // If empty, uses the device from InitModel. - string device = 2; -} - -message LoadCheckpointResponse { - // Model config from the loaded checkpoint. - ModelConfig model_config = 1; - - // Number of trainable parameters. - uint64 num_parameters = 2; - - // Device the model was loaded onto (e.g. "cuda:0"). - string device = 3; -} diff --git a/scripts/install.sh b/scripts/install.sh index 4af45f3e..3119d285 100755 --- a/scripts/install.sh +++ b/scripts/install.sh @@ -39,10 +39,9 @@ LIB_SRC=$(ls -dt "$RELEASE_DIR"/build/llama-cpp-sys-2-*/out/build/bin 2>/dev/nul command -v patchelf >/dev/null || die "patchelf not installed (apt install patchelf)" # Required system identity. weaver-admin owns the runtime tree + state dir -# and runs BOTH weaver-infer.service and weaver-embedder.service under the -# single-user collapse (weaver-embedder is no longer a distinct system user; -# the dedicated-user split was dropped when we consolidated on shared-socket -# IPC between the two services). Provisioned out-of-band. Fail fast with a +# and runs the harness services. Post-PR-1.J the embedder is in-process +# inside weaver-infer; the dedicated weaver-embedder.service retired +# alongside the Python service. Provisioned out-of-band. Fail fast with a # clear message instead of a cryptic `install -o: invalid user` later. id -u weaver-admin >/dev/null 2>&1 || die "weaver-admin user missing — see docs/infrastructure/ for provisioning" getent group weaver-admin >/dev/null 2>&1 || die "weaver-admin group missing — see docs/infrastructure/ for provisioning" @@ -195,23 +194,27 @@ sudo install -m 0755 -o root -g root \ # top of this script, so we can install unconditionally here. echo "==> Installing systemd units" UNIT_DIR=/etc/systemd/system -sudo install -m 644 "$REPO_ROOT/services/weaver-infer.service" "$UNIT_DIR/weaver-infer.service" -sudo install -m 644 "$REPO_ROOT/services/weaver-daemon.service" "$UNIT_DIR/weaver-daemon.service" -sudo install -m 644 "$REPO_ROOT/services/weaver.target" "$UNIT_DIR/weaver.target" -sudo install -m 644 "$REPO_ROOT/services/embedder/weaver-embedder.service" "$UNIT_DIR/weaver-embedder.service" +sudo install -m 644 "$REPO_ROOT/services/weaver-infer.service" "$UNIT_DIR/weaver-infer.service" +sudo install -m 644 "$REPO_ROOT/services/weaver-daemon.service" "$UNIT_DIR/weaver-daemon.service" +sudo install -m 644 "$REPO_ROOT/services/weaver.target" "$UNIT_DIR/weaver.target" +# Post-PR-1.J the external Python `weaver-embedder.service` retired — +# the embedder runs in-process inside `weaver serve`. If a previous +# install left the unit on disk, remove it so the harness target's +# Wants= chain doesn't reference a stale unit. +sudo rm -f "$UNIT_DIR/weaver-embedder.service" sudo systemctl daemon-reload echo "==> Installed:" ls -l "$PREFIX/bin/weaver" "$PREFIX/lib/" echo "systemd units in $UNIT_DIR:" -ls -l "$UNIT_DIR"/weaver-infer.service "$UNIT_DIR"/weaver-daemon.service "$UNIT_DIR"/weaver.target "$UNIT_DIR"/weaver-embedder.service +ls -l "$UNIT_DIR"/weaver-infer.service "$UNIT_DIR"/weaver-daemon.service "$UNIT_DIR"/weaver.target echo "==> Smoke test as weaver-admin:" sudo -u weaver-admin "$PREFIX/bin/weaver" --version 2>&1 || true # weaver.target is a synthetic on-demand target (no [Install] section), so it -# cannot be enabled — only started. weaver-embedder.service, weaver-infer.service, -# and weaver-daemon.service are pulled in when the target starts. +# cannot be enabled — only started. weaver-infer.service and +# weaver-daemon.service are pulled in when the target starts. cat < Next steps: @@ -219,9 +222,10 @@ cat </dev/null && pwd)" -REPO_ROOT="$(cd "$SCRIPT_DIR/../.." && pwd)" -PROTO_DIR="$REPO_ROOT/proto" -OUT_DIR="$SCRIPT_DIR/core/proto" -PY="$SCRIPT_DIR/.venv/bin/python" - -if [[ ! -x "$PY" ]]; then - echo "embedder venv not found at $PY — create it first (see services/embedder/README)" >&2 - exit 1 -fi - -mkdir -p "$OUT_DIR/persephone/embedding" -touch "$OUT_DIR/__init__.py" \ - "$OUT_DIR/persephone/__init__.py" \ - "$OUT_DIR/persephone/embedding/__init__.py" - -"$PY" -m grpc_tools.protoc \ - --proto_path="$PROTO_DIR" \ - --python_out="$OUT_DIR" \ - --grpc_python_out="$OUT_DIR" \ - "$PROTO_DIR/persephone/embedding/embedding.proto" - -# Patch the generated grpc module to use a relative import so the generated -# tree works without adding the output root to sys.path. -GRPC_FILE="$OUT_DIR/persephone/embedding/embedding_pb2_grpc.py" -sed -i 's|^from persephone.embedding import embedding_pb2 as |from . import embedding_pb2 as |' "$GRPC_FILE" - -echo "Regenerated gRPC stubs at $OUT_DIR" diff --git a/services/embedder/core/__init__.py b/services/embedder/core/__init__.py deleted file mode 100644 index 614bed7b..00000000 --- a/services/embedder/core/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# Minimal core package for the embedder service. diff --git a/services/embedder/core/cli/__init__.py b/services/embedder/core/cli/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/services/embedder/core/cli/config.py b/services/embedder/core/cli/config.py deleted file mode 100644 index 6138a889..00000000 --- a/services/embedder/core/cli/config.py +++ /dev/null @@ -1,146 +0,0 @@ -"""Embedder service configuration. - -Reads the `[embedder]` block from the harness `server.toml` — the single -source of truth shared with `weaver-inference` (see -`crates/weaver-inference/src/multi_model.rs::EmbedderConfig`). - -Resolution order (highest priority first): - 1. Environment variables (WEAVER_EMBEDDER_*) - 2. `[embedder]` block in server.toml (located via WEAVER_SERVER_TOML or - the default `/opt/weaver/server.toml`) - 3. Built-in defaults matching the Rust `EmbedderConfig` defaults - -weaver.yaml is no longer consulted for embedder fields. The `embedding:` -block was removed when the config folded into server.toml; remaining -weaver.yaml sections (database, vector_index, search, etc.) belong to -the HADES subsystem of WeaverTools and will migrate into server.toml as -each subsystem lands in the harness. -""" - -from __future__ import annotations - -import logging -import os -import tomllib -from pathlib import Path -from typing import Any - -logger = logging.getLogger(__name__) - -DEFAULT_SERVER_TOML = Path("/opt/weaver/server.toml") - -# Defaults mirror `crates/weaver-inference/src/multi_model.rs`. Keeping them -# in sync with the Rust side is load-bearing — the two config readers must -# agree when the TOML omits an optional field. -_DEFAULT_SOCKET = "/run/weaver/embedder.sock" -_DEFAULT_MODEL_NAME = "jinaai/jina-embeddings-v4" -_DEFAULT_BATCH_SIZE = 48 -_DEFAULT_IDLE_TIMEOUT = 0 -_DEFAULT_USE_FP16 = True - - -def _server_toml_path() -> Path: - override = os.environ.get("WEAVER_SERVER_TOML") - return Path(override) if override else DEFAULT_SERVER_TOML - - -def _load_server_toml(path: Path | None = None) -> dict[str, Any]: - """Load the `[embedder]` block from server.toml. Returns {} on miss.""" - target = path or _server_toml_path() - if not target.exists(): - logger.debug(f"server.toml not found at {target}; using defaults + env") - return {} - try: - with open(target, "rb") as f: - loaded = tomllib.load(f) - except tomllib.TOMLDecodeError as e: - raise ValueError(f"Invalid TOML in {target}: {e}") from e - embedder = loaded.get("embedder", {}) - if not isinstance(embedder, dict): - raise TypeError( - f"{target}: [embedder] must be a table, got {type(embedder).__name__}" - ) - return embedder - - -def _parse_bool_env(name: str, default: bool) -> bool: - raw = os.environ.get(name) - if raw is None: - return default - return raw.strip().lower() in ("true", "1", "yes") - - -def get_embedder_service_config(path: Path | None = None) -> dict[str, Any]: - """Resolve embedder service configuration. - - Env vars override server.toml; server.toml overrides built-in defaults. - Returns the dict shape consumed by `embedder_grpc._load_model` and - `_serve`: {device, model_name, use_fp16, batch_size, idle_timeout, - socket}. - """ - tconf = _load_server_toml(path) - - # Device — env > toml (derived from `gpu` ordinal) > fallback - device = os.environ.get("WEAVER_EMBEDDER_DEVICE") - if not device: - gpu = tconf.get("gpu") - if isinstance(gpu, int): - device = f"cuda:{gpu}" - if not device: - device = "cuda:2" - - # Honor an explicit disable. No yaml fallback any more — if the operator - # wants CPU, WEAVER_USE_GPU=false is the knob. - use_gpu_env = os.environ.get("WEAVER_USE_GPU") - if use_gpu_env is not None and use_gpu_env.strip().lower() in ("false", "0", "no"): - device = "cpu" - - batch_size_env = os.environ.get("WEAVER_EMBEDDER_BATCH_SIZE") - if batch_size_env: - try: - batch_size = int(batch_size_env) - except ValueError as e: - raise ValueError( - f"WEAVER_EMBEDDER_BATCH_SIZE must be an integer, got: {batch_size_env}" - ) from e - else: - batch_size = int(tconf.get("batch_size", _DEFAULT_BATCH_SIZE)) - if batch_size <= 0: - raise ValueError(f"batch_size must be a positive integer, got: {batch_size}") - - idle_timeout_env = os.environ.get("WEAVER_EMBEDDER_IDLE_TIMEOUT") - if idle_timeout_env: - try: - idle_timeout = int(idle_timeout_env) - except ValueError as e: - raise ValueError( - f"WEAVER_EMBEDDER_IDLE_TIMEOUT must be an integer, got: {idle_timeout_env}" - ) from e - else: - idle_timeout = int(tconf.get("idle_timeout_seconds", _DEFAULT_IDLE_TIMEOUT)) - - model_name = ( - os.environ.get("WEAVER_EMBEDDER_MODEL") - or tconf.get("model_name") - or _DEFAULT_MODEL_NAME - ) - - use_fp16 = _parse_bool_env( - "WEAVER_EMBEDDER_FP16", - default=bool(tconf.get("use_fp16", _DEFAULT_USE_FP16)), - ) - - socket = ( - os.environ.get("WEAVER_EMBEDDER_SOCKET") - or tconf.get("socket") - or _DEFAULT_SOCKET - ) - - return { - "device": device, - "model_name": model_name, - "use_fp16": use_fp16, - "batch_size": batch_size, - "idle_timeout": idle_timeout, - "socket": socket, - } diff --git a/services/embedder/core/config/weaver.yaml b/services/embedder/core/config/weaver.yaml deleted file mode 100644 index a58a2b0e..00000000 --- a/services/embedder/core/config/weaver.yaml +++ /dev/null @@ -1,125 +0,0 @@ -# HADES Configuration -# =================== -# Default configuration for the HADES subsystem of WeaverTools. Values can -# be overridden by environment variables or CLI arguments. -# -# Override priority (highest wins): -# 1. CLI arguments (--gpu, --limit, etc.) -# 2. Environment variables (WEAVER_*, ARANGO_*) -# 3. This config file -# -# MIGRATION NOTE (2026-04-20): -# The `embedding:` block used to live here but now belongs to the harness -# `server.toml` under `[embedder]` (single source of truth with the Rust -# MultiModelConfig). The Python embedder service reads server.toml via -# tomllib — see `core/cli/config.py`. -# -# Remaining sections (database, vector_index, search, rocchio, sync, -# arxiv, logging) stage for migration into server.toml as each HADES -# subsystem gets absorbed into the harness. Until then they stay here as -# the HADES subsystem's own config. - -# ----------------------------------------------------------------------------- -# Database Configuration -# ----------------------------------------------------------------------------- -database: - # ArangoDB connection settings - host: localhost - port: 8529 - database: NestedLearning - username: root - # password: Set via ARANGO_PASSWORD environment variable (required) - - # Unix socket paths for high-performance local connections - # These override host:port when set - sockets: - readonly: /run/weaver/readonly/arangod.sock - readwrite: /run/weaver/readwrite/arangod.sock - -# ----------------------------------------------------------------------------- -# Embedding Service Configuration — MIGRATED to server.toml [embedder] -# ----------------------------------------------------------------------------- -# The embedder block moved to /opt/weaver/server.toml under [embedder]. See -# crates/weaver-inference/src/multi_model.rs::EmbedderConfig for the schema -# and services/embedder/core/cli/config.py for the Python reader. -# Env overrides (WEAVER_EMBEDDER_*) still apply. - -# ----------------------------------------------------------------------------- -# GPU Configuration -# ----------------------------------------------------------------------------- -# Multi-GPU systems should assign specific GPUs for different workloads: -# - Training workloads: cuda:0, cuda:1 (high-memory GPUs) -# - Inference/embedding: cuda:2 (dedicated inference GPU) -# -# This separation prevents embedding operations from interrupting model training. -gpu: - # Default device for inference/embedding operations - # Using explicit GPU index (cuda:2) keeps training GPUs (cuda:0, cuda:1) free - device: cuda:2 - # Use GPU by default - enabled: true - -# ----------------------------------------------------------------------------- -# Vector Index Settings -# ----------------------------------------------------------------------------- -# ArangoDB 3.12+ FAISS-backed vector indexes for server-side ANN search. -# When a vector index exists, `hades db query` uses APPROX_NEAR_COSINE() -# (or L2/innerProduct variant) instead of brute-force. Create with `hades db create-index`. -vector_index: - # Default nProbe for ANN search (higher = better recall, slower) - default_n_probe: 10 - # Default metric (cosine, l2, innerProduct) - metric: cosine - # Auto-calculate nLists from collection size (N/15) - auto_n_lists: true - -# ----------------------------------------------------------------------------- -# Search Defaults -# ----------------------------------------------------------------------------- -search: - # Default number of results - limit: 10 - # Maximum allowed results - max_limit: 100 - # Hybrid search settings - hybrid: - # Weight for vector similarity (0-1) - vector_weight: 0.7 - # Weight for keyword matching (0-1) - keyword_weight: 0.3 - -# ----------------------------------------------------------------------------- -# Relevance Feedback (Rocchio Algorithm) -# ----------------------------------------------------------------------------- -rocchio: - # Weight for original query - alpha: 1.0 - # Weight for positive exemplars - beta: 0.75 - # Weight for negative exemplars - gamma: 0.15 - -# ----------------------------------------------------------------------------- -# Sync Settings -# ----------------------------------------------------------------------------- -sync: - # Default lookback period for incremental sync (days) - default_lookback_days: 7 - # Default batch size for embedding during sync - batch_size: 8 - # Default max results per sync - max_results: 1000 - -# ----------------------------------------------------------------------------- -# ArXiv Data Paths -# ----------------------------------------------------------------------------- -arxiv: - pdf_base_path: /bulk-store/arxiv-data/pdf - latex_base_path: /bulk-store/arxiv-data/src - -# ----------------------------------------------------------------------------- -# Logging -# ----------------------------------------------------------------------------- -logging: - level: INFO - format: "%(asctime)s - %(name)s - %(levelname)s - %(message)s" diff --git a/services/embedder/core/embedders/__init__.py b/services/embedder/core/embedders/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/services/embedder/core/embedders/embedders_base.py b/services/embedder/core/embedders/embedders_base.py deleted file mode 100644 index 6372ddb3..00000000 --- a/services/embedder/core/embedders/embedders_base.py +++ /dev/null @@ -1,168 +0,0 @@ -#!/usr/bin/env python3 -""" -Base Embedder Interface - -Defines the contract for all embedding implementations in HADES. -Embedders transform text content into vector representations while -preserving semantic relationships for similarity search. -""" - -from abc import ABC, abstractmethod -from dataclasses import dataclass -from typing import Any - -import numpy as np - - -@dataclass -class EmbeddingConfig: - """Configuration for embedders.""" - model_name: str - device: str = "cuda" - batch_size: int = 32 - max_seq_length: int = 8192 - use_fp16: bool = True - chunk_size_tokens: int | None = None - chunk_overlap_tokens: int | None = None - - -class EmbedderBase(ABC): - """ - Abstract base class for all embedders. - - Defines the interface that all embedding implementations must follow - to ensure consistency across different models and approaches. - """ - - def __init__(self, config: EmbeddingConfig | None = None): - """ - Initialize embedder with configuration. - - Args: - config: Embedding configuration - """ - self.config = config or EmbeddingConfig(model_name="default") - - @abstractmethod - def embed_texts(self, - texts: list[str], - task: str = "retrieval", - batch_size: int | None = None) -> np.ndarray: - """ - Embed a list of texts. - - Args: - texts: List of texts to embed - task: Task type (retrieval, classification, etc.) - batch_size: Override default batch size - - Returns: - Array of embeddings (N x D) - """ - pass - - @abstractmethod - def embed_single(self, - text: str, - task: str = "retrieval") -> np.ndarray: - """ - Embed a single text. - - Args: - text: Text to embed - task: Task type - - Returns: - Embedding vector (1D array) - """ - pass - - def embed_queries(self, - queries: list[str], - batch_size: int | None = None) -> np.ndarray: - """ - Embed search queries (convenience method). - - Args: - queries: List of search queries - batch_size: Override default batch size - - Returns: - Array of query embeddings - """ - return self.embed_texts(queries, task="retrieval.query", batch_size=batch_size) - - def embed_documents(self, - documents: list[str], - batch_size: int | None = None) -> np.ndarray: - """ - Embed documents for retrieval (convenience method). - - Args: - documents: List of documents - batch_size: Override default batch size - - Returns: - Array of document embeddings - """ - return self.embed_texts(documents, task="retrieval.passage", batch_size=batch_size) - - @property - @abstractmethod - def embedding_dimension(self) -> int: - """Get the dimension of embeddings produced by this model.""" - pass - - @property - @abstractmethod - def max_sequence_length(self) -> int: - """Get the maximum sequence length supported.""" - pass - - @property - def supports_late_chunking(self) -> bool: - """Whether this embedder supports late chunking.""" - return False - - def embed_with_late_chunking(self, text: str) -> list[Any]: - """ - Embed text using late chunking strategy. - - Late chunking embeds the full document first, then extracts chunk - embeddings from the encoded representation, preserving document context. - - Args: - text: Full document text to process - - Returns: - List of ChunkWithEmbedding objects (or equivalent) - - Raises: - NotImplementedError: If embedder doesn't support late chunking - """ - raise NotImplementedError( - f"{self.__class__.__name__} does not support late chunking. " - "Use embed_texts() with pre-chunked text instead." - ) - - @property - def supports_multimodal(self) -> bool: - """Whether this embedder supports multimodal inputs.""" - return False - - def get_model_info(self) -> dict[str, Any]: - """ - Get information about the model. - - Returns: - Dictionary with model metadata - """ - return { - "model_name": self.config.model_name, - "embedding_dimension": self.embedding_dimension, - "max_sequence_length": self.max_sequence_length, - "supports_late_chunking": self.supports_late_chunking, - "supports_multimodal": self.supports_multimodal, - "device": self.config.device, - "use_fp16": self.config.use_fp16 - } diff --git a/services/embedder/core/embedders/embedders_jina.py b/services/embedder/core/embedders/embedders_jina.py deleted file mode 100644 index 73b731b0..00000000 --- a/services/embedder/core/embedders/embedders_jina.py +++ /dev/null @@ -1,1019 +0,0 @@ -#!/usr/bin/env python3 -""" -Jina v4 Embedder with proper API usage, fp16 support, and LATE CHUNKING. - -Late chunking preserves contextual relationships across text boundaries. -The 32k token context window allows processing entire documents at once, -then intelligently chunking while preserving cross-boundary semantic relationships. -""" - -# cspell:ignore jina Jina embedder Embedder - -import base64 -import io -import logging -from dataclasses import dataclass -from typing import Any - -import numpy as np -import torch -from PIL import Image -from transformers import AutoModel, AutoTokenizer - -from .embedders_base import EmbedderBase, EmbeddingConfig - -logger = logging.getLogger(__name__) - - -@dataclass -class ChunkWithEmbedding: - """ - Represents a text chunk with its late-chunked embedding. - - The embedding preserves awareness of surrounding context from the - full document, enabling superior semantic search compared to - independently embedded chunks. - """ - text: str - embedding: np.ndarray - start_char: int - end_char: int - start_token: int - end_token: int - chunk_index: int - total_chunks: int - context_window_used: int # How many tokens were in the full context - - -class JinaV4Embedder(EmbedderBase): - """ - Jina v4 embedder with late chunking support. - - This embedder supports both: - 1. Traditional embedding (for short texts, backward compatibility) - 2. Late chunking (for long documents, superior context preservation) - """ - - # Jina v4 constants - MAX_TOKENS = 32768 # Jina v4's context window - EMBEDDING_DIM = 2048 - - def __init__(self, config: EmbeddingConfig | dict[str, Any] | None = None, **kwargs: Any) -> None: - """ - Initialize Jina v4 embedder with late chunking support. - - Args: - config: EmbeddingConfig object or dict/None for defaults - **kwargs: Additional configuration overrides - """ - # Build config dict from various input formats - config_dict: dict[str, Any] = {} - - if config is None: - config_dict = {} - elif isinstance(config, EmbeddingConfig): - # Extract values from EmbeddingConfig object - config_dict = { - 'device': config.device, - 'use_fp16': config.use_fp16, - 'batch_size': config.batch_size, - 'chunk_size_tokens': config.chunk_size_tokens, - 'chunk_overlap_tokens': config.chunk_overlap_tokens, - 'model_name': config.model_name, - 'max_seq_length': config.max_seq_length, - } - elif isinstance(config, dict): - config_dict = config.copy() - else: - # Old-style single param (device) - for backwards compatibility - config_dict = {'device': str(config)} # type: ignore[unreachable] - - # Apply kwargs overrides - config_dict.update(kwargs) - - # Determine device with hard CPU fallback - requested_device = config_dict.get('device', 'cuda') - if requested_device.startswith('cuda') and not torch.cuda.is_available(): - logger.warning("CUDA requested but not available, forcing CPU fallback") - self.device = 'cpu' - else: - self.device = requested_device - - self.model_name = config_dict.get('model_name', 'jinaai/jina-embeddings-v4') - self.batch_size = config_dict.get('batch_size', 128) # Default to 128 for better throughput - - # Create proper EmbeddingConfig for base class - base_config = EmbeddingConfig( - model_name=self.model_name, - device=self.device, - batch_size=self.batch_size, - max_seq_length=config_dict.get('max_seq_length', self.MAX_TOKENS), - use_fp16=config_dict.get('use_fp16', True), - chunk_size_tokens=config_dict.get('chunk_size_tokens', 500), - chunk_overlap_tokens=config_dict.get('chunk_overlap_tokens', 200), - ) - - # Initialize base class - super().__init__(base_config) - - # Set and validate chunking parameters from config - self.chunk_size_tokens = config_dict.get('chunk_size_tokens', 500) - self.chunk_overlap_tokens = config_dict.get('chunk_overlap_tokens', 200) - if self.chunk_size_tokens <= 0: - raise ValueError(f"chunk_size_tokens must be > 0, got {self.chunk_size_tokens}") - if self.chunk_overlap_tokens < 0: - raise ValueError(f"chunk_overlap_tokens must be >= 0, got {self.chunk_overlap_tokens}") - if self.chunk_overlap_tokens >= self.chunk_size_tokens: - raise ValueError( - f"chunk_overlap_tokens ({self.chunk_overlap_tokens}) must be < " - f"chunk_size_tokens ({self.chunk_size_tokens})" - ) - use_fp16 = config_dict.get('use_fp16', True) - - # Load model with appropriate dtype - # Check if device starts with "cuda" to handle cuda:0, cuda:1, etc. - dtype = torch.float16 if (use_fp16 and self.device.startswith("cuda")) else torch.float32 - - logger.info(f"Loading {self.model_name} on {self.device} with dtype={dtype}") - logger.info(f"Batch size for embedding: {self.batch_size}") - logger.info(f"Late chunking config: {self.chunk_size_tokens} tokens/chunk, {self.chunk_overlap_tokens} overlap") - - # Load tokenizer first for late chunking - self.tokenizer = AutoTokenizer.from_pretrained( - self.model_name, - trust_remote_code=True - ) - - # Load model and then move to target device - # device_map should be "auto" or a dict, not a device string - self.model = AutoModel.from_pretrained( - self.model_name, - trust_remote_code=True, - torch_dtype=dtype - ) - - # Move model to the target device if not CPU - if self.device and self.device != "cpu": - self.model = self.model.to(self.device) - - self.model.eval() - - # Log actual model configuration - logger.info("Jina v4 model loaded with late chunking support") - logger.info(f"Model dtype: {next(self.model.parameters()).dtype}") - logger.info("Embedding dimension: %s", self.EMBEDDING_DIM) - logger.info(f"Model has encode_text method: {hasattr(self.model, 'encode_text')}") - - # Check if Flash Attention is available and being used - try: - import flash_attn - logger.info("Flash Attention 2 is available - model should use it automatically") - except ImportError: - logger.warning("Flash Attention 2 not available - performance may be limited") - - def embed_texts(self, - texts: list[str], - task: str = "retrieval.passage", - batch_size: int | None = None) -> np.ndarray: - """ - Embed texts using Jina v4. - - Args: - texts: List of texts to embed - task: Task type (retrieval, text-matching, code) - batch_size: Batch size for processing - - Returns: - Numpy array of embeddings (N x 2048) - """ - all_embeddings = [] - # Cap client-requested batch_size by the server-configured ceiling. - # Without min(), a client can override the operator's resource limit. - batch_size = min(batch_size or self.batch_size, self.batch_size) - - # Commented out for performance - this was logging 30+ times per second - # logger.info(f"Processing {len(texts)} texts with batch_size={batch_size}") - - with torch.no_grad(): - # Process in batches - for i in range(0, len(texts), batch_size): - batch = texts[i:i+batch_size] - - # Use Jina's encode_text method if available - if hasattr(self.model, 'encode_text'): - # Jina v4 encode method accepts task in {'retrieval','text-matching','code'} - # and prompt_name in {'query','passage'} for retrieval prefix. - jina_task_map = { - 'retrieval.passage': 'retrieval', - 'retrieval.query': 'retrieval', - 'retrieval': 'retrieval', - 'text-matching': 'text-matching', - 'code': 'code', - } - jina_task = jina_task_map.get(task, 'retrieval') - prompt_name = 'query' if task == 'retrieval.query' else 'passage' - embeddings = self.model.encode_text( - batch, - task=jina_task, - prompt_name=prompt_name, - ) - else: - # Jina v4 requires task_label when using forward pass. - # Valid LoRA adapters are: 'retrieval', 'text-matching', 'code' - # (query vs passage is a prompt prefix, not a separate adapter) - task_mapping = { - 'retrieval.passage': 'retrieval', - 'retrieval.query': 'retrieval', - 'retrieval': 'retrieval', - 'text-matching': 'text-matching', - 'code': 'code', - } - task_label = task_mapping.get(task, 'retrieval') - - inputs = self.tokenizer( - batch, - return_tensors="pt", - padding=True, - truncation=True, - max_length=self.MAX_TOKENS - ).to(self.device) - - # Add task_label to inputs - must be a string for LoRA adapter selection - outputs = self.model(**inputs, task_label=task_label) - - # Jina v4 returns single_vec_emb for 2048-dimensional embeddings - if hasattr(outputs, 'single_vec_emb') and outputs.single_vec_emb is not None: - embeddings = outputs.single_vec_emb - else: - # Log error with available attributes for debugging - available_attrs = [attr for attr in dir(outputs) if not attr.startswith('_')] - raise AttributeError( - f"Expected 'single_vec_emb' in JinaV4 output, but got: {available_attrs}. " - f"Output type: {type(outputs).__name__}" - ) - - if torch.is_tensor(embeddings): - if embeddings.is_cuda: - embeddings = embeddings.cpu() - embeddings = embeddings.numpy() - - # Debug: Check what type of object we got - # logger.debug(f"Embeddings type: {type(embeddings)}") - - # Handle different return types - prioritize torch.is_tensor check - if torch.is_tensor(embeddings): - # Handle PyTorch tensors directly - if embeddings.is_cuda: - embeddings = embeddings.cpu() - embeddings = embeddings.numpy() - elif hasattr(embeddings, 'detach'): # Other tensor-like objects - embeddings = embeddings.detach() - if hasattr(embeddings, 'is_cuda') and embeddings.is_cuda: - embeddings = embeddings.cpu() - embeddings = embeddings.numpy() - elif isinstance(embeddings, list): - # If it's a list of tensors - processed = [] - for e in embeddings: - if torch.is_tensor(e): - if e.is_cuda: - e = e.cpu() - processed.append(e.numpy()) - elif hasattr(e, 'detach'): - e = e.detach() - if hasattr(e, 'is_cuda') and e.is_cuda: - e = e.cpu() - processed.append(e.numpy()) - else: - processed.append(np.array(e)) - embeddings = np.vstack(processed) - elif not isinstance(embeddings, np.ndarray): - # Try to convert to numpy - try: - embeddings = np.array(embeddings) - except Exception as e: - logger.error(f"Cannot convert embeddings of type {type(embeddings)} to numpy: {e}") - raise - - all_embeddings.append(embeddings.astype(np.float32, copy=False)) - - # DO NOT clear GPU cache after each batch - this kills performance! - # PyTorch's allocator efficiently reuses memory. - # Only clear cache if encountering OOM errors. - # if torch.cuda.is_available(): - # torch.cuda.empty_cache() - - # Concatenate all batches - if all_embeddings: - result = np.vstack(all_embeddings).astype(np.float32, copy=False) - else: - result = np.empty((0, self.EMBEDDING_DIM), dtype=np.float32) - - return result - - def embed_single(self, text: str, task: str = "retrieval.passage") -> np.ndarray: - """ - Embed a single text (required by EmbedderBase interface). - - Args: - text: Text to embed - task: Task type - - Returns: - 1D embedding array - """ - embeddings = self.embed_texts([text], task=task, batch_size=1) - return embeddings[0] if embeddings.size > 0 else np.zeros(self.EMBEDDING_DIM, dtype=np.float32) - - @property - def embedding_dimension(self) -> int: - """Get the dimension of embeddings produced by this model.""" - return self.EMBEDDING_DIM - - @property - def max_sequence_length(self) -> int: - """Get the maximum sequence length supported.""" - return self.MAX_TOKENS - - @property - def supports_late_chunking(self) -> bool: - """Whether this embedder supports late chunking.""" - return True - - @property - def supports_multimodal(self) -> bool: - """Whether this embedder supports multimodal inputs.""" - return True # Jina v4 supports images - - def embed_code(self, - code_snippets: list[str], - batch_size: int = 4) -> np.ndarray: - """ - Embed code using the code-specific task. - - Args: - code_snippets: List of code snippets - batch_size: Batch size for processing - - Returns: - Numpy array of embeddings (N x 2048) - """ - return self.embed_texts(code_snippets, task="code", batch_size=batch_size) - - def embed_images(self, images: list[bytes | Image.Image | str]) -> np.ndarray: - """ - Embed images using Jina v4's multimodal capabilities. - - Args: - images: List of images as bytes, PIL Images, or base64 strings - - Returns: - L2-normalized embeddings as numpy array - """ - processed_images = [] - - for img in images: - if isinstance(img, bytes): - # Convert bytes to PIL Image - pil_img = Image.open(io.BytesIO(img)) - elif isinstance(img, str): - # Assume base64 encoded - img_bytes = base64.b64decode(img) - pil_img = Image.open(io.BytesIO(img_bytes)) - elif isinstance(img, Image.Image): - pil_img = img - else: - raise ValueError(f"Unsupported image type: {type(img)}") - - # Convert to RGB if necessary - if pil_img.mode != 'RGB': - pil_img = pil_img.convert('RGB') - - processed_images.append(pil_img) - - # Encode images using Jina v4's encode_image method - with torch.no_grad(): - embeddings = self.model.encode_image( - images=processed_images, - task="retrieval" - ) - - # Handle CUDA tensors properly - if torch.is_tensor(embeddings): - if embeddings.is_cuda: - embeddings = embeddings.cpu() - embeddings = embeddings.numpy() - elif hasattr(embeddings, 'detach'): - embeddings = embeddings.detach() - if hasattr(embeddings, 'is_cuda') and embeddings.is_cuda: - embeddings = embeddings.cpu() - embeddings = embeddings.numpy() - elif not isinstance(embeddings, np.ndarray): - embeddings = np.array(embeddings) - - # L2 normalize - norms = np.linalg.norm(embeddings, axis=1, keepdims=True) - embeddings = embeddings / (norms + 1e-8) - - return embeddings - - def embed_multimodal(self, pairs: list[dict[str, str | list[bytes | Image.Image | str]]]) -> np.ndarray: - """ - Create unified embeddings for text+image pairs. - - Args: - pairs: List of dicts with 'text' and optional 'images' keys - - Returns: - L2-normalized multimodal embeddings - """ - embeddings_list = [] - - for pair in pairs: - text: str = str(pair.get('text', '')) - images_raw = pair.get('images', []) - images: list[bytes | Image.Image | str] = [] - if images_raw: - if isinstance(images_raw, list): - images = images_raw - else: - images = [images_raw] - - if not text and not images: - # Empty pair, return zero vector with correct dimensionality - embeddings_list.append(np.zeros(self.EMBEDDING_DIM, dtype=np.float32)) - continue - - # Use late fusion as Jina v4 doesn't have true multimodal yet - components = [] - weights = [] - - if text: - text_emb = self.embed_texts([text])[0] - components.append(text_emb) - weights.append(0.7) # Default text weight - - if images: - img_embs = self.embed_images(images) - # Average multiple images - img_emb = np.mean(img_embs, axis=0) - components.append(img_emb) - weights.append(0.3) # Default image weight - - # Weighted combination - if len(components) == 1: - combined = components[0] - else: - weights = np.array(weights) / np.sum(weights) # Normalize weights - combined = np.sum([w * c for w, c in zip(weights, components, strict=False)], axis=0) - - # L2 normalize - norm = np.linalg.norm(combined) - combined = combined / (norm + 1e-8) - - embeddings_list.append(combined) - - return np.array(embeddings_list) - - def embed_with_late_chunking(self, - text: str, - task: str = "retrieval.passage") -> list[ChunkWithEmbedding]: - """ - Implement PROPER late chunking as mandated by CLAUDE.md. - - For now, we use a hybrid approach: - 1. Split text into overlapping windows that fit within model context - 2. Each window gets FULL contextual encoding (not just the chunk) - 3. This ensures chunks have awareness of surrounding context - - While not perfect late chunking (which requires hidden states), - this is much better than naive chunking and works with Jina's API. - - Args: - text: Full document text to process - task: Task type (retrieval, text-matching, separation, classification) - - Returns: - List of ChunkWithEmbedding objects with context-aware embeddings - """ - if not text: - return [] - - # Encode the full document once to obtain contextual token embeddings - token_embeddings, metadata = self.encode_full_document(text, task) - - if token_embeddings.numel() == 0: - # Fallback: model returned no embeddings (extremely short text) - embedding = self.embed_texts([text], task=task)[0] - embedding = embedding if isinstance(embedding, np.ndarray) else np.asarray(embedding) - if embedding.size == 0: - embedding = np.zeros(self.EMBEDDING_DIM, dtype=np.float32) - return [ChunkWithEmbedding( - text=text, - embedding=embedding, - start_char=0, - end_char=len(text), - start_token=0, - end_token=metadata.get('num_tokens', len(text) // 4), - chunk_index=0, - total_chunks=1, - context_window_used=metadata.get('num_tokens', len(text) // 4) - )] - - # Ensure embeddings are on CPU for downstream numpy ops - if hasattr(token_embeddings, 'is_cuda') and token_embeddings.is_cuda: - token_embeddings = token_embeddings.cpu() - - # Run the second stage of late chunking directly on the cached embeddings - chunks = self.embed_chunks_from_tokens( - token_embeddings=token_embeddings, - metadata=metadata, - text=text, - chunk_size_tokens=self.chunk_size_tokens, - chunk_overlap_tokens=self.chunk_overlap_tokens - ) - - return chunks - - def embed_batch_with_late_chunking(self, - texts: list[str], - task: str = "retrieval.passage") -> list[list[ChunkWithEmbedding]]: - """ - Batch version of embed_with_late_chunking for GPU efficiency. - - Processes multiple documents with proper late chunking: - 1. Each document is fully encoded to get contextualized representations - 2. Then chunks are created from the contextualized hidden states - 3. This preserves full document context in every chunk - - Args: - texts: List of full document texts to process - task: Task type (retrieval, text-matching, separation, classification) - - Returns: - List of lists - one ChunkWithEmbedding list per input document - """ - if not texts: - return [] - - all_results: list[list[ChunkWithEmbedding]] = [] - - # Process each document with proper late chunking - # We process individually to maintain full context for each document - for text in texts: - if not text: - all_results.append([]) - continue - - # Use the proper late chunking implementation - chunks = self.embed_with_late_chunking(text, task=task) - all_results.append(chunks) - - return all_results - - - def _prepare_simple_chunks(self, text: str) -> list[dict]: - """ - Simple chunking method that always works. - Creates chunks based on chunk_size_tokens with overlap. - """ - chunks = [] - - # Character-based chunking (rough token estimate) - chars_per_token = 4 - chunk_size_chars = self.chunk_size_tokens * chars_per_token - overlap_chars = self.chunk_overlap_tokens * chars_per_token - - start_char = 0 - chunk_index = 0 - - while start_char < len(text): - # Define chunk boundaries - end_char = min(start_char + chunk_size_chars, len(text)) - - # Try to break at sentence boundary if possible - if end_char < len(text): - # Look for sentence end near boundary - search_start = max(start_char, end_char - 100) - sentence_end = text.find('. ', search_start, end_char) - if sentence_end != -1: - end_char = sentence_end + 2 - - chunk_text = text[start_char:end_char] - - chunks.append({ - 'text': chunk_text, - 'start_char': start_char, - 'end_char': end_char, - 'start_token': start_char // chars_per_token, - 'end_token': end_char // chars_per_token, - 'chunk_index': chunk_index, - 'total_chunks': 0, # Will be updated - 'context_size': len(chunk_text) // chars_per_token - }) - - # Move to next chunk with overlap - if end_char >= len(text): - break - - start_char = end_char - overlap_chars - chunk_index += 1 - - # Update total chunks - total = len(chunks) - for chunk in chunks: - chunk['total_chunks'] = total - - return chunks - - def _prepare_chunks_for_batch(self, text: str, doc_idx: int) -> list[dict]: - """ - Prepare chunks and context windows for a single document in batch processing. - - Returns list of dictionaries with chunk and context information. - """ - # Estimate chunk size in characters (rough: ~4 chars per token) - chunk_size_chars = self.chunk_size_tokens * 4 - - chunks = [] - chunk_index = 0 - start_char = 0 - - while start_char < len(text): - # Define chunk boundaries - end_char = min(start_char + chunk_size_chars, len(text)) - chunk_text = text[start_char:end_char] - - # Define context window (chunk + surrounding text) - context_start = max(0, start_char - chunk_size_chars) - context_end = min(len(text), end_char + chunk_size_chars) - context_text = text[context_start:context_end] - - chunk_info = { - 'chunk_text': chunk_text, - 'context_text': context_text, - 'start_char': start_char, - 'end_char': end_char, - 'start_token': start_char // 4, # Rough estimate - 'end_token': end_char // 4, - 'chunk_index': chunk_index, - 'total_chunks': 0, # Will be updated later - 'context_window_used': len(context_text) // 4 # Rough token estimate - } - - chunks.append(chunk_info) - - # Move to next chunk with overlap - if end_char >= len(text): - break - start_char = end_char - (self.chunk_overlap_tokens * 4) # Convert overlap to chars - chunk_index += 1 - - # Update total chunks count - for chunk in chunks: - chunk['total_chunks'] = len(chunks) - - logger.debug(f"Prepared {len(chunks)} chunks for document {doc_idx}") - return chunks - - def _chunk_with_context_windows(self, - text: str, - task: str = "retrieval.passage") -> list[ChunkWithEmbedding]: - """ - DEPRECATED: This method has O(N^2) complexity due to redundant encoding. - - Use embed_with_late_chunking() instead which properly chunks without - redundant model calls. This method is kept only for backward compatibility - and will be removed in future versions. - """ - logger.warning("_chunk_with_context_windows is deprecated due to O(N^2) complexity. " - "Using embed_with_late_chunking instead for better performance.") - - # Redirect to the efficient implementation - return self.embed_with_late_chunking(text, task) - - def encode_full_document(self, - text: str, - task: str = "retrieval.passage") -> tuple[torch.Tensor, dict]: - """ - Encode a full document to get token-level embeddings (first step of late chunking). - - This is the critical first step of proper late chunking: - 1. Process the entire document through the transformer - 2. Get contextualized token embeddings for the whole document - 3. Return these for subsequent chunking with context preservation - - Args: - text: Full document text - task: Task type (retrieval, code, etc.) - - Returns: - Tuple of (token_embeddings, metadata_dict) - where metadata_dict contains token offsets and other info - """ - if not text: - return torch.empty(0, self.EMBEDDING_DIM), {} - - # Check if truncation will be needed - estimated_tokens = len(self.tokenizer.encode(text, add_special_tokens=False)) - if estimated_tokens > self.MAX_TOKENS: - logger.warning( - f"Document will be truncated from ~{estimated_tokens} to {self.MAX_TOKENS} tokens. " - f"Consider using process_long_document() for documents > 32k tokens." - ) - - # Tokenize the full document - tokens = self.tokenizer( - text, - return_tensors="pt", - padding=False, - truncation=True, # Truncate to MAX_TOKENS if needed - max_length=self.MAX_TOKENS, - return_offsets_mapping=True, - return_attention_mask=True, - return_special_tokens_mask=True, - ) - - # Move to device - input_ids = tokens['input_ids'].to(self.model.device) - attention_mask = tokens['attention_mask'].to(self.model.device) - offset_mapping = tokens['offset_mapping'][0].cpu().numpy() - special_tokens_mask = tokens['special_tokens_mask'][0].cpu().numpy() - - with torch.no_grad(): - # Map task to Jina v4 task labels - # Map task to Jina v4 LoRA adapter names. - # Valid adapters: 'retrieval', 'text-matching', 'code' - task_mapping = { - 'retrieval.passage': 'retrieval', - 'retrieval.query': 'retrieval', - 'retrieval': 'retrieval', - 'text-matching': 'text-matching', - 'code': 'code', - } - task_label = task_mapping.get(task, 'retrieval') - - # For Jina v4, we need to pass task_label to the model - # Access the underlying transformer and add task_label - if hasattr(self.model, 'model'): - # Call the underlying model with task_label - outputs = self.model.model( - input_ids=input_ids, - attention_mask=attention_mask, - task_label=task_label, - output_hidden_states=True - ) - else: - # Fallback to direct model call - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - task_label=task_label, - output_hidden_states=True - ) - - # Get the last hidden state (contextualized token embeddings) - if hasattr(outputs, 'last_hidden_state') and outputs.last_hidden_state is not None: - token_embeddings = outputs.last_hidden_state[0] # Shape: [seq_len, hidden_dim] - elif hasattr(outputs, 'hidden_states') and outputs.hidden_states is not None: - # Some models put it in hidden_states - token_embeddings = outputs.hidden_states[-1][0] # Last layer, first batch - elif hasattr(outputs, 'vlm_last_hidden_states') and outputs.vlm_last_hidden_states is not None: - # Jina v4 specific: use vlm_last_hidden_states for token-level embeddings - token_embeddings = outputs.vlm_last_hidden_states[0] # Shape: [seq_len, hidden_dim] - elif hasattr(outputs, 'multi_vec_emb') and outputs.multi_vec_emb is not None: - # Jina v4: multi_vec_emb contains token embeddings before pooling - token_embeddings = outputs.multi_vec_emb[0] - elif isinstance(outputs, tuple): - # For tuple outputs, first element is usually the embeddings - token_embeddings = outputs[0][0] if len(outputs[0].shape) > 2 else outputs[0] - else: - # For custom output types like JinaEmbeddingsV4ModelOutput - # Try to get the first available tensor attribute - found_embeddings = False - for attr_name in ['embeddings', 'last_hidden_state', 'hidden_states', 'pooler_output', 'single_vec_emb']: - if hasattr(outputs, attr_name): - attr_value = getattr(outputs, attr_name) - if attr_value is not None: - if isinstance(attr_value, (list, tuple)) and len(attr_value) > 0: - token_embeddings = attr_value[-1][0] if hasattr(attr_value[-1], 'shape') else attr_value[0] - else: - token_embeddings = attr_value[0] if hasattr(attr_value, 'shape') and len(attr_value.shape) > 2 else attr_value - found_embeddings = True - break - - if not found_embeddings: - # Last resort: raise a more informative error - available_attrs = [attr for attr in dir(outputs) if not attr.startswith('_')] - raise ValueError(f"Could not extract token embeddings from {type(outputs).__name__}. Available attributes: {available_attrs}") - - # Apply pooling or projection if needed to get to 2048 dims - # (Jina v4 uses a projection layer for final embeddings) - if hasattr(self.model, 'encode'): - # Use Jina's projection mechanism - # We need to pool to get document embedding first - pooled = token_embeddings.mean(dim=0, keepdim=True) - # For now, we'll just use the pooled embeddings directly - # since we can't pass pre-computed embeddings to encode - # Use this as a reference for projection - else: - # Direct use of token embeddings - pass - - metadata = { - 'offset_mapping': offset_mapping, - 'num_tokens': len(input_ids[0]), - 'text_length': len(text), - 'task': task, - 'special_tokens_mask': special_tokens_mask, - } - - return token_embeddings, metadata - - def embed_chunks_from_tokens(self, - token_embeddings: torch.Tensor, - metadata: dict, - text: str, - chunk_size_tokens: int | None = None, - chunk_overlap_tokens: int | None = None) -> list[ChunkWithEmbedding]: - """ - Create chunks with embeddings from pre-computed token embeddings (second step). - - This is the second step of proper late chunking: - 1. Take the contextualized token embeddings from step 1 - 2. Create chunks by slicing the token embeddings - 3. Pool each chunk's tokens to get chunk embedding - 4. Each chunk embedding preserves full document context - - Args: - token_embeddings: Token-level embeddings from encode_full_document - metadata: Metadata dict from encode_full_document - text: Original text for creating chunk text - chunk_size_tokens: Override default chunk size - chunk_overlap_tokens: Override default overlap - - Returns: - List of ChunkWithEmbedding objects with context-aware embeddings - """ - chunk_size = chunk_size_tokens or self.chunk_size_tokens - overlap = chunk_overlap_tokens or self.chunk_overlap_tokens - - offset_mapping = metadata['offset_mapping'] - num_tokens = metadata['num_tokens'] - special_mask = metadata.get('special_tokens_mask') - - chunks = [] - chunk_index = 0 - start_token = 0 - - while start_token < num_tokens: - # Define chunk token boundaries - end_token = min(start_token + chunk_size, num_tokens) - - # Get character boundaries from offset mapping - valid_start = start_token - if special_mask is not None: - while valid_start < end_token and special_mask[valid_start] == 1: - valid_start += 1 - if valid_start >= end_token: - if end_token >= num_tokens: - break - start_token = end_token - continue - - valid_end = end_token - 1 - if special_mask is not None: - while valid_end > valid_start and special_mask[valid_end] == 1: - valid_end -= 1 - - start_offset = offset_mapping[valid_start] - end_offset = offset_mapping[valid_end] - - # Handle if these are still tensors - if hasattr(start_offset, 'cpu'): - start_offset = start_offset.cpu().numpy() - if hasattr(end_offset, 'cpu'): - end_offset = end_offset.cpu().numpy() - - start_char = int(start_offset[0]) - end_char = int(end_offset[1]) - - # Extract chunk text - chunk_text = text[start_char:end_char] - - # Get chunk embedding by pooling token embeddings - chunk_token_embeddings = token_embeddings[valid_start:valid_end + 1] - if chunk_token_embeddings.shape[0] == 0: - if end_token >= num_tokens: - break - start_token = end_token - continue - - # Mean pooling over tokens in the chunk - chunk_embedding = chunk_token_embeddings.mean(dim=0) - - # Ensure tensor is on CPU before numpy conversion - if hasattr(chunk_embedding, 'is_cuda') and chunk_embedding.is_cuda: - chunk_embedding = chunk_embedding.cpu() - - # Convert to numpy and normalize - chunk_embedding_np = chunk_embedding.numpy().astype(np.float32, copy=False) - norm = np.linalg.norm(chunk_embedding_np) - if norm > 0: - chunk_embedding_np = chunk_embedding_np / norm - - # Create ChunkWithEmbedding object - chunks.append(ChunkWithEmbedding( - text=chunk_text, - embedding=chunk_embedding_np, - start_char=start_char, - end_char=end_char, - start_token=valid_start, - end_token=valid_end + 1, - chunk_index=chunk_index, - total_chunks=0, # Will update after loop - context_window_used=num_tokens # Full document context - )) - - # Move to next chunk with overlap - if end_token >= num_tokens: - break - start_token = max(end_token - overlap, 0) - chunk_index += 1 - - # Update total chunks count - for chunk in chunks: - chunk.total_chunks = len(chunks) - - logger.debug(f"Created {len(chunks)} chunks from {num_tokens} tokens with full context") - return chunks - - def process_long_document(self, - text: str, - task: str = "retrieval.passage") -> list[ChunkWithEmbedding]: - """ - Process a document that may exceed 32k tokens. - - For documents longer than 32k tokens, we process in windows with - overlap to maintain some context across boundaries. - - Args: - text: Document text (can be very long) - task: Task type - - Returns: - List of all chunks with embeddings - """ - # Quick token count estimate (rough: ~4 chars per token) - estimated_tokens = len(text) // 4 - - logger.debug(f"process_long_document: text length={len(text)}, estimated tokens={estimated_tokens}") - - # For Jina v4, we use context window approach which already provides - # excellent context preservation through overlapping windows - # The model's encode_text method handles the embedding internally - - if estimated_tokens <= self.MAX_TOKENS: - # Document fits in model's context window - return self._chunk_with_context_windows(text, task) - - # Process very long documents in overlapping windows - logger.info(f"Document too long (~{estimated_tokens} tokens), processing in windows") - - all_chunks = [] - window_size_chars = self.MAX_TOKENS * 4 # Rough estimate - window_overlap_chars = 1000 * 4 # 1000 token overlap - - start = 0 - window_index = 0 - - while start < len(text): - end = min(start + window_size_chars, len(text)) - window_text = text[start:end] - - # Process this window with context windows approach - window_chunks = self._chunk_with_context_windows(window_text, task) - - # Adjust character positions to be relative to full document - for chunk in window_chunks: - chunk.start_char += start - chunk.end_char += start - chunk.start_token += start // 4 - chunk.end_token += start // 4 - - all_chunks.extend(window_chunks) - - if end >= len(text): - break - - start = end - window_overlap_chars - window_index += 1 - - # Re-index chunks - for i, chunk in enumerate(all_chunks): - chunk.chunk_index = i - chunk.total_chunks = len(all_chunks) - - logger.info(f"Processed {window_index + 1} windows, total {len(all_chunks)} chunks") - - return all_chunks diff --git a/services/embedder/core/proto/__init__.py b/services/embedder/core/proto/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/services/embedder/core/proto/persephone/__init__.py b/services/embedder/core/proto/persephone/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/services/embedder/core/proto/persephone/embedding/__init__.py b/services/embedder/core/proto/persephone/embedding/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/services/embedder/core/proto/persephone/embedding/embedding_pb2.py b/services/embedder/core/proto/persephone/embedding/embedding_pb2.py deleted file mode 100644 index 2d28d422..00000000 --- a/services/embedder/core/proto/persephone/embedding/embedding_pb2.py +++ /dev/null @@ -1,56 +0,0 @@ -# -*- coding: utf-8 -*- -# Generated by the protocol buffer compiler. DO NOT EDIT! -# NO CHECKED-IN PROTOBUF GENCODE -# source: persephone/embedding/embedding.proto -# Protobuf Python Version: 6.31.1 -"""Generated protocol buffer code.""" -from google.protobuf import descriptor as _descriptor -from google.protobuf import descriptor_pool as _descriptor_pool -from google.protobuf import runtime_version as _runtime_version -from google.protobuf import symbol_database as _symbol_database -from google.protobuf.internal import builder as _builder -_runtime_version.ValidateProtobufRuntimeVersion( - _runtime_version.Domain.PUBLIC, - 6, - 31, - 1, - '', - 'persephone/embedding/embedding.proto' -) -# @@protoc_insertion_point(imports) - -_sym_db = _symbol_database.Default() - - - - -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n$persephone/embedding/embedding.proto\x12\x14persephone.embedding\"?\n\x0c\x45mbedRequest\x12\r\n\x05texts\x18\x01 \x03(\t\x12\x0c\n\x04task\x18\x02 \x01(\t\x12\x12\n\nbatch_size\x18\x03 \x01(\r\"{\n\rEmbedResponse\x12\x33\n\nembeddings\x18\x01 \x03(\x0b\x32\x1f.persephone.embedding.Embedding\x12\r\n\x05model\x18\x02 \x01(\t\x12\x11\n\tdimension\x18\x03 \x01(\r\x12\x13\n\x0b\x64uration_ms\x18\x04 \x01(\x04\"5\n\x17\x45mbedLateChunkedRequest\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x0c\n\x04task\x18\x02 \x01(\t\"\x9f\x01\n\x18\x45mbedLateChunkedResponse\x12/\n\x06\x63hunks\x18\x01 \x03(\x0b\x32\x1f.persephone.embedding.LateChunk\x12\r\n\x05model\x18\x02 \x01(\t\x12\x11\n\tdimension\x18\x03 \x01(\r\x12\x13\n\x0b\x64uration_ms\x18\x04 \x01(\x04\x12\x1b\n\x13\x63ontext_window_used\x18\x05 \x01(\r\"\xc6\x01\n\tLateChunk\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x32\n\tembedding\x18\x02 \x01(\x0b\x32\x1f.persephone.embedding.Embedding\x12\x12\n\nstart_char\x18\x03 \x01(\r\x12\x10\n\x08\x65nd_char\x18\x04 \x01(\r\x12\x13\n\x0bstart_token\x18\x05 \x01(\r\x12\x11\n\tend_token\x18\x06 \x01(\r\x12\x13\n\x0b\x63hunk_index\x18\x07 \x01(\r\x12\x14\n\x0ctotal_chunks\x18\x08 \x01(\r\"\x1b\n\tEmbedding\x12\x0e\n\x06values\x18\x01 \x03(\x02\"\r\n\x0bInfoRequest\"\xbd\x01\n\x0cInfoResponse\x12\x12\n\nmodel_name\x18\x01 \x01(\t\x12\x11\n\tdimension\x18\x02 \x01(\r\x12\x16\n\x0emax_seq_length\x18\x03 \x01(\r\x12\x17\n\x0fsupported_tasks\x18\x04 \x03(\t\x12\x0e\n\x06\x64\x65vice\x18\x05 \x01(\t\x12\x14\n\x0cmodel_loaded\x18\x06 \x01(\x08\x12\x16\n\x0euptime_seconds\x18\x07 \x01(\x01\x12\x17\n\x0fweight_revision\x18\x08 \x01(\t\"3\n\x17GracefulShutdownRequest\x12\x18\n\x10\x64rain_timeout_ms\x18\x01 \x01(\r\"<\n\x18GracefulShutdownResponse\x12\x0f\n\x07\x64rained\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t2\x99\x03\n\x10\x45mbeddingService\x12P\n\x05\x45mbed\x12\".persephone.embedding.EmbedRequest\x1a#.persephone.embedding.EmbedResponse\x12q\n\x10\x45mbedLateChunked\x12-.persephone.embedding.EmbedLateChunkedRequest\x1a..persephone.embedding.EmbedLateChunkedResponse\x12M\n\x04Info\x12!.persephone.embedding.InfoRequest\x1a\".persephone.embedding.InfoResponse\x12q\n\x10GracefulShutdown\x12-.persephone.embedding.GracefulShutdownRequest\x1a..persephone.embedding.GracefulShutdownResponseb\x06proto3') - -_globals = globals() -_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'persephone.embedding.embedding_pb2', _globals) -if not _descriptor._USE_C_DESCRIPTORS: - DESCRIPTOR._loaded_options = None - _globals['_EMBEDREQUEST']._serialized_start=62 - _globals['_EMBEDREQUEST']._serialized_end=125 - _globals['_EMBEDRESPONSE']._serialized_start=127 - _globals['_EMBEDRESPONSE']._serialized_end=250 - _globals['_EMBEDLATECHUNKEDREQUEST']._serialized_start=252 - _globals['_EMBEDLATECHUNKEDREQUEST']._serialized_end=305 - _globals['_EMBEDLATECHUNKEDRESPONSE']._serialized_start=308 - _globals['_EMBEDLATECHUNKEDRESPONSE']._serialized_end=467 - _globals['_LATECHUNK']._serialized_start=470 - _globals['_LATECHUNK']._serialized_end=668 - _globals['_EMBEDDING']._serialized_start=670 - _globals['_EMBEDDING']._serialized_end=697 - _globals['_INFOREQUEST']._serialized_start=699 - _globals['_INFOREQUEST']._serialized_end=712 - _globals['_INFORESPONSE']._serialized_start=715 - _globals['_INFORESPONSE']._serialized_end=904 - _globals['_GRACEFULSHUTDOWNREQUEST']._serialized_start=906 - _globals['_GRACEFULSHUTDOWNREQUEST']._serialized_end=957 - _globals['_GRACEFULSHUTDOWNRESPONSE']._serialized_start=959 - _globals['_GRACEFULSHUTDOWNRESPONSE']._serialized_end=1019 - _globals['_EMBEDDINGSERVICE']._serialized_start=1022 - _globals['_EMBEDDINGSERVICE']._serialized_end=1431 -# @@protoc_insertion_point(module_scope) diff --git a/services/embedder/core/proto/persephone/embedding/embedding_pb2_grpc.py b/services/embedder/core/proto/persephone/embedding/embedding_pb2_grpc.py deleted file mode 100644 index ffd288b1..00000000 --- a/services/embedder/core/proto/persephone/embedding/embedding_pb2_grpc.py +++ /dev/null @@ -1,234 +0,0 @@ -# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! -"""Client and server classes corresponding to protobuf-defined services.""" -import grpc -import warnings - -from . import embedding_pb2 as persephone_dot_embedding_dot_embedding__pb2 - -GRPC_GENERATED_VERSION = '1.80.0' -GRPC_VERSION = grpc.__version__ -_version_not_supported = False - -try: - from grpc._utilities import first_version_is_lower - _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION) -except ImportError: - _version_not_supported = True - -if _version_not_supported: - raise RuntimeError( - f'The grpc package installed is at version {GRPC_VERSION},' - + ' but the generated code in persephone/embedding/embedding_pb2_grpc.py depends on' - + f' grpcio>={GRPC_GENERATED_VERSION}.' - + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}' - + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.' - ) - - -class EmbeddingServiceStub(object): - """Embedding service for generating vector representations of text. - """ - - def __init__(self, channel): - """Constructor. - - Args: - channel: A grpc.Channel. - """ - self.Embed = channel.unary_unary( - '/persephone.embedding.EmbeddingService/Embed', - request_serializer=persephone_dot_embedding_dot_embedding__pb2.EmbedRequest.SerializeToString, - response_deserializer=persephone_dot_embedding_dot_embedding__pb2.EmbedResponse.FromString, - _registered_method=True) - self.EmbedLateChunked = channel.unary_unary( - '/persephone.embedding.EmbeddingService/EmbedLateChunked', - request_serializer=persephone_dot_embedding_dot_embedding__pb2.EmbedLateChunkedRequest.SerializeToString, - response_deserializer=persephone_dot_embedding_dot_embedding__pb2.EmbedLateChunkedResponse.FromString, - _registered_method=True) - self.Info = channel.unary_unary( - '/persephone.embedding.EmbeddingService/Info', - request_serializer=persephone_dot_embedding_dot_embedding__pb2.InfoRequest.SerializeToString, - response_deserializer=persephone_dot_embedding_dot_embedding__pb2.InfoResponse.FromString, - _registered_method=True) - self.GracefulShutdown = channel.unary_unary( - '/persephone.embedding.EmbeddingService/GracefulShutdown', - request_serializer=persephone_dot_embedding_dot_embedding__pb2.GracefulShutdownRequest.SerializeToString, - response_deserializer=persephone_dot_embedding_dot_embedding__pb2.GracefulShutdownResponse.FromString, - _registered_method=True) - - -class EmbeddingServiceServicer(object): - """Embedding service for generating vector representations of text. - """ - - def Embed(self, request, context): - """Embed a batch of texts into vectors. - """ - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def EmbedLateChunked(self, request, context): - """Embed a document with late chunking — returns per-chunk embeddings whose - contextual representation reflects the whole document, not just the chunk. - """ - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def Info(self, request, context): - """Query provider capabilities and model metadata. - """ - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def GracefulShutdown(self, request, context): - """Gracefully drain in-flight requests and shut down the service. - """ - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - -def add_EmbeddingServiceServicer_to_server(servicer, server): - rpc_method_handlers = { - 'Embed': grpc.unary_unary_rpc_method_handler( - servicer.Embed, - request_deserializer=persephone_dot_embedding_dot_embedding__pb2.EmbedRequest.FromString, - response_serializer=persephone_dot_embedding_dot_embedding__pb2.EmbedResponse.SerializeToString, - ), - 'EmbedLateChunked': grpc.unary_unary_rpc_method_handler( - servicer.EmbedLateChunked, - request_deserializer=persephone_dot_embedding_dot_embedding__pb2.EmbedLateChunkedRequest.FromString, - response_serializer=persephone_dot_embedding_dot_embedding__pb2.EmbedLateChunkedResponse.SerializeToString, - ), - 'Info': grpc.unary_unary_rpc_method_handler( - servicer.Info, - request_deserializer=persephone_dot_embedding_dot_embedding__pb2.InfoRequest.FromString, - response_serializer=persephone_dot_embedding_dot_embedding__pb2.InfoResponse.SerializeToString, - ), - 'GracefulShutdown': grpc.unary_unary_rpc_method_handler( - servicer.GracefulShutdown, - request_deserializer=persephone_dot_embedding_dot_embedding__pb2.GracefulShutdownRequest.FromString, - response_serializer=persephone_dot_embedding_dot_embedding__pb2.GracefulShutdownResponse.SerializeToString, - ), - } - generic_handler = grpc.method_handlers_generic_handler( - 'persephone.embedding.EmbeddingService', rpc_method_handlers) - server.add_generic_rpc_handlers((generic_handler,)) - server.add_registered_method_handlers('persephone.embedding.EmbeddingService', rpc_method_handlers) - - - # This class is part of an EXPERIMENTAL API. -class EmbeddingService(object): - """Embedding service for generating vector representations of text. - """ - - @staticmethod - def Embed(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, - target, - '/persephone.embedding.EmbeddingService/Embed', - persephone_dot_embedding_dot_embedding__pb2.EmbedRequest.SerializeToString, - persephone_dot_embedding_dot_embedding__pb2.EmbedResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - _registered_method=True) - - @staticmethod - def EmbedLateChunked(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, - target, - '/persephone.embedding.EmbeddingService/EmbedLateChunked', - persephone_dot_embedding_dot_embedding__pb2.EmbedLateChunkedRequest.SerializeToString, - persephone_dot_embedding_dot_embedding__pb2.EmbedLateChunkedResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - _registered_method=True) - - @staticmethod - def Info(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, - target, - '/persephone.embedding.EmbeddingService/Info', - persephone_dot_embedding_dot_embedding__pb2.InfoRequest.SerializeToString, - persephone_dot_embedding_dot_embedding__pb2.InfoResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - _registered_method=True) - - @staticmethod - def GracefulShutdown(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, - target, - '/persephone.embedding.EmbeddingService/GracefulShutdown', - persephone_dot_embedding_dot_embedding__pb2.GracefulShutdownRequest.SerializeToString, - persephone_dot_embedding_dot_embedding__pb2.GracefulShutdownResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - _registered_method=True) diff --git a/services/embedder/core/services/__init__.py b/services/embedder/core/services/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/services/embedder/core/services/embedder_grpc.py b/services/embedder/core/services/embedder_grpc.py deleted file mode 100644 index 997bae21..00000000 --- a/services/embedder/core/services/embedder_grpc.py +++ /dev/null @@ -1,605 +0,0 @@ -"""Weaver Embedder Service — gRPC over Unix socket. - -Replaces the legacy FastAPI/HTTP service (embedder_service.py). The wire -protocol is now gRPC matching proto/persephone/embedding/embedding.proto, so -the Rust EmbeddingClient in crates/weaver-database/src/persephone/embedding.rs -can actually talk to this service instead of the dead-letter HTTP shape. - -RPCs: - * Embed — batch text → pooled embeddings - * EmbedLateChunked — full document → per-chunk context-aware embeddings - * Info — provider capability + uptime - * GracefulShutdown — drain in-flight + exit - -Run: python -m core.services.embedder_grpc -Socket: /run/weaver/embedder.sock (override via WEAVER_EMBEDDER_SOCKET) - -This service is harness-scoped mandatory infrastructure — same tier as -ArangoDB CE, pinned to Jina V4 FP16 safetensors, not swappable. -""" - -from __future__ import annotations - -import asyncio -import logging -import os -import signal -import socket -import stat -import time -from dataclasses import dataclass -from pathlib import Path -from collections.abc import Callable -from typing import Any - -import grpc -import numpy as np - -from core.cli.config import get_embedder_service_config -from core.proto.persephone.embedding import embedding_pb2, embedding_pb2_grpc - -logging.basicConfig( - level=logging.INFO, - format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", -) -logger = logging.getLogger(__name__) - - -@dataclass -class ServiceState: - embedder: Any | None = None - model_loaded: bool = False - device: str = "unknown" - model_name: str = "unknown" - # HuggingFace snapshot revision SHA for the currently-loaded weights. - # Populated during _load_model by reading /hub//refs/main. - # Empty string when the snapshot cannot be resolved (offline load, - # non-HF source, missing refs file). The harness pin-check tolerates - # this but logs a warning — a pinned revision simply can't be enforced - # against unknown live revision. - weight_revision: str = "" - start_time: float = 0.0 - last_request_time: float = 0.0 - idle_timeout_seconds: float = 0.0 - last_error: str | None = None - shutdown_event: asyncio.Event | None = None - lifecycle_lock: asyncio.Lock | None = None - # Two counters, not one: - # inflight_requests — RPC handler lifetime. Bumped at handler entry, - # dropped in a coroutine-level finally. Covers the full handler, - # including response marshaling AFTER the worker thread returns. - # inflight_workers — worker-thread (GPU) lifetime. Bumped before - # submit, dropped from the thread's own finally via - # call_soon_threadsafe. Covers the cancellation gap where the - # coroutine's finally has already run but the thread is still - # holding the GPU. - # GracefulShutdown and _idle_monitor must wait on both — a premature - # drain complete in either window would SIGTERM a live request or - # unload a model still in use. - inflight_requests: int = 0 - inflight_workers: int = 0 - # Drain gate: flipped False at the start of GracefulShutdown so new - # Embed/EmbedLateChunked calls are refused before the wait loop runs. - # Without this, a request arriving mid-drain can still be cut by the - # SIGTERM the shutdown RPC schedules. - accepting_requests: bool = True - - -state = ServiceState() - - -async def _unload_model() -> None: - if state.embedder is None: - return - del state.embedder - state.embedder = None - state.model_loaded = False - try: - import gc - - import torch - - if torch.cuda.is_available(): - gc.collect() - torch.cuda.empty_cache() - except Exception as e: - logger.warning(f"Could not clear CUDA cache: {e}") - logger.info("Model unloaded, GPU memory freed") - - -def _read_hf_snapshot_revision(model_name: str) -> str: - """Return the HuggingFace snapshot revision SHA for `model_name`, or "". - - The HF cache layout is: - $HF_HOME/hub/models----/refs/main -> "<40-char SHA>" - We read that file directly rather than round-tripping through - huggingface_hub's Python API — cheaper at boot, and "get the SHA the - loader just used" is not a feature that API exposes cleanly. Any - failure yields "" and the harness falls back to identity-without-rev. - """ - try: - hf_home = os.environ.get("HF_HOME") or os.path.expanduser("~/.cache/huggingface") - slug = "models--" + model_name.replace("/", "--") - ref_path = os.path.join(hf_home, "hub", slug, "refs", "main") - with open(ref_path, encoding="ascii") as fh: - return fh.read().strip() - except Exception as e: - logger.warning(f"Could not read HF snapshot revision for {model_name}: {e}") - return "" - - -async def _load_model() -> None: - from core.embedders.embedders_jina import JinaV4Embedder - - config = get_embedder_service_config() - device = config["device"] - use_fp16 = config["use_fp16"] - batch_size = config["batch_size"] - model_name = config["model_name"] - - logger.info(f"Loading model: {model_name} (device={device}, fp16={use_fp16})") - start = time.time() - state.embedder = JinaV4Embedder( - {"device": device, "use_fp16": use_fp16, "batch_size": batch_size, "model_name": model_name} - ) - load_time = time.time() - start - - state.model_loaded = True - state.last_error = None - state.device = state.embedder.device - state.model_name = model_name - state.weight_revision = _read_hf_snapshot_revision(model_name) - logger.info( - f"Model loaded in {load_time:.2f}s " - f"(revision={state.weight_revision or 'unknown'})" - ) - - -def _is_draining() -> bool: - """Either counter being non-zero means a request or GPU thread is still - in flight. Callers waiting for quiescence must poll this, not either - counter alone.""" - return state.inflight_requests > 0 or state.inflight_workers > 0 - - -async def _idle_monitor() -> None: - while True: - await asyncio.sleep(60) - if ( - state.model_loaded - and state.last_request_time > 0 - and state.idle_timeout_seconds > 0 - and (time.time() - state.last_request_time) > state.idle_timeout_seconds - ): - assert state.lifecycle_lock is not None - async with state.lifecycle_lock: - if not state.model_loaded or _is_draining(): - continue - logger.info(f"Model idle > {state.idle_timeout_seconds}s; unloading") - await _unload_model() - - -def _inc_worker() -> None: - """Bump the worker-thread counter. Scheduled by worker threads via - loop.call_soon_threadsafe so the counter only rises once the thread - has actually picked up the executor task.""" - state.inflight_workers += 1 - - -def _dec_worker() -> None: - """Decrement the worker-thread counter. Scheduled by worker threads via - loop.call_soon_threadsafe so the counter tracks the real GPU-holding - thread's lifetime even when the awaiting coroutine is cancelled.""" - state.inflight_workers -= 1 - - -async def _run_tracked(fn: Callable[[], Any]) -> Any: - """Run fn in a worker thread. Bump + dec are BOTH scheduled from the - thread via `loop.call_soon_threadsafe` — inc just before fn() runs, - dec in the thread's own `finally`. Moving the bump inside `_wrapped` - closes the "queued-then-cancelled" leak: if the asyncio task is - cancelled while `_wrapped` is still sitting in the executor's queue - (not yet picked up), the underlying concurrent.futures.Future can be - cancelled before it ever runs, and nothing schedules either inc or - dec — counter stays at 0, no leak. - - Safety of the "thread holds GPU but counter is briefly 0" window: - the window exists (thread starts fn() before _inc_worker has run on - the event loop), but `inflight_requests` is still 1 during it — the - handler coroutine is suspended on `await run_in_executor`, so its - `_ensure_loaded` bump is still in effect. `_is_draining()` ORs both - counters, so drain sees the work. - - GracefulShutdown's drain loop must poll both counters — see - `_is_draining`. - """ - loop = asyncio.get_running_loop() - - def _wrapped() -> Any: - loop.call_soon_threadsafe(_inc_worker) - try: - return fn() - finally: - loop.call_soon_threadsafe(_dec_worker) - - return await loop.run_in_executor(None, _wrapped) - - -async def _ensure_loaded(context: grpc.aio.ServicerContext) -> Any: - """Admit the RPC, ensure the model is loaded, and bump inflight_requests - under the lifecycle lock. Returns the embedder on success; aborts with - UNAVAILABLE on draining (`accepting_requests=False`) or repeated load - failure. Only admitted work ever touches the counter — abort paths - roll back the bump via the try/finally below. - - Bump ordering: increment happens IMMEDIATELY after the `accepting_requests` - gate, BEFORE `await _load_model()`. A cold load can suspend for many - seconds; counting it as "admitted" means GracefulShutdown's - `_is_draining()` sees the work for its full lifetime, not just the - post-load window (see CR #3 on PR #124). - - Cancellation between the bump and the `return state.embedder` at - async-with exit is not a hazard — `asyncio.Lock.__aexit__` has no awaits - in its body, so the increment and return run synchronously together. - Cancellation (or any exception) during `await _load_model()` flows - through the try/finally and decrements before propagating. - - `_run_tracked` still owns `inflight_workers` separately so a cancelled - coroutine doesn't drop drain bookkeeping while the worker thread is - still holding the GPU. - """ - assert state.lifecycle_lock is not None - async with state.lifecycle_lock: - if not state.accepting_requests: - await context.abort( - grpc.StatusCode.UNAVAILABLE, "service is draining for shutdown" - ) - state.inflight_requests += 1 - admitted = False - try: - if not state.model_loaded or state.embedder is None: - logger.info("Model not loaded; loading on demand") - try: - await _load_model() - except Exception as e: - state.last_error = str(e) - await context.abort( - grpc.StatusCode.UNAVAILABLE, f"Failed to load model: {e}" - ) - admitted = True - return state.embedder - finally: - if not admitted: - state.inflight_requests -= 1 - - -class EmbeddingServicer(embedding_pb2_grpc.EmbeddingServiceServicer): - async def Embed( - self, - request: embedding_pb2.EmbedRequest, - context: grpc.aio.ServicerContext, - ) -> embedding_pb2.EmbedResponse: - if not request.texts: - await context.abort(grpc.StatusCode.INVALID_ARGUMENT, "No texts provided") - batch_size = request.batch_size or None - if batch_size is not None and batch_size == 0: - batch_size = None - state.last_request_time = time.time() - - # Admission + inflight bump happen atomically inside _ensure_loaded - # under lifecycle_lock. If it aborts (draining / load failure) the - # counter is never touched, so the finally below is only reached - # on the admitted path. Keeping _ensure_loaded OUTSIDE the try is - # what enforces that symmetry. - embedder = await _ensure_loaded(context) - try: - start = time.time() - try: - embeddings: np.ndarray = await _run_tracked( - lambda: embedder.embed_texts( - list(request.texts), - task=request.task or "retrieval.passage", - batch_size=batch_size, - ) - ) - except Exception as e: - logger.exception("Embed failed") - await context.abort(grpc.StatusCode.INTERNAL, str(e)) - duration_ms = int((time.time() - start) * 1000) - pb_embeddings = [ - embedding_pb2.Embedding(values=list(map(float, e))) for e in embeddings - ] - dim = ( - int(embeddings.shape[1]) - if len(embeddings.shape) > 1 - else int(len(embeddings[0])) - ) - return embedding_pb2.EmbedResponse( - embeddings=pb_embeddings, - model=state.model_name, - dimension=dim, - duration_ms=duration_ms, - ) - finally: - state.inflight_requests -= 1 - - async def EmbedLateChunked( - self, - request: embedding_pb2.EmbedLateChunkedRequest, - context: grpc.aio.ServicerContext, - ) -> embedding_pb2.EmbedLateChunkedResponse: - if not request.text: - await context.abort(grpc.StatusCode.INVALID_ARGUMENT, "empty text") - state.last_request_time = time.time() - - # See Embed() for why admission/bump lives inside _ensure_loaded. - embedder = await _ensure_loaded(context) - try: - start = time.time() - try: - chunks = await _run_tracked( - lambda: embedder.embed_with_late_chunking( - request.text, - task=request.task or "retrieval.passage", - ) - ) - except Exception as e: - logger.exception("EmbedLateChunked failed") - await context.abort(grpc.StatusCode.INTERNAL, str(e)) - duration_ms = int((time.time() - start) * 1000) - if not chunks: - return embedding_pb2.EmbedLateChunkedResponse( - chunks=[], - model=state.model_name, - dimension=0, - duration_ms=duration_ms, - context_window_used=0, - ) - pb_chunks = [ - embedding_pb2.LateChunk( - text=c.text, - embedding=embedding_pb2.Embedding(values=list(map(float, c.embedding))), - start_char=int(c.start_char), - end_char=int(c.end_char), - start_token=int(c.start_token), - end_token=int(c.end_token), - chunk_index=int(c.chunk_index), - total_chunks=int(c.total_chunks), - ) - for c in chunks - ] - dim = int(len(chunks[0].embedding)) - ctx_used = int(chunks[0].context_window_used) - return embedding_pb2.EmbedLateChunkedResponse( - chunks=pb_chunks, - model=state.model_name, - dimension=dim, - duration_ms=duration_ms, - context_window_used=ctx_used, - ) - finally: - state.inflight_requests -= 1 - - async def Info( - self, - request: embedding_pb2.InfoRequest, - context: grpc.aio.ServicerContext, - ) -> embedding_pb2.InfoResponse: - uptime = time.time() - state.start_time if state.start_time > 0 else 0.0 - return embedding_pb2.InfoResponse( - model_name=state.model_name, - dimension=2048, - max_seq_length=32768, - supported_tasks=[ - "retrieval.passage", - "retrieval.query", - "code", - "text-matching", - ], - device=state.device, - model_loaded=state.model_loaded, - uptime_seconds=uptime, - weight_revision=state.weight_revision, - ) - - async def GracefulShutdown( - self, - request: embedding_pb2.GracefulShutdownRequest, - context: grpc.aio.ServicerContext, - ) -> embedding_pb2.GracefulShutdownResponse: - timeout_ms = int(request.drain_timeout_ms) - # Flip the drain gate BEFORE we start counting. Any Embed or - # EmbedLateChunked call arriving from now on gets UNAVAILABLE in - # _ensure_loaded, so the inflight count is bounded by what was - # already admitted when we got here. - state.accepting_requests = False - logger.info( - f"GracefulShutdown requested (" - f"requests={state.inflight_requests}, " - f"workers={state.inflight_workers}, " - f"drain_timeout_ms={timeout_ms})" - ) - drained = True - drain_start = time.time() - # Poll BOTH counters — `inflight_requests` covers the RPC handler - # (including response marshaling after the worker returns), and - # `inflight_workers` covers the cancellation gap where the handler's - # finally has fired but the GPU thread is still running. Draining on - # just one can SIGTERM live work. - while _is_draining(): - elapsed_ms = int((time.time() - drain_start) * 1000) - if timeout_ms > 0 and elapsed_ms >= timeout_ms: - drained = False - break - await asyncio.sleep(0.05) - elapsed_ms = int((time.time() - drain_start) * 1000) - if state.shutdown_event is not None: - state.shutdown_event.set() - asyncio.get_running_loop().call_later( - 0.2, lambda: os.kill(os.getpid(), signal.SIGTERM) - ) - return embedding_pb2.GracefulShutdownResponse( - drained=drained, - message=f"drain={'clean' if drained else 'timeout'}, elapsed_ms={elapsed_ms}", - ) - - -def get_socket_path() -> str: - # Resolved the same way as every other embedder config field: env var - # wins, then server.toml [embedder].socket, then built-in default. - # Single-source-of-truth with the Rust side. - return get_embedder_service_config()["socket"] - - -def ensure_socket_dir(socket_path: str) -> None: - d = Path(socket_path).parent - if not d.exists(): - logger.info(f"Creating socket directory: {d}") - d.mkdir(parents=True, mode=0o755) - - -def cleanup_stale_socket(socket_path: str) -> None: - """Remove the socket file ONLY if nothing is listening on it. - - Unlinking a live socket lets us bind a second server on the same - address while the first keeps running invisibly — two processes - racing on the same RPC surface. Probe with a short-timeout connect - and bail out if something answers. - """ - p = Path(socket_path) - if not p.exists(): - return - try: - if not stat.S_ISSOCK(p.stat().st_mode): - logger.warning( - f"Path at {socket_path} exists but is not a socket; leaving untouched" - ) - return - except OSError as e: - logger.warning(f"Could not stat {socket_path}: {e}; leaving untouched") - return - - probe = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - probe.settimeout(0.5) - try: - probe.connect(socket_path) - except (ConnectionRefusedError, FileNotFoundError): - # No one is listening — stale. Safe to remove. - pass - except OSError as e: - logger.warning( - f"Unexpected probe error on {socket_path}: {e}; refusing to unlink" - ) - return - else: - probe.close() - raise RuntimeError( - f"Another embedder is already bound to {socket_path}; " - "refusing to unlink a live socket. Stop the existing service first." - ) - finally: - probe.close() - - logger.info(f"Removing stale socket: {socket_path}") - p.unlink() - - -def _sd_notify(message: str) -> None: - """Send a notification message to systemd via NOTIFY_SOCKET. - - Implements the sd_notify protocol directly (no python-systemd dep). - Silently no-ops when NOTIFY_SOCKET is unset — i.e. when the service is - run outside systemd (manual invocation, test). With Type=notify in the - unit file, systemd blocks dependent units until "READY=1" arrives here; - forgetting the call would cause weaver-infer.service to hit its default - 90s TimeoutStartSec and fail with "start operation timed out". - """ - path = os.environ.get("NOTIFY_SOCKET") - if not path: - return - if path[0] == "@": - path = "\0" + path[1:] - try: - with socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM) as s: - s.connect(path) - s.sendall(message.encode("utf-8")) - except OSError as e: - logger.warning(f"sd_notify({message!r}) failed: {e}") - - -async def _serve() -> None: - state.start_time = time.time() - state.last_request_time = time.time() - state.shutdown_event = asyncio.Event() - state.lifecycle_lock = asyncio.Lock() - - config = get_embedder_service_config() - state.idle_timeout_seconds = float(config.get("idle_timeout", 300.0)) - - model_loaded = False - try: - await _load_model() - model_loaded = True - except Exception as e: - logger.error(f"Failed to load model at startup: {e}") - state.last_error = str(e) - - monitor_task = asyncio.create_task(_idle_monitor()) - - server = grpc.aio.server() - embedding_pb2_grpc.add_EmbeddingServiceServicer_to_server(EmbeddingServicer(), server) - socket_path = get_socket_path() - ensure_socket_dir(socket_path) - cleanup_stale_socket(socket_path) - server.add_insecure_port(f"unix://{socket_path}") - os.umask(0o077) - - loop = asyncio.get_running_loop() - for sig in (signal.SIGINT, signal.SIGTERM): - loop.add_signal_handler(sig, lambda: state.shutdown_event.set()) - - await server.start() - # 0660 is retained from the weaver-embedder/weaver-admin split era — - # both services now run as weaver-admin, so 0600 would suffice, but - # 0660 keeps the door open for a future agent-ownership model where - # peer agents want to share the embedder via group membership. - try: - os.chmod(socket_path, 0o660) - except OSError as e: - logger.warning(f"Could not chmod socket {socket_path}: {e}") - logger.info(f"gRPC embedder service listening on unix://{socket_path}") - - # Under Type=notify, this is what releases weaver-infer.service's - # After=/Requires= hold. Fire it AFTER chmod — not just after server - # start — so a dependent unit probing the socket never races against - # the 0077 default mode. Only signal READY if the model actually loaded; - # otherwise systemd will time out and surface the real failure. - if model_loaded: - _sd_notify("READY=1") - else: - logger.error("Model failed to load; not signaling READY=1 to systemd") - - try: - await state.shutdown_event.wait() - finally: - logger.info("Stopping server (grace=5s)") - monitor_task.cancel() - try: - await monitor_task - except asyncio.CancelledError: - pass - await server.stop(grace=5) - await _unload_model() - cleanup_stale_socket(socket_path) - logger.info("Embedder service stopped") - - -def main() -> None: - try: - asyncio.run(_serve()) - except KeyboardInterrupt: - pass - - -if __name__ == "__main__": - main() diff --git a/services/embedder/core/services/embedder_service.py b/services/embedder/core/services/embedder_service.py deleted file mode 100644 index 2b1aaa7f..00000000 --- a/services/embedder/core/services/embedder_service.py +++ /dev/null @@ -1,450 +0,0 @@ -"""Weaver Embedder Service Daemon — LEGACY HTTP/FastAPI implementation. - -DEPRECATED (2026-04-20 embedder-gate pivot): the supported wire protocol is -now gRPC, implemented in `core.services.embedder_grpc`, matching the Rust -EmbeddingClient in crates/weaver-database/src/persephone/embedding.rs. - -This module is retained for emergency rollback only. Do NOT start it alongside -the gRPC service — both bind to /run/weaver/embedder.sock and will conflict. -New development should target `embedder_grpc`. - -Run (legacy, HTTP over UDS at /run/weaver/embedder.sock): - python -m core.services.embedder_service -""" - -from __future__ import annotations - -import asyncio -import logging -import os -import secrets -import signal -import sys -import time -from contextlib import asynccontextmanager -from dataclasses import dataclass -from pathlib import Path -from typing import Any - -import numpy as np -import uvicorn -from fastapi import FastAPI, HTTPException -from pydantic import BaseModel, Field - -from core.cli.config import get_embedder_service_config - -# Configure logging before imports that might log -logging.basicConfig( - level=logging.INFO, - format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", -) -logger = logging.getLogger(__name__) - - -# ============================================================================= -# Request/Response Models -# ============================================================================= - - -class EmbedRequest(BaseModel): - """Request to embed texts.""" - - texts: list[str] = Field(..., description="List of texts to embed") - task: str = Field( - default="retrieval.passage", - description="Task type: retrieval.passage, retrieval.query, text-matching", - ) - batch_size: int | None = Field(default=None, description="Batch size (uses service default if not specified)") - - -class EmbedResponse(BaseModel): - """Response containing embeddings.""" - - embeddings: list[list[float]] = Field(..., description="List of embedding vectors") - model: str = Field(..., description="Model name used") - dimension: int = Field(..., description="Embedding dimension") - count: int = Field(..., description="Number of embeddings returned") - duration_ms: float = Field(..., description="Processing time in milliseconds") - - -class HealthResponse(BaseModel): - """Health check response.""" - - status: str = Field(..., description="Service status: ready, loading, idle, error") - model_loaded: bool = Field(..., description="Whether model is loaded in memory") - device: str = Field(..., description="Device model is running on") - model_name: str = Field(..., description="Name of loaded model") - uptime_seconds: float = Field(..., description="Service uptime in seconds") - idle_timeout_seconds: float = Field(..., description="Idle timeout before model unload") - model_idle_seconds: float | None = Field(None, description="Seconds since model was unloaded due to idle") - - -class ShutdownResponse(BaseModel): - """Shutdown response.""" - - message: str = Field(..., description="Shutdown message") - - -# ============================================================================= -# Service State -# ============================================================================= - - -@dataclass -class ServiceState: - """Global service state.""" - - embedder: Any | None = None # JinaV4Embedder instance - model_loaded: bool = False - device: str = "unknown" - model_name: str = "unknown" - start_time: float = 0.0 - last_request_time: float = 0.0 - idle_timeout_seconds: float = 0.0 # 0 = never unload; overridden by config - last_error: str | None = None # Most recent load failure message - shutdown_event: asyncio.Event | None = None - lifecycle_lock: asyncio.Lock | None = None # Guards model load/unload vs in-flight requests - inflight_requests: int = 0 # Number of requests currently using the embedder - - -state = ServiceState() - - -# ============================================================================= -# Lifespan Management -# ============================================================================= - - -async def unload_model() -> None: - """Unload the model from GPU memory. - - Caller must hold state.lifecycle_lock when in-flight requests are possible. - """ - if state.embedder is not None: - del state.embedder - state.embedder = None - state.model_loaded = False - try: - import gc - - import torch - - if torch.cuda.is_available(): - gc.collect() # Break reference cycles before releasing CUDA cache - torch.cuda.empty_cache() - logger.info("CUDA cache cleared") - except Exception as e: - logger.warning(f"Could not clear CUDA cache: {e}") - logger.info("Model unloaded, GPU memory freed") - - -async def idle_monitor() -> None: - """Background task to unload model after idle timeout.""" - while True: - await asyncio.sleep(60) - if ( - state.model_loaded - and state.last_request_time > 0 - and state.idle_timeout_seconds > 0 - and (time.time() - state.last_request_time) > state.idle_timeout_seconds - ): - assert state.lifecycle_lock is not None - async with state.lifecycle_lock: - # Re-check after acquiring lock — a request may have started - if not state.model_loaded or state.inflight_requests > 0: - continue - logger.info( - f"Model idle for {state.idle_timeout_seconds}s. " - "Unloading to free GPU memory..." - ) - await unload_model() - - -@asynccontextmanager -async def lifespan(app: FastAPI): - """Manage service lifespan - load model on startup, cleanup on shutdown.""" - state.start_time = time.time() - state.last_request_time = time.time() - state.shutdown_event = asyncio.Event() - state.lifecycle_lock = asyncio.Lock() - - # Load idle timeout from config (YAML with env var override) - config = get_embedder_service_config() - state.idle_timeout_seconds = float(config["idle_timeout"]) - - # Load model on startup - logger.info("Starting HADES Embedding Service...") - try: - await load_model() - except Exception as e: - logger.error(f"Failed to load model: {e}") - # Continue running so health check can report error state - state.model_loaded = False - state.last_error = str(e) - - monitor_task = asyncio.create_task(idle_monitor()) - - yield - - # Cleanup on shutdown - logger.info("Shutting down HADES Embedding Service...") - monitor_task.cancel() - try: - await monitor_task - except asyncio.CancelledError: - pass - await unload_model() - logger.info("Embedding service stopped") - - -async def load_model() -> None: - """Load the embedding model into GPU memory.""" - from core.embedders.embedders_jina import JinaV4Embedder - - # Load config from YAML with env var overrides - # Priority: env vars > yaml config > defaults - config = get_embedder_service_config() - device = config["device"] - use_fp16 = config["use_fp16"] - batch_size = config["batch_size"] - model_name = config["model_name"] - - logger.info(f"Loading model: {model_name}") - logger.info(f"Device: {device}, FP16: {use_fp16}, Batch size: {batch_size}") - - # Load model (this is the slow part - 3-5 seconds) - start = time.time() - state.embedder = JinaV4Embedder( - { - "device": device, - "use_fp16": use_fp16, - "batch_size": batch_size, - "model_name": model_name, - } - ) - load_time = time.time() - start - - state.model_loaded = True - state.last_error = None - state.device = state.embedder.device # Actual runtime device (may differ from config on CPU fallback) - state.model_name = model_name - - logger.info(f"Model loaded in {load_time:.2f}s - ready to serve requests") - - -# ============================================================================= -# FastAPI Application -# ============================================================================= - - -app = FastAPI( - title="HADES Embedding Service", - description="Persistent embedding service with GPU-accelerated Jina V4", - version="1.0.0", - lifespan=lifespan, -) - - -@app.get("/health", response_model=HealthResponse) -async def health() -> HealthResponse: - """Health check endpoint.""" - uptime = time.time() - state.start_time if state.start_time > 0 else 0 - - # Determine idle seconds (time since model was unloaded) - model_idle_seconds = None - if state.model_loaded: - status = "ready" - elif state.last_error is not None: - status = "error" - elif state.embedder is None and state.last_request_time > 0 and uptime >= 60: - # Model was unloaded due to idle - status = "idle" - model_idle_seconds = time.time() - state.last_request_time - elif state.embedder is None and uptime < 60: - status = "loading" - else: - status = "error" - - return HealthResponse( - status=status, - model_loaded=state.model_loaded, - device=state.device, - model_name=state.model_name, - uptime_seconds=uptime, - idle_timeout_seconds=state.idle_timeout_seconds, - model_idle_seconds=round(model_idle_seconds, 1) if model_idle_seconds is not None else None, - ) - - -@app.post("/embed", response_model=EmbedResponse) -async def embed(request: EmbedRequest) -> EmbedResponse: - """Embed texts using the loaded model.""" - # Validate request before touching state or triggering model reload. - if not request.texts: - raise HTTPException(status_code=400, detail="No texts provided") - if request.batch_size is not None and request.batch_size <= 0: - raise HTTPException(status_code=400, detail="batch_size must be a positive integer") - - # Reset idle timer after validation so invalid requests can't - # reset the idle timer or force a model reload. - state.last_request_time = time.time() - - assert state.lifecycle_lock is not None - async with state.lifecycle_lock: - # Reload model if it was unloaded due to idle - if not state.model_loaded or state.embedder is None: - logger.info("Model unloaded (idle). Reloading for incoming request...") - try: - await load_model() - except Exception as e: - state.last_error = str(e) - raise HTTPException( - status_code=503, - detail=f"Failed to reload model: {e}", - ) from e - - # Capture reference and bump in-flight count while holding the lock - # so idle_monitor cannot unload between our check and the executor call. - embedder = state.embedder - state.inflight_requests += 1 - - start = time.time() - - try: - # Run embedding in thread pool to avoid blocking event loop - embeddings: np.ndarray = await asyncio.get_running_loop().run_in_executor( - None, - lambda: embedder.embed_texts( - request.texts, - task=request.task, - batch_size=request.batch_size, - ), - ) - - duration_ms = (time.time() - start) * 1000 - - return EmbedResponse( - embeddings=[e.tolist() for e in embeddings], - model=state.model_name, - dimension=embeddings.shape[1] if len(embeddings.shape) > 1 else len(embeddings[0]), - count=len(embeddings), - duration_ms=round(duration_ms, 2), - ) - - except Exception as e: - logger.error(f"Embedding failed: {e}") - raise HTTPException(status_code=500, detail=str(e)) from e - - finally: - state.inflight_requests -= 1 - - -class ShutdownRequest(BaseModel): - """Shutdown request with optional token.""" - - token: str | None = Field(default=None, description="Shutdown token for authentication") - - -# Shutdown token from environment (if set, token is required) -SHUTDOWN_TOKEN = os.environ.get("WEAVER_EMBEDDER_SHUTDOWN_TOKEN") - - -@app.post("/shutdown", response_model=ShutdownResponse) -async def shutdown(request: ShutdownRequest | None = None) -> ShutdownResponse: - """Gracefully shutdown the service. - - If WEAVER_EMBEDDER_SHUTDOWN_TOKEN is set, a matching token must be provided. - """ - # Validate token if configured - if SHUTDOWN_TOKEN: - provided_token = request.token if request else None - if not provided_token or not secrets.compare_digest(provided_token, SHUTDOWN_TOKEN): - raise HTTPException( - status_code=403, - detail="Invalid or missing shutdown token", - ) - - logger.info("Shutdown requested via API") - - if state.shutdown_event: - state.shutdown_event.set() - - # Schedule shutdown after response is sent - asyncio.get_running_loop().call_later(0.5, lambda: os.kill(os.getpid(), signal.SIGTERM)) - - return ShutdownResponse(message="Shutdown initiated") - - -# ============================================================================= -# Server Configuration -# ============================================================================= - - -def get_socket_path() -> str: - """Get the Unix socket path from service config (env > server.toml > default).""" - return get_embedder_service_config()["socket"] - - -def ensure_socket_dir(socket_path: str) -> None: - """Ensure the socket directory exists with proper permissions.""" - socket_dir = Path(socket_path).parent - if not socket_dir.exists(): - logger.info(f"Creating socket directory: {socket_dir}") - socket_dir.mkdir(parents=True, mode=0o755) - - -def cleanup_stale_socket(socket_path: str) -> None: - """Remove stale socket file if it exists.""" - path = Path(socket_path) - if path.exists(): - logger.info(f"Removing stale socket: {socket_path}") - path.unlink() - - -def run_server() -> None: - """Run the embedding service.""" - socket_path = get_socket_path() - - # Ensure directory exists and clean up stale socket - ensure_socket_dir(socket_path) - cleanup_stale_socket(socket_path) - - logger.info(f"Starting server on Unix socket: {socket_path}") - - # Configure uvicorn - config = uvicorn.Config( - app, - uds=socket_path, - log_level="info", - access_log=True, - # Limit workers to 1 to avoid loading model multiple times - workers=1, - ) - - server = uvicorn.Server(config) - - # Note: uvicorn handles SIGTERM/SIGINT internally for graceful shutdown - # No custom signal handlers needed - - try: - server.run() - finally: - # Ensure socket is cleaned up - cleanup_stale_socket(socket_path) - - -# ============================================================================= -# Entry Point -# ============================================================================= - - -if __name__ == "__main__": - # Allow running with optional port for development/testing - if len(sys.argv) > 1 and sys.argv[1] == "--http": - # HTTP mode for testing (not recommended for production) - port = int(sys.argv[2]) if len(sys.argv) > 2 else 8000 - logger.info(f"Running in HTTP mode on port {port} (for testing only)") - uvicorn.run(app, host="127.0.0.1", port=port) - else: - # Default: Unix socket mode - run_server() diff --git a/services/embedder/requirements.txt b/services/embedder/requirements.txt deleted file mode 100644 index 5c6108b8..00000000 --- a/services/embedder/requirements.txt +++ /dev/null @@ -1,33 +0,0 @@ -# Minimal dependencies for the Weaver embedder service. -# GPU inference: torch, transformers, pillow, peft -# Wire protocol: grpcio (Unix socket gRPC server matching -# proto/persephone/embedding/embedding.proto) -# Config: pyyaml -# Numerics: numpy -numpy -torch -torchvision -transformers>=4.52,<5 -pillow -peft -# Flash Attention 2 — O(N) attention memory, required to fit Jina V4 within 16 GiB. -# Build requires CUDA toolkit matching torch's CUDA version. Install with -# --no-build-isolation and optionally TORCH_CUDA_ARCH_LIST limited to your GPU arch. -flash-attn>=2.8,<3 -grpcio>=1.80 -grpcio-tools>=1.80 -# 6.33.5 patches GHSA-7gcm-g887-7qv7 / CVE-2026-0994 (protobuf parser OOM -# on untrusted input). Floor must stay at-or-above 6.33.5 — do not lower. -protobuf>=6.33.5 -pyyaml -# Legacy rollback path. core/services/embedder_service.py is the pre-pivot -# FastAPI embedder and still imports these; keeping it importable means -# operators can revert the unit to the old ExecStart line if the gRPC -# path hits a production bug. Remove these once embedder_service.py is -# deleted outright (tracked in the embedder gate-exit PR follow-up). -fastapi -uvicorn[standard] -pydantic -# Dev/test-only -pytest -pytest-asyncio diff --git a/services/embedder/tests/__init__.py b/services/embedder/tests/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/services/embedder/tests/conftest.py b/services/embedder/tests/conftest.py deleted file mode 100644 index 3713e864..00000000 --- a/services/embedder/tests/conftest.py +++ /dev/null @@ -1,14 +0,0 @@ -"""Pytest config for the embedder service tests. - -Puts the service root on sys.path so tests can import `core.*` without -running from that directory. -""" - -from __future__ import annotations - -import sys -from pathlib import Path - -SERVICE_ROOT = Path(__file__).resolve().parent.parent -if str(SERVICE_ROOT) not in sys.path: - sys.path.insert(0, str(SERVICE_ROOT)) diff --git a/services/embedder/tests/test_embedder_grpc.py b/services/embedder/tests/test_embedder_grpc.py deleted file mode 100644 index 05420b77..00000000 --- a/services/embedder/tests/test_embedder_grpc.py +++ /dev/null @@ -1,297 +0,0 @@ -"""End-to-end gRPC tests for the Weaver embedder service. - -Runs the real grpc.aio server over a temp Unix socket with a stub embedder -(so no GPU is required) and exercises all four RPCs: - * Embed - * EmbedLateChunked - * Info - * GracefulShutdown (drain semantics) - -Tests the wire contract that crates/weaver-database/src/persephone/embedding.rs -actually talks to. If any of these break, the Rust client is also broken. -""" - -from __future__ import annotations - -import asyncio -import hashlib -import os -import time -from dataclasses import dataclass -from pathlib import Path - -import grpc -import numpy as np -import pytest -import pytest_asyncio - -from core.proto.persephone.embedding import embedding_pb2, embedding_pb2_grpc -from core.services import embedder_grpc - - -# --------------------------------------------------------------------------- -# Stub embedder — no GPU required -# --------------------------------------------------------------------------- - - -@dataclass -class _StubChunk: - text: str - embedding: np.ndarray - start_char: int - end_char: int - start_token: int - end_token: int - chunk_index: int - total_chunks: int - context_window_used: int - - -class StubEmbedder: - """Deterministic embedder for wire-contract tests.""" - - device = "cpu" - EMBEDDING_DIM = 2048 - - def __init__(self, _config): - pass - - def embed_texts(self, texts, task="retrieval.passage", batch_size=None): - # Deterministic: each embedding is texts[i]-length repeated / pad. - # hashlib (not builtin hash()) — the latter is PYTHONHASHSEED-randomized - # per process, which would make the stub non-reproducible across runs. - out = np.zeros((len(texts), 2048), dtype=np.float32) - task_digest = int(hashlib.sha256(task.encode("utf-8")).hexdigest()[:8], 16) - for i, t in enumerate(texts): - out[i, 0] = float(len(t)) - out[i, 1] = float(task_digest % 1000) / 1000.0 - return out - - def embed_with_late_chunking(self, text, task="retrieval.passage"): - # Split into 3 equal chunks for deterministic span math. - n = max(len(text), 3) - step = max(n // 3, 1) - chunks = [] - for i in range(3): - start = i * step - end = min((i + 1) * step, len(text)) if i < 2 else len(text) - emb = np.zeros(2048, dtype=np.float32) - emb[0] = float(end - start) - emb[1] = float(i) - chunks.append( - _StubChunk( - text=text[start:end], - embedding=emb, - start_char=start, - end_char=end, - start_token=start // 4, - end_token=end // 4, - chunk_index=i, - total_chunks=3, - context_window_used=len(text) // 4, - ) - ) - return chunks - - -# --------------------------------------------------------------------------- -# Fixtures -# --------------------------------------------------------------------------- - - -@pytest.fixture -def _stub_embedder(monkeypatch): - # Patch the import target used inside _load_model. - import core.embedders.embedders_jina as jmod - - monkeypatch.setattr(jmod, "JinaV4Embedder", StubEmbedder) - # Also short-circuit get_embedder_service_config so it doesn't require - # a YAML file or env vars beyond what we set here. - import core.cli.config as cfg - - def fake_config(path=None): - return { - "device": "cpu", - "use_fp16": False, - "batch_size": 8, - "model_name": "stub-model", - "idle_timeout": 0, - "socket": "/run/weaver/embedder.sock", - } - - monkeypatch.setattr(cfg, "get_embedder_service_config", fake_config) - monkeypatch.setattr(embedder_grpc, "get_embedder_service_config", fake_config) - - -@pytest_asyncio.fixture -async def grpc_server(tmp_path, monkeypatch, _stub_embedder): - socket_path = tmp_path / "embedder.sock" - # Reset global state — the service module uses a module-level ServiceState. - embedder_grpc.state = embedder_grpc.ServiceState() - embedder_grpc.state.start_time = time.time() - embedder_grpc.state.last_request_time = time.time() - embedder_grpc.state.shutdown_event = asyncio.Event() - embedder_grpc.state.lifecycle_lock = asyncio.Lock() - embedder_grpc.state.idle_timeout_seconds = 0 - - # monkeypatch restores the prior value automatically on fixture teardown, - # so a crashing test can't leak WEAVER_EMBEDDER_SOCKET into the next one. - monkeypatch.setenv("WEAVER_EMBEDDER_SOCKET", str(socket_path)) - - await embedder_grpc._load_model() - - server = grpc.aio.server() - embedding_pb2_grpc.add_EmbeddingServiceServicer_to_server( - embedder_grpc.EmbeddingServicer(), server - ) - server.add_insecure_port(f"unix://{socket_path}") - await server.start() - # Assert the socket is reachable before yielding — silent timeout here - # would surface later as a misleading connection-refused in the test body. - for _ in range(40): - if Path(socket_path).exists(): - break - await asyncio.sleep(0.05) - else: - raise RuntimeError( - f"grpc server did not publish socket {socket_path} within 2s" - ) - - try: - yield socket_path - finally: - await server.stop(grace=1) - await embedder_grpc._unload_model() - - -async def _client(socket_path): - channel = grpc.aio.insecure_channel(f"unix://{socket_path}") - return embedding_pb2_grpc.EmbeddingServiceStub(channel), channel - - -# --------------------------------------------------------------------------- -# Tests -# --------------------------------------------------------------------------- - - -@pytest.mark.asyncio -async def test_info_reports_model_metadata(grpc_server): - stub, channel = await _client(grpc_server) - try: - info = await stub.Info(embedding_pb2.InfoRequest()) - assert info.model_loaded is True - assert info.model_name == "stub-model" - assert info.dimension == 2048 - assert info.max_seq_length == 32768 - assert "retrieval.passage" in info.supported_tasks - assert info.uptime_seconds >= 0.0 - finally: - await channel.close() - - -@pytest.mark.asyncio -async def test_embed_returns_batch(grpc_server): - stub, channel = await _client(grpc_server) - try: - req = embedding_pb2.EmbedRequest( - texts=["hello world", "a longer passage to embed"], - task="retrieval.passage", - ) - resp = await stub.Embed(req) - assert len(resp.embeddings) == 2 - assert resp.dimension == 2048 - # Stub encodes len(text) in position 0. - assert resp.embeddings[0].values[0] == pytest.approx(len("hello world")) - assert resp.embeddings[1].values[0] == pytest.approx(len("a longer passage to embed")) - finally: - await channel.close() - - -@pytest.mark.asyncio -async def test_embed_rejects_empty(grpc_server): - stub, channel = await _client(grpc_server) - try: - with pytest.raises(grpc.aio.AioRpcError) as excinfo: - await stub.Embed(embedding_pb2.EmbedRequest(texts=[], task="retrieval.passage")) - assert excinfo.value.code() == grpc.StatusCode.INVALID_ARGUMENT - finally: - await channel.close() - - -@pytest.mark.asyncio -async def test_late_chunked_returns_spans(grpc_server): - stub, channel = await _client(grpc_server) - try: - doc = "the quick brown fox jumps over the lazy dog " * 6 - resp = await stub.EmbedLateChunked( - embedding_pb2.EmbedLateChunkedRequest(text=doc, task="retrieval.passage") - ) - assert len(resp.chunks) == 3 - # Spans cover the document, in order. - assert resp.chunks[0].start_char == 0 - assert resp.chunks[-1].end_char == len(doc) - assert resp.chunks[0].chunk_index == 0 - assert resp.chunks[0].total_chunks == 3 - assert resp.dimension == 2048 - assert resp.chunks[0].embedding.values[0] > 0 - finally: - await channel.close() - - -@pytest.mark.asyncio -async def test_late_chunked_rejects_empty(grpc_server): - stub, channel = await _client(grpc_server) - try: - with pytest.raises(grpc.aio.AioRpcError) as excinfo: - await stub.EmbedLateChunked( - embedding_pb2.EmbedLateChunkedRequest(text="", task="retrieval.passage") - ) - assert excinfo.value.code() == grpc.StatusCode.INVALID_ARGUMENT - finally: - await channel.close() - - -@pytest.mark.asyncio -async def test_graceful_shutdown_drains_when_idle(grpc_server, monkeypatch): - # Patch os.kill to avoid actually signalling the test process. - calls: list[tuple] = [] - monkeypatch.setattr(embedder_grpc.os, "kill", lambda pid, sig: calls.append((pid, sig))) - - stub, channel = await _client(grpc_server) - try: - resp = await stub.GracefulShutdown( - embedding_pb2.GracefulShutdownRequest(drain_timeout_ms=500) - ) - assert resp.drained is True - assert "clean" in resp.message - # Let the scheduled os.kill fire. - await asyncio.sleep(0.35) - assert len(calls) == 1 - finally: - await channel.close() - - -@pytest.mark.asyncio -async def test_graceful_shutdown_waits_for_workers_when_requests_idle( - grpc_server, monkeypatch -): - """The cancellation-gap scenario: the RPC handler coroutine has already - returned (inflight_requests == 0) but a worker thread is still on the - GPU (inflight_workers > 0). Drain must keep waiting — a premature - SIGTERM here is exactly what the two-counter split exists to prevent. - """ - monkeypatch.setattr(embedder_grpc.os, "kill", lambda pid, sig: None) - # Simulate a thread still running after the coroutine finished. - embedder_grpc.state.inflight_workers = 1 - - stub, channel = await _client(grpc_server) - try: - resp = await stub.GracefulShutdown( - embedding_pb2.GracefulShutdownRequest(drain_timeout_ms=200) - ) - assert resp.drained is False - assert "timeout" in resp.message - finally: - # Reset so fixture teardown doesn't hang on the lingering counter. - embedder_grpc.state.inflight_workers = 0 - await channel.close() diff --git a/services/embedder/weaver-embedder.service b/services/embedder/weaver-embedder.service deleted file mode 100644 index cdea57e6..00000000 --- a/services/embedder/weaver-embedder.service +++ /dev/null @@ -1,98 +0,0 @@ -[Unit] -Description=Weaver Embedder Service (Jina V4) -Documentation=https://github.com/toddwbucy/WeaverTools -Documentation=file:///opt/weavertools/docs/memory-and-sleep-implementation-plan-2026-04-18.md - -# Embedder is mandatory harness infrastructure (gate direction update 2026-04-20) -# — same tier as ArangoDB. The harness is the lifecycle owner; this unit is -# started/stopped alongside weaver-infer, not on boot independent of it. -# -# Ordering: embedder comes up FIRST, then weaver-infer. `weaver serve` probes -# the embedder at startup to verify the cohort-identity pin, and refuses to -# start if the pin exists but the embedder is unreachable. For that probe to -# have a chance of succeeding, the embedder must already be past `exec`d by -# the time infer starts — hence the dependency direction is expressed on -# weaver-infer.service (After=/Requires=/BindsTo=weaver-embedder.service). -# This unit just sits on network.target. -After=network.target - -# Systemctl stop/restart on weaver.target must propagate here. WantedBy= in -# [Install] only handles start-on-enable; PartOf= is what ties us to the -# target's stop/reload lifecycle. Without it, `systemctl stop weaver.target` -# would leave the embedder running because weaver-infer BindsTo us (pass-2 -# ordering flip), not the other way around. -PartOf=weaver.target - -[Service] -# Type=notify: the embedder calls sd_notify("READY=1") at the end of -# startup (after the Jina V4 weights are resident on GPU) so systemd -# holds weaver-infer.service's After=/Requires= chain until the embedder -# is genuinely ready to answer the cohort-identity probe. With -# Type=simple, `weaver serve` would race against an un-loaded embedder -# and hit "Embedder not reachable" on cold boot. -Type=notify - -# Runs as weaver-admin — the same identity as weaver-infer. This is a -# deliberate collapse of the previous weaver-embedder/weaver-admin split: -# the two services share a socket for every request, so separating their -# identities bought group-ACL gymnastics without a real privilege gain. -# Keeping them on one user simplifies the IPC and aligns with the -# harness-owned lifecycle (embedder is mandatory infrastructure, not a -# separately-sandboxed peer). -User=weaver-admin -Group=weaver-admin - -# Create runtime directory for socket (mode 0755; socket itself is -# chmod 0660 by the service at startup so only weaver-admin can connect). -# Preserve=yes keeps /run/weaver/ alive across a restart of either -# unit — without it, systemd would tear down the directory (and the -# peer's socket still bound inside it) when this unit stops. -RuntimeDirectory=weaver -RuntimeDirectoryMode=0755 -RuntimeDirectoryPreserve=yes - -WorkingDirectory=/opt/weavertools/services/embedder -ExecStart=/opt/weavertools/services/embedder/.venv/bin/python -m core.services.embedder_grpc - -# Keep Jina V4 weights resident on GPU for the lifetime of this unit. -# The embedder service has a configurable idle auto-unload (default 300s, -# guarded at embedder_grpc.py:171 by `state.idle_timeout_seconds > 0`). -# Under the harness model the embedder is always-on infrastructure — -# same tier as ArangoDB — so we set the threshold to 0, which short- -# circuits the auto-unload path. Re-loading Jina V4 mid-session is a -# 3–5s latency spike that every downstream caller has to eat, and -# defeats the "stays loaded until weaver serve shuts down" contract. -# Operators who genuinely need eviction can override in -# /etc/weaver/embedder.conf (loaded below, which wins over Environment=). -Environment=WEAVER_EMBEDDER_IDLE_TIMEOUT=0 - -# Shared config from /etc/weaver/embedder.conf (optional). Values here -# override the Environment= defaults above. -EnvironmentFile=-/etc/weaver/embedder.conf - -# Restart policy -Restart=on-failure -RestartSec=10 - -# Resource limits -LimitNOFILE=65536 -LimitMEMLOCK=infinity - -# Security hardening -NoNewPrivileges=true -ProtectSystem=strict -ProtectHome=read-only -PrivateTmp=true -ReadWritePaths=/run/weaver /var/cache/weaver - -# Logging -StandardOutput=journal -StandardError=journal -SyslogIdentifier=weaver-embedder - -[Install] -# Bound to the weaver.target so `systemctl start weaver.target` brings this -# up alongside weaver-infer, and `systemctl stop weaver.target` tears it down. -# Do NOT add multi-user.target — independent boot-time autostart is a -# lifecycle bug under the 2026-04-20 pivot. -WantedBy=weaver.target diff --git a/services/weaver-infer.service b/services/weaver-infer.service index e3e7a5f3..755c1a78 100644 --- a/services/weaver-infer.service +++ b/services/weaver-infer.service @@ -1,20 +1,12 @@ [Unit] Description=Weaver Inference Server (weaver serve) Documentation=file:///opt/weavertools/docs/memory-and-sleep-implementation-plan-2026-04-18.md -After=network.target weaver-embedder.service -# The embedder must be running before we start: `weaver serve` probes the -# embedder at startup to verify the cohort-identity pin and refuses to start -# on mismatch or when a pin exists but the embedder is unreachable. -# - Requires= pulls the embedder up when this unit is started. -# - After= orders our exec after the embedder's is scheduled. -# - BindsTo= propagates embedder failure: if the embedder dies or is stopped -# mid-session, stop us too, since serving vectors without the pinned -# embedder would silently diverge from everything already in the memory graph. -# (Earlier iterations inverted this — embedder BoundTo weaver-infer — but -# that left the probe racing against an embedder not yet started. Fixed -# 2026-04-20 alongside the pin-refuse-to-start logic.) -Requires=weaver-embedder.service -BindsTo=weaver-embedder.service +After=network.target +# Post-PR-1.J the embedder is in-process inside `weaver serve` itself; +# the external Python `weaver-embedder.service` retired alongside its +# gRPC client. The boot-time cohort-pin probe runs against the +# in-process EmbedderClient before the daemon listens, so no separate +# unit ordering is needed. PartOf=weaver.target [Service] diff --git a/services/weaver.target b/services/weaver.target index da90b313..b05aa35f 100644 --- a/services/weaver.target +++ b/services/weaver.target @@ -2,24 +2,20 @@ Description=Weaver Harness Documentation=file:///opt/weavertools/docs/memory-and-sleep-implementation-plan-2026-04-18.md # Synthetic target that owns the harness lifecycle. Starting this target -# brings up weaver-infer and the embedder together; stopping it tears -# them down in the right order via the BindsTo chain. +# brings up weaver-infer + weaver-daemon; stopping it tears them down. # # This is deliberately not WantedBy=multi-user.target — the harness is # an on-demand runtime, not a general system service. An operator starts # it explicitly with `systemctl start weaver.target`. -# Explicit Wants= is load-bearing: WantedBy= in the two services is only an +# +# Explicit Wants= is load-bearing: WantedBy= in the services is only an # enable-time symlink, not a runtime pull. Without these, starting # weaver.target directly wouldn't bring up either unit. # -# Startup ordering (embedder → infer) is enforced by the services themselves: -# weaver-infer has After=weaver-embedder.service + Requires= + BindsTo=, so -# pulling up weaver.target schedules the embedder first, then weaver-infer. -# The After= entries here just pin the target-reached event to "both -# services active", so `systemctl is-active weaver.target` reflects reality. -Wants=weaver-embedder.service +# Post-PR-1.J the external Python `weaver-embedder.service` retired — +# the embedder runs in-process inside `weaver serve`. The harness now +# has just two units to schedule. Wants=weaver-infer.service Wants=weaver-daemon.service -After=weaver-embedder.service After=weaver-infer.service After=weaver-daemon.service