Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions olive/evaluator/lmeval_ort.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,10 @@ def __init__(
self.config.set_provider_option(ep, key, value)
self.model = og.Model(self.config)
self.tokenizer = og.Tokenizer(self.model)
# HF tokenizer kept solely to render `apply_chat_template`; generation
# still uses og.Tokenizer above.
self._pretrained = str(pretrained)
self._hf_tokenizer = AutoTokenizer.from_pretrained(self._pretrained)
Comment thread
ykhrustalev marked this conversation as resolved.
Outdated

# consider adding auto batch sizes
self.batch_size = int(batch_size)
Expand All @@ -521,6 +525,24 @@ def __init__(
self.device = device
self._returns_full_logits = self._detect_full_logits()

@property
def tokenizer_name(self) -> str:
"""Identifier used by lm-eval for chat-template-aware caching."""
return self._pretrained.replace("/", "__")
Comment thread
ykhrustalev marked this conversation as resolved.
Outdated

def apply_chat_template(self, chat_history: list[dict], add_generation_prompt: bool = True) -> str:
"""Render a chat history through the model's HF chat template.

Required by lm-eval when `apply_chat_template=True` is passed to
`simple_evaluate`; without it, lm-eval raises NotImplementedError.
"""
return self._hf_tokenizer.apply_chat_template(
chat_history,
tokenize=False,
add_generation_prompt=add_generation_prompt,
continue_final_message=not add_generation_prompt,
)

def _detect_full_logits(self) -> bool:
"""Check if the model returns logits for all input positions or only the last."""
try:
Expand Down
Loading