Skip to content

Commit 630e1d2

Browse files
committed
rewrite
1 parent 90c66f8 commit 630e1d2

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

intermediate_source/TP_tutorial.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ q/k/v projection and row-wise sharding for the ``wo`` linear projection. So we c
141141
This is almost the ``layer_tp_plan`` we need to apply Tensor Parallelism to the ``TransformerBlock``. However, one thing we should be aware is that when sharding the linear layer column-wise, the output of the linear layers would become sharded on the last tensor dimension, and the row-wise sharding linear layer directly accepts an input that shards on the last dimension.
142142
If there are any more tensor operations (such as view operations) between the column-wise linear and the row-wise linear, we would need to adjust the relevant shape related ops to sharded shape.
143143

144-
For the Llama model, in the attention layer there are couple of view operations that are shape related. In particular, column-wise parallel for wq/ wk/ wv linear layers, the activation tensor is sharded on the ``num_heads`` dimension, so we would set ``use_local_output=False`` to let the output to be a DTensor. Compared to normal plain tensor, DTensor has knowledge about the parallelism plans, and will handle the ``num_heads`` dimension change under the hood.
144+
For the Llama model, in the attention layer, there are several view operations related to shape. Specifically, for column-wise parallelism in the ``wq``/``wk``/``wv`` linear layers, the activation tensor is sharded on the ``num_heads`` dimension. To manage the difference between global and local ``num_heads``, we should set ``use_local_output=False`` to ensure the output is a DTensor. Unlike a regular tensor, a DTensor is aware of the parallelism plans and will automatically handle changes in the ``num_heads`` dimension.
145145

146146
Finally, we need to call ``parallelize_module`` API to make the plan for each ``TransformerBlock`` effective. Under the hood, it distributes the model parameters inside ``Attention`` and ``FeedForward`` layers to DTensors, and registers communication hooks for model inputs and outputs (before and after each module respectively), if necessary:
147147

@@ -328,7 +328,7 @@ This 2-D parallelism pattern can be easily expressed via a 2-D DeviceMesh, and w
328328
329329
from torch.distributed.device_mesh import init_device_mesh
330330
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel, parallelize_module
331-
from torch.distributed.fsdp import fully_shard
331+
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
332332
333333
# i.e. 2-D mesh is [dp, tp], training on 64 GPUs that performs 8 way DP and 8 way TP
334334
mesh_2d = init_device_mesh("cuda", (8, 8))
@@ -342,7 +342,7 @@ This 2-D parallelism pattern can be easily expressed via a 2-D DeviceMesh, and w
342342
# apply Tensor Parallel intra-host on tp_mesh
343343
model_tp = parallelize_module(model, tp_mesh, tp_plan)
344344
# apply FSDP inter-host on dp_mesh
345-
model_2d = fully_shard(model_tp, mesh=dp_mesh, ...)
345+
model_2d = FSDP(model_tp, device_mesh=dp_mesh, use_orig_params=True, ...)
346346
347347
348348
This would allow us to easily apply Tensor Parallel within each host (intra-host) and apply FSDP across hosts (inter-hosts), with **0-code changes** to the Llama model.

0 commit comments

Comments
 (0)