Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
ef42e47
Initial plan
Copilot May 11, 2026
00571f0
feat: add CLI test-model config for HF inputs
Copilot May 11, 2026
485dfbf
test: broaden HF test-model coverage
Copilot May 11, 2026
a6fa34a
chore: polish test model config handling
Copilot May 11, 2026
273850c
fix: fail fast for HF test model loading
Copilot May 11, 2026
318fcbe
refactor: remove nested try from HF test loading
Copilot May 11, 2026
40b0740
test: cover trust_remote_code helper behavior
Copilot May 11, 2026
386ff01
feat: persist reusable HF test model path
Copilot May 11, 2026
09fac8c
fix: tighten HF test model path handling
Copilot May 11, 2026
09df0a7
refactor: simplify test model path handling
Copilot May 11, 2026
d4ebad5
lintrunner
xadupre May 11, 2026
6321d32
lint
xadupre May 11, 2026
6709852
docs: add phi test conversion how-to
Copilot May 11, 2026
3f8f8fc
feat: support run command test models
Copilot May 11, 2026
189289a
chore: address review nits for run test support
Copilot May 11, 2026
c90c520
chore: simplify run test override handling
Copilot May 11, 2026
5cec47f
chore: polish run test support follow-up
Copilot May 11, 2026
eaf0a16
fix: use saved test checkpoint in model builder
Copilot May 11, 2026
17cb075
chore: tidy model builder test fixture
Copilot May 11, 2026
996f633
chore: clarify model builder test model errors
Copilot May 11, 2026
00732b7
fix dtype: auto
xadupre May 11, 2026
76ee4ef
update documentation
xadupre May 11, 2026
9ba38c7
docs: clarify phi smoke test output path
Copilot May 11, 2026
fa6bee7
docs: switch smoke test how-to to qwen
Copilot May 11, 2026
ffda0dc
test: cover documented llm smoke flow
Copilot May 11, 2026
d0f868f
test: polish documented smoke flow test
Copilot May 11, 2026
a408b63
test: rename smoke flow cli test
Copilot May 11, 2026
e272c2a
test: refine smoke flow workflow stubs
Copilot May 11, 2026
c901b63
test: tidy smoke flow helper names
Copilot May 11, 2026
36410cd
test: clarify smoke flow mocks
Copilot May 11, 2026
e16cb82
test: polish documented smoke flow test naming
Copilot May 11, 2026
7507604
test: lift smoke flow imports and mock defaults
Copilot May 11, 2026
ac7840f
fix: keep qwen test layer types in sync
Copilot May 11, 2026
f165dda
Merge origin/main and resolve model builder test conflict
Copilot May 12, 2026
5daba5d
Merge origin/main into copilot/fr-add-model-to-config-json
Copilot May 12, 2026
8941efb
Potential fix for pull request finding
xadupre May 12, 2026
be35ef4
Merge branch 'main' into copilot/fr-add-model-to-config-json
xadupre May 13, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions olive/cli/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ def _get_hf_input_model(args: Namespace, model_path: OLIVE_RESOURCE_ANNOTATIONS)
input_model["adapter_path"] = args.adapter_path
if getattr(args, "trust_remote_code", None) is not None:
input_model["load_kwargs"]["trust_remote_code"] = args.trust_remote_code
if getattr(args, "test", False):
input_model["test_model_config"] = {"hidden_layers": 2}
return input_model


Expand Down Expand Up @@ -371,6 +373,11 @@ def add_input_model_options(
model_group.add_argument(
"--trust_remote_code", action="store_true", help="Trust remote code when loading a huggingface model."
)
model_group.add_argument(
Comment thread
xadupre marked this conversation as resolved.
"--test",
action="store_true",
help="Use a randomly initialized test model with the same Hugging Face architecture and 2 hidden layers.",
)

if enable_hf_adapter:
assert enable_hf, "enable_hf must be True when enable_hf_adapter is True."
Expand Down
8 changes: 6 additions & 2 deletions olive/common/hf/model_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def get_model_io_config(
model_name: str,
task: str,
model: Optional["PreTrainedModel"] = None,
test_model_config: Optional[dict[str, Any]] = None,
**kwargs,
) -> Optional[dict[str, Any]]:
"""Get the input/output config for the model and task.
Expand All @@ -35,6 +36,7 @@ def get_model_io_config(
model_name: The model name or path.
task: The task type (e.g., "text-generation", "text-classification").
model: Optional loaded model for input signature inspection.
test_model_config: Optional overrides for creating a lightweight random test model from the same config.
**kwargs: Additional arguments including use_cache.

Returns:
Expand Down Expand Up @@ -68,7 +70,7 @@ def get_model_io_config(
return None

# Get model config
model_config = get_model_config(model_name, **kwargs)
model_config = get_model_config(model_name, test_model_config=test_model_config, **kwargs)

# Handle PEFT models
actual_model = model
Expand All @@ -92,6 +94,7 @@ def get_model_dummy_input(
model_name: str,
task: str,
model: Optional["PreTrainedModel"] = None,
test_model_config: Optional[dict[str, Any]] = None,
**kwargs,
) -> Optional[dict[str, Any]]:
"""Get dummy inputs for the model and task.
Expand All @@ -100,6 +103,7 @@ def get_model_dummy_input(
model_name: The model name or path.
task: The task type.
model: Optional loaded model for input signature inspection.
test_model_config: Optional overrides for creating a lightweight random test model from the same config.
**kwargs: Additional arguments including use_cache, batch_size, sequence_length.

Returns:
Expand Down Expand Up @@ -133,7 +137,7 @@ def get_model_dummy_input(
return None

# Get model config (handles MLflow paths)
model_config = get_model_config(model_name, **kwargs)
model_config = get_model_config(model_name, test_model_config=test_model_config, **kwargs)

# Handle PEFT models
actual_model = model
Expand Down
56 changes: 49 additions & 7 deletions olive/common/hf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import logging
from copy import deepcopy
from pathlib import Path
from typing import TYPE_CHECKING, Optional, Union
from typing import TYPE_CHECKING, Any, Optional, Union

from transformers import AutoConfig, AutoModel, AutoTokenizer, GenerationConfig

Expand All @@ -18,7 +19,40 @@
logger = logging.getLogger(__name__)


def load_model_from_task(task: str, model_name_or_path: str, **kwargs) -> "PreTrainedModel":
def _apply_test_model_config(
model_config: "PretrainedConfig", test_model_config: Optional[dict[str, Any]] = None
) -> "PretrainedConfig":
"""Apply lightweight test-model overrides to a model config."""
if not test_model_config:
return model_config

model_config = deepcopy(model_config)
if "hidden_layers" in test_model_config:
hidden_layers = test_model_config["hidden_layers"]
elif "num_hidden_layers" in test_model_config:
hidden_layers = test_model_config["num_hidden_layers"]
else:
hidden_layers = 2
if hidden_layers < 1:
raise ValueError("test_model_config.hidden_layers must be greater than 0.")

updated = False
# Common Hugging Face configs do not use a single canonical field:
# BERT-style models use num_hidden_layers while GPT-style models often use n_layer/n_layers/num_layers.
for attr_name in ("num_hidden_layers", "num_layers", "n_layer", "n_layers"):
if hasattr(model_config, attr_name):
setattr(model_config, attr_name, hidden_layers)
updated = True

if not updated:
raise ValueError("Unable to create a test model because the config does not expose a hidden-layer count.")

return model_config


def load_model_from_task(
task: str, model_name_or_path: str, test_model_config: Optional[dict[str, Any]] = None, **kwargs
) -> "PreTrainedModel":
"""Load huggingface model from task and model_name_or_path."""
from transformers.pipelines import check_task

Expand All @@ -31,7 +65,7 @@ def load_model_from_task(task: str, model_name_or_path: str, **kwargs) -> "PreTr
else:
raise ValueError("unsupported transformers version")

model_config = get_model_config(model_name_or_path, **kwargs)
model_config = get_model_config(model_name_or_path, test_model_config=test_model_config, **kwargs)
if getattr(model_config, "quantization_config", None):
if not isinstance(model_config.quantization_config, dict):
model_config.quantization_config = model_config.quantization_config.to_dict()
Expand Down Expand Up @@ -59,7 +93,13 @@ def load_model_from_task(task: str, model_name_or_path: str, **kwargs) -> "PreTr
model = None
for i, model_class in enumerate(class_tuple):
try:
model = from_pretrained(model_class, model_name_or_path, "model", **kwargs)
if test_model_config:
try:
Comment thread
xadupre marked this conversation as resolved.
Outdated
model = model_class.from_config(model_config, trust_remote_code=kwargs.get("trust_remote_code"))
except TypeError:
model = model_class.from_config(model_config)
else:
model = from_pretrained(model_class, model_name_or_path, "model", **kwargs)
logger.debug("Loaded model %s with name_or_path %s", model_class, model_name_or_path)
break
except (OSError, ValueError) as e:
Expand Down Expand Up @@ -94,14 +134,16 @@ def from_pretrained(cls, model_name_or_path: str, mlflow_dir: str, **kwargs):
return cls.from_pretrained(get_pretrained_name_or_path(model_name_or_path, mlflow_dir), **kwargs)


def get_model_config(model_name_or_path: str, **kwargs) -> "PretrainedConfig":
def get_model_config(
model_name_or_path: str, test_model_config: Optional[dict[str, Any]] = None, **kwargs
) -> "PretrainedConfig":
"""Get HF Config for the given model_name_or_path."""
model_config = from_pretrained(AutoConfig, model_name_or_path, "config", **kwargs)

# add quantization config
quantization_config = kwargs.get("quantization_config")
if not quantization_config:
return model_config
return _apply_test_model_config(model_config, test_model_config)

if hasattr(model_config, "quantization_config") and model_config.quantization_config:
logger.warning(
Expand All @@ -111,7 +153,7 @@ def get_model_config(model_name_or_path: str, **kwargs) -> "PretrainedConfig":
)
else:
model_config.quantization_config = quantization_config
return model_config
return _apply_test_model_config(model_config, test_model_config)


def save_model_config(config: Union["PretrainedConfig", "GenerationConfig"], output_dir: str, **kwargs):
Expand Down
11 changes: 9 additions & 2 deletions olive/model/handler/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
@model_handler_registry("HFModel")
class HfModelHandler(PyTorchModelHandlerBase, MLFlowTransformersMixin, HfMixin): # pylint: disable=too-many-ancestors
resource_keys: tuple[str, ...] = ("model_path", "adapter_path")
json_config_keys: tuple[str, ...] = ("task", "load_kwargs")
json_config_keys: tuple[str, ...] = ("task", "load_kwargs", "test_model_config")

def __init__(
self,
Expand All @@ -37,6 +37,7 @@ def __init__(
load_kwargs: Union[dict[str, Any], HfLoadKwargs] = None,
io_config: Union[dict[str, Any], IoConfig, str] = None,
adapter_path: OLIVE_RESOURCE_ANNOTATIONS = None,
test_model_config: Optional[dict[str, Any]] = None,
model_attributes: Optional[dict[str, Any]] = None,
):
super().__init__(
Expand All @@ -48,6 +49,7 @@ def __init__(
self.add_resources(locals())
self.task = task
self.load_kwargs = validate_config(load_kwargs, HfLoadKwargs, warn_unused_keys=False) if load_kwargs else None
self.test_model_config = test_model_config

self.model_attributes = {**self.get_hf_model_config().to_dict(), **(self.model_attributes or {})}

Expand All @@ -72,7 +74,12 @@ def load_model(self, rank: int = None, cache_model: bool = True) -> "torch.nn.Mo
if self.model:
model = self.model
else:
model = load_model_from_task(self.task, self.model_path, **self.get_load_kwargs())
model = load_model_from_task(
self.task,
self.model_path,
test_model_config=self.test_model_config,
**self.get_load_kwargs(),
)

# we only have peft adapters for now
if self.adapter_path:
Expand Down
15 changes: 13 additions & 2 deletions olive/model/handler/mixin/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,11 @@ def get_hf_model_config(self, exclude_load_keys: Optional[list[str]] = None) ->
:param exclude_load_keys: list of keys to exclude from load_kwargs
:return: model config
"""
return get_model_config(self.model_path, **self.get_load_kwargs(exclude_load_keys))
return get_model_config(
self.model_path,
test_model_config=getattr(self, "test_model_config", None),
**self.get_load_kwargs(exclude_load_keys),
)

def get_hf_generation_config(self, exclude_load_keys: Optional[list[str]] = None) -> Optional["GenerationConfig"]:
"""Get generation config for the model if it exists.
Expand Down Expand Up @@ -114,14 +118,21 @@ def save_metadata(self, output_dir: str, exclude_load_keys: Optional[list[str]]

def get_hf_io_config(self) -> Optional[dict[str, Any]]:
"""Get Io config for the model."""
return get_model_io_config(self.model_path, self.task, self.load_model(), **self.get_load_kwargs())
return get_model_io_config(
self.model_path,
self.task,
self.load_model(),
test_model_config=getattr(self, "test_model_config", None),
**self.get_load_kwargs(),
)

def get_hf_dummy_inputs(self) -> Optional[dict[str, Any]]:
"""Get dummy inputs for the model."""
return get_model_dummy_input(
self.model_path,
self.task,
model=self.load_model(),
test_model_config=getattr(self, "test_model_config", None),
**self.get_load_kwargs(),
)

Expand Down
16 changes: 16 additions & 0 deletions test/cli/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,22 @@ def test_insert_input_model_invalid_hf_model_name():
get_input_model_config(args)


@patch("huggingface_hub.repo_exists", return_value=True)
def test_get_input_model_config_hf_test_model(_):
args = SimpleNamespace(
model_name_or_path="hf_model",
trust_remote_code=False,
task="text-generation",
model_script=None,
script_dir=None,
test=True,
)

config = get_input_model_config(args)

assert config["test_model_config"] == {"hidden_layers": 2}


def test_insert_input_model_cli_output_model():
# setup
model_path = str(Path(__file__).parent.resolve() / "output_model")
Expand Down
19 changes: 19 additions & 0 deletions test/cli/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,25 @@ def test_finetune_command(_, mock_run, tmp_path):
assert mock_run.call_count == 1


@patch("huggingface_hub.repo_exists", return_value=True)
def test_optimize_command_test_model_config(_, tmp_path):
output_dir = tmp_path / "output_dir"
command_args = [
"optimize",
"-m",
"dummy-model-id",
"--test",
"--dry_run",
"-o",
str(output_dir),
]

cli_main(command_args)

config = json.loads((output_dir / "config.json").read_text())
assert config["input_model"]["test_model_config"] == {"hidden_layers": 2}


Comment thread
xadupre marked this conversation as resolved.
@patch("olive.workflows.run")
@patch("olive.model.handler.diffusers.is_valid_diffusers_model", return_value=True)
def test_diffusion_lora_command(_, mock_run, tmp_path):
Expand Down
27 changes: 27 additions & 0 deletions test/common/test_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import pytest
import torch
from transformers import BertConfig, GPT2Config

from olive.common.hf.model_io import get_model_dummy_input, get_model_io_config
from olive.common.hf.utils import load_model_from_task
Expand All @@ -21,6 +22,32 @@ def test_load_model_from_task():
assert isinstance(model, torch.nn.Module)


@pytest.mark.parametrize(
("model_config", "hidden_layers_attr"),
[
(BertConfig(num_hidden_layers=12), "num_hidden_layers"),
Comment thread
github-advanced-security[bot] marked this conversation as resolved.
Fixed
(GPT2Config(n_layer=12), "n_layer"),
Comment thread
github-advanced-security[bot] marked this conversation as resolved.
Fixed
],
)
def test_load_model_from_task_test_model_config(model_config, hidden_layers_attr):
created_model = MagicMock(spec=torch.nn.Module)

with (
patch("transformers.pipelines.check_task") as mock_check_task,
patch("olive.common.hf.utils.from_pretrained", return_value=model_config) as mock_from_pretrained,
):
mock_model_class = MagicMock()
mock_model_class.from_config.return_value = created_model
mock_check_task.return_value = ("text-classification", {"pt": (mock_model_class,)}, None)

model = load_model_from_task("text-classification", "dummy-model", test_model_config={"hidden_layers": 2})

assert model is created_model
mock_from_pretrained.assert_called_once()
mock_model_class.from_config.assert_called_once()
assert getattr(mock_model_class.from_config.call_args.args[0], hidden_layers_attr) == 2


@pytest.mark.parametrize(
("exceptions", "expected_exception", "expected_message"),
[
Expand Down