diff --git a/package.json b/package.json index 24a0028..69a3b44 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "fastembed", - "version": "2.1.0", + "version": "3.0.0", "description": "NodeJS implementation of @Qdrant/fastembed", "keywords": [ "embeddings", diff --git a/src/dense-model-registry.ts b/src/dense-model-registry.ts new file mode 100644 index 0000000..5c1bb59 --- /dev/null +++ b/src/dense-model-registry.ts @@ -0,0 +1,74 @@ +export interface DenseModelMetadata { + repoId: string; // Actual provider repo + gcsUrl: string; // GCS fallback URL + onnxFilePath: string; // Path within HF repo: "onnx/model.onnx" + dim: number; + description: string; + requiresTokenTypeIds: boolean; +} + +export const DENSE_MODEL_REGISTRY: Record = { + "sentence-transformers/all-MiniLM-L6-v2": { + repoId: "sentence-transformers/all-MiniLM-L6-v2", + gcsUrl: + "https://storage.googleapis.com/qdrant-fastembed/sentence-transformers-all-MiniLM-L6-v2.tar.gz", + onnxFilePath: "onnx/model.onnx", + dim: 384, + description: "Sentence Transformer model, MiniLM-L6-v2", + requiresTokenTypeIds: true, + }, + "BAAI/bge-base-en": { + repoId: "BAAI/bge-base-en", + gcsUrl: + "https://storage.googleapis.com/qdrant-fastembed/fast-bge-base-en.tar.gz", + onnxFilePath: "onnx/model.onnx", + dim: 768, + description: "Base English model from BAAI", + requiresTokenTypeIds: true, + }, + "BAAI/bge-base-en-v1.5": { + repoId: "BAAI/bge-base-en-v1.5", + gcsUrl: + "https://storage.googleapis.com/qdrant-fastembed/fast-bge-base-en-v1.5.tar.gz", + onnxFilePath: "onnx/model.onnx", + dim: 768, + description: "v1.5 release of Base English model", + requiresTokenTypeIds: true, + }, + "BAAI/bge-small-en": { + repoId: "BAAI/bge-small-en", + gcsUrl: + "https://storage.googleapis.com/qdrant-fastembed/fast-bge-small-en.tar.gz", + onnxFilePath: "onnx/model.onnx", + dim: 384, + description: "Small English model from BAAI", + requiresTokenTypeIds: true, + }, + "BAAI/bge-small-en-v1.5": { + repoId: "BAAI/bge-small-en-v1.5", + gcsUrl: + "https://storage.googleapis.com/qdrant-fastembed/fast-bge-small-en-v1.5.tar.gz", + onnxFilePath: "onnx/model.onnx", + dim: 384, + description: "v1.5 release of small English model", + requiresTokenTypeIds: true, + }, + "BAAI/bge-small-zh-v1.5": { + repoId: "BAAI/bge-small-zh-v1.5", + gcsUrl: + "https://storage.googleapis.com/qdrant-fastembed/fast-bge-small-zh-v1.5.tar.gz", + onnxFilePath: "onnx/model.onnx", + dim: 512, + description: "v1.5 Chinese small model", + requiresTokenTypeIds: true, + }, + "intfloat/multilingual-e5-large": { + repoId: "intfloat/multilingual-e5-large", + gcsUrl: + "https://storage.googleapis.com/qdrant-fastembed/fast-multilingual-e5-large.tar.gz", + onnxFilePath: "onnx/model.onnx", + dim: 1024, + description: "Multilingual model, e5-large", + requiresTokenTypeIds: false, + }, +}; diff --git a/src/fastembed.ts b/src/fastembed.ts index 485b59e..4775873 100644 --- a/src/fastembed.ts +++ b/src/fastembed.ts @@ -6,6 +6,7 @@ import path from "path"; import Progress from "progress"; import tar from "tar"; import { downloadFileToCacheDir } from "@huggingface/hub"; +import { DENSE_MODEL_REGISTRY } from "./dense-model-registry.js"; export enum ExecutionProvider { CPU = "cpu", @@ -16,13 +17,13 @@ export enum ExecutionProvider { } export enum EmbeddingModel { - AllMiniLML6V2 = "fast-all-MiniLM-L6-v2", - BGEBaseEN = "fast-bge-base-en", - BGEBaseENV15 = "fast-bge-base-en-v1.5", - BGESmallEN = "fast-bge-small-en", - BGESmallENV15 = "fast-bge-small-en-v1.5", - BGESmallZH = "fast-bge-small-zh-v1.5", - MLE5Large = "fast-multilingual-e5-large", + AllMiniLML6V2 = "sentence-transformers/all-MiniLM-L6-v2", + BGEBaseEN = "BAAI/bge-base-en", + BGEBaseENV15 = "BAAI/bge-base-en-v1.5", + BGESmallEN = "BAAI/bge-small-en", + BGESmallENV15 = "BAAI/bge-small-en-v1.5", + BGESmallZH = "BAAI/bge-small-zh-v1.5", + MLE5Large = "intfloat/multilingual-e5-large", CUSTOM = "custom", } @@ -111,13 +112,24 @@ export interface InitStandardOptions extends InitOptionsBase { modelName?: string; } -// Cas custom +// Cas custom local export interface InitCustomOptions extends InitOptionsBase { model: EmbeddingModel.CUSTOM; modelAbsoluteDirPath: fs.PathLike; modelName: string; } -export type InitOptions = InitStandardOptions | InitCustomOptions; + +// Cas custom HuggingFace repo +export interface InitCustomHFOptions extends InitOptionsBase { + model: string; // Any HuggingFace repo ID + modelAbsoluteDirPath?: undefined; + modelName?: string; +} + +export type InitOptions = + | InitStandardOptions + | InitCustomOptions + | InitCustomHFOptions; // Sparse embedding init options export interface InitSparseStandardOptions extends InitOptionsBase { @@ -178,6 +190,7 @@ export class FlagEmbedding extends Embedding { } static async init(options: InitStandardOptions): Promise; static async init(options: InitCustomOptions): Promise; + static async init(options: InitCustomHFOptions): Promise; static async init({ model = EmbeddingModel.BGESmallENV15, executionProviders = [ExecutionProvider.CPU], @@ -209,15 +222,15 @@ export class FlagEmbedding extends Embedding { ); const tokenizer = this.loadTokenizer(modelDir, maxLength); - const defaultModelName = - model === EmbeddingModel.MLE5Large || - model === EmbeddingModel.AllMiniLML6V2 - ? "model.onnx" - : "model_optimized.onnx"; + + // Use metadata to determine ONNX file path + const metadata = DENSE_MODEL_REGISTRY[model]; + const onnxFileName = metadata?.onnxFilePath || "onnx/model.onnx"; const modelPath = path.join( modelDir.toString(), - modelName || defaultModelName + modelName || onnxFileName ); + if (!fs.existsSync(modelPath)) { throw new Error(`Model file not found at ${modelPath}`); } @@ -225,7 +238,7 @@ export class FlagEmbedding extends Embedding { executionProviders, graphOptimizationLevel: "all", }); - return new FlagEmbedding(tokenizer, session, model); + return new FlagEmbedding(tokenizer, session, model as EmbeddingModel); } private static loadTokenizer( @@ -374,10 +387,103 @@ export class FlagEmbedding extends Embedding { } } - private static async retrieveModel( - model: EmbeddingModel, + private static async retrieveModelHuggingFace( + model: string, cacheDir: PathLike, showDownloadProgress: boolean = true + ): Promise { + // Sanitize model name for filesystem (Org/Model → Org--Model) + const modelDir = path.join( + cacheDir.toString(), + model.replace(/\//g, "--") + ); + + // Check if already cached + if (fs.existsSync(modelDir)) { + const requiredFiles = [ + "config.json", + "tokenizer.json", + "tokenizer_config.json", + "special_tokens_map.json", + ]; + const allFilesExist = requiredFiles.every((file) => + fs.existsSync(path.join(modelDir, file)) + ); + if (allFilesExist) { + if (showDownloadProgress) { + console.log(`Model ${model} found in cache`); + } + return modelDir; + } + } + + // Get model metadata + const metadata = DENSE_MODEL_REGISTRY[model]; + + try { + // Try HuggingFace first + if (!fs.existsSync(modelDir)) { + fs.mkdirSync(modelDir, { mode: 0o777, recursive: true }); + } + + const filesToDownload = [ + metadata?.onnxFilePath || "onnx/model.onnx", // e.g., "onnx/model.onnx" + "tokenizer.json", + "tokenizer_config.json", + "config.json", + "special_tokens_map.json", + ]; + + if (showDownloadProgress) { + console.log(`Downloading ${model} from HuggingFace...`); + } + + for (const fileName of filesToDownload) { + const outputPath = path.join(modelDir, fileName); + const outputDir = path.dirname(outputPath); + + if (!fs.existsSync(outputDir)) { + fs.mkdirSync(outputDir, { recursive: true, mode: 0o777 }); + } + + // Use HuggingFace Hub library (same as sparse embeddings) + const downloaded = await downloadFileToCacheDir({ + repo: model, + path: fileName, + }); + + if (downloaded && typeof downloaded === "string") { + fs.copyFileSync(downloaded, outputPath); + } + } + + if (showDownloadProgress) { + console.log(`Successfully downloaded ${model}`); + } + + return modelDir; + } catch (error) { + // Fallback to GCS if HuggingFace fails + if (metadata?.gcsUrl) { + console.warn(`HuggingFace download failed, falling back to GCS...`); + return await FlagEmbedding.retrieveModel( + model as EmbeddingModel, + cacheDir, + showDownloadProgress, + true // force GCS + ); + } + throw new Error( + `Failed to download ${model} from HuggingFace: ${error}. No GCS fallback available.` + ); + } + } + + private static async retrieveModel( + model: EmbeddingModel | string, + cacheDir: PathLike, + showDownloadProgress: boolean = true, + forceGCS: boolean = false ): Promise { if (!fs.existsSync(cacheDir)) { fs.mkdirSync(cacheDir, { @@ -385,17 +491,36 @@ export class FlagEmbedding extends Embedding { }); } - const modelDir = path.join(cacheDir.toString(), model); + // Use GCS if forced (fallback scenario) + if (forceGCS) { + const modelDir = path.join(cacheDir.toString(), model); - if (fs.existsSync(modelDir)) { + if (fs.existsSync(modelDir)) { + return modelDir; + } + + const modelTarGz = path.join(cacheDir.toString(), `${model}.tar.gz`); + await this.downloadFileFromGCS(modelTarGz, model, showDownloadProgress); + await this.decompressToCache(modelTarGz, cacheDir); + fs.unlinkSync(modelTarGz); return modelDir; } - const modelTarGz = path.join(cacheDir.toString(), `${model}.tar.gz`); - await this.downloadFileFromGCS(modelTarGz, model, showDownloadProgress); - await this.decompressToCache(modelTarGz, cacheDir); - fs.unlinkSync(modelTarGz); - return modelDir; + // Try HuggingFace first for known models + if (DENSE_MODEL_REGISTRY[model]) { + return await FlagEmbedding.retrieveModelHuggingFace( + model, + cacheDir, + showDownloadProgress + ); + } + + // For custom models not in registry, try HuggingFace anyway + return await FlagEmbedding.retrieveModelHuggingFace( + model, + cacheDir, + showDownloadProgress + ); } async *embed(textStrings: string[], batchSize: number = 256) { @@ -444,8 +569,9 @@ export class FlagEmbedding extends Embedding { token_type_ids: batchTokenTypeId, }; - // Exclude token_type_ids for MLE5Large - if (this.model === EmbeddingModel.MLE5Large) { + // Use metadata to determine if token_type_ids is needed + const metadata = DENSE_MODEL_REGISTRY[this.model]; + if (metadata && !metadata.requiresTokenTypeIds) { delete inputs.token_type_ids; } @@ -496,44 +622,11 @@ export class FlagEmbedding extends Embedding { } listSupportedModels(): ModelInfo[] { - return [ - { - model: EmbeddingModel.BGESmallEN, - dim: 384, - description: "Fast English model", - }, - { - model: EmbeddingModel.BGESmallENV15, - dim: 384, - description: "v1.5 release of the fast, default English model", - }, - { - model: EmbeddingModel.BGEBaseEN, - dim: 768, - description: "Base English model", - }, - { - model: EmbeddingModel.BGEBaseENV15, - dim: 768, - description: "v1.5 release of Base English model", - }, - { - model: EmbeddingModel.BGESmallZH, - dim: 512, - description: "v1.5 release of the fast, Chinese model", - }, - { - model: EmbeddingModel.AllMiniLML6V2, - dim: 384, - description: "Sentence Transformer model, MiniLM-L6-v2", - }, - { - model: EmbeddingModel.MLE5Large, - dim: 1024, - description: - "Multilingual model, e5-large. Recommend using this model for non-English languages", - }, - ]; + return Object.entries(DENSE_MODEL_REGISTRY).map(([id, metadata]) => ({ + model: id as EmbeddingModel, + dim: metadata.dim, + description: metadata.description, + })); } } diff --git a/src/index.ts b/src/index.ts index c0d2d6e..42011f8 100644 --- a/src/index.ts +++ b/src/index.ts @@ -2,7 +2,13 @@ export { EmbeddingModel, ExecutionProvider, FlagEmbedding, + InitCustomHFOptions, SparseEmbeddingModel, SparseTextEmbedding, - SparseVector + SparseVector, } from "./fastembed.js"; + +export { + DENSE_MODEL_REGISTRY, + DenseModelMetadata, +} from "./dense-model-registry.js"; diff --git a/tests/fastembed_bgebase_v15.test.ts b/tests/fastembed_bgebase_v15.test.ts index a0ed958..967e157 100644 --- a/tests/fastembed_bgebase_v15.test.ts +++ b/tests/fastembed_bgebase_v15.test.ts @@ -72,7 +72,7 @@ test("FlagEmbedding queryEmbed", async () => { expect(embeddings).toBeDefined(); expect(embeddings.length).toBe(768); }); - +https://www.stubhub.com/olivia-dean-new-york-tickets-8-17-2026/event/159897641/ test("FlagEmbedding passageEmbed", async () => { const flagEmbedding = await FlagEmbedding.init({ model: EmbeddingModel.BGEBaseENV15, @@ -91,7 +91,7 @@ test("FlagEmbedding canonical values", async () => { model: EmbeddingModel.BGEBaseENV15, maxLength: 512, }); - const expected = [0.01129394, 0.05493144, 0.02615099, 0.00328772, 0.02996045]; + const expected = [0.010724321007728577, 0.05578266456723213, 0.02708405815064907, 0.0030409879982471466, 0.030335525050759315]; const embeddings = (await flagEmbedding.embed(["hello world"]).next()).value!; expect(embeddings).toBeDefined(); diff --git a/tests/fastembed_bgesmall_v15.ts b/tests/fastembed_bgesmall_v15.test.ts similarity index 100% rename from tests/fastembed_bgesmall_v15.ts rename to tests/fastembed_bgesmall_v15.test.ts diff --git a/tests/fastembed_custom_hf.test.ts b/tests/fastembed_custom_hf.test.ts new file mode 100644 index 0000000..6580b3d --- /dev/null +++ b/tests/fastembed_custom_hf.test.ts @@ -0,0 +1,68 @@ +import { describe, expect, test } from "vitest"; +import { EmbeddingModel, FlagEmbedding } from "../src"; + +describe("FastEmbed Custom HuggingFace Model Tests", () => { + test( + "loads dense model using enum (standard approach)", + async () => { + const model = await FlagEmbedding.init({ + model: EmbeddingModel.BGESmallENV15, + }); + expect(model).toBeDefined(); + + const embeddings = (await model.embed(["test"]).next()).value!; + expect(embeddings).toBeDefined(); + expect(embeddings.length).toBe(1); + expect(embeddings[0].length).toBe(384); + }, + 120000 + ); + + test( + "loads dense model using actual HF repo ID", + async () => { + const model = await FlagEmbedding.init({ + model: "BAAI/bge-small-en-v1.5", + }); + expect(model).toBeDefined(); + + const embeddings = (await model.embed(["test"]).next()).value!; + expect(embeddings).toBeDefined(); + expect(embeddings.length).toBe(1); + expect(embeddings[0].length).toBe(384); + }, + 120000 + ); + + test( + "loads sentence-transformers model", + async () => { + const model = await FlagEmbedding.init({ + model: "sentence-transformers/all-MiniLM-L6-v2", + }); + expect(model).toBeDefined(); + + const embeddings = (await model.embed(["hello world"]).next()).value!; + expect(embeddings).toBeDefined(); + expect(embeddings.length).toBe(1); + expect(embeddings[0].length).toBe(384); + }, + 120000 + ); + + test( + "embeddings are normalized", + async () => { + const model = await FlagEmbedding.init({ + model: EmbeddingModel.BGESmallENV15, + }); + + const embeddings = (await model.embed(["test"]).next()).value!; + const magnitude = Math.sqrt( + embeddings[0].reduce((acc, val) => acc + val * val, 0) + ); + expect(magnitude).toBeCloseTo(1.0, 5); + }, + 120000 + ); +});