Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
66 commits
Select commit Hold shift + click to select a range
2d245cc
conversion for deepseekv3 models between hf and lit; support for load…
ysjprojects Dec 29, 2025
18c851f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 29, 2025
44b6292
Merge branch 'main' into deepseek-models
ysjprojects Feb 8, 2026
edd45a9
fix n_query_group == n_head for deepseekv3
ysjprojects Feb 8, 2026
d17ff83
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 8, 2026
063d71b
minor fixes
ysjprojects Feb 8, 2026
e3f25fe
Merge branch 'deepseek-models' of github.com:Lightning-AI/litgpt into…
ysjprojects Feb 8, 2026
c7a852a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 8, 2026
ea60758
fix import in test_model_deepseek_v3
ysjprojects Feb 8, 2026
a25d0bb
Merge branch 'deepseek-models' of github.com:Lightning-AI/litgpt into…
ysjprojects Feb 8, 2026
1b12177
fix(deepseek-v3): correct weight mappings and FP8Linear patching
ysjprojects Feb 8, 2026
a8b47ec
fix: replace_module -> set_submodule
ysjprojects Feb 8, 2026
8a392fb
fix
ysjprojects Feb 8, 2026
7a47981
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 8, 2026
7d61199
.
ysjprojects Feb 8, 2026
e22a879
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 8, 2026
cc4a3f3
apply_rope_interleave in multihead latent attn
ysjprojects Feb 15, 2026
078f939
test: for deepseek v3 block
ysjprojects Feb 15, 2026
8d4f5f7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 15, 2026
abbfa2b
milestone: deepseek v3 assertClose now at 96%
ysjprojects Feb 15, 2026
243dd0d
deepseek v3 passes completely without rope scaling
ysjprojects Feb 15, 2026
34368bb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 15, 2026
84a72ff
mscale
ysjprojects Feb 15, 2026
d921c6b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 15, 2026
d92c421
fix
ysjprojects Feb 15, 2026
06b651d
fix conflict
ysjprojects Feb 15, 2026
4b2db64
yarn rope
ysjprojects Feb 15, 2026
a4d515d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 15, 2026
d514423
test: test_yarn (with deepseekv3 block)
ysjprojects Feb 15, 2026
c7edaad
Merge branch 'deepseek-models' of github.com:Lightning-AI/litgpt into…
ysjprojects Feb 15, 2026
0a96159
test
ysjprojects Feb 15, 2026
2d91d5b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 15, 2026
05902fe
fix
ysjprojects Feb 15, 2026
17f01be
Merge branch 'deepseek-models' of github.com:Lightning-AI/litgpt into…
ysjprojects Feb 15, 2026
73492a5
debug
ysjprojects Feb 15, 2026
7500b94
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 15, 2026
523d815
debug
ysjprojects Feb 15, 2026
8f3c689
Merge branch 'deepseek-models' of github.com:Lightning-AI/litgpt into…
ysjprojects Feb 15, 2026
968eea7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 15, 2026
90b7dd4
fix test_yarn
ysjprojects Feb 15, 2026
8d44e86
fix
ysjprojects Feb 15, 2026
7c29e71
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 15, 2026
c96b79c
rm debug
ysjprojects Feb 15, 2026
75e5d53
Merge branch 'main' into deepseek-models
ysjprojects Mar 20, 2026
943a55e
fix deepseekv3 test rope params
ysjprojects Mar 20, 2026
d45e69a
deepseek v3 support
ysjprojects Mar 20, 2026
647c2c2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 20, 2026
c246bfe
feat(marketing): add deepseek models to markdowns
ysjprojects Mar 20, 2026
70cfc93
Merge branch 'deepseek-models' of github.com:Lightning-AI/litgpt into…
ysjprojects Mar 20, 2026
7e6829a
clean up
ysjprojects Mar 20, 2026
1a21568
fixes to pass cicd
ysjprojects Mar 20, 2026
6a83e73
Merge branch 'main' into deepseek-models
ysjprojects Mar 21, 2026
c758f72
test adapter fix
ysjprojects Mar 21, 2026
adfacc1
Merge branch 'deepseek-models' of github.com:Lightning-AI/litgpt into…
ysjprojects Mar 21, 2026
99d9ef4
prompt matching deepseek v2 (With R1)
ysjprojects Mar 21, 2026
2533589
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 21, 2026
d6ec6c7
fix
ysjprojects Mar 22, 2026
8511c7e
fix
ysjprojects Mar 22, 2026
17b1ec0
fix: properly skip deepseekv3 in test_lora and test_adapter_v2
ysjprojects Mar 22, 2026
7b12a99
fix: output dim robustness across diff transformers versions
ysjprojects Mar 22, 2026
cd46a75
Merge branch 'main' into deepseek-models
ysjprojects Mar 30, 2026
878d8f1
update new typings
ysjprojects Mar 30, 2026
d76bdfe
cicd fix
ysjprojects Apr 4, 2026
22e059c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 4, 2026
454c973
Merge branch 'main' into deepseek-models
ysjprojects Jun 4, 2026
6e04ed0
skip deepseekv3 if transformers<4.56.0: YaRN RoPE factor bug fixed in…
ysjprojects Jun 4, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ Every model is written from scratch to maximize performance and remove layers of
| Phi 4 | 14B | Microsoft Research | [Abdin et al. 2024](https://arxiv.org/abs/2412.08905) |
| Qwen2.5 | 0.5B, 1.5B, 3B, 7B, 14B, 32B, 72B | Alibaba Group | [Qwen Team 2024](https://qwenlm.github.io/blog/qwen2.5/) |
| Qwen2.5 Coder | 0.5B, 1.5B, 3B, 7B, 14B, 32B | Alibaba Group | [Hui, Binyuan et al. 2024](https://arxiv.org/abs/2409.12186) |
| DeepSeek-V3 | 671B | DeepSeek AI | [DeepSeek AI 2024](https://huggingface.co/deepseek-ai/DeepSeek-V3) |
| R1 Distill Llama | 8B, 70B | DeepSeek AI | [DeepSeek AI 2025](https://github.com/deepseek-ai/DeepSeek-R1/blob/main/DeepSeek_R1.pdf) |
| ... | ... | ... | ... |

Expand All @@ -129,6 +130,7 @@ Every model is written from scratch to maximize performance and remove layers of
|----|----|----|----|
| CodeGemma | 7B | Google | [Google Team, Google Deepmind](https://ai.google.dev/gemma/docs/codegemma) |
| Code Llama | 7B, 13B, 34B, 70B | Meta AI | [Rozière et al. 2023](https://arxiv.org/abs/2308.12950) |
| DeepSeek-V3 | 671B | DeepSeek AI | [DeepSeek AI 2024](https://huggingface.co/deepseek-ai/DeepSeek-V3) |
| Falcon | 7B, 40B, 180B | TII UAE | [TII 2023](https://falconllm.tii.ae) |
| Falcon 3 | 1B, 3B, 7B, 10B | TII UAE | [TII 2024](https://huggingface.co/blog/falcon3) |
| FreeWilly2 (Stable Beluga 2) | 70B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stable-beluga-large-instruction-fine-tuned-models) |
Expand Down
6 changes: 4 additions & 2 deletions extensions/thunder/strategies/thunder_ddp.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
"""Fabric Strategy to support Thunder DDP: To be upstreamed into Fabric eventually."""

from __future__ import annotations

from contextlib import AbstractContextManager, nullcontext
from datetime import timedelta
from typing import TYPE_CHECKING, Any, Union
from typing import TYPE_CHECKING, Any

import torch
import torch.distributed
Expand Down Expand Up @@ -42,7 +44,7 @@ def __init__(
checkpoint_io: CheckpointIO | None = None,
precision: Precision | None = None,
jit: bool = True,
executors: tuple[Union["Executor", str], ...] | None = None,
executors: tuple[Executor | str, ...] | None = None,
process_group_backend: str | None = None,
timeout: timedelta | None = default_pg_timeout,
**kwargs: Any,
Expand Down
6 changes: 6 additions & 0 deletions litgpt/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from litgpt.prompts import PromptStyle, has_prompt_style, load_prompt_style, save_prompt_style
from litgpt.tokenizer import Tokenizer
from litgpt.utils import (
_has_fp8_weights,
auto_download_checkpoint,
check_file_size_on_cpu_and_warn,
check_nvlink_connectivity,
Expand All @@ -31,6 +32,7 @@
extend_checkpoint_dir,
get_default_supported_precision,
load_checkpoint,
patch_linear_for_fp8,
save_config,
)

Expand Down Expand Up @@ -397,6 +399,8 @@ def distribute(
state_dict = torch.load(
str(self.checkpoint_dir / "lit_model.pth"), mmap=True, map_location="cpu", weights_only=False
)
if _has_fp8_weights(state_dict):
patch_linear_for_fp8(model)
model.load_state_dict(state_dict, assign=True)
model = fabric.setup_module(model, move_to_device=False)

Expand All @@ -421,6 +425,8 @@ def distribute(
map_location="cpu",
weights_only=False,
)
if _has_fp8_weights(state_dict):
patch_linear_for_fp8(model)
model.load_state_dict(state_dict, assign=True)

# cannot use `.setup_module` because it will wrap with DDP
Expand Down
47 changes: 47 additions & 0 deletions litgpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ class Config:
rope_base: int = 10000
rotary_percentage: float = 0.25
rope_condense_ratio: int = 1
rope_interleave: bool = False
rope_adjustments: dict | None = None
rope_interleave: bool = False
# Transformer block (MLP)
Expand Down Expand Up @@ -3191,4 +3192,50 @@ def check_indicator_and_length(

configs.extend(r1_distill_llama)

deepseek_v3 = [
# https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/config.json
dict(
name="DeepSeek-V3",
hf_config=dict(org="deepseek-ai", name="DeepSeek-V3"),
block_size=163840,
vocab_size=128000,
padded_vocab_size=129280,
n_layer=61,
n_head=128,
n_embd=7168,
n_query_groups=128,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm",
norm_eps=1e-6,
mlp_class_name="LLaMAMoE",
intermediate_size=18432,
rope_base=10000,
rope_interleave=True,
latent_attention=dict(
q_lora_rank=1536,
kv_lora_rank=512,
qk_nope_head_dim=128,
qk_rope_head_dim=64,
v_head_dim=128,
),
n_expert=256,
n_shared_expert=1,
n_expert_per_token=8,
n_expert_groups=8,
n_topk_groups=4,
n_topk_scores_per_group=2, # hardcoded in DeepseekV3ForCausalLM
moe_intermediate_size=2048,
first_k_dense_replace=3,
norm_topk_prob=True,
routed_scaling_factor=2.5,
rope_adjustments=dict(
factor=40.0, beta_slow=1.0, beta_fast=32.0, original_max_seq_len=4096, mscale=1.0, mscale_all_dim=1.0
),
),
]

configs.extend(deepseek_v3)

name_to_config = {config["name"]: config for config in configs}
4 changes: 4 additions & 0 deletions litgpt/generate/sequentially.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,11 @@
from litgpt.prompts import PromptStyle, has_prompt_style, load_prompt_style
from litgpt.tokenizer import Tokenizer
from litgpt.utils import (
_has_fp8_weights,
check_valid_checkpoint_dir,
extend_checkpoint_dir,
get_default_supported_precision,
patch_linear_for_fp8,
)


Expand Down Expand Up @@ -240,6 +242,8 @@ def main(

t0 = time.perf_counter()
state_dict = torch.load(str(checkpoint_path), mmap=True, map_location="cpu")
if _has_fp8_weights(state_dict):
patch_linear_for_fp8(model)
# TODO: this assumes that the model fits on CPU. Use lazy_load and make the materialization checkpoint aware
model.load_state_dict(state_dict, assign=True)
print(f"Time to load the model weights: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr)
Expand Down
4 changes: 4 additions & 0 deletions litgpt/generate/tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@
from litgpt.prompts import PromptStyle, has_prompt_style, load_prompt_style
from litgpt.tokenizer import Tokenizer
from litgpt.utils import (
_has_fp8_weights,
check_nvlink_connectivity,
check_valid_checkpoint_dir,
extend_checkpoint_dir,
get_default_supported_precision,
patch_linear_for_fp8,
)


Expand Down Expand Up @@ -205,6 +207,8 @@ def main(
if fabric.global_rank == rank:
t0 = time.perf_counter()
state_dict = torch.load(str(checkpoint_path), mmap=True, map_location="cpu")
if _has_fp8_weights(state_dict):
patch_linear_for_fp8(model)
model.load_state_dict(state_dict, assign=True)
print(f"[{rank}] Time to load the model weights: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr)

Expand Down
10 changes: 3 additions & 7 deletions litgpt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1006,7 +1006,7 @@ def find_correction_dim(num_rotations, dim, base_val, max_pos):
high_dim = math.ceil(high_dim)

low_dim = max(low_dim, 0)
high_dim = min(high_dim, n_elem // 2 - 1)
high_dim = min(high_dim, n_elem - 1)

# Create linear ramp factor for blending
dim_range = torch.arange(n_elem // 2, device=device, dtype=torch.float32)
Expand All @@ -1017,12 +1017,8 @@ def find_correction_dim(num_rotations, dim, base_val, max_pos):
ramp_func = torch.clamp(linear_func, 0.0, 1.0)

# Blend extrapolation and interpolation frequencies
# ramp_func = 0 -> use interpolation (scaled), ramp_func = 1 -> use extrapolation (unscaled)
theta_extrapolation_factor = ramp_func
theta = (
theta_interpolation * (1 - theta_extrapolation_factor)
+ theta_extrapolation * theta_extrapolation_factor
)
# ramp_func = 0 -> use extrapolation (unscaled), ramp_func = 1 -> use interpolation (scaled)
theta = theta_interpolation * ramp_func + theta_extrapolation * (1 - ramp_func)
elif "original_max_seq_len" in extra_config:
# Llama3-style RoPE scaling
orig_context_len = extra_config["original_max_seq_len"]
Expand Down
2 changes: 1 addition & 1 deletion litgpt/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,7 @@ def model_name_to_prompt_style(model_name: str) -> PromptStyle:
return Llama3()
if re.search("OLMo-2.*-(Instruct|SFT|DPO)", model_name):
return Llama3()
if re.search("R1", model_name):
if re.search(r"(R1|DeepSeek-V3)", model_name):
return R1Base()
if re.search("FreeWilly2", model_name):
return FreeWilly2()
Expand Down
85 changes: 85 additions & 0 deletions litgpt/scripts/convert_hf_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,6 +723,87 @@ def copy_weights_qwen_3(
pbar.update(progress_per_file)


def copy_weights_deepseek_v3(
config: Config,
qkv_weights: dict[int, list[NotYetLoadedTensor | None]],
state_dict: dict[str, torch.Tensor],
hf_weights: dict[str, torch.Tensor | NotYetLoadedTensor],
saver: incremental_save | None = None,
dtype: torch.dtype | None = None,
pbar: tqdm | None = None,
progress_per_file: float | None = None,
debug_mode: bool | None = False,
) -> None:
weight_map = {
"model.embed_tokens.weight": "transformer.wte.weight",
"model.layers.{}.input_layernorm.weight": "transformer.h.{}.norm_1.weight",
"model.layers.{}.self_attn.q_a_proj.weight": "transformer.h.{}.attn.q_a_proj.weight",
"model.layers.{}.self_attn.q_a_proj.weight_scale_inv": "transformer.h.{}.attn.q_a_proj.weight_scale_inv",
"model.layers.{}.self_attn.q_a_layernorm.weight": "transformer.h.{}.attn.q_a_norm.weight",
"model.layers.{}.self_attn.q_b_proj.weight": "transformer.h.{}.attn.q_b_proj.weight",
"model.layers.{}.self_attn.q_b_proj.weight_scale_inv": "transformer.h.{}.attn.q_b_proj.weight_scale_inv",
"model.layers.{}.self_attn.kv_a_proj_with_mqa.weight": "transformer.h.{}.attn.kv_a_proj_with_mqa.weight",
"model.layers.{}.self_attn.kv_a_proj_with_mqa.weight_scale_inv": "transformer.h.{}.attn.kv_a_proj_with_mqa.weight_scale_inv",
"model.layers.{}.self_attn.kv_a_layernorm.weight": "transformer.h.{}.attn.kv_a_norm.weight",
"model.layers.{}.self_attn.kv_b_proj.weight": "transformer.h.{}.attn.kv_b_proj.weight",
"model.layers.{}.self_attn.kv_b_proj.weight_scale_inv": "transformer.h.{}.attn.kv_b_proj.weight_scale_inv",
"model.layers.{}.self_attn.o_proj.weight": "transformer.h.{}.attn.proj.weight",
"model.layers.{}.self_attn.o_proj.weight_scale_inv": "transformer.h.{}.attn.proj.weight_scale_inv",
"model.layers.{}.post_attention_layernorm.weight": "transformer.h.{}.norm_2.weight",
"model.norm.weight": "transformer.ln_f.weight",
"lm_head.weight": "lm_head.weight",
}
if (
config.mlp_class_name == "LLaMAMoE"
): # Deepseek V3 has both MoE and MLP layers (specified with `first_k_dense_replace`), but we treat it as a LLaMAMoE type
weight_map.update(
{
"model.layers.{}.mlp.gate_proj.weight": "transformer.h.{}.mlp.fc_1.weight",
"model.layers.{}.mlp.gate_proj.weight_scale_inv": "transformer.h.{}.mlp.fc_1.weight_scale_inv",
"model.layers.{}.mlp.up_proj.weight": "transformer.h.{}.mlp.fc_2.weight",
"model.layers.{}.mlp.up_proj.weight_scale_inv": "transformer.h.{}.mlp.fc_2.weight_scale_inv",
"model.layers.{}.mlp.down_proj.weight": "transformer.h.{}.mlp.proj.weight",
"model.layers.{}.mlp.down_proj.weight_scale_inv": "transformer.h.{}.mlp.proj.weight_scale_inv",
"model.layers.{}.mlp.experts.{}.gate_proj.weight": "transformer.h.{}.mlp.experts.{}.fc_1.weight",
"model.layers.{}.mlp.experts.{}.gate_proj.weight_scale_inv": "transformer.h.{}.mlp.experts.{}.fc_1.weight_scale_inv",
"model.layers.{}.mlp.experts.{}.up_proj.weight": "transformer.h.{}.mlp.experts.{}.fc_2.weight",
"model.layers.{}.mlp.experts.{}.up_proj.weight_scale_inv": "transformer.h.{}.mlp.experts.{}.fc_2.weight_scale_inv",
"model.layers.{}.mlp.experts.{}.down_proj.weight": "transformer.h.{}.mlp.experts.{}.proj.weight",
"model.layers.{}.mlp.experts.{}.down_proj.weight_scale_inv": "transformer.h.{}.mlp.experts.{}.proj.weight_scale_inv",
"model.layers.{}.mlp.shared_experts.gate_proj.weight": "transformer.h.{}.mlp.shared_experts.fc_1.weight",
"model.layers.{}.mlp.shared_experts.gate_proj.weight_scale_inv": "transformer.h.{}.mlp.shared_experts.fc_1.weight_scale_inv",
"model.layers.{}.mlp.shared_experts.up_proj.weight": "transformer.h.{}.mlp.shared_experts.fc_2.weight",
"model.layers.{}.mlp.shared_experts.up_proj.weight_scale_inv": "transformer.h.{}.mlp.shared_experts.fc_2.weight_scale_inv",
"model.layers.{}.mlp.shared_experts.down_proj.weight": "transformer.h.{}.mlp.shared_experts.proj.weight",
"model.layers.{}.mlp.shared_experts.down_proj.weight_scale_inv": "transformer.h.{}.mlp.shared_experts.proj.weight_scale_inv",
"model.layers.{}.mlp.gate.weight": "transformer.h.{}.mlp.gate.weight",
"model.layers.{}.mlp.gate.e_score_correction_bias": "transformer.h.{}.mlp.gate.e_score_correction_bias",
}
)
else:
raise NotImplementedError

if progress_per_file is not None:
progress_per_file = progress_per_file / max(1, len(hf_weights) + len(qkv_weights))

for from_name, param in hf_weights.items():
name_template, *ids = layer_template(from_name, num_matches=2)
to_name = weight_map[name_template]
param = load_param(param, from_name, dtype, verbose=debug_mode)
if to_name is None:
continue
to_name = to_name.format(*ids)
if saver is not None:
param = saver.store_early(param)
state_dict[to_name] = param

if progress_per_file is not None:
pbar.update(progress_per_file)

if "lm_head.weight" not in state_dict:
state_dict["lm_head.weight"] = state_dict["transformer.wte.weight"]


def qkv_reassemble(
param: torch.Tensor | NotYetLoadedTensor, config: Config
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
Expand Down Expand Up @@ -822,6 +903,10 @@ def convert_hf_checkpoint(
# holder to reconstitute the split q, k, v
qkv_weights = {}
copy_fn = partial(copy_weights_qwen_3, config, qkv_weights)
elif model_name.lower().startswith("deepseek"):
# holder to reconstitute the split q, k, v
qkv_weights = {}
copy_fn = partial(copy_weights_deepseek_v3, config, qkv_weights)
elif config.mlp_class_name in ("LLaMAMLP", "GemmaMLP", "LLaMAMoE"):
# holder to reconstitute the split q, k, v
qkv_weights = {}
Expand Down
Loading