You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: intermediate_source/TP_tutorial.rst
+3-3Lines changed: 3 additions & 3 deletions
Original file line number
Diff line number
Diff line change
@@ -141,7 +141,7 @@ q/k/v projection and row-wise sharding for the ``wo`` linear projection. So we c
141
141
This is almost the ``layer_tp_plan`` we need to apply Tensor Parallelism to the ``TransformerBlock``. However, one thing we should be aware is that when sharding the linear layer column-wise, the output of the linear layers would become sharded on the last tensor dimension, and the row-wise sharding linear layer directly accepts an input that shards on the last dimension.
142
142
If there are any more tensor operations (such as view operations) between the column-wise linear and the row-wise linear, we would need to adjust the relevant shape related ops to sharded shape.
143
143
144
-
For the Llama model, in the attention layer there are couple of view operations that are shape related. In particular, column-wise parallel for wq/ wk/ wv linear layers, the activation tensor is sharded on the ``num_heads`` dimension, so we would set ``use_local_output=False`` to let the output to be a DTensor. Compared to normal plain tensor, DTensor has knowledge about the parallelism plans, and will handle the ``num_heads`` dimension change under the hood.
144
+
For the Llama model, in the attention layer, there are several view operations related to shape. Specifically, for column-wise parallelism in the ``wq``/``wk``/``wv`` linear layers, the activation tensor is sharded on the ``num_heads`` dimension. To manage the difference between global and local ``num_heads``, we should set ``use_local_output=False`` to ensure the output is a DTensor. Unlike a regular tensor, a DTensor is aware of the parallelism plans and will automatically handle changes in the ``num_heads`` dimension.
145
145
146
146
Finally, we need to call ``parallelize_module`` API to make the plan for each ``TransformerBlock`` effective. Under the hood, it distributes the model parameters inside ``Attention`` and ``FeedForward`` layers to DTensors, and registers communication hooks for model inputs and outputs (before and after each module respectively), if necessary:
147
147
@@ -328,7 +328,7 @@ This 2-D parallelism pattern can be easily expressed via a 2-D DeviceMesh, and w
328
328
329
329
from torch.distributed.device_mesh import init_device_mesh
330
330
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel, parallelize_module
331
-
from torch.distributed.fsdp importfully_shard
331
+
from torch.distributed.fsdp importFullyShardedDataParallel asFSDP
332
332
333
333
# i.e. 2-D mesh is [dp, tp], training on 64 GPUs that performs 8 way DP and 8 way TP
334
334
mesh_2d = init_device_mesh("cuda", (8, 8))
@@ -342,7 +342,7 @@ This 2-D parallelism pattern can be easily expressed via a 2-D DeviceMesh, and w
This would allow us to easily apply Tensor Parallel within each host (intra-host) and apply FSDP across hosts (inter-hosts), with **0-code changes** to the Llama model.
0 commit comments