diff --git a/csrc/models/backend_plugin_loader.cpp b/csrc/models/backend_plugin_loader.cpp new file mode 100644 index 000000000..3d8e772ce --- /dev/null +++ b/csrc/models/backend_plugin_loader.cpp @@ -0,0 +1,100 @@ +#include "backend_plugin_loader.hpp" + +#include +#include +#include +#include +#include +#include +#include + +namespace infinilm::models { +namespace { + +using PluginInitFn = void (*)(); + +std::mutex &loader_mutex() { + static std::mutex mutex; + return mutex; +} + +std::unordered_map &loaded_handles() { + static std::unordered_map handles; + return handles; +} + +std::string trim(std::string value) { + const auto begin = value.find_first_not_of(" \t\n\r"); + if (begin == std::string::npos) { + return ""; + } + const auto end = value.find_last_not_of(" \t\n\r"); + return value.substr(begin, end - begin + 1); +} + +std::vector split_plugins(const char *env_value) { + std::vector plugins; + if (env_value == nullptr || *env_value == '\0') { + return plugins; + } + + std::stringstream stream(env_value); + std::string item; + while (std::getline(stream, item, ',')) { + item = trim(item); + if (!item.empty()) { + plugins.push_back(item); + } + } + return plugins; +} + +} // namespace + +void load_backend_plugin(const std::string &plugin_path) { + const std::string path = trim(plugin_path); + if (path.empty()) { + return; + } + + std::lock_guard lock(loader_mutex()); + auto &handles = loaded_handles(); + if (handles.find(path) != handles.end()) { + return; + } + + void *handle = dlopen(path.c_str(), RTLD_NOW | RTLD_GLOBAL); + if (handle == nullptr) { + const char *error = dlerror(); + throw std::runtime_error( + "infinilm::models::load_backend_plugin: failed to load " + path + + ": " + (error == nullptr ? "unknown dlopen error" : std::string(error))); + } + + dlerror(); + auto init_fn = reinterpret_cast(dlsym(handle, "infinilm_backend_plugin_init")); + const char *dlsym_error = dlerror(); + if (dlsym_error == nullptr && init_fn != nullptr) { + init_fn(); + } + + handles[path] = handle; +} + +void load_backend_plugins_from_env() { + for (const auto &plugin : split_plugins(std::getenv("INFINILM_BACKEND_PLUGINS"))) { + load_backend_plugin(plugin); + } +} + +std::vector loaded_backend_plugins() { + std::lock_guard lock(loader_mutex()); + std::vector plugins; + plugins.reserve(loaded_handles().size()); + for (const auto &[path, _] : loaded_handles()) { + plugins.push_back(path); + } + return plugins; +} + +} // namespace infinilm::models diff --git a/csrc/models/backend_plugin_loader.hpp b/csrc/models/backend_plugin_loader.hpp new file mode 100644 index 000000000..26ad616ca --- /dev/null +++ b/csrc/models/backend_plugin_loader.hpp @@ -0,0 +1,30 @@ +#pragma once + +#include +#include + +namespace infinilm::models { + +/** + * Load one out-of-tree backend plugin shared object. + * + * The plugin may either rely on static initializers that call + * `register_causal_lm_model` / `register_model_config`, or export an optional + * `extern "C" void infinilm_backend_plugin_init()` function. Loading is + * idempotent for each path. + */ +void load_backend_plugin(const std::string &plugin_path); + +/** + * Load backend plugins from `INFINILM_BACKEND_PLUGINS`. + * + * The environment variable accepts comma-separated shared object paths. + */ +void load_backend_plugins_from_env(); + +/** + * Return plugin paths that have already been loaded. + */ +std::vector loaded_backend_plugins(); + +} // namespace infinilm::models diff --git a/csrc/models/infinilm_model.cpp b/csrc/models/infinilm_model.cpp index 8429fffba..a3a6863c0 100644 --- a/csrc/models/infinilm_model.cpp +++ b/csrc/models/infinilm_model.cpp @@ -18,6 +18,13 @@ void InfinilmModel::reset_cache(const cache::CacheConfig *cache_config) { kv_cache_vec = std::move(default_allocate_kv_cache_tensors(cache_config, model_config_, attention_backend)); } +void InfinilmModel::load_parameters_no_sync( + const std::unordered_map ¶ms) { + for (const auto &[name, param] : params) { + load_parameter(name, param); + } +} + std::vector InfinilmModel::default_allocate_kv_cache_tensors( const cache::CacheConfig *cache_config, const std::shared_ptr &text_config, diff --git a/csrc/models/infinilm_model.hpp b/csrc/models/infinilm_model.hpp index c76a29f08..f5f92bfcb 100644 --- a/csrc/models/infinilm_model.hpp +++ b/csrc/models/infinilm_model.hpp @@ -7,6 +7,7 @@ #include "infinicore/tensor.hpp" #include +#include #include namespace infinilm { @@ -57,6 +58,8 @@ class InfinilmModel : public infinicore::nn::Module { return cache_config_.get(); } + void load_parameters_no_sync( + const std::unordered_map ¶ms); void process_weights_after_loading(); void reset_runtime_state() const; diff --git a/csrc/pybind11/bindings.cc b/csrc/pybind11/bindings.cc index 63846338b..e6e2d2d76 100644 --- a/csrc/pybind11/bindings.cc +++ b/csrc/pybind11/bindings.cc @@ -2,6 +2,7 @@ #include "cache/cache.hpp" #include "engine/engine.hpp" +#include "../models/backend_plugin_loader.hpp" namespace py = pybind11; @@ -12,4 +13,11 @@ PYBIND11_MODULE(_infinilm, m) { infinilm::engine::bind_hook_registry(m); infinilm::engine::distributed::bind_dist_config(m); infinilm::engine::bind_infer_engine(m); + + m.def("load_backend_plugin", &infinilm::models::load_backend_plugin, + "Load one InfiniLM C++ backend plugin shared object."); + m.def("load_backend_plugins_from_env", &infinilm::models::load_backend_plugins_from_env, + "Load InfiniLM C++ backend plugins from INFINILM_BACKEND_PLUGINS."); + m.def("loaded_backend_plugins", &infinilm::models::loaded_backend_plugins, + "Return paths of loaded InfiniLM C++ backend plugins."); } diff --git a/python/infinilm/__init__.py b/python/infinilm/__init__.py index f552a2cc9..544d6e074 100644 --- a/python/infinilm/__init__.py +++ b/python/infinilm/__init__.py @@ -1,17 +1,41 @@ -from .models import AutoLlamaModel -from . import distributed -from . import cache -from . import llm -from . import base_config - -from .llm import ( - LLM, - AsyncLLMEngine, - SamplingParams, - RequestOutput, - TokenOutput, +from importlib import import_module + +from .plugins import ( + ModelSpec, + load_plugin, + load_plugins, + register_model, + registered_model_types, ) + +_LAZY_ATTRS = { + "AutoLlamaModel": ("infinilm.models", "AutoLlamaModel"), + "LLM": ("infinilm.llm", "LLM"), + "AsyncLLMEngine": ("infinilm.llm", "AsyncLLMEngine"), + "SamplingParams": ("infinilm.llm", "SamplingParams"), + "RequestOutput": ("infinilm.llm", "RequestOutput"), + "TokenOutput": ("infinilm.llm", "TokenOutput"), +} + +_LAZY_MODULES = {"distributed", "cache", "llm", "base_config"} + + +def __getattr__(name): + if name in _LAZY_MODULES: + module = import_module(f".{name}", __name__) + globals()[name] = module + return module + + target = _LAZY_ATTRS.get(name) + if target is not None: + module_name, attr_name = target + value = getattr(import_module(module_name), attr_name) + globals()[name] = value + return value + + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + __all__ = [ "AutoLlamaModel", "distributed", @@ -24,4 +48,10 @@ "SamplingParams", "RequestOutput", "TokenOutput", + # Out-of-tree model plugins + "ModelSpec", + "load_plugin", + "load_plugins", + "register_model", + "registered_model_types", ] diff --git a/python/infinilm/backend_plugins.py b/python/infinilm/backend_plugins.py new file mode 100644 index 000000000..5c0f8da62 --- /dev/null +++ b/python/infinilm/backend_plugins.py @@ -0,0 +1,53 @@ +from __future__ import annotations + +import os +from collections.abc import Sequence + + +def _split_plugin_list(value: str | None) -> list[str]: + if not value: + return [] + return [item.strip() for item in value.split(",") if item.strip()] + + +def _backend_module(): + from infinilm.lib import _infinilm + + return _infinilm + + +def load_backend_plugin(plugin: str | os.PathLike[str]) -> None: + """Load one InfiniLM C++ backend plugin shared object.""" + + _backend_module().load_backend_plugin(os.fspath(plugin)) + + +def load_backend_plugins(plugins: Sequence[str | os.PathLike[str]] | str | None = None) -> tuple[str, ...]: + """Load explicitly requested InfiniLM C++ backend plugins.""" + + requested: list[str] = [] + if isinstance(plugins, (str, os.PathLike)): + requested.extend(_split_plugin_list(os.fspath(plugins))) + elif plugins: + requested.extend(os.fspath(plugin) for plugin in plugins) + + for plugin in requested: + load_backend_plugin(plugin) + return loaded_backend_plugins() + + +def load_backend_plugins_from_env() -> tuple[str, ...]: + """Load backend plugins from `INFINILM_BACKEND_PLUGINS`. + + This is an explicit compatibility helper for command-line or embedding + workflows. Core config/model factories do not read environment variables + implicitly. + """ + + return load_backend_plugins(os.environ.get("INFINILM_BACKEND_PLUGINS")) + + +def loaded_backend_plugins() -> tuple[str, ...]: + """Return paths of C++ backend plugins already loaded in this process.""" + + return tuple(_backend_module().loaded_backend_plugins()) diff --git a/python/infinilm/infer_engine.py b/python/infinilm/infer_engine.py index 17ee6c12f..71df05b59 100644 --- a/python/infinilm/infer_engine.py +++ b/python/infinilm/infer_engine.py @@ -3,9 +3,11 @@ import infinicore -from infinilm.cache import PagedKVCacheConfig +from infinilm.cache import StaticKVCacheConfig, PagedKVCacheConfig +from infinilm.backend_plugins import load_backend_plugins from infinilm.distributed import DistConfig from infinilm.lib import _infinilm +from infinilm.plugins import adapt_config, load_plugins from .modeling_utils import parse_dtype from .exception_utils import handle_oom_and_exit @@ -67,10 +69,14 @@ def __init__( enable_graph_compiling=False, attention_backend="default", kv_cache_dtype=None, + backend_plugins=None, use_mla=False, ): - self.hf_config = read_hf_config(model_path) + load_plugins() + self.hf_config = adapt_config(read_hf_config(model_path)) self.hf_generation_config = read_hf_generation_config(model_path) + load_backend_plugins(self.hf_config.get("_infinilm_backend_plugins")) + load_backend_plugins(backend_plugins) if device is None: device = infinicore.device() diff --git a/python/infinilm/lib/__init__.py b/python/infinilm/lib/__init__.py index 67c9ce400..8d125f3ab 100644 --- a/python/infinilm/lib/__init__.py +++ b/python/infinilm/lib/__init__.py @@ -6,6 +6,9 @@ import os from pathlib import Path +# Register shared pybind11 types used by the InfiniLM extension. +import infinicore # noqa: F401 + # Ensure the directory containing this __init__.py is on sys.path # This allows importing the .so file from the same directory _lib_dir = Path(__file__).parent diff --git a/python/infinilm/modeling_utils.py b/python/infinilm/modeling_utils.py index ea2518708..c8ff9e28e 100644 --- a/python/infinilm/modeling_utils.py +++ b/python/infinilm/modeling_utils.py @@ -21,6 +21,20 @@ def _get_scale_emb(model_path: str) -> float: return config.get("scale_emb", 1.0) +def _load_adapted_hf_config(model_path: str) -> Dict: + config_path = os.path.join(model_path, "config.json") + if not os.path.exists(config_path): + raise FileNotFoundError(f"config.json not found at {config_path}") + + with open(config_path, "r") as f: + config = json.load(f) + + from infinilm.plugins import adapt_config, load_plugins + + load_plugins() + return adapt_config(config) + + def parse_dtype(dtype_str: str): if dtype_str == "float32": return infinicore.float32 @@ -135,6 +149,15 @@ def get_model_state_dict( load_state_dict(file_path, device=torch_device, dtype=torch_dtype) ) + hf_config = _load_adapted_hf_config(model_path) + model_type, backend_model_type = _weight_model_types(hf_config) + model_param = _apply_model_weight_remapping( + model_param, + hf_config, + model_type, + backend_model_type, + ) + # Apply scale_emb for fm9g models (embed_tokens uses lookup, not GEMM) scale_emb = _get_scale_emb(model_path) embed_tokens_unscaled = None @@ -173,7 +196,7 @@ def load_model_state_dict_by_file( print(" load weights ......") t1 = time.time() - model_type = model.hf_config.get("model_type", "") + model_type, backend_model_type = _weight_model_types(model.hf_config) torch_device = "cpu" torch_dtype = infinicore.utils.to_torch_dtype(dtype) @@ -195,10 +218,12 @@ def load_model_state_dict_by_file( file_path, device=torch_device, dtype=torch_dtype ) - # Apply model-specific weight remapping - remapper = _WEIGHT_REMAPPER.get(model_type) - if remapper is not None: - model_param = remapper(model_param, config=model.hf_config) + model_param = _apply_model_weight_remapping( + model_param, + model.hf_config, + model_type, + backend_model_type, + ) already_loaded_keys.extend(model_param.keys()) @@ -238,10 +263,12 @@ def load_model_state_dict_by_file( file_path = os.path.join(model_path, "pytorch_model.bin") model_params = torch.load(file_path, weights_only=True, map_location="cpu") - # Apply model-specific weight remapping - remapper = _WEIGHT_REMAPPER.get(model_type) - if remapper is not None: - model_params = remapper(model_params, config=model.hf_config) + model_params = _apply_model_weight_remapping( + model_params, + model.hf_config, + model_type, + backend_model_type, + ) # Scale embed_tokens on torch side before converting if "model.embed_tokens.weight" in model_params: @@ -285,6 +312,34 @@ def load_model_state_dict_by_file( print(f" load weights over! {(t2 - t1) * 1000} ms \n") +def _weight_model_types(config): + original = config.get("_infinilm_original_model_type") or config.get("model_type", "") + backend = config.get("_infinilm_backend_model_type") or config.get("model_type", "") + return str(original).lower(), str(backend).lower() + + +def _apply_model_weight_remapping( + state_dict, + config, + model_type, + backend_model_type, +): + from infinilm.plugins import apply_weight_remapping, get_model_spec, load_plugins + + load_plugins() + spec = get_model_spec(model_type) + state_dict = apply_weight_remapping(model_type, state_dict, config) + if spec is not None and not spec.use_builtin_weight_remapper: + return state_dict + + remapper = _WEIGHT_REMAPPER.get(model_type) or _WEIGHT_REMAPPER.get( + backend_model_type + ) + if remapper is not None: + state_dict = remapper(state_dict, config=config) + return state_dict + + def load_model_state_dict_by_tensor( model: infinicore.nn.Module, model_path: str, @@ -297,6 +352,7 @@ def load_model_state_dict_by_tensor( print(" load weights ......") t1 = time.time() + model_type, backend_model_type = _weight_model_types(model.hf_config) torch_dtype = infinicore.utils.to_torch_dtype(dtype) model_keys = model.state_dict_keyname() scale_emb = _get_scale_emb(model_path) @@ -308,23 +364,38 @@ def load_model_state_dict_by_tensor( for file_path in tqdm(file_list, desc="Processing files"): tqdm.write(f"Processing: {os.path.basename(file_path)}") - with safe_open(file_path, "pt", "cpu") as f: - for name in f.keys(): - tensor = f.get_tensor(name).to(dtype=torch_dtype) + model_param = load_state_dict(file_path, device="cpu", dtype=torch_dtype) + model_param = _apply_model_weight_remapping( + model_param, + model.hf_config, + model_type, + backend_model_type, + ) + + for name, tensor in model_param.items(): + tensor = tensor.to(dtype=torch_dtype) - if name == "model.embed_tokens.weight": - embed_tokens_torch_unscaled = tensor - if scale_emb != 1.0: - tensor = tensor * float(scale_emb) + if name == "model.embed_tokens.weight": + embed_tokens_torch_unscaled = tensor + if scale_emb != 1.0: + tensor = tensor * float(scale_emb) - weight_infini = infinicore.from_torch(tensor) - model.load_param(name, weight_infini) - already_loaded_keys.append(name) - infinicore.sync_stream() + weight_infini = infinicore.from_torch(tensor) + model.load_param(name, weight_infini) + already_loaded_keys.append(name) + infinicore.sync_stream() + + del model_param elif os.path.exists(os.path.join(model_path, "pytorch_model.bin")): file_path = os.path.join(model_path, "pytorch_model.bin") model_params = torch.load(file_path, weights_only=True, map_location="cpu") + model_params = _apply_model_weight_remapping( + model_params, + model.hf_config, + model_type, + backend_model_type, + ) for key in model_params.keys(): tensor = model_params[key].to(dtype=torch_dtype) diff --git a/python/infinilm/plugins/__init__.py b/python/infinilm/plugins/__init__.py new file mode 100644 index 000000000..fd521a67a --- /dev/null +++ b/python/infinilm/plugins/__init__.py @@ -0,0 +1,31 @@ +from .model_spec import ( + ModelSpec, + adapt_config, + apply_weight_remapping, + get_model_spec, + load_plugin, + load_plugins, + register_model, + registered_model_types, +) +from infinilm.backend_plugins import ( + load_backend_plugin, + load_backend_plugins, + load_backend_plugins_from_env, + loaded_backend_plugins, +) + +__all__ = [ + "ModelSpec", + "adapt_config", + "apply_weight_remapping", + "get_model_spec", + "load_backend_plugin", + "load_backend_plugins", + "load_backend_plugins_from_env", + "load_plugin", + "load_plugins", + "loaded_backend_plugins", + "register_model", + "registered_model_types", +] diff --git a/python/infinilm/plugins/model_spec.py b/python/infinilm/plugins/model_spec.py new file mode 100644 index 000000000..6d2211d60 --- /dev/null +++ b/python/infinilm/plugins/model_spec.py @@ -0,0 +1,256 @@ +from __future__ import annotations + +import importlib +import importlib.util +import inspect +import os +import sys +from collections.abc import Callable, Mapping, Sequence +from dataclasses import dataclass, field +from pathlib import Path +from types import ModuleType +from typing import Any + + +StateDict = dict[str, Any] +ConfigDict = dict[str, Any] +ConfigAdapter = Callable[[ConfigDict], ConfigDict] | Mapping[str, Any] +WeightRemapper = Callable[..., StateDict] + + +@dataclass(slots=True) +class ModelSpec: + """Out-of-tree model registration contract. + + `model_type`/`model_types` identify HuggingFace config model_type values. + `backend_model_type` is the InfiniLM C++ model implementation to reuse. + Python callbacks run only while loading config or weights, not in the + token-by-token inference hot path. + """ + + model_type: str | None = None + model_types: Sequence[str] | None = None + backend_model_type: str | None = None + config_adapter: ConfigAdapter | None = None + weight_remapper: WeightRemapper | None = None + weight_rules: Sequence[WeightRemapper] = field(default_factory=tuple) + processor_cls: type[Any] | None = None + processor: str | None = None + backend_plugin: str | os.PathLike[str] | None = None + backend_plugins: Sequence[str | os.PathLike[str]] | None = None + use_builtin_weight_remapper: bool = True + metadata: Mapping[str, Any] = field(default_factory=dict) + + def normalized_model_types(self) -> tuple[str, ...]: + names: list[str] = [] + if self.model_type: + names.append(self.model_type) + if self.model_types: + names.extend(self.model_types) + + normalized = tuple(dict.fromkeys(name.lower() for name in names if name)) + if not normalized: + raise ValueError("ModelSpec requires model_type or model_types.") + return normalized + + +_MODEL_SPECS: dict[str, ModelSpec] = {} +_LOADED_PLUGINS: dict[str, ModuleType] = {} + + +def register_model(spec: ModelSpec | None = None, **kwargs: Any) -> ModelSpec: + """Register a ModelSpec and return it. + + Examples: + register_model(ModelSpec(model_type="foo", backend_model_type="llama")) + register_model(model_type="foo", backend_model_type="llama") + """ + + if spec is None: + spec = ModelSpec(**kwargs) + elif kwargs: + raise TypeError("Pass either a ModelSpec or keyword arguments, not both.") + + for model_type in spec.normalized_model_types(): + previous = _MODEL_SPECS.get(model_type) + if previous is not None and previous is not spec: + raise ValueError(f"Duplicate ModelSpec registration for {model_type!r}.") + _MODEL_SPECS[model_type] = spec + return spec + + +def get_model_spec(model_type: str | None) -> ModelSpec | None: + if not model_type: + return None + return _MODEL_SPECS.get(model_type.lower()) + + +def registered_model_types() -> tuple[str, ...]: + return tuple(sorted(_MODEL_SPECS)) + + +def _split_plugin_list(value: str | None) -> list[str]: + if not value: + return [] + parts: list[str] = [] + for chunk in value.split(","): + chunk = chunk.strip() + if chunk: + parts.append(chunk) + return parts + + +def load_plugins(plugins: Sequence[str] | str | None = None) -> tuple[str, ...]: + """Load plugin modules from INFINILM_PLUGINS and/or explicit names. + + INFINILM_PLUGINS accepts comma-separated Python module names or .py paths. + Loading is idempotent for a process. + """ + + requested: list[str] = _split_plugin_list(os.environ.get("INFINILM_PLUGINS")) + if isinstance(plugins, str): + requested.extend(_split_plugin_list(plugins)) + elif plugins: + requested.extend(str(plugin) for plugin in plugins) + + for plugin in requested: + load_plugin(plugin) + return tuple(_LOADED_PLUGINS) + + +def load_plugin(plugin: str | os.PathLike[str]) -> ModuleType: + plugin_name = os.fspath(plugin) + if plugin_name in _LOADED_PLUGINS: + return _LOADED_PLUGINS[plugin_name] + + if plugin_name.endswith(".py") or Path(plugin_name).expanduser().exists(): + module = _load_plugin_file(plugin_name) + else: + try: + module = importlib.import_module(plugin_name) + except ModuleNotFoundError as exc: + candidate = Path(*plugin_name.split(".")).with_suffix(".py") + if exc.name == plugin_name.split(".")[0] and candidate.exists(): + module = _load_plugin_file(str(candidate)) + else: + raise + + _LOADED_PLUGINS[plugin_name] = module + return module + + +def _load_plugin_file(plugin_path: str) -> ModuleType: + path = Path(plugin_path).expanduser().resolve() + module_name = f"_infinilm_plugin_{abs(hash(str(path)))}" + spec = importlib.util.spec_from_file_location(module_name, path) + if spec is None or spec.loader is None: + raise ImportError(f"Unable to import InfiniLM plugin from {path}.") + + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + return module + + +def adapt_config(config: Mapping[str, Any]) -> ConfigDict: + """Return an InfiniLM-ready config for a registered model_type.""" + + config_dict = dict(config) + original_model_type = config_dict.get("model_type") + spec = get_model_spec(str(original_model_type) if original_model_type else None) + if spec is None: + return config_dict + + adapted = _apply_config_adapter(config_dict, spec.config_adapter) + adapted["_infinilm_original_model_type"] = original_model_type + backend_plugins = _backend_plugins_for_spec(spec) + if backend_plugins: + adapted["_infinilm_backend_plugins"] = backend_plugins + if spec.backend_model_type: + backend_model_type = spec.backend_model_type.lower() + adapted["_infinilm_backend_model_type"] = backend_model_type + adapted["model_type"] = backend_model_type + return adapted + + +def _backend_plugins_for_spec(spec: ModelSpec) -> list[str]: + requested: list[str | os.PathLike[str]] = [] + if spec.backend_plugin: + requested.append(spec.backend_plugin) + if spec.backend_plugins: + requested.extend(spec.backend_plugins) + return [os.fspath(plugin) for plugin in requested] + + +def _apply_config_adapter( + config: ConfigDict, + adapter: ConfigAdapter | None, +) -> ConfigDict: + if adapter is None: + return dict(config) + + if callable(adapter): + adapted = adapter(dict(config)) + if adapted is None: + raise ValueError("ModelSpec.config_adapter must return a config dict.") + return dict(adapted) + + adapted = dict(config) + for key, value in adapter.items(): + if callable(value): + adapted[key] = value(config) + elif isinstance(value, str) and value.startswith("$"): + adapted[key] = config[value[1:]] + else: + adapted[key] = value + return adapted + + +def apply_weight_remapping( + model_type: str | None, + state_dict: StateDict, + config: Mapping[str, Any] | None = None, +) -> StateDict: + spec = get_model_spec(model_type) + if spec is None: + return state_dict + + result = state_dict + if spec.weight_remapper is not None: + result = _call_weight_remapper(spec.weight_remapper, result, config) + for rule in spec.weight_rules: + result = _call_weight_remapper(rule, result, config) + return result + + +def _call_weight_remapper( + remapper: WeightRemapper, + state_dict: StateDict, + config: Mapping[str, Any] | None, +) -> StateDict: + try: + signature = inspect.signature(remapper) + except (TypeError, ValueError): + return remapper(state_dict, config=config) + + params = signature.parameters.values() + if any(param.kind == inspect.Parameter.VAR_KEYWORD for param in params): + return remapper(state_dict, config=config) + + params = signature.parameters.values() + if any(param.name == "config" for param in params): + return remapper(state_dict, config=config) + + required_positionals = [ + param + for param in signature.parameters.values() + if param.default is inspect.Parameter.empty + and param.kind + in ( + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + ) + ] + if len(required_positionals) >= 2: + return remapper(state_dict, config) + return remapper(state_dict) diff --git a/python/infinilm/processors/__init__.py b/python/infinilm/processors/__init__.py index 67f97acac..056135892 100644 --- a/python/infinilm/processors/__init__.py +++ b/python/infinilm/processors/__init__.py @@ -1,7 +1,9 @@ import importlib +import json import pkgutil from pathlib import Path from transformers import AutoConfig +from infinilm.plugins import get_model_spec, load_plugins from .processor import get_processor_class, InfinilmProcessor # --------------------------------------------------------------------------- @@ -13,10 +15,14 @@ # without requiring manual imports for each new model. # --------------------------------------------------------------------------- _current_dir = Path(__file__).resolve().parent +_PROCESSOR_IMPORT_ERRORS = {} for _module_info in pkgutil.iter_modules([str(_current_dir)]): if _module_info.name.endswith("_processor"): - importlib.import_module(f".{_module_info.name}", __package__) + try: + importlib.import_module(f".{_module_info.name}", __package__) + except Exception as exc: + _PROCESSOR_IMPORT_ERRORS[_module_info.name] = exc class AutoInfinilmProcessor: @@ -30,8 +36,48 @@ def from_pretrained(cls, model_dir_path: str, **kwargs) -> InfinilmProcessor: registered Processor. Falls back to the registered default processor for unregistered or standard architectures. """ - config = AutoConfig.from_pretrained(model_dir_path, trust_remote_code=True) - model_type = config.model_type.lower() + load_plugins() + model_type = _read_model_type(model_dir_path) - processor_cls = get_processor_class(model_type) + spec = get_model_spec(model_type) + if spec is not None and spec.processor_cls is not None: + processor_cls = spec.processor_cls + else: + processor_type = model_type + if spec is not None: + processor_type = ( + spec.processor + or spec.backend_model_type + or processor_type + ) + _raise_processor_import_error_if_requested(processor_type) + processor_cls = get_processor_class(processor_type.lower()) return processor_cls(model_dir_path) + + +def _raise_processor_import_error_if_requested(processor_type: str) -> None: + normalized = processor_type.lower().replace("-", "_") + module_names = { + normalized, + f"{normalized}_processor", + } + for module_name in module_names: + exc = _PROCESSOR_IMPORT_ERRORS.get(module_name) + if exc is not None: + raise ImportError( + f"Failed to import processor module '{module_name}' while " + f"resolving processor '{processor_type}'." + ) from exc + + +def _read_model_type(model_dir_path: str) -> str: + config_path = Path(model_dir_path) / "config.json" + if config_path.exists(): + with config_path.open("r") as config_file: + config = json.load(config_file) + model_type = config.get("model_type") + if model_type: + return model_type.lower() + + config = AutoConfig.from_pretrained(model_dir_path, trust_remote_code=True) + return config.model_type.lower() diff --git a/setup.py b/setup.py index ac73b956a..da95e7a4e 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,7 @@ import subprocess from pathlib import Path -from setuptools import setup +from setuptools import find_packages, setup from setuptools.command.build import build from setuptools.command.develop import develop from setuptools.command.egg_info import egg_info @@ -37,7 +37,7 @@ def run(self): version="0.1.0", description="InfiniLM model implementations", package_dir={"": "python"}, - packages=["infinilm", "infinilm.models", "infinilm.lib", "infinilm.distributed"], + packages=find_packages(where="python"), cmdclass={ "build": Build, "develop": Develop,