From e6053b60e7b6c08049c288c5921f2940268eab9b Mon Sep 17 00:00:00 2001 From: "Frank Chiarulli Jr." Date: Fri, 10 Oct 2025 00:55:41 -0400 Subject: [PATCH 1/3] update to pull from huggingface --- fastembed.go | 95 ++++++++++++++++++++++++++++++++++++---------------- 1 file changed, 66 insertions(+), 29 deletions(-) diff --git a/fastembed.go b/fastembed.go index 362074e..6eb6702 100644 --- a/fastembed.go +++ b/fastembed.go @@ -36,6 +36,16 @@ const ( // MLE5Large EmbeddingModel = "fast-multilingual-e5-large" ) +// Map of model names to their HuggingFace repository IDs +var modelToHuggingFaceID = map[EmbeddingModel]string{ + AllMiniLML6V2: "sentence-transformers/all-MiniLM-L6-v2", + BGEBaseEN: "BAAI/bge-base-en", + BGEBaseENV15: "BAAI/bge-base-en-v1.5", + BGESmallEN: "BAAI/bge-small-en", + BGESmallENV15: "BAAI/bge-small-en-v1.5", + BGESmallZH: "BAAI/bge-base-zh-v1.5", +} + // Struct to interface with a FastEmbed model. type FlagEmbedding struct { tokenizer *tokenizer.Tokenizer @@ -417,48 +427,75 @@ func retrieveModel(model EmbeddingModel, cacheDir string, showDownloadProgress b if _, err := os.Stat(filepath.Join(cacheDir, string(model))); !errors.Is(err, fs.ErrNotExist) { return filepath.Join(cacheDir, string(model)), nil } - return downloadFromGcs(model, cacheDir, showDownloadProgress) + return downloadFromHuggingFace(model, cacheDir, showDownloadProgress) } -// Private function to download the model from Google Cloud Storage. -func downloadFromGcs(model EmbeddingModel, cacheDir string, showDownloadProgress bool) (string, error) { - // The MLE5Large model URL doesn't follow the same naming convention as the other models - // So, we tranform "fast-multilingual-e5-large" -> "intfloat-multilingual-e5-large" in the download URL - // The model directory name in the GCS storage is "fast-multilingual-e5-large", like the others - // modelName := model - // if model == MLE5Large { - // modelName = "intfloat" + model[strings.Index(string(model), "-"):] - // } - - downloadURL := fmt.Sprintf("https://storage.googleapis.com/qdrant-fastembed/%s.tar.gz", model) +// Private function to download the model from HuggingFace. +func downloadFromHuggingFace(model EmbeddingModel, cacheDir string, showDownloadProgress bool) (string, error) { + hfModelID, ok := modelToHuggingFaceID[model] + if !ok { + return "", fmt.Errorf("no HuggingFace model ID found for %s", model) + } - response, err := http.Get(downloadURL) - if err != nil { + modelDir := filepath.Join(cacheDir, string(model)) + if err := os.MkdirAll(modelDir, 0755); err != nil { return "", err } - defer response.Body.Close() - if response.StatusCode < 200 || response.StatusCode > 299 { - return "", fmt.Errorf("model download failed: %s", response.Status) + // List of files to download from HuggingFace + files := []string{ + "model_optimized.onnx", + "tokenizer.json", + "config.json", + "tokenizer_config.json", + "special_tokens_map.json", } if showDownloadProgress { - bar := progressbar.DefaultBytes( - response.ContentLength, - "Downloading "+string(model), - ) - reader := progressbar.NewReader(response.Body, bar) - err = untar(&reader, cacheDir) - } else { - fmt.Printf("Downloading %s...", model) - err = untar(response.Body, cacheDir) + fmt.Printf("Downloading %s from HuggingFace...\n", model) } - if err != nil { - return "", err + // Download each file + for _, filename := range files { + downloadURL := fmt.Sprintf("https://huggingface.co/%s/resolve/main/%s", hfModelID, filename) + + response, err := http.Get(downloadURL) + if err != nil { + return "", fmt.Errorf("failed to download %s: %w", filename, err) + } + + if response.StatusCode < 200 || response.StatusCode > 299 { + response.Body.Close() + return "", fmt.Errorf("failed to download %s: %s", filename, response.Status) + } + + destPath := filepath.Join(modelDir, filename) + destFile, err := os.Create(destPath) + if err != nil { + response.Body.Close() + return "", fmt.Errorf("failed to create %s: %w", filename, err) + } + + if showDownloadProgress { + bar := progressbar.DefaultBytes( + response.ContentLength, + fmt.Sprintf("Downloading %s", filename), + ) + reader := progressbar.NewReader(response.Body, bar) + _, err = io.Copy(destFile, &reader) + } else { + _, err = io.Copy(destFile, response.Body) + } + + destFile.Close() + response.Body.Close() + + if err != nil { + return "", fmt.Errorf("failed to write %s: %w", filename, err) + } } - return filepath.Join(cacheDir, string(model)), nil + return modelDir, nil } // Private function to untar the downloaded model from a .tar.gz file. From e7dd3dba93cc96235a2ebf20bda67025787b4942 Mon Sep 17 00:00:00 2001 From: "Frank Chiarulli Jr." Date: Fri, 10 Oct 2025 00:59:17 -0400 Subject: [PATCH 2/3] use correct Qdrant onnx models --- fastembed.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/fastembed.go b/fastembed.go index 6eb6702..3bd8ce7 100644 --- a/fastembed.go +++ b/fastembed.go @@ -38,12 +38,12 @@ const ( // Map of model names to their HuggingFace repository IDs var modelToHuggingFaceID = map[EmbeddingModel]string{ - AllMiniLML6V2: "sentence-transformers/all-MiniLM-L6-v2", - BGEBaseEN: "BAAI/bge-base-en", - BGEBaseENV15: "BAAI/bge-base-en-v1.5", - BGESmallEN: "BAAI/bge-small-en", - BGESmallENV15: "BAAI/bge-small-en-v1.5", - BGESmallZH: "BAAI/bge-base-zh-v1.5", + AllMiniLML6V2: "Qdrant/fast-all-MiniLM-L6-v2", + BGEBaseEN: "Qdrant/fast-bge-base-en", + BGEBaseENV15: "Qdrant/fast-bge-base-en-v1.5", + BGESmallEN: "Qdrant/fast-bge-small-en", + BGESmallENV15: "Qdrant/fast-bge-small-en-v1.5", + BGESmallZH: "Qdrant/fast-bge-small-zh-v1.5", } // Struct to interface with a FastEmbed model. From 6d61baa73477d12a294560072e97b56ce8f4e690 Mon Sep 17 00:00:00 2001 From: "Frank Chiarulli Jr." Date: Fri, 10 Oct 2025 01:04:52 -0400 Subject: [PATCH 3/3] update to correct huggingface paths --- fastembed.go | 32 ++++++++++++++++++++++++++------ 1 file changed, 26 insertions(+), 6 deletions(-) diff --git a/fastembed.go b/fastembed.go index 3bd8ce7..6f94b2f 100644 --- a/fastembed.go +++ b/fastembed.go @@ -38,12 +38,17 @@ const ( // Map of model names to their HuggingFace repository IDs var modelToHuggingFaceID = map[EmbeddingModel]string{ - AllMiniLML6V2: "Qdrant/fast-all-MiniLM-L6-v2", + AllMiniLML6V2: "Qdrant/all-MiniLM-L6-v2-onnx", BGEBaseEN: "Qdrant/fast-bge-base-en", - BGEBaseENV15: "Qdrant/fast-bge-base-en-v1.5", - BGESmallEN: "Qdrant/fast-bge-small-en", - BGESmallENV15: "Qdrant/fast-bge-small-en-v1.5", - BGESmallZH: "Qdrant/fast-bge-small-zh-v1.5", + BGEBaseENV15: "Qdrant/bge-base-en-v1.5-onnx-Q", + BGESmallEN: "Qdrant/bge-small-en", + BGESmallENV15: "Qdrant/bge-small-en-v1.5-onnx-Q", + BGESmallZH: "Qdrant/bge-small-zh-v1.5", +} + +// Map of models that use model.onnx instead of model_optimized.onnx +var modelsUsingStandardOnnx = map[EmbeddingModel]bool{ + AllMiniLML6V2: true, } // Struct to interface with a FastEmbed model. @@ -442,9 +447,15 @@ func downloadFromHuggingFace(model EmbeddingModel, cacheDir string, showDownload return "", err } + // Determine which ONNX model file to download + modelFile := "model_optimized.onnx" + if modelsUsingStandardOnnx[model] { + modelFile = "model.onnx" + } + // List of files to download from HuggingFace files := []string{ - "model_optimized.onnx", + modelFile, "tokenizer.json", "config.json", "tokenizer_config.json", @@ -495,6 +506,15 @@ func downloadFromHuggingFace(model EmbeddingModel, cacheDir string, showDownload } } + // If we downloaded model.onnx, rename it to model_optimized.onnx for consistency + if modelsUsingStandardOnnx[model] { + oldPath := filepath.Join(modelDir, "model.onnx") + newPath := filepath.Join(modelDir, "model_optimized.onnx") + if err := os.Rename(oldPath, newPath); err != nil { + return "", fmt.Errorf("failed to rename model.onnx: %w", err) + } + } + return modelDir, nil }