Skip to content

[DSV3] Forward and backward pass for single GPU #1320

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions torchtitan/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@

# Import the built-in models here so that the corresponding register_model_spec()
# will be called.
import torchtitan.models.deepseek_v3 # noqa: F401
import torchtitan.models.llama3 # noqa: F401
125 changes: 125 additions & 0 deletions torchtitan/models/deepseek_v3/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
#
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.

from torchtitan.components.loss import build_cross_entropy_loss
from torchtitan.components.lr_scheduler import build_lr_schedulers
from torchtitan.components.optimizer import build_optimizers
from torchtitan.datasets.hf_datasets import build_hf_dataloader
from torchtitan.datasets.tokenizer.tiktoken import build_tiktoken_tokenizer
from torchtitan.protocols.train_spec import register_train_spec, TrainSpec

from .infra.parallelize import parallelize_deepseekv3
from .model.args import DeepSeekV3ModelArgs
from .model.model import DeepSeekV3Model

__all__ = [
"parallelize_deepseekv3",
"DeepseekV3ModelArgs",
"DeepseekV3Model",
"deepseekv3_configs",
]


deepseekv3_configs = {
"debugmodel": DeepSeekV3ModelArgs(
vocab_size=102400,
dim=256,
inter_dim=10944,
moe_inter_dim=1408,
n_layers=3,
n_dense_layers=1,
n_heads=16,
n_routed_experts=8,
n_shared_experts=2,
n_activated_experts=3,
route_scale=1.0,
q_lora_rank=0,
kv_lora_rank=512,
qk_nope_head_dim=128,
qk_rope_head_dim=64,
v_head_dim=128,
mscale=0.70,
),
"16B": DeepSeekV3ModelArgs(
vocab_size=102400,
dim=2048,
inter_dim=10944,
moe_inter_dim=1408,
n_layers=27,
n_dense_layers=1,
n_heads=16,
n_routed_experts=64,
n_shared_experts=2,
n_activated_experts=6,
route_scale=1.0,
q_lora_rank=0,
kv_lora_rank=512,
qk_nope_head_dim=128,
qk_rope_head_dim=64,
v_head_dim=128,
mscale=0.70,
),
"236B": DeepSeekV3ModelArgs(
vocab_size=102400,
dim=5120,
inter_dim=12288,
moe_inter_dim=1536,
n_layers=60,
n_dense_layers=1,
n_heads=128,
n_routed_experts=160,
n_shared_experts=2,
n_activated_experts=6,
n_expert_groups=8,
n_limited_groups=3,
route_scale=16.0,
q_lora_rank=1536,
kv_lora_rank=512,
qk_nope_head_dim=128,
qk_rope_head_dim=64,
v_head_dim=128,
),
"671B": DeepSeekV3ModelArgs(
vocab_size=129280,
dim=7168,
inter_dim=18432,
moe_inter_dim=2048,
n_layers=61,
n_dense_layers=3,
n_heads=128,
n_routed_experts=256,
n_shared_experts=1,
n_activated_experts=8,
n_expert_groups=8,
n_limited_groups=4,
route_scale=2.5,
score_func="sigmoid",
q_lora_rank=1536,
kv_lora_rank=512,
qk_nope_head_dim=128,
qk_rope_head_dim=64,
v_head_dim=128,
dtype="fp8",
),
}


register_train_spec(
TrainSpec(
name="deepseek_v3",
cls=DeepSeekV3Model,
config=deepseekv3_configs,
parallelize_fn=parallelize_deepseekv3,
pipelining_fn=None,
build_optimizers_fn=build_optimizers,
build_lr_schedulers_fn=build_lr_schedulers,
build_dataloader_fn=build_hf_dataloader,
build_tokenizer_fn=build_tiktoken_tokenizer,
build_loss_fn=build_cross_entropy_loss,
)
)
23 changes: 23 additions & 0 deletions torchtitan/models/deepseek_v3/infra/parallelize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


import torch.nn as nn

from torch.distributed.device_mesh import DeviceMesh

from torchtitan.config_manager import JobConfig
from torchtitan.distributed import ParallelDims


def parallelize_deepseekv3(
model: nn.Module,
world_mesh: DeviceMesh,
parallel_dims: ParallelDims,
job_config: JobConfig,
):
# TODO: Add support for parallelizing the model, this is a placeholder function for now
return model
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
from torchtitan.components.tokenizer import Tokenizer
from torchtitan.config_manager import JobConfig
from torchtitan.protocols.train_spec import BaseModelArgs
from torchtitan.tools.logging import logger


# Reference: https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py
@dataclass
class DeepseekV3ModelArgs(BaseModelArgs):
class DeepSeekV3ModelArgs(BaseModelArgs):
"""
Data class for defining model arguments and hyperparameters.

Expand Down Expand Up @@ -53,7 +54,6 @@ class DeepseekV3ModelArgs(BaseModelArgs):
rope_factor (float): Scaling factor for extended sequence lengths.
beta_fast (int): Fast beta correction factor.
beta_slow (int): Slow beta correction factor.
mscale (float): Scaling factor for extended attention.
"""

max_batch_size: int = 8
Expand Down Expand Up @@ -95,12 +95,63 @@ class DeepseekV3ModelArgs(BaseModelArgs):

def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> None:
"""
TODO: Placeholder for now
Update the model_config config from the given job config.
"""
pass
self.vocab_size = tokenizer.n_words
self.max_seq_len = job_config.training.seq_len

def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]:
"""
TODO: Placeholder for now
Adopted from llama4 implementation.
"""
return 0, 0
nparams_embedding = 0
nparams_moe_router = 0
nparams_shared_expert = 0
nparams_experts = 0
nparams_dense = 0

for name, p in model.named_parameters():
print(name)
if "embedding" in name:
nparams_embedding += p.numel()
nparams_dense += p.numel()
elif "moe.shared_expert" in name:
nparams_shared_expert += p.numel()
elif "moe.router" in name:
nparams_moe_router += p.numel()
elif "moe.experts" in name:
nparams_experts += p.numel()
else:
nparams_dense += p.numel()

nparams_sparse = nparams_moe_router + nparams_shared_expert + nparams_experts
nparams = nparams_dense + nparams_sparse
nparams_sparse_active = (
nparams_moe_router
+ nparams_shared_expert
+ nparams_experts * self.n_activated_experts // self.n_routed_experts
)

logger.info(
f"Total parameter count: dense {nparams_dense:,}, "
f"sparse {nparams_sparse:,}, active {nparams_dense + nparams_sparse_active:,}"
)

l, h, q, t = (
self.n_layers,
self.n_heads,
self.dim // self.n_heads,
seq_len,
)
# Reasoning behind the factor of 12 for the self-attention part of the formula:
# 1. each self-attention has 2 matmul in the forward and 4 in the backward (6)
# 2. the flash attention does 1 more matmul recomputation in the backward
# but recomputation should not be counted in calculating MFU (+0)
# 3. each matmul performs 1 multiplication and 1 addition (*2)
# 4. we follow the convention and do not account for sparsity in causal attention
num_flops_per_token = (
6 * (nparams_dense - nparams_embedding + nparams_sparse_active)
+ 12 * l * h * q * t
)

return nparams, num_flops_per_token
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,17 @@
from torchtitan.models.attention import build_attention
from torchtitan.protocols.train_spec import ModelProtocol

from .args import DeepseekV3ModelArgs
from .args import DeepSeekV3ModelArgs
from .moe import MoE


# Adopted from https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py#L294
def precompute_freqs_cis(args: DeepseekV3ModelArgs) -> torch.Tensor:
# Adapted from https://github.com/DeepSeek-ai/DeepSeek-V3/blob/main/inference/model.py#L294
def precompute_freqs_cis(args: DeepSeekV3ModelArgs) -> torch.Tensor:
"""
Precomputes frequency-based complex exponential values for rotary positional embeddings.

Args:
args (DeepseekV3ModelArgs): Model arguments containing positional embedding parameters.
args (DeepSeekV3ModelArgs): Model arguments containing positional embedding parameters.

Returns:
torch.Tensor: Precomputed complex exponential values for positional embeddings.
Expand Down Expand Up @@ -98,16 +98,21 @@ def linear_ramp_factor(min: float, max: float, dim: int) -> torch.Tensor:
# Basic RoPE frequency calculation
freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))

# YaRN scaling for extended context
# YaRN scaling for extended context. YaRN is used to extend the context length after pre-training.
if seqlen > args.original_seq_len:
low, high = find_correction_range(
beta_fast, beta_slow, dim, base, args.original_seq_len
)
smooth = 1 - linear_ramp_factor(low, high, dim // 2)
freqs = freqs / factor * (1 - smooth) + freqs * smooth

# Create position indices
t = torch.arange(seqlen)

# Outer product: [positions] × [frequencies]
freqs = torch.outer(t, freqs)

# Convert to complex exponentials: e^(i*freq*pos)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
return freqs_cis

Expand Down Expand Up @@ -135,7 +140,7 @@ class Attention(nn.Module):
Multi-head attention (MLA) module.
"""

def __init__(self, model_args: DeepseekV3ModelArgs):
def __init__(self, model_args: DeepSeekV3ModelArgs):
super().__init__()
self.dim = model_args.dim
self.n_heads = model_args.n_heads
Expand Down Expand Up @@ -264,13 +269,13 @@ class TransformerBlock(nn.Module):
Transformer block with attention and feed-forward layers.
"""

def __init__(self, layer_id: int, model_args: DeepseekV3ModelArgs):
def __init__(self, layer_id: int, model_args: DeepSeekV3ModelArgs):

super().__init__()
self.attention = Attention(model_args)
self.attention_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps)
self.ffn_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps)
self.ffn = (
self.moe_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps)
self.moe = (
FeedForward(model_args.dim, model_args.inter_dim)
if layer_id < model_args.n_dense_layers
else MoE(model_args)
Expand All @@ -288,16 +293,16 @@ def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor):
torch.Tensor: Output tensor with the same shape as the input.
"""
x = x + self.attention(self.attention_norm(x), freqs_cis)
x = x + self.ffn(self.ffn_norm(x))
x = x + self.moe(self.moe_norm(x))
return x


class Transformer(nn.Module, ModelProtocol):
class DeepSeekV3Model(nn.Module, ModelProtocol):
"""
Deepseek-V3 Transformer model with attention and feed-forward layers.
DeepSeek-V3 Transformer model with attention and feed-forward layers.
"""

def __init__(self, model_args: DeepseekV3ModelArgs):
def __init__(self, model_args: DeepSeekV3ModelArgs):
super().__init__()
self.max_seq_len = model_args.max_seq_len
self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim)
Expand Down Expand Up @@ -327,10 +332,11 @@ def forward(self, tokens: torch.Tensor):
torch.Tensor: Logits tensor of shape (batch_size, vocab_size).
"""
h = self.tok_embeddings(tokens)

for layer in self.layers:
h = layer(h, self.freqs_cis)
h = self.norm(h)[:, -1]
output = self.output(h)
h = self.norm(h)
output = self.output(h) # (batch_size, seq_len, dim)
return output

def init_weights(self, buffer_device: torch.device | None = None) -> None:
Expand Down
Loading
Loading