Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 46 additions & 3 deletions olive/passes/onnx/mobius_model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,19 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> dict[str, PassCon
"configs alongside the ONNX models. 'none' to skip."
),
),
"components_to_export": PassConfigParam(
type_=list[str],
required=False,
default_value=None,
description=(
"Optional list of component names to export from a multi-component model "
"(e.g. ['vision', 'embedding'] to skip the decoder). "
"When set, only the named components are saved and returned; "
"all others are discarded after the mobius build step. "
"When not set (None), all components are exported (default, backward compatible). "
"Raises ValueError if any specified name is not found in the model's components."
Comment thread
titaiwangms marked this conversation as resolved.
Outdated
),
),
}

def _run_for_config(
Expand Down Expand Up @@ -163,21 +176,51 @@ def _run_for_config(
trust_remote_code=trust_remote_code,
)

# Determine which package components to export.
all_keys = list(pkg.keys())
if config.components_to_export is not None:
if len(config.components_to_export) == 0:
raise ValueError(
"MobiusBuilder: components_to_export cannot be empty. "
"Pass None to export all components, or specify at least one component name."
)
requested = set(config.components_to_export)
unknown = requested - set(all_keys)
if unknown:
raise ValueError(
f"MobiusBuilder: components_to_export contains unknown component(s): {sorted(unknown)}. "
f"Available components from this model: {sorted(all_keys)}"
)
package_keys = [k for k in all_keys if k in requested]
logger.info(
"MobiusBuilder: exporting subset of components %s (skipping %s)",
package_keys,
[k for k in all_keys if k not in requested],
)

def components_filter(name: str) -> bool:
return name in requested
else:
package_keys = all_keys
components_filter = None

# ModelPackage.save() handles both single and multi-component layouts:
# single component → <output_dir>/model.onnx
# multi-component → <output_dir>/<name>/model.onnx for each key
pkg.save(str(output_dir))
pkg.save(str(output_dir), components=components_filter)

Comment thread
titaiwangms marked this conversation as resolved.
# Generate ORT GenAI config artifacts (genai_config.json, tokenizer
# files, processor configs) when runtime is set to ort-genai.
genai_artifacts = {}
if config.runtime == self.MobiusRuntime.ORT_GENAI:
genai_artifacts = self._write_genai_config(pkg, str(output_dir), model_id, ep_str)

package_keys = list(pkg.keys())
logger.info("MobiusBuilder: saved components %s to '%s'", package_keys, output_dir)

if len(package_keys) == 1:
# Use the single-component (root layout) path only when the model is
# architecturally single-component. A multi-component model filtered
# down to one component still uses component sub-directories on disk.
if len(all_keys) == 1:
# Single-component model (most LLMs): return a plain ONNXModelHandler.
onnx_path = output_dir / "model.onnx"
if not onnx_path.exists():
Expand Down
150 changes: 146 additions & 4 deletions test/passes/onnx/test_mobius_model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,18 +74,23 @@ def _make_pass(ep: str = ExecutionProvider.CPUExecutionProvider) -> MobiusBuilde


def _fake_pkg(keys: list[str], _output_dir: Path) -> MagicMock:
"""Create a fake ModelPackage that writes dummy .onnx files when .save() is called."""
"""Create a fake ModelPackage that writes dummy .onnx files when .save() is called.

def _save(directory: str, **_kwargs):
Respects the optional ``components`` filter kwarg passed to ``save()``: only writes
files for components for which ``components(name)`` returns True (or all if None).
"""

def _save(directory: str, components=None, **_kwargs):
out = Path(directory)
if len(keys) == 1:
# Single-component: saved as <dir>/model.onnx
(out / "model.onnx").write_text("dummy")
else:
Comment thread
titaiwangms marked this conversation as resolved.
# Multi-component: saved as <dir>/<key>/model.onnx
for k in keys:
(out / k).mkdir(parents=True, exist_ok=True)
(out / k / "model.onnx").write_text("dummy")
if components is None or components(k):
(out / k).mkdir(parents=True, exist_ok=True)
(out / k / "model.onnx").write_text("dummy")

pkg = MagicMock()
pkg.keys.return_value = keys
Expand Down Expand Up @@ -454,3 +459,140 @@ def test_no_warning_when_trust_remote_code_false(tmp_path):

warning_messages = [call.args[0] for call in mock_logger.warning.call_args_list]
assert not any("trust_remote_code" in msg for msg in warning_messages)


# ---------------------------------------------------------------------------
# components_to_export filter tests
# ---------------------------------------------------------------------------


def test_components_to_export_filters_subset(tmp_path):
"""Only requested components are saved and returned when components_to_export is set."""
out = tmp_path / "out"
keys = ["decoder", "vision_encoder", "embedding"]
pkg = _fake_pkg(keys, out)

accelerator_spec = AcceleratorSpec(
accelerator_type=Device.CPU, execution_provider=ExecutionProvider.CPUExecutionProvider
)
p = create_pass_from_dict(
MobiusBuilder,
{"precision": "fp16", "components_to_export": ["vision_encoder", "embedding"]},
disable_search=True,
accelerator_spec=accelerator_spec,
)

with _patch_build(pkg):
result = p.run(_make_hf_model("org/vlm"), out)

assert isinstance(result, CompositeModelHandler)
assert result.model_component_names == ["vision_encoder", "embedding"]

# pkg.save must have been called with a components filter that excludes decoder
save_kwargs = pkg.save.call_args.kwargs
components_filter = save_kwargs.get("components")
assert components_filter is not None
assert components_filter("vision_encoder") is True
assert components_filter("embedding") is True
assert components_filter("decoder") is False

# Verify skipped component directory is absent from disk
assert (out / "vision_encoder" / "model.onnx").exists(), "vision_encoder should be on disk"
assert (out / "embedding" / "model.onnx").exists(), "embedding should be on disk"
assert not (out / "decoder").exists(), "decoder directory should not exist on disk (was skipped)"


def test_components_to_export_none_exports_all(tmp_path):
"""All components are exported when components_to_export is None (default)."""
out = tmp_path / "out"
keys = ["decoder", "vision_encoder", "embedding"]
pkg = _fake_pkg(keys, out)

with _patch_build(pkg):
result = _make_pass().run(_make_hf_model("org/vlm"), out)

assert isinstance(result, CompositeModelHandler)
assert result.model_component_names == keys
# pkg.save must have been called without a filter (components=None)
save_kwargs = pkg.save.call_args.kwargs
assert save_kwargs.get("components") is None


def test_components_to_export_single_component_via_filter(tmp_path):
"""Filtering a multi-component model to one component returns CompositeModelHandler with one component.

Unlike an architecturally single-component model (which uses root layout), a
filtered multi-component model still uses the component sub-directory layout,
so we always return CompositeModelHandler for multi-component packages.
"""
out = tmp_path / "out"
keys = ["decoder", "vision_encoder", "embedding"]
pkg = _fake_pkg(keys, out)

accelerator_spec = AcceleratorSpec(
accelerator_type=Device.CPU, execution_provider=ExecutionProvider.CPUExecutionProvider
)
p = create_pass_from_dict(
MobiusBuilder,
{"precision": "fp16", "components_to_export": ["decoder"]},
disable_search=True,
accelerator_spec=accelerator_spec,
)

with _patch_build(pkg):
result = p.run(_make_hf_model("org/vlm"), out)

# Multi-component model filtered to 1 → still CompositeModelHandler (component sub-dir layout)
assert isinstance(result, CompositeModelHandler)
assert result.model_component_names == ["decoder"]


def test_components_to_export_unknown_component_raises(tmp_path):
"""ValueError when components_to_export names a component not in the package."""
out = tmp_path / "out"
keys = ["decoder", "vision_encoder"]
pkg = _fake_pkg(keys, out)

accelerator_spec = AcceleratorSpec(
accelerator_type=Device.CPU, execution_provider=ExecutionProvider.CPUExecutionProvider
)
p = create_pass_from_dict(
MobiusBuilder,
{"precision": "fp16", "components_to_export": ["nonexistent"]},
disable_search=True,
accelerator_spec=accelerator_spec,
)

with _patch_build(pkg), pytest.raises(ValueError, match="unknown component"):
p.run(_make_hf_model("org/vlm"), out)


def test_components_to_export_empty_list_raises(tmp_path):
"""components_to_export=[] must raise ValueError — empty list is always a mistake."""
out = tmp_path / "out"
keys = ["decoder", "vision_encoder"]
pkg = _fake_pkg(keys, out)

accelerator_spec = AcceleratorSpec(
accelerator_type=Device.CPU, execution_provider=ExecutionProvider.CPUExecutionProvider
)
p = create_pass_from_dict(
MobiusBuilder,
{"precision": "fp16", "components_to_export": []},
disable_search=True,
accelerator_spec=accelerator_spec,
)

with _patch_build(pkg), pytest.raises(ValueError, match="cannot be empty"):
p.run(_make_hf_model("org/vlm"), out)


def test_components_to_export_in_default_config():
"""components_to_export parameter must appear in _default_config with None default."""
accelerator_spec = AcceleratorSpec(
accelerator_type=Device.CPU, execution_provider=ExecutionProvider.CPUExecutionProvider
)
config = MobiusBuilder._default_config(accelerator_spec) # pylint: disable=protected-access
assert "components_to_export" in config
assert config["components_to_export"].default_value is None
assert config["components_to_export"].required is False
Loading