diff --git a/docs/source/features/quantization.md b/docs/source/features/quantization.md index 64c9b8f4b..b3fd55f06 100644 --- a/docs/source/features/quantization.md +++ b/docs/source/features/quantization.md @@ -64,6 +64,8 @@ This pass only supports HuggingFace transformer PyTorch models. This pass supports ONNX models and can quantize `MatMul` and `Gather` nodes to 4 or 8 bits with block-wise quantization. +For multi-component models (e.g. VLMs exported with `MobiusBuilder`), use `components_to_skip` to bypass quantization for specific components that must stay in higher precision. Skipped components are copied unchanged to the output. Unknown component names in `components_to_skip` emit a warning and are otherwise ignored. + ### Example Configuration ```json { @@ -71,6 +73,17 @@ This pass supports ONNX models and can quantize `MatMul` and `Gather` nodes to 4 } ``` +### Skipping Components (Composite Models) +```json +{ + "type": "OnnxBlockWiseRtnQuantization", + "block_size": 128, + "is_symmetric": true, + "accuracy_level": 4, + "components_to_skip": ["embedding"] +} +``` + ## HQQ `HQQ (Half-Quadratic Quantization)` is a fast, calibration-free weight quantization method that enables low-bit quantization of large models without relying on gradient-based optimization. Unlike data-dependent approaches like GPTQ, [HQQ](https://dropbox.github.io/hqq_blog/) uses half-quadratic splitting to minimize weight quantization error efficiently. diff --git a/olive/passes/onnx/rtn_quantization.py b/olive/passes/onnx/rtn_quantization.py index e66cec112..7be24047b 100644 --- a/olive/passes/onnx/rtn_quantization.py +++ b/olive/passes/onnx/rtn_quantization.py @@ -3,6 +3,7 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- import logging +import shutil from pathlib import Path from typing import Optional @@ -25,7 +26,30 @@ class OnnxBlockWiseRtnQuantization(Pass): - """Quantize ONNX models with weight-only block-wise RTN algorithm.""" + """Quantize ONNX models with weight-only block-wise RTN algorithm. + + Quantizes ``MatMul`` and ``Gather`` nodes to 4-bit or 8-bit weights using + the Round-To-Nearest (RTN) algorithm. Supports both plain + :class:`~olive.model.ONNXModelHandler` and multi-component + :class:`~olive.model.handler.composite.CompositeModelHandler` inputs. + + Use ``components_to_skip`` to bypass quantization for specific components + of a composite model. Skipped components are copied unchanged to the + output directory. This is useful when certain components must stay in + higher precision (e.g. an ``embedding`` component where + ``GatherBlockQuantized`` may not be supported):: + + { + "type": "OnnxBlockWiseRtnQuantization", + "block_size": 128, + "is_symmetric": true, + "components_to_skip": ["embedding"] + } + + Unknown component names in ``components_to_skip`` emit a warning and are + otherwise ignored. ``components_to_skip`` has no effect on non-composite + (single-component) models. + """ @classmethod def _default_config(cls, accelerator_spec: AcceleratorSpec) -> dict[str, PassConfigParam]: @@ -69,9 +93,112 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> dict[str, PassCon default_value=None, description="List of node names to include in quantization.", ), + "components_to_skip": PassConfigParam( + type_=list[str], + default_value=None, + description=( + "Optional list of component names to skip quantization for " + "(e.g. ['embedding'] to pass the embedding model through unchanged). " + "When a composite model component's name matches an entry in this list, " + "its files are copied to the output path without modification. " + "When not set, all components are quantized (default, backward compatible). " + "Has no effect on single-component (non-composite) models." + ), + ), **get_external_data_config(), } + def run(self, model, output_model_path: str): + """Run quantization, skipping components listed in components_to_skip. + + Overrides the base Pass.run() to intercept CompositeModelHandler processing. + Components whose names appear in config.components_to_skip are copied to the + output path unchanged instead of being quantized. + + Unknown component names in components_to_skip produce a warning, not an error — + skipping is a non-fatal operation so misspellings are surfaced without aborting. + """ + from olive.model import CompositeModelHandler + + components_to_skip: set[str] = set(self.config.components_to_skip or []) + if not components_to_skip or not isinstance(model, CompositeModelHandler): + return super().run(model, output_model_path) + + # Cache get_model_components() — avoid calling the generator twice. + all_components = list(model.get_model_components()) + + # Warn about component names that won't match anything — misspellings are + # silently ignored otherwise since skipping is non-fatal. + all_component_names = {name for name, _ in all_components} + unknown_skips = components_to_skip - all_component_names + if unknown_skips: + logger.warning( + "OnnxBlockWiseRtnQuantization: components_to_skip contains name(s) not found " + "in this composite model: %s. Available components: %s", + sorted(unknown_skips), + sorted(all_component_names), + ) + + # Mirror the _initialized guard from the base Pass.run() implementation. + # Pass.run() checks and sets self._initialized before calling _run_for_config; + # since we bypass super().run() for composite models, we must replicate it here + # so lazy initialization (e.g. loading config, setting up hardware state) still runs. + if not self._initialized: + self._initialize() + self._initialized = True + + model_dir = Path(output_model_path).with_suffix("") + model_dir.mkdir(parents=True, exist_ok=True) + + components = [] + component_names = [] + for component_name, component_model in all_components: + component_output_path = model_dir / component_name + if component_name in components_to_skip: + logger.info( + "OnnxBlockWiseRtnQuantization: skipping quantization for component '%s'.", + component_name, + ) + src = Path(component_model.model_path) + if src.is_dir(): + # src is the component directory — copy it directly. + if src.resolve() != component_output_path.resolve(): + shutil.rmtree(str(component_output_path), ignore_errors=True) + shutil.copytree(str(src), str(component_output_path)) + else: + # src is the ONNX file — copy only this file (and its .data sidecar + # if present) to avoid accidentally copying sibling files from src.parent. + if src.resolve() != (component_output_path / src.name).resolve(): + shutil.rmtree(str(component_output_path), ignore_errors=True) + component_output_path.mkdir(parents=True, exist_ok=True) + shutil.copy2(str(src), str(component_output_path / src.name)) + data_sidecar = Path(str(src) + ".data") + if data_sidecar.exists(): + shutil.copy2( + str(data_sidecar), + str(component_output_path / data_sidecar.name), + ) + # Derive onnx_file_name from the source model handler; fall back to + # the basename of model_path rather than hardcoding 'model.onnx'. + onnx_file_name = ( + getattr(component_model, "onnx_file_name", None) or Path(component_model.model_path).name + ) + output_component = ONNXModelHandler( + model_path=str(component_output_path), + onnx_file_name=onnx_file_name, + model_attributes=component_model.model_attributes, + ) + Pass._carry_forward_additional_files(component_model, output_component) + else: + output_component = self.run(component_model, str(component_output_path)) + components.append(output_component) + component_names.append(component_name) + + output_model = CompositeModelHandler(components, component_names, model_path=model_dir) + output_model.model_attributes = output_model.model_attributes or model.model_attributes + Pass._carry_forward_additional_files(model, output_model) + return output_model + def _run_for_config( self, model: ONNXModelHandler, config: type[BasePassConfig], output_model_path: str ) -> ONNXModelHandler: diff --git a/test/passes/onnx/test_rtn_quantization.py b/test/passes/onnx/test_rtn_quantization.py index edec80ec0..460cf9142 100644 --- a/test/passes/onnx/test_rtn_quantization.py +++ b/test/passes/onnx/test_rtn_quantization.py @@ -427,3 +427,149 @@ def test_rtn_quantization_removes_unused_initializers(self, matmul_model_path, t assert "weight" not in init_names, ( f"Original FP32 'weight' initializer should have been removed, found: {init_names}" ) + + +class TestRTNQuantizationComponentsToSkip: + """Tests for the components_to_skip parameter on OnnxBlockWiseRtnQuantization.""" + + @staticmethod + def _make_matmul_model(tmp_path, name: str) -> ONNXModelHandler: + """Create a tiny MatMul ONNX model and return an ONNXModelHandler.""" + weight = np.random.randn(64, 128).astype(np.float32) + inp = onnx.helper.make_tensor_value_info("input", onnx.TensorProto.FLOAT, [1, 64]) + out = onnx.helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, [1, 128]) + weight_init = onnx.helper.make_tensor( + name="weight", + data_type=onnx.TensorProto.FLOAT, + dims=[64, 128], + vals=weight.flatten().tolist(), + ) + node = onnx.helper.make_node("MatMul", ["input", "weight"], ["output"], name="MatMul_Node") + graph = onnx.helper.make_graph([node], "g", [inp], [out], initializer=[weight_init]) + model_def = onnx.helper.make_model(graph, producer_name="test") + model_def.opset_import[0].version = 13 + + model_dir = tmp_path / name + model_dir.mkdir(parents=True, exist_ok=True) + onnx.save(model_def, str(model_dir / "model.onnx")) + return ONNXModelHandler(model_path=str(model_dir), onnx_file_name="model.onnx") + + @staticmethod + def _make_pass(components_to_skip=None) -> OnnxBlockWiseRtnQuantization: + accelerator_spec = AcceleratorSpec(accelerator_type="CPU", execution_provider="CPUExecutionProvider") + config = {"bits": 4, "block_size": 128, "axis": 0, "is_symmetric": True} + if components_to_skip is not None: + config["components_to_skip"] = components_to_skip + return create_pass_from_dict( + OnnxBlockWiseRtnQuantization, config, disable_search=True, accelerator_spec=accelerator_spec + ) + + def test_components_to_skip_passes_component_through_unchanged(self, tmp_path): + """Skipped component's model files are copied without quantization.""" + from olive.model.handler.composite import CompositeModelHandler + + decoder = self._make_matmul_model(tmp_path / "src", "decoder") + embedding = self._make_matmul_model(tmp_path / "src", "embedding") + + composite = CompositeModelHandler( + model_components=[decoder, embedding], + model_component_names=["decoder", "embedding"], + model_path=str(tmp_path / "src"), + ) + + p = self._make_pass(components_to_skip=["embedding"]) + result = p.run(composite, str(tmp_path / "out")) + + assert isinstance(result, CompositeModelHandler) + assert result.model_component_names == ["decoder", "embedding"] + + # decoder should be quantized (MatMulNBits present) + decoder_out = next(m for name, m in result.get_model_components() if name == "decoder") + decoder_ir = ir.load(decoder_out.model_path) + assert any(n.op_type == str(OpType.MatMulNBits) for n in decoder_ir.graph.all_nodes()), ( + "decoder should be quantized (MatMulNBits expected)" + ) + + # embedding should be unchanged (original MatMul still present) + emb_out = next(m for name, m in result.get_model_components() if name == "embedding") + emb_ir = ir.load(emb_out.model_path) + has_matmul = any(n.op_type == str(OpType.MatMul) for n in emb_ir.graph.all_nodes()) + has_nbits = any(n.op_type == str(OpType.MatMulNBits) for n in emb_ir.graph.all_nodes()) + assert has_matmul, "embedding should still contain the original MatMul op" + assert not has_nbits, "embedding should not be quantized (no MatMulNBits expected)" + + def test_components_to_skip_none_quantizes_all(self, tmp_path): + """When components_to_skip is not set, all composite components are quantized.""" + from olive.model.handler.composite import CompositeModelHandler + + decoder = self._make_matmul_model(tmp_path / "src", "decoder") + embedding = self._make_matmul_model(tmp_path / "src", "embedding") + + composite = CompositeModelHandler( + model_components=[decoder, embedding], + model_component_names=["decoder", "embedding"], + model_path=str(tmp_path / "src"), + ) + + p = self._make_pass(components_to_skip=None) + result = p.run(composite, str(tmp_path / "out")) + + assert isinstance(result, CompositeModelHandler) + + for name, component in result.get_model_components(): + component_ir = ir.load(component.model_path) + assert any(n.op_type == str(OpType.MatMulNBits) for n in component_ir.graph.all_nodes()), ( + f"component '{name}' should be quantized when components_to_skip is None" + ) + + def test_components_to_skip_does_not_affect_single_model(self, tmp_path): + """components_to_skip has no effect on non-composite (single) models.""" + model = self._make_matmul_model(tmp_path, "single") + p = self._make_pass(components_to_skip=["single"]) + result = p.run(model, str(tmp_path / "out")) + + # Single model should still be quantized despite its path matching the skip list + result_ir = ir.load(result.model_path) + assert any(n.op_type == str(OpType.MatMulNBits) for n in result_ir.graph.all_nodes()), ( + "Single-component model should be quantized even when components_to_skip is set" + ) + + def test_components_to_skip_in_default_config(self): + """components_to_skip must appear in _default_config with None as default.""" + accelerator_spec = AcceleratorSpec(accelerator_type="CPU", execution_provider="CPUExecutionProvider") + config = OnnxBlockWiseRtnQuantization._default_config(accelerator_spec) # pylint: disable=protected-access + assert "components_to_skip" in config + assert config["components_to_skip"].default_value is None + assert config["components_to_skip"].required is False + + def test_components_to_skip_unknown_name_warns(self, tmp_path): + """Misspelled or missing component names in components_to_skip must log a warning.""" + from olive.model.handler.composite import CompositeModelHandler + + decoder = self._make_matmul_model(tmp_path / "src", "decoder") + vision = self._make_matmul_model(tmp_path / "src", "vision") + composite = CompositeModelHandler( + model_components=[decoder, vision], + model_component_names=["decoder", "vision"], + ) + + p = self._make_pass(components_to_skip=["typo_component"]) + + import logging + + records = [] + + class _Handler(logging.Handler): + def emit(self, record): + records.append(record.getMessage()) + + rtn_logger = logging.getLogger("olive.passes.onnx.rtn_quantization") + rtn_logger.addHandler(_Handler()) + try: + p.run(composite, str(tmp_path / "out")) + finally: + rtn_logger.handlers = [h for h in rtn_logger.handlers if not isinstance(h, _Handler)] + + assert any("typo_component" in msg for msg in records), ( + f"Expected warning about unknown component name 'typo_component', got: {records}" + )