Skip to content
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 163 additions & 3 deletions olive/evaluator/lmeval_ort.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,151 @@
self._batch_size = batch_size


@register_model("ort-multimodal")
class LMEvalORTMultimodalEvaluator(LMEvalOnnxBase):
"""Evaluate a multimodal ONNX model using direct ORT InferenceSession.

Designed for ORT GenAI multimodal packages (e.g. Gemma4) that have separate
decoder and embedding ONNX models. Uses direct session.run() instead of
GenAI's Generator API, avoiding the overhead of loading all sub-models
and creating Generator objects per call.

Supports models with heterogeneous KV cache head dimensions (e.g. Gemma4
with head_dim=256 for sliding attention and head_dim=512 for full attention),
which the standard 'ort' backend cannot handle.
"""

def __init__(
self,
pretrained: str,
batch_size: int | str = 1,
max_length: int | None = None,
ep: str | None = None,
ep_options: dict | None = None,
**kwargs,
):
"""Initialize the evaluator.

:param pretrained: Path to the ORT GenAI model directory containing
genai_config.json, decoder/, embedding/, and tokenizer files.
:param batch_size: Batch size for evaluation.
:param max_length: Maximum sequence length. Defaults to config value.
:param ep: Execution provider (e.g. 'CUDAExecutionProvider').
:param ep_options: Provider options dict.
"""
import onnxruntime as ort

super().__init__()

model_dir = Path(pretrained)

# Load genai_config to find model paths and metadata
with (model_dir / "genai_config.json").open() as f:
genai_config = json.load(f)

model_config = genai_config["model"]
decoder_config = model_config["decoder"]

# Resolve max_length
if max_length:
self._max_length = max_length
else:
self._max_length = min(
genai_config.get("search", {}).get("max_length", 2048),
2048, # Cap at 2048 for eval efficiency
)

# EOS token handling (can be list or scalar)
eot = model_config["eos_token_id"]
self._eot_token_id = eot[0] if isinstance(eot, list) else eot

# Set up execution providers
providers = []
if ep:
providers.append(ep)
providers.append("CPUExecutionProvider")
Comment on lines +522 to +525

# Load decoder session
decoder_path = str(model_dir / decoder_config["filename"])
logger.info("Loading decoder from %s", decoder_path)
self._decoder_sess = ort.InferenceSession(decoder_path, providers=providers)

Comment on lines +521 to +531
# Detect per-layer KV cache shapes (supports heterogeneous head_dim)
self._kv_shapes = {}
for inp in self._decoder_sess.get_inputs():
if inp.name.startswith("past_key_values"):
self._kv_shapes[inp.name] = {
"num_kv_heads": inp.shape[1],
"head_dim": inp.shape[3],
}

# Load embedding session if available
self._embedding_sess = None
self._hidden_size = decoder_config["hidden_size"]
embedding_config = model_config.get("embedding")
if embedding_config:
emb_path = str(model_dir / embedding_config["filename"])
logger.info("Loading embedding from %s", emb_path)
self._embedding_sess = ort.InferenceSession(emb_path, providers=providers)

# Load tokenizer from model directory
self._tokenizer = AutoTokenizer.from_pretrained(str(model_dir))
self.batch_size = int(batch_size)

@property
def max_length(self) -> int:
return self._max_length

@property
def eot_token_id(self) -> int:
return self._eot_token_id

def tok_encode(self, string: str, **kwargs) -> list[int]:
return self._tokenizer.encode(string, add_special_tokens=False)

def prepare(self, requests: list[LogLikelihoodInputs]):
pass

def model_call(self, input_ids: torch.Tensor, cont_len: int = 0) -> torch.Tensor:
import numpy as np

batch_size, seq_len = input_ids.shape
ids_np = input_ids.cpu().numpy().astype(np.int64)

# Get embeddings if embedding model is available
if self._embedding_sess is not None:
emb_feed = {
"input_ids": ids_np,
"image_features": np.zeros((0, self._hidden_size), dtype=np.float16),
"audio_features": np.zeros((0, self._hidden_size), dtype=np.float16),
}
inputs_embeds = self._embedding_sess.run(None, emb_feed)[0]
else:
inputs_embeds = np.zeros((batch_size, seq_len, self._hidden_size), dtype=np.float16)

# Build decoder feed with per-layer KV cache shapes
dec_feed = {
"input_ids": ids_np,
"inputs_embeds": inputs_embeds,
"attention_mask": np.ones((batch_size, seq_len), dtype=np.int64),
"position_ids": np.broadcast_to(
np.arange(seq_len, dtype=np.int64).reshape(1, -1),
(batch_size, seq_len),
).copy(),
}
for name, info in self._kv_shapes.items():
dec_feed[name] = np.zeros(
(batch_size, info["num_kv_heads"], 0, info["head_dim"]),
dtype=np.float16,
)

result = self._decoder_sess.run(["logits"], dec_feed)
return torch.from_numpy(result[0])

def complete(self):
pass


@register_model("ortgenai")
class LMEvalORTGenAIEvaluator(LMEvalOnnxBase):
"""Evaluate a model using ONNX Runtime GenAI."""
Expand Down Expand Up @@ -520,6 +665,7 @@

self.device = device
self._returns_full_logits = self._detect_full_logits()
self._cached_generator = None

def _detect_full_logits(self) -> bool:
"""Check if the model returns logits for all input positions or only the last."""
Expand All @@ -546,10 +692,24 @@
def prepare(self, requests: list[LogLikelihoodInputs]):
pass

def model_call(self, input_ids: torch.Tensor, cont_len: int = 0) -> torch.Tensor:
batch_size, seq_len = input_ids.shape
def _get_generator(self, batch_size: int) -> "og.Generator":

Check warning

Code scanning / lintrunner

RUFF/UP037 Warning

Remove quotes from type annotation.
See https://docs.astral.sh/ruff/rules/quoted-annotation
"""Get a Generator, reusing via rewind_to(0) when possible."""
if self._cached_generator is not None:
try:
self._cached_generator.rewind_to(0)
return self._cached_generator
except Exception:
# rewind_to not supported for this model — fall back to new Generator
self._cached_generator = None

self.params.set_search_options(batch_size=batch_size)
generator = og.Generator(self.model, self.params)
self._cached_generator = generator
return generator

def model_call(self, input_ids: torch.Tensor, cont_len: int = 0) -> torch.Tensor:
batch_size, seq_len = input_ids.shape
generator = self._get_generator(batch_size)

if self._returns_full_logits:
generator.append_tokens(input_ids.tolist())
Expand Down Expand Up @@ -579,7 +739,7 @@
return torch.cat(all_logits, dim=1) # [batch, n_logits, vocab]

def complete(self):
pass
self._cached_generator = None

def generate_until(self, requests, disable_tqdm: bool = False) -> list[str]:
"""Generate text until a stop sequence is reached.
Expand Down
6 changes: 6 additions & 0 deletions olive/evaluator/olive_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1572,6 +1572,12 @@ def evaluate(
"ep_options": self.ep_options,
"device": device,
}
elif self.model_class == "ort-multimodal":
init_args = {
"pretrained": str(Path(model.model_path).parent),
"ep": self.ep or execution_providers,
"ep_options": self.ep_options,
}
Comment on lines +1575 to +1580
Comment on lines +1575 to +1580
else:
raise ValueError(f"Unknown model class: {self.model_class}")

Expand Down
Loading