-
Notifications
You must be signed in to change notification settings - Fork 297
feat: add components_to_export filter to MobiusBuilder pass #2456
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 8 commits
89a3410
6bfdff4
d1207e1
c6bb2c0
87e41ff
37b5709
ec35cec
a452b34
79eb44f
dcd23b4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -74,18 +74,26 @@ | |
|
|
||
|
|
||
| def _fake_pkg(keys: list[str], _output_dir: Path) -> MagicMock: | ||
| """Create a fake ModelPackage that writes dummy .onnx files when .save() is called.""" | ||
| """Create a fake ModelPackage that writes dummy .onnx files when .save() is called. | ||
|
|
||
| def _save(directory: str, **_kwargs): | ||
| Respects the optional ``components`` filter kwarg passed to ``save()``: only writes | ||
| files for components for which ``components(name)`` returns True (or all if None). | ||
| """ | ||
|
|
||
| def _save(directory: str, components=None, **_kwargs): | ||
| out = Path(directory) | ||
| if len(keys) == 1: | ||
| # Single-component: saved as <dir>/model.onnx | ||
| (out / "model.onnx").write_text("dummy") | ||
| # Single-component: saved as <dir>/model.onnx. | ||
| # Apply the components filter consistently with multi-component behaviour. | ||
|
Check warning on line 87 in test/passes/onnx/test_mobius_model_builder.py
|
||
| key = keys[0] | ||
| if components is None or components(key): | ||
| (out / "model.onnx").write_text("dummy") | ||
| else: | ||
|
titaiwangms marked this conversation as resolved.
|
||
| # Multi-component: saved as <dir>/<key>/model.onnx | ||
| for k in keys: | ||
| (out / k).mkdir(parents=True, exist_ok=True) | ||
| (out / k / "model.onnx").write_text("dummy") | ||
| if components is None or components(k): | ||
| (out / k).mkdir(parents=True, exist_ok=True) | ||
| (out / k / "model.onnx").write_text("dummy") | ||
|
|
||
| pkg = MagicMock() | ||
| pkg.keys.return_value = keys | ||
|
|
@@ -454,3 +462,184 @@ | |
|
|
||
| warning_messages = [call.args[0] for call in mock_logger.warning.call_args_list] | ||
| assert not any("trust_remote_code" in msg for msg in warning_messages) | ||
|
|
||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # components_to_export filter tests | ||
| # --------------------------------------------------------------------------- | ||
|
|
||
|
|
||
| def test_components_to_export_filters_subset(tmp_path): | ||
| """Only requested components are saved and returned when components_to_export is set.""" | ||
| out = tmp_path / "out" | ||
| keys = ["decoder", "vision_encoder", "embedding"] | ||
| pkg = _fake_pkg(keys, out) | ||
|
|
||
| accelerator_spec = AcceleratorSpec( | ||
| accelerator_type=Device.CPU, execution_provider=ExecutionProvider.CPUExecutionProvider | ||
| ) | ||
| p = create_pass_from_dict( | ||
| MobiusBuilder, | ||
| {"precision": "fp16", "components_to_export": ["vision_encoder", "embedding"]}, | ||
| disable_search=True, | ||
| accelerator_spec=accelerator_spec, | ||
| ) | ||
|
|
||
| with _patch_build(pkg): | ||
| result = p.run(_make_hf_model("org/vlm"), out) | ||
|
|
||
| assert isinstance(result, CompositeModelHandler) | ||
| assert result.model_component_names == ["vision_encoder", "embedding"] | ||
|
|
||
| # pkg.save must have been called with a components filter that excludes decoder | ||
| save_kwargs = pkg.save.call_args.kwargs | ||
| components_filter = save_kwargs.get("components") | ||
| assert components_filter is not None | ||
| assert components_filter("vision_encoder") is True | ||
| assert components_filter("embedding") is True | ||
| assert components_filter("decoder") is False | ||
|
|
||
| # Verify skipped component directory is absent from disk | ||
| assert (out / "vision_encoder" / "model.onnx").exists(), "vision_encoder should be on disk" | ||
| assert (out / "embedding" / "model.onnx").exists(), "embedding should be on disk" | ||
| assert not (out / "decoder").exists(), "decoder directory should not exist on disk (was skipped)" | ||
|
|
||
|
|
||
| def test_components_to_export_none_exports_all(tmp_path): | ||
| """All components are exported when components_to_export is None (default).""" | ||
| out = tmp_path / "out" | ||
| keys = ["decoder", "vision_encoder", "embedding"] | ||
| pkg = _fake_pkg(keys, out) | ||
|
|
||
| with _patch_build(pkg): | ||
| result = _make_pass().run(_make_hf_model("org/vlm"), out) | ||
|
|
||
| assert isinstance(result, CompositeModelHandler) | ||
| assert result.model_component_names == keys | ||
| # pkg.save must have been called without a filter (components=None) | ||
| save_kwargs = pkg.save.call_args.kwargs | ||
| assert save_kwargs.get("components") is None | ||
|
|
||
|
|
||
| def test_components_to_export_single_component_via_filter(tmp_path): | ||
| """Filtering a multi-component model to one component returns CompositeModelHandler with one component. | ||
|
|
||
| Unlike an architecturally single-component model (which uses root layout), a | ||
| filtered multi-component model still uses the component sub-directory layout, | ||
| so we always return CompositeModelHandler for multi-component packages. | ||
| """ | ||
| out = tmp_path / "out" | ||
| keys = ["decoder", "vision_encoder", "embedding"] | ||
| pkg = _fake_pkg(keys, out) | ||
|
|
||
| accelerator_spec = AcceleratorSpec( | ||
| accelerator_type=Device.CPU, execution_provider=ExecutionProvider.CPUExecutionProvider | ||
| ) | ||
| p = create_pass_from_dict( | ||
| MobiusBuilder, | ||
| {"precision": "fp16", "components_to_export": ["decoder"]}, | ||
| disable_search=True, | ||
| accelerator_spec=accelerator_spec, | ||
| ) | ||
|
|
||
| with _patch_build(pkg): | ||
| result = p.run(_make_hf_model("org/vlm"), out) | ||
|
|
||
| # Multi-component model filtered to 1 → still CompositeModelHandler (component sub-dir layout) | ||
| assert isinstance(result, CompositeModelHandler) | ||
| assert result.model_component_names == ["decoder"] | ||
|
|
||
|
|
||
| def test_components_to_export_unknown_component_raises(tmp_path): | ||
| """ValueError when components_to_export names a component not in the package.""" | ||
| out = tmp_path / "out" | ||
| keys = ["decoder", "vision_encoder"] | ||
| pkg = _fake_pkg(keys, out) | ||
|
|
||
| accelerator_spec = AcceleratorSpec( | ||
| accelerator_type=Device.CPU, execution_provider=ExecutionProvider.CPUExecutionProvider | ||
| ) | ||
| p = create_pass_from_dict( | ||
| MobiusBuilder, | ||
| {"precision": "fp16", "components_to_export": ["nonexistent"]}, | ||
| disable_search=True, | ||
| accelerator_spec=accelerator_spec, | ||
| ) | ||
|
|
||
| with _patch_build(pkg), pytest.raises(ValueError, match="unknown component"): | ||
| p.run(_make_hf_model("org/vlm"), out) | ||
|
|
||
|
|
||
| def test_components_to_export_empty_list_raises(tmp_path): | ||
| """components_to_export=[] must raise ValueError — empty list is always a mistake.""" | ||
| out = tmp_path / "out" | ||
| keys = ["decoder", "vision_encoder"] | ||
| pkg = _fake_pkg(keys, out) | ||
|
|
||
| accelerator_spec = AcceleratorSpec( | ||
| accelerator_type=Device.CPU, execution_provider=ExecutionProvider.CPUExecutionProvider | ||
| ) | ||
| p = create_pass_from_dict( | ||
| MobiusBuilder, | ||
| {"precision": "fp16", "components_to_export": []}, | ||
| disable_search=True, | ||
| accelerator_spec=accelerator_spec, | ||
| ) | ||
|
|
||
| with _patch_build(pkg), pytest.raises(ValueError, match="cannot be empty"): | ||
| p.run(_make_hf_model("org/vlm"), out) | ||
|
|
||
|
|
||
| def test_components_to_export_in_default_config(): | ||
| """components_to_export parameter must appear in _default_config with None default.""" | ||
| accelerator_spec = AcceleratorSpec( | ||
| accelerator_type=Device.CPU, execution_provider=ExecutionProvider.CPUExecutionProvider | ||
| ) | ||
| config = MobiusBuilder._default_config(accelerator_spec) # pylint: disable=protected-access | ||
| assert "components_to_export" in config | ||
| assert config["components_to_export"].default_value is None | ||
| assert config["components_to_export"].required is False | ||
|
|
||
|
|
||
| def test_pkg_save_typeerror_falls_back_gracefully(tmp_path): | ||
| """When pkg.save() raises TypeError for the components= kwarg (old mobius), fall back without filter.""" | ||
| out = tmp_path / "out" | ||
| keys = ["decoder", "vision_encoder", "embedding"] | ||
|
|
||
| # Build a pkg whose save() raises TypeError only when components= is passed (old mobius API). | ||
| def _save_old_api(directory: str, **kwargs): | ||
| if "components" in kwargs: | ||
| raise TypeError("unexpected keyword argument 'components'") | ||
| # Old API: save all components unconditionally. | ||
| d = Path(directory) | ||
| for k in keys: | ||
| (d / k).mkdir(parents=True, exist_ok=True) | ||
| (d / k / "model.onnx").write_text("dummy") | ||
|
|
||
| pkg = MagicMock() | ||
| pkg.keys.return_value = keys | ||
| pkg.__iter__ = MagicMock(return_value=iter(keys)) | ||
| pkg.items.return_value = [(k, MagicMock()) for k in keys] | ||
| pkg.save.side_effect = _save_old_api | ||
|
|
||
| accelerator_spec = AcceleratorSpec( | ||
| accelerator_type=Device.CPU, execution_provider=ExecutionProvider.CPUExecutionProvider | ||
| ) | ||
| p = create_pass_from_dict( | ||
| MobiusBuilder, | ||
| {"precision": "fp16", "components_to_export": ["vision_encoder", "embedding"]}, | ||
| disable_search=True, | ||
| accelerator_spec=accelerator_spec, | ||
| ) | ||
|
|
||
| with ( | ||
| _patch_build(pkg), | ||
| patch("olive.passes.onnx.mobius_model_builder.logger") as mock_logger, | ||
| ): | ||
| result = p.run(_make_hf_model("org/vlm"), out) | ||
|
|
||
| # Old mobius saved all; pass returns CompositeModelHandler with all keys (filter not enforced). | ||
| assert isinstance(result, CompositeModelHandler) | ||
| # A warning must have been logged about the missing kwarg support. | ||
| warning_messages = [str(call) for call in mock_logger.warning.call_args_list] | ||
| assert any("components" in msg and "mobius" in msg for msg in warning_messages) | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed: test renamed to |
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed: replaced the
try/except TypeErrorwithinspect.signature(pkg.save).parameterscheck. This detects kwarg support upfront from the function signature, avoiding any execution ofsave()before the error — so no orphaned dirs and no masked real TypeErrors. The check happens cleanly before the call.