Skip to content

Commit 91c2e8e

Browse files
committed
fix simplefsdp gradient_divide_factor
1 parent 5d8e2d5 commit 91c2e8e

File tree

2 files changed

+80
-6
lines changed

2 files changed

+80
-6
lines changed

torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -125,18 +125,19 @@ def parallelize_deepseekv3(
125125
):
126126
experts_shard_dim = 1
127127

128+
# when EP is enable, the shared experts' gradient reduce is done over
129+
# dp_mod_ep_mesh instead of whole dp_mesh.
130+
# we add a `fsdp_gradient_divide_factor` to scale gradient over dp_mesh
131+
# to be consistent with data.
128132
transformer_block.moe.experts = data_parallel(
129133
transformer_block.moe.experts,
130134
dp_mod_ep_mesh,
131135
dp_mode,
132136
ac_mode=job_config.activation_checkpoint.mode,
133137
mp_policy=mp_policy,
134138
shard_dim=experts_shard_dim,
139+
reduction_divide_factor=parallel_dims.fsdp_gradient_divide_factor,
135140
)
136-
# TODO(ruisizhang123): support set_gradient_divide_factor in simplefsdp
137-
# transformer_block.moe.experts.set_gradient_divide_factor(
138-
# parallel_dims.fsdp_gradient_divide_factor,
139-
# )
140141

141142
model = data_parallel(
142143
model,

torchtitan/experiments/simple_fsdp/simple_fsdp.py

Lines changed: 75 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
Shard,
2121
)
2222
from torch.distributed.device_mesh import _mesh_resources, DeviceMesh
23+
from torch.distributed.fsdp._fully_shard._fsdp_collectives import _div_if_needed
2324
from torch.distributed.tensor._dtensor_spec import DTensorSpec
2425
from torch.distributed.tensor._redistribute import redistribute_local_tensor
2526
from torch.distributed.tensor.placement_types import _StridedShard, Placement
@@ -49,6 +50,69 @@ class MixedPrecisionPolicy:
4950
reduce_dtype: Optional[torch.dtype] = None
5051

5152

53+
class _ScaledPartial(Partial):
54+
# A subclass of Partial placement that allows user to perform gradient reduction with a custom
55+
# factor (reduction_divide_factor) other than the default world size.
56+
def __init__(
57+
self,
58+
reduction_divide_factor: float,
59+
reduce_dtype: Optional[torch.dtype] = None,
60+
):
61+
self.reduction_divide_factor = reduction_divide_factor
62+
self.reduce_dtype = reduce_dtype
63+
super().__init__(reduce_op="sum")
64+
65+
def _get_reduction_divide_factors(
66+
self,
67+
) -> tuple[Optional[float], Optional[float]]:
68+
"""
69+
the logic follows
70+
https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py#L688
71+
"""
72+
assert self.reduce_dtype in (
73+
torch.float32,
74+
torch.bfloat16,
75+
), "only support reduce_dtype to be fp32/bf16"
76+
pre_factor, post_factor = None, self.reduction_divide_factor
77+
return pre_factor, post_factor
78+
79+
def _reduce_value(
80+
self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int
81+
) -> torch.Tensor:
82+
# for all_reduce in DDP
83+
(
84+
pre_factor,
85+
post_factor,
86+
) = self._get_reduction_divide_factors()
87+
if pre_factor is not None:
88+
_div_if_needed(tensor, pre_factor)
89+
reduced = super()._reduce_value(tensor, mesh, mesh_dim)
90+
if post_factor is not None:
91+
_div_if_needed(reduced, post_factor)
92+
return reduced
93+
94+
def _reduce_shard_value(
95+
self,
96+
tensor: torch.Tensor,
97+
mesh: DeviceMesh,
98+
mesh_dim: int,
99+
shard_spec: Placement,
100+
) -> torch.Tensor:
101+
# for reduce_scatter in FSDP
102+
(
103+
pre_factor,
104+
post_factor,
105+
) = self._get_reduction_divide_factors()
106+
107+
if pre_factor is not None:
108+
_div_if_needed(tensor, pre_factor)
109+
reduced = super()._reduce_shard_value(tensor, mesh, mesh_dim, shard_spec)
110+
111+
if post_factor is not None:
112+
_div_if_needed(reduced, post_factor)
113+
return reduced
114+
115+
52116
def _distribute_dtensor(
53117
tensor: DTensor,
54118
device_mesh: DeviceMesh,
@@ -192,18 +256,25 @@ def __init__(
192256
mode,
193257
regional_ac,
194258
mp_policy,
259+
reduction_divide_factor,
195260
):
196261
super().__init__()
197262
self.device_mesh = device_mesh
198263
self.param_sharding = param_sharding
199264
self.mode = mode
200265
self.compute_placements = [Replicate()] * self.device_mesh.ndim
201-
self.grad_placements = [Partial(reduce_op="avg")] * self.device_mesh.ndim
266+
self.grad_placements = [
267+
_ScaledPartial(
268+
reduction_divide_factor=reduction_divide_factor,
269+
reduce_dtype=mp_policy.reduce_dtype,
270+
)
271+
if reduction_divide_factor is not None
272+
else Partial(reduce_op="avg")
273+
] * self.device_mesh.ndim
202274
self.regional_ac = regional_ac
203275
mp_policy = mp_policy or MixedPrecisionPolicy()
204276
self.param_dtype = mp_policy.param_dtype
205277
self.reduce_dtype = mp_policy.reduce_dtype
206-
self.ep_mesh_name, self.tp_mesh_name = "ep", "tp"
207278

208279
def replicate_compute(self, x):
209280
# data parallel runtime replicate parameters and do local compute
@@ -286,6 +357,7 @@ def data_parallel(
286357
ac_mode: str = "none",
287358
mp_policy: Optional[MixedPrecisionPolicy] = None,
288359
shard_dim: int = 0,
360+
reduction_divide_factor: Optional[float] = None,
289361
):
290362
if mode == "replicate":
291363
param_sharding = (Replicate(),)
@@ -348,6 +420,7 @@ def data_parallel(
348420
mode,
349421
regional_ac,
350422
mp_policy=mp_policy,
423+
reduction_divide_factor=reduction_divide_factor,
351424
),
352425
)
353426
return model

0 commit comments

Comments
 (0)