diff --git a/routellm/controller.py b/routellm/controller.py index 8a02a05..6c8f837 100644 --- a/routellm/controller.py +++ b/routellm/controller.py @@ -168,3 +168,27 @@ async def acompletion( kwargs["messages"], router, threshold ) return await acompletion(api_base=self.api_base, api_key=self.api_key, **kwargs) + # --- NEW: lightweight helper ------------------------------------------- + def invoke( + self, + prompt: str, + *, + router: str = "mf", + threshold: float = 0.5, + return_score: bool = False, + ): + """ + Fast path that runs the router **only** and returns the routed model + name. Set `return_score=True` to also get the strong-model win-rate. + + Example + ------- + >>> ctrl = Controller(routers=["mf"], strong_model="gpt-4o", weak_model="llama3-8b") + >>> model = ctrl.invoke("Write a haiku about routing.") + 'llama3-8b-8192' + """ + self._validate_router_threshold(router, threshold) + router_inst = self.routers[router] + win_rate = router_inst.calculate_strong_win_rate(prompt) + chosen = self.model_pair.strong if win_rate >= threshold else self.model_pair.weak + return (chosen, win_rate) if return_score else chosen diff --git a/routellm/routers/routers.py b/routellm/routers/routers.py index 0096c0a..765586c 100644 --- a/routellm/routers/routers.py +++ b/routellm/routers/routers.py @@ -140,6 +140,7 @@ def __init__( strong_model="gpt-4-1106-preview", weak_model="mixtral-8x7b-instruct-v0.1", num_tiers=10, + embedding_model: str = "text-embedding-3-small" ): self.strong_model = strong_model self.weak_model = weak_model @@ -154,7 +155,7 @@ def __init__( for dataset in arena_embedding_datasets ] self.arena_conv_embedding = np.concatenate(embeddings) - self.embedding_model = "text-embedding-3-small" + self.embedding_model = embedding_model assert len(self.arena_df) == len( self.arena_conv_embedding @@ -181,7 +182,7 @@ def calculate_strong_win_rate( prompt_emb = ( ( OPENAI_CLIENT.embeddings.create( - input=[prompt], model=self.embedding_model + input=[prompt], model=self.embedding_model , encoding_format="float" ) ) .data[0] diff --git a/routellm/routers/similarity_weighted/utils.py b/routellm/routers/similarity_weighted/utils.py index 19035ce..e703079 100644 --- a/routellm/routers/similarity_weighted/utils.py +++ b/routellm/routers/similarity_weighted/utils.py @@ -8,7 +8,7 @@ from sklearn.linear_model import LogisticRegression choices = ["A", "B", "C", "D"] -OPENAI_CLIENT = OpenAI() +OPENAI_CLIENT = OpenAI(api_key="tpsg-WuXePnCvR4CG6y8wB3MXbAkqHi4yCaf" , base_url="https://api.tapsage.com/openai/v1") def compute_tiers(model_ratings, num_tiers):