Skip to content
This repository was archived by the owner on Jan 15, 2026. It is now read-only.
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 86 additions & 29 deletions fastembed.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,21 @@ const (
// MLE5Large EmbeddingModel = "fast-multilingual-e5-large"
)

// Map of model names to their HuggingFace repository IDs
var modelToHuggingFaceID = map[EmbeddingModel]string{
AllMiniLML6V2: "Qdrant/all-MiniLM-L6-v2-onnx",
BGEBaseEN: "Qdrant/fast-bge-base-en",
BGEBaseENV15: "Qdrant/bge-base-en-v1.5-onnx-Q",
BGESmallEN: "Qdrant/bge-small-en",
BGESmallENV15: "Qdrant/bge-small-en-v1.5-onnx-Q",
BGESmallZH: "Qdrant/bge-small-zh-v1.5",
}

// Map of models that use model.onnx instead of model_optimized.onnx
var modelsUsingStandardOnnx = map[EmbeddingModel]bool{
AllMiniLML6V2: true,
}

// Struct to interface with a FastEmbed model.
type FlagEmbedding struct {
tokenizer *tokenizer.Tokenizer
Expand Down Expand Up @@ -417,48 +432,90 @@ func retrieveModel(model EmbeddingModel, cacheDir string, showDownloadProgress b
if _, err := os.Stat(filepath.Join(cacheDir, string(model))); !errors.Is(err, fs.ErrNotExist) {
return filepath.Join(cacheDir, string(model)), nil
}
return downloadFromGcs(model, cacheDir, showDownloadProgress)
return downloadFromHuggingFace(model, cacheDir, showDownloadProgress)
}

// Private function to download the model from Google Cloud Storage.
func downloadFromGcs(model EmbeddingModel, cacheDir string, showDownloadProgress bool) (string, error) {
// The MLE5Large model URL doesn't follow the same naming convention as the other models
// So, we tranform "fast-multilingual-e5-large" -> "intfloat-multilingual-e5-large" in the download URL
// The model directory name in the GCS storage is "fast-multilingual-e5-large", like the others
// modelName := model
// if model == MLE5Large {
// modelName = "intfloat" + model[strings.Index(string(model), "-"):]
// }

downloadURL := fmt.Sprintf("https://storage.googleapis.com/qdrant-fastembed/%s.tar.gz", model)
// Private function to download the model from HuggingFace.
func downloadFromHuggingFace(model EmbeddingModel, cacheDir string, showDownloadProgress bool) (string, error) {
hfModelID, ok := modelToHuggingFaceID[model]
if !ok {
return "", fmt.Errorf("no HuggingFace model ID found for %s", model)
}

response, err := http.Get(downloadURL)
if err != nil {
modelDir := filepath.Join(cacheDir, string(model))
if err := os.MkdirAll(modelDir, 0755); err != nil {
return "", err
}
defer response.Body.Close()

if response.StatusCode < 200 || response.StatusCode > 299 {
return "", fmt.Errorf("model download failed: %s", response.Status)
// Determine which ONNX model file to download
modelFile := "model_optimized.onnx"
if modelsUsingStandardOnnx[model] {
modelFile = "model.onnx"
}

// List of files to download from HuggingFace
files := []string{
modelFile,
"tokenizer.json",
"config.json",
"tokenizer_config.json",
"special_tokens_map.json",
}

if showDownloadProgress {
bar := progressbar.DefaultBytes(
response.ContentLength,
"Downloading "+string(model),
)
reader := progressbar.NewReader(response.Body, bar)
err = untar(&reader, cacheDir)
} else {
fmt.Printf("Downloading %s...", model)
err = untar(response.Body, cacheDir)
fmt.Printf("Downloading %s from HuggingFace...\n", model)
}

if err != nil {
return "", err
// Download each file
for _, filename := range files {
downloadURL := fmt.Sprintf("https://huggingface.co/%s/resolve/main/%s", hfModelID, filename)

response, err := http.Get(downloadURL)
if err != nil {
return "", fmt.Errorf("failed to download %s: %w", filename, err)
}

if response.StatusCode < 200 || response.StatusCode > 299 {
response.Body.Close()
return "", fmt.Errorf("failed to download %s: %s", filename, response.Status)
}

destPath := filepath.Join(modelDir, filename)
destFile, err := os.Create(destPath)
if err != nil {
response.Body.Close()
return "", fmt.Errorf("failed to create %s: %w", filename, err)
}

if showDownloadProgress {
bar := progressbar.DefaultBytes(
response.ContentLength,
fmt.Sprintf("Downloading %s", filename),
)
reader := progressbar.NewReader(response.Body, bar)
_, err = io.Copy(destFile, &reader)
} else {
_, err = io.Copy(destFile, response.Body)
}

destFile.Close()
response.Body.Close()

if err != nil {
return "", fmt.Errorf("failed to write %s: %w", filename, err)
}
}

// If we downloaded model.onnx, rename it to model_optimized.onnx for consistency
if modelsUsingStandardOnnx[model] {
oldPath := filepath.Join(modelDir, "model.onnx")
newPath := filepath.Join(modelDir, "model_optimized.onnx")
if err := os.Rename(oldPath, newPath); err != nil {
return "", fmt.Errorf("failed to rename model.onnx: %w", err)
}
}

return filepath.Join(cacheDir, string(model)), nil
return modelDir, nil
}

// Private function to untar the downloaded model from a .tar.gz file.
Expand Down