Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions docs/source/features/quantization.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,26 @@ This pass only supports HuggingFace transformer PyTorch models.

This pass supports ONNX models and can quantize `MatMul` and `Gather` nodes to 4 or 8 bits with block-wise quantization.

For multi-component models (e.g. VLMs exported with `MobiusBuilder`), use `components_to_skip` to bypass quantization for specific components that must stay in higher precision. Skipped components are copied unchanged to the output. Unknown component names in `components_to_skip` emit a warning and are otherwise ignored.

### Example Configuration
```json
{
"type": "OnnxBlockWiseRtnQuantization"
}
```

### Skipping Components (Composite Models)
```json
{
"type": "OnnxBlockWiseRtnQuantization",
"block_size": 128,
"is_symmetric": true,
"accuracy_level": 4,
"components_to_skip": ["embedding"]
}
```

## HQQ
`HQQ (Half-Quadratic Quantization)` is a fast, calibration-free weight quantization method that enables low-bit quantization of large models without relying on gradient-based optimization. Unlike data-dependent approaches like GPTQ, [HQQ](https://dropbox.github.io/hqq_blog/) uses half-quadratic splitting to minimize weight quantization error efficiently.

Expand Down
129 changes: 128 additions & 1 deletion olive/passes/onnx/rtn_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import logging
import shutil
from pathlib import Path
from typing import Optional

Expand All @@ -25,7 +26,30 @@


class OnnxBlockWiseRtnQuantization(Pass):
"""Quantize ONNX models with weight-only block-wise RTN algorithm."""
"""Quantize ONNX models with weight-only block-wise RTN algorithm.

Quantizes ``MatMul`` and ``Gather`` nodes to 4-bit or 8-bit weights using
the Round-To-Nearest (RTN) algorithm. Supports both plain
:class:`~olive.model.ONNXModelHandler` and multi-component
:class:`~olive.model.handler.composite.CompositeModelHandler` inputs.

Use ``components_to_skip`` to bypass quantization for specific components
of a composite model. Skipped components are copied unchanged to the
output directory. This is useful when certain components must stay in
higher precision (e.g. an ``embedding`` component where
``GatherBlockQuantized`` may not be supported)::

{
"type": "OnnxBlockWiseRtnQuantization",
"block_size": 128,
"is_symmetric": true,
"components_to_skip": ["embedding"]
}

Unknown component names in ``components_to_skip`` emit a warning and are
otherwise ignored. ``components_to_skip`` has no effect on non-composite
(single-component) models.
"""

@classmethod
def _default_config(cls, accelerator_spec: AcceleratorSpec) -> dict[str, PassConfigParam]:
Expand Down Expand Up @@ -69,9 +93,112 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> dict[str, PassCon
default_value=None,
description="List of node names to include in quantization.",
),
"components_to_skip": PassConfigParam(
type_=list[str],
default_value=None,
description=(
"Optional list of component names to skip quantization for "
"(e.g. ['embedding'] to pass the embedding model through unchanged). "
"When a composite model component's name matches an entry in this list, "
"its files are copied to the output path without modification. "
"When not set, all components are quantized (default, backward compatible). "
"Has no effect on single-component (non-composite) models."
),
),
**get_external_data_config(),
}

def run(self, model, output_model_path: str):
"""Run quantization, skipping components listed in components_to_skip.

Overrides the base Pass.run() to intercept CompositeModelHandler processing.
Components whose names appear in config.components_to_skip are copied to the
output path unchanged instead of being quantized.

Unknown component names in components_to_skip produce a warning, not an error —
skipping is a non-fatal operation so misspellings are surfaced without aborting.
"""
from olive.model import CompositeModelHandler

components_to_skip: set[str] = set(self.config.components_to_skip or [])
if not components_to_skip or not isinstance(model, CompositeModelHandler):
return super().run(model, output_model_path)

# Cache get_model_components() — avoid calling the generator twice.
all_components = list(model.get_model_components())

# Warn about component names that won't match anything — misspellings are
# silently ignored otherwise since skipping is non-fatal.
all_component_names = {name for name, _ in all_components}
unknown_skips = components_to_skip - all_component_names
if unknown_skips:
logger.warning(
"OnnxBlockWiseRtnQuantization: components_to_skip contains name(s) not found "
"in this composite model: %s. Available components: %s",
sorted(unknown_skips),
sorted(all_component_names),
)

# Mirror the _initialized guard from the base Pass.run() implementation.
# Pass.run() checks and sets self._initialized before calling _run_for_config;
# since we bypass super().run() for composite models, we must replicate it here
# so lazy initialization (e.g. loading config, setting up hardware state) still runs.
if not self._initialized:
self._initialize()
self._initialized = True

model_dir = Path(output_model_path).with_suffix("")
model_dir.mkdir(parents=True, exist_ok=True)

components = []
component_names = []
for component_name, component_model in all_components:
component_output_path = model_dir / component_name
if component_name in components_to_skip:
logger.info(
"OnnxBlockWiseRtnQuantization: skipping quantization for component '%s'.",
component_name,
)
src = Path(component_model.model_path)
if src.is_dir():
# src is the component directory — copy it directly.
if src.resolve() != component_output_path.resolve():
shutil.rmtree(str(component_output_path), ignore_errors=True)
shutil.copytree(str(src), str(component_output_path))
else:
# src is the ONNX file — copy only this file (and its .data sidecar
# if present) to avoid accidentally copying sibling files from src.parent.
if src.resolve() != (component_output_path / src.name).resolve():
shutil.rmtree(str(component_output_path), ignore_errors=True)
component_output_path.mkdir(parents=True, exist_ok=True)
shutil.copy2(str(src), str(component_output_path / src.name))
data_sidecar = Path(str(src) + ".data")
if data_sidecar.exists():
shutil.copy2(
str(data_sidecar),
str(component_output_path / data_sidecar.name),
)
# Derive onnx_file_name from the source model handler; fall back to
# the basename of model_path rather than hardcoding 'model.onnx'.
onnx_file_name = (
getattr(component_model, "onnx_file_name", None) or Path(component_model.model_path).name
)
output_component = ONNXModelHandler(
model_path=str(component_output_path),
onnx_file_name=onnx_file_name,
model_attributes=component_model.model_attributes,
)
Pass._carry_forward_additional_files(component_model, output_component)
else:
output_component = self.run(component_model, str(component_output_path))
components.append(output_component)
component_names.append(component_name)

output_model = CompositeModelHandler(components, component_names, model_path=model_dir)
output_model.model_attributes = output_model.model_attributes or model.model_attributes
Pass._carry_forward_additional_files(model, output_model)
return output_model

def _run_for_config(
self, model: ONNXModelHandler, config: type[BasePassConfig], output_model_path: str
) -> ONNXModelHandler:
Expand Down
146 changes: 146 additions & 0 deletions test/passes/onnx/test_rtn_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,3 +427,149 @@ def test_rtn_quantization_removes_unused_initializers(self, matmul_model_path, t
assert "weight" not in init_names, (
f"Original FP32 'weight' initializer should have been removed, found: {init_names}"
)


class TestRTNQuantizationComponentsToSkip:
"""Tests for the components_to_skip parameter on OnnxBlockWiseRtnQuantization."""

@staticmethod
def _make_matmul_model(tmp_path, name: str) -> ONNXModelHandler:
"""Create a tiny MatMul ONNX model and return an ONNXModelHandler."""
weight = np.random.randn(64, 128).astype(np.float32)
inp = onnx.helper.make_tensor_value_info("input", onnx.TensorProto.FLOAT, [1, 64])
out = onnx.helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, [1, 128])
weight_init = onnx.helper.make_tensor(
name="weight",
data_type=onnx.TensorProto.FLOAT,
dims=[64, 128],
vals=weight.flatten().tolist(),
)
node = onnx.helper.make_node("MatMul", ["input", "weight"], ["output"], name="MatMul_Node")
graph = onnx.helper.make_graph([node], "g", [inp], [out], initializer=[weight_init])
model_def = onnx.helper.make_model(graph, producer_name="test")
model_def.opset_import[0].version = 13

model_dir = tmp_path / name
model_dir.mkdir(parents=True, exist_ok=True)
onnx.save(model_def, str(model_dir / "model.onnx"))
return ONNXModelHandler(model_path=str(model_dir), onnx_file_name="model.onnx")

@staticmethod
def _make_pass(components_to_skip=None) -> OnnxBlockWiseRtnQuantization:
accelerator_spec = AcceleratorSpec(accelerator_type="CPU", execution_provider="CPUExecutionProvider")
config = {"bits": 4, "block_size": 128, "axis": 0, "is_symmetric": True}
if components_to_skip is not None:
config["components_to_skip"] = components_to_skip
return create_pass_from_dict(
OnnxBlockWiseRtnQuantization, config, disable_search=True, accelerator_spec=accelerator_spec
)

def test_components_to_skip_passes_component_through_unchanged(self, tmp_path):
"""Skipped component's model files are copied without quantization."""
from olive.model.handler.composite import CompositeModelHandler

decoder = self._make_matmul_model(tmp_path / "src", "decoder")
embedding = self._make_matmul_model(tmp_path / "src", "embedding")

composite = CompositeModelHandler(
model_components=[decoder, embedding],
model_component_names=["decoder", "embedding"],
model_path=str(tmp_path / "src"),
)

p = self._make_pass(components_to_skip=["embedding"])
result = p.run(composite, str(tmp_path / "out"))

assert isinstance(result, CompositeModelHandler)
assert result.model_component_names == ["decoder", "embedding"]

# decoder should be quantized (MatMulNBits present)
decoder_out = next(m for name, m in result.get_model_components() if name == "decoder")
decoder_ir = ir.load(decoder_out.model_path)
assert any(n.op_type == str(OpType.MatMulNBits) for n in decoder_ir.graph.all_nodes()), (
"decoder should be quantized (MatMulNBits expected)"
)

# embedding should be unchanged (original MatMul still present)
emb_out = next(m for name, m in result.get_model_components() if name == "embedding")
emb_ir = ir.load(emb_out.model_path)
has_matmul = any(n.op_type == str(OpType.MatMul) for n in emb_ir.graph.all_nodes())
has_nbits = any(n.op_type == str(OpType.MatMulNBits) for n in emb_ir.graph.all_nodes())
assert has_matmul, "embedding should still contain the original MatMul op"
assert not has_nbits, "embedding should not be quantized (no MatMulNBits expected)"

def test_components_to_skip_none_quantizes_all(self, tmp_path):
"""When components_to_skip is not set, all composite components are quantized."""
from olive.model.handler.composite import CompositeModelHandler

decoder = self._make_matmul_model(tmp_path / "src", "decoder")
embedding = self._make_matmul_model(tmp_path / "src", "embedding")

composite = CompositeModelHandler(
model_components=[decoder, embedding],
model_component_names=["decoder", "embedding"],
model_path=str(tmp_path / "src"),
)

p = self._make_pass(components_to_skip=None)
result = p.run(composite, str(tmp_path / "out"))

assert isinstance(result, CompositeModelHandler)

for name, component in result.get_model_components():
component_ir = ir.load(component.model_path)
assert any(n.op_type == str(OpType.MatMulNBits) for n in component_ir.graph.all_nodes()), (
f"component '{name}' should be quantized when components_to_skip is None"
)

def test_components_to_skip_does_not_affect_single_model(self, tmp_path):
"""components_to_skip has no effect on non-composite (single) models."""
model = self._make_matmul_model(tmp_path, "single")
p = self._make_pass(components_to_skip=["single"])
result = p.run(model, str(tmp_path / "out"))

# Single model should still be quantized despite its path matching the skip list
result_ir = ir.load(result.model_path)
assert any(n.op_type == str(OpType.MatMulNBits) for n in result_ir.graph.all_nodes()), (
"Single-component model should be quantized even when components_to_skip is set"
)

def test_components_to_skip_in_default_config(self):
"""components_to_skip must appear in _default_config with None as default."""
accelerator_spec = AcceleratorSpec(accelerator_type="CPU", execution_provider="CPUExecutionProvider")
config = OnnxBlockWiseRtnQuantization._default_config(accelerator_spec) # pylint: disable=protected-access
assert "components_to_skip" in config
assert config["components_to_skip"].default_value is None
assert config["components_to_skip"].required is False

def test_components_to_skip_unknown_name_warns(self, tmp_path):
"""Misspelled or missing component names in components_to_skip must log a warning."""
from olive.model.handler.composite import CompositeModelHandler

decoder = self._make_matmul_model(tmp_path / "src", "decoder")
vision = self._make_matmul_model(tmp_path / "src", "vision")
composite = CompositeModelHandler(
model_components=[decoder, vision],
model_component_names=["decoder", "vision"],
)

p = self._make_pass(components_to_skip=["typo_component"])

import logging

records = []

class _Handler(logging.Handler):
def emit(self, record):
records.append(record.getMessage())

rtn_logger = logging.getLogger("olive.passes.onnx.rtn_quantization")
rtn_logger.addHandler(_Handler())
try:
p.run(composite, str(tmp_path / "out"))
finally:
rtn_logger.handlers = [h for h in rtn_logger.handlers if not isinstance(h, _Handler)]

assert any("typo_component" in msg for msg in records), (
f"Expected warning about unknown component name 'typo_component', got: {records}"
)
Loading