Skip to content

Declarative MTP-Llama head stack + Gemma4 weight aggregation#528

Merged
jlamypoirier merged 7 commits into
mainfrom
jlp_declarative_aggregator_converters
May 29, 2026
Merged

Declarative MTP-Llama head stack + Gemma4 weight aggregation#528
jlamypoirier merged 7 commits into
mainfrom
jlp_declarative_aggregator_converters

Conversation

@jlamypoirier
Copy link
Copy Markdown
Collaborator

@jlamypoirier jlamypoirier commented May 29, 2026

Continues the declarative weight-conversion initiative (#523, #527): removes the last imperative get_converters overrides that hand-rolled weight aggregation, and makes HF conversion work and stay green across the supported transformers range (setup.cfg: >=4.57.3,<6.0.0).

Changes

Gemma4 weight aggregation → declarative. Gemma4BaseModelConverter declares embeddings + decoder as NestedWeightConverters; get_converters collapses to the canonical emit + head form, mirroring LlamaBaseModelConverter. Emitted converters unchanged.

RepeatWeightConverter primitive + MTP-Llama migration. New structural primitive that fans a sub-section converter over a config-driven count with index-templated prefixes, reusing one sub-config per iteration (unlike BlockSequenceWeightConverter, which walks a per-position config list). Used to declare MTP-Llama's per-prediction-distance blocks and norms on the base-model converter — their correct home, since they live at the model root as multi_token_prediction.*, not under head — dropping the imperative get_converters override on the head converter. Emitted converters unchanged.

HF metadata allowlist derived from PretrainedConfig. The coverage check rejected configs from supported transformers builds: the hand-curated _HF_METADATA_ALLOWLIST omits generic PretrainedConfig fields (torchscript + generation kwargs) that v4 dumps into to_dict() but v5 moved to GenerationConfig. So every format's HF import failed on v4 with unknown key 'torchscript'. Now unions the static allowlist with the live PretrainedConfig().to_dict() key set — generic-by-definition, never architecture — adapting across the whole supported range.

Gate optional Gemma4 import + conditional convert-skip. The top-level from transformers import Gemma4ForCausalLM, Gemma4TextConfig (from #492) made test_hf_roundtrip fail to collect on builds without Gemma 4; guarded behind try/except with per-case skip. Symmetrically, testing_group_enabled now skips the convert group when the format's HF config class can't be imported (mirroring the existing requires_cuda env-skip), so test_conversion[gemma4] skips on transformers v4 instead of erroring.

Verification (both supported transformers versions)

v4 (4.57.5) v5 (5.8.1)
test_converters.py walker 37 passed 37 passed
test_hf_roundtrip.py 7 passed, gemma4 skipped 8 passed (gemma4 runs)
test_checkpoint.py (gpt formats) green; gemma4 convert skipped green; gemma4 convert runs

gemma4 conversion is exercised on v5 and cleanly skipped on v4 (no Gemma 4 class there). The torchscript failures that previously hit every format — including untouched llama — are gone.

🤖 Generated with Claude Code

jlamypoirier and others added 7 commits May 29, 2026 13:49
Express the embeddings/decoder weight mapping as NestedWeightConverter declarations in
_create_weight_converters and reduce get_converters to the canonical emit+head form, mirroring
LlamaBaseModelConverter. No change to emitted converters.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…erter

Add RepeatWeightConverter, a structural primitive that fans a sub-section converter over a
config-driven count with index-templated prefixes (reusing one sub-config per iteration, unlike
BlockSequenceWeightConverter which walks a per-position config list).

Use it to declare MTP-Llama's per-prediction-distance blocks and norms on the base-model converter
(where they belong: they live at the model root as multi_token_prediction.*, not under head),
dropping the imperative get_converters override on the head converter. Emitted converters are
unchanged.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
The top-level `from transformers import Gemma4ForCausalLM, Gemma4TextConfig` (added with Gemma 4
support in #492) makes the entire test_hf_roundtrip module fail to collect on transformers builds
without Gemma 4 — taking the llama/mistral/qwen2/mixtral/mtp_llama round-trips down with it. Guard
the import (matching the try/except pattern used elsewhere in tests) and skip any round-trip case
whose model class is unavailable.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…ness

The HF-coverage check rejected configs from supported transformers builds: the hand-curated
_HF_METADATA_ALLOWLIST omits generic PretrainedConfig fields (torchscript and the generation kwargs)
that v4 dumps into to_dict() but v5 moved out to GenerationConfig. With setup.cfg pinning
transformers>=4.57.3,<6.0.0, every format's HF import failed on v4 with "unknown key 'torchscript'".

Union the static allowlist with the live PretrainedConfig().to_dict() key set — every key a bare
PretrainedConfig carries is generic metadata, never architecture — so the check adapts across the
whole supported range instead of pinning a version-specific set.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…ig class

Conversion tests instantiate the format's transformers config class; on supported-but-older
transformers builds the class may not exist (e.g. no Gemma 4 before v5), making test_conversion[gemma4]
error with `module transformers has no attribute Gemma4TextConfig`. Mirror the existing requires_cuda
env-skip: in testing_group_enabled, skip the convert group when the class can't be imported, so the
suite stays green across the supported transformers range (>=4.57,<6.0).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…dConfig union

The static _HF_METADATA_ALLOWLIST kept 13 generic fields (architectures, model_type, dtype, …) that
a bare PretrainedConfig().to_dict() carries on every supported transformers version, so the dynamic
union already covers them. Drop those; keep only what the union misses across the range: auto_map /
torch_dtype / use_cache (absent from a bare config on both v4 and v5), the token ids (absent from a
bare v5 config), and the model-specific init/pretraining metadata. Behavior-preserving.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
The example named a downstream consumer (prediction-distance head stack); the preceding sentence
already conveys the generic intent (runtime count drives the repeat). Per the style guide, explanatory
text should not reference specific consumers.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@jlamypoirier jlamypoirier merged commit 109d9f7 into main May 29, 2026
1 of 2 checks passed
@jlamypoirier jlamypoirier deleted the jlp_declarative_aggregator_converters branch May 29, 2026 22:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant