|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +import torch |
| 8 | +import torch.nn as nn |
| 9 | +from torch.distributed.device_mesh import DeviceMesh |
| 10 | + |
| 11 | +from torchtitan.config import JobConfig, TORCH_DTYPE_MAP |
| 12 | +from torchtitan.distributed import ParallelDims |
| 13 | +from torchtitan.experiments.llama4.infra.parallelize import apply_moe_ep_tp |
| 14 | +from torchtitan.models.deepseek_v3.infra.parallelize import apply_non_moe_tp |
| 15 | +from torchtitan.models.llama3.infra.parallelize import apply_ac |
| 16 | +from torchtitan.tools.logging import logger |
| 17 | + |
| 18 | +from .simple_fsdp import data_parallel, MixedPrecisionPolicy |
| 19 | + |
| 20 | +# Adapted from llama4/infra/parallelize.py |
| 21 | +def parallelize_deepseekv3( |
| 22 | + model: nn.Module, |
| 23 | + parallel_dims: ParallelDims, |
| 24 | + job_config: JobConfig, |
| 25 | +): |
| 26 | + world_mesh = parallel_dims.world_mesh |
| 27 | + # TODO: TP currently cannot handle uneven seq_len because we set |
| 28 | + # `use_local_output=True` to use plain Tensors for legacy reasons. |
| 29 | + # Need to revisit this. |
| 30 | + assert ( |
| 31 | + job_config.training.seq_len % parallel_dims.seq_len_divisor == 0 |
| 32 | + ), f""" |
| 33 | + Sequence length {job_config.training.seq_len} must be divisible by the product of TP degree |
| 34 | + ({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}). |
| 35 | + """ |
| 36 | + |
| 37 | + if ( |
| 38 | + job_config.parallelism.context_parallel_degree > 1 |
| 39 | + and model.model_args.use_flex_attn |
| 40 | + ): |
| 41 | + raise NotImplementedError("CP support for FlexAttention is still in progress.") |
| 42 | + |
| 43 | + if parallel_dims.tp_enabled: |
| 44 | + if job_config.parallelism.enable_async_tensor_parallel: |
| 45 | + # TODO(jianiw): This branch needs to be tested and enabled |
| 46 | + raise NotImplementedError( |
| 47 | + "Currently, async TP is not tested for deepseekv3. \ |
| 48 | + torch.compile is not supported yet, which is required for async TP." |
| 49 | + ) |
| 50 | + |
| 51 | + enable_float8_linear = "float8" in job_config.model.converters |
| 52 | + float8_is_rowwise = job_config.float8.recipe_name in ( |
| 53 | + "rowwise", |
| 54 | + "rowwise_with_gw_hp", |
| 55 | + ) |
| 56 | + |
| 57 | + enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise |
| 58 | + if enable_float8_tensorwise_tp: |
| 59 | + # TODO(jianiw): This branch needs to be tested and enabled |
| 60 | + raise NotImplementedError( |
| 61 | + "Currently, float8 tensorwise TP is not tested for deepseekv3" |
| 62 | + ) |
| 63 | + |
| 64 | + apply_non_moe_tp( |
| 65 | + model, |
| 66 | + world_mesh["tp"], |
| 67 | + loss_parallel=not job_config.parallelism.disable_loss_parallel, |
| 68 | + enable_float8_tensorwise_tp=False, |
| 69 | + enable_async_tp=False, |
| 70 | + ) |
| 71 | + |
| 72 | + if parallel_dims.tp_enabled or parallel_dims.ep_enabled: |
| 73 | + apply_moe_ep_tp( |
| 74 | + model, |
| 75 | + tp_mesh=world_mesh["tp"] if parallel_dims.tp_enabled else None, |
| 76 | + ep_mesh=world_mesh["ep"] if parallel_dims.ep_enabled else None, |
| 77 | + ep_tp_mesh=( |
| 78 | + world_mesh["ep", "tp"] |
| 79 | + if parallel_dims.tp_enabled and parallel_dims.ep_enabled |
| 80 | + else None |
| 81 | + ), |
| 82 | + ) |
| 83 | + |
| 84 | + if job_config.activation_checkpoint.mode != "none": |
| 85 | + apply_ac(model, job_config.activation_checkpoint) |
| 86 | + |
| 87 | + mp_policy = MixedPrecisionPolicy( |
| 88 | + param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], |
| 89 | + reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], |
| 90 | + ) |
| 91 | + |
| 92 | + # apply data parallel |
| 93 | + dp_mesh: DeviceMesh | None = None |
| 94 | + if ( |
| 95 | + parallel_dims.fsdp_enabled |
| 96 | + or parallel_dims.ep_enabled |
| 97 | + or parallel_dims.dp_replicate_enabled |
| 98 | + ): |
| 99 | + if parallel_dims.dp_replicate_enabled: |
| 100 | + if parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled: |
| 101 | + dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") |
| 102 | + dp_mode = "hybrid_shard" |
| 103 | + else: |
| 104 | + dp_mesh_dim_names = ("dp_replicate",) |
| 105 | + dp_mode = "replicate" |
| 106 | + else: |
| 107 | + dp_mesh_dim_names = ("dp_shard_cp",) |
| 108 | + dp_mode = "fully_shard" |
| 109 | + |
| 110 | + dp_mesh = world_mesh[tuple(dp_mesh_dim_names)] |
| 111 | + # the mesh dim names of which the MoE params are sharded on via FSDP/HSDP |
| 112 | + dp_mod_ep_mesh_dim_names = [] |
| 113 | + ep_modules = [] |
| 114 | + ep_shared_experts = [] |
| 115 | + if parallel_dims.ep_enabled: |
| 116 | + if parallel_dims.dp_replicate_enabled: |
| 117 | + dp_mod_ep_mesh_dim_names.append("dp_replicate") |
| 118 | + dp_mod_ep_mesh_dim_names.append("dp_shard_mod_ep") |
| 119 | + for _, transformer_block in model.layers.items(): |
| 120 | + if transformer_block.moe_enabled: |
| 121 | + ep_modules.append(transformer_block.moe.experts) |
| 122 | + ep_shared_experts.append(transformer_block.moe.shared_expert) |
| 123 | + |
| 124 | + if parallel_dims.tp_enabled and parallel_dims.ep_enabled: |
| 125 | + tp_ep_mesh = world_mesh["ep", "tp"] |
| 126 | + else: |
| 127 | + tp_ep_mesh = None |
| 128 | + |
| 129 | + model = data_parallel( |
| 130 | + model, |
| 131 | + dp_mesh, |
| 132 | + dp_mode, |
| 133 | + ac_mode=job_config.activation_checkpoint.mode, |
| 134 | + mp_policy=mp_policy, |
| 135 | + tp_mesh=world_mesh["tp"] if parallel_dims.tp_enabled else None, |
| 136 | + tp_ep_mesh=tp_ep_mesh, |
| 137 | + dp_mod_ep_mesh=world_mesh[tuple(dp_mod_ep_mesh_dim_names)] |
| 138 | + if parallel_dims.ep_enabled |
| 139 | + else None, |
| 140 | + ep_modules=ep_modules, |
| 141 | + ep_shared_experts=ep_shared_experts, |
| 142 | + ) |
| 143 | + if parallel_dims.dp_replicate_enabled: |
| 144 | + logger.info("Applied HSDP to the model") |
| 145 | + else: |
| 146 | + logger.info("Applied FSDP to the model") |
| 147 | + |
| 148 | + if job_config.training.compile: |
| 149 | + torch._inductor.config.reorder_for_peak_memory = False |
| 150 | + torch._dynamo.config.capture_scalar_outputs = True |
| 151 | + model = torch.compile(model, fullgraph=True) |
| 152 | + |
| 153 | + return model |
0 commit comments