From 2d245cc03bd473f896a47d1ec3683b88edc81702 Mon Sep 17 00:00:00 2001 From: Yu Shi Jie Date: Mon, 29 Dec 2025 13:33:17 -0500 Subject: [PATCH 01/51] conversion for deepseekv3 models between hf and lit; support for loading finegrained_fp8 weights --- litgpt/config.py | 41 +++++++ litgpt/scripts/convert_hf_checkpoint.py | 76 ++++++++++++ litgpt/scripts/convert_lit_checkpoint.py | 62 ++++++++++ tests/test_model_deepseek_v3.py | 143 +++++++++++++++++++++++ 4 files changed, 322 insertions(+) create mode 100644 tests/test_model_deepseek_v3.py diff --git a/litgpt/config.py b/litgpt/config.py index da7d3ee5bb..9e024e0c3e 100644 --- a/litgpt/config.py +++ b/litgpt/config.py @@ -3148,4 +3148,45 @@ def norm_class(self) -> Type: configs.extend(r1_distill_llama) +deepseek_v3 = [ + # https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/config.json + dict( + name="DeepSeek-V3", + hf_config=dict(org="deepseek-ai", name="DeepSeek-V3"), + block_size=163840, + vocab_size=129280, + padded_vocab_size=129280, + n_layer=61, + n_head=128, + n_embd=7168, + n_query_groups=128, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + norm_class_name="RMSNorm", + norm_eps=1e-6, + mlp_class_name="LLaMAMoE", + intermediate_size=18432, + rope_base=10000, + latent_attention=dict( + q_lora_rank=1536, + kv_lora_rank=512, + qk_nope_head_dim=128, + qk_rope_head_dim=64, + v_head_dim=128, + ), + n_expert=256, + n_shared_expert=1, + n_expert_per_token=8, + n_expert_groups=8, + n_topk_groups=4, + n_topk_scores_per_group=2, # hardcoded in DeepseekV3ForCausalLM + moe_intermediate_size=2048, + first_k_dense_replace=3, + norm_topk_prob=True, + routed_scaling_factor=2.5, + rope_adjustments=dict(factor=40.0, low_freq_factor=1.0, high_freq_factor=32.0, original_max_seq_len=4096), + ), +] + name_to_config = {config["name"]: config for config in configs} diff --git a/litgpt/scripts/convert_hf_checkpoint.py b/litgpt/scripts/convert_hf_checkpoint.py index 7a39c14a58..fc624d1ed3 100644 --- a/litgpt/scripts/convert_hf_checkpoint.py +++ b/litgpt/scripts/convert_hf_checkpoint.py @@ -717,6 +717,78 @@ def copy_weights_qwen_3( if progress_per_file is not None: pbar.update(progress_per_file) +def copy_weights_deepseek_v3( + config: Config, + qkv_weights: Dict[int, List[Optional[NotYetLoadedTensor]]], + state_dict: Dict[str, torch.Tensor], + hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], + saver: Optional[incremental_save] = None, + dtype: Optional[torch.dtype] = None, + pbar: Optional[tqdm] = None, + progress_per_file: Optional[float] = None, + debug_mode: Optional[bool] = False, +) -> None: + weight_map = { + "model.embed_tokens.weight": "transformer.wte.weight", + "model.layers.{}.input_layernorm.weight": "transformer.h.{}.norm_1.weight", + "model.layers.{}.self_attn.q_a_proj.weight": "transformer.h.{}.attn.q_a_proj.weight", + "model.layers.{}.self_attn.q_a_proj.weight_scale_inv": "transformer.h.{}.attn.q_a_proj.weight_scale_inv", + "model.layers.{}.self_attn.q_a_layernorm.weight": "transformer.h.{}.attn.q_a_norm.weight", + "model.layers.{}.self_attn.q_b_proj.weight": "transformer.h.{}.attn.q_b_proj.weight", + "model.layers.{}.self_attn.q_b_proj.weight_scale_inv": "transformer.h.{}.attn.q_b_proj.weight_scale_inv", + "model.layers.{}.self_attn.kv_a_proj_with_mqa.weight": "transformer.h.{}.attn.kv_a_proj_with_mqa.weight", + "model.layers.{}.self_attn.kv_a_proj_with_mqa.weight_scale_inv": "transformer.h.{}.attn.kv_a_proj_with_mqa.weight_scale_inv", + "model.layers.{}.self_attn.kv_a_layernorm.weight": "transformer.h.{}.attn.kv_a_norm.weight", + "model.layers.{}.self_attn.kv_b_proj.weight": "transformer.h.{}.attn.kv_b_proj.weight", + "model.layers.{}.self_attn.kv_b_proj.weight_scale_inv": "transformer.h.{}.attn.kv_b_proj.weight_scale_inv", + "model.layers.{}.self_attn.o_proj.weight": "transformer.h.{}.attn.proj.weight", + "model.layers.{}.self_attn.o_proj.weight_scale_inv": "transformer.h.{}.attn.proj.weight_scale_inv", + "model.layers.{}.post_attention_layernorm.weight": "transformer.h.{}.norm_2.weight", + "model.norm.weight": "transformer.ln_f.weight", + "lm_head.weight": "lm_head.weight", + } + if config.mlp_class_name == "LLaMAMoE": + weight_map.update( + { + "model.layers.{}.mlp.experts.{}.gate_proj.weight": "transformer.h.{}.mlp.experts.{}.fc_1.weight", + "model.layers.{}.mlp.experts.{}.gate_proj.weight_scale_inv": "transformer.h.{}.mlp.experts.{}.fc_1.weight_scale_inv", + "model.layers.{}.mlp.experts.{}.up_proj.weight": "transformer.h.{}.mlp.experts.{}.fc_2.weight", + "model.layers.{}.mlp.experts.{}.up_proj.weight_scale_inv": "transformer.h.{}.mlp.experts.{}.fc_2.weight_scale_inv", + "model.layers.{}.mlp.experts.{}.down_proj.weight": "transformer.h.{}.mlp.experts.{}.proj.weight", + "model.layers.{}.mlp.experts.{}.down_proj.weight_scale_inv": "transformer.h.{}.mlp.experts.{}.proj.weight_scale_inv", + "model.layers.{}.mlp.shared_experts.gate_proj.weight": "transformer.h.{}.mlp.shared_experts.fc_1.weight", + "model.layers.{}.mlp.shared_experts.gate_proj.weight_scale_inv": "transformer.h.{}.mlp.shared_experts.fc_1.weight_scale_inv", + "model.layers.{}.mlp.shared_experts.up_proj.weight": "transformer.h.{}.mlp.shared_experts.fc_2.weight", + "model.layers.{}.mlp.shared_experts.up_proj.weight_scale_inv": "transformer.h.{}.mlp.shared_experts.fc_2.weight_scale_inv", + "model.layers.{}.mlp.shared_experts.down_proj.weight": "transformer.h.{}.mlp.shared_experts.proj.weight", + "model.layers.{}.mlp.shared_experts.down_proj.weight_scale_inv": "transformer.h.{}.mlp.shared_experts.proj.weight_scale_inv", + "model.layers.{}.mlp.gate.weight": "transformer.h.{}.mlp.gate.weight", + "model.layers.{}.mlp.e_score_correction_bias": "transformer.h.{}.mlp.e_score_correction_bias", + } + ) + else: + raise NotImplementedError + + if progress_per_file is not None: + progress_per_file = progress_per_file / max(1, len(hf_weights) + len(qkv_weights)) + + for from_name, param in hf_weights.items(): + name_template, *ids = layer_template(from_name, num_matches=2) + to_name = weight_map[name_template] + param = load_param(param, from_name, dtype, verbose=debug_mode) + if to_name is None: + continue + to_name = to_name.format(*ids) + if saver is not None: + param = saver.store_early(param) + state_dict[to_name] = param + + if progress_per_file is not None: + pbar.update(progress_per_file) + + if "lm_head.weight" not in state_dict: + state_dict["lm_head.weight"] = state_dict["transformer.wte.weight"] + def qkv_reassemble( param: Union[torch.Tensor, NotYetLoadedTensor], config: Config @@ -817,6 +889,10 @@ def convert_hf_checkpoint( # holder to reconstitute the split q, k, v qkv_weights = {} copy_fn = partial(copy_weights_qwen_3, config, qkv_weights) + elif model_name.lower().startswith("deepseek"): + # holder to reconstitute the split q, k, v + qkv_weights = {} + copy_fn = partial(copy_weights_deepseek_v3, config, qkv_weights) elif config.mlp_class_name in ("LLaMAMLP", "GemmaMLP", "LLaMAMoE"): # holder to reconstitute the split q, k, v qkv_weights = {} diff --git a/litgpt/scripts/convert_lit_checkpoint.py b/litgpt/scripts/convert_lit_checkpoint.py index d7ce885c55..357d90b240 100644 --- a/litgpt/scripts/convert_lit_checkpoint.py +++ b/litgpt/scripts/convert_lit_checkpoint.py @@ -516,6 +516,66 @@ def copy_weights_qwen_3( param = saver.store_early(param) state_dict[to_name] = param +def copy_weights_deepseek_v3( + config: Config, + state_dict: Dict[str, torch.Tensor], + lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], + untie_weights: bool = False, + saver: Optional[incremental_save] = None, +) -> None: + weight_map = { + "transformer.wte.weight": "model.embed_tokens.weight", + "transformer.h.{}.norm_1.weight": "model.layers.{}.input_layernorm.weight", + "transformer.h.{}.attn.q_a_proj.weight": "model.layers.{}.self_attn.q_a_proj.weight", + "transformer.h.{}.attn.q_a_proj.weight_scale_inv": "model.layers.{}.self_attn.q_a_proj.weight_scale_inv", + "transformer.h.{}.attn.q_a_norm.weight": "model.layers.{}.self_attn.q_a_layernorm.weight", + "transformer.h.{}.attn.q_b_proj.weight": "model.layers.{}.self_attn.q_b_proj.weight", + "transformer.h.{}.attn.q_b_proj.weight_scale_inv": "model.layers.{}.self_attn.q_b_proj.weight_scale_inv", + "transformer.h.{}.attn.kv_a_proj_with_mqa.weight": "model.layers.{}.self_attn.kv_a_proj_with_mqa.weight", + "transformer.h.{}.attn.kv_a_proj_with_mqa.weight_scale_inv": "model.layers.{}.self_attn.kv_a_proj_with_mqa.weight_scale_inv", + "transformer.h.{}.attn.kv_a_norm.weight": "model.layers.{}.self_attn.kv_a_layernorm.weight", + "transformer.h.{}.attn.kv_b_proj.weight": "model.layers.{}.self_attn.kv_b_proj.weight", + "transformer.h.{}.attn.kv_b_proj.weight_scale_inv": "model.layers.{}.self_attn.kv_b_proj.weight_scale_inv", + "transformer.h.{}.attn.proj.weight": "model.layers.{}.self_attn.o_proj.weight", + "transformer.h.{}.attn.proj.weight_scale_inv": "model.layers.{}.self_attn.o_proj.weight_scale_inv", + "transformer.h.{}.norm_2.weight": "model.layers.{}.post_attention_layernorm.weight", + "transformer.ln_f.weight": "model.norm.weight", + "lm_head.weight": "lm_head.weight", + } + if config.mlp_class_name == "LLaMAMoE": + weight_map.update( + { + "transformer.h.{}.mlp.experts.{}.fc_1.weight": "model.layers.{}.mlp.experts.{}.gate_proj.weight", + "transformer.h.{}.mlp.experts.{}.fc_1.weight_scale_inv": "model.layers.{}.mlp.experts.{}.gate_proj.weight_scale_inv", + "transformer.h.{}.mlp.experts.{}.fc_2.weight": "model.layers.{}.mlp.experts.{}.up_proj.weight", + "transformer.h.{}.mlp.experts.{}.fc_2.weight_scale_inv": "model.layers.{}.mlp.experts.{}.up_proj.weight_scale_inv", + "transformer.h.{}.mlp.experts.{}.proj.weight": "model.layers.{}.mlp.experts.{}.down_proj.weight", + "transformer.h.{}.mlp.experts.{}.proj.weight_scale_inv": "model.layers.{}.mlp.experts.{}.down_proj.weight_scale_inv", + "transformer.h.{}.mlp.shared_experts.fc_1.weight": "model.layers.{}.mlp.shared_experts.gate_proj.weight", + "transformer.h.{}.mlp.shared_experts.fc_1.weight_scale_inv": "model.layers.{}.mlp.shared_experts.gate_proj.weight_scale_inv", + "transformer.h.{}.mlp.shared_experts.fc_2.weight": "model.layers.{}.mlp.shared_experts.up_proj.weight", + "transformer.h.{}.mlp.shared_experts.fc_2.weight_scale_inv": "model.layers.{}.mlp.shared_experts.up_proj.weight_scale_inv", + "transformer.h.{}.mlp.shared_experts.proj.weight": "model.layers.{}.mlp.shared_experts.down_proj.weight", + "transformer.h.{}.mlp.shared_experts.proj.weight_scale_inv": "model.layers.{}.mlp.shared_experts.down_proj.weight_scale_inv", + "transformer.h.{}.mlp.gate.weight": "model.layers.{}.mlp.gate.weight", + "transformer.h.{}.mlp.e_score_correction_bias": "model.layers.{}.mlp.e_score_correction_bias", + } + ) + else: + raise NotImplementedError + + for from_name, param in lit_weights.items(): + if from_name == "lm_head.weight" and untie_weights: + continue + name_template, *ids = layer_template(from_name, num_matches=2) + param = load_param(param, from_name, None) + to_names = (weight_map[name_template].format(*ids),) + params = (param,) + + for to_name, param in zip(to_names, params): + if saver is not None: + param = saver.store_early(param) + state_dict[to_name] = param def qkv_reassemble(param: Union[torch.Tensor, NotYetLoadedTensor], config: Config) -> torch.Tensor: """Reassemble from a normal to an interleaved placement in a QKV matrix. @@ -565,6 +625,8 @@ def convert_lit_checkpoint(checkpoint_dir: Path, output_dir: Path) -> None: copy_fn = partial(copy_weights_olmo2, config) elif config.name.lower().startswith("qwen3"): copy_fn = partial(copy_weights_qwen_3, config) + elif config.name.lower().startswith("deepseek"): + copy_fn = partial(copy_weights_deepseek_v3, config) elif config.mlp_class_name in ("LLaMAMLP", "GemmaMLP", "LLaMAMoE"): untie_weights = "Gemma" in config.name copy_fn = partial(copy_weights_llama, config, untie_weights=untie_weights) diff --git a/tests/test_model_deepseek_v3.py b/tests/test_model_deepseek_v3.py new file mode 100644 index 0000000000..41174342ef --- /dev/null +++ b/tests/test_model_deepseek_v3.py @@ -0,0 +1,143 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. + +from copy import deepcopy +from functools import partial +from unittest import mock + +import pytest +import torch +from lightning import Fabric +from lightning.fabric.utilities.imports import _IS_WINDOWS +from lightning.fabric.utilities.init import _materialize_meta_tensors +from torch._dynamo.backends import debugging +from torch.backends.cuda import ( + SDPAParams, + SDPBackend, + can_use_efficient_attention, + can_use_flash_attention, + flash_sdp_enabled, + math_sdp_enabled, + mem_efficient_sdp_enabled, +) +from transformers import AutoConfig, AutoModelForCausalLM +from transformers.models.deepseek_v3 import DeepseekV3Config, DeepseekV3ForCausalLM +from transformers.integrations.finegrained_fp8 import FP8Linear + +import litgpt.config as config_module +from litgpt import GPT, Config +from litgpt.model import CausalSelfAttention, batched_index_copy_ +from litgpt.scripts.convert_hf_checkpoint import ( + copy_weights_deepseek_v3, +) +from litgpt.scripts.convert_lit_checkpoint import qkv_reassemble as make_qkv_interleaved +from litgpt.utils import _RunIf + + +@torch.inference_mode() +@pytest.mark.parametrize( + "model_name", ["DeepSeek-V3"] +) +@pytest.mark.parametrize( + ("device", "dtype"), + [ + (torch.device("cpu"), torch.float32), + pytest.param( + torch.device("cuda"), + torch.float16, + marks=[ + # the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input + # is slightly different + pytest.mark.xfail(raises=AssertionError, strict=False), + _RunIf(min_cuda_gpus=1), + ], + ), + ], +) +def test_against_original_deepseek_v3(model_name, device, dtype): + torch.set_default_dtype(dtype) + + T = 20 + ours_config = Config.from_name( + model_name, + block_size=T, + n_layer=2, + n_head=16, + n_embd=32, + n_query_groups=4, + intermediate_size=86, + moe_intermediate_size=20, + n_expert=4, + n_shared_expert=1, + n_expert_per_token=2, + n_expert_groups=2, + n_topk_groups=2, + n_topk_scores_per_group=2, # hardcoded in DeepseekV3ForCausalLM + first_k_dense_replace=1, + latent_attention=dict( + q_lora_rank=16, + kv_lora_rank=16, + qk_rope_head_dim=8, + qk_nope_head_dim=8, + v_head_dim=16, + ), + ) + theirs_config = DeepseekV3Config( + vocab_size=ours_config.padded_vocab_size, + hidden_size=ours_config.n_embd, + head_dim=ours_config.head_size, + num_attention_heads=ours_config.n_head, + num_hidden_layers=ours_config.n_layer, + intermediate_size=ours_config.intermediate_size, + moe_intermediate_size=ours_config.moe_intermediate_size, + max_position_embeddings=ours_config.block_size, + rms_norm_eps=ours_config.norm_eps, + num_key_value_heads=ours_config.n_query_groups, + rope_theta=ours_config.rope_base, + tie_word_embeddings=False, + num_experts_per_tok=ours_config.n_expert_per_token, + norm_topk_prob=True, + n_routed_experts=ours_config.n_expert, # 256 + n_shared_experts=ours_config.n_shared_expert, # 1 + n_group=ours_config.n_expert_groups, + topk_group=ours_config.n_topk_groups, + routed_scaling_factor=ours_config.routed_scaling_factor, # 2.5 + first_k_dense_replace=ours_config.first_k_dense_replace, + qk_nope_head_dim=ours_config.latent_attention["qk_nope_head_dim"], # 128 + qk_rope_head_dim=ours_config.latent_attention["qk_rope_head_dim"], + v_head_dim=ours_config.latent_attention["v_head_dim"], + q_lora_rank=ours_config.latent_attention["q_lora_rank"], + kv_lora_rank=ours_config.latent_attention["kv_lora_rank"], + ) + + theirs_model = DeepseekV3ForCausalLM(theirs_config).to(device) + theirs_state_dict = theirs_model.state_dict() + state_dict = {} + copy_weights_deepseek_v3(ours_config, {}, state_dict, theirs_state_dict) + ours_model = GPT(ours_config) + ours_model = patch_deepseek_v3(ours_model) + ours_model.to(device) + ours_model.load_state_dict(state_dict) + + # test end to end + x = torch.randint(low=0, high=ours_config.padded_vocab_size, size=(T,), device=device).unsqueeze(0) + assert x.size(1) == T + ours_y = ours_model(x) + theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float + torch.testing.assert_close(ours_y, theirs_y) + +def patch_deepseek_v3(model: GPT): + to_replace = ["attn.q_a_proj", "attn.q_b_proj", "attn.kv_a_proj_with_mqa", "attn.kv_b_proj", "attn.proj", "gate_proj", "up_proj", "down_proj"] + for name, module in model.named_modules(): + new_module = None + with torch.device("meta"): + if isinstance(module, nn.Linear) and any(to_replace) in name: + new_module = FP8Linear( + in_features=module.in_features, + out_features=module.out_features, + bias=module.bias is not None, + activation_scheme="dynamic", + block_size=(128, 128), + ) + if new_module is not None: + model.replace_module(name, new_module) + return model From 18c851fc367c10ab47e8a57e320ae3c8a8106aac Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 29 Dec 2025 18:37:50 +0000 Subject: [PATCH 02/51] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- litgpt/config.py | 2 +- litgpt/scripts/convert_hf_checkpoint.py | 1 + litgpt/scripts/convert_lit_checkpoint.py | 2 ++ tests/test_model_deepseek_v3.py | 40 +++++++++--------------- 4 files changed, 18 insertions(+), 27 deletions(-) diff --git a/litgpt/config.py b/litgpt/config.py index 9e024e0c3e..ce6b9d8973 100644 --- a/litgpt/config.py +++ b/litgpt/config.py @@ -3180,7 +3180,7 @@ def norm_class(self) -> Type: n_expert_per_token=8, n_expert_groups=8, n_topk_groups=4, - n_topk_scores_per_group=2, # hardcoded in DeepseekV3ForCausalLM + n_topk_scores_per_group=2, # hardcoded in DeepseekV3ForCausalLM moe_intermediate_size=2048, first_k_dense_replace=3, norm_topk_prob=True, diff --git a/litgpt/scripts/convert_hf_checkpoint.py b/litgpt/scripts/convert_hf_checkpoint.py index fc624d1ed3..5a6537b863 100644 --- a/litgpt/scripts/convert_hf_checkpoint.py +++ b/litgpt/scripts/convert_hf_checkpoint.py @@ -717,6 +717,7 @@ def copy_weights_qwen_3( if progress_per_file is not None: pbar.update(progress_per_file) + def copy_weights_deepseek_v3( config: Config, qkv_weights: Dict[int, List[Optional[NotYetLoadedTensor]]], diff --git a/litgpt/scripts/convert_lit_checkpoint.py b/litgpt/scripts/convert_lit_checkpoint.py index 357d90b240..53580100bf 100644 --- a/litgpt/scripts/convert_lit_checkpoint.py +++ b/litgpt/scripts/convert_lit_checkpoint.py @@ -516,6 +516,7 @@ def copy_weights_qwen_3( param = saver.store_early(param) state_dict[to_name] = param + def copy_weights_deepseek_v3( config: Config, state_dict: Dict[str, torch.Tensor], @@ -577,6 +578,7 @@ def copy_weights_deepseek_v3( param = saver.store_early(param) state_dict[to_name] = param + def qkv_reassemble(param: Union[torch.Tensor, NotYetLoadedTensor], config: Config) -> torch.Tensor: """Reassemble from a normal to an interleaved placement in a QKV matrix. [Q, Q, ..., K, K, ..., V, V, ...] --> [Q, K, V, Q, K, V, ...] diff --git a/tests/test_model_deepseek_v3.py b/tests/test_model_deepseek_v3.py index 41174342ef..77c8cee510 100644 --- a/tests/test_model_deepseek_v3.py +++ b/tests/test_model_deepseek_v3.py @@ -1,42 +1,20 @@ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. -from copy import deepcopy -from functools import partial -from unittest import mock import pytest import torch -from lightning import Fabric -from lightning.fabric.utilities.imports import _IS_WINDOWS -from lightning.fabric.utilities.init import _materialize_meta_tensors -from torch._dynamo.backends import debugging -from torch.backends.cuda import ( - SDPAParams, - SDPBackend, - can_use_efficient_attention, - can_use_flash_attention, - flash_sdp_enabled, - math_sdp_enabled, - mem_efficient_sdp_enabled, -) -from transformers import AutoConfig, AutoModelForCausalLM -from transformers.models.deepseek_v3 import DeepseekV3Config, DeepseekV3ForCausalLM from transformers.integrations.finegrained_fp8 import FP8Linear +from transformers.models.deepseek_v3 import DeepseekV3Config, DeepseekV3ForCausalLM -import litgpt.config as config_module from litgpt import GPT, Config -from litgpt.model import CausalSelfAttention, batched_index_copy_ from litgpt.scripts.convert_hf_checkpoint import ( copy_weights_deepseek_v3, ) -from litgpt.scripts.convert_lit_checkpoint import qkv_reassemble as make_qkv_interleaved from litgpt.utils import _RunIf @torch.inference_mode() -@pytest.mark.parametrize( - "model_name", ["DeepSeek-V3"] -) +@pytest.mark.parametrize("model_name", ["DeepSeek-V3"]) @pytest.mark.parametrize( ("device", "dtype"), [ @@ -71,7 +49,7 @@ def test_against_original_deepseek_v3(model_name, device, dtype): n_expert_per_token=2, n_expert_groups=2, n_topk_groups=2, - n_topk_scores_per_group=2, # hardcoded in DeepseekV3ForCausalLM + n_topk_scores_per_group=2, # hardcoded in DeepseekV3ForCausalLM first_k_dense_replace=1, latent_attention=dict( q_lora_rank=16, @@ -125,8 +103,18 @@ def test_against_original_deepseek_v3(model_name, device, dtype): theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float torch.testing.assert_close(ours_y, theirs_y) + def patch_deepseek_v3(model: GPT): - to_replace = ["attn.q_a_proj", "attn.q_b_proj", "attn.kv_a_proj_with_mqa", "attn.kv_b_proj", "attn.proj", "gate_proj", "up_proj", "down_proj"] + to_replace = [ + "attn.q_a_proj", + "attn.q_b_proj", + "attn.kv_a_proj_with_mqa", + "attn.kv_b_proj", + "attn.proj", + "gate_proj", + "up_proj", + "down_proj", + ] for name, module in model.named_modules(): new_module = None with torch.device("meta"): From edd45a98016f108588e437c82b140edb3afaaacc Mon Sep 17 00:00:00 2001 From: Yu Shi Jie Date: Sat, 7 Feb 2026 20:56:07 -0500 Subject: [PATCH 03/51] fix n_query_group == n_head for deepseekv3 --- litgpt/config.py | 2 ++ litgpt/scripts/convert_hf_checkpoint.py | 8 +++++++- tests/test_model_deepseek_v3.py | 2 +- 3 files changed, 10 insertions(+), 2 deletions(-) diff --git a/litgpt/config.py b/litgpt/config.py index 5df5369210..5d2c27ca7e 100644 --- a/litgpt/config.py +++ b/litgpt/config.py @@ -3224,4 +3224,6 @@ def check_indicator_and_length( ), ] +configs.extend(deepseek_v3) + name_to_config = {config["name"]: config for config in configs} diff --git a/litgpt/scripts/convert_hf_checkpoint.py b/litgpt/scripts/convert_hf_checkpoint.py index 1e2b2159c5..4fa6a818fc 100644 --- a/litgpt/scripts/convert_hf_checkpoint.py +++ b/litgpt/scripts/convert_hf_checkpoint.py @@ -754,9 +754,15 @@ def copy_weights_deepseek_v3( "model.norm.weight": "transformer.ln_f.weight", "lm_head.weight": "lm_head.weight", } - if config.mlp_class_name == "LLaMAMoE": + if config.mlp_class_name == "LLaMAMoE": # Deepseek V3 has both MoE and MLP layers (specified with `first_k_dense_replace`), but we treat it as a LLaMAMoE type weight_map.update( { + "model.layers.{}.mlp.gate_proj.weight": "transformer.h.{}.mlp.gate_proj.weight", + "model.layers.{}.mlp.gate_proj.weight_scale_inv": "transformer.h.{}.mlp.gate_proj.weight_scale_inv", + "model.layers.{}.mlp.up_proj.weight": "transformer.h.{}.mlp.up_proj.weight", + "model.layers.{}.mlp.up_proj.weight_scale_inv": "transformer.h.{}.mlp.up_proj.weight_scale_inv", + "model.layers.{}.mlp.down_proj.weight": "transformer.h.{}.mlp.down_proj.weight", + "model.layers.{}.mlp.down_proj.weight_scale_inv": "transformer.h.{}.mlp.down_proj.weight_scale_inv", "model.layers.{}.mlp.experts.{}.gate_proj.weight": "transformer.h.{}.mlp.experts.{}.fc_1.weight", "model.layers.{}.mlp.experts.{}.gate_proj.weight_scale_inv": "transformer.h.{}.mlp.experts.{}.fc_1.weight_scale_inv", "model.layers.{}.mlp.experts.{}.up_proj.weight": "transformer.h.{}.mlp.experts.{}.fc_2.weight", diff --git a/tests/test_model_deepseek_v3.py b/tests/test_model_deepseek_v3.py index 77c8cee510..18e7d03ac6 100644 --- a/tests/test_model_deepseek_v3.py +++ b/tests/test_model_deepseek_v3.py @@ -41,7 +41,7 @@ def test_against_original_deepseek_v3(model_name, device, dtype): n_layer=2, n_head=16, n_embd=32, - n_query_groups=4, + n_query_groups=16, intermediate_size=86, moe_intermediate_size=20, n_expert=4, From d17ff83a109bd1864c28da4f66abed1fdd19cadf Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 8 Feb 2026 01:56:24 +0000 Subject: [PATCH 04/51] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- litgpt/scripts/convert_hf_checkpoint.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/litgpt/scripts/convert_hf_checkpoint.py b/litgpt/scripts/convert_hf_checkpoint.py index 4fa6a818fc..d93fc97881 100644 --- a/litgpt/scripts/convert_hf_checkpoint.py +++ b/litgpt/scripts/convert_hf_checkpoint.py @@ -754,7 +754,9 @@ def copy_weights_deepseek_v3( "model.norm.weight": "transformer.ln_f.weight", "lm_head.weight": "lm_head.weight", } - if config.mlp_class_name == "LLaMAMoE": # Deepseek V3 has both MoE and MLP layers (specified with `first_k_dense_replace`), but we treat it as a LLaMAMoE type + if ( + config.mlp_class_name == "LLaMAMoE" + ): # Deepseek V3 has both MoE and MLP layers (specified with `first_k_dense_replace`), but we treat it as a LLaMAMoE type weight_map.update( { "model.layers.{}.mlp.gate_proj.weight": "transformer.h.{}.mlp.gate_proj.weight", From 063d71b233b3f015a066a27301cff03ced1f431b Mon Sep 17 00:00:00 2001 From: Yu Shi Jie Date: Sat, 7 Feb 2026 21:40:40 -0500 Subject: [PATCH 05/51] minor fixes --- litgpt/scripts/convert_hf_checkpoint.py | 14 +++++++------- litgpt/scripts/convert_lit_checkpoint.py | 8 +++++++- tests/test_model_deepseek_v3.py | 8 +++++--- 3 files changed, 19 insertions(+), 11 deletions(-) diff --git a/litgpt/scripts/convert_hf_checkpoint.py b/litgpt/scripts/convert_hf_checkpoint.py index 4fa6a818fc..97fcd80912 100644 --- a/litgpt/scripts/convert_hf_checkpoint.py +++ b/litgpt/scripts/convert_hf_checkpoint.py @@ -757,12 +757,12 @@ def copy_weights_deepseek_v3( if config.mlp_class_name == "LLaMAMoE": # Deepseek V3 has both MoE and MLP layers (specified with `first_k_dense_replace`), but we treat it as a LLaMAMoE type weight_map.update( { - "model.layers.{}.mlp.gate_proj.weight": "transformer.h.{}.mlp.gate_proj.weight", - "model.layers.{}.mlp.gate_proj.weight_scale_inv": "transformer.h.{}.mlp.gate_proj.weight_scale_inv", - "model.layers.{}.mlp.up_proj.weight": "transformer.h.{}.mlp.up_proj.weight", - "model.layers.{}.mlp.up_proj.weight_scale_inv": "transformer.h.{}.mlp.up_proj.weight_scale_inv", - "model.layers.{}.mlp.down_proj.weight": "transformer.h.{}.mlp.down_proj.weight", - "model.layers.{}.mlp.down_proj.weight_scale_inv": "transformer.h.{}.mlp.down_proj.weight_scale_inv", + "model.layers.{}.mlp.gate_proj.weight": "transformer.h.{}.mlp.fc_1.weight", + "model.layers.{}.mlp.gate_proj.weight_scale_inv": "transformer.h.{}.mlp.fc_1.weight_scale_inv", + "model.layers.{}.mlp.up_proj.weight": "transformer.h.{}.mlp.fc_2.weight", + "model.layers.{}.mlp.up_proj.weight_scale_inv": "transformer.h.{}.mlp.fc_2.weight_scale_inv", + "model.layers.{}.mlp.down_proj.weight": "transformer.h.{}.mlp.proj.weight", + "model.layers.{}.mlp.down_proj.weight_scale_inv": "transformer.h.{}.mlp.proj.weight_scale_inv", "model.layers.{}.mlp.experts.{}.gate_proj.weight": "transformer.h.{}.mlp.experts.{}.fc_1.weight", "model.layers.{}.mlp.experts.{}.gate_proj.weight_scale_inv": "transformer.h.{}.mlp.experts.{}.fc_1.weight_scale_inv", "model.layers.{}.mlp.experts.{}.up_proj.weight": "transformer.h.{}.mlp.experts.{}.fc_2.weight", @@ -776,7 +776,7 @@ def copy_weights_deepseek_v3( "model.layers.{}.mlp.shared_experts.down_proj.weight": "transformer.h.{}.mlp.shared_experts.proj.weight", "model.layers.{}.mlp.shared_experts.down_proj.weight_scale_inv": "transformer.h.{}.mlp.shared_experts.proj.weight_scale_inv", "model.layers.{}.mlp.gate.weight": "transformer.h.{}.mlp.gate.weight", - "model.layers.{}.mlp.e_score_correction_bias": "transformer.h.{}.mlp.e_score_correction_bias", + "model.layers.{}.mlp.gate.e_score_correction_bias": "transformer.h.{}.mlp.gate.e_score_correction_bias", } ) else: diff --git a/litgpt/scripts/convert_lit_checkpoint.py b/litgpt/scripts/convert_lit_checkpoint.py index a266c78df9..b3fe4507d8 100644 --- a/litgpt/scripts/convert_lit_checkpoint.py +++ b/litgpt/scripts/convert_lit_checkpoint.py @@ -546,6 +546,12 @@ def copy_weights_deepseek_v3( if config.mlp_class_name == "LLaMAMoE": weight_map.update( { + "transformer.h.{}.mlp.fc_1.weight": "model.layers.{}.mlp.gate_proj.weight", + "transformer.h.{}.mlp.fc_1.weight_scale_inv": "model.layers.{}.mlp.gate_proj.weight_scale_inv", + "transformer.h.{}.mlp.fc_2.weight": "model.layers.{}.mlp.up_proj.weight", + "transformer.h.{}.mlp.fc_2.weight_scale_inv": "model.layers.{}.mlp.up_proj.weight_scale_inv", + "transformer.h.{}.mlp.proj.weight": "model.layers.{}.mlp.down_proj.weight", + "transformer.h.{}.mlp.proj.weight_scale_inv": "model.layers.{}.mlp.down_proj.weight_scale_inv", "transformer.h.{}.mlp.experts.{}.fc_1.weight": "model.layers.{}.mlp.experts.{}.gate_proj.weight", "transformer.h.{}.mlp.experts.{}.fc_1.weight_scale_inv": "model.layers.{}.mlp.experts.{}.gate_proj.weight_scale_inv", "transformer.h.{}.mlp.experts.{}.fc_2.weight": "model.layers.{}.mlp.experts.{}.up_proj.weight", @@ -559,7 +565,7 @@ def copy_weights_deepseek_v3( "transformer.h.{}.mlp.shared_experts.proj.weight": "model.layers.{}.mlp.shared_experts.down_proj.weight", "transformer.h.{}.mlp.shared_experts.proj.weight_scale_inv": "model.layers.{}.mlp.shared_experts.down_proj.weight_scale_inv", "transformer.h.{}.mlp.gate.weight": "model.layers.{}.mlp.gate.weight", - "transformer.h.{}.mlp.e_score_correction_bias": "model.layers.{}.mlp.e_score_correction_bias", + "transformer.h.{}.mlp.gate.e_score_correction_bias": "model.layers.{}.mlp.gate.e_score_correction_bias", } ) else: diff --git a/tests/test_model_deepseek_v3.py b/tests/test_model_deepseek_v3.py index 18e7d03ac6..33802dc4de 100644 --- a/tests/test_model_deepseek_v3.py +++ b/tests/test_model_deepseek_v3.py @@ -111,9 +111,11 @@ def patch_deepseek_v3(model: GPT): "attn.kv_a_proj_with_mqa", "attn.kv_b_proj", "attn.proj", - "gate_proj", - "up_proj", - "down_proj", + "mlp.fc_1", + "mlp.fc_2", + "mlp.proj", + "mlp.experts", + "mlp.shared_experts" ] for name, module in model.named_modules(): new_module = None From c7a852a716ee34c28f6e06db86807a1227390005 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 8 Feb 2026 02:44:14 +0000 Subject: [PATCH 06/51] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_model_deepseek_v3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_model_deepseek_v3.py b/tests/test_model_deepseek_v3.py index 33802dc4de..2c1fb434b2 100644 --- a/tests/test_model_deepseek_v3.py +++ b/tests/test_model_deepseek_v3.py @@ -115,7 +115,7 @@ def patch_deepseek_v3(model: GPT): "mlp.fc_2", "mlp.proj", "mlp.experts", - "mlp.shared_experts" + "mlp.shared_experts", ] for name, module in model.named_modules(): new_module = None From ea6075806bcf4a6822fed8b6c5e2e47f5416eb9a Mon Sep 17 00:00:00 2001 From: Yu Shi Jie Date: Sat, 7 Feb 2026 21:50:32 -0500 Subject: [PATCH 07/51] fix import in test_model_deepseek_v3 --- tests/test_model_deepseek_v3.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_model_deepseek_v3.py b/tests/test_model_deepseek_v3.py index 33802dc4de..06f724c380 100644 --- a/tests/test_model_deepseek_v3.py +++ b/tests/test_model_deepseek_v3.py @@ -3,6 +3,7 @@ import pytest import torch +import torch.nn as nn from transformers.integrations.finegrained_fp8 import FP8Linear from transformers.models.deepseek_v3 import DeepseekV3Config, DeepseekV3ForCausalLM From 1b12177b506e426c8b9699decafb1985a2ebcf23 Mon Sep 17 00:00:00 2001 From: Yu Shi Jie Date: Sat, 7 Feb 2026 21:54:29 -0500 Subject: [PATCH 08/51] fix(deepseek-v3): correct weight mappings and FP8Linear patching --- tests/test_model_deepseek_v3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_model_deepseek_v3.py b/tests/test_model_deepseek_v3.py index eaaec1f0fc..184beba381 100644 --- a/tests/test_model_deepseek_v3.py +++ b/tests/test_model_deepseek_v3.py @@ -121,7 +121,7 @@ def patch_deepseek_v3(model: GPT): for name, module in model.named_modules(): new_module = None with torch.device("meta"): - if isinstance(module, nn.Linear) and any(to_replace) in name: + if isinstance(module, nn.Linear) and any(target in name for target in to_replace): new_module = FP8Linear( in_features=module.in_features, out_features=module.out_features, From a8b47ec716294ac3b5531c97ac8192ea6d6a42de Mon Sep 17 00:00:00 2001 From: Yu Shi Jie Date: Sat, 7 Feb 2026 22:01:58 -0500 Subject: [PATCH 09/51] fix: replace_module -> set_submodule --- tests/test_model_deepseek_v3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_model_deepseek_v3.py b/tests/test_model_deepseek_v3.py index 184beba381..512086fa40 100644 --- a/tests/test_model_deepseek_v3.py +++ b/tests/test_model_deepseek_v3.py @@ -130,5 +130,5 @@ def patch_deepseek_v3(model: GPT): block_size=(128, 128), ) if new_module is not None: - model.replace_module(name, new_module) + model.set_submodule(name, new_module) return model From 8a392fb56f04031f4217d38fec0e590a63ff8160 Mon Sep 17 00:00:00 2001 From: Yu Shi Jie Date: Sat, 7 Feb 2026 22:05:01 -0500 Subject: [PATCH 10/51] fix --- tests/test_model_deepseek_v3.py | 33 ++++++++++++++++++++++----------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/tests/test_model_deepseek_v3.py b/tests/test_model_deepseek_v3.py index 512086fa40..d0b92cfc38 100644 --- a/tests/test_model_deepseek_v3.py +++ b/tests/test_model_deepseek_v3.py @@ -118,17 +118,28 @@ def patch_deepseek_v3(model: GPT): "mlp.experts", "mlp.shared_experts", ] + modules_to_replace = [] for name, module in model.named_modules(): - new_module = None + if isinstance(module, nn.Linear) and any(target in name for target in to_replace): + modules_to_replace.append((name, module)) + + for name, module in modules_to_replace: with torch.device("meta"): - if isinstance(module, nn.Linear) and any(target in name for target in to_replace): - new_module = FP8Linear( - in_features=module.in_features, - out_features=module.out_features, - bias=module.bias is not None, - activation_scheme="dynamic", - block_size=(128, 128), - ) - if new_module is not None: - model.set_submodule(name, new_module) + new_module = FP8Linear( + in_features=module.in_features, + out_features=module.out_features, + bias=module.bias is not None, + activation_scheme="dynamic", + block_size=(128, 128), + ) + + # Use to_empty() to move from meta device + new_module = new_module.to_empty(device=module.weight.device, dtype=module.weight.dtype) + + # Copy weights and bias + new_module.weight.data = module.weight.data.clone() + if module.bias is not None: + new_module.bias.data = module.bias.data.clone() + + model.set_submodule(name, new_module) return model From 7a4798118a7fa541f2ba41fbbf400695c3a295f9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 8 Feb 2026 03:05:18 +0000 Subject: [PATCH 11/51] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_model_deepseek_v3.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_model_deepseek_v3.py b/tests/test_model_deepseek_v3.py index d0b92cfc38..a404ad5d79 100644 --- a/tests/test_model_deepseek_v3.py +++ b/tests/test_model_deepseek_v3.py @@ -122,7 +122,7 @@ def patch_deepseek_v3(model: GPT): for name, module in model.named_modules(): if isinstance(module, nn.Linear) and any(target in name for target in to_replace): modules_to_replace.append((name, module)) - + for name, module in modules_to_replace: with torch.device("meta"): new_module = FP8Linear( @@ -132,14 +132,14 @@ def patch_deepseek_v3(model: GPT): activation_scheme="dynamic", block_size=(128, 128), ) - + # Use to_empty() to move from meta device new_module = new_module.to_empty(device=module.weight.device, dtype=module.weight.dtype) - + # Copy weights and bias new_module.weight.data = module.weight.data.clone() if module.bias is not None: new_module.bias.data = module.bias.data.clone() - + model.set_submodule(name, new_module) return model From 7d61199bff59351cb4ba976f76b5010b7957be5c Mon Sep 17 00:00:00 2001 From: Yu Shi Jie Date: Sun, 8 Feb 2026 03:40:30 +0000 Subject: [PATCH 12/51] . --- litgpt/config.py | 2 + litgpt/model.py | 114 ++++++++++++++++++++++++++++++-- tests/test_model_deepseek_v3.py | 4 +- 3 files changed, 114 insertions(+), 6 deletions(-) diff --git a/litgpt/config.py b/litgpt/config.py index 5d2c27ca7e..ec772f10f7 100644 --- a/litgpt/config.py +++ b/litgpt/config.py @@ -113,6 +113,7 @@ class Config: # `rope_base` is used, for 1 `rope_local_base_freq` is used. If # `len(rope_indices) > n_layer`, we only use the initial part. rope_indices: Optional[List[int]] = None + rope_interleaved: bool = False def __post_init__(self): if not self.name: @@ -3221,6 +3222,7 @@ def check_indicator_and_length( norm_topk_prob=True, routed_scaling_factor=2.5, rope_adjustments=dict(factor=40.0, low_freq_factor=1.0, high_freq_factor=32.0, original_max_seq_len=4096), + rope_interleaved=True ), ] diff --git a/litgpt/model.py b/litgpt/model.py index 73b6a29290..c977735a4f 100644 --- a/litgpt/model.py +++ b/litgpt/model.py @@ -220,6 +220,14 @@ def rope_cache(self, device: Optional[torch.device] = None) -> Tuple[torch.Tenso base=self.config.rope_base, extra_config=extra_config, rope_local_base_freq=self.config.rope_local_base_freq, + ) if not self.config.rope_interleaved else build_rope_cache_interleaved( + seq_len=self.max_seq_length, + n_elem=self.config.rope_n_elem, + device=device, + condense_ratio=self.config.rope_condense_ratio, + base=self.config.rope_base, + extra_config=extra_config, + rope_local_base_freq=self.config.rope_local_base_freq, ) def rope_cache_length(self) -> int: @@ -457,8 +465,8 @@ def forward( k = self.norm_k(k) # Unlike standard positional embeddings rotary embeddings must be applied at every layer. - q_roped = apply_rope(q[..., :rope_n_elem], cos, sin) - k_roped = apply_rope(k[..., :rope_n_elem], cos, sin) + q_roped = apply_rope(q[..., :rope_n_elem], cos, sin) if not self.config.rope_interleaved else apply_rope_interleaved(q[..., :rope_n_elem], cos, sin) + k_roped = apply_rope(k[..., :rope_n_elem], cos, sin) if not self.config.rope_interleaved else apply_rope_interleaved(k[..., :rope_n_elem], cos, sin) q = torch.cat((q_roped, q[..., rope_n_elem:]), dim=-1) # (B, nh_q, T, hs) k = torch.cat((k_roped, k[..., rope_n_elem:]), dim=-1) # (B, nh_k, T, hs) @@ -657,8 +665,8 @@ def forward( k_rot = k_rot.view(B, 1, T, self.config.qk_rope_head_dim) # (B, 1, T, qk_rope_head_dim) # Unlike standard positional embeddings rotary embeddings must be applied at every layer. - q_roped = apply_rope(q_rot, cos, sin) - k_roped = apply_rope(k_rot, cos, sin) + q_roped = apply_rope(q_rot, cos, sin) if not self.config.rope_interleaved else apply_rope_interleaved(q_rot, cos, sin) + k_roped = apply_rope(k_rot, cos, sin) if not self.config.rope_interleaved else apply_rope_interleaved(k_rot, cos, sin) k_roped = k_roped.expand(*k_pass.shape[:-1], -1) # (B, n_head, T, qk_rope_head_dim) q = torch.cat((q_pass, q_roped), dim=-1) @@ -938,6 +946,62 @@ def build_rope_cache( return torch.cos(idx_theta), torch.sin(idx_theta) +def build_rope_cache_interleaved( + seq_len: int, + n_elem: int, + device: Optional[torch.device] = None, + base: int = 10000, + condense_ratio: int = 1, + extra_config: Optional[dict] = None, + rope_local_base_freq: Optional[float] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + + # [Identical logic to original for calculating theta] + theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, device=device).float() / n_elem)) + + if extra_config is not None: + factor = extra_config["factor"] + if "original_max_seq_len" in extra_config: + orig_context_len = extra_config["original_max_seq_len"] + low_freq_factor = extra_config["low_freq_factor"] + high_freq_factor = extra_config["high_freq_factor"] + + wavelen = 2 * torch.pi / theta + ratio = orig_context_len / wavelen + smooth_factor = (ratio - low_freq_factor) / (high_freq_factor - low_freq_factor) + smooth_factor = torch.clamp(smooth_factor, min=0.0, max=1.0) + adjusted_theta = (1 - smooth_factor) * (theta / factor) + smooth_factor * theta + theta = adjusted_theta + else: + theta = theta / factor + + seq_idx = torch.arange(seq_len, device=device).float() / condense_ratio + + # --- CHANGED SECTION START --- + # Original: idx_theta = torch.outer(seq_idx, theta).repeat(1, 2) + # New: Repeat interleave to get [theta0, theta0, theta1, theta1...] + idx_theta = torch.outer(seq_idx, theta) + idx_theta = torch.repeat_interleave(idx_theta, 2, dim=-1) + # --- CHANGED SECTION END --- + + if idx_theta.shape[-1] > n_elem > 1: + idx_theta = idx_theta[..., :n_elem] + + if rope_local_base_freq is not None: + local_theta = 1.0 / (rope_local_base_freq ** (torch.arange(0, n_elem, 2, device=device).float() / n_elem)) + local_idx_theta = torch.outer(seq_idx, local_theta) + # --- CHANGED SECTION START --- + local_idx_theta = torch.repeat_interleave(local_idx_theta, 2, dim=-1) + # --- CHANGED SECTION END --- + + if local_idx_theta.shape[-1] > n_elem > 1: + local_idx_theta = local_idx_theta[..., :n_elem] + + idx_theta = torch.stack((idx_theta, local_idx_theta), dim=-1) + + return torch.cos(idx_theta), torch.sin(idx_theta) + + def batched_index_select(t, dim, idx): """index_select for batched index and unbatched t""" if idx.dim() == 1: @@ -1039,6 +1103,48 @@ def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.T return roped.to(dtype=x.dtype) +def apply_rope_interleaved(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: + """ + Applies Interleaved RoPE transform to `x`. + """ + if cos.dim() != 3: + raise ValueError(f"cos must be three-dimensional, but shape is {cos.shape}") + if cos.shape != sin.shape: + raise ValueError(f"cos, sin must have same shape, but cos.shape={cos.shape}, sin.shape={sin.shape}") + + # Handle shape mismatches (same as original model.py) + dims_diff = x.dim() - cos.dim() + if dims_diff > 0: + new_shape = cos.shape[0:1] + (1,) * dims_diff + cos.shape[1:] + cos = cos.view(*new_shape) + sin = sin.view(*new_shape) + + # --- CHANGED SECTION START --- + # Original Logic: + # head_size_half = x.size(-1) // 2 + # x1 = x[..., :head_size_half] + # x2 = x[..., head_size_half:] + # rotated = torch.cat((-x2, x1), dim=-1) + + # New Interleaved Logic: + # 1. Reshape to group pairs: (..., Head_Dim) -> (..., Head_Dim/2, 2) + # 2. Select evens (x) and odds (y) + # 3. Construct rotated pairs (-y, x) + x_reshaped = x.view(*x.shape[:-1], -1, 2) + + # x_reshaped[..., 0] is the "real" part (even indices) + # x_reshaped[..., 1] is the "imag" part (odd indices) + # Rotation: (x, y) -> (-y, x) + rotated_reshaped = torch.stack((-x_reshaped[..., 1], x_reshaped[..., 0]), dim=-1) + + # Flatten back to original shape + rotated = rotated_reshaped.view_as(x) + # --- CHANGED SECTION END --- + + roped = (x * cos) + (rotated * sin) + return roped.to(dtype=x.dtype) + + def do_softcapping(x: torch.Tensor, thresh: float) -> torch.Tensor: return torch.tanh(x / thresh) * thresh diff --git a/tests/test_model_deepseek_v3.py b/tests/test_model_deepseek_v3.py index a404ad5d79..fd61455fe4 100644 --- a/tests/test_model_deepseek_v3.py +++ b/tests/test_model_deepseek_v3.py @@ -93,7 +93,7 @@ def test_against_original_deepseek_v3(model_name, device, dtype): state_dict = {} copy_weights_deepseek_v3(ours_config, {}, state_dict, theirs_state_dict) ours_model = GPT(ours_config) - ours_model = patch_deepseek_v3(ours_model) + #ours_model = patch_deepseek_v3(ours_model) ours_model.to(device) ours_model.load_state_dict(state_dict) @@ -134,7 +134,7 @@ def patch_deepseek_v3(model: GPT): ) # Use to_empty() to move from meta device - new_module = new_module.to_empty(device=module.weight.device, dtype=module.weight.dtype) + new_module = new_module.to_empty(device=module.weight.device) # Copy weights and bias new_module.weight.data = module.weight.data.clone() From e22a879b4ec43da225a1b11c6cfc586c776c70a5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 8 Feb 2026 03:40:44 +0000 Subject: [PATCH 13/51] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- litgpt/config.py | 2 +- litgpt/model.py | 63 ++++++++++++++++++++------------- tests/test_model_deepseek_v3.py | 2 +- 3 files changed, 41 insertions(+), 26 deletions(-) diff --git a/litgpt/config.py b/litgpt/config.py index ec772f10f7..2957292a30 100644 --- a/litgpt/config.py +++ b/litgpt/config.py @@ -3222,7 +3222,7 @@ def check_indicator_and_length( norm_topk_prob=True, routed_scaling_factor=2.5, rope_adjustments=dict(factor=40.0, low_freq_factor=1.0, high_freq_factor=32.0, original_max_seq_len=4096), - rope_interleaved=True + rope_interleaved=True, ), ] diff --git a/litgpt/model.py b/litgpt/model.py index c977735a4f..e24da49bc6 100644 --- a/litgpt/model.py +++ b/litgpt/model.py @@ -212,22 +212,26 @@ def rope_cache(self, device: Optional[torch.device] = None) -> Tuple[torch.Tenso "All adjusted RoPE parameters must be specified together." ) - return build_rope_cache( - seq_len=self.max_seq_length, - n_elem=self.config.rope_n_elem, - device=device, - condense_ratio=self.config.rope_condense_ratio, - base=self.config.rope_base, - extra_config=extra_config, - rope_local_base_freq=self.config.rope_local_base_freq, - ) if not self.config.rope_interleaved else build_rope_cache_interleaved( - seq_len=self.max_seq_length, - n_elem=self.config.rope_n_elem, - device=device, - condense_ratio=self.config.rope_condense_ratio, - base=self.config.rope_base, - extra_config=extra_config, - rope_local_base_freq=self.config.rope_local_base_freq, + return ( + build_rope_cache( + seq_len=self.max_seq_length, + n_elem=self.config.rope_n_elem, + device=device, + condense_ratio=self.config.rope_condense_ratio, + base=self.config.rope_base, + extra_config=extra_config, + rope_local_base_freq=self.config.rope_local_base_freq, + ) + if not self.config.rope_interleaved + else build_rope_cache_interleaved( + seq_len=self.max_seq_length, + n_elem=self.config.rope_n_elem, + device=device, + condense_ratio=self.config.rope_condense_ratio, + base=self.config.rope_base, + extra_config=extra_config, + rope_local_base_freq=self.config.rope_local_base_freq, + ) ) def rope_cache_length(self) -> int: @@ -465,8 +469,16 @@ def forward( k = self.norm_k(k) # Unlike standard positional embeddings rotary embeddings must be applied at every layer. - q_roped = apply_rope(q[..., :rope_n_elem], cos, sin) if not self.config.rope_interleaved else apply_rope_interleaved(q[..., :rope_n_elem], cos, sin) - k_roped = apply_rope(k[..., :rope_n_elem], cos, sin) if not self.config.rope_interleaved else apply_rope_interleaved(k[..., :rope_n_elem], cos, sin) + q_roped = ( + apply_rope(q[..., :rope_n_elem], cos, sin) + if not self.config.rope_interleaved + else apply_rope_interleaved(q[..., :rope_n_elem], cos, sin) + ) + k_roped = ( + apply_rope(k[..., :rope_n_elem], cos, sin) + if not self.config.rope_interleaved + else apply_rope_interleaved(k[..., :rope_n_elem], cos, sin) + ) q = torch.cat((q_roped, q[..., rope_n_elem:]), dim=-1) # (B, nh_q, T, hs) k = torch.cat((k_roped, k[..., rope_n_elem:]), dim=-1) # (B, nh_k, T, hs) @@ -665,8 +677,12 @@ def forward( k_rot = k_rot.view(B, 1, T, self.config.qk_rope_head_dim) # (B, 1, T, qk_rope_head_dim) # Unlike standard positional embeddings rotary embeddings must be applied at every layer. - q_roped = apply_rope(q_rot, cos, sin) if not self.config.rope_interleaved else apply_rope_interleaved(q_rot, cos, sin) - k_roped = apply_rope(k_rot, cos, sin) if not self.config.rope_interleaved else apply_rope_interleaved(k_rot, cos, sin) + q_roped = ( + apply_rope(q_rot, cos, sin) if not self.config.rope_interleaved else apply_rope_interleaved(q_rot, cos, sin) + ) + k_roped = ( + apply_rope(k_rot, cos, sin) if not self.config.rope_interleaved else apply_rope_interleaved(k_rot, cos, sin) + ) k_roped = k_roped.expand(*k_pass.shape[:-1], -1) # (B, n_head, T, qk_rope_head_dim) q = torch.cat((q_pass, q_roped), dim=-1) @@ -955,7 +971,6 @@ def build_rope_cache_interleaved( extra_config: Optional[dict] = None, rope_local_base_freq: Optional[float] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: - # [Identical logic to original for calculating theta] theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, device=device).float() / n_elem)) @@ -993,7 +1008,7 @@ def build_rope_cache_interleaved( # --- CHANGED SECTION START --- local_idx_theta = torch.repeat_interleave(local_idx_theta, 2, dim=-1) # --- CHANGED SECTION END --- - + if local_idx_theta.shape[-1] > n_elem > 1: local_idx_theta = local_idx_theta[..., :n_elem] @@ -1131,12 +1146,12 @@ def apply_rope_interleaved(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor # 2. Select evens (x) and odds (y) # 3. Construct rotated pairs (-y, x) x_reshaped = x.view(*x.shape[:-1], -1, 2) - + # x_reshaped[..., 0] is the "real" part (even indices) # x_reshaped[..., 1] is the "imag" part (odd indices) # Rotation: (x, y) -> (-y, x) rotated_reshaped = torch.stack((-x_reshaped[..., 1], x_reshaped[..., 0]), dim=-1) - + # Flatten back to original shape rotated = rotated_reshaped.view_as(x) # --- CHANGED SECTION END --- diff --git a/tests/test_model_deepseek_v3.py b/tests/test_model_deepseek_v3.py index fd61455fe4..7ef7822325 100644 --- a/tests/test_model_deepseek_v3.py +++ b/tests/test_model_deepseek_v3.py @@ -93,7 +93,7 @@ def test_against_original_deepseek_v3(model_name, device, dtype): state_dict = {} copy_weights_deepseek_v3(ours_config, {}, state_dict, theirs_state_dict) ours_model = GPT(ours_config) - #ours_model = patch_deepseek_v3(ours_model) + # ours_model = patch_deepseek_v3(ours_model) ours_model.to(device) ours_model.load_state_dict(state_dict) From cc4a3f349522c1963a792f8c6c0d8d609ab14e63 Mon Sep 17 00:00:00 2001 From: Yu Shi Jie Date: Sat, 14 Feb 2026 22:00:00 -0500 Subject: [PATCH 14/51] apply_rope_interleave in multihead latent attn --- litgpt/config.py | 4 +- litgpt/configuration_deepseek_v3.py | 249 ++++++++ litgpt/model.py | 166 ++--- litgpt/modeling_deepseek_v3.py | 772 +++++++++++++++++++++++ tests/test_multihead_latent_attention.py | 2 +- 5 files changed, 1071 insertions(+), 122 deletions(-) create mode 100644 litgpt/configuration_deepseek_v3.py create mode 100644 litgpt/modeling_deepseek_v3.py diff --git a/litgpt/config.py b/litgpt/config.py index 2957292a30..d4af544a51 100644 --- a/litgpt/config.py +++ b/litgpt/config.py @@ -83,6 +83,7 @@ class Config: rope_base: int = 10000 rotary_percentage: float = 0.25 rope_condense_ratio: int = 1 + rope_interleave: bool = False rope_adjustments: Optional[dict] = None # Transformer block (MLP) intermediate_size: Optional[int] = None @@ -113,7 +114,6 @@ class Config: # `rope_base` is used, for 1 `rope_local_base_freq` is used. If # `len(rope_indices) > n_layer`, we only use the initial part. rope_indices: Optional[List[int]] = None - rope_interleaved: bool = False def __post_init__(self): if not self.name: @@ -3204,6 +3204,7 @@ def check_indicator_and_length( mlp_class_name="LLaMAMoE", intermediate_size=18432, rope_base=10000, + rope_interleave=True, latent_attention=dict( q_lora_rank=1536, kv_lora_rank=512, @@ -3222,7 +3223,6 @@ def check_indicator_and_length( norm_topk_prob=True, routed_scaling_factor=2.5, rope_adjustments=dict(factor=40.0, low_freq_factor=1.0, high_freq_factor=32.0, original_max_seq_len=4096), - rope_interleaved=True, ), ] diff --git a/litgpt/configuration_deepseek_v3.py b/litgpt/configuration_deepseek_v3.py new file mode 100644 index 0000000000..c09f3dfe2c --- /dev/null +++ b/litgpt/configuration_deepseek_v3.py @@ -0,0 +1,249 @@ +# Copyright 2025 bzantium and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on the DeepSeekV3 implementations from the DeepSeek AI team. (https://huggingface.co/deepseek-ai/DeepSeek-V3) + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""DeepSeekV3 model configuration""" + +from ...configuration_utils import PreTrainedConfig +from ...modeling_rope_utils import RopeParameters + + +DEEPSEEK_PRETRAINED_CONFIG_ARCHIVE_MAP = {} + + +class DeepseekV3Config(PreTrainedConfig): + r""" + This is the configuration class to store the configuration of a [`DeepseekV3Model`]. It is used to instantiate an DeepSeek + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the DeepSeek-V3. + e.g. [bzantium/tiny-deepseek-v3](https://huggingface.co/bzantium/tiny-deepseek-v3) + Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PreTrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 129280): + Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`DeepseekV3Model`] + hidden_size (`int`, *optional*, defaults to 7168): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 18432): + Dimension of the MLP representations. + moe_intermediate_size (`int`, *optional*, defaults to 2048): + Dimension of the MoE representations. + num_hidden_layers (`int`, *optional*, defaults to 61): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 128): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*, defaults to 128): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details, check out [this + paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to + `num_attention_heads`. + n_shared_experts (`int`, *optional*, defaults to 1): + Number of shared experts. + n_routed_experts (`int`, *optional*, defaults to 256): + Number of routed experts. + routed_scaling_factor (`float`, *optional*, defaults to 2.5): + Scaling factor or routed experts. + kv_lora_rank (`int`, *optional*, defaults to 512): + Rank of the LoRA matrices for key and value projections. + q_lora_rank (`int`, *optional*, defaults to 1536): + Rank of the LoRA matrices for query projections. + qk_rope_head_dim (`int`, *optional*, defaults to 64): + Dimension of the query/key heads that use rotary position embeddings. + v_head_dim (`int`, *optional*, defaults to 128): + Dimension of the value heads. + qk_nope_head_dim (`int`, *optional*, defaults to 128): + Dimension of the query/key heads that don't use rotary position embeddings. + n_group (`int`, *optional*, defaults to 8): + Number of groups for routed experts. + topk_group (`int`, *optional*, defaults to 4): + Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups). + num_experts_per_tok (`int`, *optional*, defaults to 8): + Number of selected experts, None means dense model. + first_k_dense_replace (`int`, *optional*, defaults to 3): + Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head). + \--k dense layers--/ + norm_topk_prob (`bool`, *optional*, defaults to `True`): + Whether to normalize the weights of the routed experts. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 4096): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 0): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 1): + End of stream token id. + pretraining_tp (`int`, *optional*, defaults to 1): + Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this + document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is + necessary to ensure exact reproducibility of the pretraining results. Please refer to [this + issue](https://github.com/pytorch/pytorch/issues/76232). + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_parameters (`RopeParameters`, *optional*): + Dictionary containing the configuration parameters for the RoPE embeddings. The dictionary should contain + a value for `rope_theta` and optionally parameters used for scaling in case you want to use RoPE + with longer `max_position_embeddings`. + rope_interleave (`bool`, *optional*, defaults to `True`): + Whether to interleave the rotary position embeddings. + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + + ```python + >>> from transformers import DeepseekV3Model, DeepseekV3Config + + >>> # Initializing a Deepseek-V3 style configuration + >>> configuration = DeepseekV3Config() + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "deepseek_v3" + keys_to_ignore_at_inference = ["past_key_values"] + base_model_tp_plan = { + "layers.*.mlp.experts.gate_up_proj": "rowwise", + "layers.*.mlp.experts.down_proj": "rowwise", + "layers.*.mlp.shared_experts.gate_proj": "colwise", + "layers.*.mlp.shared_experts.up_proj": "colwise", + "layers.*.mlp.shared_experts.down_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + attribute_map = { + "num_local_experts": "n_routed_experts", + } + + def __init__( + self, + vocab_size: int | None = 129280, + hidden_size: int | None = 7168, + intermediate_size: int | None = 18432, + moe_intermediate_size: int | None = 2048, + num_hidden_layers: int | None = 61, + num_attention_heads: int | None = 128, + num_key_value_heads: int | None = 128, + n_shared_experts: int | None = 1, + n_routed_experts: int | None = 256, + routed_scaling_factor: float | None = 2.5, + kv_lora_rank: int | None = 512, + q_lora_rank: int | None = 1536, + qk_rope_head_dim: int | None = 64, + v_head_dim: int | None = 128, + qk_nope_head_dim: int | None = 128, + n_group: int | None = 8, + topk_group: int | None = 4, + num_experts_per_tok: int | None = 8, + first_k_dense_replace: int | None = 3, + norm_topk_prob: bool | None = True, + hidden_act: str | None = "silu", + max_position_embeddings: int | None = 4096, + initializer_range: float | None = 0.02, + rms_norm_eps: int | None = 1e-6, + use_cache: bool | None = True, + pad_token_id: int | None = None, + bos_token_id: int | None = 0, + eos_token_id: int | None = 1, + pretraining_tp: int | None = 1, + tie_word_embeddings: bool | None = False, + rope_parameters: RopeParameters | dict[str, RopeParameters] | None = None, + rope_interleave: bool | None = True, + attention_bias: bool | None = False, + attention_dropout: float | None = 0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.moe_intermediate_size = moe_intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.n_shared_experts = n_shared_experts + self.n_routed_experts = n_routed_experts + self.routed_scaling_factor = routed_scaling_factor + self.kv_lora_rank = kv_lora_rank + self.q_lora_rank = q_lora_rank + self.qk_rope_head_dim = qk_rope_head_dim + self.v_head_dim = v_head_dim + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim + self.head_dim = qk_rope_head_dim + self.n_group = n_group + self.topk_group = topk_group + self.num_experts_per_tok = num_experts_per_tok + self.first_k_dense_replace = first_k_dense_replace + self.norm_topk_prob = norm_topk_prob + self.rope_interleave = rope_interleave + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.rope_parameters = rope_parameters + + self.tie_word_embeddings = tie_word_embeddings + self.pad_token_id = pad_token_id + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + super().__init__(**kwargs) + + def convert_rope_params_to_dict(self, ignore_keys_at_rope_validation: set | None = None, **kwargs): + rope_scaling = kwargs.pop("rope_scaling", None) + self.rope_parameters = rope_scaling or self.rope_parameters + self.rope_parameters = self.rope_parameters if self.rope_parameters is not None else {} + + # Standardize and validate the correctness of rotary position embeddings parameters + self.rope_parameters.setdefault("rope_theta", kwargs.pop("rope_theta", self.default_theta)) + self.standardize_rope_params() + self.validate_rope(ignore_keys=ignore_keys_at_rope_validation) + + # Convert to float because RoPE fn expect a float. Models on the hub were saved as int + for key in ["beta_fast", "beta_slow", "factor"]: + if key in self.rope_parameters: + self.rope_parameters[key] = float(self.rope_parameters[key]) + return kwargs + + +__all__ = ["DeepseekV3Config"] \ No newline at end of file diff --git a/litgpt/model.py b/litgpt/model.py index e24da49bc6..2cf8aa8217 100644 --- a/litgpt/model.py +++ b/litgpt/model.py @@ -212,26 +212,14 @@ def rope_cache(self, device: Optional[torch.device] = None) -> Tuple[torch.Tenso "All adjusted RoPE parameters must be specified together." ) - return ( - build_rope_cache( - seq_len=self.max_seq_length, - n_elem=self.config.rope_n_elem, - device=device, - condense_ratio=self.config.rope_condense_ratio, - base=self.config.rope_base, - extra_config=extra_config, - rope_local_base_freq=self.config.rope_local_base_freq, - ) - if not self.config.rope_interleaved - else build_rope_cache_interleaved( - seq_len=self.max_seq_length, - n_elem=self.config.rope_n_elem, - device=device, - condense_ratio=self.config.rope_condense_ratio, - base=self.config.rope_base, - extra_config=extra_config, - rope_local_base_freq=self.config.rope_local_base_freq, - ) + return build_rope_cache( + seq_len=self.max_seq_length, + n_elem=self.config.rope_n_elem, + device=device, + condense_ratio=self.config.rope_condense_ratio, + base=self.config.rope_base, + extra_config=extra_config, + rope_local_base_freq=self.config.rope_local_base_freq, ) def rope_cache_length(self) -> int: @@ -469,16 +457,12 @@ def forward( k = self.norm_k(k) # Unlike standard positional embeddings rotary embeddings must be applied at every layer. - q_roped = ( - apply_rope(q[..., :rope_n_elem], cos, sin) - if not self.config.rope_interleaved - else apply_rope_interleaved(q[..., :rope_n_elem], cos, sin) - ) - k_roped = ( - apply_rope(k[..., :rope_n_elem], cos, sin) - if not self.config.rope_interleaved - else apply_rope_interleaved(k[..., :rope_n_elem], cos, sin) - ) + if self.config.rope_interleave: + q_roped = apply_rope_interleave(q[..., :rope_n_elem], cos, sin) + k_roped = apply_rope_interleave(k[..., :rope_n_elem], cos, sin) + else: + q_roped = apply_rope(q[..., :rope_n_elem], cos, sin) + k_roped = apply_rope(k[..., :rope_n_elem], cos, sin) q = torch.cat((q_roped, q[..., rope_n_elem:]), dim=-1) # (B, nh_q, T, hs) k = torch.cat((k_roped, k[..., rope_n_elem:]), dim=-1) # (B, nh_k, T, hs) @@ -677,12 +661,12 @@ def forward( k_rot = k_rot.view(B, 1, T, self.config.qk_rope_head_dim) # (B, 1, T, qk_rope_head_dim) # Unlike standard positional embeddings rotary embeddings must be applied at every layer. - q_roped = ( - apply_rope(q_rot, cos, sin) if not self.config.rope_interleaved else apply_rope_interleaved(q_rot, cos, sin) - ) - k_roped = ( - apply_rope(k_rot, cos, sin) if not self.config.rope_interleaved else apply_rope_interleaved(k_rot, cos, sin) - ) + if self.config.rope_interleave: + q_roped = apply_rope_interleave(q_rot, cos, sin) + k_roped = apply_rope_interleave(k_rot, cos, sin) + else: + q_roped = apply_rope(q_rot, cos, sin) + k_roped = apply_rope(k_rot, cos, sin) k_roped = k_roped.expand(*k_pass.shape[:-1], -1) # (B, n_head, T, qk_rope_head_dim) q = torch.cat((q_pass, q_roped), dim=-1) @@ -926,7 +910,7 @@ def build_rope_cache( ratio = orig_context_len / wavelen smooth_factor = (ratio - low_freq_factor) / (high_freq_factor - low_freq_factor) smooth_factor = torch.clamp(smooth_factor, min=0.0, max=1.0) - + # Compute adjusted_theta without masked indexing adjusted_theta = (1 - smooth_factor) * (theta / factor) + smooth_factor * theta theta = adjusted_theta @@ -962,61 +946,6 @@ def build_rope_cache( return torch.cos(idx_theta), torch.sin(idx_theta) -def build_rope_cache_interleaved( - seq_len: int, - n_elem: int, - device: Optional[torch.device] = None, - base: int = 10000, - condense_ratio: int = 1, - extra_config: Optional[dict] = None, - rope_local_base_freq: Optional[float] = None, -) -> Tuple[torch.Tensor, torch.Tensor]: - # [Identical logic to original for calculating theta] - theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, device=device).float() / n_elem)) - - if extra_config is not None: - factor = extra_config["factor"] - if "original_max_seq_len" in extra_config: - orig_context_len = extra_config["original_max_seq_len"] - low_freq_factor = extra_config["low_freq_factor"] - high_freq_factor = extra_config["high_freq_factor"] - - wavelen = 2 * torch.pi / theta - ratio = orig_context_len / wavelen - smooth_factor = (ratio - low_freq_factor) / (high_freq_factor - low_freq_factor) - smooth_factor = torch.clamp(smooth_factor, min=0.0, max=1.0) - adjusted_theta = (1 - smooth_factor) * (theta / factor) + smooth_factor * theta - theta = adjusted_theta - else: - theta = theta / factor - - seq_idx = torch.arange(seq_len, device=device).float() / condense_ratio - - # --- CHANGED SECTION START --- - # Original: idx_theta = torch.outer(seq_idx, theta).repeat(1, 2) - # New: Repeat interleave to get [theta0, theta0, theta1, theta1...] - idx_theta = torch.outer(seq_idx, theta) - idx_theta = torch.repeat_interleave(idx_theta, 2, dim=-1) - # --- CHANGED SECTION END --- - - if idx_theta.shape[-1] > n_elem > 1: - idx_theta = idx_theta[..., :n_elem] - - if rope_local_base_freq is not None: - local_theta = 1.0 / (rope_local_base_freq ** (torch.arange(0, n_elem, 2, device=device).float() / n_elem)) - local_idx_theta = torch.outer(seq_idx, local_theta) - # --- CHANGED SECTION START --- - local_idx_theta = torch.repeat_interleave(local_idx_theta, 2, dim=-1) - # --- CHANGED SECTION END --- - - if local_idx_theta.shape[-1] > n_elem > 1: - local_idx_theta = local_idx_theta[..., :n_elem] - - idx_theta = torch.stack((idx_theta, local_idx_theta), dim=-1) - - return torch.cos(idx_theta), torch.sin(idx_theta) - - def batched_index_select(t, dim, idx): """index_select for batched index and unbatched t""" if idx.dim() == 1: @@ -1118,44 +1047,43 @@ def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.T return roped.to(dtype=x.dtype) -def apply_rope_interleaved(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: - """ - Applies Interleaved RoPE transform to `x`. +def apply_rope_interleave(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: + """Apply rotary position embeddings with interleaved tensor layout. + + This version rearranges the input tensor to group even/odd indices separately + before applying the standard RoPE rotation, matching HuggingFace's + apply_rotary_pos_emb_interleave behavior. + + Args: + x: Input tensor of shape (..., seq_len, head_dim) + cos: Cosine component of shape (B, seq_len, head_dim) or (1, seq_len, head_dim) + sin: Sine component of shape (B, seq_len, head_dim) or (1, seq_len, head_dim) + + Returns: + Tensor with RoPE applied, same shape as input """ if cos.dim() != 3: raise ValueError(f"cos must be three-dimensional, but shape is {cos.shape}") if cos.shape != sin.shape: raise ValueError(f"cos, sin must have same shape, but cos.shape={cos.shape}, sin.shape={sin.shape}") - # Handle shape mismatches (same as original model.py) + # Rearrange tensor to group even/odd indices: [x0,x1,x2,x3,...] -> [x0,x2,x4,...,x1,x3,x5,...] + *batch_dims, d = x.shape + x = x.view(*batch_dims, d // 2, 2).transpose(-1, -2).reshape(*batch_dims, d) + + # Standard rotation logic (same as apply_rope) + head_size_half = x.size(-1) // 2 + x1 = x[..., :head_size_half] + x2 = x[..., head_size_half:] + rotated = torch.cat((-x2, x1), dim=-1) + + # Auto-detect dimension mismatch and reshape cos/sin dims_diff = x.dim() - cos.dim() if dims_diff > 0: new_shape = cos.shape[0:1] + (1,) * dims_diff + cos.shape[1:] cos = cos.view(*new_shape) sin = sin.view(*new_shape) - - # --- CHANGED SECTION START --- - # Original Logic: - # head_size_half = x.size(-1) // 2 - # x1 = x[..., :head_size_half] - # x2 = x[..., head_size_half:] - # rotated = torch.cat((-x2, x1), dim=-1) - - # New Interleaved Logic: - # 1. Reshape to group pairs: (..., Head_Dim) -> (..., Head_Dim/2, 2) - # 2. Select evens (x) and odds (y) - # 3. Construct rotated pairs (-y, x) - x_reshaped = x.view(*x.shape[:-1], -1, 2) - - # x_reshaped[..., 0] is the "real" part (even indices) - # x_reshaped[..., 1] is the "imag" part (odd indices) - # Rotation: (x, y) -> (-y, x) - rotated_reshaped = torch.stack((-x_reshaped[..., 1], x_reshaped[..., 0]), dim=-1) - - # Flatten back to original shape - rotated = rotated_reshaped.view_as(x) - # --- CHANGED SECTION END --- - + roped = (x * cos) + (rotated * sin) return roped.to(dtype=x.dtype) diff --git a/litgpt/modeling_deepseek_v3.py b/litgpt/modeling_deepseek_v3.py new file mode 100644 index 0000000000..c7b036ced1 --- /dev/null +++ b/litgpt/modeling_deepseek_v3.py @@ -0,0 +1,772 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/deepseek_v3/modular_deepseek_v3.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_deepseek_v3.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +import math +from collections.abc import Callable +from typing import Optional + +import torch +import torch.nn.functional as F +from torch import nn + +from ... import initialization as init +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...generation import GenerationMixin +from ...integrations import use_experts_implementation, use_kernel_forward_from_hub, use_kernel_func_from_hub +from ...masking_utils import create_causal_mask +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import ( + GenericForSequenceClassification, + GenericForTokenClassification, + GradientCheckpointingLayer, +) +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_grouped_mm_available +from ...utils.generic import is_flash_attention_requested, maybe_autocast, merge_with_config_defaults +from ...utils.output_capturing import capture_outputs +from .configuration_deepseek_v3 import DeepseekV3Config + + +@use_kernel_forward_from_hub("RMSNorm") +class DeepseekV3RMSNorm(nn.Module): + def __init__(self, hidden_size, eps: float = 1e-6) -> None: + """ + DeepseekV3RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +# ============================================================================ +# ROPE COMPONENT 1: Core RoPE Class - Rotary Position Embedding Implementation +# ============================================================================ +class DeepseekV3RotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: DeepseekV3Config, device=None): + super().__init__() + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + + self.rope_type = self.config.rope_parameters["rope_type"] + rope_init_fn: Callable = self.compute_default_rope_parameters + if self.rope_type != "default": + rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + inv_freq, self.attention_scaling = rope_init_fn(self.config, device) + + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False) + + # ROPE: Compute inverse frequencies for rotary embeddings + @staticmethod + def compute_default_rope_parameters( + config: DeepseekV3Config | None = None, + device: Optional["torch.device"] = None, + seq_len: int | None = None, + ) -> tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies according to the original RoPE implementation + Args: + config ([`~transformers.PreTrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). + """ + base = config.rope_parameters["rope_theta"] + dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + + attention_factor = 1.0 # Unused in this type + + # Compute the inverse frequencies + inv_freq = 1.0 / ( + base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) + ) + return inv_freq, attention_factor + + # ROPE: Forward pass - generates cos/sin embeddings from position IDs + # build_rope_cache() + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) +# ============================================================================ +# END ROPE COMPONENT 1 +# ============================================================================ + + +class DeepseekV3MLP(nn.Module): + def __init__(self, config, intermediate_size=None): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class DeepseekV3TopkRouter(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.n_routed_experts = config.n_routed_experts + + self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size))) + self.register_buffer("e_score_correction_bias", torch.zeros(self.n_routed_experts)) + + def forward(self, hidden_states): + hidden_states = hidden_states.view(-1, self.config.hidden_size) + router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32)) + return router_logits + + +@use_experts_implementation +class DeepseekV3NaiveMoe(nn.Module): + """Collection of expert weights stored as 3D tensors.""" + + def __init__(self, config): + super().__init__() + self.num_experts = config.num_local_experts + self.hidden_dim = config.hidden_size + self.intermediate_dim = config.moe_intermediate_size + self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim)) + self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim)) + self.act_fn = ACT2FN[config.hidden_act] + + def forward( + self, + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, + ) -> torch.Tensor: + final_hidden_states = torch.zeros_like(hidden_states) + with torch.no_grad(): + expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts) + expert_mask = expert_mask.permute(2, 1, 0) + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + + for expert_idx in expert_hit: + expert_idx = expert_idx[0] + if expert_idx == self.num_experts: + continue + top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) + current_state = hidden_states[token_idx] + gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) + current_hidden_states = self.act_fn(gate) * up + current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) + current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None] + final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) + + return final_hidden_states + + +class DeepseekV3MoE(nn.Module): + """ + A mixed expert module containing shared experts. + """ + + def __init__(self, config): + super().__init__() + self.config = config + self.experts = DeepseekV3NaiveMoe(config) + self.gate = DeepseekV3TopkRouter(config) + self.shared_experts = DeepseekV3MLP( + config=config, intermediate_size=config.moe_intermediate_size * config.n_shared_experts + ) + self.n_routed_experts = config.n_routed_experts + self.n_group = config.n_group + self.topk_group = config.topk_group + self.norm_topk_prob = config.norm_topk_prob + self.routed_scaling_factor = config.routed_scaling_factor + self.top_k = config.num_experts_per_tok + + def route_tokens_to_experts(self, router_logits): + router_logits = router_logits.sigmoid() + router_logits_for_choice = router_logits + self.gate.e_score_correction_bias + group_scores = ( + router_logits_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group) + .topk(2, dim=-1)[0] + .sum(dim=-1) + ) + group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] + group_mask = torch.zeros_like(group_scores) + group_mask.scatter_(1, group_idx, 1) + score_mask = ( + group_mask.unsqueeze(-1) + .expand(-1, self.n_group, self.n_routed_experts // self.n_group) + .reshape(-1, self.n_routed_experts) + ) + scores_for_choice = router_logits_for_choice.masked_fill(~score_mask.bool(), 0.0) + topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1] + topk_weights = router_logits.gather(1, topk_indices) + if self.norm_topk_prob: + denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20 + topk_weights /= denominator + topk_weights = topk_weights * self.routed_scaling_factor + return topk_indices, topk_weights + + def forward(self, hidden_states): + residuals = hidden_states + orig_shape = hidden_states.shape + router_logits = self.gate(hidden_states) + topk_indices, topk_weights = self.route_tokens_to_experts(router_logits) + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + hidden_states = self.experts(hidden_states, topk_indices, topk_weights).view(*orig_shape) + hidden_states = hidden_states + self.shared_experts(residuals) + return hidden_states + + +# ============================================================================ +# ROPE COMPONENT 2: RoPE Helper Functions +# ============================================================================ +# ROPE: Rotation helper - splits tensor and rotates [-x2, x1] +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# ROPE: Main function to apply RoPE to query and key tensors +# apply_rope +@use_kernel_func_from_hub("rotary_pos_emb") +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor | None, + scaling: float, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +# ROPE: Alternative interleaved RoPE application (with view/transpose for efficiency) +def apply_rotary_pos_emb_interleave(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + r""" + TODO let's just use the original freqcis computation to not have the view + transpose + reshape! This is not optimized! + Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + b, h, s, d = q.shape + q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) + + b, h, s, d = k.shape + k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +# ROPE: YaRN (Yet another RoPE extensioN) scaling function for extended context +def yarn_get_mscale(scale=1, mscale=1): + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 +# ============================================================================ +# END ROPE COMPONENT 2 +# ============================================================================ + + +class DeepseekV3Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: DeepseekV3Config, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.attention_dropout = config.attention_dropout + self.num_heads = config.num_attention_heads + + self.q_lora_rank = config.q_lora_rank + self.qk_rope_head_dim = config.qk_rope_head_dim # ROPE: dimension for rotary embeddings + self.kv_lora_rank = config.kv_lora_rank + self.v_head_dim = config.v_head_dim + self.qk_nope_head_dim = config.qk_nope_head_dim + self.qk_head_dim = config.qk_head_dim + + self.is_causal = True + if self.q_lora_rank is None: + self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.qk_head_dim, bias=False) + else: + self.q_a_proj = nn.Linear(config.hidden_size, config.q_lora_rank, bias=config.attention_bias) + self.q_a_layernorm = DeepseekV3RMSNorm(config.q_lora_rank) + self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.qk_head_dim, bias=False) + + self.kv_a_proj_with_mqa = nn.Linear( + config.hidden_size, + self.kv_lora_rank + self.qk_rope_head_dim, + bias=config.attention_bias, + ) + self.kv_a_layernorm = DeepseekV3RMSNorm(self.kv_lora_rank) + self.kv_b_proj = nn.Linear( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + bias=False, + ) + + self.o_proj = nn.Linear( + self.num_heads * self.v_head_dim, + config.hidden_size, + bias=config.attention_bias, + ) + + # ROPE: Initialize attention scaling (potentially adjusted by YaRN) + self.scaling = self.qk_head_dim ** (-0.5) + if self.config.rope_parameters.get("rope_type", "default") != "default": + mscale_all_dim = self.config.rope_parameters.get("mscale_all_dim", 0) + scaling_factor = self.config.rope_parameters["factor"] + if mscale_all_dim: + mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) # ROPE: Apply YaRN scaling + self.scaling = self.scaling * mscale * mscale + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: torch.Tensor | None, + past_key_values: Cache | None = None, + cache_position: torch.LongTensor | None = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: + batch_size, seq_length = hidden_states.shape[:-1] + query_shape = (batch_size, seq_length, -1, self.qk_head_dim) + key_shape = (batch_size, seq_length, -1, self.qk_nope_head_dim + self.v_head_dim) + + if self.q_lora_rank is None: + q_states = self.q_proj(hidden_states) + else: + q_states = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) + q_states = q_states.view(query_shape).transpose(1, 2) + # ROPE: Split query into non-RoPE (q_pass) and RoPE (q_rot) parts + q_pass, q_rot = torch.split(q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + # ROPE: Split key into non-RoPE (k_pass) and RoPE (k_rot) parts + k_pass, k_rot = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + + k_pass = self.kv_b_proj(self.kv_a_layernorm(k_pass)).view(key_shape).transpose(1, 2) + k_pass, value_states = torch.split(k_pass, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + k_rot = k_rot.view(batch_size, 1, seq_length, self.qk_rope_head_dim) + + # ======================================================================== + # ROPE COMPONENT 3: Apply RoPE to Query and Key tensors + # ======================================================================== + cos, sin = position_embeddings # ROPE: Get cos/sin from rotary embeddings + if self.config.rope_interleave: # support using interleaved weights for efficiency + q_rot, k_rot = apply_rotary_pos_emb_interleave(q_rot, k_rot, cos, sin) # ROPE: Interleaved version + else: + q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin) # ROPE: Standard version + # ======================================================================== + # END ROPE COMPONENT 3 + # ======================================================================== + k_rot = k_rot.expand(*k_pass.shape[:-1], -1) + + # ROPE: Concatenate non-RoPE and RoPE parts back together + query_states = torch.cat((q_pass, q_rot), dim=-1) + key_states = torch.cat((k_pass, k_rot), dim=-1) + + if past_key_values is not None: + # ROPE: sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + if is_flash_attention_requested(self.config) and self.qk_head_dim != self.v_head_dim: + value_states = F.pad(value_states, [0, self.qk_head_dim - self.v_head_dim]) + + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( + self.config._attn_implementation, eager_attention_forward + ) + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + if is_flash_attention_requested(self.config) and self.qk_head_dim != self.v_head_dim: + attn_output = attn_output[:, :, :, : self.v_head_dim] + + attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class DeepseekV3DecoderLayer(GradientCheckpointingLayer): + def __init__(self, config: DeepseekV3Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = DeepseekV3Attention(config=config, layer_idx=layer_idx) + + if layer_idx >= config.first_k_dense_replace: + self.mlp = DeepseekV3MoE(config) + else: + self.mlp = DeepseekV3MLP(config) + + self.input_layernorm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + use_cache: bool | None = False, + cache_position: torch.LongTensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + # Self Attention + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +@auto_docstring +class DeepseekV3PreTrainedModel(PreTrainedModel): + config: DeepseekV3Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["DeepseekV3DecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + _can_compile_fullgraph = ( + is_grouped_mm_available() + ) # https://huggingface.co/docs/transformers/experts_interface#torchcompile + _supports_attention_backend = True + _can_record_outputs = { + "hidden_states": DeepseekV3DecoderLayer, + "attentions": DeepseekV3Attention, + } + _keep_in_fp32_modules_strict = ["e_score_correction_bias"] + _keys_to_ignore_on_load_unexpected = [r"model\.layers\.61.*"] + + @torch.no_grad() + def _init_weights(self, module): + super()._init_weights(module) + if isinstance(module, DeepseekV3TopkRouter): + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + init.zeros_(module.e_score_correction_bias) + elif isinstance(module, DeepseekV3NaiveMoe): + init.normal_(module.gate_up_proj, mean=0.0, std=self.config.initializer_range) + init.normal_(module.down_proj, mean=0.0, std=self.config.initializer_range) + + +@auto_docstring +class DeepseekV3Model(DeepseekV3PreTrainedModel): + def __init__(self, config: DeepseekV3Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [DeepseekV3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + # ======================================================================== + # ROPE COMPONENT 4: RoPE Instantiation in Model + # ======================================================================== + self.rotary_emb = DeepseekV3RotaryEmbedding(config=config) # ROPE: Create rotary embedding instance + # ======================================================================== + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + @merge_with_config_defaults + @capture_outputs + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + cache_position: torch.LongTensor | None = None, + use_cache: bool | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds: torch.Tensor = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position: torch.Tensor = ( + torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = create_causal_mask( + config=self.config, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + hidden_states = inputs_embeds + # ROPE: Generate position embeddings (cos/sin) for all positions + position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids) + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + hidden_states = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_embeddings=position_embeddings, # ROPE: Pass to each layer + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + + +@auto_docstring +class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel, GenerationMixin): + _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} + _tp_plan = {"lm_head": "colwise_gather_output"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config): + super().__init__(config) + self.model = DeepseekV3Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + cache_position: torch.LongTensor | None = None, + logits_to_keep: int | torch.Tensor = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> CausalLMOutputWithPast: + r""" + Example: + + ```python + >>> from transformers import AutoTokenizer, DeepseekV3ForCausalLM + + >>> model = DeepseekV3ForCausalLM.from_pretrained("meta-deepseek_v3/DeepseekV3-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-deepseek_v3/DeepseekV3-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class DeepseekV3ForSequenceClassification(GenericForSequenceClassification, DeepseekV3PreTrainedModel): + pass + + +class DeepseekV3ForTokenClassification(GenericForTokenClassification, DeepseekV3PreTrainedModel): + pass + + +__all__ = [ + "DeepseekV3PreTrainedModel", + "DeepseekV3Model", + "DeepseekV3ForCausalLM", + "DeepseekV3ForSequenceClassification", + "DeepseekV3ForTokenClassification", +] \ No newline at end of file diff --git a/tests/test_multihead_latent_attention.py b/tests/test_multihead_latent_attention.py index 5d040f4d09..43eb606818 100644 --- a/tests/test_multihead_latent_attention.py +++ b/tests/test_multihead_latent_attention.py @@ -105,7 +105,7 @@ def test_multihead_latent_attention_litgpt_vs_hf(batch_size, seq_len, device): qk_rope_head_dim=8, qk_nope_head_dim=8, v_head_dim=16, - rope_interleave=False, + rope_interleave=True, ) mla_litgpt = MultiheadLatentAttention(config_litgpt, block_idx=0).to(device) From 078f939659e665770a361ddc5ed7236ab4444146 Mon Sep 17 00:00:00 2001 From: Yu Shi Jie Date: Sat, 14 Feb 2026 22:07:21 -0500 Subject: [PATCH 15/51] test: for deepseek v3 block --- tests/test_multihead_latent_attention.py | 99 ++++++++++++++++++++++++ 1 file changed, 99 insertions(+) diff --git a/tests/test_multihead_latent_attention.py b/tests/test_multihead_latent_attention.py index 43eb606818..1deb3bcf41 100644 --- a/tests/test_multihead_latent_attention.py +++ b/tests/test_multihead_latent_attention.py @@ -148,3 +148,102 @@ def sync_weights(litgpt_model, hf_model): hf_model.kv_b_proj.weight.copy_(litgpt_model.kv_b_proj.weight) hf_model.o_proj.weight.copy_(litgpt_model.proj.weight) print("Synchronization complete.") + + +@torch.inference_mode() +@pytest.mark.parametrize("batch_size", (1, 2)) +@pytest.mark.parametrize("seq_len", (8, 16)) +@pytest.mark.parametrize("device", [torch.device("cpu")]) +def test_deepseek_v3_block(batch_size, seq_len, device): + """Test DeepSeek V3 block (attention + MLP + norms) litgpt vs hf""" + from litgpt.model import Block + + # Use layer_idx=0 to test dense MLP instead of MoE + layer_idx = 0 + + config_litgpt = Config( + n_embd=64, + n_head=4, + n_query_groups=4, + head_size=16, + norm_eps=1e-6, + norm_class_name="RMSNorm", + bias=False, + parallel_residual=False, + mlp_class_name="LLaMAMLP", + intermediate_size=128, + rope_interleave=True, + latent_attention={ + "q_lora_rank": 32, + "kv_lora_rank": 16, + "qk_rope_head_dim": 8, + "qk_nope_head_dim": 8, + "v_head_dim": 16, + }, + first_k_dense_replace=3, # Use dense MLP for first 3 layers + ) + + config_hf = DeepseekV3Config( + padded_vocab_size=10000, + num_hidden_layers=1, + vocab_size=10000, + hidden_size=64, + intermediate_size=128, + num_attention_heads=4, + num_key_value_heads=4, + q_lora_rank=32, + kv_lora_rank=16, + qk_rope_head_dim=8, + qk_nope_head_dim=8, + v_head_dim=16, + rope_interleave=True, + first_k_dense_replace=3, + rms_norm_eps=1e-6, + ) + + block_litgpt = Block(config_litgpt, block_idx=layer_idx).to(device) + model_hf = DeepseekV3ForCausalLM(config_hf).to(device) + block_hf = model_hf.model.layers[layer_idx] + + block_litgpt.eval() + block_hf.eval() + + sync_block_weights(block_litgpt, block_hf) + + hidden_states = torch.randn(batch_size, seq_len, config_litgpt.n_embd, device=device) + + # Prepare RoPE sin/cos tables + rope_head_dim = config_litgpt.latent_attention["qk_rope_head_dim"] + cos = torch.randn(batch_size, seq_len, rope_head_dim, device=device, dtype=hidden_states.dtype) + sin = torch.randn(batch_size, seq_len, rope_head_dim, device=device, dtype=hidden_states.dtype) + + causal_mask = torch.triu( + torch.full((seq_len, seq_len), float("-inf"), device=device, dtype=hidden_states.dtype), diagonal=1 + ) + attention_mask = causal_mask.unsqueeze(0).unsqueeze(0).expand(batch_size, 1, -1, -1) + + # Run forward passes + output_litgpt = block_litgpt(hidden_states, cos, sin) + output_hf = block_hf(hidden_states, position_embeddings=(cos, sin), attention_mask=attention_mask) + + assert torch.allclose(output_litgpt, output_hf, atol=1e-5, rtol=1e-4), \ + f"Max diff: {(output_litgpt - output_hf).abs().max()}" + + +def sync_block_weights(block_litgpt, block_hf): + """Synchronize all weights from LitGPT block to HF block.""" + print("Synchronizing block weights...") + with torch.no_grad(): + # Sync attention weights + sync_weights(block_litgpt.attn, block_hf.self_attn) + + # Sync MLP weights (assumes dense MLP, not MoE) + block_hf.mlp.gate_proj.weight.copy_(block_litgpt.mlp.fc_1.weight) + block_hf.mlp.up_proj.weight.copy_(block_litgpt.mlp.fc_2.weight) + block_hf.mlp.down_proj.weight.copy_(block_litgpt.mlp.proj.weight) + + # Sync normalization layers + block_hf.input_layernorm.weight.copy_(block_litgpt.norm_1.weight) + block_hf.post_attention_layernorm.weight.copy_(block_litgpt.norm_2.weight) + + print("Block synchronization complete.") From 8d4f5f704e7e8c97ef056514ef80517f5951ba1f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 15 Feb 2026 03:07:49 +0000 Subject: [PATCH 16/51] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- litgpt/configuration_deepseek_v3.py | 3 +-- litgpt/model.py | 8 ++++---- litgpt/modeling_deepseek_v3.py | 6 +++++- tests/test_multihead_latent_attention.py | 3 ++- 4 files changed, 12 insertions(+), 8 deletions(-) diff --git a/litgpt/configuration_deepseek_v3.py b/litgpt/configuration_deepseek_v3.py index c09f3dfe2c..ea56e5a88d 100644 --- a/litgpt/configuration_deepseek_v3.py +++ b/litgpt/configuration_deepseek_v3.py @@ -18,7 +18,6 @@ from ...configuration_utils import PreTrainedConfig from ...modeling_rope_utils import RopeParameters - DEEPSEEK_PRETRAINED_CONFIG_ARCHIVE_MAP = {} @@ -246,4 +245,4 @@ def convert_rope_params_to_dict(self, ignore_keys_at_rope_validation: set | None return kwargs -__all__ = ["DeepseekV3Config"] \ No newline at end of file +__all__ = ["DeepseekV3Config"] diff --git a/litgpt/model.py b/litgpt/model.py index 2cf8aa8217..c58b655b0b 100644 --- a/litgpt/model.py +++ b/litgpt/model.py @@ -910,7 +910,7 @@ def build_rope_cache( ratio = orig_context_len / wavelen smooth_factor = (ratio - low_freq_factor) / (high_freq_factor - low_freq_factor) smooth_factor = torch.clamp(smooth_factor, min=0.0, max=1.0) - + # Compute adjusted_theta without masked indexing adjusted_theta = (1 - smooth_factor) * (theta / factor) + smooth_factor * theta theta = adjusted_theta @@ -1070,20 +1070,20 @@ def apply_rope_interleave(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) # Rearrange tensor to group even/odd indices: [x0,x1,x2,x3,...] -> [x0,x2,x4,...,x1,x3,x5,...] *batch_dims, d = x.shape x = x.view(*batch_dims, d // 2, 2).transpose(-1, -2).reshape(*batch_dims, d) - + # Standard rotation logic (same as apply_rope) head_size_half = x.size(-1) // 2 x1 = x[..., :head_size_half] x2 = x[..., head_size_half:] rotated = torch.cat((-x2, x1), dim=-1) - + # Auto-detect dimension mismatch and reshape cos/sin dims_diff = x.dim() - cos.dim() if dims_diff > 0: new_shape = cos.shape[0:1] + (1,) * dims_diff + cos.shape[1:] cos = cos.view(*new_shape) sin = sin.view(*new_shape) - + roped = (x * cos) + (rotated * sin) return roped.to(dtype=x.dtype) diff --git a/litgpt/modeling_deepseek_v3.py b/litgpt/modeling_deepseek_v3.py index c7b036ced1..d4151f39c4 100644 --- a/litgpt/modeling_deepseek_v3.py +++ b/litgpt/modeling_deepseek_v3.py @@ -124,6 +124,8 @@ def forward(self, x, position_ids): sin = emb.sin() * self.attention_scaling return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + # ============================================================================ # END ROPE COMPONENT 1 # ============================================================================ @@ -376,6 +378,8 @@ def yarn_get_mscale(scale=1, mscale=1): if scale <= 1: return 1.0 return 0.1 * mscale * math.log(scale) + 1.0 + + # ============================================================================ # END ROPE COMPONENT 2 # ============================================================================ @@ -769,4 +773,4 @@ class DeepseekV3ForTokenClassification(GenericForTokenClassification, DeepseekV3 "DeepseekV3ForCausalLM", "DeepseekV3ForSequenceClassification", "DeepseekV3ForTokenClassification", -] \ No newline at end of file +] diff --git a/tests/test_multihead_latent_attention.py b/tests/test_multihead_latent_attention.py index 1deb3bcf41..d45c6d650a 100644 --- a/tests/test_multihead_latent_attention.py +++ b/tests/test_multihead_latent_attention.py @@ -226,8 +226,9 @@ def test_deepseek_v3_block(batch_size, seq_len, device): output_litgpt = block_litgpt(hidden_states, cos, sin) output_hf = block_hf(hidden_states, position_embeddings=(cos, sin), attention_mask=attention_mask) - assert torch.allclose(output_litgpt, output_hf, atol=1e-5, rtol=1e-4), \ + assert torch.allclose(output_litgpt, output_hf, atol=1e-5, rtol=1e-4), ( f"Max diff: {(output_litgpt - output_hf).abs().max()}" + ) def sync_block_weights(block_litgpt, block_hf): From abbfa2b5671b33776520ba888bdd241ecefad14c Mon Sep 17 00:00:00 2001 From: Yu Shi Jie Date: Sun, 15 Feb 2026 03:56:36 +0000 Subject: [PATCH 17/51] milestone: deepseek v3 assertClose now at 96% --- tests/test_model_deepseek_v3.py | 12 ++++++------ tests/test_multihead_latent_attention.py | 3 ++- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/tests/test_model_deepseek_v3.py b/tests/test_model_deepseek_v3.py index 7ef7822325..18fc6003e4 100644 --- a/tests/test_model_deepseek_v3.py +++ b/tests/test_model_deepseek_v3.py @@ -40,9 +40,9 @@ def test_against_original_deepseek_v3(model_name, device, dtype): model_name, block_size=T, n_layer=2, - n_head=16, - n_embd=32, - n_query_groups=16, + n_head=8, # Reduced to make n_embd work out + n_embd=32, # 8 heads * 4 head_dim (qk_rope=4 + qk_nope=8) + n_query_groups=8, intermediate_size=86, moe_intermediate_size=20, n_expert=4, @@ -55,9 +55,9 @@ def test_against_original_deepseek_v3(model_name, device, dtype): latent_attention=dict( q_lora_rank=16, kv_lora_rank=16, - qk_rope_head_dim=8, - qk_nope_head_dim=8, - v_head_dim=16, + qk_rope_head_dim=4, # Maintain 1:2 ratio with qk_nope_head_dim + qk_nope_head_dim=8, # 2x qk_rope_head_dim (matching DeepSeek V3 architecture) + v_head_dim=8, # Same as qk_nope_head_dim ), ) theirs_config = DeepseekV3Config( diff --git a/tests/test_multihead_latent_attention.py b/tests/test_multihead_latent_attention.py index d45c6d650a..e7a299fb48 100644 --- a/tests/test_multihead_latent_attention.py +++ b/tests/test_multihead_latent_attention.py @@ -91,6 +91,7 @@ def test_multihead_latent_attention_litgpt_vs_hf(batch_size, seq_len, device): "qk_nope_head_dim": 8, "v_head_dim": 16, }, + rope_interleave=True, ) config_hf = DeepseekV3Config( @@ -170,7 +171,7 @@ def test_deepseek_v3_block(batch_size, seq_len, device): norm_class_name="RMSNorm", bias=False, parallel_residual=False, - mlp_class_name="LLaMAMLP", + mlp_class_name="LLaMAMoE", intermediate_size=128, rope_interleave=True, latent_attention={ From 243dd0dbd45cf5b0803e0908d1b31253f2bf7237 Mon Sep 17 00:00:00 2001 From: Yu Shi Jie Date: Sun, 15 Feb 2026 04:18:01 +0000 Subject: [PATCH 18/51] deepseek v3 passes completely without rope scaling --- tests/test_model_deepseek_v3.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_model_deepseek_v3.py b/tests/test_model_deepseek_v3.py index 18fc6003e4..99f942d87f 100644 --- a/tests/test_model_deepseek_v3.py +++ b/tests/test_model_deepseek_v3.py @@ -59,6 +59,7 @@ def test_against_original_deepseek_v3(model_name, device, dtype): qk_nope_head_dim=8, # 2x qk_rope_head_dim (matching DeepSeek V3 architecture) v_head_dim=8, # Same as qk_nope_head_dim ), + rope_adjustments=None ) theirs_config = DeepseekV3Config( vocab_size=ours_config.padded_vocab_size, From 34368bb6bcc174b402f135837463b30eb0825d57 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 15 Feb 2026 04:18:15 +0000 Subject: [PATCH 19/51] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_model_deepseek_v3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_model_deepseek_v3.py b/tests/test_model_deepseek_v3.py index 99f942d87f..ff36a60dae 100644 --- a/tests/test_model_deepseek_v3.py +++ b/tests/test_model_deepseek_v3.py @@ -59,7 +59,7 @@ def test_against_original_deepseek_v3(model_name, device, dtype): qk_nope_head_dim=8, # 2x qk_rope_head_dim (matching DeepSeek V3 architecture) v_head_dim=8, # Same as qk_nope_head_dim ), - rope_adjustments=None + rope_adjustments=None, ) theirs_config = DeepseekV3Config( vocab_size=ours_config.padded_vocab_size, From 84a72ff134da46c9ac27ed69c24c3f5094bdc962 Mon Sep 17 00:00:00 2001 From: Yu Shi Jie Date: Sun, 15 Feb 2026 00:01:03 -0500 Subject: [PATCH 20/51] mscale --- litgpt/model.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/litgpt/model.py b/litgpt/model.py index c58b655b0b..9e905a6ff0 100644 --- a/litgpt/model.py +++ b/litgpt/model.py @@ -378,6 +378,13 @@ def __init__(self, config: Config, block_idx: int) -> None: else: self.norm_q = self.norm_k = None + mscale_all_dim = config.rope_adjustments.get("mscale_all_dim",None) + scaling_factor = config.rope_adjustments.get("factor", None) + if mscale_all_dim and scaling_factor: #YaRN + self.mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) + else: + self.mscale = 1.0 + self.config = config self.block_idx = block_idx @@ -531,6 +538,7 @@ def scaled_dot_product_attention( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None ) -> torch.Tensor: scale = 1.0 / math.sqrt(self.config.attention_scores_scalar or self.config.head_size) + scale = scale * self.mscale * self.mscale # with softcapping we cannot use SDPA if self.config.attention_logit_softcapping is not None: @@ -622,6 +630,13 @@ def __init__(self, config: Config, block_idx: int) -> None: # disabled by default self.kv_cache: Optional[KVCache] = None + mscale_all_dim = config.rope_adjustments.get("mscale_all_dim",None) + scaling_factor = config.rope_adjustments.get("factor", None) + if mscale_all_dim and scaling_factor: #YaRN + self.mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) + else: + self.mscale = 1.0 + self.config = config self.block_idx = block_idx @@ -707,6 +722,7 @@ def scaled_dot_product_attention( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None ) -> torch.Tensor: scale = 1.0 / math.sqrt(self.config.attention_scores_scalar or self.config.qk_head_dim) + scale = scale * self.mscale * self.mscale # with softcapping we cannot use SDPA if self.config.attention_logit_softcapping is not None: @@ -871,6 +887,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return topk_weights, topk_indices +# ROPE: YaRN (Yet another RoPE extensioN) scaling function for extended context +def yarn_get_mscale(scale=1, mscale=1): + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + def build_rope_cache( seq_len: int, n_elem: int, From d921c6baf88533e9b7eac6ccb869d46a7ab868f1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 15 Feb 2026 05:01:21 +0000 Subject: [PATCH 21/51] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- litgpt/model.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/litgpt/model.py b/litgpt/model.py index 9e905a6ff0..0e3bcacfee 100644 --- a/litgpt/model.py +++ b/litgpt/model.py @@ -378,9 +378,9 @@ def __init__(self, config: Config, block_idx: int) -> None: else: self.norm_q = self.norm_k = None - mscale_all_dim = config.rope_adjustments.get("mscale_all_dim",None) + mscale_all_dim = config.rope_adjustments.get("mscale_all_dim", None) scaling_factor = config.rope_adjustments.get("factor", None) - if mscale_all_dim and scaling_factor: #YaRN + if mscale_all_dim and scaling_factor: # YaRN self.mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) else: self.mscale = 1.0 @@ -630,9 +630,9 @@ def __init__(self, config: Config, block_idx: int) -> None: # disabled by default self.kv_cache: Optional[KVCache] = None - mscale_all_dim = config.rope_adjustments.get("mscale_all_dim",None) + mscale_all_dim = config.rope_adjustments.get("mscale_all_dim", None) scaling_factor = config.rope_adjustments.get("factor", None) - if mscale_all_dim and scaling_factor: #YaRN + if mscale_all_dim and scaling_factor: # YaRN self.mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) else: self.mscale = 1.0 From d92c421d4f0e61c814c519d4b09782b1f1c36254 Mon Sep 17 00:00:00 2001 From: Yu Shi Jie Date: Sun, 15 Feb 2026 00:09:59 -0500 Subject: [PATCH 22/51] fix --- litgpt/config.py | 4 +++- litgpt/model.py | 22 ++++++++++++++-------- tests/test_model_deepseek_v3.py | 16 +++++++++++++++- 3 files changed, 32 insertions(+), 10 deletions(-) diff --git a/litgpt/config.py b/litgpt/config.py index d4af544a51..a80daea4fb 100644 --- a/litgpt/config.py +++ b/litgpt/config.py @@ -3222,7 +3222,9 @@ def check_indicator_and_length( first_k_dense_replace=3, norm_topk_prob=True, routed_scaling_factor=2.5, - rope_adjustments=dict(factor=40.0, low_freq_factor=1.0, high_freq_factor=32.0, original_max_seq_len=4096), + rope_adjustments=dict( + factor=40.0, low_freq_factor=1.0, high_freq_factor=32.0, original_max_seq_len=4096, mscale_all_dim=1.0 + ), ), ] diff --git a/litgpt/model.py b/litgpt/model.py index 9e905a6ff0..90d3146c26 100644 --- a/litgpt/model.py +++ b/litgpt/model.py @@ -378,10 +378,13 @@ def __init__(self, config: Config, block_idx: int) -> None: else: self.norm_q = self.norm_k = None - mscale_all_dim = config.rope_adjustments.get("mscale_all_dim",None) - scaling_factor = config.rope_adjustments.get("factor", None) - if mscale_all_dim and scaling_factor: #YaRN - self.mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) + if config.rope_adjustments is not None: + mscale_all_dim = config.rope_adjustments.get("mscale_all_dim", None) + scaling_factor = config.rope_adjustments.get("factor", None) + if mscale_all_dim and scaling_factor: # YaRN + self.mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) + else: + self.mscale = 1.0 else: self.mscale = 1.0 @@ -630,10 +633,13 @@ def __init__(self, config: Config, block_idx: int) -> None: # disabled by default self.kv_cache: Optional[KVCache] = None - mscale_all_dim = config.rope_adjustments.get("mscale_all_dim",None) - scaling_factor = config.rope_adjustments.get("factor", None) - if mscale_all_dim and scaling_factor: #YaRN - self.mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) + if config.rope_adjustments is not None: + mscale_all_dim = config.rope_adjustments.get("mscale_all_dim", None) + scaling_factor = config.rope_adjustments.get("factor", None) + if mscale_all_dim and scaling_factor: # YaRN + self.mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) + else: + self.mscale = 1.0 else: self.mscale = 1.0 diff --git a/tests/test_model_deepseek_v3.py b/tests/test_model_deepseek_v3.py index ff36a60dae..29c6a1b73e 100644 --- a/tests/test_model_deepseek_v3.py +++ b/tests/test_model_deepseek_v3.py @@ -59,7 +59,9 @@ def test_against_original_deepseek_v3(model_name, device, dtype): qk_nope_head_dim=8, # 2x qk_rope_head_dim (matching DeepSeek V3 architecture) v_head_dim=8, # Same as qk_nope_head_dim ), - rope_adjustments=None, + rope_adjustments=dict( + factor=40.0, low_freq_factor=1.0, high_freq_factor=32.0, original_max_seq_len=4096, mscale_all_dim=1.0 + ), ) theirs_config = DeepseekV3Config( vocab_size=ours_config.padded_vocab_size, @@ -87,6 +89,18 @@ def test_against_original_deepseek_v3(model_name, device, dtype): v_head_dim=ours_config.latent_attention["v_head_dim"], q_lora_rank=ours_config.latent_attention["q_lora_rank"], kv_lora_rank=ours_config.latent_attention["kv_lora_rank"], + rope_parameters=( + { + "type": "yarn", + "factor": ours_config.rope_adjustments["factor"], + "beta_slow": ours_config.rope_adjustments["low_freq_factor"], + "beta_fast": ours_config.rope_adjustments["high_freq_factor"], + "original_max_position_embeddings": ours_config.rope_adjustments["original_max_seq_len"], + "mscale_all_dim": 1.0, + } + if ours_config.rope_adjustments + else None + ), ) theirs_model = DeepseekV3ForCausalLM(theirs_config).to(device) From 4b2db6459e757f91418ca9a1bb71ac707ce2f295 Mon Sep 17 00:00:00 2001 From: Yu Shi Jie Date: Sun, 15 Feb 2026 14:34:30 -0500 Subject: [PATCH 23/51] yarn rope --- litgpt/model.py | 129 ++++- litgpt/modeling_rope_util.py | 945 +++++++++++++++++++++++++++++++++++ 2 files changed, 1057 insertions(+), 17 deletions(-) create mode 100644 litgpt/modeling_rope_util.py diff --git a/litgpt/model.py b/litgpt/model.py index 90d3146c26..0ad8e739b7 100644 --- a/litgpt/model.py +++ b/litgpt/model.py @@ -189,28 +189,64 @@ def rope_cache(self, device: Optional[torch.device] = None) -> Tuple[torch.Tenso extra_config = None else: - adjusted_params_required = ["factor", "low_freq_factor", "high_freq_factor", "original_max_seq_len"] - params_present = [param in self.config.rope_adjustments for param in adjusted_params_required] - num_params_present = sum(params_present) + # Check for mutually exclusive parameter sets + llama3_params = ["low_freq_factor", "high_freq_factor"] + yarn_params = ["beta_fast", "beta_slow"] - if num_params_present == 0: - extra_config = None # uses standard RoPE - elif num_params_present == 4: - # These parameters should always be used together so that we don't interfere with standard rope - extra_config = {name: self.config.rope_adjustments[name] for name in adjusted_params_required} + has_llama3 = any(param in self.config.rope_adjustments for param in llama3_params) + has_yarn = any(param in self.config.rope_adjustments for param in yarn_params) + + if has_llama3 and has_yarn: + raise ValueError( + "RoPE adjustments cannot contain both Llama3 parameters (low_freq_factor, high_freq_factor) " + "and YaRN parameters (beta_fast, beta_slow). These are mutually exclusive." + ) + + # Llama3-style RoPE + if has_llama3: + adjusted_params_required = ["factor", "low_freq_factor", "high_freq_factor", "original_max_seq_len"] + params_present = [param in self.config.rope_adjustments for param in adjusted_params_required] + if all(params_present): + extra_config = {name: self.config.rope_adjustments[name] for name in adjusted_params_required} + else: + missing_params = [ + param for param, present in zip(adjusted_params_required, params_present) if not present + ] + raise ValueError( + f"The following Llama3 RoPE parameters are missing in rope_adjustments: {', '.join(missing_params)}. " + "All Llama3 parameters must be specified together." + ) + + # YaRN-style RoPE + elif has_yarn: + # Required: factor, beta_fast, beta_slow, original_max_seq_len + # Optional: mscale, mscale_all_dim + yarn_required_params = ["factor", "beta_fast", "beta_slow", "original_max_seq_len"] + params_present = [param in self.config.rope_adjustments for param in yarn_required_params] + + if not all(params_present): + missing_params = [ + param for param, present in zip(yarn_required_params, params_present) if not present + ] + raise ValueError( + f"The following YaRN RoPE parameters are missing in rope_adjustments: {', '.join(missing_params)}. " + "All YaRN required parameters must be specified together." + ) + + extra_config = {name: self.config.rope_adjustments[name] for name in yarn_required_params} + + # Add optional YaRN parameters + for param in ["mscale", "mscale_all_dim"]: + if param in self.config.rope_adjustments: + extra_config[param] = self.config.rope_adjustments[param] + + # Linear or standard RoPE elif "factor" in self.config.rope_adjustments: # linear RoPE adjusted_params_required = ["factor"] extra_config = {name: self.config.rope_adjustments[name] for name in adjusted_params_required} else: - # Some but not all parameters are specified; raise an error - missing_params = [ - param for param, present in zip(adjusted_params_required, params_present) if not present - ] - raise ValueError( - f"The following adjusted RoPE parameters are missing in rope_adjustments: {', '.join(missing_params)}. " - "All adjusted RoPE parameters must be specified together." - ) + extra_config = None # uses standard RoPE return build_rope_cache( seq_len=self.max_seq_length, @@ -928,6 +964,9 @@ def build_rope_cache( # Compute the inverse frequencies theta theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, device=device).float() / n_elem)) + # Initialize attention scaling factor (modified for YaRN) + attention_scaling = 1.0 + if extra_config is not None: factor = extra_config["factor"] if "original_max_seq_len" in extra_config: @@ -943,7 +982,61 @@ def build_rope_cache( # Compute adjusted_theta without masked indexing adjusted_theta = (1 - smooth_factor) * (theta / factor) + smooth_factor * theta theta = adjusted_theta + elif "beta_fast" in extra_config: + # YaRN-style RoPE scaling + beta_fast = extra_config["beta_fast"] + beta_slow = extra_config["beta_slow"] + original_max_seq_len = extra_config["original_max_seq_len"] + + # Calculate attention scaling factor based on mscale and mscale_all_dim + mscale = extra_config.get("mscale") + mscale_all_dim = extra_config.get("mscale_all_dim") + if mscale and mscale_all_dim: + attention_scaling = yarn_get_mscale(factor, mscale) / yarn_get_mscale(factor, mscale_all_dim) + elif mscale_all_dim: + attention_scaling = yarn_get_mscale(factor, mscale_all_dim) + elif mscale: + attention_scaling = yarn_get_mscale(factor, mscale) + # else: attention_scaling remains 1.0 + + # Create two frequency sets: extrapolation (unscaled) and interpolation (scaled) + pos_freqs = base ** (torch.arange(0, n_elem, 2, device=device).float() / n_elem) + theta_extrapolation = 1.0 / pos_freqs + theta_interpolation = 1.0 / (factor * pos_freqs) + + # Find correction range based on rotation counts + # Inverse dimension formula to find dimension based on number of rotations + def find_correction_dim(num_rotations, dim, base_val, max_pos): + return (dim * math.log(max_pos / (num_rotations * 2 * math.pi))) / (2 * math.log(base_val)) + + low_dim = find_correction_dim(beta_fast, n_elem, base, original_max_seq_len) + high_dim = find_correction_dim(beta_slow, n_elem, base, original_max_seq_len) + + # Apply truncation if specified + if extra_config.get("truncate", True): + low_dim = math.floor(low_dim) + high_dim = math.ceil(high_dim) + + low_dim = max(low_dim, 0) + high_dim = min(high_dim, n_elem // 2 - 1) + + # Create linear ramp factor for blending + dim_range = torch.arange(n_elem // 2, device=device, dtype=torch.float32) + if low_dim == high_dim: + high_dim += 0.001 # Prevent singularity + + linear_func = (dim_range - low_dim) / (high_dim - low_dim) + ramp_func = torch.clamp(linear_func, 0.0, 1.0) + + # Blend extrapolation and interpolation frequencies + # ramp_func = 0 -> use interpolation (scaled), ramp_func = 1 -> use extrapolation (unscaled) + theta_extrapolation_factor = ramp_func + theta = ( + theta_interpolation * (1 - theta_extrapolation_factor) + + theta_extrapolation * theta_extrapolation_factor + ) else: + # Linear scaling fallback theta = theta / factor # Create position indices `[0, 1, ..., seq_len - 1]` @@ -972,7 +1065,9 @@ def build_rope_cache( idx_theta = torch.stack((idx_theta, local_idx_theta), dim=-1) - return torch.cos(idx_theta), torch.sin(idx_theta) + cos = torch.cos(idx_theta) * attention_scaling + sin = torch.sin(idx_theta) * attention_scaling + return cos, sin def batched_index_select(t, dim, idx): diff --git a/litgpt/modeling_rope_util.py b/litgpt/modeling_rope_util.py new file mode 100644 index 0000000000..48403e23b3 --- /dev/null +++ b/litgpt/modeling_rope_util.py @@ -0,0 +1,945 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import warnings +from functools import wraps +from typing import TYPE_CHECKING, Optional, TypedDict + +from .utils import is_torch_available, logging + + +logger = logging.get_logger(__name__) + + +if is_torch_available(): + import torch + +if TYPE_CHECKING: + from .configuration_utils import PreTrainedConfig + + +def dynamic_rope_update(rope_forward): + """ + Decorator function to update the RoPE parameters in the forward pass, if the model is using a dynamic RoPE + (i.e. a RoPE implementation that may recompute its frequencies in the forward pass). + + Args: + rope_forward (Callable): + The forward pass of the RoPE implementation. + + Returns: + The decorated forward pass. + """ + + def longrope_frequency_update(self, position_ids, device, layer_type=None): + """Longrope uses long factor if sequence is larger than original pretraining length, short otherwise.""" + seq_len = torch.max(position_ids) + 1 + + if layer_type is None: + rope_type = self.rope_type + original_inv_freq = self.original_inv_freq + prefix = "" + original_max_position_embeddings = self.config.rope_parameters["original_max_position_embeddings"] + else: + rope_type = self.rope_type[layer_type] + original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq") + prefix = f"{layer_type}_" + original_max_position_embeddings = self.config.rope_parameters[layer_type][ + "original_max_position_embeddings" + ] + + if seq_len > original_max_position_embeddings: + if not hasattr(self, f"{layer_type}_long_inv_freq"): + rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type] + long_inv_freq, _ = rope_init_fn( + self.config, + device, + seq_len=original_max_position_embeddings + 1, + layer_type=layer_type, + ) + self.register_buffer(f"{prefix}inv_freq", long_inv_freq, persistent=False) + setattr(self, f"{prefix}long_inv_freq", long_inv_freq) + else: + # This .to() is needed if the model has been moved to a device after being initialized (because + # the buffer is automatically moved, but not the original copy) + original_inv_freq = original_inv_freq.to(device) + self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False) + setattr(self, f"{prefix}original_inv_freq", original_inv_freq) + + def dynamic_frequency_update(self, position_ids, device, layer_type=None): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if layer_type is None: + rope_type = self.rope_type + max_seq_len_cached = self.max_seq_len_cached + original_inv_freq = self.original_inv_freq + prefix = "" + else: + rope_type = self.rope_type[layer_type] + max_seq_len_cached = getattr(self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached) + original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq") + prefix = f"{layer_type}_" + + if seq_len > max_seq_len_cached: # growth + rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type] + inv_freq, self.attention_scaling = rope_init_fn( + self.config, + device, + seq_len=seq_len, + layer_type=layer_type, + ) + # TODO joao: may break with compilation + self.register_buffer(f"{prefix}inv_freq", inv_freq, persistent=False) + setattr(self, f"{layer_type}_max_seq_len_cached", seq_len) + + if seq_len < self.original_max_seq_len and max_seq_len_cached > self.original_max_seq_len: # reset + # This .to() is needed if the model has been moved to a device after being initialized (because + # the buffer is automatically moved, but not the original copy) + original_inv_freq = original_inv_freq.to(device) + self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False) + setattr(self, f"{prefix}original_inv_freq", original_inv_freq) + setattr(self, f"{layer_type}_max_seq_len_cached", self.original_max_seq_len) + + @wraps(rope_forward) + def wrapper(self, x, position_ids, layer_type=None): + rope_type = self.rope_type if layer_type is None else self.rope_type[layer_type] + kwargs = {"layer_type": layer_type} if layer_type is not None else {} + if "dynamic" in rope_type: + dynamic_frequency_update(self, position_ids, device=x.device, **kwargs) + elif rope_type == "longrope": + longrope_frequency_update(self, position_ids, device=x.device, **kwargs) + return rope_forward(self, x, position_ids, **kwargs) + + return wrapper + + +def _compute_linear_scaling_rope_parameters( + config: Optional["PreTrainedConfig"] = None, + device: Optional["torch.device"] = None, + seq_len: int | None = None, + layer_type: str | None = None, +) -> tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies with linear scaling. Credits to the Reddit user /u/kaiokendev + Args: + config ([`~transformers."PreTrainedConfig"`]): + The model configuration. This function assumes that the config will provide at least the following + properties: + + * rope_theta (`float`): The base wavelength from which the inverse frequencies will be derived. + * hidden_size (`int`): The numerator when deriving a head_dim, if not provided directly. + * num_attention_heads (`int`): The denominator when deriving a head_dim, if not provided directly. + + Additionally, this function will make use of the following properties if they are found in the config: + + * head_dim (`int`, *optional*): The size of the key-value heads in the model. If None, this value will be + derived as hidden_size // num_attention_heads. + * partial_rotary_factor (`float`, *optional*): If less than 1.0, inverse frequencies will be returned for + the first fraction of the head_dim. Defaults to 1.0. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). + """ + # For backward compatibility standardize the `rope_parameters_dict` if it uses old format + config.standardize_rope_params() + rope_parameters_dict = config.rope_parameters[layer_type] if layer_type is not None else config.rope_parameters + factor = rope_parameters_dict["factor"] + + # Gets the default RoPE parameters + base = rope_parameters_dict["rope_theta"] + partial_rotary_factor = rope_parameters_dict.get("partial_rotary_factor", 1.0) + head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + dim = int(head_dim * partial_rotary_factor) + attention_factor = 1.0 # Unused in this type of RoPE + + # Compute the inverse frequencies + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)) + + # Then applies linear scaling to the frequencies. + # NOTE: originally, scaling was applied to the position_ids. However, we get `embs = inv_freq @ position_ids`, so + # applying scaling to the inverse frequencies is equivalent. + inv_freq /= factor + return inv_freq, attention_factor + + +def _compute_dynamic_ntk_parameters( + config: Optional["PreTrainedConfig"] = None, + device: Optional["torch.device"] = None, + seq_len: int | None = None, + layer_type: str | None = None, +) -> tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies with NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla + + Args: + config ([`~transformers."PreTrainedConfig"`]): + The model configuration. This function assumes that the config will provide at least the following + properties: + + * rope_theta (`float`): The base wavelength from which the inverse frequencies will be derived. + * hidden_size (`int`): The numerator when deriving a head_dim, if not provided directly. + * num_attention_heads (`int`): The denominator when deriving a head_dim, if not provided directly. + * max_position_embeddings (`int`): The default sequence length used to update the dynamic RoPE at + inference time + * rope_parameters (`dict[str, float]`): The standard RoPE scaling parameters, from which `factor` + will be accessed. The value of `factor` is used to determine the new base frequency, along with the + current sequence length (seq_len), the maximum positional embeddings (max_position_embeddings), and the + computed dimensionality (dim) of the rotary embeddings. If seq_len <= max_position_embeddings, this + factor has no effect. If seq_len <= max_position_embeddings, this factor effectively stretches the + context window using an exponent derived from `dim`. + + Additionally, this function will make use of the following properties if they are found in the config: + + * head_dim (`int`, *optional*): The size of the key-value heads in the model. If None, this value will be + derived as hidden_size // num_attention_heads. + * partial_rotary_factor (`float`, *optional*): If less than 1.0, inverse frequencies will be returned for + the first fraction of the head_dim. Defaults to 1.0. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length, used to update the dynamic RoPE at inference time. If `None` or shorter than + max_position_embeddings, this value will be overridden by max_position_embeddings. + + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). + """ + # For backward compatibility standardize the `rope_parameters_dict` if it uses old format + config.standardize_rope_params() + rope_parameters_dict = config.rope_parameters[layer_type] if layer_type is not None else config.rope_parameters + + base = rope_parameters_dict["rope_theta"] + partial_rotary_factor = rope_parameters_dict.get("partial_rotary_factor", 1.0) + head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + dim = int(head_dim * partial_rotary_factor) + factor = rope_parameters_dict["factor"] + attention_factor = 1.0 # Unused in this type of RoPE + + # seq_len: default to max_position_embeddings, e.g. at init time + if seq_len is None: + seq_len = config.max_position_embeddings + elif isinstance(seq_len, torch.Tensor): + seq_len = torch.maximum( + seq_len, + torch.tensor(config.max_position_embeddings, dtype=seq_len.dtype, device=seq_len.device), + ) + else: + seq_len = max(seq_len, config.max_position_embeddings) + + # Compute the inverse frequencies + base = base * ((factor * seq_len / config.max_position_embeddings) - (factor - 1)) ** (dim / (dim - 2)) + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)) + return inv_freq, attention_factor + + +def _compute_yarn_parameters( + config: "PreTrainedConfig", + device: Optional["torch.device"] = None, + seq_len: int | None = None, + layer_type: str | None = None, +) -> tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies with NTK scaling. Please refer to the + [original paper](https://huggingface.co/papers/2309.00071) + + Args: + config ([`~transformers."PreTrainedConfig"`]): + The model configuration. This function assumes that the config will provide at least the following + properties: + + * rope_theta (`float`): The base wavelength from which the inverse frequencies will be derived. + * hidden_size (`int`): The numerator when deriving a head_dim, if not provided directly. + * num_attention_heads (`int`): The denominator when deriving a head_dim, if not provided directly. + * max_position_embeddings (`int`): The maximum length of the positional embeddings. + * rope_parameters (`dict[str, float | int]`): The standard RoPE scaling parameters, from which the following + keys will be accessed: + * `attention_factor` (`float`, *optional*): The scaling factor to be applied to the computed cos/sin. + If None, the value is inferred from `factor`, `mscale`, and `mscale_all_dim` as available. + * `beta_fast` (`float`, *optional*, defaults to 32): Parameter to set the boundary for extrapolation + (only) in the linear ramp function. + * `beta_slow` (`float`, *optional*, defaults to 1): Parameter to set the boundary for interpolation + (only) in the linear ramp function. + * `factor` (`float`, *optional*): The scaling factor applied when interpolating the position IDs to + extend the possible context length. Additionally, if `attention_factor` is None, the log of this + value is used to compute a value for `attention_factor`, possibly in conjunciton with `mscale` and + `mscale_all_dim`, if provided. + * `mscale` (`float`, *optional*): If `attention_factor` is None and both `mscale` and + `mscale_all_dim` are provided, `mscale` acts scalar augmenting `log(factor)` when computing the + numerator for the inferred value of `attention_factor`. If not provided, `attention_factor` will be + calculated based on `factor` only. + * `mscale_all_dim` (`float`, *optional*): If `attention_factor` is None and both `mscale` and + `mscale_all_dim` are provided, `mscale_all_dim` acts scalar augmenting `log(factor)` when computing + the denominator for the inferred value of `attention_factor`. If not provided, `attention_factor` + will be calculated based on `factor` only. + * `original_max_position_embeddings` (`int`): The original max position embeddings used during pretraining. + * `truncate` (`bool`, *optional*): Whether to truncate the correction range. + + Additionally, this function will make use of the following properties if they are found in the config: + + * head_dim (`int`, *optional*): The size of the key-value heads in the model. If None, this value will be + derived as hidden_size // num_attention_heads. + * partial_rotary_factor (`float`, *optional*, defaults to 1.0): If less than 1.0, inverse frequencies + will be returned for the first fraction of the head_dim. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin. + """ + # For backward compatibility standardize the `rope_parameters_dict` if it uses old format + config.standardize_rope_params() + rope_parameters_dict = config.rope_parameters[layer_type] if layer_type is not None else config.rope_parameters + + base = rope_parameters_dict["rope_theta"] + partial_rotary_factor = rope_parameters_dict.get("partial_rotary_factor", 1.0) + head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + dim = int(head_dim * partial_rotary_factor) + + factor = rope_parameters_dict["factor"] + attention_factor = rope_parameters_dict.get("attention_factor") + mscale = rope_parameters_dict.get("mscale") + mscale_all_dim = rope_parameters_dict.get("mscale_all_dim") + original_max_position_embeddings = rope_parameters_dict["original_max_position_embeddings"] + + # NOTE: DeekSeek-V3 (and potentially other models) have `original_max_position_embeddings` field + # containing the pretrained value. They use the ratio between `max_position_embeddings` and this value + # to compute the default attention scaling factor, instead of using `factor`. + if factor is None: + factor = config.max_position_embeddings / original_max_position_embeddings + + def get_mscale(scale, mscale=1): + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + # Sets the attention factor as suggested in the paper + if attention_factor is None: + if mscale and mscale_all_dim: + attention_factor = float(get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dim)) + else: + attention_factor = get_mscale(factor) + + # Optional config options + # beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly) + beta_fast = rope_parameters_dict.get("beta_fast") or 32 + beta_slow = rope_parameters_dict.get("beta_slow") or 1 + + # Compute the inverse frequencies + def find_correction_dim(num_rotations, dim, base, max_position_embeddings): + """Inverse dimension formula to find the dimension based on the number of rotations""" + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base)) + + def find_correction_range(low_rot, high_rot, dim, base, max_position_embeddings, truncate): + """Find dimension range bounds based on rotations""" + low = find_correction_dim(low_rot, dim, base, max_position_embeddings) + high = find_correction_dim(high_rot, dim, base, max_position_embeddings) + if truncate: + low = math.floor(low) + high = math.ceil(high) + return max(low, 0), min(high, dim - 1) + + def linear_ramp_factor(min, max, dim): + if min == max: + max += 0.001 # Prevent singularity + + linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + # Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs + # to expand the possible context length. In other words, interpolation = apply scaling factor. + pos_freqs = base ** (torch.arange(0, dim, 2).to(device=device, dtype=torch.float) / dim) + inv_freq_extrapolation = 1.0 / pos_freqs + inv_freq_interpolation = 1.0 / (factor * pos_freqs) + + truncate = config.rope_parameters.get("truncate", True) + low, high = find_correction_range(beta_fast, beta_slow, dim, base, original_max_position_embeddings, truncate) + + # Get n-dimensional rotational scaling corrected for extrapolation + inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).to(device=device, dtype=torch.float) + inv_freq = ( + inv_freq_interpolation * (1 - inv_freq_extrapolation_factor) + + inv_freq_extrapolation * inv_freq_extrapolation_factor + ) + return inv_freq, attention_factor + + +def _compute_longrope_parameters( + config: "PreTrainedConfig", + device: Optional["torch.device"] = None, + seq_len: int | None = None, + layer_type: str | None = None, +) -> tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies with LongRoPE scaling. Please refer to the + [original implementation](https://github.com/microsoft/LongRoPE) + + Args: + config ([`~transformers."PreTrainedConfig"`]): + The model configuration. This function assumes that the config will provide at least the following + properties: + + * rope_theta (`float`): The base wavelength from which the inverse frequencies will be derived. + * hidden_size (`int`): The numerator when deriving a head_dim, if not provided directly. + * num_attention_heads (`int`): The denominator when deriving a head_dim, if not provided directly. + * max_position_embeddings (`int`): The maximum length of the positional embeddings. + * original_max_position_embeddings (`int`, *optional*): The original max position embeddings used during + pretraining. If not provided, defaults to `max_position_embeddings`. + * rope_parameters (`dict[str, float]`): The standard RoPE scaling parameters, from which the following keys + will be accessed: + * `attention_factor` (`float`, *optional*): The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, inferred from + the value of `factor`. + * `factor` (`float`, *optional*): The scaling factor to apply to the RoPE embeddings. If both + `max_position_embeddings` and `original_max_position_embeddings` are provided, this value will be + overridden s the ratio between those values. + * `long_factor` (`float`, *optional*): The scale factor applied when computing the inverse + frequencies if `seq_len` is provided and greater than `original_max_position_embeddings`. + * `short_factor` (`float`, *optional*): The scale factor applied when computing the inverse + frequencies if `seq_len` is None or less-than-or-equal-to `original_max_position_embeddings`. + + Additionally, this function will make use of the following properties if they are found in the config: + + * head_dim (`int`, *optional*): The size of the key-value heads in the model. If None, this value will be + derived as hidden_size // num_attention_heads. + * partial_rotary_factor (`float`, *optional*, defaults to 1.0): If less than 1.0, inverse frequencies + will be returned for the first fraction of the head_dim. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. + + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin. + """ + # For backward compatibility standardize the `rope_parameters_dict` if it uses old format + config.standardize_rope_params() + rope_parameters_dict = config.rope_parameters[layer_type] if layer_type is not None else config.rope_parameters + + base = rope_parameters_dict["rope_theta"] + partial_rotary_factor = rope_parameters_dict.get("partial_rotary_factor", 1.0) + head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + dim = int(head_dim * partial_rotary_factor) + + long_factor = rope_parameters_dict["long_factor"] + short_factor = rope_parameters_dict["short_factor"] + factor = rope_parameters_dict.get("factor") + attention_factor = rope_parameters_dict.get("attention_factor") + original_max_position_embeddings = rope_parameters_dict["original_max_position_embeddings"] + + # NOTE: Phi3 (and potentially other models) modify `max_position_embeddings` and have a + # `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two + # values to compute the default attention scaling factor, instead of using `factor`. + if factor is None: + factor = config.max_position_embeddings / original_max_position_embeddings + + # Sets the attention factor as suggested in the paper + if attention_factor is None: + if factor <= 1.0: + attention_factor = 1.0 + else: + attention_factor = math.sqrt(1 + math.log(factor) / math.log(original_max_position_embeddings)) + + # Compute the inverse frequencies -- scaled based on the target sequence length + if seq_len and seq_len > original_max_position_embeddings: + ext_factors = torch.tensor(long_factor, dtype=torch.float32, device=device) + else: + ext_factors = torch.tensor(short_factor, dtype=torch.float32, device=device) + inv_freq_shape = torch.arange(0, dim, 2, dtype=torch.int64, device=device).float() / dim + inv_freq = 1.0 / (ext_factors * base**inv_freq_shape) + + return inv_freq, attention_factor + + +def _compute_llama3_parameters( + config: "PreTrainedConfig", + device: Optional["torch.device"] = None, + seq_len: int | None = None, + layer_type: str | None = None, +) -> tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies for llama 3.1. + + Args: + config ([`~transformers."PreTrainedConfig"`]): + The model configuration. This function assumes that the config will provide at least the following + properties: + + * rope_theta (`float`): The base wavelength from which the inverse frequencies will be derived. + * hidden_size (`int`): The numerator when deriving a head_dim, if not provided directly. + * num_attention_heads (`int`): The denominator when deriving a head_dim, if not provided directly. + * rope_parameters (`dict[str, float | int]`): The standard RoPE scaling parameters, from which the following + keys will be accessed: + * `factor` (`float`, *optional*): The scaling factor applied to the inverse frequencies when 1) the + wavelength is greater than `low_freq_wavelen` prior to smoothing, and 2) to all inverse frequencies + during smoothing. + * `high_freq_factor` (`float`): The scale factor used to compute `high_freq_wavelen` and + the value for the denominator of the smoothing factor prior to the `low_freq_factor` shift. + * `low_freq_factor` (`float`): The scale factor used to compute `low_freq_wavelen` and + the shift applied to the numerator and denominator of the smoothing factor. + frequencies if `seq_len` is None or less-than-or-equal-to `original_max_position_embeddings`. + * `original_max_position_embeddings` (`int`): The original max position embeddings used + during pretraining. If not provided, the function falls back to `max_position_embeddings`. + + Additionally, this function will make use of the following properties if they are found in the config: + + * head_dim (`int`, *optional*): The size of the key-value heads in the model. If None, this value will be + derived as hidden_size // num_attention_heads. + * partial_rotary_factor (`float`, *optional*): If less than 1.0, inverse frequencies will be returned for + the first fraction of the head_dim. Defaults to 1.0. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin. + """ + # For backward compatibility standardize the `rope_parameters_dict` if it uses old format + config.standardize_rope_params() + rope_parameters_dict = config.rope_parameters[layer_type] if layer_type is not None else config.rope_parameters + + # Gets the default RoPE parameters + base = rope_parameters_dict["rope_theta"] + partial_rotary_factor = rope_parameters_dict.get("partial_rotary_factor", 1.0) + head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + dim = int(head_dim * partial_rotary_factor) + attention_factor = 1.0 # Unused in this type of RoPE + + # Compute the inverse frequencies + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)) + + factor = rope_parameters_dict["factor"] # `8` in the original implementation + low_freq_factor = rope_parameters_dict["low_freq_factor"] # `1` in the original implementation + high_freq_factor = rope_parameters_dict["high_freq_factor"] # `4` in the original implementation + old_context_len = rope_parameters_dict["original_max_position_embeddings"] # `8192` in the original implementation + + low_freq_wavelen = old_context_len / low_freq_factor + high_freq_wavelen = old_context_len / high_freq_factor + + wavelen = 2 * math.pi / inv_freq + # wavelen < high_freq_wavelen: do nothing + # wavelen > low_freq_wavelen: divide by factor + inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq) + # otherwise: interpolate between the two, using a smooth factor + smooth_factor = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor) + smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / factor + smooth_factor * inv_freq_llama + is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen) + inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama) + + return inv_freq_llama, attention_factor + + +# This maps the "rope_type" string field in rope config to the corresponding function to compute the RoPE parameters +# from the model config. You can append new {'rope_type': callable} pairs to this rope_parameters to enable custom RoPE +# parameterizations, as long as the callable has the same signature. +ROPE_INIT_FUNCTIONS = { + "linear": _compute_linear_scaling_rope_parameters, + "dynamic": _compute_dynamic_ntk_parameters, + "yarn": _compute_yarn_parameters, + "longrope": _compute_longrope_parameters, + "llama3": _compute_llama3_parameters, +} + + +class RopeParameters(TypedDict, total=False): + """ + Args: + rope_theta (`float`): + The base period of the RoPE embeddings. + rope_type (`str`, *optional*, defaults to "default"): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + partial_rotary_factor (`float`, *optional*): + The percentage of the query and key head embedding on which RoPE will be applied. + factor (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + original_max_position_embeddings (`int`, *optional*): + Used with 'yarn', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + attention_factor (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + beta_fast (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + beta_slow (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + short_factor (`list[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + long_factor (`list[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + low_freq_factor (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + high_freq_factor (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + """ + + rope_theta: float + rope_type: str | None + partial_rotary_factor: float | None + factor: float | None + original_max_position_embeddings: int | None + attention_factor: float | None + beta_fast: float | None + beta_slow: float | None + short_factor: list[float] | None + long_factor: list[float] | None + low_freq_factor: float | None + high_freq_factor: float | None + + +class RotaryEmbeddingConfigMixin: + """ + A Mixin containing the functionality to standardize and validate RoPE parameters. + """ + + default_theta = 10_000.0 + + def convert_rope_params_to_dict(self, ignore_keys_at_rope_validation: set | None = None, **kwargs): + rope_scaling = kwargs.pop("rope_scaling", None) + self.rope_parameters = rope_scaling or self.rope_parameters + self.rope_parameters = self.rope_parameters if self.rope_parameters is not None else {} + + # Standardize and validate the correctness of rotary position embeddings parameters. Priority for these parameters is: + # 1. Values in `rope_parameters` dict (where they should be after standardization) + # 2. Values in `kwargs` (i.e. it's in config.json but not MyConfig.__init__'s args) + # 3. Values in the config's attributes (i.e. it's in MyConfig.__init__'s args) + # 4. Default values (i.e. not present at all but other RoPE parameters are present) + rope_theta = kwargs.pop("rope_theta", getattr(self, "rope_theta", self.default_theta)) + self.rope_parameters.setdefault("rope_theta", rope_theta) + + partial_rotary_factor = kwargs.get("partial_rotary_factor", getattr(self, "partial_rotary_factor", None)) + if partial_rotary_factor is not None: + self.rope_parameters.setdefault("partial_rotary_factor", partial_rotary_factor) + ignore_keys_at_rope_validation = ( + set() if ignore_keys_at_rope_validation is None else ignore_keys_at_rope_validation + ) + ignore_keys_at_rope_validation = ignore_keys_at_rope_validation | {"partial_rotary_factor"} + + self.standardize_rope_params() + self.validate_rope(ignore_keys=ignore_keys_at_rope_validation) + return kwargs + + def standardize_rope_params(self): + """ + Helper to standardize the config's rope params field by ensuring the params are defined for each + later type. For old model the fn will duplicate a single rope param in each layer type (backward compatibility) + """ + # Move `rope_theta` and `partial_rotary_factor` to the `rope_parameters`, if not there yet + rope_theta = getattr(self, "rope_theta", None) + partial_rotary_factor = getattr(self, "partial_rotary_factor", None) + rope_parameters = getattr(self, "rope_parameters", None) or {} + layer_types = getattr(self, "layer_types", None) + + # Case 0: no RoPE params defined + if not (rope_parameters or rope_theta): + # partial_rotary_factor without rope_theta is invalid, so we don't check for it here + logger.warning("`standardize_rope_params` was called but no RoPE parameters were found.") + return + # Case 1: RoPE param keys do not intersect with possible `layer_types` -> one global dict + elif layer_types is None or rope_parameters == {} or not set(rope_parameters.keys()).issubset(layer_types): + rope_parameters.setdefault("rope_type", rope_parameters.get("type", "default")) + rope_parameters.setdefault("rope_theta", rope_theta) + if partial_rotary_factor is not None: + rope_parameters["partial_rotary_factor"] = partial_rotary_factor + + # Move pretraining-time maximum length to rope parameter dict for RoPE types with scaling + if rope_parameters["rope_type"] in ["llama3", "yarn", "longrope"]: + if hasattr(self, "original_max_position_embeddings"): + # NOTE: Phi3 (and potentially other models) save `original_max_position_embeddings` field + # containing the pretrained value outside rope parameters. This is an exception case where we + # give priority to `self.original_max_position_embeddings + self.rope_parameters["original_max_position_embeddings"] = self.original_max_position_embeddings + else: + self.rope_parameters.setdefault("original_max_position_embeddings", self.max_position_embeddings) + + # Case 2: different RoPE for each layer -> several params as nested dict + else: + for layer_type in set(layer_types): + rope_parameters[layer_type].setdefault("rope_type", rope_parameters[layer_type].get("type", "default")) + rope_parameters[layer_type].setdefault("rope_theta", rope_theta) + if partial_rotary_factor is not None: + rope_parameters[layer_type]["partial_rotary_factor"] = partial_rotary_factor + + if rope_parameters[layer_type]["rope_type"] in ["llama3", "yarn", "longrope"]: + self.rope_parameters[layer_type].setdefault( + "original_max_position_embeddings", self.max_position_embeddings + ) + + self.rope_parameters = rope_parameters + + def validate_rope(self: "PreTrainedConfig", ignore_keys: set | None = None): + """ + Validate the RoPE config arguments, given a `"PreTrainedConfig"` object + """ + rope_parameters_dict = self.rope_parameters + if rope_parameters_dict is None: + return + + if getattr(self, "layer_types", None) is not None and set(rope_parameters_dict.keys()).issubset( + self.layer_types + ): + pass + else: + rope_parameters_dict = {"full_attention": rope_parameters_dict} + + for rope_parameters in rope_parameters_dict.values(): + rope_type = rope_parameters.get("rope_type", rope_parameters.get("type", "default")) + validation_fn = getattr(self, f"_validate_{rope_type}_rope_parameters", None) + rope_parameters["rope_type"] = rope_type + + if validation_fn is not None: + validation_fn(rope_parameters, ignore_keys=ignore_keys) + else: + logger.warning( + f"Missing validation function in 'RotaryEmbeddingConfigMixin' for 'rope_type'='{rope_type}'" + ) + + def _validate_default_rope_parameters(self, rope_parameters: dict, ignore_keys: set | None = None): + required_keys = {"rope_type", "rope_theta"} + received_keys = set(rope_parameters.keys()) + rope_type = rope_parameters["rope_type"] + self._check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys) + + def _validate_linear_rope_parameters(self, rope_parameters: dict, ignore_keys: set | None = None): + required_keys = {"rope_type", "factor", "rope_theta"} + received_keys = set(rope_parameters.keys()) + rope_type = rope_parameters["rope_type"] + self._check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys) + + factor = rope_parameters["factor"] + if factor is None or not isinstance(factor, float) or factor < 1.0: + logger.warning(f"`rope_parameters`'s factor field must be a float >= 1, got {factor}") + + def _validate_dynamic_rope_parameters(self, rope_parameters: dict, ignore_keys: set | None = None): + required_keys = {"rope_type", "factor"} + received_keys = set(rope_parameters.keys()) + rope_type = rope_parameters["rope_type"] + self._check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys) + + factor = rope_parameters["factor"] + if factor is None or not isinstance(factor, float) or factor < 1.0: + logger.warning(f"`rope_parameters`'s factor field must be a float >= 1, got {factor}") + + def _validate_yarn_rope_parameters(self, rope_parameters: dict, ignore_keys: set | None = None): + required_keys = {"rope_type", "factor", "rope_theta", "original_max_position_embeddings"} + optional_keys = { + "attention_factor", + "beta_fast", + "beta_slow", + "mscale", + "mscale_all_dim", + "truncate", + } + received_keys = set(rope_parameters.keys()) + rope_type = rope_parameters["rope_type"] + self._check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys) + + factor = rope_parameters["factor"] + if factor is None or not isinstance(factor, float) or factor < 1.0: + logger.warning(f"`rope_parameters`'s factor field must be a float >= 1, got {factor}") + + attention_factor = rope_parameters.get("attention_factor") + if attention_factor is not None and (not isinstance(attention_factor, float) or attention_factor < 0): + logger.warning( + f"`rope_parameters`'s attention_factor field must be a float greater than 0, got {attention_factor}" + ) + beta_fast = rope_parameters.get("beta_fast") + if beta_fast is not None and not isinstance(beta_fast, float): + logger.warning(f"`rope_parameters`'s beta_fast field must be a float, got {beta_fast}") + beta_slow = rope_parameters.get("beta_slow") + if beta_slow is not None and not isinstance(beta_slow, float): + logger.warning(f"`rope_parameters`'s beta_slow field must be a float, got {beta_slow}") + + if (beta_fast or 32) < (beta_slow or 1): + logger.warning( + f"`rope_parameters`'s beta_fast field must be greater than beta_slow, got beta_fast={beta_fast} " + f"(defaults to 32 if None) and beta_slow={beta_slow} (defaults to 1 if None)" + ) + + # Double-check: `factor` should be the ratio between the pre-yarn and post-yarn context lengths. + # NOTE: we might get `implicit_factor == 1` if config's `original_max_position_embeddings` was + # inferred from `max_position_embeddings` during standardization + original_max_position_embeddings = self.rope_parameters["original_max_position_embeddings"] + implicit_factor = self.max_position_embeddings / original_max_position_embeddings + if implicit_factor != factor and implicit_factor != 1: + logger.warning_once( + f"The explicitly set RoPE scaling factor (config.rope_parameters['factor'] = {factor}) does not match " + "the ratio implicitly set by other parameters (implicit factor = " + "post-yarn context length / pre-yarn context length = " + "config.max_position_embeddings / config.rope_parameters['original_max_position_embeddings'] = " + f"{implicit_factor}). Using the explicit factor ({factor}) in YaRN. This may cause unexpected " + "behaviour in model usage, please correct the 'original_max_position_embeddings' fields in the model config." + ) + + def _validate_longrope_rope_parameters(self, rope_parameters: dict, ignore_keys: set | None = None): + required_keys = {"rope_type", "short_factor", "long_factor", "rope_theta", "original_max_position_embeddings"} + optional_keys = {"attention_factor", "factor"} + received_keys = set(rope_parameters.keys()) + rope_type = rope_parameters["rope_type"] + self._check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys) + + partial_rotary_factor = rope_parameters.get("partial_rotary_factor", 1.0) + head_dim = getattr(self, "head_dim", self.hidden_size // self.num_attention_heads) + dim = int(head_dim * partial_rotary_factor) + + short_factor = rope_parameters.get("short_factor") + if not isinstance(short_factor, list) and all(isinstance(x, (int, float)) for x in short_factor): + logger.warning(f"`rope_parameters`'s short_factor field must be a list of numbers, got {short_factor}") + if len(short_factor) != dim // 2: + logger.warning( + f"`rope_parameters`'s short_factor field must have length {dim // 2}, got {len(short_factor)}" + ) + + long_factor = rope_parameters.get("long_factor") + if not isinstance(long_factor, list) and all(isinstance(x, (int, float)) for x in long_factor): + logger.warning(f"`rope_parameters`'s long_factor field must be a list of numbers, got {long_factor}") + if len(long_factor) != dim // 2: + logger.warning( + f"`rope_parameters`'s long_factor field must have length {dim // 2}, got {len(long_factor)}" + ) + + factor = rope_parameters.get("factor") + original_max_position_embeddings = rope_parameters["original_max_position_embeddings"] + + # Handle Phi3 divergence: we prefer the use of `attention_factor` and/or `factor` over + # `original_max_position_embeddings` to compute internal variables. The latter is undesirable + if factor is None and original_max_position_embeddings is not None: + logger.warning_once( + "This model config has set a `rope_parameters['original_max_position_embeddings']` field, to be used together with " + "`max_position_embeddings` to determine a scaling factor. Please set the `factor` field of `rope_parameters`" + "with this ratio instead -- we recommend the use of this field over `original_max_position_embeddings`, " + "as it is compatible with most model architectures." + ) + elif factor is None and original_max_position_embeddings is None: + logger.warning("Missing required keys in `rope_parameters`: 'factor'") + elif not isinstance(factor, float) or factor < 1.0: + logger.warning(f"`rope_parameters`'s factor field must be a float >= 1, got {factor}") + + attention_factor = rope_parameters.get("attention_factor") + if attention_factor is not None and (not isinstance(attention_factor, float) or attention_factor < 0.0): + logger.warning( + f"`rope_parameters`'s attention_factor field must be a float greater than 0, got {attention_factor}" + ) + + def _validate_llama3_rope_parameters(self, rope_parameters: dict, ignore_keys: set | None = None): + required_keys = { + "rope_type", + "factor", + "original_max_position_embeddings", + "low_freq_factor", + "high_freq_factor", + "rope_theta", + } + rope_type = rope_parameters["rope_type"] + received_keys = set(rope_parameters.keys()) + self._check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys) + + factor = rope_parameters["factor"] + if factor is None or not isinstance(factor, float) or factor < 1.0: + logger.warning(f"`rope_parameters`'s factor field must be a float >= 1, got {factor}") + + low_freq_factor = rope_parameters["low_freq_factor"] + high_freq_factor = rope_parameters["high_freq_factor"] + if low_freq_factor is None or not isinstance(low_freq_factor, float): + logger.warning(f"`rope_parameters`'s low_freq_factor field must be a float, got {low_freq_factor}") + if high_freq_factor is None or not isinstance(high_freq_factor, float): + logger.warning(f"`rope_parameters`'s high_freq_factor field must be a float, got {high_freq_factor}") + if high_freq_factor <= low_freq_factor: + logger.warning( + "`rope_parameters`'s high_freq_factor field must be greater than low_freq_factor, got high_freq_factor=" + f"{high_freq_factor} and low_freq_factor={low_freq_factor}" + ) + + original_max_position_embeddings = rope_parameters["original_max_position_embeddings"] + if original_max_position_embeddings is None or not isinstance(original_max_position_embeddings, int): + logger.warning( + "`rope_parameters`'s original_max_position_embeddings field must be an integer, got " + f"{original_max_position_embeddings}" + ) + if original_max_position_embeddings >= self.max_position_embeddings: + logger.warning( + "`rope_parameters`'s original_max_position_embeddings field must be less than max_position_embeddings, got " + f"{original_max_position_embeddings} and max_position_embeddings={self.max_position_embeddings}" + ) + + @staticmethod + def _check_received_keys( + rope_type: str, + received_keys: set, + required_keys: set, + optional_keys: set | None = None, + ignore_keys: set | None = None, + ): + """Compare the received keys in `config.rope_parameters` against the expected and optional keys""" + # BC: "rope_type" was originally "type" -- let's check for "rope_type" when "type" is present + if "type" in received_keys: + received_keys -= {"type"} + required_keys.add("rope_type") + + optional_keys = optional_keys or set() + if "partial_rotary_factor" not in optional_keys: + optional_keys.add("partial_rotary_factor") + + # Some models need to store model-specific keys, and we don't want to throw warning at them + if ignore_keys is not None: + received_keys -= ignore_keys + + missing_keys = required_keys - received_keys + if missing_keys: + raise KeyError(f"Missing required keys in `rope_parameters` for 'rope_type'='{rope_type}': {missing_keys}") + + unused_keys = received_keys - required_keys - optional_keys + if unused_keys: + logger.warning(f"Unrecognized keys in `rope_parameters` for 'rope_type'='{rope_type}': {unused_keys}") + + +def rope_config_validation(config: RotaryEmbeddingConfigMixin, ignore_keys: set | None = None): + """ + This is a deprecated function. + It has been kept for backward compatibility with custom code models. + """ + warnings.warn( + "`rope_config_validation` is deprecated and has been removed. " + "Its functionality has been moved to RotaryEmbeddingConfigMixin.validate_rope method. " + "PreTrainedConfig inherits this class, so please call self.validate_rope() instead. " + "Also, make sure to use the new rope_parameters syntax. " + "You can call self.standardize_rope_params() in the meantime.", + FutureWarning, + ) + config.standardize_rope_params() + config.validate_rope(ignore_keys=ignore_keys) \ No newline at end of file From a4d515d6f5fa1d4d68981e2289b7d08d8cb763ce Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 15 Feb 2026 19:34:54 +0000 Subject: [PATCH 24/51] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- litgpt/modeling_rope_util.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/litgpt/modeling_rope_util.py b/litgpt/modeling_rope_util.py index 48403e23b3..a0e2a8559e 100644 --- a/litgpt/modeling_rope_util.py +++ b/litgpt/modeling_rope_util.py @@ -19,7 +19,6 @@ from .utils import is_torch_available, logging - logger = logging.get_logger(__name__) @@ -829,9 +828,7 @@ def _validate_longrope_rope_parameters(self, rope_parameters: dict, ignore_keys: if not isinstance(long_factor, list) and all(isinstance(x, (int, float)) for x in long_factor): logger.warning(f"`rope_parameters`'s long_factor field must be a list of numbers, got {long_factor}") if len(long_factor) != dim // 2: - logger.warning( - f"`rope_parameters`'s long_factor field must have length {dim // 2}, got {len(long_factor)}" - ) + logger.warning(f"`rope_parameters`'s long_factor field must have length {dim // 2}, got {len(long_factor)}") factor = rope_parameters.get("factor") original_max_position_embeddings = rope_parameters["original_max_position_embeddings"] @@ -942,4 +939,4 @@ def rope_config_validation(config: RotaryEmbeddingConfigMixin, ignore_keys: set FutureWarning, ) config.standardize_rope_params() - config.validate_rope(ignore_keys=ignore_keys) \ No newline at end of file + config.validate_rope(ignore_keys=ignore_keys) From d514423fc2b0441ee1d8ead42198830a74018dfd Mon Sep 17 00:00:00 2001 From: Yu Shi Jie Date: Sun, 15 Feb 2026 14:45:02 -0500 Subject: [PATCH 25/51] test: test_yarn (with deepseekv3 block) --- litgpt/config.py | 2 +- tests/test_yarn.py | 144 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 145 insertions(+), 1 deletion(-) create mode 100644 tests/test_yarn.py diff --git a/litgpt/config.py b/litgpt/config.py index a80daea4fb..ac46e51875 100644 --- a/litgpt/config.py +++ b/litgpt/config.py @@ -3223,7 +3223,7 @@ def check_indicator_and_length( norm_topk_prob=True, routed_scaling_factor=2.5, rope_adjustments=dict( - factor=40.0, low_freq_factor=1.0, high_freq_factor=32.0, original_max_seq_len=4096, mscale_all_dim=1.0 + factor=40.0, beta_slow=1.0, beta_fast=32.0, original_max_seq_len=4096, mscale=1.0, mscale_all_dim=1.0 ), ), ] diff --git a/tests/test_yarn.py b/tests/test_yarn.py new file mode 100644 index 0000000000..297a822727 --- /dev/null +++ b/tests/test_yarn.py @@ -0,0 +1,144 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. + +import pytest +import torch +from transformers.models.deepseek_v3 import DeepseekV3Config, DeepseekV3ForCausalLM + +from litgpt import Config +from litgpt.model import Block + + +@torch.inference_mode() +@pytest.mark.parametrize("batch_size", (1, 2)) +@pytest.mark.parametrize("seq_len", (8, 16)) +@pytest.mark.parametrize("device", [torch.device("cpu")]) +def test_deepseek_v3_block_with_yarn(batch_size, seq_len, device): + """Test DeepSeek V3 block (attention + MLP + norms) with YaRN RoPE scaling - litgpt vs hf""" + # Use layer_idx=0 to test dense MLP instead of MoE + layer_idx = 0 + + # YaRN configuration + yarn_config = dict( + factor=8.0, + beta_fast=32.0, + beta_slow=1.0, + original_max_seq_len=4096, + mscale=1.0, + mscale_all_dim=0.8, + ) + + config_litgpt = Config( + n_embd=64, + n_head=4, + n_query_groups=4, + head_size=16, + norm_eps=1e-6, + norm_class_name="RMSNorm", + bias=False, + parallel_residual=False, + mlp_class_name="LLaMAMoE", + intermediate_size=128, + rope_interleave=True, + rope_adjustments=yarn_config, # YaRN config + latent_attention={ + "q_lora_rank": 32, + "kv_lora_rank": 16, + "qk_rope_head_dim": 8, + "qk_nope_head_dim": 8, + "v_head_dim": 16, + }, + first_k_dense_replace=3, # Use dense MLP for first 3 layers + ) + + # HF config with YaRN + rope_parameters = { + "rope_type": "yarn", + "rope_theta": 10000.0, + "factor": yarn_config["factor"], + "beta_fast": yarn_config["beta_fast"], + "beta_slow": yarn_config["beta_slow"], + "original_max_position_embeddings": yarn_config["original_max_seq_len"], + "mscale": yarn_config["mscale"], + "mscale_all_dim": yarn_config["mscale_all_dim"], + } + + config_hf = DeepseekV3Config( + padded_vocab_size=10000, + num_hidden_layers=1, + vocab_size=10000, + hidden_size=64, + intermediate_size=128, + num_attention_heads=4, + num_key_value_heads=4, + q_lora_rank=32, + kv_lora_rank=16, + qk_rope_head_dim=8, + qk_nope_head_dim=8, + v_head_dim=16, + rope_interleave=True, + first_k_dense_replace=3, + rms_norm_eps=1e-6, + rope_parameters=rope_parameters, # YaRN config + ) + + block_litgpt = Block(config_litgpt, block_idx=layer_idx).to(device) + model_hf = DeepseekV3ForCausalLM(config_hf).to(device) + block_hf = model_hf.model.layers[layer_idx] + + block_litgpt.eval() + block_hf.eval() + + sync_block_weights(block_litgpt, block_hf) + + hidden_states = torch.randn(batch_size, seq_len, config_litgpt.n_embd, device=device) + + # Prepare RoPE sin/cos tables + rope_head_dim = config_litgpt.latent_attention["qk_rope_head_dim"] + cos = torch.randn(batch_size, seq_len, rope_head_dim, device=device, dtype=hidden_states.dtype) + sin = torch.randn(batch_size, seq_len, rope_head_dim, device=device, dtype=hidden_states.dtype) + + causal_mask = torch.triu( + torch.full((seq_len, seq_len), float("-inf"), device=device, dtype=hidden_states.dtype), diagonal=1 + ) + attention_mask = causal_mask.unsqueeze(0).unsqueeze(0).expand(batch_size, 1, -1, -1) + + # Run forward passes + output_litgpt = block_litgpt(hidden_states, cos, sin) + output_hf = block_hf(hidden_states, position_embeddings=(cos, sin), attention_mask=attention_mask) + + assert torch.allclose(output_litgpt, output_hf, atol=1e-5, rtol=1e-4), ( + f"Max diff: {(output_litgpt - output_hf).abs().max()}" + ) + + +def sync_weights(litgpt_model, hf_model): + """Copies weights from lit-gpt model to HF model.""" + print("Synchronizing weights...") + with torch.no_grad(): + hf_model.q_a_proj.weight.copy_(litgpt_model.q_a_proj.weight) + hf_model.q_a_layernorm.weight.copy_(litgpt_model.q_a_norm.weight) + hf_model.q_b_proj.weight.copy_(litgpt_model.q_b_proj.weight) + hf_model.kv_a_proj_with_mqa.weight.copy_(litgpt_model.kv_a_proj_with_mqa.weight) + hf_model.kv_a_layernorm.weight.copy_(litgpt_model.kv_a_norm.weight) + hf_model.kv_b_proj.weight.copy_(litgpt_model.kv_b_proj.weight) + hf_model.o_proj.weight.copy_(litgpt_model.proj.weight) + print("Synchronization complete.") + + +def sync_block_weights(block_litgpt, block_hf): + """Synchronize all weights from LitGPT block to HF block.""" + print("Synchronizing block weights...") + with torch.no_grad(): + # Sync attention weights + sync_weights(block_litgpt.attn, block_hf.self_attn) + + # Sync MLP weights (assumes dense MLP, not MoE) + block_hf.mlp.gate_proj.weight.copy_(block_litgpt.mlp.fc_1.weight) + block_hf.mlp.up_proj.weight.copy_(block_litgpt.mlp.fc_2.weight) + block_hf.mlp.down_proj.weight.copy_(block_litgpt.mlp.proj.weight) + + # Sync normalization layers + block_hf.input_layernorm.weight.copy_(block_litgpt.norm_1.weight) + block_hf.post_attention_layernorm.weight.copy_(block_litgpt.norm_2.weight) + + print("Block synchronization complete.") From 0a961592db2b42db8573cf42c76c8a48368095ae Mon Sep 17 00:00:00 2001 From: Yu Shi Jie Date: Sun, 15 Feb 2026 15:08:19 -0500 Subject: [PATCH 26/51] test --- litgpt/model.py | 30 ++++++++++++------------ tests/test_yarn.py | 57 ++++++++++++++++++++++++++++++++++++++++++---- 2 files changed, 69 insertions(+), 18 deletions(-) diff --git a/litgpt/model.py b/litgpt/model.py index 0ad8e739b7..f9d9857511 100644 --- a/litgpt/model.py +++ b/litgpt/model.py @@ -969,20 +969,8 @@ def build_rope_cache( if extra_config is not None: factor = extra_config["factor"] - if "original_max_seq_len" in extra_config: - orig_context_len = extra_config["original_max_seq_len"] - low_freq_factor = extra_config["low_freq_factor"] - high_freq_factor = extra_config["high_freq_factor"] - - wavelen = 2 * torch.pi / theta - ratio = orig_context_len / wavelen - smooth_factor = (ratio - low_freq_factor) / (high_freq_factor - low_freq_factor) - smooth_factor = torch.clamp(smooth_factor, min=0.0, max=1.0) - - # Compute adjusted_theta without masked indexing - adjusted_theta = (1 - smooth_factor) * (theta / factor) + smooth_factor * theta - theta = adjusted_theta - elif "beta_fast" in extra_config: + # Check YaRN first (has beta_fast/beta_slow) + if "beta_fast" in extra_config or "beta_slow" in extra_config: # YaRN-style RoPE scaling beta_fast = extra_config["beta_fast"] beta_slow = extra_config["beta_slow"] @@ -1035,6 +1023,20 @@ def find_correction_dim(num_rotations, dim, base_val, max_pos): theta_interpolation * (1 - theta_extrapolation_factor) + theta_extrapolation * theta_extrapolation_factor ) + elif "original_max_seq_len" in extra_config: + # Llama3-style RoPE scaling + orig_context_len = extra_config["original_max_seq_len"] + low_freq_factor = extra_config["low_freq_factor"] + high_freq_factor = extra_config["high_freq_factor"] + + wavelen = 2 * torch.pi / theta + ratio = orig_context_len / wavelen + smooth_factor = (ratio - low_freq_factor) / (high_freq_factor - low_freq_factor) + smooth_factor = torch.clamp(smooth_factor, min=0.0, max=1.0) + + # Compute adjusted_theta without masked indexing + adjusted_theta = (1 - smooth_factor) * (theta / factor) + smooth_factor * theta + theta = adjusted_theta else: # Linear scaling fallback theta = theta / factor diff --git a/tests/test_yarn.py b/tests/test_yarn.py index 297a822727..8491934ef8 100644 --- a/tests/test_yarn.py +++ b/tests/test_yarn.py @@ -92,10 +92,46 @@ def test_deepseek_v3_block_with_yarn(batch_size, seq_len, device): hidden_states = torch.randn(batch_size, seq_len, config_litgpt.n_embd, device=device) - # Prepare RoPE sin/cos tables + # Prepare RoPE sin/cos tables using YaRN computation + from litgpt.model import build_rope_cache + rope_head_dim = config_litgpt.latent_attention["qk_rope_head_dim"] - cos = torch.randn(batch_size, seq_len, rope_head_dim, device=device, dtype=hidden_states.dtype) - sin = torch.randn(batch_size, seq_len, rope_head_dim, device=device, dtype=hidden_states.dtype) + + # Build YaRN RoPE cache for LitGPT + cos_litgpt, sin_litgpt = build_rope_cache( + seq_len=seq_len, + n_elem=rope_head_dim, + device=device, + base=config_litgpt.rope_base, + extra_config={ + "factor": yarn_config["factor"], + "beta_fast": yarn_config["beta_fast"], + "beta_slow": yarn_config["beta_slow"], + "original_max_seq_len": yarn_config["original_max_seq_len"], + "mscale": yarn_config["mscale"], + "mscale_all_dim": yarn_config["mscale_all_dim"], + }, + ) + + # Get YaRN RoPE embeddings from HF + rotary_emb = model_hf.model.layers[layer_idx].self_attn.rotary_emb + position_ids = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1) + cos_hf, sin_hf = rotary_emb(hidden_states, position_ids) + + # Expand dimensions for batch and broadcast + cos_litgpt = cos_litgpt.unsqueeze(0).expand(batch_size, -1, -1) + sin_litgpt = sin_litgpt.unsqueeze(0).expand(batch_size, -1, -1) + + # Compare RoPE embeddings first + print(f"\n=== RoPE Embeddings Comparison ===") + print(f"LitGPT cos/sin shape: {cos_litgpt.shape}, {sin_litgpt.shape}") + print(f"HF cos/sin shape: {cos_hf.shape}, {sin_hf.shape}") + print(f"Cos max diff: {(cos_litgpt - cos_hf).abs().max()}") + print(f"Sin max diff: {(sin_litgpt - sin_hf).abs().max()}") + + # Use the same embeddings for both (LitGPT's) + cos = cos_litgpt + sin = sin_litgpt causal_mask = torch.triu( torch.full((seq_len, seq_len), float("-inf"), device=device, dtype=hidden_states.dtype), diagonal=1 @@ -106,8 +142,21 @@ def test_deepseek_v3_block_with_yarn(batch_size, seq_len, device): output_litgpt = block_litgpt(hidden_states, cos, sin) output_hf = block_hf(hidden_states, position_embeddings=(cos, sin), attention_mask=attention_mask) + max_diff = (output_litgpt - output_hf).abs().max() + print(f"\n=== DEBUG INFO ===") + print(f"Max diff: {max_diff}") + print(f"Output litgpt mean: {output_litgpt.mean()}, std: {output_litgpt.std()}") + print(f"Output hf mean: {output_hf.mean()}, std: {output_hf.std()}") + print(f"Cos/sin shape: {cos.shape}, {sin.shape}") + print(f"Hidden states shape: {hidden_states.shape}") + + # Check if the issue is in attention or MLP + if hasattr(output_litgpt, 'shape') and hasattr(output_hf, 'shape'): + if output_litgpt.shape != output_hf.shape: + print(f"Shape mismatch! litgpt: {output_litgpt.shape}, hf: {output_hf.shape}") + assert torch.allclose(output_litgpt, output_hf, atol=1e-5, rtol=1e-4), ( - f"Max diff: {(output_litgpt - output_hf).abs().max()}" + f"FAILED: Max diff: {max_diff}" ) From 2d91d5be8893ffee03a305bbc74bde96cddbe928 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 15 Feb 2026 20:09:23 +0000 Subject: [PATCH 27/51] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_yarn.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/tests/test_yarn.py b/tests/test_yarn.py index 8491934ef8..5b21bb0517 100644 --- a/tests/test_yarn.py +++ b/tests/test_yarn.py @@ -123,7 +123,7 @@ def test_deepseek_v3_block_with_yarn(batch_size, seq_len, device): sin_litgpt = sin_litgpt.unsqueeze(0).expand(batch_size, -1, -1) # Compare RoPE embeddings first - print(f"\n=== RoPE Embeddings Comparison ===") + print("\n=== RoPE Embeddings Comparison ===") print(f"LitGPT cos/sin shape: {cos_litgpt.shape}, {sin_litgpt.shape}") print(f"HF cos/sin shape: {cos_hf.shape}, {sin_hf.shape}") print(f"Cos max diff: {(cos_litgpt - cos_hf).abs().max()}") @@ -143,7 +143,7 @@ def test_deepseek_v3_block_with_yarn(batch_size, seq_len, device): output_hf = block_hf(hidden_states, position_embeddings=(cos, sin), attention_mask=attention_mask) max_diff = (output_litgpt - output_hf).abs().max() - print(f"\n=== DEBUG INFO ===") + print("\n=== DEBUG INFO ===") print(f"Max diff: {max_diff}") print(f"Output litgpt mean: {output_litgpt.mean()}, std: {output_litgpt.std()}") print(f"Output hf mean: {output_hf.mean()}, std: {output_hf.std()}") @@ -151,13 +151,11 @@ def test_deepseek_v3_block_with_yarn(batch_size, seq_len, device): print(f"Hidden states shape: {hidden_states.shape}") # Check if the issue is in attention or MLP - if hasattr(output_litgpt, 'shape') and hasattr(output_hf, 'shape'): + if hasattr(output_litgpt, "shape") and hasattr(output_hf, "shape"): if output_litgpt.shape != output_hf.shape: print(f"Shape mismatch! litgpt: {output_litgpt.shape}, hf: {output_hf.shape}") - assert torch.allclose(output_litgpt, output_hf, atol=1e-5, rtol=1e-4), ( - f"FAILED: Max diff: {max_diff}" - ) + assert torch.allclose(output_litgpt, output_hf, atol=1e-5, rtol=1e-4), f"FAILED: Max diff: {max_diff}" def sync_weights(litgpt_model, hf_model): From 05902fefae57a4a7a9621cb5353b1a166c2ba400 Mon Sep 17 00:00:00 2001 From: Yu Shi Jie Date: Sun, 15 Feb 2026 15:11:44 -0500 Subject: [PATCH 28/51] fix --- tests/test_yarn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_yarn.py b/tests/test_yarn.py index 8491934ef8..39f30e4dd7 100644 --- a/tests/test_yarn.py +++ b/tests/test_yarn.py @@ -113,8 +113,8 @@ def test_deepseek_v3_block_with_yarn(batch_size, seq_len, device): }, ) - # Get YaRN RoPE embeddings from HF - rotary_emb = model_hf.model.layers[layer_idx].self_attn.rotary_emb + # Get YaRN RoPE embeddings from HF (rotary_emb is on model level, not layer level) + rotary_emb = model_hf.model.rotary_emb position_ids = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1) cos_hf, sin_hf = rotary_emb(hidden_states, position_ids) From 73492a5944084ac0062b9f0f448fd342b8d7e16b Mon Sep 17 00:00:00 2001 From: Yu Shi Jie Date: Sun, 15 Feb 2026 15:16:33 -0500 Subject: [PATCH 29/51] debug --- litgpt/model.py | 4 ++++ tests/test_yarn.py | 11 +++++++++++ 2 files changed, 15 insertions(+) diff --git a/litgpt/model.py b/litgpt/model.py index f9d9857511..6a7f7fc73d 100644 --- a/litgpt/model.py +++ b/litgpt/model.py @@ -1023,6 +1023,10 @@ def find_correction_dim(num_rotations, dim, base_val, max_pos): theta_interpolation * (1 - theta_extrapolation_factor) + theta_extrapolation * theta_extrapolation_factor ) + # Debug: print YaRN theta + if extra_config.get("_debug"): + print(f"[YaRN Debug] theta (inv_freq): {theta}") + print(f"[YaRN Debug] attention_scaling: {attention_scaling}") elif "original_max_seq_len" in extra_config: # Llama3-style RoPE scaling orig_context_len = extra_config["original_max_seq_len"] diff --git a/tests/test_yarn.py b/tests/test_yarn.py index 5007f2bbe2..c943c994d8 100644 --- a/tests/test_yarn.py +++ b/tests/test_yarn.py @@ -110,6 +110,7 @@ def test_deepseek_v3_block_with_yarn(batch_size, seq_len, device): "original_max_seq_len": yarn_config["original_max_seq_len"], "mscale": yarn_config["mscale"], "mscale_all_dim": yarn_config["mscale_all_dim"], + "_debug": True, # Enable debug prints }, ) @@ -128,6 +129,16 @@ def test_deepseek_v3_block_with_yarn(batch_size, seq_len, device): print(f"HF cos/sin shape: {cos_hf.shape}, {sin_hf.shape}") print(f"Cos max diff: {(cos_litgpt - cos_hf).abs().max()}") print(f"Sin max diff: {(sin_litgpt - sin_hf).abs().max()}") + print(f"\nLitGPT cos sample [0,0,:4]: {cos_litgpt[0, 0, :4]}") + print(f"HF cos sample [0,0,:4]: {cos_hf[0, 0, :4]}") + print(f"LitGPT cos min/max: {cos_litgpt.min():.4f} / {cos_litgpt.max():.4f}") + print(f"HF cos min/max: {cos_hf.min():.4f} / {cos_hf.max():.4f}") + + # Check inv_freq from both + print(f"\n=== Checking inv_freq ===") + print(f"HF rotary_emb.inv_freq shape: {rotary_emb.inv_freq.shape}") + print(f"HF inv_freq: {rotary_emb.inv_freq}") + print(f"HF attention_scaling: {rotary_emb.attention_scaling}") # Use the same embeddings for both (LitGPT's) cos = cos_litgpt From 7500b9444355ca775f4f570abdeef143ef237abb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 15 Feb 2026 20:16:52 +0000 Subject: [PATCH 30/51] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_yarn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_yarn.py b/tests/test_yarn.py index c943c994d8..f069d38ea3 100644 --- a/tests/test_yarn.py +++ b/tests/test_yarn.py @@ -135,7 +135,7 @@ def test_deepseek_v3_block_with_yarn(batch_size, seq_len, device): print(f"HF cos min/max: {cos_hf.min():.4f} / {cos_hf.max():.4f}") # Check inv_freq from both - print(f"\n=== Checking inv_freq ===") + print("\n=== Checking inv_freq ===") print(f"HF rotary_emb.inv_freq shape: {rotary_emb.inv_freq.shape}") print(f"HF inv_freq: {rotary_emb.inv_freq}") print(f"HF attention_scaling: {rotary_emb.attention_scaling}") From 523d815d6e61e31ced3c9a0443a96d6478ed3e48 Mon Sep 17 00:00:00 2001 From: Yu Shi Jie Date: Sun, 15 Feb 2026 15:19:40 -0500 Subject: [PATCH 31/51] debug --- tests/test_yarn.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/test_yarn.py b/tests/test_yarn.py index c943c994d8..eee9d1d70c 100644 --- a/tests/test_yarn.py +++ b/tests/test_yarn.py @@ -81,6 +81,10 @@ def test_deepseek_v3_block_with_yarn(batch_size, seq_len, device): rope_parameters=rope_parameters, # YaRN config ) + # Debug: Check if HF config has rope_parameters + print(f"\n=== HF Config Debug ===") + print(f"config_hf.rope_parameters: {config_hf.rope_parameters}") + block_litgpt = Block(config_litgpt, block_idx=layer_idx).to(device) model_hf = DeepseekV3ForCausalLM(config_hf).to(device) block_hf = model_hf.model.layers[layer_idx] From 968eea72bac326dd237eb2ae47e2f402d1a9680a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 15 Feb 2026 20:21:42 +0000 Subject: [PATCH 32/51] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_yarn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_yarn.py b/tests/test_yarn.py index ebae3fc9ae..71b46fcbf5 100644 --- a/tests/test_yarn.py +++ b/tests/test_yarn.py @@ -82,7 +82,7 @@ def test_deepseek_v3_block_with_yarn(batch_size, seq_len, device): ) # Debug: Check if HF config has rope_parameters - print(f"\n=== HF Config Debug ===") + print("\n=== HF Config Debug ===") print(f"config_hf.rope_parameters: {config_hf.rope_parameters}") block_litgpt = Block(config_litgpt, block_idx=layer_idx).to(device) From 90b7dd488fb5e093900625881ac77bd8aac74a6b Mon Sep 17 00:00:00 2001 From: Yu Shi Jie Date: Sun, 15 Feb 2026 20:38:38 +0000 Subject: [PATCH 33/51] fix test_yarn --- tests/test_yarn.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_yarn.py b/tests/test_yarn.py index ebae3fc9ae..062848896f 100644 --- a/tests/test_yarn.py +++ b/tests/test_yarn.py @@ -52,7 +52,7 @@ def test_deepseek_v3_block_with_yarn(batch_size, seq_len, device): # HF config with YaRN rope_parameters = { - "rope_type": "yarn", + "type": "yarn", "rope_theta": 10000.0, "factor": yarn_config["factor"], "beta_fast": yarn_config["beta_fast"], @@ -78,12 +78,12 @@ def test_deepseek_v3_block_with_yarn(batch_size, seq_len, device): rope_interleave=True, first_k_dense_replace=3, rms_norm_eps=1e-6, - rope_parameters=rope_parameters, # YaRN config + rope_scaling=rope_parameters, # YaRN config ) # Debug: Check if HF config has rope_parameters print(f"\n=== HF Config Debug ===") - print(f"config_hf.rope_parameters: {config_hf.rope_parameters}") + print(f"config_hf.rope_parameters: {config_hf.rope_scaling}") block_litgpt = Block(config_litgpt, block_idx=layer_idx).to(device) model_hf = DeepseekV3ForCausalLM(config_hf).to(device) From 7c29e7178d4e9fcc7c8988f23e29650ab7575cff Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 15 Feb 2026 20:41:38 +0000 Subject: [PATCH 34/51] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_yarn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_yarn.py b/tests/test_yarn.py index 062848896f..0fccfedf24 100644 --- a/tests/test_yarn.py +++ b/tests/test_yarn.py @@ -82,7 +82,7 @@ def test_deepseek_v3_block_with_yarn(batch_size, seq_len, device): ) # Debug: Check if HF config has rope_parameters - print(f"\n=== HF Config Debug ===") + print("\n=== HF Config Debug ===") print(f"config_hf.rope_parameters: {config_hf.rope_scaling}") block_litgpt = Block(config_litgpt, block_idx=layer_idx).to(device) From c96b79c8a2b4a44a62193b3c43c2d2594e9b0bc2 Mon Sep 17 00:00:00 2001 From: Yu Shi Jie Date: Sun, 15 Feb 2026 17:41:35 -0500 Subject: [PATCH 35/51] rm debug --- litgpt/model.py | 4 ---- tests/test_yarn.py | 1 - 2 files changed, 5 deletions(-) diff --git a/litgpt/model.py b/litgpt/model.py index 6a7f7fc73d..f9d9857511 100644 --- a/litgpt/model.py +++ b/litgpt/model.py @@ -1023,10 +1023,6 @@ def find_correction_dim(num_rotations, dim, base_val, max_pos): theta_interpolation * (1 - theta_extrapolation_factor) + theta_extrapolation * theta_extrapolation_factor ) - # Debug: print YaRN theta - if extra_config.get("_debug"): - print(f"[YaRN Debug] theta (inv_freq): {theta}") - print(f"[YaRN Debug] attention_scaling: {attention_scaling}") elif "original_max_seq_len" in extra_config: # Llama3-style RoPE scaling orig_context_len = extra_config["original_max_seq_len"] diff --git a/tests/test_yarn.py b/tests/test_yarn.py index 0fccfedf24..a613d590ca 100644 --- a/tests/test_yarn.py +++ b/tests/test_yarn.py @@ -114,7 +114,6 @@ def test_deepseek_v3_block_with_yarn(batch_size, seq_len, device): "original_max_seq_len": yarn_config["original_max_seq_len"], "mscale": yarn_config["mscale"], "mscale_all_dim": yarn_config["mscale_all_dim"], - "_debug": True, # Enable debug prints }, ) From 943a55e053168360b82e7bcbb827b0ebadda4451 Mon Sep 17 00:00:00 2001 From: Yu Shi Jie Date: Fri, 20 Mar 2026 07:38:16 -0700 Subject: [PATCH 36/51] fix deepseekv3 test rope params --- tests/test_model_deepseek_v3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_model_deepseek_v3.py b/tests/test_model_deepseek_v3.py index 29c6a1b73e..7cd2f672ec 100644 --- a/tests/test_model_deepseek_v3.py +++ b/tests/test_model_deepseek_v3.py @@ -60,7 +60,7 @@ def test_against_original_deepseek_v3(model_name, device, dtype): v_head_dim=8, # Same as qk_nope_head_dim ), rope_adjustments=dict( - factor=40.0, low_freq_factor=1.0, high_freq_factor=32.0, original_max_seq_len=4096, mscale_all_dim=1.0 + factor=40.0, beta_slow=1.0, beta_fast=32.0, original_max_seq_len=4096, mscale_all_dim=1.0 ), ) theirs_config = DeepseekV3Config( From d45e69a55bd8df88deedf24b2be1dc1658991376 Mon Sep 17 00:00:00 2001 From: Yu Shi Jie Date: Fri, 20 Mar 2026 12:13:18 -0700 Subject: [PATCH 37/51] deepseek v3 support --- litgpt/api.py | 6 ++++ litgpt/generate/sequentially.py | 4 +++ litgpt/generate/tp.py | 4 +++ litgpt/model.py | 9 +++--- litgpt/utils.py | 48 ++++++++++++++++++++++++++++++ tests/test_model_deepseek_v3.py | 52 ++++----------------------------- 6 files changed, 71 insertions(+), 52 deletions(-) diff --git a/litgpt/api.py b/litgpt/api.py index 32cc196603..46ae8dbf28 100644 --- a/litgpt/api.py +++ b/litgpt/api.py @@ -22,6 +22,7 @@ from litgpt.prompts import PromptStyle, has_prompt_style, load_prompt_style, save_prompt_style from litgpt.tokenizer import Tokenizer from litgpt.utils import ( + _has_fp8_weights, auto_download_checkpoint, check_file_size_on_cpu_and_warn, check_nvlink_connectivity, @@ -30,6 +31,7 @@ extend_checkpoint_dir, get_default_supported_precision, load_checkpoint, + patch_linear_for_fp8, save_config, ) @@ -396,6 +398,8 @@ def distribute( state_dict = torch.load( str(self.checkpoint_dir / "lit_model.pth"), mmap=True, map_location="cpu", weights_only=False ) + if _has_fp8_weights(state_dict): + patch_linear_for_fp8(model) model.load_state_dict(state_dict, assign=True) model = fabric.setup_module(model, move_to_device=False) @@ -420,6 +424,8 @@ def distribute( map_location="cpu", weights_only=False, ) + if _has_fp8_weights(state_dict): + patch_linear_for_fp8(model) model.load_state_dict(state_dict, assign=True) # cannot use `.setup_module` because it will wrap with DDP diff --git a/litgpt/generate/sequentially.py b/litgpt/generate/sequentially.py index 04a60bacae..e83ef39783 100644 --- a/litgpt/generate/sequentially.py +++ b/litgpt/generate/sequentially.py @@ -26,9 +26,11 @@ from litgpt.prompts import PromptStyle, has_prompt_style, load_prompt_style from litgpt.tokenizer import Tokenizer from litgpt.utils import ( + _has_fp8_weights, check_valid_checkpoint_dir, extend_checkpoint_dir, get_default_supported_precision, + patch_linear_for_fp8, ) @@ -240,6 +242,8 @@ def main( t0 = time.perf_counter() state_dict = torch.load(str(checkpoint_path), mmap=True, map_location="cpu") + if _has_fp8_weights(state_dict): + patch_linear_for_fp8(model) # TODO: this assumes that the model fits on CPU. Use lazy_load and make the materialization checkpoint aware model.load_state_dict(state_dict, assign=True) print(f"Time to load the model weights: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr) diff --git a/litgpt/generate/tp.py b/litgpt/generate/tp.py index a4030eec1f..d3329f77c2 100644 --- a/litgpt/generate/tp.py +++ b/litgpt/generate/tp.py @@ -23,10 +23,12 @@ from litgpt.prompts import PromptStyle, has_prompt_style, load_prompt_style from litgpt.tokenizer import Tokenizer from litgpt.utils import ( + _has_fp8_weights, check_nvlink_connectivity, check_valid_checkpoint_dir, extend_checkpoint_dir, get_default_supported_precision, + patch_linear_for_fp8, ) @@ -205,6 +207,8 @@ def main( if fabric.global_rank == rank: t0 = time.perf_counter() state_dict = torch.load(str(checkpoint_path), mmap=True, map_location="cpu") + if _has_fp8_weights(state_dict): + patch_linear_for_fp8(model) model.load_state_dict(state_dict, assign=True) print(f"[{rank}] Time to load the model weights: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr) diff --git a/litgpt/model.py b/litgpt/model.py index f9d9857511..86693b82d2 100644 --- a/litgpt/model.py +++ b/litgpt/model.py @@ -1006,7 +1006,7 @@ def find_correction_dim(num_rotations, dim, base_val, max_pos): high_dim = math.ceil(high_dim) low_dim = max(low_dim, 0) - high_dim = min(high_dim, n_elem // 2 - 1) + high_dim = min(high_dim, n_elem - 1) # Create linear ramp factor for blending dim_range = torch.arange(n_elem // 2, device=device, dtype=torch.float32) @@ -1017,11 +1017,10 @@ def find_correction_dim(num_rotations, dim, base_val, max_pos): ramp_func = torch.clamp(linear_func, 0.0, 1.0) # Blend extrapolation and interpolation frequencies - # ramp_func = 0 -> use interpolation (scaled), ramp_func = 1 -> use extrapolation (unscaled) - theta_extrapolation_factor = ramp_func + # ramp_func = 0 -> use extrapolation (unscaled), ramp_func = 1 -> use interpolation (scaled) theta = ( - theta_interpolation * (1 - theta_extrapolation_factor) - + theta_extrapolation * theta_extrapolation_factor + theta_interpolation * ramp_func + + theta_extrapolation * (1 - ramp_func) ) elif "original_max_seq_len" in extra_config: # Llama3-style RoPE scaling diff --git a/litgpt/utils.py b/litgpt/utils.py index f5702043de..87fcf0138b 100644 --- a/litgpt/utils.py +++ b/litgpt/utils.py @@ -379,11 +379,57 @@ def get_default_supported_precision(training: bool) -> str: return "bf16-mixed" if training else "bf16-true" +def _has_fp8_weights(state_dict: Dict[str, Any]) -> bool: + """Check if a state dict contains FP8 weight_scale_inv tensors.""" + return any(k.endswith(".weight_scale_inv") for k in state_dict) + + +def patch_linear_for_fp8(model: nn.Module) -> nn.Module: + """Replace nn.Linear modules with FP8Linear for loading FP8 weights (e.g. DeepSeek-V3). + + Must be called after GPT(config) but before load_state_dict(). + Only replaces modules whose names match known FP8-quantized layers. + """ + from transformers.integrations.finegrained_fp8 import FP8Linear + + fp8_targets = ( + "q_a_proj", "q_b_proj", "kv_a_proj_with_mqa", "kv_b_proj", "proj", + "fc_1", "fc_2", + ) + + modules_to_replace = [] + for name, module in model.named_modules(): + if isinstance(module, nn.Linear) and name.split(".")[-1] in fp8_targets: + modules_to_replace.append((name, module)) + + for name, module in modules_to_replace: + with torch.device("meta"): + new_module = FP8Linear( + in_features=module.in_features, + out_features=module.out_features, + bias=module.bias is not None, + activation_scheme="dynamic", + block_size=(128, 128), + ) + new_module = new_module.to_empty(device=module.weight.device) + + # Navigate to parent and replace the child module + parts = name.split(".") + parent = model + for part in parts[:-1]: + parent = getattr(parent, part) + setattr(parent, parts[-1], new_module) + + return model + + def load_checkpoint(fabric: L.Fabric, model: nn.Module, checkpoint_path: Path, strict: bool = True) -> None: if isinstance(fabric.strategy, FSDPStrategy): fabric.load_raw(checkpoint_path, model, strict=strict) elif isinstance(fabric.strategy, ModelParallelStrategy): state_dict = torch.load(checkpoint_path, mmap=True) + if _has_fp8_weights(state_dict): + patch_linear_for_fp8(model) load_from_full_model_state_dict( model=model, full_sd=state_dict, @@ -394,6 +440,8 @@ def load_checkpoint(fabric: L.Fabric, model: nn.Module, checkpoint_path: Path, s else: state_dict = lazy_load(checkpoint_path) state_dict = state_dict.get("model", state_dict) + if _has_fp8_weights(state_dict): + patch_linear_for_fp8(model) model.load_state_dict(state_dict, strict=strict) diff --git a/tests/test_model_deepseek_v3.py b/tests/test_model_deepseek_v3.py index 7cd2f672ec..6089bbaa86 100644 --- a/tests/test_model_deepseek_v3.py +++ b/tests/test_model_deepseek_v3.py @@ -3,8 +3,6 @@ import pytest import torch -import torch.nn as nn -from transformers.integrations.finegrained_fp8 import FP8Linear from transformers.models.deepseek_v3 import DeepseekV3Config, DeepseekV3ForCausalLM from litgpt import GPT, Config @@ -60,7 +58,7 @@ def test_against_original_deepseek_v3(model_name, device, dtype): v_head_dim=8, # Same as qk_nope_head_dim ), rope_adjustments=dict( - factor=40.0, beta_slow=1.0, beta_fast=32.0, original_max_seq_len=4096, mscale_all_dim=1.0 + factor=40.0, beta_slow=1.0, beta_fast=32.0, original_max_seq_len=4096, mscale=1.0, mscale_all_dim=1.0 ), ) theirs_config = DeepseekV3Config( @@ -89,13 +87,14 @@ def test_against_original_deepseek_v3(model_name, device, dtype): v_head_dim=ours_config.latent_attention["v_head_dim"], q_lora_rank=ours_config.latent_attention["q_lora_rank"], kv_lora_rank=ours_config.latent_attention["kv_lora_rank"], - rope_parameters=( + rope_scaling=( { "type": "yarn", "factor": ours_config.rope_adjustments["factor"], - "beta_slow": ours_config.rope_adjustments["low_freq_factor"], - "beta_fast": ours_config.rope_adjustments["high_freq_factor"], + "beta_slow": ours_config.rope_adjustments["beta_slow"], + "beta_fast": ours_config.rope_adjustments["beta_fast"], "original_max_position_embeddings": ours_config.rope_adjustments["original_max_seq_len"], + "mscale": 1.0, "mscale_all_dim": 1.0, } if ours_config.rope_adjustments @@ -108,7 +107,6 @@ def test_against_original_deepseek_v3(model_name, device, dtype): state_dict = {} copy_weights_deepseek_v3(ours_config, {}, state_dict, theirs_state_dict) ours_model = GPT(ours_config) - # ours_model = patch_deepseek_v3(ours_model) ours_model.to(device) ours_model.load_state_dict(state_dict) @@ -118,43 +116,3 @@ def test_against_original_deepseek_v3(model_name, device, dtype): ours_y = ours_model(x) theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float torch.testing.assert_close(ours_y, theirs_y) - - -def patch_deepseek_v3(model: GPT): - to_replace = [ - "attn.q_a_proj", - "attn.q_b_proj", - "attn.kv_a_proj_with_mqa", - "attn.kv_b_proj", - "attn.proj", - "mlp.fc_1", - "mlp.fc_2", - "mlp.proj", - "mlp.experts", - "mlp.shared_experts", - ] - modules_to_replace = [] - for name, module in model.named_modules(): - if isinstance(module, nn.Linear) and any(target in name for target in to_replace): - modules_to_replace.append((name, module)) - - for name, module in modules_to_replace: - with torch.device("meta"): - new_module = FP8Linear( - in_features=module.in_features, - out_features=module.out_features, - bias=module.bias is not None, - activation_scheme="dynamic", - block_size=(128, 128), - ) - - # Use to_empty() to move from meta device - new_module = new_module.to_empty(device=module.weight.device) - - # Copy weights and bias - new_module.weight.data = module.weight.data.clone() - if module.bias is not None: - new_module.bias.data = module.bias.data.clone() - - model.set_submodule(name, new_module) - return model From 647c2c21c946a6c055034bcfe64fe5657a71affc Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 20 Mar 2026 19:13:34 +0000 Subject: [PATCH 38/51] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- litgpt/model.py | 5 +---- litgpt/utils.py | 9 +++++++-- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/litgpt/model.py b/litgpt/model.py index 86693b82d2..509b6ef66a 100644 --- a/litgpt/model.py +++ b/litgpt/model.py @@ -1018,10 +1018,7 @@ def find_correction_dim(num_rotations, dim, base_val, max_pos): # Blend extrapolation and interpolation frequencies # ramp_func = 0 -> use extrapolation (unscaled), ramp_func = 1 -> use interpolation (scaled) - theta = ( - theta_interpolation * ramp_func - + theta_extrapolation * (1 - ramp_func) - ) + theta = theta_interpolation * ramp_func + theta_extrapolation * (1 - ramp_func) elif "original_max_seq_len" in extra_config: # Llama3-style RoPE scaling orig_context_len = extra_config["original_max_seq_len"] diff --git a/litgpt/utils.py b/litgpt/utils.py index 87fcf0138b..6398578c06 100644 --- a/litgpt/utils.py +++ b/litgpt/utils.py @@ -393,8 +393,13 @@ def patch_linear_for_fp8(model: nn.Module) -> nn.Module: from transformers.integrations.finegrained_fp8 import FP8Linear fp8_targets = ( - "q_a_proj", "q_b_proj", "kv_a_proj_with_mqa", "kv_b_proj", "proj", - "fc_1", "fc_2", + "q_a_proj", + "q_b_proj", + "kv_a_proj_with_mqa", + "kv_b_proj", + "proj", + "fc_1", + "fc_2", ) modules_to_replace = [] From c246bfe91d901865a9aba5f03a5c11a98df6d7ab Mon Sep 17 00:00:00 2001 From: Yu Shi Jie Date: Fri, 20 Mar 2026 12:21:37 -0700 Subject: [PATCH 39/51] feat(marketing): add deepseek models to markdowns --- README.md | 2 ++ tutorials/download_model_weights.md | 2 ++ 2 files changed, 4 insertions(+) diff --git a/README.md b/README.md index 6723fb7146..473f190dae 100644 --- a/README.md +++ b/README.md @@ -115,6 +115,7 @@ Every model is written from scratch to maximize performance and remove layers of | Phi 4 | 14B | Microsoft Research | [Abdin et al. 2024](https://arxiv.org/abs/2412.08905) | | Qwen2.5 | 0.5B, 1.5B, 3B, 7B, 14B, 32B, 72B | Alibaba Group | [Qwen Team 2024](https://qwenlm.github.io/blog/qwen2.5/) | | Qwen2.5 Coder | 0.5B, 1.5B, 3B, 7B, 14B, 32B | Alibaba Group | [Hui, Binyuan et al. 2024](https://arxiv.org/abs/2409.12186) | +| DeepSeek-V3 | 671B | DeepSeek AI | [DeepSeek AI 2024](https://huggingface.co/deepseek-ai/DeepSeek-V3) | | R1 Distill Llama | 8B, 70B | DeepSeek AI | [DeepSeek AI 2025](https://github.com/deepseek-ai/DeepSeek-R1/blob/main/DeepSeek_R1.pdf) | | ... | ... | ... | ... | @@ -129,6 +130,7 @@ Every model is written from scratch to maximize performance and remove layers of |----|----|----|----| | CodeGemma | 7B | Google | [Google Team, Google Deepmind](https://ai.google.dev/gemma/docs/codegemma) | | Code Llama | 7B, 13B, 34B, 70B | Meta AI | [Rozière et al. 2023](https://arxiv.org/abs/2308.12950) | +| DeepSeek-V3 | 671B | DeepSeek AI | [DeepSeek AI 2024](https://huggingface.co/deepseek-ai/DeepSeek-V3) | | Falcon | 7B, 40B, 180B | TII UAE | [TII 2023](https://falconllm.tii.ae) | | Falcon 3 | 1B, 3B, 7B, 10B | TII UAE | [TII 2024](https://huggingface.co/blog/falcon3) | | FreeWilly2 (Stable Beluga 2) | 70B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stable-beluga-large-instruction-fine-tuned-models) | diff --git a/tutorials/download_model_weights.md b/tutorials/download_model_weights.md index fcd3111ea6..584ec7e524 100644 --- a/tutorials/download_model_weights.md +++ b/tutorials/download_model_weights.md @@ -10,6 +10,7 @@ LitGPT supports a variety of LLM architectures with publicly available weights. | CodeGemma | 7B | Google | [Google Team, Google Deepmind](https://ai.google.dev/gemma/docs/codegemma) | | Code Llama | 7B, 13B, 34B, 70B | Meta AI | [Rozière et al. 2023](https://arxiv.org/abs/2308.12950) | | Danube2 | 1.8B | H2O.ai | [H2O.ai](https://h2o.ai/platform/danube-1-8b/) | +| DeepSeek-V3 | 671B | DeepSeek AI | [DeepSeek AI 2024](https://huggingface.co/deepseek-ai/DeepSeek-V3) | | Dolly | 3B, 7B, 12B | Databricks | [Conover et al. 2023](https://www.databricks.com/blog/2023/04/12/dolly-first-open-commercially-viable-instruction-tuned-llm) | | Falcon | 7B, 40B, 180B | TII UAE | [TII 2023](https://falconllm.tii.ae) | | Falcon 3 | 1B, 3B, 7B, 10B | TII UAE | [TII 2024](https://huggingface.co/blog/falcon3) | @@ -99,6 +100,7 @@ databricks/dolly-v2-3b databricks/dolly-v2-7b deepseek-ai/DeepSeek-R1-Distill-Llama-8B deepseek-ai/DeepSeek-R1-Distill-Llama-70B +deepseek-ai/DeepSeek-V3 EleutherAI/pythia-1.4b EleutherAI/pythia-1.4b-deduped EleutherAI/pythia-12b From 7e6829a1b66ff9b61042fe09c0655e90af2f8394 Mon Sep 17 00:00:00 2001 From: Yu Shi Jie Date: Fri, 20 Mar 2026 12:25:31 -0700 Subject: [PATCH 40/51] clean up --- litgpt/configuration_deepseek_v3.py | 248 -------- litgpt/modeling_deepseek_v3.py | 776 ----------------------- litgpt/modeling_rope_util.py | 942 ---------------------------- 3 files changed, 1966 deletions(-) delete mode 100644 litgpt/configuration_deepseek_v3.py delete mode 100644 litgpt/modeling_deepseek_v3.py delete mode 100644 litgpt/modeling_rope_util.py diff --git a/litgpt/configuration_deepseek_v3.py b/litgpt/configuration_deepseek_v3.py deleted file mode 100644 index ea56e5a88d..0000000000 --- a/litgpt/configuration_deepseek_v3.py +++ /dev/null @@ -1,248 +0,0 @@ -# Copyright 2025 bzantium and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on the DeepSeekV3 implementations from the DeepSeek AI team. (https://huggingface.co/deepseek-ai/DeepSeek-V3) - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""DeepSeekV3 model configuration""" - -from ...configuration_utils import PreTrainedConfig -from ...modeling_rope_utils import RopeParameters - -DEEPSEEK_PRETRAINED_CONFIG_ARCHIVE_MAP = {} - - -class DeepseekV3Config(PreTrainedConfig): - r""" - This is the configuration class to store the configuration of a [`DeepseekV3Model`]. It is used to instantiate an DeepSeek - model according to the specified arguments, defining the model architecture. Instantiating a configuration with the - defaults will yield a similar configuration to that of the DeepSeek-V3. - e.g. [bzantium/tiny-deepseek-v3](https://huggingface.co/bzantium/tiny-deepseek-v3) - Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the - documentation from [`PreTrainedConfig`] for more information. - - - Args: - vocab_size (`int`, *optional*, defaults to 129280): - Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the - `inputs_ids` passed when calling [`DeepseekV3Model`] - hidden_size (`int`, *optional*, defaults to 7168): - Dimension of the hidden representations. - intermediate_size (`int`, *optional*, defaults to 18432): - Dimension of the MLP representations. - moe_intermediate_size (`int`, *optional*, defaults to 2048): - Dimension of the MoE representations. - num_hidden_layers (`int`, *optional*, defaults to 61): - Number of hidden layers in the Transformer decoder. - num_attention_heads (`int`, *optional*, defaults to 128): - Number of attention heads for each attention layer in the Transformer decoder. - num_key_value_heads (`int`, *optional*, defaults to 128): - This is the number of key_value heads that should be used to implement Grouped Query Attention. If - `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if - `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When - converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed - by meanpooling all the original heads within that group. For more details, check out [this - paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to - `num_attention_heads`. - n_shared_experts (`int`, *optional*, defaults to 1): - Number of shared experts. - n_routed_experts (`int`, *optional*, defaults to 256): - Number of routed experts. - routed_scaling_factor (`float`, *optional*, defaults to 2.5): - Scaling factor or routed experts. - kv_lora_rank (`int`, *optional*, defaults to 512): - Rank of the LoRA matrices for key and value projections. - q_lora_rank (`int`, *optional*, defaults to 1536): - Rank of the LoRA matrices for query projections. - qk_rope_head_dim (`int`, *optional*, defaults to 64): - Dimension of the query/key heads that use rotary position embeddings. - v_head_dim (`int`, *optional*, defaults to 128): - Dimension of the value heads. - qk_nope_head_dim (`int`, *optional*, defaults to 128): - Dimension of the query/key heads that don't use rotary position embeddings. - n_group (`int`, *optional*, defaults to 8): - Number of groups for routed experts. - topk_group (`int`, *optional*, defaults to 4): - Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups). - num_experts_per_tok (`int`, *optional*, defaults to 8): - Number of selected experts, None means dense model. - first_k_dense_replace (`int`, *optional*, defaults to 3): - Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head). - \--k dense layers--/ - norm_topk_prob (`bool`, *optional*, defaults to `True`): - Whether to normalize the weights of the routed experts. - hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): - The non-linear activation function (function or string) in the decoder. - max_position_embeddings (`int`, *optional*, defaults to 4096): - The maximum sequence length that this model might ever be used with. - initializer_range (`float`, *optional*, defaults to 0.02): - The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - rms_norm_eps (`float`, *optional*, defaults to 1e-06): - The epsilon used by the rms normalization layers. - use_cache (`bool`, *optional*, defaults to `True`): - Whether or not the model should return the last key/values attentions (not used by all models). Only - relevant if `config.is_decoder=True`. - pad_token_id (`int`, *optional*): - Padding token id. - bos_token_id (`int`, *optional*, defaults to 0): - Beginning of stream token id. - eos_token_id (`int`, *optional*, defaults to 1): - End of stream token id. - pretraining_tp (`int`, *optional*, defaults to 1): - Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this - document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is - necessary to ensure exact reproducibility of the pretraining results. Please refer to [this - issue](https://github.com/pytorch/pytorch/issues/76232). - tie_word_embeddings (`bool`, *optional*, defaults to `False`): - Whether to tie weight embeddings - rope_parameters (`RopeParameters`, *optional*): - Dictionary containing the configuration parameters for the RoPE embeddings. The dictionary should contain - a value for `rope_theta` and optionally parameters used for scaling in case you want to use RoPE - with longer `max_position_embeddings`. - rope_interleave (`bool`, *optional*, defaults to `True`): - Whether to interleave the rotary position embeddings. - attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): - Whether to use a bias in the query, key, value and output projection layers during self-attention. - attention_dropout (`float`, *optional*, defaults to 0.0): - The dropout ratio for the attention probabilities. - - ```python - >>> from transformers import DeepseekV3Model, DeepseekV3Config - - >>> # Initializing a Deepseek-V3 style configuration - >>> configuration = DeepseekV3Config() - - >>> # Accessing the model configuration - >>> configuration = model.config - ```""" - - model_type = "deepseek_v3" - keys_to_ignore_at_inference = ["past_key_values"] - base_model_tp_plan = { - "layers.*.mlp.experts.gate_up_proj": "rowwise", - "layers.*.mlp.experts.down_proj": "rowwise", - "layers.*.mlp.shared_experts.gate_proj": "colwise", - "layers.*.mlp.shared_experts.up_proj": "colwise", - "layers.*.mlp.shared_experts.down_proj": "rowwise", - "layers.*.mlp.gate_proj": "colwise", - "layers.*.mlp.up_proj": "colwise", - "layers.*.mlp.down_proj": "rowwise", - } - base_model_pp_plan = { - "embed_tokens": (["input_ids"], ["inputs_embeds"]), - "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), - "norm": (["hidden_states"], ["hidden_states"]), - } - attribute_map = { - "num_local_experts": "n_routed_experts", - } - - def __init__( - self, - vocab_size: int | None = 129280, - hidden_size: int | None = 7168, - intermediate_size: int | None = 18432, - moe_intermediate_size: int | None = 2048, - num_hidden_layers: int | None = 61, - num_attention_heads: int | None = 128, - num_key_value_heads: int | None = 128, - n_shared_experts: int | None = 1, - n_routed_experts: int | None = 256, - routed_scaling_factor: float | None = 2.5, - kv_lora_rank: int | None = 512, - q_lora_rank: int | None = 1536, - qk_rope_head_dim: int | None = 64, - v_head_dim: int | None = 128, - qk_nope_head_dim: int | None = 128, - n_group: int | None = 8, - topk_group: int | None = 4, - num_experts_per_tok: int | None = 8, - first_k_dense_replace: int | None = 3, - norm_topk_prob: bool | None = True, - hidden_act: str | None = "silu", - max_position_embeddings: int | None = 4096, - initializer_range: float | None = 0.02, - rms_norm_eps: int | None = 1e-6, - use_cache: bool | None = True, - pad_token_id: int | None = None, - bos_token_id: int | None = 0, - eos_token_id: int | None = 1, - pretraining_tp: int | None = 1, - tie_word_embeddings: bool | None = False, - rope_parameters: RopeParameters | dict[str, RopeParameters] | None = None, - rope_interleave: bool | None = True, - attention_bias: bool | None = False, - attention_dropout: float | None = 0.0, - **kwargs, - ): - self.vocab_size = vocab_size - self.max_position_embeddings = max_position_embeddings - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.moe_intermediate_size = moe_intermediate_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.n_shared_experts = n_shared_experts - self.n_routed_experts = n_routed_experts - self.routed_scaling_factor = routed_scaling_factor - self.kv_lora_rank = kv_lora_rank - self.q_lora_rank = q_lora_rank - self.qk_rope_head_dim = qk_rope_head_dim - self.v_head_dim = v_head_dim - self.qk_nope_head_dim = qk_nope_head_dim - self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim - self.head_dim = qk_rope_head_dim - self.n_group = n_group - self.topk_group = topk_group - self.num_experts_per_tok = num_experts_per_tok - self.first_k_dense_replace = first_k_dense_replace - self.norm_topk_prob = norm_topk_prob - self.rope_interleave = rope_interleave - - # for backward compatibility - if num_key_value_heads is None: - num_key_value_heads = num_attention_heads - - self.num_key_value_heads = num_key_value_heads - self.hidden_act = hidden_act - self.initializer_range = initializer_range - self.rms_norm_eps = rms_norm_eps - self.pretraining_tp = pretraining_tp - self.use_cache = use_cache - self.attention_bias = attention_bias - self.attention_dropout = attention_dropout - self.rope_parameters = rope_parameters - - self.tie_word_embeddings = tie_word_embeddings - self.pad_token_id = pad_token_id - self.bos_token_id = bos_token_id - self.eos_token_id = eos_token_id - super().__init__(**kwargs) - - def convert_rope_params_to_dict(self, ignore_keys_at_rope_validation: set | None = None, **kwargs): - rope_scaling = kwargs.pop("rope_scaling", None) - self.rope_parameters = rope_scaling or self.rope_parameters - self.rope_parameters = self.rope_parameters if self.rope_parameters is not None else {} - - # Standardize and validate the correctness of rotary position embeddings parameters - self.rope_parameters.setdefault("rope_theta", kwargs.pop("rope_theta", self.default_theta)) - self.standardize_rope_params() - self.validate_rope(ignore_keys=ignore_keys_at_rope_validation) - - # Convert to float because RoPE fn expect a float. Models on the hub were saved as int - for key in ["beta_fast", "beta_slow", "factor"]: - if key in self.rope_parameters: - self.rope_parameters[key] = float(self.rope_parameters[key]) - return kwargs - - -__all__ = ["DeepseekV3Config"] diff --git a/litgpt/modeling_deepseek_v3.py b/litgpt/modeling_deepseek_v3.py deleted file mode 100644 index d4151f39c4..0000000000 --- a/litgpt/modeling_deepseek_v3.py +++ /dev/null @@ -1,776 +0,0 @@ -# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -# This file was automatically generated from src/transformers/models/deepseek_v3/modular_deepseek_v3.py. -# Do NOT edit this file manually as any edits will be overwritten by the generation of -# the file from the modular. If any change should be done, please apply the change to the -# modular_deepseek_v3.py file directly. One of our CI enforces this. -# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 -import math -from collections.abc import Callable -from typing import Optional - -import torch -import torch.nn.functional as F -from torch import nn - -from ... import initialization as init -from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache -from ...generation import GenerationMixin -from ...integrations import use_experts_implementation, use_kernel_forward_from_hub, use_kernel_func_from_hub -from ...masking_utils import create_causal_mask -from ...modeling_flash_attention_utils import FlashAttentionKwargs -from ...modeling_layers import ( - GenericForSequenceClassification, - GenericForTokenClassification, - GradientCheckpointingLayer, -) -from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast -from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update -from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel -from ...processing_utils import Unpack -from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_grouped_mm_available -from ...utils.generic import is_flash_attention_requested, maybe_autocast, merge_with_config_defaults -from ...utils.output_capturing import capture_outputs -from .configuration_deepseek_v3 import DeepseekV3Config - - -@use_kernel_forward_from_hub("RMSNorm") -class DeepseekV3RMSNorm(nn.Module): - def __init__(self, hidden_size, eps: float = 1e-6) -> None: - """ - DeepseekV3RMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - def extra_repr(self): - return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" - - -# ============================================================================ -# ROPE COMPONENT 1: Core RoPE Class - Rotary Position Embedding Implementation -# ============================================================================ -class DeepseekV3RotaryEmbedding(nn.Module): - inv_freq: torch.Tensor # fix linting for `register_buffer` - - def __init__(self, config: DeepseekV3Config, device=None): - super().__init__() - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings - - self.config = config - - self.rope_type = self.config.rope_parameters["rope_type"] - rope_init_fn: Callable = self.compute_default_rope_parameters - if self.rope_type != "default": - rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - inv_freq, self.attention_scaling = rope_init_fn(self.config, device) - - self.register_buffer("inv_freq", inv_freq, persistent=False) - self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False) - - # ROPE: Compute inverse frequencies for rotary embeddings - @staticmethod - def compute_default_rope_parameters( - config: DeepseekV3Config | None = None, - device: Optional["torch.device"] = None, - seq_len: int | None = None, - ) -> tuple["torch.Tensor", float]: - """ - Computes the inverse frequencies according to the original RoPE implementation - Args: - config ([`~transformers.PreTrainedConfig`]): - The model configuration. - device (`torch.device`): - The device to use for initialization of the inverse frequencies. - seq_len (`int`, *optional*): - The current sequence length. Unused for this type of RoPE. - Returns: - Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the - post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). - """ - base = config.rope_parameters["rope_theta"] - dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads - - attention_factor = 1.0 # Unused in this type - - # Compute the inverse frequencies - inv_freq = 1.0 / ( - base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) - ) - return inv_freq, attention_factor - - # ROPE: Forward pass - generates cos/sin embeddings from position IDs - # build_rope_cache() - @torch.no_grad() - @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) - def forward(self, x, position_ids): - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) - position_ids_expanded = position_ids[:, None, :].float() - - device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with maybe_autocast(device_type=device_type, enabled=False): # Force float32 - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() * self.attention_scaling - sin = emb.sin() * self.attention_scaling - - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - - -# ============================================================================ -# END ROPE COMPONENT 1 -# ============================================================================ - - -class DeepseekV3MLP(nn.Module): - def __init__(self, config, intermediate_size=None): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, x): - down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - return down_proj - - -class DeepseekV3TopkRouter(nn.Module): - def __init__(self, config): - super().__init__() - self.config = config - self.n_routed_experts = config.n_routed_experts - - self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size))) - self.register_buffer("e_score_correction_bias", torch.zeros(self.n_routed_experts)) - - def forward(self, hidden_states): - hidden_states = hidden_states.view(-1, self.config.hidden_size) - router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32)) - return router_logits - - -@use_experts_implementation -class DeepseekV3NaiveMoe(nn.Module): - """Collection of expert weights stored as 3D tensors.""" - - def __init__(self, config): - super().__init__() - self.num_experts = config.num_local_experts - self.hidden_dim = config.hidden_size - self.intermediate_dim = config.moe_intermediate_size - self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim)) - self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim)) - self.act_fn = ACT2FN[config.hidden_act] - - def forward( - self, - hidden_states: torch.Tensor, - top_k_index: torch.Tensor, - top_k_weights: torch.Tensor, - ) -> torch.Tensor: - final_hidden_states = torch.zeros_like(hidden_states) - with torch.no_grad(): - expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts) - expert_mask = expert_mask.permute(2, 1, 0) - expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() - - for expert_idx in expert_hit: - expert_idx = expert_idx[0] - if expert_idx == self.num_experts: - continue - top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) - current_state = hidden_states[token_idx] - gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) - current_hidden_states = self.act_fn(gate) * up - current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) - current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None] - final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) - - return final_hidden_states - - -class DeepseekV3MoE(nn.Module): - """ - A mixed expert module containing shared experts. - """ - - def __init__(self, config): - super().__init__() - self.config = config - self.experts = DeepseekV3NaiveMoe(config) - self.gate = DeepseekV3TopkRouter(config) - self.shared_experts = DeepseekV3MLP( - config=config, intermediate_size=config.moe_intermediate_size * config.n_shared_experts - ) - self.n_routed_experts = config.n_routed_experts - self.n_group = config.n_group - self.topk_group = config.topk_group - self.norm_topk_prob = config.norm_topk_prob - self.routed_scaling_factor = config.routed_scaling_factor - self.top_k = config.num_experts_per_tok - - def route_tokens_to_experts(self, router_logits): - router_logits = router_logits.sigmoid() - router_logits_for_choice = router_logits + self.gate.e_score_correction_bias - group_scores = ( - router_logits_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group) - .topk(2, dim=-1)[0] - .sum(dim=-1) - ) - group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] - group_mask = torch.zeros_like(group_scores) - group_mask.scatter_(1, group_idx, 1) - score_mask = ( - group_mask.unsqueeze(-1) - .expand(-1, self.n_group, self.n_routed_experts // self.n_group) - .reshape(-1, self.n_routed_experts) - ) - scores_for_choice = router_logits_for_choice.masked_fill(~score_mask.bool(), 0.0) - topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1] - topk_weights = router_logits.gather(1, topk_indices) - if self.norm_topk_prob: - denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20 - topk_weights /= denominator - topk_weights = topk_weights * self.routed_scaling_factor - return topk_indices, topk_weights - - def forward(self, hidden_states): - residuals = hidden_states - orig_shape = hidden_states.shape - router_logits = self.gate(hidden_states) - topk_indices, topk_weights = self.route_tokens_to_experts(router_logits) - hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - hidden_states = self.experts(hidden_states, topk_indices, topk_weights).view(*orig_shape) - hidden_states = hidden_states + self.shared_experts(residuals) - return hidden_states - - -# ============================================================================ -# ROPE COMPONENT 2: RoPE Helper Functions -# ============================================================================ -# ROPE: Rotation helper - splits tensor and rotates [-x2, x1] -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -# ROPE: Main function to apply RoPE to query and key tensors -# apply_rope -@use_kernel_func_from_hub("rotary_pos_emb") -def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): - """Applies Rotary Position Embedding to the query and key tensors. - - Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. - cos (`torch.Tensor`): The cosine part of the rotary embedding. - sin (`torch.Tensor`): The sine part of the rotary embedding. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. - Returns: - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. - """ - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - -def eager_attention_forward( - module: nn.Module, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attention_mask: torch.Tensor | None, - scaling: float, - dropout: float = 0.0, - **kwargs: Unpack[TransformersKwargs], -): - key_states = repeat_kv(key, module.num_key_value_groups) - value_states = repeat_kv(value, module.num_key_value_groups) - - attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling - if attention_mask is not None: - attn_weights = attn_weights + attention_mask - - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) - attn_output = torch.matmul(attn_weights, value_states) - attn_output = attn_output.transpose(1, 2).contiguous() - - return attn_output, attn_weights - - -# ROPE: Alternative interleaved RoPE application (with view/transpose for efficiency) -def apply_rotary_pos_emb_interleave(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): - r""" - TODO let's just use the original freqcis computation to not have the view - transpose + reshape! This is not optimized! - Applies Rotary Position Embedding to the query and key tensors. - - Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. - cos (`torch.Tensor`): The cosine part of the rotary embedding. - sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`): - The position indices of the tokens corresponding to the query and key tensors. For example, this can be - used to pass offsetted position ids when working with a KV-cache. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. - Returns: - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. - """ - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) - - b, h, s, d = q.shape - q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) - - b, h, s, d = k.shape - k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) - - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -# ROPE: YaRN (Yet another RoPE extensioN) scaling function for extended context -def yarn_get_mscale(scale=1, mscale=1): - if scale <= 1: - return 1.0 - return 0.1 * mscale * math.log(scale) + 1.0 - - -# ============================================================================ -# END ROPE COMPONENT 2 -# ============================================================================ - - -class DeepseekV3Attention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__(self, config: DeepseekV3Config, layer_idx: int): - super().__init__() - self.config = config - self.layer_idx = layer_idx - self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads - self.attention_dropout = config.attention_dropout - self.num_heads = config.num_attention_heads - - self.q_lora_rank = config.q_lora_rank - self.qk_rope_head_dim = config.qk_rope_head_dim # ROPE: dimension for rotary embeddings - self.kv_lora_rank = config.kv_lora_rank - self.v_head_dim = config.v_head_dim - self.qk_nope_head_dim = config.qk_nope_head_dim - self.qk_head_dim = config.qk_head_dim - - self.is_causal = True - if self.q_lora_rank is None: - self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.qk_head_dim, bias=False) - else: - self.q_a_proj = nn.Linear(config.hidden_size, config.q_lora_rank, bias=config.attention_bias) - self.q_a_layernorm = DeepseekV3RMSNorm(config.q_lora_rank) - self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.qk_head_dim, bias=False) - - self.kv_a_proj_with_mqa = nn.Linear( - config.hidden_size, - self.kv_lora_rank + self.qk_rope_head_dim, - bias=config.attention_bias, - ) - self.kv_a_layernorm = DeepseekV3RMSNorm(self.kv_lora_rank) - self.kv_b_proj = nn.Linear( - self.kv_lora_rank, - self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), - bias=False, - ) - - self.o_proj = nn.Linear( - self.num_heads * self.v_head_dim, - config.hidden_size, - bias=config.attention_bias, - ) - - # ROPE: Initialize attention scaling (potentially adjusted by YaRN) - self.scaling = self.qk_head_dim ** (-0.5) - if self.config.rope_parameters.get("rope_type", "default") != "default": - mscale_all_dim = self.config.rope_parameters.get("mscale_all_dim", 0) - scaling_factor = self.config.rope_parameters["factor"] - if mscale_all_dim: - mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) # ROPE: Apply YaRN scaling - self.scaling = self.scaling * mscale * mscale - - def forward( - self, - hidden_states: torch.Tensor, - position_embeddings: tuple[torch.Tensor, torch.Tensor], - attention_mask: torch.Tensor | None, - past_key_values: Cache | None = None, - cache_position: torch.LongTensor | None = None, - **kwargs: Unpack[FlashAttentionKwargs], - ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]: - batch_size, seq_length = hidden_states.shape[:-1] - query_shape = (batch_size, seq_length, -1, self.qk_head_dim) - key_shape = (batch_size, seq_length, -1, self.qk_nope_head_dim + self.v_head_dim) - - if self.q_lora_rank is None: - q_states = self.q_proj(hidden_states) - else: - q_states = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) - q_states = q_states.view(query_shape).transpose(1, 2) - # ROPE: Split query into non-RoPE (q_pass) and RoPE (q_rot) parts - q_pass, q_rot = torch.split(q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) - - compressed_kv = self.kv_a_proj_with_mqa(hidden_states) - # ROPE: Split key into non-RoPE (k_pass) and RoPE (k_rot) parts - k_pass, k_rot = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) - - k_pass = self.kv_b_proj(self.kv_a_layernorm(k_pass)).view(key_shape).transpose(1, 2) - k_pass, value_states = torch.split(k_pass, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) - - k_rot = k_rot.view(batch_size, 1, seq_length, self.qk_rope_head_dim) - - # ======================================================================== - # ROPE COMPONENT 3: Apply RoPE to Query and Key tensors - # ======================================================================== - cos, sin = position_embeddings # ROPE: Get cos/sin from rotary embeddings - if self.config.rope_interleave: # support using interleaved weights for efficiency - q_rot, k_rot = apply_rotary_pos_emb_interleave(q_rot, k_rot, cos, sin) # ROPE: Interleaved version - else: - q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin) # ROPE: Standard version - # ======================================================================== - # END ROPE COMPONENT 3 - # ======================================================================== - k_rot = k_rot.expand(*k_pass.shape[:-1], -1) - - # ROPE: Concatenate non-RoPE and RoPE parts back together - query_states = torch.cat((q_pass, q_rot), dim=-1) - key_states = torch.cat((k_pass, k_rot), dim=-1) - - if past_key_values is not None: - # ROPE: sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) - - if is_flash_attention_requested(self.config) and self.qk_head_dim != self.v_head_dim: - value_states = F.pad(value_states, [0, self.qk_head_dim - self.v_head_dim]) - - attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface( - self.config._attn_implementation, eager_attention_forward - ) - - attn_output, attn_weights = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask, - dropout=0.0 if not self.training else self.attention_dropout, - scaling=self.scaling, - **kwargs, - ) - - if is_flash_attention_requested(self.config) and self.qk_head_dim != self.v_head_dim: - attn_output = attn_output[:, :, :, : self.v_head_dim] - - attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous() - attn_output = self.o_proj(attn_output) - return attn_output, attn_weights - - -class DeepseekV3DecoderLayer(GradientCheckpointingLayer): - def __init__(self, config: DeepseekV3Config, layer_idx: int): - super().__init__() - self.hidden_size = config.hidden_size - - self.self_attn = DeepseekV3Attention(config=config, layer_idx=layer_idx) - - if layer_idx >= config.first_k_dense_replace: - self.mlp = DeepseekV3MoE(config) - else: - self.mlp = DeepseekV3MLP(config) - - self.input_layernorm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: torch.Tensor | None = None, - position_ids: torch.LongTensor | None = None, - past_key_values: Cache | None = None, - use_cache: bool | None = False, - cache_position: torch.LongTensor | None = None, - position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, - **kwargs: Unpack[TransformersKwargs], - ) -> torch.Tensor: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - # Self Attention - hidden_states, _ = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **kwargs, - ) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - return hidden_states - - -@auto_docstring -class DeepseekV3PreTrainedModel(PreTrainedModel): - config: DeepseekV3Config - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["DeepseekV3DecoderLayer"] - _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn = True - _supports_sdpa = True - _supports_flex_attn = True - _can_compile_fullgraph = ( - is_grouped_mm_available() - ) # https://huggingface.co/docs/transformers/experts_interface#torchcompile - _supports_attention_backend = True - _can_record_outputs = { - "hidden_states": DeepseekV3DecoderLayer, - "attentions": DeepseekV3Attention, - } - _keep_in_fp32_modules_strict = ["e_score_correction_bias"] - _keys_to_ignore_on_load_unexpected = [r"model\.layers\.61.*"] - - @torch.no_grad() - def _init_weights(self, module): - super()._init_weights(module) - if isinstance(module, DeepseekV3TopkRouter): - init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) - init.zeros_(module.e_score_correction_bias) - elif isinstance(module, DeepseekV3NaiveMoe): - init.normal_(module.gate_up_proj, mean=0.0, std=self.config.initializer_range) - init.normal_(module.down_proj, mean=0.0, std=self.config.initializer_range) - - -@auto_docstring -class DeepseekV3Model(DeepseekV3PreTrainedModel): - def __init__(self, config: DeepseekV3Config): - super().__init__(config) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.layers = nn.ModuleList( - [DeepseekV3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] - ) - self.norm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - # ======================================================================== - # ROPE COMPONENT 4: RoPE Instantiation in Model - # ======================================================================== - self.rotary_emb = DeepseekV3RotaryEmbedding(config=config) # ROPE: Create rotary embedding instance - # ======================================================================== - self.gradient_checkpointing = False - - # Initialize weights and apply final processing - self.post_init() - - @merge_with_config_defaults - @capture_outputs - @auto_docstring - def forward( - self, - input_ids: torch.LongTensor | None = None, - attention_mask: torch.Tensor | None = None, - position_ids: torch.LongTensor | None = None, - past_key_values: Cache | None = None, - inputs_embeds: torch.FloatTensor | None = None, - cache_position: torch.LongTensor | None = None, - use_cache: bool | None = None, - **kwargs: Unpack[TransformersKwargs], - ) -> BaseModelOutputWithPast: - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - - if inputs_embeds is None: - inputs_embeds: torch.Tensor = self.embed_tokens(input_ids) - - if use_cache and past_key_values is None: - past_key_values = DynamicCache(config=self.config) - - if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position: torch.Tensor = ( - torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens - ) - - if position_ids is None: - position_ids = cache_position.unsqueeze(0) - - causal_mask = create_causal_mask( - config=self.config, - inputs_embeds=inputs_embeds, - attention_mask=attention_mask, - cache_position=cache_position, - past_key_values=past_key_values, - position_ids=position_ids, - ) - - hidden_states = inputs_embeds - # ROPE: Generate position embeddings (cos/sin) for all positions - position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids) - - for decoder_layer in self.layers[: self.config.num_hidden_layers]: - hidden_states = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_embeddings=position_embeddings, # ROPE: Pass to each layer - position_ids=position_ids, - past_key_values=past_key_values, - use_cache=use_cache, - cache_position=cache_position, - **kwargs, - ) - - hidden_states = self.norm(hidden_states) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=past_key_values, - ) - - -@auto_docstring -class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel, GenerationMixin): - _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} - _tp_plan = {"lm_head": "colwise_gather_output"} - _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} - - def __init__(self, config): - super().__init__(config) - self.model = DeepseekV3Model(config) - self.vocab_size = config.vocab_size - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - @can_return_tuple - @auto_docstring - def forward( - self, - input_ids: torch.LongTensor | None = None, - attention_mask: torch.Tensor | None = None, - position_ids: torch.LongTensor | None = None, - past_key_values: Cache | None = None, - inputs_embeds: torch.FloatTensor | None = None, - labels: torch.LongTensor | None = None, - use_cache: bool | None = None, - cache_position: torch.LongTensor | None = None, - logits_to_keep: int | torch.Tensor = 0, - **kwargs: Unpack[TransformersKwargs], - ) -> CausalLMOutputWithPast: - r""" - Example: - - ```python - >>> from transformers import AutoTokenizer, DeepseekV3ForCausalLM - - >>> model = DeepseekV3ForCausalLM.from_pretrained("meta-deepseek_v3/DeepseekV3-2-7b-hf") - >>> tokenizer = AutoTokenizer.from_pretrained("meta-deepseek_v3/DeepseekV3-2-7b-hf") - - >>> prompt = "Hey, are you conscious? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." - ```""" - outputs: BaseModelOutputWithPast = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - cache_position=cache_position, - **kwargs, - ) - - hidden_states = outputs.last_hidden_state - # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep - logits = self.lm_head(hidden_states[:, slice_indices, :]) - - loss = None - if labels is not None: - loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -class DeepseekV3ForSequenceClassification(GenericForSequenceClassification, DeepseekV3PreTrainedModel): - pass - - -class DeepseekV3ForTokenClassification(GenericForTokenClassification, DeepseekV3PreTrainedModel): - pass - - -__all__ = [ - "DeepseekV3PreTrainedModel", - "DeepseekV3Model", - "DeepseekV3ForCausalLM", - "DeepseekV3ForSequenceClassification", - "DeepseekV3ForTokenClassification", -] diff --git a/litgpt/modeling_rope_util.py b/litgpt/modeling_rope_util.py deleted file mode 100644 index a0e2a8559e..0000000000 --- a/litgpt/modeling_rope_util.py +++ /dev/null @@ -1,942 +0,0 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math -import warnings -from functools import wraps -from typing import TYPE_CHECKING, Optional, TypedDict - -from .utils import is_torch_available, logging - -logger = logging.get_logger(__name__) - - -if is_torch_available(): - import torch - -if TYPE_CHECKING: - from .configuration_utils import PreTrainedConfig - - -def dynamic_rope_update(rope_forward): - """ - Decorator function to update the RoPE parameters in the forward pass, if the model is using a dynamic RoPE - (i.e. a RoPE implementation that may recompute its frequencies in the forward pass). - - Args: - rope_forward (Callable): - The forward pass of the RoPE implementation. - - Returns: - The decorated forward pass. - """ - - def longrope_frequency_update(self, position_ids, device, layer_type=None): - """Longrope uses long factor if sequence is larger than original pretraining length, short otherwise.""" - seq_len = torch.max(position_ids) + 1 - - if layer_type is None: - rope_type = self.rope_type - original_inv_freq = self.original_inv_freq - prefix = "" - original_max_position_embeddings = self.config.rope_parameters["original_max_position_embeddings"] - else: - rope_type = self.rope_type[layer_type] - original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq") - prefix = f"{layer_type}_" - original_max_position_embeddings = self.config.rope_parameters[layer_type][ - "original_max_position_embeddings" - ] - - if seq_len > original_max_position_embeddings: - if not hasattr(self, f"{layer_type}_long_inv_freq"): - rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type] - long_inv_freq, _ = rope_init_fn( - self.config, - device, - seq_len=original_max_position_embeddings + 1, - layer_type=layer_type, - ) - self.register_buffer(f"{prefix}inv_freq", long_inv_freq, persistent=False) - setattr(self, f"{prefix}long_inv_freq", long_inv_freq) - else: - # This .to() is needed if the model has been moved to a device after being initialized (because - # the buffer is automatically moved, but not the original copy) - original_inv_freq = original_inv_freq.to(device) - self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False) - setattr(self, f"{prefix}original_inv_freq", original_inv_freq) - - def dynamic_frequency_update(self, position_ids, device, layer_type=None): - """ - dynamic RoPE layers should recompute `inv_freq` in the following situations: - 1 - growing beyond the cached sequence length (allow scaling) - 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) - """ - seq_len = torch.max(position_ids) + 1 - if layer_type is None: - rope_type = self.rope_type - max_seq_len_cached = self.max_seq_len_cached - original_inv_freq = self.original_inv_freq - prefix = "" - else: - rope_type = self.rope_type[layer_type] - max_seq_len_cached = getattr(self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached) - original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq") - prefix = f"{layer_type}_" - - if seq_len > max_seq_len_cached: # growth - rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type] - inv_freq, self.attention_scaling = rope_init_fn( - self.config, - device, - seq_len=seq_len, - layer_type=layer_type, - ) - # TODO joao: may break with compilation - self.register_buffer(f"{prefix}inv_freq", inv_freq, persistent=False) - setattr(self, f"{layer_type}_max_seq_len_cached", seq_len) - - if seq_len < self.original_max_seq_len and max_seq_len_cached > self.original_max_seq_len: # reset - # This .to() is needed if the model has been moved to a device after being initialized (because - # the buffer is automatically moved, but not the original copy) - original_inv_freq = original_inv_freq.to(device) - self.register_buffer(f"{prefix}inv_freq", original_inv_freq, persistent=False) - setattr(self, f"{prefix}original_inv_freq", original_inv_freq) - setattr(self, f"{layer_type}_max_seq_len_cached", self.original_max_seq_len) - - @wraps(rope_forward) - def wrapper(self, x, position_ids, layer_type=None): - rope_type = self.rope_type if layer_type is None else self.rope_type[layer_type] - kwargs = {"layer_type": layer_type} if layer_type is not None else {} - if "dynamic" in rope_type: - dynamic_frequency_update(self, position_ids, device=x.device, **kwargs) - elif rope_type == "longrope": - longrope_frequency_update(self, position_ids, device=x.device, **kwargs) - return rope_forward(self, x, position_ids, **kwargs) - - return wrapper - - -def _compute_linear_scaling_rope_parameters( - config: Optional["PreTrainedConfig"] = None, - device: Optional["torch.device"] = None, - seq_len: int | None = None, - layer_type: str | None = None, -) -> tuple["torch.Tensor", float]: - """ - Computes the inverse frequencies with linear scaling. Credits to the Reddit user /u/kaiokendev - Args: - config ([`~transformers."PreTrainedConfig"`]): - The model configuration. This function assumes that the config will provide at least the following - properties: - - * rope_theta (`float`): The base wavelength from which the inverse frequencies will be derived. - * hidden_size (`int`): The numerator when deriving a head_dim, if not provided directly. - * num_attention_heads (`int`): The denominator when deriving a head_dim, if not provided directly. - - Additionally, this function will make use of the following properties if they are found in the config: - - * head_dim (`int`, *optional*): The size of the key-value heads in the model. If None, this value will be - derived as hidden_size // num_attention_heads. - * partial_rotary_factor (`float`, *optional*): If less than 1.0, inverse frequencies will be returned for - the first fraction of the head_dim. Defaults to 1.0. - device (`torch.device`): - The device to use for initialization of the inverse frequencies. - seq_len (`int`, *optional*): - The current sequence length. Unused for this type of RoPE. - - Returns: - Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the - post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). - """ - # For backward compatibility standardize the `rope_parameters_dict` if it uses old format - config.standardize_rope_params() - rope_parameters_dict = config.rope_parameters[layer_type] if layer_type is not None else config.rope_parameters - factor = rope_parameters_dict["factor"] - - # Gets the default RoPE parameters - base = rope_parameters_dict["rope_theta"] - partial_rotary_factor = rope_parameters_dict.get("partial_rotary_factor", 1.0) - head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads - dim = int(head_dim * partial_rotary_factor) - attention_factor = 1.0 # Unused in this type of RoPE - - # Compute the inverse frequencies - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)) - - # Then applies linear scaling to the frequencies. - # NOTE: originally, scaling was applied to the position_ids. However, we get `embs = inv_freq @ position_ids`, so - # applying scaling to the inverse frequencies is equivalent. - inv_freq /= factor - return inv_freq, attention_factor - - -def _compute_dynamic_ntk_parameters( - config: Optional["PreTrainedConfig"] = None, - device: Optional["torch.device"] = None, - seq_len: int | None = None, - layer_type: str | None = None, -) -> tuple["torch.Tensor", float]: - """ - Computes the inverse frequencies with NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla - - Args: - config ([`~transformers."PreTrainedConfig"`]): - The model configuration. This function assumes that the config will provide at least the following - properties: - - * rope_theta (`float`): The base wavelength from which the inverse frequencies will be derived. - * hidden_size (`int`): The numerator when deriving a head_dim, if not provided directly. - * num_attention_heads (`int`): The denominator when deriving a head_dim, if not provided directly. - * max_position_embeddings (`int`): The default sequence length used to update the dynamic RoPE at - inference time - * rope_parameters (`dict[str, float]`): The standard RoPE scaling parameters, from which `factor` - will be accessed. The value of `factor` is used to determine the new base frequency, along with the - current sequence length (seq_len), the maximum positional embeddings (max_position_embeddings), and the - computed dimensionality (dim) of the rotary embeddings. If seq_len <= max_position_embeddings, this - factor has no effect. If seq_len <= max_position_embeddings, this factor effectively stretches the - context window using an exponent derived from `dim`. - - Additionally, this function will make use of the following properties if they are found in the config: - - * head_dim (`int`, *optional*): The size of the key-value heads in the model. If None, this value will be - derived as hidden_size // num_attention_heads. - * partial_rotary_factor (`float`, *optional*): If less than 1.0, inverse frequencies will be returned for - the first fraction of the head_dim. Defaults to 1.0. - device (`torch.device`): - The device to use for initialization of the inverse frequencies. - seq_len (`int`, *optional*): - The current sequence length, used to update the dynamic RoPE at inference time. If `None` or shorter than - max_position_embeddings, this value will be overridden by max_position_embeddings. - - Returns: - Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the - post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). - """ - # For backward compatibility standardize the `rope_parameters_dict` if it uses old format - config.standardize_rope_params() - rope_parameters_dict = config.rope_parameters[layer_type] if layer_type is not None else config.rope_parameters - - base = rope_parameters_dict["rope_theta"] - partial_rotary_factor = rope_parameters_dict.get("partial_rotary_factor", 1.0) - head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - dim = int(head_dim * partial_rotary_factor) - factor = rope_parameters_dict["factor"] - attention_factor = 1.0 # Unused in this type of RoPE - - # seq_len: default to max_position_embeddings, e.g. at init time - if seq_len is None: - seq_len = config.max_position_embeddings - elif isinstance(seq_len, torch.Tensor): - seq_len = torch.maximum( - seq_len, - torch.tensor(config.max_position_embeddings, dtype=seq_len.dtype, device=seq_len.device), - ) - else: - seq_len = max(seq_len, config.max_position_embeddings) - - # Compute the inverse frequencies - base = base * ((factor * seq_len / config.max_position_embeddings) - (factor - 1)) ** (dim / (dim - 2)) - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)) - return inv_freq, attention_factor - - -def _compute_yarn_parameters( - config: "PreTrainedConfig", - device: Optional["torch.device"] = None, - seq_len: int | None = None, - layer_type: str | None = None, -) -> tuple["torch.Tensor", float]: - """ - Computes the inverse frequencies with NTK scaling. Please refer to the - [original paper](https://huggingface.co/papers/2309.00071) - - Args: - config ([`~transformers."PreTrainedConfig"`]): - The model configuration. This function assumes that the config will provide at least the following - properties: - - * rope_theta (`float`): The base wavelength from which the inverse frequencies will be derived. - * hidden_size (`int`): The numerator when deriving a head_dim, if not provided directly. - * num_attention_heads (`int`): The denominator when deriving a head_dim, if not provided directly. - * max_position_embeddings (`int`): The maximum length of the positional embeddings. - * rope_parameters (`dict[str, float | int]`): The standard RoPE scaling parameters, from which the following - keys will be accessed: - * `attention_factor` (`float`, *optional*): The scaling factor to be applied to the computed cos/sin. - If None, the value is inferred from `factor`, `mscale`, and `mscale_all_dim` as available. - * `beta_fast` (`float`, *optional*, defaults to 32): Parameter to set the boundary for extrapolation - (only) in the linear ramp function. - * `beta_slow` (`float`, *optional*, defaults to 1): Parameter to set the boundary for interpolation - (only) in the linear ramp function. - * `factor` (`float`, *optional*): The scaling factor applied when interpolating the position IDs to - extend the possible context length. Additionally, if `attention_factor` is None, the log of this - value is used to compute a value for `attention_factor`, possibly in conjunciton with `mscale` and - `mscale_all_dim`, if provided. - * `mscale` (`float`, *optional*): If `attention_factor` is None and both `mscale` and - `mscale_all_dim` are provided, `mscale` acts scalar augmenting `log(factor)` when computing the - numerator for the inferred value of `attention_factor`. If not provided, `attention_factor` will be - calculated based on `factor` only. - * `mscale_all_dim` (`float`, *optional*): If `attention_factor` is None and both `mscale` and - `mscale_all_dim` are provided, `mscale_all_dim` acts scalar augmenting `log(factor)` when computing - the denominator for the inferred value of `attention_factor`. If not provided, `attention_factor` - will be calculated based on `factor` only. - * `original_max_position_embeddings` (`int`): The original max position embeddings used during pretraining. - * `truncate` (`bool`, *optional*): Whether to truncate the correction range. - - Additionally, this function will make use of the following properties if they are found in the config: - - * head_dim (`int`, *optional*): The size of the key-value heads in the model. If None, this value will be - derived as hidden_size // num_attention_heads. - * partial_rotary_factor (`float`, *optional*, defaults to 1.0): If less than 1.0, inverse frequencies - will be returned for the first fraction of the head_dim. - device (`torch.device`): - The device to use for initialization of the inverse frequencies. - seq_len (`int`, *optional*): - The current sequence length. Unused for this type of RoPE. - - Returns: - Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the - post-processing scaling factor applied to the computed cos/sin. - """ - # For backward compatibility standardize the `rope_parameters_dict` if it uses old format - config.standardize_rope_params() - rope_parameters_dict = config.rope_parameters[layer_type] if layer_type is not None else config.rope_parameters - - base = rope_parameters_dict["rope_theta"] - partial_rotary_factor = rope_parameters_dict.get("partial_rotary_factor", 1.0) - head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - dim = int(head_dim * partial_rotary_factor) - - factor = rope_parameters_dict["factor"] - attention_factor = rope_parameters_dict.get("attention_factor") - mscale = rope_parameters_dict.get("mscale") - mscale_all_dim = rope_parameters_dict.get("mscale_all_dim") - original_max_position_embeddings = rope_parameters_dict["original_max_position_embeddings"] - - # NOTE: DeekSeek-V3 (and potentially other models) have `original_max_position_embeddings` field - # containing the pretrained value. They use the ratio between `max_position_embeddings` and this value - # to compute the default attention scaling factor, instead of using `factor`. - if factor is None: - factor = config.max_position_embeddings / original_max_position_embeddings - - def get_mscale(scale, mscale=1): - if scale <= 1: - return 1.0 - return 0.1 * mscale * math.log(scale) + 1.0 - - # Sets the attention factor as suggested in the paper - if attention_factor is None: - if mscale and mscale_all_dim: - attention_factor = float(get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dim)) - else: - attention_factor = get_mscale(factor) - - # Optional config options - # beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly) - beta_fast = rope_parameters_dict.get("beta_fast") or 32 - beta_slow = rope_parameters_dict.get("beta_slow") or 1 - - # Compute the inverse frequencies - def find_correction_dim(num_rotations, dim, base, max_position_embeddings): - """Inverse dimension formula to find the dimension based on the number of rotations""" - return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base)) - - def find_correction_range(low_rot, high_rot, dim, base, max_position_embeddings, truncate): - """Find dimension range bounds based on rotations""" - low = find_correction_dim(low_rot, dim, base, max_position_embeddings) - high = find_correction_dim(high_rot, dim, base, max_position_embeddings) - if truncate: - low = math.floor(low) - high = math.ceil(high) - return max(low, 0), min(high, dim - 1) - - def linear_ramp_factor(min, max, dim): - if min == max: - max += 0.001 # Prevent singularity - - linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) - ramp_func = torch.clamp(linear_func, 0, 1) - return ramp_func - - # Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs - # to expand the possible context length. In other words, interpolation = apply scaling factor. - pos_freqs = base ** (torch.arange(0, dim, 2).to(device=device, dtype=torch.float) / dim) - inv_freq_extrapolation = 1.0 / pos_freqs - inv_freq_interpolation = 1.0 / (factor * pos_freqs) - - truncate = config.rope_parameters.get("truncate", True) - low, high = find_correction_range(beta_fast, beta_slow, dim, base, original_max_position_embeddings, truncate) - - # Get n-dimensional rotational scaling corrected for extrapolation - inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).to(device=device, dtype=torch.float) - inv_freq = ( - inv_freq_interpolation * (1 - inv_freq_extrapolation_factor) - + inv_freq_extrapolation * inv_freq_extrapolation_factor - ) - return inv_freq, attention_factor - - -def _compute_longrope_parameters( - config: "PreTrainedConfig", - device: Optional["torch.device"] = None, - seq_len: int | None = None, - layer_type: str | None = None, -) -> tuple["torch.Tensor", float]: - """ - Computes the inverse frequencies with LongRoPE scaling. Please refer to the - [original implementation](https://github.com/microsoft/LongRoPE) - - Args: - config ([`~transformers."PreTrainedConfig"`]): - The model configuration. This function assumes that the config will provide at least the following - properties: - - * rope_theta (`float`): The base wavelength from which the inverse frequencies will be derived. - * hidden_size (`int`): The numerator when deriving a head_dim, if not provided directly. - * num_attention_heads (`int`): The denominator when deriving a head_dim, if not provided directly. - * max_position_embeddings (`int`): The maximum length of the positional embeddings. - * original_max_position_embeddings (`int`, *optional*): The original max position embeddings used during - pretraining. If not provided, defaults to `max_position_embeddings`. - * rope_parameters (`dict[str, float]`): The standard RoPE scaling parameters, from which the following keys - will be accessed: - * `attention_factor` (`float`, *optional*): The scaling factor to be applied on the attention - computation. If unspecified, it defaults to value recommended by the implementation, inferred from - the value of `factor`. - * `factor` (`float`, *optional*): The scaling factor to apply to the RoPE embeddings. If both - `max_position_embeddings` and `original_max_position_embeddings` are provided, this value will be - overridden s the ratio between those values. - * `long_factor` (`float`, *optional*): The scale factor applied when computing the inverse - frequencies if `seq_len` is provided and greater than `original_max_position_embeddings`. - * `short_factor` (`float`, *optional*): The scale factor applied when computing the inverse - frequencies if `seq_len` is None or less-than-or-equal-to `original_max_position_embeddings`. - - Additionally, this function will make use of the following properties if they are found in the config: - - * head_dim (`int`, *optional*): The size of the key-value heads in the model. If None, this value will be - derived as hidden_size // num_attention_heads. - * partial_rotary_factor (`float`, *optional*, defaults to 1.0): If less than 1.0, inverse frequencies - will be returned for the first fraction of the head_dim. - device (`torch.device`): - The device to use for initialization of the inverse frequencies. - seq_len (`int`, *optional*): - The current sequence length. - - Returns: - Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the - post-processing scaling factor applied to the computed cos/sin. - """ - # For backward compatibility standardize the `rope_parameters_dict` if it uses old format - config.standardize_rope_params() - rope_parameters_dict = config.rope_parameters[layer_type] if layer_type is not None else config.rope_parameters - - base = rope_parameters_dict["rope_theta"] - partial_rotary_factor = rope_parameters_dict.get("partial_rotary_factor", 1.0) - head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - dim = int(head_dim * partial_rotary_factor) - - long_factor = rope_parameters_dict["long_factor"] - short_factor = rope_parameters_dict["short_factor"] - factor = rope_parameters_dict.get("factor") - attention_factor = rope_parameters_dict.get("attention_factor") - original_max_position_embeddings = rope_parameters_dict["original_max_position_embeddings"] - - # NOTE: Phi3 (and potentially other models) modify `max_position_embeddings` and have a - # `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two - # values to compute the default attention scaling factor, instead of using `factor`. - if factor is None: - factor = config.max_position_embeddings / original_max_position_embeddings - - # Sets the attention factor as suggested in the paper - if attention_factor is None: - if factor <= 1.0: - attention_factor = 1.0 - else: - attention_factor = math.sqrt(1 + math.log(factor) / math.log(original_max_position_embeddings)) - - # Compute the inverse frequencies -- scaled based on the target sequence length - if seq_len and seq_len > original_max_position_embeddings: - ext_factors = torch.tensor(long_factor, dtype=torch.float32, device=device) - else: - ext_factors = torch.tensor(short_factor, dtype=torch.float32, device=device) - inv_freq_shape = torch.arange(0, dim, 2, dtype=torch.int64, device=device).float() / dim - inv_freq = 1.0 / (ext_factors * base**inv_freq_shape) - - return inv_freq, attention_factor - - -def _compute_llama3_parameters( - config: "PreTrainedConfig", - device: Optional["torch.device"] = None, - seq_len: int | None = None, - layer_type: str | None = None, -) -> tuple["torch.Tensor", float]: - """ - Computes the inverse frequencies for llama 3.1. - - Args: - config ([`~transformers."PreTrainedConfig"`]): - The model configuration. This function assumes that the config will provide at least the following - properties: - - * rope_theta (`float`): The base wavelength from which the inverse frequencies will be derived. - * hidden_size (`int`): The numerator when deriving a head_dim, if not provided directly. - * num_attention_heads (`int`): The denominator when deriving a head_dim, if not provided directly. - * rope_parameters (`dict[str, float | int]`): The standard RoPE scaling parameters, from which the following - keys will be accessed: - * `factor` (`float`, *optional*): The scaling factor applied to the inverse frequencies when 1) the - wavelength is greater than `low_freq_wavelen` prior to smoothing, and 2) to all inverse frequencies - during smoothing. - * `high_freq_factor` (`float`): The scale factor used to compute `high_freq_wavelen` and - the value for the denominator of the smoothing factor prior to the `low_freq_factor` shift. - * `low_freq_factor` (`float`): The scale factor used to compute `low_freq_wavelen` and - the shift applied to the numerator and denominator of the smoothing factor. - frequencies if `seq_len` is None or less-than-or-equal-to `original_max_position_embeddings`. - * `original_max_position_embeddings` (`int`): The original max position embeddings used - during pretraining. If not provided, the function falls back to `max_position_embeddings`. - - Additionally, this function will make use of the following properties if they are found in the config: - - * head_dim (`int`, *optional*): The size of the key-value heads in the model. If None, this value will be - derived as hidden_size // num_attention_heads. - * partial_rotary_factor (`float`, *optional*): If less than 1.0, inverse frequencies will be returned for - the first fraction of the head_dim. Defaults to 1.0. - device (`torch.device`): - The device to use for initialization of the inverse frequencies. - seq_len (`int`, *optional*): - The current sequence length. Unused for this type of RoPE. - Returns: - Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the - post-processing scaling factor applied to the computed cos/sin. - """ - # For backward compatibility standardize the `rope_parameters_dict` if it uses old format - config.standardize_rope_params() - rope_parameters_dict = config.rope_parameters[layer_type] if layer_type is not None else config.rope_parameters - - # Gets the default RoPE parameters - base = rope_parameters_dict["rope_theta"] - partial_rotary_factor = rope_parameters_dict.get("partial_rotary_factor", 1.0) - head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads - dim = int(head_dim * partial_rotary_factor) - attention_factor = 1.0 # Unused in this type of RoPE - - # Compute the inverse frequencies - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)) - - factor = rope_parameters_dict["factor"] # `8` in the original implementation - low_freq_factor = rope_parameters_dict["low_freq_factor"] # `1` in the original implementation - high_freq_factor = rope_parameters_dict["high_freq_factor"] # `4` in the original implementation - old_context_len = rope_parameters_dict["original_max_position_embeddings"] # `8192` in the original implementation - - low_freq_wavelen = old_context_len / low_freq_factor - high_freq_wavelen = old_context_len / high_freq_factor - - wavelen = 2 * math.pi / inv_freq - # wavelen < high_freq_wavelen: do nothing - # wavelen > low_freq_wavelen: divide by factor - inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq) - # otherwise: interpolate between the two, using a smooth factor - smooth_factor = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor) - smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / factor + smooth_factor * inv_freq_llama - is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen) - inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama) - - return inv_freq_llama, attention_factor - - -# This maps the "rope_type" string field in rope config to the corresponding function to compute the RoPE parameters -# from the model config. You can append new {'rope_type': callable} pairs to this rope_parameters to enable custom RoPE -# parameterizations, as long as the callable has the same signature. -ROPE_INIT_FUNCTIONS = { - "linear": _compute_linear_scaling_rope_parameters, - "dynamic": _compute_dynamic_ntk_parameters, - "yarn": _compute_yarn_parameters, - "longrope": _compute_longrope_parameters, - "llama3": _compute_llama3_parameters, -} - - -class RopeParameters(TypedDict, total=False): - """ - Args: - rope_theta (`float`): - The base period of the RoPE embeddings. - rope_type (`str`, *optional*, defaults to "default"): - The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', - 'llama3'], with 'default' being the original RoPE implementation. - partial_rotary_factor (`float`, *optional*): - The percentage of the query and key head embedding on which RoPE will be applied. - factor (`float`, *optional*): - Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In - most scaling types, a `factor` of x will enable the model to handle sequences of length x * - original maximum pre-trained length. - original_max_position_embeddings (`int`, *optional*): - Used with 'yarn', 'longrope' and 'llama3'. The original max position embeddings used during - pretraining. - attention_factor (`float`, *optional*): - Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention - computation. If unspecified, it defaults to value recommended by the implementation, using the - `factor` field to infer the suggested value. - beta_fast (`float`, *optional*): - Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear - ramp function. If unspecified, it defaults to 32. - beta_slow (`float`, *optional*): - Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear - ramp function. If unspecified, it defaults to 1. - short_factor (`list[float]`, *optional*): - Only used with 'longrope'. The scaling factor to be applied to short contexts (< - `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden - size divided by the number of attention heads divided by 2 - long_factor (`list[float]`, *optional*): - Only used with 'longrope'. The scaling factor to be applied to long contexts (< - `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden - size divided by the number of attention heads divided by 2 - low_freq_factor (`float`, *optional*): - Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE - high_freq_factor (`float`, *optional*): - Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE - """ - - rope_theta: float - rope_type: str | None - partial_rotary_factor: float | None - factor: float | None - original_max_position_embeddings: int | None - attention_factor: float | None - beta_fast: float | None - beta_slow: float | None - short_factor: list[float] | None - long_factor: list[float] | None - low_freq_factor: float | None - high_freq_factor: float | None - - -class RotaryEmbeddingConfigMixin: - """ - A Mixin containing the functionality to standardize and validate RoPE parameters. - """ - - default_theta = 10_000.0 - - def convert_rope_params_to_dict(self, ignore_keys_at_rope_validation: set | None = None, **kwargs): - rope_scaling = kwargs.pop("rope_scaling", None) - self.rope_parameters = rope_scaling or self.rope_parameters - self.rope_parameters = self.rope_parameters if self.rope_parameters is not None else {} - - # Standardize and validate the correctness of rotary position embeddings parameters. Priority for these parameters is: - # 1. Values in `rope_parameters` dict (where they should be after standardization) - # 2. Values in `kwargs` (i.e. it's in config.json but not MyConfig.__init__'s args) - # 3. Values in the config's attributes (i.e. it's in MyConfig.__init__'s args) - # 4. Default values (i.e. not present at all but other RoPE parameters are present) - rope_theta = kwargs.pop("rope_theta", getattr(self, "rope_theta", self.default_theta)) - self.rope_parameters.setdefault("rope_theta", rope_theta) - - partial_rotary_factor = kwargs.get("partial_rotary_factor", getattr(self, "partial_rotary_factor", None)) - if partial_rotary_factor is not None: - self.rope_parameters.setdefault("partial_rotary_factor", partial_rotary_factor) - ignore_keys_at_rope_validation = ( - set() if ignore_keys_at_rope_validation is None else ignore_keys_at_rope_validation - ) - ignore_keys_at_rope_validation = ignore_keys_at_rope_validation | {"partial_rotary_factor"} - - self.standardize_rope_params() - self.validate_rope(ignore_keys=ignore_keys_at_rope_validation) - return kwargs - - def standardize_rope_params(self): - """ - Helper to standardize the config's rope params field by ensuring the params are defined for each - later type. For old model the fn will duplicate a single rope param in each layer type (backward compatibility) - """ - # Move `rope_theta` and `partial_rotary_factor` to the `rope_parameters`, if not there yet - rope_theta = getattr(self, "rope_theta", None) - partial_rotary_factor = getattr(self, "partial_rotary_factor", None) - rope_parameters = getattr(self, "rope_parameters", None) or {} - layer_types = getattr(self, "layer_types", None) - - # Case 0: no RoPE params defined - if not (rope_parameters or rope_theta): - # partial_rotary_factor without rope_theta is invalid, so we don't check for it here - logger.warning("`standardize_rope_params` was called but no RoPE parameters were found.") - return - # Case 1: RoPE param keys do not intersect with possible `layer_types` -> one global dict - elif layer_types is None or rope_parameters == {} or not set(rope_parameters.keys()).issubset(layer_types): - rope_parameters.setdefault("rope_type", rope_parameters.get("type", "default")) - rope_parameters.setdefault("rope_theta", rope_theta) - if partial_rotary_factor is not None: - rope_parameters["partial_rotary_factor"] = partial_rotary_factor - - # Move pretraining-time maximum length to rope parameter dict for RoPE types with scaling - if rope_parameters["rope_type"] in ["llama3", "yarn", "longrope"]: - if hasattr(self, "original_max_position_embeddings"): - # NOTE: Phi3 (and potentially other models) save `original_max_position_embeddings` field - # containing the pretrained value outside rope parameters. This is an exception case where we - # give priority to `self.original_max_position_embeddings - self.rope_parameters["original_max_position_embeddings"] = self.original_max_position_embeddings - else: - self.rope_parameters.setdefault("original_max_position_embeddings", self.max_position_embeddings) - - # Case 2: different RoPE for each layer -> several params as nested dict - else: - for layer_type in set(layer_types): - rope_parameters[layer_type].setdefault("rope_type", rope_parameters[layer_type].get("type", "default")) - rope_parameters[layer_type].setdefault("rope_theta", rope_theta) - if partial_rotary_factor is not None: - rope_parameters[layer_type]["partial_rotary_factor"] = partial_rotary_factor - - if rope_parameters[layer_type]["rope_type"] in ["llama3", "yarn", "longrope"]: - self.rope_parameters[layer_type].setdefault( - "original_max_position_embeddings", self.max_position_embeddings - ) - - self.rope_parameters = rope_parameters - - def validate_rope(self: "PreTrainedConfig", ignore_keys: set | None = None): - """ - Validate the RoPE config arguments, given a `"PreTrainedConfig"` object - """ - rope_parameters_dict = self.rope_parameters - if rope_parameters_dict is None: - return - - if getattr(self, "layer_types", None) is not None and set(rope_parameters_dict.keys()).issubset( - self.layer_types - ): - pass - else: - rope_parameters_dict = {"full_attention": rope_parameters_dict} - - for rope_parameters in rope_parameters_dict.values(): - rope_type = rope_parameters.get("rope_type", rope_parameters.get("type", "default")) - validation_fn = getattr(self, f"_validate_{rope_type}_rope_parameters", None) - rope_parameters["rope_type"] = rope_type - - if validation_fn is not None: - validation_fn(rope_parameters, ignore_keys=ignore_keys) - else: - logger.warning( - f"Missing validation function in 'RotaryEmbeddingConfigMixin' for 'rope_type'='{rope_type}'" - ) - - def _validate_default_rope_parameters(self, rope_parameters: dict, ignore_keys: set | None = None): - required_keys = {"rope_type", "rope_theta"} - received_keys = set(rope_parameters.keys()) - rope_type = rope_parameters["rope_type"] - self._check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys) - - def _validate_linear_rope_parameters(self, rope_parameters: dict, ignore_keys: set | None = None): - required_keys = {"rope_type", "factor", "rope_theta"} - received_keys = set(rope_parameters.keys()) - rope_type = rope_parameters["rope_type"] - self._check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys) - - factor = rope_parameters["factor"] - if factor is None or not isinstance(factor, float) or factor < 1.0: - logger.warning(f"`rope_parameters`'s factor field must be a float >= 1, got {factor}") - - def _validate_dynamic_rope_parameters(self, rope_parameters: dict, ignore_keys: set | None = None): - required_keys = {"rope_type", "factor"} - received_keys = set(rope_parameters.keys()) - rope_type = rope_parameters["rope_type"] - self._check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys) - - factor = rope_parameters["factor"] - if factor is None or not isinstance(factor, float) or factor < 1.0: - logger.warning(f"`rope_parameters`'s factor field must be a float >= 1, got {factor}") - - def _validate_yarn_rope_parameters(self, rope_parameters: dict, ignore_keys: set | None = None): - required_keys = {"rope_type", "factor", "rope_theta", "original_max_position_embeddings"} - optional_keys = { - "attention_factor", - "beta_fast", - "beta_slow", - "mscale", - "mscale_all_dim", - "truncate", - } - received_keys = set(rope_parameters.keys()) - rope_type = rope_parameters["rope_type"] - self._check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys) - - factor = rope_parameters["factor"] - if factor is None or not isinstance(factor, float) or factor < 1.0: - logger.warning(f"`rope_parameters`'s factor field must be a float >= 1, got {factor}") - - attention_factor = rope_parameters.get("attention_factor") - if attention_factor is not None and (not isinstance(attention_factor, float) or attention_factor < 0): - logger.warning( - f"`rope_parameters`'s attention_factor field must be a float greater than 0, got {attention_factor}" - ) - beta_fast = rope_parameters.get("beta_fast") - if beta_fast is not None and not isinstance(beta_fast, float): - logger.warning(f"`rope_parameters`'s beta_fast field must be a float, got {beta_fast}") - beta_slow = rope_parameters.get("beta_slow") - if beta_slow is not None and not isinstance(beta_slow, float): - logger.warning(f"`rope_parameters`'s beta_slow field must be a float, got {beta_slow}") - - if (beta_fast or 32) < (beta_slow or 1): - logger.warning( - f"`rope_parameters`'s beta_fast field must be greater than beta_slow, got beta_fast={beta_fast} " - f"(defaults to 32 if None) and beta_slow={beta_slow} (defaults to 1 if None)" - ) - - # Double-check: `factor` should be the ratio between the pre-yarn and post-yarn context lengths. - # NOTE: we might get `implicit_factor == 1` if config's `original_max_position_embeddings` was - # inferred from `max_position_embeddings` during standardization - original_max_position_embeddings = self.rope_parameters["original_max_position_embeddings"] - implicit_factor = self.max_position_embeddings / original_max_position_embeddings - if implicit_factor != factor and implicit_factor != 1: - logger.warning_once( - f"The explicitly set RoPE scaling factor (config.rope_parameters['factor'] = {factor}) does not match " - "the ratio implicitly set by other parameters (implicit factor = " - "post-yarn context length / pre-yarn context length = " - "config.max_position_embeddings / config.rope_parameters['original_max_position_embeddings'] = " - f"{implicit_factor}). Using the explicit factor ({factor}) in YaRN. This may cause unexpected " - "behaviour in model usage, please correct the 'original_max_position_embeddings' fields in the model config." - ) - - def _validate_longrope_rope_parameters(self, rope_parameters: dict, ignore_keys: set | None = None): - required_keys = {"rope_type", "short_factor", "long_factor", "rope_theta", "original_max_position_embeddings"} - optional_keys = {"attention_factor", "factor"} - received_keys = set(rope_parameters.keys()) - rope_type = rope_parameters["rope_type"] - self._check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys) - - partial_rotary_factor = rope_parameters.get("partial_rotary_factor", 1.0) - head_dim = getattr(self, "head_dim", self.hidden_size // self.num_attention_heads) - dim = int(head_dim * partial_rotary_factor) - - short_factor = rope_parameters.get("short_factor") - if not isinstance(short_factor, list) and all(isinstance(x, (int, float)) for x in short_factor): - logger.warning(f"`rope_parameters`'s short_factor field must be a list of numbers, got {short_factor}") - if len(short_factor) != dim // 2: - logger.warning( - f"`rope_parameters`'s short_factor field must have length {dim // 2}, got {len(short_factor)}" - ) - - long_factor = rope_parameters.get("long_factor") - if not isinstance(long_factor, list) and all(isinstance(x, (int, float)) for x in long_factor): - logger.warning(f"`rope_parameters`'s long_factor field must be a list of numbers, got {long_factor}") - if len(long_factor) != dim // 2: - logger.warning(f"`rope_parameters`'s long_factor field must have length {dim // 2}, got {len(long_factor)}") - - factor = rope_parameters.get("factor") - original_max_position_embeddings = rope_parameters["original_max_position_embeddings"] - - # Handle Phi3 divergence: we prefer the use of `attention_factor` and/or `factor` over - # `original_max_position_embeddings` to compute internal variables. The latter is undesirable - if factor is None and original_max_position_embeddings is not None: - logger.warning_once( - "This model config has set a `rope_parameters['original_max_position_embeddings']` field, to be used together with " - "`max_position_embeddings` to determine a scaling factor. Please set the `factor` field of `rope_parameters`" - "with this ratio instead -- we recommend the use of this field over `original_max_position_embeddings`, " - "as it is compatible with most model architectures." - ) - elif factor is None and original_max_position_embeddings is None: - logger.warning("Missing required keys in `rope_parameters`: 'factor'") - elif not isinstance(factor, float) or factor < 1.0: - logger.warning(f"`rope_parameters`'s factor field must be a float >= 1, got {factor}") - - attention_factor = rope_parameters.get("attention_factor") - if attention_factor is not None and (not isinstance(attention_factor, float) or attention_factor < 0.0): - logger.warning( - f"`rope_parameters`'s attention_factor field must be a float greater than 0, got {attention_factor}" - ) - - def _validate_llama3_rope_parameters(self, rope_parameters: dict, ignore_keys: set | None = None): - required_keys = { - "rope_type", - "factor", - "original_max_position_embeddings", - "low_freq_factor", - "high_freq_factor", - "rope_theta", - } - rope_type = rope_parameters["rope_type"] - received_keys = set(rope_parameters.keys()) - self._check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys) - - factor = rope_parameters["factor"] - if factor is None or not isinstance(factor, float) or factor < 1.0: - logger.warning(f"`rope_parameters`'s factor field must be a float >= 1, got {factor}") - - low_freq_factor = rope_parameters["low_freq_factor"] - high_freq_factor = rope_parameters["high_freq_factor"] - if low_freq_factor is None or not isinstance(low_freq_factor, float): - logger.warning(f"`rope_parameters`'s low_freq_factor field must be a float, got {low_freq_factor}") - if high_freq_factor is None or not isinstance(high_freq_factor, float): - logger.warning(f"`rope_parameters`'s high_freq_factor field must be a float, got {high_freq_factor}") - if high_freq_factor <= low_freq_factor: - logger.warning( - "`rope_parameters`'s high_freq_factor field must be greater than low_freq_factor, got high_freq_factor=" - f"{high_freq_factor} and low_freq_factor={low_freq_factor}" - ) - - original_max_position_embeddings = rope_parameters["original_max_position_embeddings"] - if original_max_position_embeddings is None or not isinstance(original_max_position_embeddings, int): - logger.warning( - "`rope_parameters`'s original_max_position_embeddings field must be an integer, got " - f"{original_max_position_embeddings}" - ) - if original_max_position_embeddings >= self.max_position_embeddings: - logger.warning( - "`rope_parameters`'s original_max_position_embeddings field must be less than max_position_embeddings, got " - f"{original_max_position_embeddings} and max_position_embeddings={self.max_position_embeddings}" - ) - - @staticmethod - def _check_received_keys( - rope_type: str, - received_keys: set, - required_keys: set, - optional_keys: set | None = None, - ignore_keys: set | None = None, - ): - """Compare the received keys in `config.rope_parameters` against the expected and optional keys""" - # BC: "rope_type" was originally "type" -- let's check for "rope_type" when "type" is present - if "type" in received_keys: - received_keys -= {"type"} - required_keys.add("rope_type") - - optional_keys = optional_keys or set() - if "partial_rotary_factor" not in optional_keys: - optional_keys.add("partial_rotary_factor") - - # Some models need to store model-specific keys, and we don't want to throw warning at them - if ignore_keys is not None: - received_keys -= ignore_keys - - missing_keys = required_keys - received_keys - if missing_keys: - raise KeyError(f"Missing required keys in `rope_parameters` for 'rope_type'='{rope_type}': {missing_keys}") - - unused_keys = received_keys - required_keys - optional_keys - if unused_keys: - logger.warning(f"Unrecognized keys in `rope_parameters` for 'rope_type'='{rope_type}': {unused_keys}") - - -def rope_config_validation(config: RotaryEmbeddingConfigMixin, ignore_keys: set | None = None): - """ - This is a deprecated function. - It has been kept for backward compatibility with custom code models. - """ - warnings.warn( - "`rope_config_validation` is deprecated and has been removed. " - "Its functionality has been moved to RotaryEmbeddingConfigMixin.validate_rope method. " - "PreTrainedConfig inherits this class, so please call self.validate_rope() instead. " - "Also, make sure to use the new rope_parameters syntax. " - "You can call self.standardize_rope_params() in the meantime.", - FutureWarning, - ) - config.standardize_rope_params() - config.validate_rope(ignore_keys=ignore_keys) From 1a215683f344add41da32a492716589c8e9a14f3 Mon Sep 17 00:00:00 2001 From: Yu Shi Jie Date: Fri, 20 Mar 2026 13:14:55 -0700 Subject: [PATCH 41/51] fixes to pass cicd --- litgpt/config.py | 2 +- tests/test_lora.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/litgpt/config.py b/litgpt/config.py index 0e7dbff0c8..e7c83d0507 100644 --- a/litgpt/config.py +++ b/litgpt/config.py @@ -3191,7 +3191,7 @@ def check_indicator_and_length( name="DeepSeek-V3", hf_config=dict(org="deepseek-ai", name="DeepSeek-V3"), block_size=163840, - vocab_size=129280, + vocab_size=128000, padded_vocab_size=129280, n_layer=61, n_head=128, diff --git a/tests/test_lora.py b/tests/test_lora.py index 46ba75b384..7001cc8c1b 100644 --- a/tests/test_lora.py +++ b/tests/test_lora.py @@ -529,6 +529,9 @@ def test_lora_gpt_init_weights(): @pytest.mark.parametrize("name", [c["name"] for c in config_module.configs]) def test_base_model_can_be_lora_loaded(name): kwargs = {"n_layer": 2, "n_head": 8, "n_query_groups": 4, "n_embd": 16, "padded_vocab_size": 32} + config = config_module.Config.from_name(name, **kwargs) + if config.latent_attention is not None: + pytest.skip("LoRA does not support latent attention") base_model = BaseGPT.from_name(name, **kwargs) base_model_state_dict = base_model.state_dict() lora_model = LoRAGPT.from_name( From c758f7294d1b361e3d95688b3758ff99dc6aa3f1 Mon Sep 17 00:00:00 2001 From: Yu Shi Jie Date: Sat, 21 Mar 2026 11:35:18 -0700 Subject: [PATCH 42/51] test adapter fix --- tests/test_adapter_v2.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/test_adapter_v2.py b/tests/test_adapter_v2.py index 3a7d17d5e5..f6131eba10 100644 --- a/tests/test_adapter_v2.py +++ b/tests/test_adapter_v2.py @@ -139,6 +139,9 @@ def test_adapter_v2_gpt_init_weights(): @pytest.mark.parametrize("name", [c["name"] for c in config_module.configs]) def test_base_model_can_be_adapter_v2_loaded(name): kwargs = {"n_layer": 2, "n_head": 8, "n_query_groups": 4, "n_embd": 16, "padded_vocab_size": 32} + config = config_module.Config.from_name(name, **kwargs) + if config.latent_attention is not None: + pytest.skip("Adapter V2 does not support latent attention") base_model = BaseGPT.from_name(name, **kwargs) base_model_state_dict = base_model.state_dict() lora_model = AdapterV2GPT.from_name(name, **kwargs, adapter_start_layer=0) From 99d9ef4d072d8296dec09ae971ffd07b2e11b6e6 Mon Sep 17 00:00:00 2001 From: Yu Shi Jie Date: Sat, 21 Mar 2026 11:47:55 -0700 Subject: [PATCH 43/51] prompt matching deepseek v2 (With R1) --- litgpt/prompts.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/litgpt/prompts.py b/litgpt/prompts.py index b5926068ef..7d774d389e 100644 --- a/litgpt/prompts.py +++ b/litgpt/prompts.py @@ -1,4 +1,4 @@ -# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. +x# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import importlib import re from abc import abstractmethod @@ -474,7 +474,7 @@ def model_name_to_prompt_style(model_name: str) -> PromptStyle: return Llama3() if re.search("OLMo-2.*-(Instruct|SFT|DPO)", model_name): return Llama3() - if re.search("R1", model_name): + if re.search(r"(R1|DeepSeek-V3)", model_name): return R1Base() if re.search("FreeWilly2", model_name): return FreeWilly2() From 2533589910146b4ab9bc954606370de604c65915 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 21 Mar 2026 18:48:16 +0000 Subject: [PATCH 44/51] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- litgpt/prompts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/litgpt/prompts.py b/litgpt/prompts.py index 7d774d389e..b716063654 100644 --- a/litgpt/prompts.py +++ b/litgpt/prompts.py @@ -1,4 +1,4 @@ -x# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. +x # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import importlib import re from abc import abstractmethod From d6ec6c7c90c114a373f397f3812c18ab39db913c Mon Sep 17 00:00:00 2001 From: Yu Shi Jie Date: Sat, 21 Mar 2026 17:32:02 -0700 Subject: [PATCH 45/51] fix --- litgpt/prompts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/litgpt/prompts.py b/litgpt/prompts.py index 7d774d389e..ca5ddb0505 100644 --- a/litgpt/prompts.py +++ b/litgpt/prompts.py @@ -1,4 +1,4 @@ -x# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import importlib import re from abc import abstractmethod From 17b1ec00db3f1112e9edc48b8f11e125f12549e8 Mon Sep 17 00:00:00 2001 From: Yu Shi Jie Date: Sun, 22 Mar 2026 04:31:44 -0700 Subject: [PATCH 46/51] fix: properly skip deepseekv3 in test_lora and test_adapter_v2 --- tests/test_adapter_v2.py | 4 ++-- tests/test_lora.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_adapter_v2.py b/tests/test_adapter_v2.py index f6131eba10..14fbf15eba 100644 --- a/tests/test_adapter_v2.py +++ b/tests/test_adapter_v2.py @@ -139,8 +139,8 @@ def test_adapter_v2_gpt_init_weights(): @pytest.mark.parametrize("name", [c["name"] for c in config_module.configs]) def test_base_model_can_be_adapter_v2_loaded(name): kwargs = {"n_layer": 2, "n_head": 8, "n_query_groups": 4, "n_embd": 16, "padded_vocab_size": 32} - config = config_module.Config.from_name(name, **kwargs) - if config.latent_attention is not None: + raw_config = next(c for c in config_module.configs if c["name"] == name) + if raw_config.get("latent_attention") is not None: pytest.skip("Adapter V2 does not support latent attention") base_model = BaseGPT.from_name(name, **kwargs) base_model_state_dict = base_model.state_dict() diff --git a/tests/test_lora.py b/tests/test_lora.py index 7001cc8c1b..662df06ae3 100644 --- a/tests/test_lora.py +++ b/tests/test_lora.py @@ -529,8 +529,8 @@ def test_lora_gpt_init_weights(): @pytest.mark.parametrize("name", [c["name"] for c in config_module.configs]) def test_base_model_can_be_lora_loaded(name): kwargs = {"n_layer": 2, "n_head": 8, "n_query_groups": 4, "n_embd": 16, "padded_vocab_size": 32} - config = config_module.Config.from_name(name, **kwargs) - if config.latent_attention is not None: + raw_config = next(c for c in config_module.configs if c["name"] == name) + if raw_config.get("latent_attention") is not None: pytest.skip("LoRA does not support latent attention") base_model = BaseGPT.from_name(name, **kwargs) base_model_state_dict = base_model.state_dict() From 7b12a99b629a1b2882627c20d8b15bb02d90232a Mon Sep 17 00:00:00 2001 From: Yu Shi Jie Date: Sun, 22 Mar 2026 05:19:50 -0700 Subject: [PATCH 47/51] fix: output dim robustness across diff transformers versions --- tests/test_multihead_latent_attention.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_multihead_latent_attention.py b/tests/test_multihead_latent_attention.py index e7a299fb48..b0fbb2a2a5 100644 --- a/tests/test_multihead_latent_attention.py +++ b/tests/test_multihead_latent_attention.py @@ -226,6 +226,8 @@ def test_deepseek_v3_block(batch_size, seq_len, device): # Run forward passes output_litgpt = block_litgpt(hidden_states, cos, sin) output_hf = block_hf(hidden_states, position_embeddings=(cos, sin), attention_mask=attention_mask) + if isinstance(output_hf, tuple): + output_hf = output_hf[0] assert torch.allclose(output_litgpt, output_hf, atol=1e-5, rtol=1e-4), ( f"Max diff: {(output_litgpt - output_hf).abs().max()}" From 878d8f160e768b627c6e818f2ffd48447755c534 Mon Sep 17 00:00:00 2001 From: Yu Shi Jie Date: Mon, 30 Mar 2026 12:46:06 -0400 Subject: [PATCH 48/51] update new typings --- extensions/thunder/strategies/thunder_ddp.py | 4 ++-- litgpt/scripts/convert_hf_checkpoint.py | 16 ++++++++-------- litgpt/scripts/convert_lit_checkpoint.py | 6 +++--- litgpt/utils.py | 2 +- 4 files changed, 14 insertions(+), 14 deletions(-) diff --git a/extensions/thunder/strategies/thunder_ddp.py b/extensions/thunder/strategies/thunder_ddp.py index 43cad50c4f..d1a9cc5278 100644 --- a/extensions/thunder/strategies/thunder_ddp.py +++ b/extensions/thunder/strategies/thunder_ddp.py @@ -2,7 +2,7 @@ from contextlib import AbstractContextManager, nullcontext from datetime import timedelta -from typing import TYPE_CHECKING, Any, Union +from typing import TYPE_CHECKING, Any import torch import torch.distributed @@ -42,7 +42,7 @@ def __init__( checkpoint_io: CheckpointIO | None = None, precision: Precision | None = None, jit: bool = True, - executors: tuple[Union["Executor", str], ...] | None = None, + executors: tuple["Executor" | str, ...] | None = None, process_group_backend: str | None = None, timeout: timedelta | None = default_pg_timeout, **kwargs: Any, diff --git a/litgpt/scripts/convert_hf_checkpoint.py b/litgpt/scripts/convert_hf_checkpoint.py index 19ecea9fb6..f78612fbe9 100644 --- a/litgpt/scripts/convert_hf_checkpoint.py +++ b/litgpt/scripts/convert_hf_checkpoint.py @@ -725,14 +725,14 @@ def copy_weights_qwen_3( def copy_weights_deepseek_v3( config: Config, - qkv_weights: Dict[int, List[Optional[NotYetLoadedTensor]]], - state_dict: Dict[str, torch.Tensor], - hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], - saver: Optional[incremental_save] = None, - dtype: Optional[torch.dtype] = None, - pbar: Optional[tqdm] = None, - progress_per_file: Optional[float] = None, - debug_mode: Optional[bool] = False, + qkv_weights: dict[int, list[NotYetLoadedTensor | None]], + state_dict: dict[str, torch.Tensor], + hf_weights: dict[str, torch.Tensor | NotYetLoadedTensor], + saver: incremental_save | None = None, + dtype: torch.dtype | None = None, + pbar: tqdm | None = None, + progress_per_file: float | None = None, + debug_mode: bool | None = False, ) -> None: weight_map = { "model.embed_tokens.weight": "transformer.wte.weight", diff --git a/litgpt/scripts/convert_lit_checkpoint.py b/litgpt/scripts/convert_lit_checkpoint.py index ef6d3e167f..21222b7b54 100644 --- a/litgpt/scripts/convert_lit_checkpoint.py +++ b/litgpt/scripts/convert_lit_checkpoint.py @@ -518,10 +518,10 @@ def copy_weights_qwen_3( def copy_weights_deepseek_v3( config: Config, - state_dict: Dict[str, torch.Tensor], - lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], + state_dict: dict[str, torch.Tensor], + lit_weights: dict[str, torch.Tensor | NotYetLoadedTensor], untie_weights: bool = False, - saver: Optional[incremental_save] = None, + saver: incremental_save | None = None, ) -> None: weight_map = { "transformer.wte.weight": "model.embed_tokens.weight", diff --git a/litgpt/utils.py b/litgpt/utils.py index e38eee816e..897a29900c 100644 --- a/litgpt/utils.py +++ b/litgpt/utils.py @@ -380,7 +380,7 @@ def get_default_supported_precision(training: bool) -> str: return "bf16-mixed" if training else "bf16-true" -def _has_fp8_weights(state_dict: Dict[str, Any]) -> bool: +def _has_fp8_weights(state_dict: dict[str, Any]) -> bool: """Check if a state dict contains FP8 weight_scale_inv tensors.""" return any(k.endswith(".weight_scale_inv") for k in state_dict) From d76bdfe30c8a37dd198df28a019994828cba3275 Mon Sep 17 00:00:00 2001 From: Yu Shi Jie Date: Sat, 4 Apr 2026 12:19:56 -0400 Subject: [PATCH 49/51] cicd fix --- extensions/thunder/strategies/thunder_ddp.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/extensions/thunder/strategies/thunder_ddp.py b/extensions/thunder/strategies/thunder_ddp.py index d1a9cc5278..9a0712088b 100644 --- a/extensions/thunder/strategies/thunder_ddp.py +++ b/extensions/thunder/strategies/thunder_ddp.py @@ -1,5 +1,7 @@ """Fabric Strategy to support Thunder DDP: To be upstreamed into Fabric eventually.""" +from __future__ import annotations + from contextlib import AbstractContextManager, nullcontext from datetime import timedelta from typing import TYPE_CHECKING, Any From 22e059c5bd50c1558941f23765340c0ee7ec1ecc Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 4 Apr 2026 16:20:16 +0000 Subject: [PATCH 50/51] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- extensions/thunder/strategies/thunder_ddp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/extensions/thunder/strategies/thunder_ddp.py b/extensions/thunder/strategies/thunder_ddp.py index 9a0712088b..8276818e08 100644 --- a/extensions/thunder/strategies/thunder_ddp.py +++ b/extensions/thunder/strategies/thunder_ddp.py @@ -44,7 +44,7 @@ def __init__( checkpoint_io: CheckpointIO | None = None, precision: Precision | None = None, jit: bool = True, - executors: tuple["Executor" | str, ...] | None = None, + executors: tuple[Executor | str, ...] | None = None, process_group_backend: str | None = None, timeout: timedelta | None = default_pg_timeout, **kwargs: Any, From 6e04ed07499f5138b917c041e4f64595f34d7874 Mon Sep 17 00:00:00 2001 From: Yu Shi Jie Date: Thu, 4 Jun 2026 10:02:27 -0400 Subject: [PATCH 51/51] skip deepseekv3 if transformers<4.56.0: YaRN RoPE factor bug fixed in transformers 4.56.0 (prior versions override explicit factor with max_pos/original_max_pos) --- tests/test_model_deepseek_v3.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/test_model_deepseek_v3.py b/tests/test_model_deepseek_v3.py index 6089bbaa86..c898c3226b 100644 --- a/tests/test_model_deepseek_v3.py +++ b/tests/test_model_deepseek_v3.py @@ -3,6 +3,7 @@ import pytest import torch +import transformers from transformers.models.deepseek_v3 import DeepseekV3Config, DeepseekV3ForCausalLM from litgpt import GPT, Config @@ -13,6 +14,10 @@ @torch.inference_mode() +@pytest.mark.skipif( + transformers.__version__ < "4.56.0", + reason="YaRN RoPE factor bug fixed in transformers 4.56.0 (prior versions override explicit factor with max_pos/original_max_pos)", +) @pytest.mark.parametrize("model_name", ["DeepSeek-V3"]) @pytest.mark.parametrize( ("device", "dtype"),