Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions routellm/routers/matrix_factorization/model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
from huggingface_hub import PyTorchModelHubMixin

from routellm.routers.similarity_weighted.utils import OPENAI_CLIENT
from routellm.routers.similarity_weighted.utils import get_openai_client

MODEL_IDS = {
"RWKV-4-Raven-14B": 0,
Expand Down Expand Up @@ -110,7 +110,8 @@ def forward(self, model_id, prompt):
model_embed = torch.nn.functional.normalize(model_embed, p=2, dim=1)

prompt_embed = (
OPENAI_CLIENT.embeddings.create(input=[prompt], model=self.embedding_model)
get_openai_client()
.embeddings.create(input=[prompt], model=self.embedding_model)
.data[0]
.embedding
)
Expand Down
4 changes: 2 additions & 2 deletions routellm/routers/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
from routellm.routers.causal_llm.model import CausalLLMClassifier
from routellm.routers.matrix_factorization.model import MODEL_IDS, MFModel
from routellm.routers.similarity_weighted.utils import (
OPENAI_CLIENT,
compute_elo_mle_with_tie,
compute_tiers,
get_openai_client,
preprocess_battles,
)

Expand Down Expand Up @@ -180,7 +180,7 @@ def calculate_strong_win_rate(
):
prompt_emb = (
(
OPENAI_CLIENT.embeddings.create(
get_openai_client().embeddings.create(
input=[prompt], model=self.embedding_model
)
)
Expand Down
14 changes: 13 additions & 1 deletion routellm/routers/similarity_weighted/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import json
import math
import os
Expand All @@ -8,7 +9,18 @@
from sklearn.linear_model import LogisticRegression

choices = ["A", "B", "C", "D"]
OPENAI_CLIENT = OpenAI()


@functools.lru_cache(maxsize=1)
def get_openai_client():
"""Lazily instantiate (and cache) the OpenAI client.

Constructing ``OpenAI()`` requires ``OPENAI_API_KEY`` to be set. Creating it
at import time would force every router to need an API key, even local ones
(e.g. ``bert``, ``causal_llm``) that never call OpenAI. Deferring creation to
first use keeps those routers importable without any credentials.
"""
return OpenAI()


def compute_tiers(model_ratings, num_tiers):
Expand Down