Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions helix-db/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ reqwest = { version = "0.12.15", features = [
"blocking",
], optional = true }
url = { version = "2.5", optional = true }
model2vec-rs = { version = "0.1", optional = true }
tokio-util = { version = "0.7.15", features = ["compat"] }
axum = "0.8.4"
tracing = "0.1.41"
Expand Down Expand Up @@ -79,6 +80,7 @@ api-key = []
build = ["compiler"]
vectors = ["cosine", "url"]
server = ["build", "compiler", "vectors", "reqwest"]
model2vec = ["model2vec-rs"]
full = ["build", "compiler", "vectors"]
bench = ["polars"]
dev = ["debug-output", "server", "bench"]
Expand Down
158 changes: 158 additions & 0 deletions helix-db/src/helix_gateway/embedding_providers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,76 @@ use sonic_rs::{JsonContainerTrait, json};
use std::env;
use url::Url;

#[cfg(feature = "model2vec")]
use model2vec_rs::model::StaticModel;

/// Embedding providers for generating text embeddings.
///
/// HelixDB supports four embedding providers:
///
/// ## OpenAI (requires `reqwest` feature)
/// - Format: `"openai:{model}"` or just `"{model}"` for default
/// - Requires: `OPENAI_API_KEY` environment variable
/// - Example: `"text-embedding-ada-002"`, `"openai:text-embedding-3-small"`
/// - Network: External API call to api.openai.com
/// - Cost: Paid per token
///
/// ## Gemini (requires `reqwest` feature)
/// - Format: `"gemini:{model}"` or `"gemini:{model}:{task_type}"`
/// - Requires: `GEMINI_API_KEY` environment variable
/// - Example: `"gemini:gemini-embedding-001"`, `"gemini:gemini-embedding-001:SEMANTIC_SIMILARITY"`
/// - Network: External API call to Google's API
/// - Cost: Paid per character
/// - Task types: `RETRIEVAL_QUERY`, `RETRIEVAL_DOCUMENT`, `SEMANTIC_SIMILARITY`, `CLASSIFICATION`, `CLUSTERING`
///
/// ## Local (requires `reqwest` feature)
/// - Format: `"local"`
/// - Requires: Local HTTP server running on `http://localhost:8699/embed`
/// - Network: HTTP call to localhost
/// - Cost: Free (self-hosted)
/// - Note: You must run your own embedding server
///
/// ## Model2Vec (requires `model2vec` feature)
/// - Format: `"model2vec:{model}"` or `"model2vec:"` for default
/// - Requires: No API key, no server
/// - Example: `"model2vec:minishlab/potion-base-8M"`, `"model2vec:minishlab/potion-base-32M"`
/// - Default: `"minishlab/potion-base-32M"` (768 dimensions)
/// - Network: Downloads model from HuggingFace on first use, then fully offline
/// - Cost: Free (in-process)
/// - Speed: <1ms inference after model load
/// - Models cached in: `~/.cache/huggingface/`
/// - Available models:
/// - `minishlab/potion-base-2M` (2MB, 256 dims, fastest)
/// - `minishlab/potion-base-8M` (8MB, 256 dims, balanced)
/// - `minishlab/potion-base-32M` (32MB, 768 dims, recommended)
/// - `minishlab/potion-retrieval-32M` (32MB, 768 dims, optimized for retrieval)
///
/// # Usage
///
/// Configure in `config.hx.json`:
/// ```json
/// {
/// "embedding_model": "model2vec:minishlab/potion-base-32M"
/// }
/// ```
///
/// Or use in HelixQL queries:
/// ```hql
/// #[model("model2vec:minishlab/potion-base-8M")]
/// QUERY search(query: String) =>
/// results <- SearchV<Document>(Embed(query), 10)
/// RETURN results
/// ```
///
/// # Feature Flags
///
/// - `server`: Enables OpenAI, Gemini, and Local providers (requires `reqwest`)
/// - `model2vec`: Enables Model2Vec provider (requires `model2vec-rs`)
///
/// Build with both:
/// ```bash
/// cargo build --features server,model2vec
/// ```
/// Parse an API error response and return a descriptive GraphError
fn parse_api_error(provider: &str, status: u16, body: &str) -> GraphError {
// Try to extract error message from JSON response
Expand Down Expand Up @@ -46,6 +116,10 @@ pub enum EmbeddingProvider {
deployment_id: String,
},
Local,
#[cfg(feature = "model2vec")]
Model2Vec {
model_name: String,
},
}

pub struct EmbeddingModelImpl {
Expand All @@ -54,6 +128,8 @@ pub struct EmbeddingModelImpl {
client: Client,
pub(crate) model: String,
pub(crate) url: Option<String>,
#[cfg(feature = "model2vec")]
pub(crate) model2vec: Option<StaticModel>,
}

impl EmbeddingModelImpl {
Expand Down Expand Up @@ -86,6 +162,8 @@ impl EmbeddingModelImpl {
Some(key)
}
EmbeddingProvider::Local => None,
#[cfg(feature = "model2vec")]
EmbeddingProvider::Model2Vec { .. } => None,
};

let url = match &provider {
Expand All @@ -97,12 +175,38 @@ impl EmbeddingModelImpl {
_ => None,
};

// Load model2vec model if using Model2Vec provider
#[cfg(feature = "model2vec")]
let model2vec = match &provider {
EmbeddingProvider::Model2Vec { model_name } => {
Some(
StaticModel::from_pretrained(
model_name, None, // No HF token needed for public models
None, // Use model's default normalization
None, // No subfolder
)
.map_err(|e| {
GraphError::from(format!(
"Failed to load model2vec model '{}': {}",
model_name, e
))
})?,
)
}
_ => None,
};

#[cfg(not(feature = "model2vec"))]
let _model2vec: Option<()> = None;

Ok(EmbeddingModelImpl {
provider,
api_key,
client: Client::new(),
model: model_name,
url,
#[cfg(feature = "model2vec")]
model2vec,
})
}

Expand Down Expand Up @@ -160,6 +264,37 @@ impl EmbeddingModelImpl {
}
Some("local") => Ok((EmbeddingProvider::Local, "local".to_string())),

// Model2Vec provider (in-process, local embedding generation)
// Format: "model2vec:{model_name}"
// Example: "model2vec:minishlab/potion-base-8M"
// Default model: "minishlab/potion-base-32M"
//
// Features:
// - No API key required
// - No network calls (after initial model download)
// - Works fully offline
// - Fast inference (<1ms after model load)
// - Models cached in ~/.cache/huggingface/
//
// Available models:
// - minishlab/potion-base-2M (2MB, 256 dims)
// - minishlab/potion-base-8M (8MB, 256 dims)
// - minishlab/potion-base-32M (32MB, 768 dims) [recommended]
// - minishlab/potion-retrieval-32M (32MB, 768 dims)
#[cfg(feature = "model2vec")]
Some(m) if m.starts_with("model2vec:") => {
let model_name = m
.strip_prefix("model2vec:")
.filter(|s| !s.is_empty())
.unwrap_or("minishlab/potion-base-32M");
Ok((
EmbeddingProvider::Model2Vec {
model_name: model_name.to_string(),
},
model_name.to_string(),
))
}

Some(_) => Ok((
EmbeddingProvider::OpenAI,
"text-embedding-ada-002".to_string(),
Expand Down Expand Up @@ -421,6 +556,28 @@ impl EmbeddingModel for EmbeddingModelImpl {

Ok(embedding)
}

#[cfg(feature = "model2vec")]
EmbeddingProvider::Model2Vec { .. } => {
let model = self
.model2vec
.as_ref()
.ok_or_else(|| GraphError::from("Model2Vec model not loaded"))?;

// Clone for blocking task (cheap Arc-based clone)
let text_owned = text.to_string();
let model_clone = model.clone();

// Run on blocking threadpool to avoid blocking async runtime
let embedding = tokio::task::spawn_blocking(move || -> Vec<f64> {
let embedding_f32 = model_clone.encode_single(&text_owned);
embedding_f32.into_iter().map(|v| v as f64).collect()
})
.await
.map_err(|e| GraphError::from(format!("Model2Vec task failed: {}", e)))?;

Ok(embedding)
}
}
}
}
Expand All @@ -447,6 +604,7 @@ pub fn get_embedding_model(
/// let query = embed!("Hello, world!");
/// let embedding = embed!("Hello, world!", "text-embedding-ada-002");
/// let embedding = embed!("Hello, world!", "gemini:gemini-embedding-001:SEMANTIC_SIMILARITY");
/// let embedding = embed!("Hello, world!", "model2vec:minishlab/potion-base-32M");
/// let embedding = embed!("Hello, world!", "text-embedding-ada-002", "http://localhost:8699/embed");
/// ```
macro_rules! embed {
Expand Down
43 changes: 43 additions & 0 deletions helix-db/src/helix_gateway/tests/embedding_providers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,49 @@ fn test_parse_local_provider() {
assert_eq!(model, "local");
}

#[test]
#[cfg(feature = "model2vec")]
fn test_parse_model2vec_provider() {
let result =
EmbeddingModelImpl::parse_provider_and_model(Some("model2vec:minishlab/potion-base-8M"));
assert!(result.is_ok());
let (provider, model) = result.unwrap();
match provider {
EmbeddingProvider::Model2Vec { model_name } => {
assert_eq!(model_name, "minishlab/potion-base-8M");
}
_ => panic!("Expected Model2Vec provider"),
}
assert_eq!(model, "minishlab/potion-base-8M");
}

#[test]
#[cfg(feature = "model2vec")]
fn test_parse_model2vec_default() {
let result = EmbeddingModelImpl::parse_provider_and_model(Some("model2vec:"));
assert!(result.is_ok());
let (provider, model) = result.unwrap();
match provider {
EmbeddingProvider::Model2Vec { model_name } => {
assert_eq!(model_name, "minishlab/potion-base-32M");
}
_ => panic!("Expected Model2Vec provider"),
}
assert_eq!(model, "minishlab/potion-base-32M");
}

#[test]
#[cfg(feature = "model2vec")]
#[ignore]
fn test_model2vec_embedding() {
let model =
get_embedding_model(None, Some("model2vec:minishlab/potion-base-2M"), None).unwrap();

let embedding = model.fetch_embedding("test").unwrap();
assert!(!embedding.is_empty());
assert!(embedding.iter().all(|&v| v.is_finite()));
}

#[test]
fn test_parse_unknown_provider_defaults_to_openai() {
let result = EmbeddingModelImpl::parse_provider_and_model(Some("unknown-provider"));
Expand Down
Loading