Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
251 changes: 251 additions & 0 deletions crates/weaver-spu/src/encoder/client.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,251 @@
//! In-process candle-backed embedder.
//!
//! `EmbedderClient` is the post-Phase-1 production embedder backend
//! per `docs/specs/weaver-spu-Spec.md` §10 PR-1.F + the
//! `embedder-oxidization-Spec.md` cutover plan. It implements the
//! backend-agnostic [`weaver_core::embedder::Embedder`] trait against
//! [`JinaV4Embedder`] (in-process candle / CUDA), replacing the
//! gRPC client to the Python `weaver-embedder.service`.
//!
//! ## Why in-process
//!
//! Per the project mantra "latency is the enemy of agency": every
//! cross-process hop (gRPC + Unix socket + Python tokenizer + Python
//! tensor marshalling) is dead cost on every memory read. The SPU is
//! treated as a unit — encoder and decoder co-resident on the same
//! GPU, sharing process and address space. If the embedder crashes,
//! the inference stack dies with it; that's the desired failure mode
//! (fail-loud, not silent fallback).
//!
//! ## Concurrency model
//!
//! [`JinaV4Embedder::encode_text`] needs `&mut self` because it
//! mutates the model's KV cache (cleared on entry and exit per
//! independent-sequence semantics). The trait surfaces `&self`, so
//! we wrap the embedder behind an `Arc<parking_lot::Mutex<…>>` and
//! hop through `spawn_blocking` for the actual forward — the GPU
//! is single-threaded anyway, so callers serialize on the lock with
//! no observable throughput cost. `parking_lot::Mutex` (not
//! `std::sync::Mutex`) so a panic in `encode_text` doesn't poison
//! the slot and brick the embedder for the rest of the daemon's
//! lifetime.
//!
//! ## Late-chunked embeddings
//!
//! `embed_late_chunked` returns
//! [`EmbeddingError::NotAvailable`] for now. The Python service
//! exposed token-level forward outputs over its proto; the in-process
//! path needs an additional surface on `JinaV4Embedder` (encode →
//! per-token vectors → `late_chunk_embeddings`) that hasn't landed
//! yet. Tracked in a follow-up; consumers that need late chunking
//! must keep `backend = "python"` until it lands.

use std::path::Path;
use std::sync::Arc;
use std::time::Instant;

use anyhow::Context;
use async_trait::async_trait;
use candle_core::{DType, Device};
use parking_lot::Mutex;
use tracing::{debug, instrument};

use weaver_core::embedder::{
EmbedResult, Embedder, EmbedderInfo, EmbeddingError, LateChunkedResult,
};

use crate::encoder::jina_v4::{JinaV4Embedder, Task};
use crate::encoder::matryoshka;

/// Logical model name surfaced via [`Embedder::info`] — matches what
/// the Python `weaver-embedder.service` reports, so cohort-pin
/// comparisons stay backend-agnostic.
const MODEL_NAME: &str = "jinaai/jina-embeddings-v4";

/// In-process candle-backed embedder.
///
/// One per daemon. Construction loads the snapshot's safetensors
/// (~470 MB mmap), tokenizer, and adapter weights — measured at ~1.2 s
/// on warm cache, dominated by mmap fault-in. Subsequent `embed`
/// calls run a forward pass on the GPU; ~0.16 s for "hello world" on
/// an A6000.
pub struct EmbedderClient {
inner: Arc<Mutex<JinaV4Embedder>>,
/// Cached identity surfaced by `info()`. Static for the lifetime
/// of the client — the snapshot is mmapped read-only, no swap
/// path.
info: EmbedderInfo,
}

impl EmbedderClient {
/// Load a Jina V4 snapshot in-process and return a ready-to-use
/// client.
///
/// `snapshot_dir` must point at an HF snapshot directory
/// (typically under `/opt/weaver/huggingface/.../snapshots/<sha>/`).
/// `dtype` is the in-VRAM precision; production default is `bf16`
/// per `docs/specs/weaver-spu-Spec.md` §3. `device` is typically
/// `Device::new_cuda(idx)` for the GPU the SPU is bound to.
///
/// `weight_revision` in the resulting [`EmbedderInfo`] is derived
/// from the snapshot directory's basename — that's the SHA in the
/// HF cache layout. Empty string if the path has no usable
/// basename. The cohort-pin guard reads this for identity-drift
/// detection.
pub fn from_snapshot(
snapshot_dir: &Path,
dtype: DType,
device: &Device,
) -> Result<Self, EmbeddingError> {
let embedder = JinaV4Embedder::from_snapshot(snapshot_dir, dtype, device)
.map_err(|e| EmbeddingError::Transport(Box::new(IoError(e))))?;

let weight_revision = snapshot_dir
.file_name()
.and_then(|s| s.to_str())
.unwrap_or("")
.to_string();

let info = EmbedderInfo {
model_name: MODEL_NAME.to_string(),
model_loaded: true,
dimension: matryoshka::FULL_DIM as u32,
max_seq_length: embedder.max_seq_len() as u32,
weight_revision,
};

Ok(Self {
inner: Arc::new(Mutex::new(embedder)),
info,
})
}

/// Snapshot a separate Arc to the inner embedder for use across a
/// `spawn_blocking` boundary. Cheap (Arc clone).
fn handle(&self) -> Arc<Mutex<JinaV4Embedder>> {
Arc::clone(&self.inner)
}
}

#[async_trait]
impl Embedder for EmbedderClient {
#[instrument(skip(self, texts), fields(n = texts.len(), task))]
async fn embed(
&self,
texts: &[String],
task: &str,
_batch_size: Option<u32>,
) -> Result<EmbedResult, EmbeddingError> {
let parsed_task: Task = task
.parse()
.map_err(|e: anyhow::Error| EmbeddingError::InvalidResponse(e.to_string()))?;

// Clone owned inputs into the blocking task. The trait gives
// us `&[String]`; spawn_blocking needs `'static` captures.
let texts_owned: Vec<String> = texts.to_vec();
let handle = self.handle();
let dimension = self.info.dimension;
let model_name = self.info.model_name.clone();

let start = Instant::now();
let embeddings: Vec<Vec<f32>> = tokio::task::spawn_blocking(move || {
let mut emb = handle.lock();
let mut out = Vec::with_capacity(texts_owned.len());
for text in &texts_owned {
// truncate_dim = None → return full 2048-d. Callers
// that want matryoshka truncation will get a separate
// surface (or a per-call param) when that's wired in.
let v = emb
.encode_text(text, parsed_task, None)
.with_context(|| format!("encode_text len={}", text.len()))?;
out.push(v);
}
anyhow::Ok(out)
})
.await
.map_err(|e| EmbeddingError::Transport(Box::new(IoError(e.into()))))?
.map_err(|e| EmbeddingError::Transport(Box::new(IoError(e))))?;
Comment thread
coderabbitai[bot] marked this conversation as resolved.

let duration_ms = start.elapsed().as_millis() as u64;

// Defence-in-depth: check we got the expected number and shape
// of vectors before handing the batch to consumers. A bug in
// `encode_text` that returns the wrong dim should surface as
// an error here, not as silent dimension drift downstream.
if embeddings.len() != texts.len() {
return Err(EmbeddingError::InvalidResponse(format!(
"expected {} embeddings, got {}",
texts.len(),
embeddings.len()
)));
}
for (i, v) in embeddings.iter().enumerate() {
if v.len() != dimension as usize {
return Err(EmbeddingError::InvalidResponse(format!(
"embedding[{i}] has dim {}, expected {dimension}",
v.len()
)));
}
}

debug!(
n = embeddings.len(),
dimension,
duration_ms,
"in-process embed complete"
);

Ok(EmbedResult {
embeddings,
model: model_name,
dimension,
duration_ms,
})
}

async fn embed_late_chunked(
&self,
_text: &str,
_task: &str,
) -> Result<LateChunkedResult, EmbeddingError> {
// The in-process path doesn't yet expose a token-level forward
// surface to drive `late_chunk::late_chunk_embeddings`. Until
// that lands, fail loud rather than fall back to single-vector
// pooling — silently widening the contract would break
// downstream consumers that index per-chunk vectors.
Err(EmbeddingError::NotAvailable(
"embed_late_chunked not yet implemented in the in-process backend; \
keep `backend = \"python\"` if late chunking is required"
.to_string(),
))
}

async fn info(&self) -> Result<EmbedderInfo, EmbeddingError> {
Ok(self.info.clone())
}

/// Always-true once construction succeeds. The model is mmapped
/// for the daemon's lifetime; there's no runtime path that can
/// transition `model_loaded` to `false`. Override of the trait
/// default avoids the round-trip through `info()`.
async fn health_check(&self) -> bool {
true
}
}

/// Thin wrapper that lets us shove `anyhow::Error` and
/// `tokio::task::JoinError` into [`EmbeddingError::Transport`]'s
/// `Box<dyn std::error::Error + Send + Sync>` slot. The trait error
/// type is deliberately abstract over transport so this crate doesn't
/// bake the concrete error tree into its public surface; the wrapper
/// preserves the `Display` chain via `anyhow`.
#[derive(Debug)]
struct IoError(anyhow::Error);

impl std::fmt::Display for IoError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:#}", self.0)
}
}

impl std::error::Error for IoError {}
7 changes: 7 additions & 0 deletions crates/weaver-spu/src/encoder/jina_v4.rs
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,13 @@ impl JinaV4Embedder {
self.model.dtype()
}

/// Architectural sequence-length ceiling read from the snapshot's
/// `config.json` (`max_position_embeddings`). Used by the
/// `EmbedderClient` to populate `EmbedderInfo::max_seq_length`.
pub fn max_seq_len(&self) -> usize {
self.max_seq_len
}

/// Diagnostic-only entry point for the task #114 layer-by-layer
/// bisection (`crates/weaver-spu/tests/jina_v4_layerwise_bisection.rs`).
/// Takes pre-tokenized `input_ids` (so the Rust path uses the
Expand Down
5 changes: 5 additions & 0 deletions crates/weaver-spu/src/encoder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,11 @@ pub mod qwen25vl_loraed;
// pooling + matryoshka, single-vector path.
#[cfg(feature = "candle")]
pub mod jina_v4;
// In-process EmbedderClient — implements weaver_core::embedder::Embedder
// against JinaV4Embedder. Replaces grpc_client_legacy at the daemon
// wiring point post-Phase-1 cutover.
#[cfg(feature = "candle")]
pub mod client;

// `grpc_client_legacy` is NOT feature-gated: it's the production
// embedder backend during the migration window (talks to the Python
Expand Down
1 change: 1 addition & 0 deletions crates/weaver-spu/src/encoder/qwen25vl_loraed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,7 @@ impl LoraAttention {
/// `forward_with_intermediates` so the two paths can never desync
/// (the previous duplicated implementation was a maintenance
/// hazard called out in PR #290's review).
#[allow(clippy::too_many_arguments)]
fn run_attention_post_proj(
&mut self,
query_states: Tensor,
Expand Down