-
Notifications
You must be signed in to change notification settings - Fork 0
feat(spu): in-process EmbedderClient backend (PR-1.F) #294
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 1 commit
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,251 @@ | ||
| //! In-process candle-backed embedder. | ||
| //! | ||
| //! `EmbedderClient` is the post-Phase-1 production embedder backend | ||
| //! per `docs/specs/weaver-spu-Spec.md` §10 PR-1.F + the | ||
| //! `embedder-oxidization-Spec.md` cutover plan. It implements the | ||
| //! backend-agnostic [`weaver_core::embedder::Embedder`] trait against | ||
| //! [`JinaV4Embedder`] (in-process candle / CUDA), replacing the | ||
| //! gRPC client to the Python `weaver-embedder.service`. | ||
| //! | ||
| //! ## Why in-process | ||
| //! | ||
| //! Per the project mantra "latency is the enemy of agency": every | ||
| //! cross-process hop (gRPC + Unix socket + Python tokenizer + Python | ||
| //! tensor marshalling) is dead cost on every memory read. The SPU is | ||
| //! treated as a unit — encoder and decoder co-resident on the same | ||
| //! GPU, sharing process and address space. If the embedder crashes, | ||
| //! the inference stack dies with it; that's the desired failure mode | ||
| //! (fail-loud, not silent fallback). | ||
| //! | ||
| //! ## Concurrency model | ||
| //! | ||
| //! [`JinaV4Embedder::encode_text`] needs `&mut self` because it | ||
| //! mutates the model's KV cache (cleared on entry and exit per | ||
| //! independent-sequence semantics). The trait surfaces `&self`, so | ||
| //! we wrap the embedder behind an `Arc<parking_lot::Mutex<…>>` and | ||
| //! hop through `spawn_blocking` for the actual forward — the GPU | ||
| //! is single-threaded anyway, so callers serialize on the lock with | ||
| //! no observable throughput cost. `parking_lot::Mutex` (not | ||
| //! `std::sync::Mutex`) so a panic in `encode_text` doesn't poison | ||
| //! the slot and brick the embedder for the rest of the daemon's | ||
| //! lifetime. | ||
| //! | ||
| //! ## Late-chunked embeddings | ||
| //! | ||
| //! `embed_late_chunked` returns | ||
| //! [`EmbeddingError::NotAvailable`] for now. The Python service | ||
| //! exposed token-level forward outputs over its proto; the in-process | ||
| //! path needs an additional surface on `JinaV4Embedder` (encode → | ||
| //! per-token vectors → `late_chunk_embeddings`) that hasn't landed | ||
| //! yet. Tracked in a follow-up; consumers that need late chunking | ||
| //! must keep `backend = "python"` until it lands. | ||
|
|
||
| use std::path::Path; | ||
| use std::sync::Arc; | ||
| use std::time::Instant; | ||
|
|
||
| use anyhow::Context; | ||
| use async_trait::async_trait; | ||
| use candle_core::{DType, Device}; | ||
| use parking_lot::Mutex; | ||
| use tracing::{debug, instrument}; | ||
|
|
||
| use weaver_core::embedder::{ | ||
| EmbedResult, Embedder, EmbedderInfo, EmbeddingError, LateChunkedResult, | ||
| }; | ||
|
|
||
| use crate::encoder::jina_v4::{JinaV4Embedder, Task}; | ||
| use crate::encoder::matryoshka; | ||
|
|
||
| /// Logical model name surfaced via [`Embedder::info`] — matches what | ||
| /// the Python `weaver-embedder.service` reports, so cohort-pin | ||
| /// comparisons stay backend-agnostic. | ||
| const MODEL_NAME: &str = "jinaai/jina-embeddings-v4"; | ||
|
|
||
| /// In-process candle-backed embedder. | ||
| /// | ||
| /// One per daemon. Construction loads the snapshot's safetensors | ||
| /// (~470 MB mmap), tokenizer, and adapter weights — measured at ~1.2 s | ||
| /// on warm cache, dominated by mmap fault-in. Subsequent `embed` | ||
| /// calls run a forward pass on the GPU; ~0.16 s for "hello world" on | ||
| /// an A6000. | ||
| pub struct EmbedderClient { | ||
| inner: Arc<Mutex<JinaV4Embedder>>, | ||
| /// Cached identity surfaced by `info()`. Static for the lifetime | ||
| /// of the client — the snapshot is mmapped read-only, no swap | ||
| /// path. | ||
| info: EmbedderInfo, | ||
| } | ||
|
|
||
| impl EmbedderClient { | ||
| /// Load a Jina V4 snapshot in-process and return a ready-to-use | ||
| /// client. | ||
| /// | ||
| /// `snapshot_dir` must point at an HF snapshot directory | ||
| /// (typically under `/opt/weaver/huggingface/.../snapshots/<sha>/`). | ||
| /// `dtype` is the in-VRAM precision; production default is `bf16` | ||
| /// per `docs/specs/weaver-spu-Spec.md` §3. `device` is typically | ||
| /// `Device::new_cuda(idx)` for the GPU the SPU is bound to. | ||
| /// | ||
| /// `weight_revision` in the resulting [`EmbedderInfo`] is derived | ||
| /// from the snapshot directory's basename — that's the SHA in the | ||
| /// HF cache layout. Empty string if the path has no usable | ||
| /// basename. The cohort-pin guard reads this for identity-drift | ||
| /// detection. | ||
| pub fn from_snapshot( | ||
| snapshot_dir: &Path, | ||
| dtype: DType, | ||
| device: &Device, | ||
| ) -> Result<Self, EmbeddingError> { | ||
| let embedder = JinaV4Embedder::from_snapshot(snapshot_dir, dtype, device) | ||
| .map_err(|e| EmbeddingError::Transport(Box::new(IoError(e))))?; | ||
|
|
||
| let weight_revision = snapshot_dir | ||
| .file_name() | ||
| .and_then(|s| s.to_str()) | ||
| .unwrap_or("") | ||
| .to_string(); | ||
|
|
||
| let info = EmbedderInfo { | ||
| model_name: MODEL_NAME.to_string(), | ||
| model_loaded: true, | ||
| dimension: matryoshka::FULL_DIM as u32, | ||
| max_seq_length: embedder.max_seq_len() as u32, | ||
| weight_revision, | ||
| }; | ||
|
|
||
| Ok(Self { | ||
| inner: Arc::new(Mutex::new(embedder)), | ||
| info, | ||
| }) | ||
| } | ||
|
|
||
| /// Snapshot a separate Arc to the inner embedder for use across a | ||
| /// `spawn_blocking` boundary. Cheap (Arc clone). | ||
| fn handle(&self) -> Arc<Mutex<JinaV4Embedder>> { | ||
| Arc::clone(&self.inner) | ||
| } | ||
| } | ||
|
|
||
| #[async_trait] | ||
| impl Embedder for EmbedderClient { | ||
| #[instrument(skip(self, texts), fields(n = texts.len(), task))] | ||
| async fn embed( | ||
| &self, | ||
| texts: &[String], | ||
| task: &str, | ||
| _batch_size: Option<u32>, | ||
| ) -> Result<EmbedResult, EmbeddingError> { | ||
| let parsed_task: Task = task | ||
| .parse() | ||
| .map_err(|e: anyhow::Error| EmbeddingError::InvalidResponse(e.to_string()))?; | ||
|
|
||
| // Clone owned inputs into the blocking task. The trait gives | ||
| // us `&[String]`; spawn_blocking needs `'static` captures. | ||
| let texts_owned: Vec<String> = texts.to_vec(); | ||
| let handle = self.handle(); | ||
| let dimension = self.info.dimension; | ||
| let model_name = self.info.model_name.clone(); | ||
|
|
||
| let start = Instant::now(); | ||
| let embeddings: Vec<Vec<f32>> = tokio::task::spawn_blocking(move || { | ||
| let mut emb = handle.lock(); | ||
| let mut out = Vec::with_capacity(texts_owned.len()); | ||
| for text in &texts_owned { | ||
| // truncate_dim = None → return full 2048-d. Callers | ||
| // that want matryoshka truncation will get a separate | ||
| // surface (or a per-call param) when that's wired in. | ||
| let v = emb | ||
| .encode_text(text, parsed_task, None) | ||
| .with_context(|| format!("encode_text len={}", text.len()))?; | ||
| out.push(v); | ||
| } | ||
| anyhow::Ok(out) | ||
| }) | ||
| .await | ||
| .map_err(|e| EmbeddingError::Transport(Box::new(IoError(e.into()))))? | ||
| .map_err(|e| EmbeddingError::Transport(Box::new(IoError(e))))?; | ||
|
|
||
| let duration_ms = start.elapsed().as_millis() as u64; | ||
|
|
||
| // Defence-in-depth: check we got the expected number and shape | ||
| // of vectors before handing the batch to consumers. A bug in | ||
| // `encode_text` that returns the wrong dim should surface as | ||
| // an error here, not as silent dimension drift downstream. | ||
| if embeddings.len() != texts.len() { | ||
| return Err(EmbeddingError::InvalidResponse(format!( | ||
| "expected {} embeddings, got {}", | ||
| texts.len(), | ||
| embeddings.len() | ||
| ))); | ||
| } | ||
| for (i, v) in embeddings.iter().enumerate() { | ||
| if v.len() != dimension as usize { | ||
| return Err(EmbeddingError::InvalidResponse(format!( | ||
| "embedding[{i}] has dim {}, expected {dimension}", | ||
| v.len() | ||
| ))); | ||
| } | ||
| } | ||
|
|
||
| debug!( | ||
| n = embeddings.len(), | ||
| dimension, | ||
| duration_ms, | ||
| "in-process embed complete" | ||
| ); | ||
|
|
||
| Ok(EmbedResult { | ||
| embeddings, | ||
| model: model_name, | ||
| dimension, | ||
| duration_ms, | ||
| }) | ||
| } | ||
|
|
||
| async fn embed_late_chunked( | ||
| &self, | ||
| _text: &str, | ||
| _task: &str, | ||
| ) -> Result<LateChunkedResult, EmbeddingError> { | ||
| // The in-process path doesn't yet expose a token-level forward | ||
| // surface to drive `late_chunk::late_chunk_embeddings`. Until | ||
| // that lands, fail loud rather than fall back to single-vector | ||
| // pooling — silently widening the contract would break | ||
| // downstream consumers that index per-chunk vectors. | ||
| Err(EmbeddingError::NotAvailable( | ||
| "embed_late_chunked not yet implemented in the in-process backend; \ | ||
| keep `backend = \"python\"` if late chunking is required" | ||
| .to_string(), | ||
| )) | ||
| } | ||
|
|
||
| async fn info(&self) -> Result<EmbedderInfo, EmbeddingError> { | ||
| Ok(self.info.clone()) | ||
| } | ||
|
|
||
| /// Always-true once construction succeeds. The model is mmapped | ||
| /// for the daemon's lifetime; there's no runtime path that can | ||
| /// transition `model_loaded` to `false`. Override of the trait | ||
| /// default avoids the round-trip through `info()`. | ||
| async fn health_check(&self) -> bool { | ||
| true | ||
| } | ||
| } | ||
|
|
||
| /// Thin wrapper that lets us shove `anyhow::Error` and | ||
| /// `tokio::task::JoinError` into [`EmbeddingError::Transport`]'s | ||
| /// `Box<dyn std::error::Error + Send + Sync>` slot. The trait error | ||
| /// type is deliberately abstract over transport so this crate doesn't | ||
| /// bake the concrete error tree into its public surface; the wrapper | ||
| /// preserves the `Display` chain via `anyhow`. | ||
| #[derive(Debug)] | ||
| struct IoError(anyhow::Error); | ||
|
|
||
| impl std::fmt::Display for IoError { | ||
| fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { | ||
| write!(f, "{:#}", self.0) | ||
| } | ||
| } | ||
|
|
||
| impl std::error::Error for IoError {} | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.