Skip to content
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 56 additions & 3 deletions olive/passes/onnx/mobius_model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,20 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> dict[str, PassCon
"configs alongside the ONNX models. 'none' to skip."
),
),
"components_to_export": PassConfigParam(
type_=list[str],
required=False,
default_value=None,
description=(
"Optional list of component names to export from a multi-component model "
"(e.g. ['vision', 'embedding'] to skip the decoder). "
"When set, only the named components are saved and returned; "
"all others are discarded after the mobius build step. "
"When not set (None), all components are exported (default, backward compatible). "
"Raises ValueError if the list is empty or if any specified name is not found in "
"the model's components."
),
),
}

def _run_for_config(
Expand Down Expand Up @@ -163,21 +177,60 @@ def _run_for_config(
trust_remote_code=trust_remote_code,
)

# Determine which package components to export.
all_keys = list(pkg.keys())
if config.components_to_export is not None:
if len(config.components_to_export) == 0:
raise ValueError(
"MobiusBuilder: components_to_export cannot be empty. "
"Pass None to export all components, or specify at least one component name."
)
requested = set(config.components_to_export)
unknown = requested - set(all_keys)
if unknown:
raise ValueError(
f"MobiusBuilder: components_to_export contains unknown component(s): {sorted(unknown)}. "
f"Available components from this model: {sorted(all_keys)}"
)
package_keys = [k for k in all_keys if k in requested]
logger.info(
"MobiusBuilder: exporting subset of components %s (skipping %s)",
package_keys,
[k for k in all_keys if k not in requested],
)

def components_filter(name: str) -> bool:
return name in requested
else:
package_keys = all_keys
components_filter = None

# ModelPackage.save() handles both single and multi-component layouts:
# single component → <output_dir>/model.onnx
# multi-component → <output_dir>/<name>/model.onnx for each key
pkg.save(str(output_dir))
# Older mobius releases may not support the `components` kwarg — fall back gracefully.
try:
pkg.save(str(output_dir), components=components_filter)
except TypeError:
if components_filter is not None:
logger.warning(
"MobiusBuilder: installed mobius version does not support the 'components' filter kwarg; "
"all components will be saved. Upgrade mobius to enable selective export."
)
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed: replaced the try/except TypeError with inspect.signature(pkg.save).parameters check. This detects kwarg support upfront from the function signature, avoiding any execution of save() before the error — so no orphaned dirs and no masked real TypeErrors. The check happens cleanly before the call.

pkg.save(str(output_dir))

Comment thread
titaiwangms marked this conversation as resolved.
# Generate ORT GenAI config artifacts (genai_config.json, tokenizer
# files, processor configs) when runtime is set to ort-genai.
genai_artifacts = {}
if config.runtime == self.MobiusRuntime.ORT_GENAI:
genai_artifacts = self._write_genai_config(pkg, str(output_dir), model_id, ep_str)

package_keys = list(pkg.keys())
logger.info("MobiusBuilder: saved components %s to '%s'", package_keys, output_dir)

if len(package_keys) == 1:
# Use the single-component (root layout) path only when the model is
# architecturally single-component. A multi-component model filtered
# down to one component still uses component sub-directories on disk.
if len(all_keys) == 1:
# Single-component model (most LLMs): return a plain ONNXModelHandler.
onnx_path = output_dir / "model.onnx"
if not onnx_path.exists():
Expand Down
201 changes: 195 additions & 6 deletions test/passes/onnx/test_mobius_model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,18 +74,26 @@


def _fake_pkg(keys: list[str], _output_dir: Path) -> MagicMock:
"""Create a fake ModelPackage that writes dummy .onnx files when .save() is called."""
"""Create a fake ModelPackage that writes dummy .onnx files when .save() is called.

def _save(directory: str, **_kwargs):
Respects the optional ``components`` filter kwarg passed to ``save()``: only writes
files for components for which ``components(name)`` returns True (or all if None).
"""

def _save(directory: str, components=None, **_kwargs):
out = Path(directory)
if len(keys) == 1:
# Single-component: saved as <dir>/model.onnx
(out / "model.onnx").write_text("dummy")
# Single-component: saved as <dir>/model.onnx.
# Apply the components filter consistently with multi-component behaviour.

Check warning on line 87 in test/passes/onnx/test_mobius_model_builder.py

View workflow job for this annotation

GitHub Actions / Optional Lint

[misspell] reported by reviewdog 🐶 "behaviour" is a misspelling of "behavior" Raw Output: ./test/passes/onnx/test_mobius_model_builder.py:87:76: "behaviour" is a misspelling of "behavior"
key = keys[0]
if components is None or components(key):
(out / "model.onnx").write_text("dummy")
else:
Comment thread
titaiwangms marked this conversation as resolved.
# Multi-component: saved as <dir>/<key>/model.onnx
for k in keys:
(out / k).mkdir(parents=True, exist_ok=True)
(out / k / "model.onnx").write_text("dummy")
if components is None or components(k):
(out / k).mkdir(parents=True, exist_ok=True)
(out / k / "model.onnx").write_text("dummy")

pkg = MagicMock()
pkg.keys.return_value = keys
Expand Down Expand Up @@ -454,3 +462,184 @@

warning_messages = [call.args[0] for call in mock_logger.warning.call_args_list]
assert not any("trust_remote_code" in msg for msg in warning_messages)


# ---------------------------------------------------------------------------
# components_to_export filter tests
# ---------------------------------------------------------------------------


def test_components_to_export_filters_subset(tmp_path):
"""Only requested components are saved and returned when components_to_export is set."""
out = tmp_path / "out"
keys = ["decoder", "vision_encoder", "embedding"]
pkg = _fake_pkg(keys, out)

accelerator_spec = AcceleratorSpec(
accelerator_type=Device.CPU, execution_provider=ExecutionProvider.CPUExecutionProvider
)
p = create_pass_from_dict(
MobiusBuilder,
{"precision": "fp16", "components_to_export": ["vision_encoder", "embedding"]},
disable_search=True,
accelerator_spec=accelerator_spec,
)

with _patch_build(pkg):
result = p.run(_make_hf_model("org/vlm"), out)

assert isinstance(result, CompositeModelHandler)
assert result.model_component_names == ["vision_encoder", "embedding"]

# pkg.save must have been called with a components filter that excludes decoder
save_kwargs = pkg.save.call_args.kwargs
components_filter = save_kwargs.get("components")
assert components_filter is not None
assert components_filter("vision_encoder") is True
assert components_filter("embedding") is True
assert components_filter("decoder") is False

# Verify skipped component directory is absent from disk
assert (out / "vision_encoder" / "model.onnx").exists(), "vision_encoder should be on disk"
assert (out / "embedding" / "model.onnx").exists(), "embedding should be on disk"
assert not (out / "decoder").exists(), "decoder directory should not exist on disk (was skipped)"


def test_components_to_export_none_exports_all(tmp_path):
"""All components are exported when components_to_export is None (default)."""
out = tmp_path / "out"
keys = ["decoder", "vision_encoder", "embedding"]
pkg = _fake_pkg(keys, out)

with _patch_build(pkg):
result = _make_pass().run(_make_hf_model("org/vlm"), out)

assert isinstance(result, CompositeModelHandler)
assert result.model_component_names == keys
# pkg.save must have been called without a filter (components=None)
save_kwargs = pkg.save.call_args.kwargs
assert save_kwargs.get("components") is None


def test_components_to_export_single_component_via_filter(tmp_path):
"""Filtering a multi-component model to one component returns CompositeModelHandler with one component.

Unlike an architecturally single-component model (which uses root layout), a
filtered multi-component model still uses the component sub-directory layout,
so we always return CompositeModelHandler for multi-component packages.
"""
out = tmp_path / "out"
keys = ["decoder", "vision_encoder", "embedding"]
pkg = _fake_pkg(keys, out)

accelerator_spec = AcceleratorSpec(
accelerator_type=Device.CPU, execution_provider=ExecutionProvider.CPUExecutionProvider
)
p = create_pass_from_dict(
MobiusBuilder,
{"precision": "fp16", "components_to_export": ["decoder"]},
disable_search=True,
accelerator_spec=accelerator_spec,
)

with _patch_build(pkg):
result = p.run(_make_hf_model("org/vlm"), out)

# Multi-component model filtered to 1 → still CompositeModelHandler (component sub-dir layout)
assert isinstance(result, CompositeModelHandler)
assert result.model_component_names == ["decoder"]


def test_components_to_export_unknown_component_raises(tmp_path):
"""ValueError when components_to_export names a component not in the package."""
out = tmp_path / "out"
keys = ["decoder", "vision_encoder"]
pkg = _fake_pkg(keys, out)

accelerator_spec = AcceleratorSpec(
accelerator_type=Device.CPU, execution_provider=ExecutionProvider.CPUExecutionProvider
)
p = create_pass_from_dict(
MobiusBuilder,
{"precision": "fp16", "components_to_export": ["nonexistent"]},
disable_search=True,
accelerator_spec=accelerator_spec,
)

with _patch_build(pkg), pytest.raises(ValueError, match="unknown component"):
p.run(_make_hf_model("org/vlm"), out)


def test_components_to_export_empty_list_raises(tmp_path):
"""components_to_export=[] must raise ValueError — empty list is always a mistake."""
out = tmp_path / "out"
keys = ["decoder", "vision_encoder"]
pkg = _fake_pkg(keys, out)

accelerator_spec = AcceleratorSpec(
accelerator_type=Device.CPU, execution_provider=ExecutionProvider.CPUExecutionProvider
)
p = create_pass_from_dict(
MobiusBuilder,
{"precision": "fp16", "components_to_export": []},
disable_search=True,
accelerator_spec=accelerator_spec,
)

with _patch_build(pkg), pytest.raises(ValueError, match="cannot be empty"):
p.run(_make_hf_model("org/vlm"), out)


def test_components_to_export_in_default_config():
"""components_to_export parameter must appear in _default_config with None default."""
accelerator_spec = AcceleratorSpec(
accelerator_type=Device.CPU, execution_provider=ExecutionProvider.CPUExecutionProvider
)
config = MobiusBuilder._default_config(accelerator_spec) # pylint: disable=protected-access
assert "components_to_export" in config
assert config["components_to_export"].default_value is None
assert config["components_to_export"].required is False


def test_pkg_save_typeerror_falls_back_gracefully(tmp_path):
"""When pkg.save() raises TypeError for the components= kwarg (old mobius), fall back without filter."""
out = tmp_path / "out"
keys = ["decoder", "vision_encoder", "embedding"]

# Build a pkg whose save() raises TypeError only when components= is passed (old mobius API).
def _save_old_api(directory: str, **kwargs):
if "components" in kwargs:
raise TypeError("unexpected keyword argument 'components'")
# Old API: save all components unconditionally.
d = Path(directory)
for k in keys:
(d / k).mkdir(parents=True, exist_ok=True)
(d / k / "model.onnx").write_text("dummy")

pkg = MagicMock()
pkg.keys.return_value = keys
pkg.__iter__ = MagicMock(return_value=iter(keys))
pkg.items.return_value = [(k, MagicMock()) for k in keys]
pkg.save.side_effect = _save_old_api

accelerator_spec = AcceleratorSpec(
accelerator_type=Device.CPU, execution_provider=ExecutionProvider.CPUExecutionProvider
)
p = create_pass_from_dict(
MobiusBuilder,
{"precision": "fp16", "components_to_export": ["vision_encoder", "embedding"]},
disable_search=True,
accelerator_spec=accelerator_spec,
)

with (
_patch_build(pkg),
patch("olive.passes.onnx.mobius_model_builder.logger") as mock_logger,
):
result = p.run(_make_hf_model("org/vlm"), out)

# Old mobius saved all; pass returns CompositeModelHandler with all keys (filter not enforced).
assert isinstance(result, CompositeModelHandler)
# A warning must have been logged about the missing kwarg support.
warning_messages = [str(call) for call in mock_logger.warning.call_args_list]
assert any("components" in msg and "mobius" in msg for msg in warning_messages)
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed: test renamed to test_pkg_save_old_api_no_components_kwarg_falls_back_gracefully. Now asserts: (1) result.model_component_names == ["vision_encoder", "embedding"] — the 2 requested components, not all 3; (2) all 3 component dirs are on disk (old API writes all, decoder is orphaned); (3) pkg.save was NOT called with components= kwarg. Added a separate regression test test_pkg_save_components_kwarg_detected_and_filter_applied for the modern-API path.

Loading