diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index f854ca694..9d00f27fa 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -46,7 +46,7 @@ jobs: enable-cache: true - name: Install the project - run: uv sync --extra dev-gpu + run: uv sync --extra dev-gpu --extra litellm - name: Ensure cache directories exist run: mkdir -p cache/models cache/datasets diff --git a/.gitignore b/.gitignore index 77e347cf5..b1077b2a7 100644 --- a/.gitignore +++ b/.gitignore @@ -183,3 +183,4 @@ prod_env logs _logs outputs +.litellm_cache/ diff --git a/docs/source/installation.mdx b/docs/source/installation.mdx index 0025c04f6..859411105 100644 --- a/docs/source/installation.mdx +++ b/docs/source/installation.mdx @@ -38,7 +38,7 @@ Lighteval provides several optional dependency groups that you can install based |-------|-------------|--------------| | `vllm` | Use VLLM as backend for high-performance inference | vllm>=0.10.0, ray, more_itertools | | `tgi` | Use Text Generation Inference API | text-generation>=0.6.0 | -| `litellm` | Use LiteLLM for unified API access | litellm, diskcache | +| `litellm` | Use LiteLLM for unified API access (generative + loglikelihood for completion-capable models) | litellm, diskcache | | `optimum` | Use Optimum for optimized models | optimum==1.12.0 | | `quantization` | Evaluate quantized models | bitsandbytes>=0.41.0, auto-gptq>=0.4.2 | | `adapters` | Evaluate adapter models (PEFT, Delta) | peft==0.3.0 | diff --git a/docs/source/package_reference/models.mdx b/docs/source/package_reference/models.mdx index 7655396e5..e205c2274 100644 --- a/docs/source/package_reference/models.mdx +++ b/docs/source/package_reference/models.mdx @@ -38,6 +38,7 @@ set in the `model-args` or in the model yaml file (see example ### Litellm Model [[autodoc]] models.endpoints.litellm_model.LiteLLMModelConfig +[[autodoc]] models.endpoints.litellm_model.LiteLLMClient ## Custom Model [[autodoc]] models.custom.custom_model.CustomModelConfig diff --git a/docs/source/use-litellm-as-backend.mdx b/docs/source/use-litellm-as-backend.mdx index 257da5d32..bdc000c85 100644 --- a/docs/source/use-litellm-as-backend.mdx +++ b/docs/source/use-litellm-as-backend.mdx @@ -7,6 +7,49 @@ OpenAI, Groq, and many others. > [!TIP] > Documentation for available APIs and compatible endpoints can be found [here](https://docs.litellm.ai/docs/). +## Supported Evaluation Modes + +The LiteLLM backend supports two evaluation modes depending on the model and provider: + +| Mode | Method | Benchmarks | Provider requirement | +|------|--------|-----------|----------------------| +| **Generative** | `greedy_until` | GSM8K, HLE, IFEval, … | Any chat-completion provider | +| **Log-likelihood** | `loglikelihood` / `loglikelihood_rolling` | MMLU, ARC, HellaSwag, … | `/v1/completions` endpoint with `echo=True` | + +### Generative evaluation — all providers + +Generative benchmarks route through `POST /v1/chat/completions`. Any model from +any provider works. + +```bash +lighteval endpoint litellm \ + "provider=openai,model_name=gpt-4o" \ + gsm8k +``` + +### Log-likelihood evaluation — completion-capable models only + +MCQ benchmarks (MMLU, ARC, HellaSwag, …) and perplexity benchmarks require +`POST /v1/completions` with `echo=True` and `logprobs=1`. **Chat-only models +such as gpt-4o, Claude, or Gemini do not expose this endpoint** and will +produce all-`-inf` results with a warning. + +Supported models: +- **OpenAI**: `gpt-3.5-turbo-instruct` +- **Local servers**: any OpenAI-compatible server — llama.cpp, `vllm serve`, Ollama, etc. + +```bash +# Run MCQ benchmarks (requires completion endpoint) +lighteval endpoint litellm \ + examples/model_configs/litellm_completion_model.yaml \ + "mmlu|0" "arc|0" "hellaswag|0" +``` + +> [!WARNING] +> Lighteval automatically detects unsupported chat-only models via +> `litellm.get_model_info()` and emits a WARNING before the evaluation starts. +> Results will be `-inf` for every choice and metrics will be at chance level. + ## Basic Usage ```bash @@ -18,9 +61,9 @@ lighteval endpoint litellm \ ## Using a Configuration File LiteLLM allows generation with any OpenAI-compatible endpoint. For example, you -can evaluate a model running on a local VLLM server. +can evaluate a model running on a local vLLM server. -To do so, you will need to use a configuration file like this: +**Generative tasks** (`examples/model_configs/litellm_model.yaml`): ```yaml model_parameters: @@ -37,25 +80,39 @@ model_parameters: frequency_penalty: 0.0 ``` +**Log-likelihood / MCQ tasks** (`examples/model_configs/litellm_completion_model.yaml`): + +```yaml +model_parameters: + model_name: "gpt-3.5-turbo-instruct" + provider: "openai" + concurrent_requests: 10 + generation_parameters: + seed: 42 +``` + ## Supported Providers LiteLLM supports a wide range of LLM providers: ### Cloud Providers -all cloud providers can be found in the [litellm documentation](https://docs.litellm.ai/docs/providers). +All cloud providers can be found in the [litellm documentation](https://docs.litellm.ai/docs/providers). ### Local/On-Premise -- **VLLM**: Local VLLM servers -- **Hugging Face**: Local Hugging Face models +- **vLLM**: Local vLLM servers (supports both generative and log-likelihood) +- **llama.cpp**: OpenAI-compatible server (supports both generative and log-likelihood) +- **Ollama**: OpenAI-compatible endpoint (generative only) - **Custom endpoints**: Any OpenAI-compatible API ## Using with Local Models -### VLLM Server -To use with a local VLLM server: +### vLLM Server (generative + log-likelihood) + +Local vLLM servers expose both `/v1/chat/completions` and `/v1/completions`, so +they support all evaluation modes. -1. Start your VLLM server: +1. Start your vLLM server: ```bash vllm serve HuggingFaceH4/zephyr-7b-beta --host 0.0.0.0 --port 8000 ``` @@ -67,6 +124,46 @@ model_parameters: model_name: "hosted_vllm/HuggingFaceH4/zephyr-7b-beta" base_url: "http://localhost:8000/v1" api_key: "" + generation_parameters: + seed: 42 +``` + +3. Run any benchmark: +```bash +# Generative +lighteval endpoint litellm my_config.yaml "gsm8k|0" + +# MCQ (log-likelihood) +lighteval endpoint litellm my_config.yaml "mmlu|0" "arc|0" ``` +### llama.cpp Server (generative + log-likelihood) + +```bash +./llama-server -m model.gguf --port 8080 +``` + +```yaml +model_parameters: + model_name: "openai/local" + base_url: "http://localhost:8080/v1" + api_key: "none" +``` + +## Generation Parameters + +All parameters in `generation_parameters` are forwarded appropriately to the +underlying API call. + +| Parameter | Generative (`/chat/completions`) | Log-likelihood (`/completions`) | +|-----------|----------------------------------|----------------------------------| +| `seed` | ✅ | ✅ | +| `temperature` | ✅ | hardcoded `0.0` (deterministic scoring) | +| `max_new_tokens` | ✅ | hardcoded `1` (only logprobs needed) | +| `stop_tokens` | ✅ | ✅ | +| `top_p` | ✅ | ✅ | +| `frequency_penalty` | ✅ | ✅ | +| `presence_penalty` | ✅ | ✅ | +| `repetition_penalty` | ✅ | ❌ (not part of `/v1/completions`) | + For more detailed error handling and debugging, refer to the [LiteLLM documentation](https://docs.litellm.ai/docs/). diff --git a/examples/model_configs/litellm_completion_model.yaml b/examples/model_configs/litellm_completion_model.yaml new file mode 100644 index 000000000..a00ec392b --- /dev/null +++ b/examples/model_configs/litellm_completion_model.yaml @@ -0,0 +1,26 @@ +# LiteLLM configuration for a model that supports loglikelihood evaluation. +# +# loglikelihood() and loglikelihood_rolling() require the /v1/completions +# endpoint with echo=True and logprobs=1. Only "completion-style" models +# expose this endpoint. Examples: +# - OpenAI: gpt-3.5-turbo-instruct +# - Local (llama.cpp): openai/local-model (with base_url pointing to server) +# - Local (vLLM serve): openai/local-model (with base_url pointing to server) +# +# Chat-only models (gpt-4o, claude-*, gemini-*) do NOT support this endpoint +# and will produce -inf loglikelihoods. Use the standard litellm_model.yaml +# for generative (greedy_until) evaluations with those models. +# +# Usage: +# lighteval endpoint litellm examples/model_configs/litellm_completion_model.yaml \ +# "mmlu|0" "arc|0" "hellaswag|0" + +model_parameters: + model_name: "gpt-3.5-turbo-instruct" + provider: "openai" + concurrent_requests: 10 + api_max_retry: 8 + api_retry_sleep: 1.0 + api_retry_multiplier: 2.0 + generation_parameters: + seed: 42 diff --git a/src/lighteval/models/endpoints/inference_providers_model.py b/src/lighteval/models/endpoints/inference_providers_model.py index 54790e45b..0fe599801 100644 --- a/src/lighteval/models/endpoints/inference_providers_model.py +++ b/src/lighteval/models/endpoints/inference_providers_model.py @@ -258,12 +258,30 @@ def max_length(self) -> int: @cached(SamplingMethod.LOGPROBS) def loglikelihood(self, docs: list[Doc]) -> list[ModelResponse]: - """Tokenize the context and continuation and compute the log likelihood of those - tokenized sequences. + """Not supported for HuggingFace Inference Providers. + + The HF Inference Providers API exposes only ``/v1/chat/completions``. + That endpoint does not support ``echo=True`` or per-prompt token + log-probabilities, which are required for loglikelihood evaluation + (MCQ benchmarks such as MMLU, ARC, HellaSwag). + + Use the LiteLLM backend (``lighteval endpoint litellm``) with a + model that supports the ``/v1/completions`` endpoint — for example + ``gpt-3.5-turbo-instruct`` or any OpenAI-compatible local server — + to run loglikelihood evaluations over a remote API. """ - raise NotImplementedError + raise NotImplementedError( + "loglikelihood is not supported for the HuggingFace Inference Providers backend. " + "The provider API exposes only /v1/chat/completions, which does not return " + "per-prompt token log-probabilities. " + "Use `lighteval endpoint litellm` with a completion-capable model instead " + "(e.g. gpt-3.5-turbo-instruct or a local OpenAI-compatible server)." + ) @cached(SamplingMethod.PERPLEXITY) def loglikelihood_rolling(self, docs: list[Doc]) -> list[ModelResponse]: - """This function is used to compute the log likelihood of the context for perplexity metrics.""" - raise NotImplementedError + """Not supported for HuggingFace Inference Providers — see ``loglikelihood`` for details.""" + raise NotImplementedError( + "loglikelihood_rolling is not supported for the HuggingFace Inference Providers backend. " + "See loglikelihood() for the full explanation." + ) diff --git a/src/lighteval/models/endpoints/litellm_model.py b/src/lighteval/models/endpoints/litellm_model.py index 87332d1d7..d63198136 100644 --- a/src/lighteval/models/endpoints/litellm_model.py +++ b/src/lighteval/models/endpoints/litellm_model.py @@ -20,6 +20,7 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +import asyncio import logging import time from concurrent.futures import ThreadPoolExecutor @@ -28,7 +29,7 @@ import requests from tqdm import tqdm -from lighteval.data import GenerativeTaskDataset +from lighteval.data import GenerativeTaskDataset, LoglikelihoodDataset from lighteval.models.abstract_model import LightevalModel, ModelConfig from lighteval.models.model_output import ModelResponse from lighteval.tasks.prompt_manager import PromptManager @@ -103,17 +104,32 @@ class LiteLLMModelConfig(ModelConfig): This prompt sets the behavior and context for the model during evaluation. cache_dir (str, optional, defaults to "~/.cache/huggingface/lighteval"): Directory to cache the model. + Supported evaluation modes: + - ``greedy_until`` (generative): all models and providers supported. + - ``loglikelihood`` (MCQ ranking) and ``loglikelihood_rolling`` (perplexity): + requires the ``/v1/completions`` endpoint with ``echo=True`` and + ``logprobs=1``. Supported by ``gpt-3.5-turbo-instruct`` and any + OpenAI-compatible local server (llama.cpp, vLLM serve, etc.). + Chat-only models (gpt-4o, Claude, Gemini) are **not** supported for + these modes; the backend will warn at eval start and return ``-inf``. + Example: ```python + # Generative tasks only (any provider) config = LiteLLMModelConfig( model_name="gpt-4", provider="openai", base_url="https://api.openai.com/v1", concurrent_requests=5, - generation_parameters=GenerationParameters( - temperature=0.7, - max_new_tokens=100 - ) + generation_parameters=GenerationParameters(temperature=0.7, max_new_tokens=100), + ) + + # MCQ / perplexity tasks (requires /v1/completions support) + config = LiteLLMModelConfig( + model_name="gpt-3.5-turbo-instruct", + provider="openai", + concurrent_requests=10, + generation_parameters=GenerationParameters(seed=42), ) ``` """ @@ -347,7 +363,7 @@ def greedy_until( position=0, disable=self.disable_tqdm, ): - contexts = [self.prompt_manager.prepare_prompt_api(doc) for doc in dataset] + contexts = [self.prompt_manager.prepare_prompt_api(doc) for doc in split] max_new_tokens = split[0].generation_size # could be none return_logits = split[0].use_logits num_samples = split[0].num_samples @@ -403,14 +419,378 @@ def max_length(self) -> int: return max_tokens + # ------------------------------------------------------------------ + # Token Alignment Engine helpers + # ------------------------------------------------------------------ + + @staticmethod + def _find_continuation_start(logprobs_obj, context_str: str, model: str) -> int: + """Return the index in the token sequence where the continuation begins. + + Two-layer strategy: + 1. Character-offset alignment via ``text_offset`` (exact, uses the + API's own tokenisation — preferred when available). + 2. Tiktoken-count fallback via ``litellm.encode`` (reliable for all + OpenAI-family models when ``text_offset`` is absent). + """ + # Layer 1 — character-offset alignment + text_offset = getattr(logprobs_obj, "text_offset", None) + if text_offset: + context_char_len = len(context_str) + for i, offset in enumerate(text_offset): + if offset >= context_char_len: + return i + return len(text_offset) # empty continuation + + # Layer 2 — tiktoken fallback + try: + ctx_toks = encode(model, context_str) + return len(ctx_toks) + except Exception: + logger.warning( + "Could not align continuation tokens via text_offset or tiktoken. " + "Logprob results may be inaccurate for this provider." + ) + return 0 + + @staticmethod + def _check_argmax( + tokens: list[str], + token_logprobs: list, + top_logprobs: list, + cont_start: int, + ) -> bool: + """Return True if every continuation token was the model's top-1 prediction. + + Mirrors vLLM's ``rank == 1`` check. The last token in ``tokens`` is the + newly generated token (from ``max_tokens=1``) and is excluded from the + check. + """ + if not top_logprobs: + return False + + cont_end = len(tokens) - 1 # exclude the single generated token at the end + if cont_start >= cont_end: + return True # empty continuation trivially matches + + for i in range(cont_start, cont_end): + if i >= len(top_logprobs): + return False + top_dict = top_logprobs[i] + if not top_dict: + return False + # With logprobs=1, top_dict has exactly one key: the top-1 token. + top_token = next(iter(top_dict)) + if i >= len(tokens) or tokens[i] != top_token: + return False + + return True + + # ------------------------------------------------------------------ + # Async loglikelihood implementation + # ------------------------------------------------------------------ + + async def _call_api_text_completion_async( + self, + full_text: str, + semaphore: asyncio.Semaphore, + ): + """Async call to ``litellm.atext_completion`` with exponential backoff. + + Uses ``echo=True``, ``logprobs=1``, ``max_tokens=1``, and + ``temperature=0.0`` to retrieve per-token log-probabilities for every + token in ``full_text`` (context + continuation). + + HTTP 429 (``RateLimitError``) is handled explicitly before the generic + exception handler so that rate-limit pauses are clearly logged. + """ + async with semaphore: + for attempt in range(self.API_MAX_RETRY): + try: + response = await litellm.atext_completion( + model=self.model, + prompt=full_text, + max_tokens=1, # generate exactly 1 token (echo gives prompt logprobs) + echo=True, # return prompt tokens with their log-probabilities + logprobs=1, # top-1 logprob per position for argmax check + temperature=0.0, # deterministic scoring + base_url=self.base_url, + api_key=self.api_key, + caching=True, + timeout=self.timeout, + **self.generation_parameters.to_litellm_text_completion_dict(), + ) + return response + except litellm.RateLimitError: + wait_time = min(64.0, self.API_RETRY_SLEEP * (self.API_RETRY_MULTIPLIER**attempt)) + logger.warning( + f"Rate limit (HTTP 429) on loglikelihood call — " + f"backing off {wait_time:.1f}s " + f"(attempt {attempt + 1}/{self.API_MAX_RETRY})" + ) + await asyncio.sleep(wait_time) + except Exception as e: + wait_time = min(64.0, self.API_RETRY_SLEEP * (self.API_RETRY_MULTIPLIER**attempt)) + logger.warning( + f"Error in loglikelihood API call: {e} — " + f"backing off {wait_time:.1f}s " + f"(attempt {attempt + 1}/{self.API_MAX_RETRY})" + ) + await asyncio.sleep(wait_time) + + logger.error( + f"Loglikelihood API call failed after {self.API_MAX_RETRY} attempts, returning None." + ) + return None + + async def _process_doc_loglikelihood_async( + self, + doc: Doc, + context_str: str, + semaphore: asyncio.Semaphore, + ) -> ModelResponse: + """Compute logprobs for all choices of a single doc concurrently. + + All (context + choice) API calls for this doc are fired at once via + ``asyncio.gather``, bounded by the shared semaphore. Returns a + ``ModelResponse`` with ``logprobs`` and ``argmax_logits_eq_gold`` + populated per-choice, matching the VLLMModel data contract exactly. + """ + # Soft length check using the longest choice as a conservative estimate + if doc.choices: + longest = max(doc.choices, key=len) + self._warn_if_too_long(context_str + longest, label=f"doc '{doc.id}' longest choice") + + tasks = [ + self._call_api_text_completion_async(context_str + choice, semaphore) + for choice in doc.choices + ] + responses = await asyncio.gather(*tasks) + + logprobs_per_choice: list[float] = [] + argmax_per_choice: list[bool] = [] + + for choice, response in zip(doc.choices, responses): + if response is None or not getattr(response, "choices", None): + logprobs_per_choice.append(float("-inf")) + argmax_per_choice.append(False) + continue + + lp_obj = getattr(response.choices[0], "logprobs", None) + if lp_obj is None or not getattr(lp_obj, "token_logprobs", None): + logprobs_per_choice.append(float("-inf")) + argmax_per_choice.append(False) + continue + + tokens: list[str] = list(lp_obj.tokens or []) + token_logprobs: list = list(lp_obj.token_logprobs or []) + top_logprobs: list = list(lp_obj.top_logprobs or []) + + cont_start = self._find_continuation_start(lp_obj, context_str, self.model) + + # token_logprobs[cont_start:-1] isolates the continuation slice. + # The -1 excludes the single token generated by max_tokens=1 which + # is appended at the very end of the echoed sequence. + cont_lp_slice = token_logprobs[cont_start:-1] + valid_lp = [v for v in cont_lp_slice if v is not None] + total_logprob = sum(valid_lp) if valid_lp else float("-inf") + + is_argmax = self._check_argmax(tokens, token_logprobs, top_logprobs, cont_start) + + logprobs_per_choice.append(total_logprob) + argmax_per_choice.append(is_argmax) + + return ModelResponse( + input=context_str, + logprobs=logprobs_per_choice, + argmax_logits_eq_gold=argmax_per_choice, + ) + + async def _loglikelihood_async(self, docs: list[Doc]) -> list[ModelResponse]: + """Async coordinator: process every doc in parallel, bounded by the semaphore. + + ``asyncio.gather`` preserves input order, so the returned list aligns + 1-to-1 with ``docs``. + """ + semaphore = asyncio.Semaphore(self.concurrent_requests) + tasks = [ + self._process_doc_loglikelihood_async( + doc=doc, + context_str=self.prompt_manager._prepare_plain_text(doc), + semaphore=semaphore, + ) + for doc in docs + ] + return list(await asyncio.gather(*tasks)) + + # ------------------------------------------------------------------ + # Provider compatibility guard + # ------------------------------------------------------------------ + + def _check_text_completion_support(self) -> None: + """Warn if the model is known to be chat-only and cannot serve loglikelihoods. + + ``loglikelihood`` and ``loglikelihood_rolling`` use + ``litellm.atext_completion`` (``/v1/completions``) with ``echo=True`` + and ``logprobs=1``. Models whose ``mode`` is ``"chat"`` in litellm's + registry (gpt-4o, Claude, Gemini, …) do not expose this endpoint and + will return all-``-inf`` results. + + We use ``litellm.get_model_info()`` so we only warn for *positively + identified* chat-only models, avoiding false positives on completion + models like ``gpt-3.5-turbo-instruct`` whose params list may not + enumerate ``echo`` explicitly. + """ + try: + model_info = litellm.get_model_info(model=self.model, custom_llm_provider=self.provider) or {} + mode = model_info.get("mode", "") + if mode == "chat": + logger.warning( + f"Model '{self.model}' is registered as a chat-only model (mode='chat'). " + "loglikelihood and loglikelihood_rolling require the /v1/completions " + "endpoint with echo=True — chat-only models do not support this. " + "Results will be all -inf. " + "Use 'gpt-3.5-turbo-instruct' or any OpenAI-compatible local server " + "(llama.cpp, vLLM serve, etc.) instead." + ) + except Exception: + pass # Registry lookup failed — proceed silently, never crash an eval + + def _warn_if_too_long(self, text: str, label: str = "") -> None: + """Warn when ``text`` is estimated to exceed the model's context window. + + Uses tiktoken (via ``litellm.encode``) for a fast, local token count. + Silently skips if ``max_length`` is unknown or encoding fails. + """ + try: + n_tokens = len(encode(self.model, text)) + limit = self.max_length + if limit and n_tokens > limit: + tag = f" [{label}]" if label else "" + logger.warning( + f"Input{tag} is ~{n_tokens} tokens, which exceeds max_length={limit}. " + "The API may truncate or reject the request. " + "Consider shortening your context or choice strings." + ) + except Exception: + pass # Encoding unavailable for this model — skip silently + + # ------------------------------------------------------------------ + # loglikelihood (MCQ / log-prob ranking) + # ------------------------------------------------------------------ + @cached(SamplingMethod.LOGPROBS) def loglikelihood(self, docs: list[Doc]) -> list[ModelResponse]: - """Tokenize the context and continuation and compute the log likelihood of those - tokenized sequences. + """Compute log-likelihoods for MCQ-style tasks via ``litellm.text_completion``. + + Uses ``echo=True`` and ``logprobs=1`` to retrieve per-token log + probabilities for the full (context + choice) string, then isolates the + continuation slice with the Token Alignment Engine (see + ``_find_continuation_start``). + + Provider requirement: + The underlying model must support the ``/v1/completions`` endpoint + with ``echo`` and ``logprobs`` parameters — for example + ``gpt-3.5-turbo-instruct`` or any OpenAI-compatible local server + (llama.cpp, vLLM serving, etc.). Chat-only models such as + ``gpt-4o`` or Claude do not expose this endpoint and are **not** + supported by this method. + """ + self._check_text_completion_support() + + dataset = LoglikelihoodDataset(requests=docs, num_dataset_splits=self.DATASET_SPLITS) + results = [] + + for split in tqdm( + dataset.splits_iterator(), + total=dataset.num_dataset_splits, + desc="Loglikelihood splits", + position=0, + disable=self.disable_tqdm, + ): + split_docs = list(split) + split_results = asyncio.run(self._loglikelihood_async(split_docs)) + results.extend(split_results) + + return dataset.get_original_order(results) + + # ------------------------------------------------------------------ + # loglikelihood_rolling (perplexity) + # ------------------------------------------------------------------ + + async def _process_doc_rolling_async( + self, + doc: Doc, + semaphore: asyncio.Semaphore, + ) -> ModelResponse: + """Compute per-token log-probabilities for the entire document text. + + Sends the full document as the prompt with ``echo=True`` and collects + one logprob per token (skipping the leading null and the trailing + generated token appended by ``max_tokens=1``). + + The returned ``ModelResponse.logprobs`` is a list of per-token floats. + ``PerplexityPreparator`` sums this list to obtain the document-level + log-likelihood, which is then used to compute perplexity, weighted + perplexity, or bits-per-byte. """ - raise NotImplementedError + doc_text = self.prompt_manager._prepare_plain_text(doc) + self._warn_if_too_long(doc_text, label=f"doc '{doc.id}'") + response = await self._call_api_text_completion_async(doc_text, semaphore) + + if response is None or not getattr(response, "choices", None): + return ModelResponse(input=doc_text, logprobs=[float("-inf")]) + + lp_obj = getattr(response.choices[0], "logprobs", None) + if lp_obj is None or not getattr(lp_obj, "token_logprobs", None): + return ModelResponse(input=doc_text, logprobs=[float("-inf")]) + + token_logprobs: list = list(lp_obj.token_logprobs or []) + + # token_logprobs[0] → always None (first token has no prior context) + # token_logprobs[1:-1] → per-token log-probs for the full document + # token_logprobs[-1] → the 1 newly generated token from max_tokens=1 (discard) + rolling_logprobs = [v for v in token_logprobs[1:-1] if v is not None] + + return ModelResponse( + input=doc_text, + logprobs=rolling_logprobs, + ) + + async def _loglikelihood_rolling_async(self, docs: list[Doc]) -> list[ModelResponse]: + """Async coordinator for rolling perplexity: one API call per doc.""" + semaphore = asyncio.Semaphore(self.concurrent_requests) + tasks = [self._process_doc_rolling_async(doc=doc, semaphore=semaphore) for doc in docs] + return list(await asyncio.gather(*tasks)) @cached(SamplingMethod.PERPLEXITY) def loglikelihood_rolling(self, docs: list[Doc]) -> list[ModelResponse]: - """This function is used to compute the log likelihood of the context for perplexity metrics.""" - raise NotImplementedError + """Compute rolling log-likelihoods for perplexity-style evaluation. + + Each document is sent as a single prompt with ``echo=True`` so that the + API returns per-token log-probabilities for the whole text. The result + is a ``ModelResponse`` whose ``logprobs`` list holds one float per + token; downstream preparators (``PerplexityPreparator``, + ``TargetPerplexityPreparator``) sum these to produce the document-level + log-likelihood used by perplexity / bits-per-byte metrics. + + Provider requirement: same as ``loglikelihood`` — requires the + ``/v1/completions`` endpoint with ``echo`` and ``logprobs``. + """ + self._check_text_completion_support() + + dataset = LoglikelihoodDataset(requests=docs, num_dataset_splits=self.DATASET_SPLITS) + results = [] + + for split in tqdm( + dataset.splits_iterator(), + total=dataset.num_dataset_splits, + desc="Loglikelihood rolling splits", + position=0, + disable=self.disable_tqdm, + ): + split_docs = list(split) + split_results = asyncio.run(self._loglikelihood_rolling_async(split_docs)) + results.extend(split_results) + + return dataset.get_original_order(results) diff --git a/src/lighteval/models/model_input.py b/src/lighteval/models/model_input.py index ad41c23eb..a71064725 100644 --- a/src/lighteval/models/model_input.py +++ b/src/lighteval/models/model_input.py @@ -118,6 +118,29 @@ def to_litellm_dict(self) -> dict: "seed": self.seed, "repetition_penalty": self.repetition_penalty, "frequency_penalty": self.frequency_penalty, + "presence_penalty": self.presence_penalty, + } + return {k: v for k, v in args.items() if v is not None} + + def to_litellm_text_completion_dict(self) -> dict: + """Selects parameters relevant to the ``/v1/completions`` (text completion) endpoint. + + Used by the LiteLLM loglikelihood implementation which calls + ``litellm.atext_completion``. The caller always overrides ``max_tokens``, + ``echo``, ``logprobs``, and ``temperature`` for deterministic scoring, so + those are intentionally excluded here. + + Doc: https://docs.litellm.ai/docs/text_completion + + Returns: + dict: Parameters forwarded to ``litellm.atext_completion``. + """ + args = { + "seed": self.seed, + "stop": self.stop_tokens, + "top_p": self.top_p, + "frequency_penalty": self.frequency_penalty, + "presence_penalty": self.presence_penalty, } return {k: v for k, v in args.items() if v is not None} diff --git a/tests/unit/models/endpoints/test_litellm_loglikelihood.py b/tests/unit/models/endpoints/test_litellm_loglikelihood.py new file mode 100644 index 000000000..5601e35f7 --- /dev/null +++ b/tests/unit/models/endpoints/test_litellm_loglikelihood.py @@ -0,0 +1,1291 @@ +# MIT License + +# Copyright (c) 2024 The HuggingFace Team + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""Unit tests for the LiteLLM loglikelihood implementation. + +All litellm API calls are mocked — no network requests are made. +Async helpers are exercised via asyncio.run() to remain dependency-free +(no pytest-asyncio required). +""" + +import asyncio +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch + +import pytest + +from lighteval.models.model_input import GenerationParameters +from lighteval.utils.imports import is_package_available + + +pytestmark = pytest.mark.skipif( + not is_package_available("litellm"), + reason="litellm not installed — run `pip install lighteval[litellm]` to enable these tests", +) + +from lighteval.models.endpoints.litellm_model import LiteLLMClient # noqa: E402 +from lighteval.models.model_output import ModelResponse # noqa: E402 +from lighteval.tasks.requests import Doc # noqa: E402 + + +# --------------------------------------------------------------------------- +# Helpers — build fake litellm text_completion response objects +# --------------------------------------------------------------------------- + + +def make_logprobs(tokens, token_logprobs, top_logprobs=None, text_offset=None): + """Return a SimpleNamespace mimicking litellm's logprobs object.""" + return SimpleNamespace( + tokens=tokens, + token_logprobs=token_logprobs, + top_logprobs=top_logprobs if top_logprobs is not None else [None] * len(tokens), + text_offset=text_offset, + ) + + +def make_response(tokens, token_logprobs, top_logprobs=None, text_offset=None): + """Return a SimpleNamespace mimicking a litellm text_completion response.""" + lp = make_logprobs(tokens, token_logprobs, top_logprobs, text_offset) + choice = SimpleNamespace(logprobs=lp) + return SimpleNamespace(choices=[choice]) + + +def make_doc(query, choices, gold_index=0, task_name="test_task", doc_id="0"): + doc = Doc(query=query, choices=choices, gold_index=gold_index, task_name=task_name) + doc.id = doc_id + return doc + + +def make_bare_client( + model="gpt-3.5-turbo-instruct", + concurrent_requests=10, + api_max_retry=3, + api_retry_sleep=0.0, # instant retries in tests + api_retry_multiplier=1.0, +): + """Construct a LiteLLMClient instance bypassing __init__ to avoid real API calls.""" + client = object.__new__(LiteLLMClient) + client.model = model + client.provider = "openai" + client.base_url = None + client.api_key = None + client.timeout = None + client.concurrent_requests = concurrent_requests + client.API_MAX_RETRY = api_max_retry + client.API_RETRY_SLEEP = api_retry_sleep + client.API_RETRY_MULTIPLIER = api_retry_multiplier + client._max_length = 4096 + client.generation_parameters = GenerationParameters() + # disable_tqdm is a read-only @property on LightevalModel (returns False). + # We leave it as-is; tqdm output in tests is harmless. + + # Minimal PromptManager stub: _prepare_plain_text returns doc.query directly + pm = MagicMock() + pm._prepare_plain_text = lambda doc: doc.query + client.prompt_manager = pm + + # Disable the @cached decorator by setting _cache = None + client._cache = None + return client + + +# --------------------------------------------------------------------------- +# 1. _find_continuation_start — Token Alignment Engine unit tests +# --------------------------------------------------------------------------- + + +class TestFindContinuationStart: + """Pure function tests — no mocking needed.""" + + def test_text_offset_layer1_exact_boundary(self): + """Continuation starts exactly at len(context_str) characters.""" + # context = "Q:" (2 chars), continuation = " A" + context_str = "Q:" + lp = make_logprobs( + tokens=["Q", ":", " A", "_gen"], + token_logprobs=[None, -0.1, -0.5, -0.9], + text_offset=[0, 1, 2, 4], # " A" starts at offset 2 == len("Q:") + ) + result = LiteLLMClient._find_continuation_start(lp, context_str, "gpt-3.5-turbo-instruct") + assert result == 2 + + def test_text_offset_layer1_midpoint(self): + """Works when the context ends mid-word and text_offset values are larger.""" + context_str = "Hello world " # 12 chars + lp = make_logprobs( + tokens=["Hello", " world", " ", "foo"], + token_logprobs=[None, -0.2, -0.1, -0.3], + text_offset=[0, 5, 11, 12], # "foo" starts at 12 == len(context) + ) + result = LiteLLMClient._find_continuation_start(lp, context_str, "gpt-3.5-turbo-instruct") + assert result == 3 + + def test_text_offset_all_context_empty_continuation(self): + """All tokens belong to context (empty continuation) → returns len(tokens).""" + context_str = "ABCD" + lp = make_logprobs( + tokens=["A", "B", "C", "D"], + token_logprobs=[None] * 4, + text_offset=[0, 1, 2, 3], # no token starts at offset >= 4 + ) + result = LiteLLMClient._find_continuation_start(lp, context_str, "gpt-3.5-turbo-instruct") + assert result == 4 # == len(tokens), signals empty continuation + + def test_tiktoken_fallback_called_when_no_text_offset(self): + """Layer 2: litellm.encode is called when text_offset is absent.""" + context_str = "Hello world" + lp = make_logprobs( + tokens=["Hello", " world", " foo"], + token_logprobs=[None, -0.5, -0.3], + text_offset=None, # force fallback + ) + with patch("lighteval.models.endpoints.litellm_model.encode", return_value=[1, 2]) as mock_enc: + result = LiteLLMClient._find_continuation_start(lp, context_str, "gpt-3.5-turbo-instruct") + + mock_enc.assert_called_once_with("gpt-3.5-turbo-instruct", context_str) + assert result == 2 # len([1, 2]) + + def test_tiktoken_fallback_called_when_text_offset_is_empty_list(self): + """Empty text_offset list is falsy → Layer 2 fallback.""" + context_str = "ctx" + lp = make_logprobs(["ctx", "cont"], [None, -0.3], text_offset=[]) + with patch("lighteval.models.endpoints.litellm_model.encode", return_value=[99]) as mock_enc: + result = LiteLLMClient._find_continuation_start(lp, context_str, "gpt-3.5-turbo-instruct") + assert result == 1 + mock_enc.assert_called_once() + + def test_tiktoken_fallback_encode_failure_returns_zero(self): + """If encode raises, we fall back gracefully to 0 (log a warning).""" + context_str = "ctx" + lp = make_logprobs(["ctx", "cont"], [None, -0.3], text_offset=None) + with patch("lighteval.models.endpoints.litellm_model.encode", side_effect=RuntimeError("tiktoken unavailable")): + result = LiteLLMClient._find_continuation_start(lp, context_str, "unknown-model") + assert result == 0 + + +# --------------------------------------------------------------------------- +# 2. _check_argmax — Argmax unit tests +# --------------------------------------------------------------------------- + + +class TestCheckArgmax: + """Mirrors vLLM's `rank == 1` semantics. Last token is excluded (max_tokens=1 artifact).""" + + def test_all_continuation_tokens_match_top1(self): + # tokens: [ctx, contA, contB, generated] + # cont_start=1, cont_end=3 → check positions 1 and 2 + tokens = ["ctx", " A", " B", "_gen"] + top_logprobs = [ + {"ctx": -0.1}, + {" A": -0.5}, # actual token " A" IS top-1 ✓ + {" B": -0.3}, # actual token " B" IS top-1 ✓ + {"_gen": -0.9}, + ] + result = LiteLLMClient._check_argmax(tokens, [], top_logprobs, cont_start=1) + assert result is True + + def test_first_continuation_token_does_not_match(self): + tokens = ["ctx", " A", "_gen"] + top_logprobs = [ + {"ctx": -0.1}, + {" B": -0.2}, # top-1 is " B" but actual is " A" ✗ + {"_gen": -0.9}, + ] + result = LiteLLMClient._check_argmax(tokens, [], top_logprobs, cont_start=1) + assert result is False + + def test_partial_match_fails_overall(self): + # Two continuation tokens; first matches, second does not + tokens = ["ctx", " A", " C", "_gen"] + top_logprobs = [ + {"ctx": -0.1}, + {" A": -0.5}, # ✓ + {" B": -0.3}, # top-1 is " B" but actual is " C" ✗ + {"_gen": -0.9}, + ] + result = LiteLLMClient._check_argmax(tokens, [], top_logprobs, cont_start=1) + assert result is False + + def test_empty_continuation_returns_true(self): + # cont_start == len(tokens) - 1 means no continuation tokens + tokens = ["ctx", "_gen"] + result = LiteLLMClient._check_argmax(tokens, [], [{"ctx": -0.1}, {"_gen": -0.9}], cont_start=1) + assert result is True + + def test_empty_top_logprobs_returns_false(self): + result = LiteLLMClient._check_argmax(["a", "b", "_gen"], [], [], cont_start=0) + assert result is False + + def test_none_top_dict_at_position_returns_false(self): + tokens = ["ctx", " A", "_gen"] + top_logprobs = [{"ctx": -0.1}, None, {"_gen": -0.5}] + result = LiteLLMClient._check_argmax(tokens, [], top_logprobs, cont_start=1) + assert result is False + + def test_cont_start_beyond_tokens_length_returns_false(self): + tokens = ["a"] + top_logprobs = [{"a": -0.1}] + # cont_start == len(tokens) - 1 means empty continuation → True + # cont_start > len(tokens) - 1 means start >= end → True + result = LiteLLMClient._check_argmax(tokens, [], top_logprobs, cont_start=5) + assert result is True + + +# --------------------------------------------------------------------------- +# 3. _call_api_text_completion_async — backoff and retry tests +# --------------------------------------------------------------------------- + + +class TestCallApiTextCompletionAsync: + """Tests the async API caller: success, 429 backoff, total failure, semaphore.""" + + # Custom exception standing in for litellm.RateLimitError in tests + class _FakeRateLimitError(Exception): + pass + + class _FakeGenericError(Exception): + pass + + def _run_async(self, coro): + return asyncio.run(coro) + + def test_success_on_first_attempt(self): + client = make_bare_client() + fake_resp = make_response(["Hello", " world", "_gen"], [None, -0.5, -0.1]) + + async def run(): + sem = asyncio.Semaphore(10) + with patch("lighteval.models.endpoints.litellm_model.litellm") as mock_lit: + mock_lit.RateLimitError = self._FakeRateLimitError + mock_lit.atext_completion = AsyncMock(return_value=fake_resp) + return await client._call_api_text_completion_async("Hello world", sem) + + result = self._run_async(run()) + assert result is fake_resp + + def test_rate_limit_429_then_success(self): + """First call raises RateLimitError; second call succeeds.""" + client = make_bare_client(api_max_retry=3) + fake_resp = make_response(["tok", "_gen"], [None, -0.3]) + + async def run(): + sem = asyncio.Semaphore(10) + with patch("lighteval.models.endpoints.litellm_model.litellm") as mock_lit: + mock_lit.RateLimitError = self._FakeRateLimitError + mock_lit.atext_completion = AsyncMock( + side_effect=[self._FakeRateLimitError("429 rate limited"), fake_resp] + ) + return await client._call_api_text_completion_async("test prompt", sem) + + result = self._run_async(run()) + assert result is fake_resp + + def test_generic_error_then_success(self): + """Non-rate-limit transient error is also retried with backoff.""" + client = make_bare_client(api_max_retry=3) + fake_resp = make_response(["t"], [None]) + + async def run(): + sem = asyncio.Semaphore(10) + with patch("lighteval.models.endpoints.litellm_model.litellm") as mock_lit: + mock_lit.RateLimitError = self._FakeRateLimitError + mock_lit.atext_completion = AsyncMock( + side_effect=[self._FakeGenericError("timeout"), fake_resp] + ) + return await client._call_api_text_completion_async("test", sem) + + result = self._run_async(run()) + assert result is fake_resp + + def test_all_retries_exhausted_returns_none(self): + """All API_MAX_RETRY attempts fail → returns None gracefully.""" + client = make_bare_client(api_max_retry=3) + + async def run(): + sem = asyncio.Semaphore(10) + with patch("lighteval.models.endpoints.litellm_model.litellm") as mock_lit: + mock_lit.RateLimitError = self._FakeRateLimitError + mock_lit.atext_completion = AsyncMock( + side_effect=self._FakeGenericError("permanent failure") + ) + return await client._call_api_text_completion_async("test", sem) + + result = self._run_async(run()) + assert result is None + + def test_rate_limit_all_retries_exhausted_returns_none(self): + """Persistent 429 across all retries → None, not an exception.""" + client = make_bare_client(api_max_retry=2) + + async def run(): + sem = asyncio.Semaphore(10) + with patch("lighteval.models.endpoints.litellm_model.litellm") as mock_lit: + mock_lit.RateLimitError = self._FakeRateLimitError + mock_lit.atext_completion = AsyncMock( + side_effect=self._FakeRateLimitError("perpetual 429") + ) + return await client._call_api_text_completion_async("test", sem) + + result = self._run_async(run()) + assert result is None + + def test_semaphore_limits_concurrency(self): + """Semaphore(1) causes calls to serialise; all results are still returned.""" + client = make_bare_client(concurrent_requests=1) + call_order = [] + + async def fake_atext(*args, **kwargs): + call_order.append(kwargs.get("prompt", "?")) + return make_response(["t"], [None]) + + async def run(): + sem = asyncio.Semaphore(1) + with patch("lighteval.models.endpoints.litellm_model.litellm") as mock_lit: + mock_lit.RateLimitError = self._FakeRateLimitError + mock_lit.atext_completion = fake_atext + results = await asyncio.gather( + client._call_api_text_completion_async("A", sem), + client._call_api_text_completion_async("B", sem), + client._call_api_text_completion_async("C", sem), + ) + return results + + results = asyncio.run(run()) + assert len(results) == 3 + assert all(r is not None for r in results) + assert len(call_order) == 3 + + def test_correct_api_parameters_passed(self): + """Verifies echo=True, logprobs=1, max_tokens=1, temperature=0.0 are sent.""" + client = make_bare_client() + client.model = "gpt-3.5-turbo-instruct" + client.api_key = "sk-test" + fake_resp = make_response(["t"], [None]) + + async def run(): + sem = asyncio.Semaphore(10) + with patch("lighteval.models.endpoints.litellm_model.litellm") as mock_lit: + mock_lit.RateLimitError = Exception + mock_lit.atext_completion = AsyncMock(return_value=fake_resp) + await client._call_api_text_completion_async("the prompt", sem) + call_kwargs = mock_lit.atext_completion.call_args.kwargs + return call_kwargs + + kw = asyncio.run(run()) + assert kw["echo"] is True + assert kw["logprobs"] == 1 + assert kw["max_tokens"] == 1 + assert kw["temperature"] == 0.0 + assert kw["prompt"] == "the prompt" + assert kw["model"] == "gpt-3.5-turbo-instruct" + assert kw["api_key"] == "sk-test" + + +# --------------------------------------------------------------------------- +# 4. _process_doc_loglikelihood_async — per-doc processing tests +# --------------------------------------------------------------------------- + + +class TestProcessDocLoglikelihoodAsync: + """Tests the per-doc async processor using patched _call_api_text_completion_async.""" + + def _run_async(self, coro): + return asyncio.run(coro) + + def _make_doc_with_two_choices(self): + return make_doc("Q:", [" A", " B"], gold_index=0) + + def test_basic_two_choices_correct_logprobs(self): + """Correct logprob sums and argmax booleans for a 2-choice doc.""" + client = make_bare_client() + + # context = "Q:" (2 chars), text_offset: " A" starts at offset 2 + resp_a = make_response( + tokens=["Q", ":", " A", "_gen"], + token_logprobs=[None, -0.1, -0.5, -0.9], + top_logprobs=[{"Q": -0.05}, {":": -0.1}, {" A": -0.5}, {"_gen": -0.9}], + text_offset=[0, 1, 2, 4], + ) + # For choice B: " B" is NOT the top-1 (top-1 would be " A") + resp_b = make_response( + tokens=["Q", ":", " B", "_gen"], + token_logprobs=[None, -0.1, -2.0, -0.9], + top_logprobs=[{"Q": -0.05}, {":": -0.1}, {" A": -0.5}, {"_gen": -0.9}], + text_offset=[0, 1, 2, 4], + ) + + async def run(): + sem = asyncio.Semaphore(10) + with patch.object( + client, + "_call_api_text_completion_async", + AsyncMock(side_effect=[resp_a, resp_b]), + ): + return await client._process_doc_loglikelihood_async( + self._make_doc_with_two_choices(), "Q:", sem + ) + + result = self._run_async(run()) + + assert isinstance(result, ModelResponse) + assert len(result.logprobs) == 2 + assert result.logprobs[0] == pytest.approx(-0.5) # only continuation token for A + assert result.logprobs[1] == pytest.approx(-2.0) # only continuation token for B + assert result.argmax_logits_eq_gold[0] is True # " A" was top-1 + assert result.argmax_logits_eq_gold[1] is False # " B" was NOT top-1 + + def test_multi_token_continuation_sums_correctly(self): + """Continuation with two tokens: sum of both logprobs.""" + client = make_bare_client() + + resp = make_response( + tokens=["Q", ":", " yes", " sir", "_gen"], + token_logprobs=[None, -0.1, -0.4, -0.6, -0.9], + top_logprobs=[ + {"Q": -0.05}, {":": -0.1}, + {" yes": -0.4}, {" sir": -0.6}, {"_gen": -0.9}, + ], + text_offset=[0, 1, 2, 6, 10], # context "Q:" ends at char 2 + ) + + doc = make_doc("Q:", [" yes sir"], gold_index=0) + + async def run(): + sem = asyncio.Semaphore(10) + with patch.object(client, "_call_api_text_completion_async", AsyncMock(return_value=resp)): + return await client._process_doc_loglikelihood_async(doc, "Q:", sem) + + result = self._run_async(run()) + # logprobs for " yes" and " sir" → -0.4 + -0.6 = -1.0 + assert result.logprobs[0] == pytest.approx(-1.0) + assert result.argmax_logits_eq_gold[0] is True + + def test_failed_api_call_returns_neg_inf_sentinel(self): + """If the API call returns None, logprob = -inf, argmax = False.""" + client = make_bare_client() + doc = self._make_doc_with_two_choices() + + async def run(): + sem = asyncio.Semaphore(10) + with patch.object( + client, "_call_api_text_completion_async", AsyncMock(return_value=None) + ): + return await client._process_doc_loglikelihood_async(doc, "Q:", sem) + + result = self._run_async(run()) + assert result.logprobs == [float("-inf"), float("-inf")] + assert result.argmax_logits_eq_gold == [False, False] + + def test_none_logprobs_object_in_response_returns_sentinel(self): + """Response with logprobs=None → sentinel values.""" + client = make_bare_client() + doc = make_doc("Q:", [" A"], gold_index=0) + + bad_resp = SimpleNamespace(choices=[SimpleNamespace(logprobs=None)]) + + async def run(): + sem = asyncio.Semaphore(10) + with patch.object( + client, "_call_api_text_completion_async", AsyncMock(return_value=bad_resp) + ): + return await client._process_doc_loglikelihood_async(doc, "Q:", sem) + + result = self._run_async(run()) + assert result.logprobs == [float("-inf")] + assert result.argmax_logits_eq_gold == [False] + + def test_empty_choices_list_returns_empty_response(self): + """Doc with no choices → empty lists (no API calls fired).""" + client = make_bare_client() + doc = make_doc("Q:", choices=[]) + + async def run(): + sem = asyncio.Semaphore(10) + # We verify that atext_completion is never called when choices is empty + with patch.object( + client, "_call_api_text_completion_async", AsyncMock() + ) as mock_call: + result = await client._process_doc_loglikelihood_async(doc, "Q:", sem) + assert mock_call.call_count == 0 + return result + + result = self._run_async(run()) + assert result.logprobs == [] + assert result.argmax_logits_eq_gold == [] + + def test_context_is_prepended_to_each_choice(self): + """Verifies full_text = context + choice is sent for each choice.""" + client = make_bare_client() + doc = make_doc("CTX", [" X", " Y"], gold_index=0) + fake_resp = make_response(["C", "T", "X", " X", "_gen"], [None, -0.1, -0.1, -0.5, -0.9], + text_offset=[0, 1, 2, 3, 5]) + + captured_prompts = [] + + async def fake_call(full_text, semaphore): + captured_prompts.append(full_text) + return fake_resp + + async def run(): + sem = asyncio.Semaphore(10) + with patch.object(client, "_call_api_text_completion_async", side_effect=fake_call): + return await client._process_doc_loglikelihood_async(doc, "CTX", sem) + + self._run_async(run()) + assert captured_prompts == ["CTX X", "CTX Y"] + + def test_result_order_matches_choice_order(self): + """asyncio.gather preserves order; results align 1-to-1 with choices.""" + client = make_bare_client() + doc = make_doc("Q:", [" A", " B", " C"], gold_index=1) + + # Each choice has a distinct logprob so we can verify ordering. + # actual_tok is always " A" (a valid string); top_tok is " A" when + # is_top=True (so actual==top → argmax True) or " Z" otherwise. + def make_choice_resp(choice_lp, is_top): + actual_tok = " A" + top_tok = " A" if is_top else " Z" + return make_response( + tokens=["Q", ":", actual_tok, "_gen"], + token_logprobs=[None, -0.1, float(choice_lp), -0.9], + top_logprobs=[{"Q": -0.1}, {":": -0.1}, {top_tok: float(choice_lp)}, {"_gen": -0.9}], + text_offset=[0, 1, 2, 3], + ) + + resps = [make_choice_resp(-1.0, True), make_choice_resp(-2.0, False), make_choice_resp(-3.0, False)] + + async def run(): + sem = asyncio.Semaphore(10) + with patch.object( + client, "_call_api_text_completion_async", AsyncMock(side_effect=resps) + ): + return await client._process_doc_loglikelihood_async(doc, "Q:", sem) + + result = self._run_async(run()) + assert result.logprobs[0] == pytest.approx(-1.0) + assert result.logprobs[1] == pytest.approx(-2.0) + assert result.logprobs[2] == pytest.approx(-3.0) + assert result.argmax_logits_eq_gold == [True, False, False] + + +# --------------------------------------------------------------------------- +# 5. loglikelihood (full integration) — end-to-end with fully mocked pipeline +# --------------------------------------------------------------------------- + + +class TestLoglikelihoodIntegration: + """Integration tests: exercise loglikelihood() top to bottom. + + All async work is mocked at _process_doc_loglikelihood_async so we validate + the orchestration layer (LoglikelihoodDataset, original ordering, output shape) + without needing live API credentials. + """ + + def _make_known_response(self, logprob_vals, argmax_vals, context="ctx"): + """Convenience: build a ModelResponse with explicit per-choice values.""" + return ModelResponse( + input=context, + logprobs=list(logprob_vals), + argmax_logits_eq_gold=list(argmax_vals), + ) + + def test_two_docs_output_shape_and_order(self): + """Two docs returned in original input order.""" + client = make_bare_client() + + docs = [ + make_doc("Question 1:", [" A", " B", " C", " D"], gold_index=0, doc_id="0"), + make_doc("Q2:", [" W", " X"], gold_index=1, doc_id="1"), + ] + + resp0 = self._make_known_response([-0.3, -1.5, -2.0, -3.0], [True, False, False, False]) + resp1 = self._make_known_response([-1.8, -0.2], [False, True]) + + # Map doc_id → pre-built response + responses_by_id = {"0": resp0, "1": resp1} + + async def fake_process_doc(doc, context_str, semaphore): + return responses_by_id[doc.id] + + with patch.object(client, "_process_doc_loglikelihood_async", side_effect=fake_process_doc), \ + patch.object(type(client), "disable_tqdm", new_callable=PropertyMock, return_value=True): + results = client.loglikelihood(docs) + + assert len(results) == 2 + + # Original order preserved (doc "0" first, "1" second) + assert results[0].logprobs == pytest.approx([-0.3, -1.5, -2.0, -3.0]) + assert results[1].logprobs == pytest.approx([-1.8, -0.2]) + assert results[0].argmax_logits_eq_gold == [True, False, False, False] + assert results[1].argmax_logits_eq_gold == [False, True] + + def test_single_doc_four_choices(self): + """Single doc, 4 choices, correct output shape.""" + client = make_bare_client() + doc = make_doc("Q:", [" A", " B", " C", " D"], gold_index=2, doc_id="0") + + pre_resp = self._make_known_response([-2.0, -1.0, -0.5, -3.0], [False, False, True, False]) + + async def fake_process(doc, context_str, semaphore): + return pre_resp + + with patch.object(client, "_process_doc_loglikelihood_async", side_effect=fake_process), \ + patch.object(type(client), "disable_tqdm", new_callable=PropertyMock, return_value=True): + results = client.loglikelihood([doc]) + + assert len(results) == 1 + assert len(results[0].logprobs) == 4 + # Choice index 2 has the highest (least negative) logprob + assert results[0].logprobs.index(max(results[0].logprobs)) == 2 + + def test_original_order_restored_after_dataset_sorting(self): + """LoglikelihoodDataset sorts by prompt length; loglikelihood must un-sort.""" + client = make_bare_client() + + # Deliberately create docs with different query lengths so the dataset + # reorders them. The short query will be sorted first by LoglikelihoodDataset. + short_doc = make_doc("Q?", [" Y", " N"], gold_index=0, doc_id="short") + long_doc = make_doc( + "This is a much longer question that triggers reordering:", + [" A", " B"], + gold_index=1, + doc_id="long", + ) + + resp_short = self._make_known_response([-0.1, -1.0], [True, False], context="Q?") + resp_long = self._make_known_response([-2.0, -0.3], [False, True]) + + responses_by_id = {"short": resp_short, "long": resp_long} + + async def fake_process(doc, context_str, semaphore): + return responses_by_id[doc.id] + + # Input order: [short, long]. Dataset sorts long first. loglikelihood must + # restore original order → results[0] is short_doc's response. + with patch.object(client, "_process_doc_loglikelihood_async", side_effect=fake_process), \ + patch.object(type(client), "disable_tqdm", new_callable=PropertyMock, return_value=True): + results = client.loglikelihood([short_doc, long_doc]) + + assert len(results) == 2 + # results[0] must correspond to short_doc (original position 0) + assert results[0].logprobs == pytest.approx([-0.1, -1.0]) + # results[1] must correspond to long_doc (original position 1) + assert results[1].logprobs == pytest.approx([-2.0, -0.3]) + + def test_all_api_failures_return_neg_inf_per_choice(self): + """Graceful degradation: all docs return -inf when API is completely down.""" + client = make_bare_client() + docs = [make_doc("Q:", [" A", " B"], gold_index=0, doc_id=str(i)) for i in range(3)] + + async def fake_process(doc, context_str, semaphore): + return ModelResponse( + input=context_str, + logprobs=[float("-inf"), float("-inf")], + argmax_logits_eq_gold=[False, False], + ) + + with patch.object(client, "_process_doc_loglikelihood_async", side_effect=fake_process), \ + patch.object(type(client), "disable_tqdm", new_callable=PropertyMock, return_value=True): + results = client.loglikelihood(docs) + + assert len(results) == 3 + for r in results: + assert r.logprobs == [float("-inf"), float("-inf")] + assert r.argmax_logits_eq_gold == [False, False] + + +# --------------------------------------------------------------------------- +# 6. _check_text_completion_support — provider guard tests +# --------------------------------------------------------------------------- + + +class TestCheckTextCompletionSupport: + """The guard should warn when 'echo' is absent from litellm's param list and + stay silent (no exception) when 'echo' is present or the check itself fails.""" + + def _make_client(self, model="gpt-3.5-turbo-instruct", provider="openai"): + client = make_bare_client(model=model) + client.provider = provider + return client + + def test_no_warning_when_echo_supported(self, caplog): + client = self._make_client() + with patch("lighteval.models.endpoints.litellm_model.litellm") as mock_lit: + mock_lit.get_supported_openai_params = MagicMock(return_value=["echo", "logprobs", "max_tokens"]) + import logging + with caplog.at_level(logging.WARNING, logger="lighteval.models.endpoints.litellm_model"): + client._check_text_completion_support() + assert "echo" not in caplog.text or "does not list" not in caplog.text + + def test_warning_emitted_when_echo_not_supported(self, caplog): + client = self._make_client(model="gpt-4o", provider="openai") + with patch("lighteval.models.endpoints.litellm_model.litellm") as mock_lit: + mock_lit.get_model_info = MagicMock(return_value={"mode": "chat"}) + import logging + with caplog.at_level(logging.WARNING, logger="lighteval.models.endpoints.litellm_model"): + client._check_text_completion_support() + assert "chat-only" in caplog.text + + def test_no_crash_when_registry_lookup_raises(self): + """If litellm's param registry explodes, the guard must stay silent.""" + client = self._make_client() + with patch("lighteval.models.endpoints.litellm_model.litellm") as mock_lit: + mock_lit.get_supported_openai_params = MagicMock(side_effect=RuntimeError("registry unavailable")) + client._check_text_completion_support() # must not raise + + def test_no_crash_when_params_returns_none(self): + client = self._make_client() + with patch("lighteval.models.endpoints.litellm_model.litellm") as mock_lit: + mock_lit.get_supported_openai_params = MagicMock(return_value=None) + client._check_text_completion_support() # None → treated as empty list, no crash + + def test_warning_contains_model_name(self, caplog): + client = self._make_client(model="claude-3-opus", provider="anthropic") + with patch("lighteval.models.endpoints.litellm_model.litellm") as mock_lit: + mock_lit.get_model_info = MagicMock(return_value={"mode": "chat"}) + import logging + with caplog.at_level(logging.WARNING, logger="lighteval.models.endpoints.litellm_model"): + client._check_text_completion_support() + assert "claude-3-opus" in caplog.text + + def test_loglikelihood_calls_guard(self): + """loglikelihood() must call _check_text_completion_support before processing.""" + client = make_bare_client() + guard_called = {"n": 0} + + def fake_guard(self_inner): + guard_called["n"] += 1 + + async def fake_process(doc, context_str, semaphore): + return ModelResponse(input=context_str, logprobs=[-0.5], argmax_logits_eq_gold=[True]) + + with patch.object(LiteLLMClient, "_check_text_completion_support", fake_guard), \ + patch.object(client, "_process_doc_loglikelihood_async", side_effect=fake_process), \ + patch.object(type(client), "disable_tqdm", new_callable=PropertyMock, return_value=True): + client.loglikelihood([make_doc("Q:", [" A"], gold_index=0, doc_id="0")]) + + assert guard_called["n"] == 1 + + def test_loglikelihood_rolling_calls_guard(self): + """loglikelihood_rolling() must also call _check_text_completion_support.""" + client = make_bare_client() + guard_called = {"n": 0} + + def fake_guard(self_inner): + guard_called["n"] += 1 + + async def fake_process(doc, semaphore): + return ModelResponse(input=doc.query, logprobs=[-0.1, -0.2]) + + with patch.object(LiteLLMClient, "_check_text_completion_support", fake_guard), \ + patch.object(client, "_process_doc_rolling_async", side_effect=fake_process), \ + patch.object(type(client), "disable_tqdm", new_callable=PropertyMock, return_value=True): + client.loglikelihood_rolling([make_doc("Hello world", choices=[], gold_index=0, doc_id="0")]) + + assert guard_called["n"] == 1 + + +# --------------------------------------------------------------------------- +# 7. _process_doc_rolling_async — per-token perplexity tests +# --------------------------------------------------------------------------- + + +class TestProcessDocRollingAsync: + """Tests the per-document rolling log-likelihood computation.""" + + def _run(self, coro): + return asyncio.run(coro) + + def test_basic_rolling_sums_all_token_logprobs(self): + """token_logprobs[1:-1] are the valid rolling logprobs.""" + client = make_bare_client() + doc = make_doc("Hello world", choices=[], gold_index=0) + + # 5 tokens: [null, -0.3, -0.5, -0.2, generated] + # rolling = [-0.3, -0.5, -0.2] (indices 1..3, excluding last) + resp = make_response( + tokens=["Hello", " world", " foo", " bar", "_gen"], + token_logprobs=[None, -0.3, -0.5, -0.2, -0.9], + ) + + async def run(): + sem = asyncio.Semaphore(10) + with patch.object(client, "_call_api_text_completion_async", AsyncMock(return_value=resp)): + return await client._process_doc_rolling_async(doc, sem) + + result = self._run(run()) + assert isinstance(result, ModelResponse) + assert result.logprobs == pytest.approx([-0.3, -0.5, -0.2]) + + def test_single_token_doc_returns_empty_logprobs(self): + """A 1-token document: token_logprobs = [None, generated] → nothing to sum.""" + client = make_bare_client() + doc = make_doc("Hi", choices=[], gold_index=0) + + resp = make_response( + tokens=["Hi", "_gen"], + token_logprobs=[None, -0.9], + ) + + async def run(): + sem = asyncio.Semaphore(10) + with patch.object(client, "_call_api_text_completion_async", AsyncMock(return_value=resp)): + return await client._process_doc_rolling_async(doc, sem) + + result = self._run(run()) + assert result.logprobs == [] + + def test_failed_api_call_returns_neg_inf(self): + client = make_bare_client() + doc = make_doc("Hello", choices=[], gold_index=0) + + async def run(): + sem = asyncio.Semaphore(10) + with patch.object(client, "_call_api_text_completion_async", AsyncMock(return_value=None)): + return await client._process_doc_rolling_async(doc, sem) + + result = self._run(run()) + assert result.logprobs == [float("-inf")] + + def test_null_logprobs_in_response(self): + client = make_bare_client() + doc = make_doc("Hello", choices=[], gold_index=0) + bad_resp = SimpleNamespace(choices=[SimpleNamespace(logprobs=None)]) + + async def run(): + sem = asyncio.Semaphore(10) + with patch.object(client, "_call_api_text_completion_async", AsyncMock(return_value=bad_resp)): + return await client._process_doc_rolling_async(doc, sem) + + result = self._run(run()) + assert result.logprobs == [float("-inf")] + + def test_none_values_in_token_logprobs_are_filtered(self): + """Unexpected None values mid-sequence are skipped gracefully.""" + client = make_bare_client() + doc = make_doc("A B C", choices=[], gold_index=0) + + resp = make_response( + tokens=["A", " B", " C", "_gen"], + token_logprobs=[None, -0.4, None, -0.9], # middle None is unusual but handled + ) + + async def run(): + sem = asyncio.Semaphore(10) + with patch.object(client, "_call_api_text_completion_async", AsyncMock(return_value=resp)): + return await client._process_doc_rolling_async(doc, sem) + + result = self._run(run()) + # token_logprobs[1:-1] = [-0.4, None] → filter None → [-0.4] + assert result.logprobs == pytest.approx([-0.4]) + + def test_correct_prompt_sent_to_api(self): + """The full plain-text doc is sent as the prompt.""" + client = make_bare_client() + doc = make_doc("The quick brown fox", choices=[], gold_index=0) + captured = {} + + async def fake_call(full_text, semaphore): + captured["prompt"] = full_text + return make_response(["The", " quick", "_gen"], [None, -0.3, -0.9]) + + async def run(): + sem = asyncio.Semaphore(10) + with patch.object(client, "_call_api_text_completion_async", side_effect=fake_call): + return await client._process_doc_rolling_async(doc, sem) + + self._run(run()) + assert captured["prompt"] == "The quick brown fox" + + +# --------------------------------------------------------------------------- +# 8. loglikelihood_rolling integration +# --------------------------------------------------------------------------- + + +class TestLoglikelihoodRollingIntegration: + """End-to-end tests for loglikelihood_rolling() orchestration.""" + + def test_three_docs_correct_shape_and_order(self): + client = make_bare_client() + docs = [ + make_doc("Doc one text", choices=[], gold_index=0, doc_id="0"), + make_doc("A much longer second document for sorting test", choices=[], gold_index=0, doc_id="1"), + make_doc("Short", choices=[], gold_index=0, doc_id="2"), + ] + + # Pre-built per-doc responses keyed by doc id + responses = { + "0": ModelResponse(input="Doc one text", logprobs=[-0.3, -0.5]), + "1": ModelResponse(input="...", logprobs=[-0.1, -0.2, -0.4]), + "2": ModelResponse(input="Short", logprobs=[-0.8]), + } + + async def fake_rolling(doc, semaphore): + return responses[doc.id] + + with patch.object(client, "_process_doc_rolling_async", side_effect=fake_rolling), \ + patch.object(type(client), "disable_tqdm", new_callable=PropertyMock, return_value=True): + results = client.loglikelihood_rolling(docs) + + assert len(results) == 3 + # Original order must be preserved (LoglikelihoodDataset sorts internally) + assert results[0].logprobs == pytest.approx([-0.3, -0.5]) + assert results[1].logprobs == pytest.approx([-0.1, -0.2, -0.4]) + assert results[2].logprobs == pytest.approx([-0.8]) + + def test_perplexity_sum_compatible(self): + """np.sum(result.logprobs) must give the total document log-likelihood.""" + import numpy as np + + client = make_bare_client() + doc = make_doc("Hello world", choices=[], gold_index=0, doc_id="0") + + async def fake_rolling(doc, semaphore): + return ModelResponse(input=doc.query, logprobs=[-0.3, -0.5, -0.2]) + + with patch.object(client, "_process_doc_rolling_async", side_effect=fake_rolling), \ + patch.object(type(client), "disable_tqdm", new_callable=PropertyMock, return_value=True): + results = client.loglikelihood_rolling([doc]) + + total_logprob = float(np.sum(results[0].logprobs)) + assert total_logprob == pytest.approx(-1.0) + + +# --------------------------------------------------------------------------- +# 9. Provider guard — improved mode-based detection +# --------------------------------------------------------------------------- + + +class TestCheckTextCompletionSupportModeDetection: + """The guard uses litellm.get_model_info() 'mode' field, not the params list.""" + + def _make_client(self, model="gpt-3.5-turbo-instruct", provider="openai"): + client = make_bare_client(model=model) + client.provider = provider + return client + + def test_no_warning_for_completion_mode_model(self, caplog): + """mode='completion' → no warning (correct: model supports text_completion).""" + client = self._make_client("gpt-3.5-turbo-instruct") + with patch("lighteval.models.endpoints.litellm_model.litellm") as mock_lit: + mock_lit.get_model_info = MagicMock(return_value={"mode": "completion"}) + import logging + with caplog.at_level(logging.WARNING, logger="lighteval.models.endpoints.litellm_model"): + client._check_text_completion_support() + assert "chat-only" not in caplog.text + + def test_warning_for_chat_mode_model(self, caplog): + """mode='chat' → warning emitted.""" + client = self._make_client("gpt-4o", "openai") + with patch("lighteval.models.endpoints.litellm_model.litellm") as mock_lit: + mock_lit.get_model_info = MagicMock(return_value={"mode": "chat"}) + import logging + with caplog.at_level(logging.WARNING, logger="lighteval.models.endpoints.litellm_model"): + client._check_text_completion_support() + assert "chat-only" in caplog.text + assert "gpt-4o" in caplog.text + + def test_no_warning_when_mode_field_absent(self, caplog): + """mode not present in model_info → unknown model, proceed silently.""" + client = self._make_client("custom-model") + with patch("lighteval.models.endpoints.litellm_model.litellm") as mock_lit: + mock_lit.get_model_info = MagicMock(return_value={"max_tokens": 4096}) + import logging + with caplog.at_level(logging.WARNING, logger="lighteval.models.endpoints.litellm_model"): + client._check_text_completion_support() + assert "chat-only" not in caplog.text + + def test_no_crash_when_get_model_info_raises(self): + """Any exception in model_info lookup must not propagate.""" + client = self._make_client() + with patch("lighteval.models.endpoints.litellm_model.litellm") as mock_lit: + mock_lit.get_model_info = MagicMock(side_effect=KeyError("model not in registry")) + client._check_text_completion_support() # must not raise + + def test_no_crash_when_get_model_info_returns_none(self): + client = self._make_client() + with patch("lighteval.models.endpoints.litellm_model.litellm") as mock_lit: + mock_lit.get_model_info = MagicMock(return_value=None) + client._check_text_completion_support() # None → treated as {} → no warning + + +# --------------------------------------------------------------------------- +# 10. Seed forwarding +# --------------------------------------------------------------------------- + + +class TestSeedForwarding: + """generation_parameters.seed must be passed to every text_completion call.""" + + class _FakeRateLimitError(Exception): + pass + + def test_seed_forwarded_via_text_completion_dict(self): + """seed from generation_parameters flows into the API call via to_litellm_text_completion_dict.""" + client = make_bare_client() + client.generation_parameters = GenerationParameters(seed=42) + + fake_resp = make_response(["t"], [None]) + + async def run(): + sem = asyncio.Semaphore(10) + with patch("lighteval.models.endpoints.litellm_model.litellm") as mock_lit: + mock_lit.RateLimitError = self._FakeRateLimitError + mock_lit.atext_completion = AsyncMock(return_value=fake_resp) + await client._call_api_text_completion_async("hello", sem) + kw = mock_lit.atext_completion.call_args.kwargs + return kw + + kw = asyncio.run(run()) + assert kw["seed"] == 42 + + def test_no_seed_key_when_seed_not_set(self): + """When seed is None, to_litellm_text_completion_dict omits it entirely + (litellm.drop_params handles the rest).""" + client = make_bare_client() + client.generation_parameters = GenerationParameters() # seed=None by default + + fake_resp = make_response(["t"], [None]) + + async def run(): + sem = asyncio.Semaphore(10) + with patch("lighteval.models.endpoints.litellm_model.litellm") as mock_lit: + mock_lit.RateLimitError = self._FakeRateLimitError + mock_lit.atext_completion = AsyncMock(return_value=fake_resp) + await client._call_api_text_completion_async("hello", sem) + kw = mock_lit.atext_completion.call_args.kwargs + return kw + + kw = asyncio.run(run()) + # seed was not set → not in the dict (omitted by to_litellm_text_completion_dict) + assert "seed" not in kw + + def test_stop_tokens_forwarded(self): + """stop_tokens from generation_parameters flows through to the API call.""" + client = make_bare_client() + client.generation_parameters = GenerationParameters(stop_tokens=["\n", "END"]) + fake_resp = make_response(["t"], [None]) + + async def run(): + sem = asyncio.Semaphore(10) + with patch("lighteval.models.endpoints.litellm_model.litellm") as mock_lit: + mock_lit.RateLimitError = self._FakeRateLimitError + mock_lit.atext_completion = AsyncMock(return_value=fake_resp) + await client._call_api_text_completion_async("hello", sem) + return mock_lit.atext_completion.call_args.kwargs + + kw = asyncio.run(run()) + assert kw["stop"] == ["\n", "END"] + + +# --------------------------------------------------------------------------- +# 11. Input length guard +# --------------------------------------------------------------------------- + + +class TestWarnIfTooLong: + """_warn_if_too_long emits a WARNING when encode() returns more tokens than max_length.""" + + def _make_client(self, max_length=10): + client = make_bare_client() + client._max_length = max_length + return client + + def test_warning_when_over_limit(self, caplog): + client = self._make_client(max_length=3) + with patch("lighteval.models.endpoints.litellm_model.encode", return_value=list(range(5))): + import logging + with caplog.at_level(logging.WARNING, logger="lighteval.models.endpoints.litellm_model"): + client._warn_if_too_long("some long text", label="test") + assert "exceeds max_length" in caplog.text + assert "5" in caplog.text # token count shown + assert "3" in caplog.text # max_length shown + + def test_no_warning_when_within_limit(self, caplog): + client = self._make_client(max_length=100) + with patch("lighteval.models.endpoints.litellm_model.encode", return_value=list(range(5))): + import logging + with caplog.at_level(logging.WARNING, logger="lighteval.models.endpoints.litellm_model"): + client._warn_if_too_long("short text") + assert "exceeds" not in caplog.text + + def test_no_crash_when_encode_raises(self): + client = self._make_client() + with patch("lighteval.models.endpoints.litellm_model.encode", side_effect=RuntimeError("tiktoken missing")): + client._warn_if_too_long("text") # must not raise + + def test_no_warning_when_max_length_is_none(self, caplog): + """Unknown context window → skip silently.""" + client = make_bare_client() + client._max_length = None + with patch("lighteval.models.endpoints.litellm_model.encode", return_value=list(range(999))): + import logging + with caplog.at_level(logging.WARNING, logger="lighteval.models.endpoints.litellm_model"): + client._warn_if_too_long("very long text") + assert "exceeds" not in caplog.text + + def test_label_appears_in_warning(self, caplog): + client = self._make_client(max_length=1) + with patch("lighteval.models.endpoints.litellm_model.encode", return_value=[1, 2, 3]): + import logging + with caplog.at_level(logging.WARNING, logger="lighteval.models.endpoints.litellm_model"): + client._warn_if_too_long("text", label="doc '42' longest choice") + assert "doc '42'" in caplog.text + + def test_length_guard_called_in_process_doc_loglikelihood(self): + """_warn_if_too_long is called once per doc using the longest choice.""" + client = make_bare_client() + warn_calls = [] + + def fake_warn(text, label=""): + warn_calls.append((text, label)) + + # Use choices of clearly different lengths so max(key=len) is deterministic + doc = make_doc("Q:", [" A", " longer_choice"], gold_index=0) + fake_resp = make_response(["Q", ":", " A", "_gen"], [None, -0.1, -0.5, -0.9], + text_offset=[0, 1, 2, 4]) + + async def run(): + sem = asyncio.Semaphore(10) + with patch.object(client, "_warn_if_too_long", side_effect=fake_warn), \ + patch.object(client, "_call_api_text_completion_async", + AsyncMock(return_value=fake_resp)): + return await client._process_doc_loglikelihood_async(doc, "Q:", sem) + + asyncio.run(run()) + assert len(warn_calls) == 1 # called exactly once per doc + assert "longer_choice" in warn_calls[0][0] # longest choice was used + assert warn_calls[0][0].startswith("Q:") # context is prepended + + def test_length_guard_called_in_process_doc_rolling(self): + """_warn_if_too_long is called in _process_doc_rolling_async.""" + client = make_bare_client() + warn_calls = [] + + def fake_warn(text, label=""): + warn_calls.append(text) + + doc = make_doc("Hello world", choices=[], gold_index=0) + fake_resp = make_response(["Hello", " world", "_gen"], [None, -0.3, -0.9]) + + async def run(): + sem = asyncio.Semaphore(10) + with patch.object(client, "_warn_if_too_long", side_effect=fake_warn), \ + patch.object(client, "_call_api_text_completion_async", + AsyncMock(return_value=fake_resp)): + return await client._process_doc_rolling_async(doc, sem) + + asyncio.run(run()) + assert len(warn_calls) == 1 + assert warn_calls[0] == "Hello world" + + +# --------------------------------------------------------------------------- +# Regression: PR #1192 — greedy_until iterates split, not full dataset +# --------------------------------------------------------------------------- + + +class TestGreedyUntilSplitFix: + """Regression test: greedy_until must build contexts from the current split only. + + The bug: `contexts = [prepare_prompt_api(doc) for doc in dataset]` iterated + the entire dataset on every split iteration, causing each doc to be processed + `num_splits` times instead of once. The fix uses `for doc in split`. + """ + + def _make_greedy_doc(self, query, generation_size=32, doc_id="0"): + doc = Doc(query=query, choices=[], gold_index=0, task_name="test", generation_size=generation_size) + doc.id = doc_id + return doc + + def _make_chat_response(self, content="answer"): + choice = MagicMock() + choice.message.content = content + choice.message.reasoning_content = None + resp = MagicMock() + resp.choices = [choice] + return resp + + def test_each_doc_prepared_exactly_once_across_two_splits(self): + """Docs with different generation_size land in separate splits. + prepare_prompt_api must be called once per doc, not once per split × N docs. + """ + client = make_bare_client() + + # Different generation_size → GenerativeTaskDataset puts them in different splits + doc_a = self._make_greedy_doc("Prompt A", generation_size=16, doc_id="a") + doc_b = self._make_greedy_doc("Prompt B", generation_size=32, doc_id="b") + docs = [doc_a, doc_b] + + prepared = [] + + def tracking_prepare(doc): + prepared.append(doc.query) + return [{"role": "user", "content": doc.query}] + + client.prompt_manager.prepare_prompt_api.side_effect = tracking_prepare + + def fake_parallel(contexts, *args, **kwargs): + return [self._make_chat_response() for _ in contexts] + + with patch.object( + client, "_LiteLLMClient__call_api_parallel", side_effect=fake_parallel + ), patch.object(type(client), "disable_tqdm", new_callable=PropertyMock, return_value=True): + results = client.greedy_until(docs) + + assert len(prepared) == 2, ( + f"Expected 2 prepare_prompt_api calls (one per doc), got {len(prepared)}. " + "Regression: greedy_until iterated full dataset instead of current split." + ) + assert set(prepared) == {"Prompt A", "Prompt B"} + assert len(results) == 2 + + def test_single_split_all_docs_processed(self): + """Sanity: one split (same generation_size) — all docs processed correctly.""" + client = make_bare_client() + + docs = [ + self._make_greedy_doc("Prompt A", generation_size=32, doc_id="a"), + self._make_greedy_doc("Prompt B", generation_size=32, doc_id="b"), + self._make_greedy_doc("Prompt C", generation_size=32, doc_id="c"), + ] + + prepared = [] + + def tracking_prepare(doc): + prepared.append(doc.query) + return [{"role": "user", "content": doc.query}] + + client.prompt_manager.prepare_prompt_api.side_effect = tracking_prepare + + def fake_parallel(contexts, *args, **kwargs): + return [self._make_chat_response() for _ in contexts] + + with patch.object( + client, "_LiteLLMClient__call_api_parallel", side_effect=fake_parallel + ), patch.object(type(client), "disable_tqdm", new_callable=PropertyMock, return_value=True): + results = client.greedy_until(docs) + + assert len(prepared) == 3 + assert set(prepared) == {"Prompt A", "Prompt B", "Prompt C"} + assert len(results) == 3 diff --git a/tests/unit/models/test_model_input.py b/tests/unit/models/test_model_input.py index 7c06df445..4dfdd414b 100644 --- a/tests/unit/models/test_model_input.py +++ b/tests/unit/models/test_model_input.py @@ -47,3 +47,106 @@ def test_extract_num_samples(self, model_args: str, expected): gen = GenerationParameters.from_model_args(model_args) for k, v in expected.items(): assert getattr(gen, k) == v + + +class TestToLitellmTextCompletionDict: + """Tests for GenerationParameters.to_litellm_text_completion_dict().""" + + def test_all_none_returns_empty_dict(self): + gen = GenerationParameters() + result = gen.to_litellm_text_completion_dict() + assert result == {} + + def test_seed_included_when_set(self): + gen = GenerationParameters(seed=42) + assert gen.to_litellm_text_completion_dict()["seed"] == 42 + + def test_stop_tokens_included_when_set(self): + gen = GenerationParameters(stop_tokens=["\n", "END"]) + assert gen.to_litellm_text_completion_dict()["stop"] == ["\n", "END"] + + def test_top_p_included_when_set(self): + gen = GenerationParameters(top_p=0.9) + assert gen.to_litellm_text_completion_dict()["top_p"] == pytest.approx(0.9) + + def test_frequency_penalty_included(self): + gen = GenerationParameters(frequency_penalty=0.5) + assert gen.to_litellm_text_completion_dict()["frequency_penalty"] == pytest.approx(0.5) + + def test_presence_penalty_included(self): + gen = GenerationParameters(presence_penalty=0.3) + assert gen.to_litellm_text_completion_dict()["presence_penalty"] == pytest.approx(0.3) + + def test_max_new_tokens_not_included(self): + """max_new_tokens belongs to the caller (hardcoded to 1 for loglikelihood).""" + gen = GenerationParameters(max_new_tokens=256) + assert "max_new_tokens" not in gen.to_litellm_text_completion_dict() + assert "max_tokens" not in gen.to_litellm_text_completion_dict() + assert "max_completion_tokens" not in gen.to_litellm_text_completion_dict() + + def test_temperature_not_included(self): + """temperature is hardcoded to 0.0 by the caller for deterministic scoring.""" + gen = GenerationParameters(temperature=0.7) + assert "temperature" not in gen.to_litellm_text_completion_dict() + + def test_chat_only_params_not_included(self): + """repetition_penalty is chat-specific and absent from text_completion.""" + gen = GenerationParameters(repetition_penalty=1.2) + assert "repetition_penalty" not in gen.to_litellm_text_completion_dict() + + def test_full_config_only_returns_non_none(self): + gen = GenerationParameters(seed=1, top_p=0.95, stop_tokens=["\n"]) + result = gen.to_litellm_text_completion_dict() + assert set(result.keys()) == {"seed", "top_p", "stop"} + assert result["seed"] == 1 + assert result["top_p"] == pytest.approx(0.95) + assert result["stop"] == ["\n"] + + +class TestToLitellmDict: + """Tests for GenerationParameters.to_litellm_dict() — regression coverage for PR #1193.""" + + def test_presence_penalty_included(self): + """presence_penalty must not be silently dropped (bug fix PR #1193).""" + gen = GenerationParameters(presence_penalty=0.4) + result = gen.to_litellm_dict() + assert "presence_penalty" in result + assert result["presence_penalty"] == pytest.approx(0.4) + + def test_temperature_zero_included(self): + """temperature=0 (the default) is included since 0 is not None.""" + gen = GenerationParameters() + result = gen.to_litellm_dict() + assert "temperature" in result + assert result["temperature"] == 0 + + def test_all_params_forwarded(self): + """All chat-completion params are forwarded with correct key names.""" + gen = GenerationParameters( + max_new_tokens=200, + stop_tokens=["\n", "END"], + temperature=0.8, + top_p=0.95, + seed=7, + repetition_penalty=1.05, + frequency_penalty=0.1, + presence_penalty=0.2, + ) + result = gen.to_litellm_dict() + assert result["max_completion_tokens"] == 200 + assert result["stop"] == ["\n", "END"] + assert result["temperature"] == pytest.approx(0.8) + assert result["top_p"] == pytest.approx(0.95) + assert result["seed"] == 7 + assert result["repetition_penalty"] == pytest.approx(1.05) + assert result["frequency_penalty"] == pytest.approx(0.1) + assert result["presence_penalty"] == pytest.approx(0.2) + + def test_none_values_excluded(self): + """Parameters left as None are not forwarded.""" + gen = GenerationParameters(temperature=0.5) + result = gen.to_litellm_dict() + assert "max_completion_tokens" not in result + assert "stop" not in result + assert "seed" not in result + assert "top_p" not in result diff --git a/tests/unit/utils/test_caching.py b/tests/unit/utils/test_caching.py index 7ab8644be..d1748448e 100644 --- a/tests/unit/utils/test_caching.py +++ b/tests/unit/utils/test_caching.py @@ -23,7 +23,7 @@ import tempfile import unittest from dataclasses import asdict -from unittest.mock import Mock, patch +from unittest.mock import AsyncMock, Mock, patch import pytest import torch @@ -83,6 +83,11 @@ def test_cache_directory_structure(self): DummyModelConfig, ] + if is_package_available("litellm"): + from lighteval.models.endpoints.litellm_model import LiteLLMModelConfig + + model_configs.append(LiteLLMModelConfig) + for model_config in model_configs: with self.subTest(model_config=model_config): with tempfile.TemporaryDirectory() as temp_dir: @@ -127,6 +132,11 @@ def test_cache_decorator_presence(self): SGLangModel, DummyModel, ] + + if is_package_available("litellm"): + from lighteval.models.endpoints.litellm_model import LiteLLMClient + + model_classes.append(LiteLLMClient) methods_to_check = ["greedy_until", "loglikelihood", "loglikelihood_rolling"] for model_class in model_classes: @@ -343,3 +353,28 @@ def test_cache_vlm_transformers(self, mock_create_model, mock_accelerator, mock_ ("greedy_until", SamplingMethod.GENERATIVE), ], ) + + def test_cache_litellm(self): + """Test that @cached works correctly for LiteLLMClient loglikelihood methods.""" + if not is_package_available("litellm"): + self.skipTest("litellm not installed") + + from lighteval.models.endpoints.litellm_model import LiteLLMClient, LiteLLMModelConfig + + with tempfile.TemporaryDirectory() as temp_dir: + config = LiteLLMModelConfig(model_name="gpt-3.5-turbo-instruct", cache_dir=temp_dir) + model = LiteLLMClient(config) + + # _loglikelihood_async / _loglikelihood_rolling_async are the internal + # async methods called by the public @cached-decorated methods. + # AsyncMock makes them return coroutines that resolve to our test responses. + with patch.object(model, "_loglikelihood_async", AsyncMock(return_value=self.model_responses)), \ + patch.object(model, "_loglikelihood_rolling_async", AsyncMock(return_value=self.model_responses)), \ + patch.object(model, "_check_text_completion_support"): # suppress provider warning + self._test_cache( + model, + [ + ("loglikelihood", SamplingMethod.LOGPROBS), + ("loglikelihood_rolling", SamplingMethod.PERPLEXITY), + ], + )