Skip to content

[WIP] expert parallel dp2ep #1324

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion run_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ set -ex
# LOG_RANK=0,1 NGPU=4 ./run_train.sh
NGPU=${NGPU:-"8"}
export LOG_RANK=${LOG_RANK:-0}
CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/models/llama3/train_configs/debug_model.toml"}
CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/experiments/llama4/train_configs/debug_model.toml"}

overrides=""
if [ $# -ne 0 ]; then
Expand Down
25 changes: 13 additions & 12 deletions torchtitan/components/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@
LR_SCHEDULER = "lr_scheduler"
DATALOADER = "dataloader"
TRAIN_STATE = "train_state"
# For now, we will manually pop the freqs_cis buffer, as we made this permanent
# temporarily and we don't want to include it in the exported state_dict.
# Context: https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama3/model.py#L404
excluded_parameters_for_model_only = {"freqs_cis"}


class AsyncMode(str, enum.Enum):
Expand All @@ -53,7 +57,10 @@ class ModelWrapper(Stateful):
def __init__(self, model: nn.Module | list[nn.Module]) -> None:
self.model = [model] if isinstance(model, nn.Module) else model
self.cache_state_dict = {
k: v for sd in map(get_model_state_dict, self.model) for k, v in sd.items()
k: v
for sd in map(get_model_state_dict, self.model)
for k, v in sd.items()
if k not in excluded_parameters_for_model_only
}

def state_dict(self) -> dict[str, Any]:
Expand All @@ -69,7 +76,10 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
# `set_model_state_dict()` does change the keys of the input state_dict,
# we will need to reinitialize the cache_state_dict.
self.cache_state_dict = {
k: v for sd in map(get_model_state_dict, self.model) for k, v in sd.items()
k: v
for sd in map(get_model_state_dict, self.model)
for k, v in sd.items()
if k not in excluded_parameters_for_model_only
}


Expand All @@ -81,12 +91,6 @@ class SaveDone:
pass


# For now, we will manually pop the freqs_cis buffer, as we made this permanent
# temporarily and we don't want to include it in the exported state_dict.
# Context: https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama3/model.py#L404
excluded_parameters_for_model_only = {"freqs_cis"}


@torch.no_grad()
def save_with_gc(state, checkpoint_id):
dcp.save(state, checkpoint_id=checkpoint_id)
Expand Down Expand Up @@ -568,10 +572,7 @@ def _states_to_load(self, model_only: bool) -> dict[str, Any]:
"""
# For the first step, we will only load the model weights.
if model_only:
sd = self.states[MODEL].state_dict()
for k in excluded_parameters_for_model_only:
sd.pop(k, None)
return sd
return {MODEL: self.states[MODEL]}

for exclude_key in self.exclude_from_loading:
if exclude_key not in self.states:
Expand Down
8 changes: 8 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,14 @@ class Parallelism:
The default value is 'allgather'.
"""

expert_parallel_degree: int = 1
"""
Expert parallelism degree. 1 means disabled.
Currently, only "dp2ep" is supported, with the following constraints:
context_parallel_degree <= expert_parallel_degree <= data_parallel_shard_degree * context_parallel_degree
Note that this is still an experimental feature.
"""


@dataclass
class Checkpoint:
Expand Down
81 changes: 79 additions & 2 deletions torchtitan/distributed/parallel_dims.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,23 @@ class ParallelDims:
cp: int
tp: int
pp: int
ep: int
world_size: int
enable_loss_parallel: bool

def __post_init__(self):
self._validate()

def _validate(self):
dp_replicate, dp_shard, cp, tp, pp = (
dp_replicate, dp_shard, cp, tp, pp, ep = (
self.dp_replicate,
self.dp_shard,
self.cp,
self.tp,
self.pp,
self.ep,
)
for d in (dp_replicate, cp, tp, pp):
for d in (dp_replicate, cp, tp, pp, ep):
assert d >= 1, "Parallelism degree should be >= 1, except for dp_shard"

assert dp_shard == -1 or dp_shard >= 1, " dp_shard must -1 or >=1."
Expand All @@ -50,7 +52,78 @@ def _validate(self):
f"cp({cp}) * tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})"
)

if ep > 1:
# EP would borrow all cp and some dp_shard degree
assert ep % cp == 0 and (dp_shard * cp) % ep == 0

def _build_mesh_with_ep(self, device_type):
# With ep, dp_shard and ep are derived submeshes:
# dp_shard = dp_shard_mod_ep * dp_shard_in_ep
# ep = dp_shard_in_ep * cp
dp_shard_mod_ep = self.dp_shard * self.cp // self.ep
dp_shard_in_ep = self.ep // self.cp

dims = []
names = []
for d, name in zip(
[
self.pp,
self.dp_replicate,
dp_shard_mod_ep,
dp_shard_in_ep,
self.cp,
self.tp,
],
["pp", "dp_replicate", "dp_shard_mod_ep", "dp_shard_in_ep", "cp", "tp"],
):
# dp_shard_mod_ep is needed even if it's 1, whose FSDP wrapping
# helps the MoE layers do mixed precision training
if d > 1 or name == "dp_shard_mod_ep":
dims.append(d)
names.append(name)

logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}")
mesh = init_device_mesh(device_type, dims, mesh_dim_names=names)

# Create all the submesh here to ensure all required process groups are
# initialized:
# Mesh for data loading (no communication on this mesh)
dp_mesh_dim_names = []
# Mesh for param sharding
dp_shard_cp_mesh_dim_names = []
# Mesh for loss all-reduce
dp_cp_mesh_dim_names = []
# Mesh for ep
ep_mesh_dim_names = []

if self.dp_replicate_enabled:
dp_mesh_dim_names.append("dp_replicate")
dp_cp_mesh_dim_names.append("dp_replicate")
# dp_shard_mod_ep is always needed, even if it's 1
dp_mesh_dim_names.append("dp_shard_mod_ep")
dp_shard_cp_mesh_dim_names.append("dp_shard_mod_ep")
dp_cp_mesh_dim_names.append("dp_shard_mod_ep")
if "dp_shard_in_ep" in names:
dp_mesh_dim_names.append("dp_shard_in_ep")
dp_shard_cp_mesh_dim_names.append("dp_shard_in_ep")
dp_cp_mesh_dim_names.append("dp_shard_in_ep")
ep_mesh_dim_names.append("dp_shard_in_ep")
if self.cp_enabled:
dp_shard_cp_mesh_dim_names.append("cp")
dp_cp_mesh_dim_names.append("cp")
ep_mesh_dim_names.append("cp")

mesh[tuple(dp_mesh_dim_names)]._flatten(mesh_dim_name="dp")
mesh[tuple(dp_shard_cp_mesh_dim_names)]._flatten(mesh_dim_name="dp_shard_cp")
mesh[tuple(dp_cp_mesh_dim_names)]._flatten(mesh_dim_name="dp_cp")
mesh[tuple(ep_mesh_dim_names)]._flatten(mesh_dim_name="ep")

return mesh

def build_mesh(self, device_type: str) -> DeviceMesh:
if self.ep > 1:
return self._build_mesh_with_ep(device_type)

dims = []
names = []
for d, name in zip(
Expand Down Expand Up @@ -143,3 +216,7 @@ def loss_parallel_enabled(self):
@cached_property
def non_data_parallel_size(self):
return self.cp * self.tp * self.pp

@property
def ep_enabled(self):
return self.ep > 1
Loading
Loading