Skip to content

add support for simplefsdp+ep #1529

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/integration_test_8gpu_simple_fsdp.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,4 @@ jobs:
python -m pip install --force-reinstall --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126

mkdir artifacts-to-be-uploaded
python -m torchtitan.experiments.simple_fsdp.tests.integration_tests artifacts-to-be-uploaded --ngpu 8
python -m torchtitan.experiments.simple_fsdp.tests.llama3_integration_tests artifacts-to-be-uploaded --ngpu 8
73 changes: 38 additions & 35 deletions torchtitan/distributed/expert_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,12 @@
class _A2A(torch.autograd.Function):
@staticmethod
def forward(ctx, x, out_splits, in_splits, group):
if isinstance(out_splits, torch.Tensor):
out_splits = out_splits.tolist()
if isinstance(in_splits, torch.Tensor):
in_splits = in_splits.tolist()
T_out = int(sum(out_splits))

y = x.new_empty((T_out,) + tuple(x.shape[1:])) # allocate by output splits
dist.all_to_all_single(y, x.contiguous(), out_splits, in_splits, group=group)

Expand Down Expand Up @@ -171,7 +176,6 @@ def __init__(self):
def _token_dispatch(self, mod, inputs, device_mesh):
# annotate module input placements/sharding with input_layouts
routed_input, num_tokens_per_expert = inputs
ep_size = device_mesh.shape[0]

# generate the input splits and output splits for all-to-all
with torch.no_grad():
Expand All @@ -183,20 +187,15 @@ def _token_dispatch(self, mod, inputs, device_mesh):
num_tokens_per_expert,
group=device_mesh.get_group(),
)
input_splits = (
num_tokens_per_expert.view(ep_size, -1)
.sum(dim=1)
.to(torch.device("cpu"), non_blocking=True)
# NOTE: this would incur a device-to-host sync
self.input_splits = (
num_tokens_per_expert.view(device_mesh.shape[0], -1).sum(dim=1).tolist()
)
output_splits = (
num_tokens_per_expert_group.view(ep_size, -1)
self.output_splits = (
num_tokens_per_expert_group.view(device_mesh.shape[0], -1)
.sum(dim=1)
.to(torch.device("cpu"), non_blocking=True)
.tolist()
)
# NOTE: this would incur a device-to-host sync
torch.cuda.current_stream().synchronize()
self.input_splits = input_splits.tolist()
self.output_splits = output_splits.tolist()

# perform all-to-all
routed_input = all_to_all_single_autograd(
Expand Down Expand Up @@ -321,41 +320,45 @@ def wrapper(
w2: torch.Tensor,
w3: torch.Tensor,
x: torch.Tensor,
num_tokens_per_expert: torch.Tensor,
num_tokens_per_expert: torch.Tensor | None = None,
) -> torch.Tensor:
global TOKEN_GROUP_ALIGN_SIZE_M
if isinstance(w1, DTensor):
w1 = w1.to_local()
w2 = w2.to_local()
w3 = w3.to_local()

from torchtitan.experiments.kernels.moe.indices import generate_permute_indices

experts_per_ep_rank = w1.shape[0]
num_ep_ranks = num_tokens_per_expert.shape[0] // experts_per_ep_rank

with torch.no_grad():
(
permuted_indices,
num_tokens_per_expert,
_, # offsets,
) = generate_permute_indices(
num_tokens_per_expert,
experts_per_ep_rank,
num_ep_ranks,
x.shape[0] + experts_per_ep_rank * TOKEN_GROUP_ALIGN_SIZE_M,
TOKEN_GROUP_ALIGN_SIZE_M,
if num_tokens_per_expert is not None:
from torchtitan.experiments.kernels.moe.indices import (
generate_permute_indices,
)

x = torch.vstack((x, x.new_zeros((x.shape[-1]))))
input_shape = x.shape
x = x[permuted_indices, :]
experts_per_ep_rank = w1.shape[0]
num_ep_ranks = num_tokens_per_expert.shape[0] // experts_per_ep_rank

with torch.no_grad():
(
permuted_indices,
num_tokens_per_expert,
_, # offsets,
) = generate_permute_indices(
num_tokens_per_expert,
experts_per_ep_rank,
num_ep_ranks,
x.shape[0] + experts_per_ep_rank * TOKEN_GROUP_ALIGN_SIZE_M,
TOKEN_GROUP_ALIGN_SIZE_M,
)

x = torch.vstack((x, x.new_zeros((x.shape[-1]))))
input_shape = x.shape
x = x[permuted_indices, :]

out = func(w1, w2, w3, x, num_tokens_per_expert)

out_unpermuted = out.new_empty(input_shape)
out_unpermuted[permuted_indices, :] = out
out = out_unpermuted[:-1]
if num_tokens_per_expert is not None:
out_unpermuted = out.new_empty(input_shape)
out_unpermuted[permuted_indices, :] = out
out = out_unpermuted[:-1]

return out

Expand Down
9 changes: 9 additions & 0 deletions torchtitan/experiments/simple_fsdp/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,18 @@ This folder includes an experimental frontend implementation for [SimpleFSDP: Si

### Enable SimpleFSDP Training

#### Training Llama3 models

```bash
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --model.name llama3_simple_fsdp --training.compile
```

#### Training DeepSeek_v3 models

```bash
CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" ./run_train.sh --model.name deepseekv3_simple_fsdp --training.compile
```

### Composability Support

Some of the features require the updates from PyTorch, with which we are working on providing composability support for the following features:
Expand All @@ -28,6 +36,7 @@ Some of the features require the updates from PyTorch, with which we are working
|Tensor Parallelism| ✅ |
|Context Parallelism| ✅ |
|Pipeline Parallelism| ✅ |
|Expert Parallelism| ✅ |
|Distributed Checkpointing| ✅ |
|Float8 Training| 🚧 |

Expand Down
28 changes: 25 additions & 3 deletions torchtitan/experiments/simple_fsdp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,20 @@

from torchtitan.components.loss import build_cross_entropy_loss
from torchtitan.components.lr_scheduler import build_lr_schedulers
from torchtitan.components.optimizer import build_optimizers
from torchtitan.components.optimizer import (
build_optimizers,
build_optimizers_with_moe_load_balancing,
)
from torchtitan.components.tokenizer import build_hf_tokenizer
from torchtitan.datasets.hf_datasets import build_hf_dataloader
from torchtitan.models.deepseek_v3 import deepseekv3_configs
from torchtitan.models.llama3 import llama3_configs, pipeline_llama
from torchtitan.protocols.train_spec import register_train_spec, TrainSpec
from .deepseek_v3_model import SimpleFSDPDeepSeekV3Model
from .deepseek_v3_parallelize import parallelize_deepseekv3

from .model import SimpleFSDPTransformer
from .parallelize import parallelize_llama
from .llama3_model import SimpleFSDPTransformer
from .llama3_parallelize import parallelize_llama

register_train_spec(
TrainSpec(
Expand All @@ -31,3 +37,19 @@
build_loss_fn=build_cross_entropy_loss,
)
)


register_train_spec(
TrainSpec(
name="deepseekv3_simple_fsdp",
model_cls=SimpleFSDPDeepSeekV3Model,
model_args=deepseekv3_configs,
parallelize_fn=parallelize_deepseekv3,
pipelining_fn=pipeline_llama,
build_optimizers_fn=build_optimizers_with_moe_load_balancing,
build_lr_schedulers_fn=build_lr_schedulers,
build_dataloader_fn=build_hf_dataloader,
build_tokenizer_fn=build_hf_tokenizer,
build_loss_fn=build_cross_entropy_loss,
)
)
18 changes: 18 additions & 0 deletions torchtitan/experiments/simple_fsdp/deepseek_v3_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from torchtitan.models.deepseek_v3 import DeepSeekV3Model, DeepSeekV3ModelArgs
from .simple_fsdp import disable_data_parallel


class SimpleFSDPDeepSeekV3Model(DeepSeekV3Model):
def __init__(self, model_args: DeepSeekV3ModelArgs):
super().__init__(model_args)
self.init_weights()

def init_weights(self, *args, **kwargs):
with disable_data_parallel():
super().init_weights(*args, **kwargs)
156 changes: 156 additions & 0 deletions torchtitan/experiments/simple_fsdp/deepseek_v3_parallelize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
import torch.nn as nn
from torch.distributed.device_mesh import DeviceMesh

from torchtitan.config import JobConfig, TORCH_DTYPE_MAP
from torchtitan.distributed import ParallelDims
from torchtitan.experiments.llama4.infra.parallelize import apply_moe_ep_tp
from torchtitan.models.deepseek_v3.infra.parallelize import apply_non_moe_tp
from torchtitan.models.llama3.infra.parallelize import apply_ac
from torchtitan.tools.logging import logger

from .simple_fsdp import data_parallel, MixedPrecisionPolicy

# Adapted from llama4/infra/parallelize.py
def parallelize_deepseekv3(
model: nn.Module,
parallel_dims: ParallelDims,
job_config: JobConfig,
):
world_mesh = parallel_dims.world_mesh
# TODO: TP currently cannot handle uneven seq_len because we set
# `use_local_output=True` to use plain Tensors for legacy reasons.
# Need to revisit this.
assert (
job_config.training.seq_len % parallel_dims.seq_len_divisor == 0
), f"""
Sequence length {job_config.training.seq_len} must be divisible by the product of TP degree
({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}).
"""

if (
job_config.parallelism.context_parallel_degree > 1
and model.model_args.use_flex_attn
):
raise NotImplementedError("CP support for FlexAttention is still in progress.")

if parallel_dims.tp_enabled:
if job_config.parallelism.enable_async_tensor_parallel:
# TODO(jianiw): This branch needs to be tested and enabled
raise NotImplementedError(
"Currently, async TP is not tested for deepseekv3. \
torch.compile is not supported yet, which is required for async TP."
)

enable_float8_linear = "float8" in job_config.model.converters
float8_is_rowwise = job_config.float8.recipe_name in (
"rowwise",
"rowwise_with_gw_hp",
)

enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise
if enable_float8_tensorwise_tp:
# TODO(jianiw): This branch needs to be tested and enabled
raise NotImplementedError(
"Currently, float8 tensorwise TP is not tested for deepseekv3"
)

apply_non_moe_tp(
model,
world_mesh["tp"],
loss_parallel=not job_config.parallelism.disable_loss_parallel,
enable_float8_tensorwise_tp=False,
enable_async_tp=False,
)

if parallel_dims.tp_enabled or parallel_dims.ep_enabled:
apply_moe_ep_tp(
model,
tp_mesh=world_mesh["tp"] if parallel_dims.tp_enabled else None,
ep_mesh=world_mesh["ep"] if parallel_dims.ep_enabled else None,
ep_tp_mesh=(
world_mesh["ep", "tp"]
if parallel_dims.tp_enabled and parallel_dims.ep_enabled
else None
),
etp_enabled=parallel_dims.etp_enabled,
)

if job_config.activation_checkpoint.mode != "none":
apply_ac(model, job_config.activation_checkpoint)

mp_policy = MixedPrecisionPolicy(
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
)

# apply data parallel
dp_mesh: DeviceMesh | None = None
if (
parallel_dims.fsdp_enabled
or parallel_dims.ep_enabled
or parallel_dims.dp_replicate_enabled
):
if parallel_dims.dp_replicate_enabled:
if parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled:
dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp")
dp_mode = "hybrid_shard"
else:
dp_mesh_dim_names = ("dp_replicate",)
dp_mode = "replicate"
else:
dp_mesh_dim_names = ("dp_shard_cp",)
dp_mode = "fully_shard"

dp_mesh = world_mesh[tuple(dp_mesh_dim_names)]
# the mesh dim names of which the MoE params are sharded on via FSDP/HSDP
dp_mod_ep_mesh_dim_names = []
ep_modules = []
ep_shared_experts = []
if parallel_dims.ep_enabled:
if parallel_dims.dp_replicate_enabled:
dp_mod_ep_mesh_dim_names.append("dp_replicate")
dp_mod_ep_mesh_dim_names.append("dp_shard_mod_ep")
for _, transformer_block in model.layers.items():
if transformer_block.moe_enabled:
ep_modules.append(transformer_block.moe.experts)
ep_shared_experts.append(transformer_block.moe.shared_experts)

if not parallel_dims.tp_enabled and parallel_dims.ep_enabled:
tp_ep_mesh = world_mesh["ep"]
elif parallel_dims.tp_enabled and parallel_dims.ep_enabled:
tp_ep_mesh = world_mesh["ep", "tp"]
else:
tp_ep_mesh = None

model = data_parallel(
model,
dp_mesh,
dp_mode,
ac_mode=job_config.activation_checkpoint.mode,
mp_policy=mp_policy,
tp_mesh=world_mesh["tp"] if parallel_dims.tp_enabled else None,
tp_ep_mesh=tp_ep_mesh,
dp_mod_ep_mesh=world_mesh[tuple(dp_mod_ep_mesh_dim_names)]
if parallel_dims.ep_enabled
else None,
ep_modules=ep_modules,
ep_shared_experts=ep_shared_experts,
)
if parallel_dims.dp_replicate_enabled:
logger.info("Applied HSDP to the model")
else:
logger.info("Applied FSDP to the model")

if job_config.training.compile:
torch._inductor.config.reorder_for_peak_memory = False
torch._dynamo.config.capture_scalar_outputs = True
model = torch.compile(model, fullgraph=True)

return model
Loading
Loading