diff --git a/olive/olive_config.json b/olive/olive_config.json index 50e1f36d6..8e6ae9274 100644 --- a/olive/olive_config.json +++ b/olive/olive_config.json @@ -520,6 +520,15 @@ "supported_quantization_encodings": [ ], "extra_dependencies": [ "qairt-dev" ] }, + "QairtPipelinePass": { + "module_path": "olive.passes.qairt.pipeline.QairtPipelinePass", + "supported_providers": [ "QNNExecutionProvider" ], + "supported_accelerators": [ "npu" ], + "supported_precisions": [ "*" ], + "supported_algorithms": [ ], + "supported_quantization_encodings": [ ], + "extra_dependencies": [ "qairt-dev" ] + }, "QairtPreparation": { "module_path": "olive.passes.qairt.preparation.QairtPreparation", "supported_providers": [ "QNNExecutionProvider" ], diff --git a/olive/passes/qairt/pipeline.py b/olive/passes/qairt/pipeline.py new file mode 100644 index 000000000..ca630774e --- /dev/null +++ b/olive/passes/qairt/pipeline.py @@ -0,0 +1,157 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: MIT +# -------------------------------------------------------------------------- + +import logging +import shutil +from pathlib import Path + +from olive.common.config_utils import ParamCategory +from olive.hardware.accelerator import AcceleratorSpec +from olive.model import HfModelHandler, QairtModelHandler +from olive.passes import Pass +from olive.passes.pass_config import BasePassConfig, PassConfigParam +from olive.passes.qairt.utils import QairtLogLevel + +logger = logging.getLogger(__name__) + + +class QairtPipelinePass(Pass): + """Run a QairtPipeline from a YAML recipe on a HuggingFace model. + + Executes the full LLMPipeline workflow (model loading, quantization, compilation) + defined by the recipe and exports the result as a QairtModelHandler. This pass + is intended to replace the QairtPreparation -> QairtGenAIBuilder workflow. + + The input HfModelHandler is the authoritative source for the model identity. + If the recipe also specifies model_id_or_path and it differs from the handler's + path, an error is raised. If the recipe omits model_id_or_path, the handler's + path is used. + """ + + @classmethod + def _default_config(cls, accelerator_spec: AcceleratorSpec) -> dict[str, PassConfigParam]: + return { + "recipe": PassConfigParam( + type_=str, + required=True, + category=ParamCategory.PATH, + description="Path to the YAML recipe file that defines the LLM pipeline stages " + "(model loading, quantization, genai_builder, etc.).", + ), + "cache_dir": PassConfigParam( + type_=str, + required=False, + default_value=None, + description="Directory for pipeline intermediate artifacts. " + "Overrides the recipe's cache_dir field when set.", + ), + "log_level": PassConfigParam( + type_=QairtLogLevel, + required=False, + default_value=None, + description="Log level for underlying QAIRT pipeline components. " + "Valid values: DEBUG, INFO, WARNING, ERROR, TRACE. " + "Overrides the recipe's log_level field when set.", + ), + } + + @classmethod + def validate_config( + cls, + config: type[BasePassConfig], + accelerator_spec: AcceleratorSpec, + ) -> bool: + # Only validates the top-level qairt import. The qairt.experimental.pipeline.* + # sub-modules are not checked here; if they are absent (e.g. older SDK), the + # error surfaces in _run_for_config instead. + try: + import qairt # noqa: F401 # pylint: disable=unused-import + except ImportError as exc: + raise ImportError( + "Failed to import QAIRT SDK - please install olive-ai[qairt] to use QAIRT passes. " + "If already installed, please run `qairt-vm -i` for help troubleshooting issues." + ) from exc + + return True + + def _run_for_config( + self, + model: HfModelHandler, + config: type[BasePassConfig], + output_model_path: str, + ) -> QairtModelHandler: + try: + import qairt # noqa: F401 # pylint: disable=unused-import + from qairt.experimental.pipeline.torch.common.recipe import Recipe + from qairt.experimental.pipeline.torch.llm.pipeline import LLMPipeline + except ImportError as exc: + raise ImportError( + "Failed to import QAIRT Pipeline API - please install olive-ai[qairt] to use QAIRT passes. " + "If already installed, please run `qairt-vm -i` for help troubleshooting issues." + ) from exc + + if not isinstance(model, HfModelHandler): + raise ValueError(f"QairtPipelinePass requires HfModelHandler as input, got {type(model).__name__}") + + recipe_path = Path(config.recipe).resolve() + if not recipe_path.exists(): + raise ValueError(f"Recipe file not found at: {recipe_path}") + + recipe_data = dict(Recipe.from_file(recipe_path)) + + recipe_model_id = recipe_data.get("model_id_or_path") + if recipe_model_id and recipe_model_id != model.model_path: + raise ValueError( + f"Conflict between recipe model_id_or_path '{recipe_model_id}' and input model " + f"path '{model.model_path}'. Remove model_id_or_path from the recipe or ensure " + "it matches the input model path." + ) + + if config.cache_dir is not None: + recipe_data["cache_dir"] = config.cache_dir + if config.log_level is not None: + recipe_data["log_level"] = config.log_level + + pipe = LLMPipeline.from_pretrained(model.model_path, recipe=recipe_data) + pipe.construct() + + Path(output_model_path).mkdir(parents=True, exist_ok=True) + pipe.export(output_model_path) + + # QairtEncapsulation needs config.json and generation_config.json to generate + # genai_config.json. Resolve the local HF cache path (model.model_path may be a + # HuggingFace repo ID rather than a local directory) and copy if not already present. + try: + from huggingface_hub import snapshot_download + + local_model_path = snapshot_download( + model.model_path, + local_files_only=True, + ignore_patterns=["*.pt", "*.bin", "*.safetensors"], + ) + except Exception as e: + logger.warning( + "Failed to resolve local HF cache for '%s': %s. File copy will be skipped.", + model.model_path, + e, + ) + local_model_path = model.model_path + + for fname in ("config.json", "generation_config.json"): + src = Path(local_model_path) / fname + dst = Path(output_model_path) / fname + if src.exists() and not dst.exists(): + shutil.copy2(src, dst) + + # The pipeline exports chat_template files into a chat_template/ subdirectory. + # QairtEncapsulation expects these as flat files in the model root. + chat_template_dir = Path(output_model_path) / "chat_template" + for fname in ("chat_template.jinja", "tokenizer_config.json"): + src = chat_template_dir / fname + dst = Path(output_model_path) / fname + if src.exists() and not dst.exists(): + shutil.copy2(src, dst) + + return QairtModelHandler(model_path=output_model_path) diff --git a/test/passes/qairt/test_pipeline_pass.py b/test/passes/qairt/test_pipeline_pass.py new file mode 100644 index 000000000..09ab4c096 --- /dev/null +++ b/test/passes/qairt/test_pipeline_pass.py @@ -0,0 +1,249 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: MIT +# -------------------------------------------------------------------------- +# pylint: disable=protected-access + +import builtins +from unittest.mock import MagicMock, patch + +import pytest + +from olive.model import QairtModelHandler +from olive.passes.olive_pass import create_pass_from_dict +from olive.passes.qairt.pipeline import QairtPipelinePass + + +@pytest.fixture(name="mock_pipeline_modules") +def mock_pipeline_modules_fixture(): + """Mock qairt and the LLMPipeline API.""" + mock_qairt = MagicMock() + mock_recipe_cls = MagicMock() + mock_pipeline_cls = MagicMock() + + mock_pipeline = MagicMock() + mock_pipeline_cls.from_pretrained.return_value = mock_pipeline + + with ( + patch.dict("sys.modules", {"qairt": mock_qairt}), + patch( + "qairt.experimental.pipeline.torch.common.recipe.Recipe", + mock_recipe_cls, + create=True, + ), + patch( + "qairt.experimental.pipeline.torch.llm.pipeline.LLMPipeline", + mock_pipeline_cls, + create=True, + ), + patch.dict( + "sys.modules", + { + "qairt.experimental.pipeline.torch.common.recipe": MagicMock(Recipe=mock_recipe_cls), + "qairt.experimental.pipeline.torch.llm.pipeline": MagicMock(LLMPipeline=mock_pipeline_cls), + }, + ), + ): + yield { + "qairt": mock_qairt, + "Recipe": mock_recipe_cls, + "LLMPipeline": mock_pipeline_cls, + "pipeline": mock_pipeline, + } + + +@pytest.fixture(name="recipe_file") +def recipe_file_fixture(tmp_path): + # Content is irrelevant — Recipe.from_file is mocked in every test that uses this fixture. + # The file must exist so that the recipe_path.exists() guard in _run_for_config passes. + path = tmp_path / "recipe.yaml" + path.write_text("") + return path + + +@pytest.fixture(name="recipe_file_with_model_id") +def recipe_file_with_model_id_fixture(tmp_path): + # Content is irrelevant — Recipe.from_file is mocked in every test that uses this fixture. + # The file must exist so that the recipe_path.exists() guard in _run_for_config passes. + path = tmp_path / "recipe_with_model.yaml" + path.write_text("") + return path + + +def test_pipeline_pass_default_config(mock_accelerator_spec): + """Test that the default config has the expected parameters.""" + config = QairtPipelinePass._default_config(mock_accelerator_spec) + + assert "recipe" in config + assert config["recipe"].required is True + assert "cache_dir" in config + assert config["cache_dir"].default_value is None + assert "log_level" in config + assert config["log_level"].default_value is None + + +def test_pipeline_pass_success(tmp_path, mock_hf_model, recipe_file, mock_pipeline_modules): + """Test successful pass execution with no model_id_or_path in recipe.""" + output_path = tmp_path / "output" + + mock_pipeline_modules["Recipe"].from_file.return_value = { + "cache_dir": "./pipeline_cache", + "backend": "HTP", + "stages": {}, + } + + pipeline_pass = create_pass_from_dict( + QairtPipelinePass, + {"recipe": str(recipe_file)}, + disable_search=True, + ) + + result = pipeline_pass.run(mock_hf_model, str(output_path)) + + assert isinstance(result, QairtModelHandler) + assert result.model_path == str(output_path) + mock_pipeline_modules["LLMPipeline"].from_pretrained.assert_called_once_with( + mock_hf_model.model_path, + recipe={"cache_dir": "./pipeline_cache", "backend": "HTP", "stages": {}}, + ) + mock_pipeline_modules["pipeline"].construct.assert_called_once() + mock_pipeline_modules["pipeline"].export.assert_called_once_with(str(output_path)) + + +def test_pipeline_pass_recipe_model_id_matches_handler( + tmp_path, mock_hf_model, recipe_file_with_model_id, mock_pipeline_modules +): + """Test that no error is raised when recipe model_id_or_path matches the handler path.""" + output_path = tmp_path / "output" + + mock_pipeline_modules["Recipe"].from_file.return_value = { + "model_id_or_path": mock_hf_model.model_path, + "stages": {}, + } + + pipeline_pass = create_pass_from_dict( + QairtPipelinePass, + {"recipe": str(recipe_file_with_model_id)}, + disable_search=True, + ) + + result = pipeline_pass.run(mock_hf_model, str(output_path)) + assert isinstance(result, QairtModelHandler) + + +def test_pipeline_pass_recipe_model_id_conflict_raises( + tmp_path, mock_hf_model, recipe_file_with_model_id, mock_pipeline_modules +): + """Test that a ValueError is raised when recipe model_id_or_path conflicts with handler path.""" + output_path = tmp_path / "output" + + mock_pipeline_modules["Recipe"].from_file.return_value = { + "model_id_or_path": "meta-llama/Llama-3.2-3B-Instruct", + "stages": {}, + } + + pipeline_pass = create_pass_from_dict( + QairtPipelinePass, + {"recipe": str(recipe_file_with_model_id)}, + disable_search=True, + ) + + with pytest.raises(ValueError, match="Conflict between recipe model_id_or_path"): + pipeline_pass.run(mock_hf_model, str(output_path)) + + +def test_pipeline_pass_cache_dir_override(tmp_path, mock_hf_model, recipe_file, mock_pipeline_modules): + """Test that Olive-level cache_dir overrides the recipe's cache_dir.""" + output_path = tmp_path / "output" + + mock_pipeline_modules["Recipe"].from_file.return_value = { + "cache_dir": "./recipe_cache", + "stages": {}, + } + + pipeline_pass = create_pass_from_dict( + QairtPipelinePass, + {"recipe": str(recipe_file), "cache_dir": "/custom/cache"}, + disable_search=True, + ) + + pipeline_pass.run(mock_hf_model, str(output_path)) + + call_kwargs = mock_pipeline_modules["LLMPipeline"].from_pretrained.call_args + recipe_arg = call_kwargs.kwargs["recipe"] + assert recipe_arg["cache_dir"] == "/custom/cache" + + +def test_pipeline_pass_log_level_override(tmp_path, mock_hf_model, recipe_file, mock_pipeline_modules): + """Test that Olive-level log_level overrides the recipe's log_level.""" + output_path = tmp_path / "output" + + mock_pipeline_modules["Recipe"].from_file.return_value = { + "log_level": "warn", + "stages": {}, + } + + pipeline_pass = create_pass_from_dict( + QairtPipelinePass, + {"recipe": str(recipe_file), "log_level": "DEBUG"}, + disable_search=True, + ) + + pipeline_pass.run(mock_hf_model, str(output_path)) + + call_kwargs = mock_pipeline_modules["LLMPipeline"].from_pretrained.call_args + recipe_arg = call_kwargs.kwargs["recipe"] + assert recipe_arg["log_level"] == "DEBUG" + + +def test_pipeline_pass_invalid_input_model(tmp_path, mock_qairt_prepared_model, recipe_file, mock_pipeline_modules): + """Test that ValueError is raised when input is not HfModelHandler.""" + output_path = tmp_path / "output" + + mock_pipeline_modules["Recipe"].from_file.return_value = {"stages": {}} + + pipeline_pass = create_pass_from_dict( + QairtPipelinePass, + {"recipe": str(recipe_file)}, + disable_search=True, + ) + + with pytest.raises(ValueError, match="QairtPipelinePass requires HfModelHandler"): + pipeline_pass.run(mock_qairt_prepared_model, str(output_path)) + + +def test_pipeline_pass_missing_recipe_file(tmp_path, mock_hf_model, mock_pipeline_modules): + """Test that ValueError is raised when recipe file does not exist.""" + output_path = tmp_path / "output" + + mock_pipeline_modules["Recipe"].from_file.return_value = {"stages": {}} + + pipeline_pass = create_pass_from_dict( + QairtPipelinePass, + {"recipe": str(tmp_path / "nonexistent_recipe.yaml")}, + disable_search=True, + ) + + with pytest.raises(ValueError, match="Recipe file not found"): + pipeline_pass.run(mock_hf_model, str(output_path)) + + +def test_pipeline_pass_import_error(tmp_path, mock_hf_model, recipe_file): + """Test that ImportError is raised if qairt cannot be imported.""" + + def import_side_effect(name, *args, **kwargs): + if "qairt" in name: + raise ImportError("Mock import error") + return original_import(name, *args, **kwargs) + + original_import = builtins.__import__ + + with patch("builtins.__import__", side_effect=import_side_effect): + pipeline_pass = create_pass_from_dict( + QairtPipelinePass, + {"recipe": str(recipe_file)}, + disable_search=True, + ) + + with pytest.raises(ImportError, match="Failed to import QAIRT Pipeline API"): + pipeline_pass.run(mock_hf_model, str(tmp_path / "output"))