Skip to content
Open
84 changes: 84 additions & 0 deletions olive/passes/onnx/rtn_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import logging
import shutil
from pathlib import Path
from typing import Optional

Expand Down Expand Up @@ -69,9 +70,92 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> dict[str, PassCon
default_value=None,
description="List of node names to include in quantization.",
),
"components_to_skip": PassConfigParam(
type_=list[str],
default_value=None,
description=(
"Optional list of component names to skip quantization for "
"(e.g. ['embedding'] to pass the embedding model through unchanged). "
"When a composite model component's name matches an entry in this list, "
"its files are copied to the output path without modification. "
"When not set, all components are quantized (default, backward compatible). "
"Has no effect on single-component (non-composite) models."
),
),
**get_external_data_config(),
}

def run(self, model, output_model_path: str):
"""Run quantization, skipping components listed in components_to_skip.

Overrides the base Pass.run() to intercept CompositeModelHandler processing.
Components whose names appear in config.components_to_skip are copied to the
output path unchanged instead of being quantized.
"""
from olive.model import CompositeModelHandler

components_to_skip: set[str] = set(self.config.components_to_skip or [])
if not components_to_skip or not isinstance(model, CompositeModelHandler):
return super().run(model, output_model_path)

# Warn about component names that won't match anything — misspellings are
# silently ignored otherwise since skipping is non-fatal.
all_component_names = {name for name, _ in model.get_model_components()}
unknown_skips = components_to_skip - all_component_names
if unknown_skips:
logger.warning(
"OnnxBlockWiseRtnQuantization: components_to_skip contains name(s) not found "
"in this composite model: %s. Available components: %s",
sorted(unknown_skips),
sorted(all_component_names),
)

# Mirror the _initialized guard from the base Pass.run() implementation.
# Pass.run() checks and sets self._initialized before calling _run_for_config;
# since we bypass super().run() for composite models, we must replicate it here
# so lazy initialization (e.g. loading config, setting up hardware state) still runs.
if not self._initialized:
self._initialize()
self._initialized = True

model_dir = Path(output_model_path).with_suffix("")
model_dir.mkdir(parents=True, exist_ok=True)

components = []
component_names = []
for component_name, component_model in model.get_model_components():
component_output_path = model_dir / component_name
if component_name in components_to_skip:
logger.info(
"OnnxBlockWiseRtnQuantization: skipping quantization for component '%s'.",
component_name,
)
src = Path(component_model.model_path)
# model_path may point to the .onnx file rather than its parent dir
src_dir = src.parent if src.is_file() else src
if src_dir != component_output_path:
if component_output_path.exists():
shutil.rmtree(str(component_output_path))
shutil.copytree(str(src_dir), str(component_output_path))
Comment thread
titaiwangms marked this conversation as resolved.
Outdated
# onnx_file_name may be None if the handler was created without an explicit name;
# fall back to 'model.onnx' which is the standard Olive convention.
onnx_file_name = getattr(component_model, "onnx_file_name", None) or "model.onnx"
output_component = ONNXModelHandler(
model_path=str(component_output_path),
onnx_file_name=onnx_file_name,
model_attributes=component_model.model_attributes,
)
Comment thread
titaiwangms marked this conversation as resolved.
Outdated
Pass._carry_forward_additional_files(component_model, output_component)
else:
output_component = self.run(component_model, str(component_output_path))
components.append(output_component)
component_names.append(component_name)

output_model = CompositeModelHandler(components, component_names, model_path=model_dir)
output_model.model_attributes = output_model.model_attributes or model.model_attributes
Pass._carry_forward_additional_files(model, output_model)
return output_model

def _run_for_config(
self, model: ONNXModelHandler, config: type[BasePassConfig], output_model_path: str
) -> ONNXModelHandler:
Expand Down
146 changes: 146 additions & 0 deletions test/passes/onnx/test_rtn_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,3 +427,149 @@ def test_rtn_quantization_removes_unused_initializers(self, matmul_model_path, t
assert "weight" not in init_names, (
f"Original FP32 'weight' initializer should have been removed, found: {init_names}"
)


class TestRTNQuantizationComponentsToSkip:
"""Tests for the components_to_skip parameter on OnnxBlockWiseRtnQuantization."""

@staticmethod
def _make_matmul_model(tmp_path, name: str) -> ONNXModelHandler:
"""Create a tiny MatMul ONNX model and return an ONNXModelHandler."""
weight = np.random.randn(64, 128).astype(np.float32)
inp = onnx.helper.make_tensor_value_info("input", onnx.TensorProto.FLOAT, [1, 64])
out = onnx.helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, [1, 128])
weight_init = onnx.helper.make_tensor(
name="weight",
data_type=onnx.TensorProto.FLOAT,
dims=[64, 128],
vals=weight.flatten().tolist(),
)
node = onnx.helper.make_node("MatMul", ["input", "weight"], ["output"], name="MatMul_Node")
graph = onnx.helper.make_graph([node], "g", [inp], [out], initializer=[weight_init])
model_def = onnx.helper.make_model(graph, producer_name="test")
model_def.opset_import[0].version = 13

model_dir = tmp_path / name
model_dir.mkdir(parents=True, exist_ok=True)
onnx.save(model_def, str(model_dir / "model.onnx"))
return ONNXModelHandler(model_path=str(model_dir), onnx_file_name="model.onnx")

@staticmethod
def _make_pass(components_to_skip=None) -> OnnxBlockWiseRtnQuantization:
accelerator_spec = AcceleratorSpec(accelerator_type="CPU", execution_provider="CPUExecutionProvider")
config = {"bits": 4, "block_size": 128, "axis": 0, "is_symmetric": True}
if components_to_skip is not None:
config["components_to_skip"] = components_to_skip
return create_pass_from_dict(
OnnxBlockWiseRtnQuantization, config, disable_search=True, accelerator_spec=accelerator_spec
)

def test_components_to_skip_passes_component_through_unchanged(self, tmp_path):
"""Skipped component's model files are copied without quantization."""
from olive.model.handler.composite import CompositeModelHandler

decoder = self._make_matmul_model(tmp_path / "src", "decoder")
embedding = self._make_matmul_model(tmp_path / "src", "embedding")

composite = CompositeModelHandler(
model_components=[decoder, embedding],
model_component_names=["decoder", "embedding"],
model_path=str(tmp_path / "src"),
)

p = self._make_pass(components_to_skip=["embedding"])
result = p.run(composite, str(tmp_path / "out"))

assert isinstance(result, CompositeModelHandler)
assert result.model_component_names == ["decoder", "embedding"]

# decoder should be quantized (MatMulNBits present)
decoder_out = next(m for name, m in result.get_model_components() if name == "decoder")
decoder_ir = ir.load(decoder_out.model_path)
assert any(n.op_type == str(OpType.MatMulNBits) for n in decoder_ir.graph.all_nodes()), (
"decoder should be quantized (MatMulNBits expected)"
)

# embedding should be unchanged (original MatMul still present)
emb_out = next(m for name, m in result.get_model_components() if name == "embedding")
emb_ir = ir.load(emb_out.model_path)
has_matmul = any(n.op_type == str(OpType.MatMul) for n in emb_ir.graph.all_nodes())
has_nbits = any(n.op_type == str(OpType.MatMulNBits) for n in emb_ir.graph.all_nodes())
assert has_matmul, "embedding should still contain the original MatMul op"
assert not has_nbits, "embedding should not be quantized (no MatMulNBits expected)"

def test_components_to_skip_none_quantizes_all(self, tmp_path):
"""When components_to_skip is not set, all composite components are quantized."""
from olive.model.handler.composite import CompositeModelHandler

decoder = self._make_matmul_model(tmp_path / "src", "decoder")
embedding = self._make_matmul_model(tmp_path / "src", "embedding")

composite = CompositeModelHandler(
model_components=[decoder, embedding],
model_component_names=["decoder", "embedding"],
model_path=str(tmp_path / "src"),
)

p = self._make_pass(components_to_skip=None)
result = p.run(composite, str(tmp_path / "out"))

assert isinstance(result, CompositeModelHandler)

for name, component in result.get_model_components():
component_ir = ir.load(component.model_path)
assert any(n.op_type == str(OpType.MatMulNBits) for n in component_ir.graph.all_nodes()), (
f"component '{name}' should be quantized when components_to_skip is None"
)

def test_components_to_skip_does_not_affect_single_model(self, tmp_path):
"""components_to_skip has no effect on non-composite (single) models."""
model = self._make_matmul_model(tmp_path, "single")
p = self._make_pass(components_to_skip=["single"])
result = p.run(model, str(tmp_path / "out"))

# Single model should still be quantized despite its path matching the skip list
result_ir = ir.load(result.model_path)
assert any(n.op_type == str(OpType.MatMulNBits) for n in result_ir.graph.all_nodes()), (
"Single-component model should be quantized even when components_to_skip is set"
)

def test_components_to_skip_in_default_config(self):
"""components_to_skip must appear in _default_config with None as default."""
accelerator_spec = AcceleratorSpec(accelerator_type="CPU", execution_provider="CPUExecutionProvider")
config = OnnxBlockWiseRtnQuantization._default_config(accelerator_spec) # pylint: disable=protected-access
assert "components_to_skip" in config
assert config["components_to_skip"].default_value is None
assert config["components_to_skip"].required is False

def test_components_to_skip_unknown_name_warns(self, tmp_path):
"""Misspelled or missing component names in components_to_skip must log a warning."""
from olive.model.handler.composite import CompositeModelHandler

decoder = self._make_matmul_model(tmp_path / "src", "decoder")
vision = self._make_matmul_model(tmp_path / "src", "vision")
composite = CompositeModelHandler(
model_components=[decoder, vision],
model_component_names=["decoder", "vision"],
)

p = self._make_pass(components_to_skip=["typo_component"])

import logging

records = []

class _Handler(logging.Handler):
def emit(self, record):
records.append(record.getMessage())

rtn_logger = logging.getLogger("olive.passes.onnx.rtn_quantization")
rtn_logger.addHandler(_Handler())
try:
p.run(composite, str(tmp_path / "out"))
finally:
rtn_logger.handlers = [h for h in rtn_logger.handlers if not isinstance(h, _Handler)]

assert any("typo_component" in msg for msg in records), (
f"Expected warning about unknown component name 'typo_component', got: {records}"
)
Loading