|
20 | 20 | Shard, |
21 | 21 | ) |
22 | 22 | from torch.distributed.device_mesh import _mesh_resources, DeviceMesh |
| 23 | +from torch.distributed.fsdp._fully_shard._fsdp_collectives import _div_if_needed |
23 | 24 | from torch.distributed.tensor._dtensor_spec import DTensorSpec |
24 | 25 | from torch.distributed.tensor._redistribute import redistribute_local_tensor |
25 | 26 | from torch.distributed.tensor.placement_types import _StridedShard, Placement |
@@ -49,6 +50,69 @@ class MixedPrecisionPolicy: |
49 | 50 | reduce_dtype: Optional[torch.dtype] = None |
50 | 51 |
|
51 | 52 |
|
| 53 | +class _ScaledPartial(Partial): |
| 54 | + # A subclass of Partial placement that allows user to perform gradient reduction with a custom |
| 55 | + # factor (reduction_divide_factor) other than the default world size. |
| 56 | + def __init__( |
| 57 | + self, |
| 58 | + reduction_divide_factor: float, |
| 59 | + reduce_dtype: Optional[torch.dtype] = None, |
| 60 | + ): |
| 61 | + self.reduction_divide_factor = reduction_divide_factor |
| 62 | + self.reduce_dtype = reduce_dtype |
| 63 | + super().__init__(reduce_op="sum") |
| 64 | + |
| 65 | + def _get_reduction_divide_factors( |
| 66 | + self, |
| 67 | + ) -> tuple[Optional[float], Optional[float]]: |
| 68 | + """ |
| 69 | + the logic follows |
| 70 | + https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py#L688 |
| 71 | + """ |
| 72 | + assert self.reduce_dtype in ( |
| 73 | + torch.float32, |
| 74 | + torch.bfloat16, |
| 75 | + ), "only support reduce_dtype to be fp32/bf16" |
| 76 | + pre_factor, post_factor = None, self.reduction_divide_factor |
| 77 | + return pre_factor, post_factor |
| 78 | + |
| 79 | + def _reduce_value( |
| 80 | + self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int |
| 81 | + ) -> torch.Tensor: |
| 82 | + # for all_reduce in DDP |
| 83 | + ( |
| 84 | + pre_factor, |
| 85 | + post_factor, |
| 86 | + ) = self._get_reduction_divide_factors() |
| 87 | + if pre_factor is not None: |
| 88 | + _div_if_needed(tensor, pre_factor) |
| 89 | + reduced = super()._reduce_value(tensor, mesh, mesh_dim) |
| 90 | + if post_factor is not None: |
| 91 | + _div_if_needed(reduced, post_factor) |
| 92 | + return reduced |
| 93 | + |
| 94 | + def _reduce_shard_value( |
| 95 | + self, |
| 96 | + tensor: torch.Tensor, |
| 97 | + mesh: DeviceMesh, |
| 98 | + mesh_dim: int, |
| 99 | + shard_spec: Placement, |
| 100 | + ) -> torch.Tensor: |
| 101 | + # for reduce_scatter in FSDP |
| 102 | + ( |
| 103 | + pre_factor, |
| 104 | + post_factor, |
| 105 | + ) = self._get_reduction_divide_factors() |
| 106 | + |
| 107 | + if pre_factor is not None: |
| 108 | + _div_if_needed(tensor, pre_factor) |
| 109 | + reduced = super()._reduce_shard_value(tensor, mesh, mesh_dim, shard_spec) |
| 110 | + |
| 111 | + if post_factor is not None: |
| 112 | + _div_if_needed(reduced, post_factor) |
| 113 | + return reduced |
| 114 | + |
| 115 | + |
52 | 116 | def _distribute_dtensor( |
53 | 117 | tensor: DTensor, |
54 | 118 | device_mesh: DeviceMesh, |
@@ -192,18 +256,25 @@ def __init__( |
192 | 256 | mode, |
193 | 257 | regional_ac, |
194 | 258 | mp_policy, |
| 259 | + reduction_divide_factor, |
195 | 260 | ): |
196 | 261 | super().__init__() |
197 | 262 | self.device_mesh = device_mesh |
198 | 263 | self.param_sharding = param_sharding |
199 | 264 | self.mode = mode |
200 | 265 | self.compute_placements = [Replicate()] * self.device_mesh.ndim |
201 | | - self.grad_placements = [Partial(reduce_op="avg")] * self.device_mesh.ndim |
| 266 | + self.grad_placements = [ |
| 267 | + _ScaledPartial( |
| 268 | + reduction_divide_factor=reduction_divide_factor, |
| 269 | + reduce_dtype=mp_policy.reduce_dtype, |
| 270 | + ) |
| 271 | + if reduction_divide_factor is not None |
| 272 | + else Partial(reduce_op="avg") |
| 273 | + ] * self.device_mesh.ndim |
202 | 274 | self.regional_ac = regional_ac |
203 | 275 | mp_policy = mp_policy or MixedPrecisionPolicy() |
204 | 276 | self.param_dtype = mp_policy.param_dtype |
205 | 277 | self.reduce_dtype = mp_policy.reduce_dtype |
206 | | - self.ep_mesh_name, self.tp_mesh_name = "ep", "tp" |
207 | 278 |
|
208 | 279 | def replicate_compute(self, x): |
209 | 280 | # data parallel runtime replicate parameters and do local compute |
@@ -286,6 +357,7 @@ def data_parallel( |
286 | 357 | ac_mode: str = "none", |
287 | 358 | mp_policy: Optional[MixedPrecisionPolicy] = None, |
288 | 359 | shard_dim: int = 0, |
| 360 | + reduction_divide_factor: Optional[float] = None, |
289 | 361 | ): |
290 | 362 | if mode == "replicate": |
291 | 363 | param_sharding = (Replicate(),) |
@@ -348,6 +420,7 @@ def data_parallel( |
348 | 420 | mode, |
349 | 421 | regional_ac, |
350 | 422 | mp_policy=mp_policy, |
| 423 | + reduction_divide_factor=reduction_divide_factor, |
351 | 424 | ), |
352 | 425 | ) |
353 | 426 | return model |
0 commit comments