diff --git a/routellm/routers/matrix_factorization/model.py b/routellm/routers/matrix_factorization/model.py index 09fbb25..f292b8d 100644 --- a/routellm/routers/matrix_factorization/model.py +++ b/routellm/routers/matrix_factorization/model.py @@ -1,7 +1,7 @@ import torch from huggingface_hub import PyTorchModelHubMixin -from routellm.routers.similarity_weighted.utils import OPENAI_CLIENT +from routellm.routers.similarity_weighted.utils import get_openai_client MODEL_IDS = { "RWKV-4-Raven-14B": 0, @@ -110,7 +110,8 @@ def forward(self, model_id, prompt): model_embed = torch.nn.functional.normalize(model_embed, p=2, dim=1) prompt_embed = ( - OPENAI_CLIENT.embeddings.create(input=[prompt], model=self.embedding_model) + get_openai_client() + .embeddings.create(input=[prompt], model=self.embedding_model) .data[0] .embedding ) diff --git a/routellm/routers/routers.py b/routellm/routers/routers.py index 0096c0a..37e07cc 100644 --- a/routellm/routers/routers.py +++ b/routellm/routers/routers.py @@ -16,9 +16,9 @@ from routellm.routers.causal_llm.model import CausalLLMClassifier from routellm.routers.matrix_factorization.model import MODEL_IDS, MFModel from routellm.routers.similarity_weighted.utils import ( - OPENAI_CLIENT, compute_elo_mle_with_tie, compute_tiers, + get_openai_client, preprocess_battles, ) @@ -180,7 +180,7 @@ def calculate_strong_win_rate( ): prompt_emb = ( ( - OPENAI_CLIENT.embeddings.create( + get_openai_client().embeddings.create( input=[prompt], model=self.embedding_model ) ) diff --git a/routellm/routers/similarity_weighted/utils.py b/routellm/routers/similarity_weighted/utils.py index 19035ce..435fd07 100644 --- a/routellm/routers/similarity_weighted/utils.py +++ b/routellm/routers/similarity_weighted/utils.py @@ -1,3 +1,4 @@ +import functools import json import math import os @@ -8,7 +9,18 @@ from sklearn.linear_model import LogisticRegression choices = ["A", "B", "C", "D"] -OPENAI_CLIENT = OpenAI() + + +@functools.lru_cache(maxsize=1) +def get_openai_client(): + """Lazily instantiate (and cache) the OpenAI client. + + Constructing ``OpenAI()`` requires ``OPENAI_API_KEY`` to be set. Creating it + at import time would force every router to need an API key, even local ones + (e.g. ``bert``, ``causal_llm``) that never call OpenAI. Deferring creation to + first use keeps those routers importable without any credentials. + """ + return OpenAI() def compute_tiers(model_ratings, num_tiers):