Skip to content
49 changes: 49 additions & 0 deletions fast_llm/engine/checkpoint/external.py
Original file line number Diff line number Diff line change
Expand Up @@ -1144,6 +1144,55 @@ def _emit(
return out


class RepeatWeightConverter(WeightConverter):
"""Repeat a sub-section converter a config-driven number of times, with index-templated prefixes.

Unlike :class:`BlockSequenceWeightConverter` — which fans a converter over a materialized per-position
config list — every iteration recurses into the *same* sub-config; only the emitted prefixes change
with the index. Used where a runtime count, not a block list, drives the repeat.

``count`` and ``sub_config`` are resolved from the live section config. ``fast_llm_prefix`` and
``hf_prefix`` map a 0-based iteration index to the section-relative prefixes; the two need not share
the same index arithmetic — e.g. an HF-side stack whose element 0 is declared elsewhere is reached by
offsetting the index here.
"""

def __init__(
self,
sub_converter_class: type["ConfigSectionConverter"],
*,
count: typing.Callable[[Config], int],
sub_config: typing.Callable[[Config], Config],
fast_llm_prefix: typing.Callable[[int], str],
hf_prefix: typing.Callable[[int], str],
):
super().__init__((), ())
self._sub_converter_class = sub_converter_class
self._count = count
self._sub_config = sub_config
self._fast_llm_prefix = fast_llm_prefix
self._hf_prefix = hf_prefix

def _emit(
self,
config: Config,
fast_llm_prefix: str,
hf_prefix: str,
*,
root_config: Config,
) -> list[WeightConverter]:
sub_config = self._sub_config(config)
out: list[WeightConverter] = []
for index in range(self._count(config)):
out += self._sub_converter_class.emit_weight_converters(
sub_config,
join_prefix(fast_llm_prefix, self._fast_llm_prefix(index)),
join_prefix(hf_prefix, self._hf_prefix(index)),
root_config=root_config,
)
return out


class DispatchWeightConverter(WeightConverter):
"""Dispatch a single sub-section converter based on the live config's runtime type.

Expand Down
43 changes: 24 additions & 19 deletions fast_llm/engine/checkpoint/huggingface.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import abc
import functools
import json
import pathlib
import shutil
Expand Down Expand Up @@ -121,29 +122,20 @@ def _export_config(cls, config: FastLLMModelConfig) -> dict[str, typing.Any]:
},
)

# Top-level HF metadata keys that are always permitted, regardless of the converter tree.
# Covers transformers' generic ``PretrainedConfig`` fields (always present after ``to_dict()``)
# plus a handful of widely-shared metadata that Fast-LLM intentionally does not store.
# HF metadata keys that are always permitted, regardless of the converter tree. The generic
# ``PretrainedConfig`` fields are added dynamically (see :meth:`_hf_metadata_allowlist`) because the
# exact set drifts across the supported transformers range — e.g. the generation kwargs and
# ``torchscript`` that v4 dumps into ``to_dict()`` were moved out to ``GenerationConfig`` in v5. This
# static set covers the widely-shared metadata that Fast-LLM intentionally does not store but that a
# bare ``PretrainedConfig`` does not carry (model-specific defaults like ``max_position_embeddings``).
_HF_METADATA_ALLOWLIST: typing.ClassVar[frozenset[str]] = frozenset(
{
# transformers PretrainedConfig
"_name_or_path",
"architectures",
# transformers metadata Fast-LLM does not store that a bare ``PretrainedConfig().to_dict()``
# omits across the supported range (so the dynamic union would miss them).
"auto_map",
"chunk_size_feed_forward",
"dtype",
"id2label",
"is_encoder_decoder",
"label2id",
"model_type",
"output_attentions",
"output_hidden_states",
"problem_type",
"return_dict",
"torch_dtype",
"transformers_version",
"use_cache",
# Token ids — generation/inference, not architecture.
# Token ids — generation/inference, not architecture (a bare v5 config omits these).
"bos_token_id",
"decoder_start_token_id",
"eos_token_id",
Expand All @@ -156,14 +148,27 @@ def _export_config(cls, config: FastLLMModelConfig) -> dict[str, typing.Any]:
}
)

@classmethod
@functools.cache
def _hf_metadata_allowlist(cls) -> frozenset[str]:
"""Static allowlist unioned with the live ``PretrainedConfig`` field set.

Every key a bare ``PretrainedConfig`` carries is generic transformers metadata, never
architecture, so deriving them from the installed transformers keeps the coverage check correct
across the supported version range instead of hard-coding a version-specific set.
"""
import transformers

return cls._HF_METADATA_ALLOWLIST | frozenset(transformers.PretrainedConfig().to_dict())

@classmethod
def _check_hf_coverage(cls, config: dict[str, typing.Any]) -> None:
"""Run the HF-side coverage check at the import boundary.

Subclasses that override :meth:`_import_config` should call this explicitly to keep the check
active.
"""
cls.base_model_converter_class.check_hf_coverage(config, allowlist=cls._HF_METADATA_ALLOWLIST)
cls.base_model_converter_class.check_hf_coverage(config, allowlist=cls._hf_metadata_allowlist())

@classmethod
def _import_config(cls, config: dict[str, typing.Any]) -> FastLLMModelConfig:
Expand Down
18 changes: 12 additions & 6 deletions fast_llm/models/gpt/conversion/gemma4.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,15 +701,21 @@ def _head_import(hf_dict: dict) -> dict:
"vocab_size_per_layer_input": IgnoredConfigConverter(hf_paths=(("vocab_size_per_layer_input",),)),
}

@classmethod
@functools.cache
def _create_weight_converters(cls) -> dict[str, WeightConverter]:
# ``head`` is added at the aggregator level (in :meth:`get_converters`) because the head
# converter takes the full base-model config so subclasses extending the head can read
# sibling sections.
return {
"embeddings": NestedWeightConverter("embeddings", "model", cls.embeddings_converter_class),
"decoder": NestedWeightConverter("decoder", "model.layers", cls.decoder_converter_class),
}

@classmethod
def get_converters(cls, config: GPTBaseModelConfig) -> list[WeightConverter]:
return [
*cls.embeddings_converter_class.emit_weight_converters(
config.embeddings, "embeddings", "model", root_config=config
),
*cls.decoder_converter_class.emit_weight_converters(
config.decoder, "decoder", "model.layers", root_config=config
),
*cls.emit_weight_converters(config, "", ""),
*cls.head_converter_class.get_converters(config),
]

Expand Down
54 changes: 29 additions & 25 deletions fast_llm/models/gpt/conversion/mtp_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
NestedWeightConverter,
OutputProjectionWeightConverter,
RenameConfigConverter,
RepeatWeightConverter,
WeightConverter,
)
from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig
from fast_llm.models.gpt.config import GPTModelConfig
from fast_llm.models.gpt.conversion.config import MTPLlamaCheckpointFormat
from fast_llm.models.gpt.conversion.llama import (
LlamaBaseModelConverter,
Expand All @@ -36,41 +37,44 @@ def _create_config_converters(cls) -> dict:
@classmethod
@functools.cache
def _create_weight_converters(cls) -> dict[str, WeightConverter]:
# MTP-Llama places the first prediction head's final norm under ``model.mtp_norms.0`` instead
# of the standard ``model.norm``; the additional MTP blocks/norms come from the imperative
# ``get_converters`` override below since their count depends on ``head.prediction_heads``.
# MTP-Llama places the first prediction head's final norm under ``model.mtp_norms.0`` instead of
# the standard ``model.norm``. The additional per-prediction-distance blocks and norms are
# declared on the base-model converter (they live at the model root, not under ``head``).
return {
"final_norm": NestedWeightConverter(
"final_norm", "model.mtp_norms.0", cls.normalization_converter_class, config_attr="normalization"
),
"output_weights": OutputProjectionWeightConverter("output_weights", "lm_head.weight"),
}

@classmethod
def get_converters(
cls,
config: GPTBaseModelConfig,
) -> list[WeightConverter]:
converters = list(cls.emit_weight_converters(config.head, "head", "", root_config=config))
for prediction_distance in range(2, config.head.prediction_heads + 1):
converters += cls.block_converter_class.emit_weight_converters(
config.decoder.last_block_config,
f"multi_token_prediction.blocks.{prediction_distance - 2}",
f"model.mtp_heads.{prediction_distance - 2}",
root_config=config,
)
converters += cls.normalization_converter_class.emit_weight_converters(
config.head.normalization,
f"multi_token_prediction.heads.{prediction_distance - 2}.final_norm",
f"model.mtp_norms.{prediction_distance - 1}",
root_config=config,
)
return converters


class MTPLlamaBaseModelConverter(LlamaBaseModelConverter):
head_converter_class: typing.ClassVar[type[MTPLlamaHeadConverter]] = MTPLlamaHeadConverter

@classmethod
@functools.cache
def _create_weight_converters(cls) -> dict[str, WeightConverter]:
# The extra prediction-distance heads (distances 2..prediction_heads) repeat the main decoder's
# last block and the head normalization. They sit at the model root as ``multi_token_prediction.*``,
# interleaved on the HF side with the base head's ``model.mtp_norms.0`` (declared on the head).
return {
**super()._create_weight_converters(),
"multi_token_prediction_blocks": RepeatWeightConverter(
cls.block_converter_class,
count=lambda config: config.head.prediction_heads - 1,
sub_config=lambda config: config.decoder.last_block_config,
fast_llm_prefix=lambda index: f"multi_token_prediction.blocks.{index}",
hf_prefix=lambda index: f"model.mtp_heads.{index}",
),
"multi_token_prediction_norms": RepeatWeightConverter(
cls.head_converter_class.normalization_converter_class,
count=lambda config: config.head.prediction_heads - 1,
sub_config=lambda config: config.head.normalization,
fast_llm_prefix=lambda index: f"multi_token_prediction.heads.{index}.final_norm",
hf_prefix=lambda index: f"model.mtp_norms.{index + 1}",
),
}


class MTPLlamaHuggingfaceCheckpointHandler(LlamaHuggingfaceCheckpointHandler):
format: typing.ClassVar[type[CheckpointFormat]] = MTPLlamaCheckpointFormat
Expand Down
22 changes: 19 additions & 3 deletions tests/models/test_hf_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
import torch
from transformers import (
AutoConfig,
Gemma4ForCausalLM,
Gemma4TextConfig,
LlamaConfig,
LlamaForCausalLM,
MistralConfig,
Expand Down Expand Up @@ -54,6 +52,12 @@
from fast_llm_external_models.mtp_llama.configuration_mtp_llama import MTPLlamaConfig
from fast_llm_external_models.mtp_llama.modeling_mtp_llama import MTPLlamaForCausalLM

try:
# Available only in transformers builds that ship Gemma 4.
from transformers import Gemma4ForCausalLM, Gemma4TextConfig
except ImportError:
Gemma4ForCausalLM = Gemma4TextConfig = None


@dataclasses.dataclass(frozen=True)
class HFRoundtripCase:
Expand Down Expand Up @@ -247,7 +251,19 @@ def make_model(self) -> PreTrainedModel:
]


@pytest.mark.parametrize("case", [pytest.param(c, id=c.name) for c in _HF_ROUNDTRIP_CASES])
@pytest.mark.parametrize(
"case",
[
pytest.param(
case,
id=case.name,
marks=pytest.mark.skipif(
case.model_class is None, reason="transformers build does not provide this model class"
),
)
for case in _HF_ROUNDTRIP_CASES
],
)
def test_hf_roundtrip(case: HFRoundtripCase, result_path: pathlib.Path):
"""HF model survives HF → Fast-LLM → HF with identical config and weights."""
base = result_path / "hf_roundtrip" / case.name
Expand Down
23 changes: 23 additions & 0 deletions tests/utils/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1111,6 +1111,20 @@ def model_testing_config(request) -> ModelTestingConfig:
return MODEL_CONFIGS[request.param]


@functools.cache
def _hf_config_class_available(checkpoint_format: type[CheckpointFormat]) -> bool:
"""Whether the installed transformers provides the format's HF config class.

Conversion tests need it; older-but-supported transformers builds may lack a recent model
(e.g. no Gemma 4 before transformers v5), in which case the convert group skips rather than errors.
"""
try:
checkpoint_format.get_handler_class().get_transformers_configuration_class()
except (ImportError, AttributeError):
return False
return True


def testing_group_enabled(item: pytest.Function, skip_slow: bool, skip_extra_slow: bool, show_skipped: bool) -> bool:
if "model_testing_group" in item.keywords:
assert hasattr(item, "callspec") and "model_testing_config" in item.callspec.params, item.nodeid
Expand All @@ -1120,6 +1134,15 @@ def testing_group_enabled(item: pytest.Function, skip_slow: bool, skip_extra_slo
if model_config.requires_cuda and not torch.cuda.is_available():
item.add_marker(pytest.mark.skip(reason=f"Cuda not available."))
for group in groups:
if (
group == ModelTestingGroup.convert
and model_config.checkpoint_format is not None
and not _hf_config_class_available(model_config.checkpoint_format)
):
item.add_marker(
pytest.mark.skip(reason=f"transformers build lacks the HF config class for {model_testing_config}")
)
continue
action = model_config.groups.get(group, ModelTestingGroupAction.unimportant)
if action == ModelTestingGroupAction.main:
pass
Expand Down
Loading