Skip to content

Using DTensor to handle local num_heads change while TP is applied #3465

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 9 additions & 14 deletions intermediate_source/TP_tutorial.rst
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,9 @@ q/k/v projection and row-wise sharding for the ``wo`` linear projection. So we c
layer_tp_plan = {
# by default ColwiseParallel input layouts is replicated
# and RowwiseParallel output layouts is replicated
"attention.wq": ColwiseParallel(),
"attention.wk": ColwiseParallel(),
"attention.wv": ColwiseParallel(),
"attention.wq": ColwiseParallel(use_local_output=False),
"attention.wk": ColwiseParallel(use_local_output=False),
"attention.wv": ColwiseParallel(use_local_output=False),
"attention.wo": RowwiseParallel(),
"feed_forward.w1": ColwiseParallel(),
"feed_forward.w2": RowwiseParallel(),
Expand All @@ -141,7 +141,7 @@ q/k/v projection and row-wise sharding for the ``wo`` linear projection. So we c
This is almost the ``layer_tp_plan`` we need to apply Tensor Parallelism to the ``TransformerBlock``. However, one thing we should be aware is that when sharding the linear layer column-wise, the output of the linear layers would become sharded on the last tensor dimension, and the row-wise sharding linear layer directly accepts an input that shards on the last dimension.
If there are any more tensor operations (such as view operations) between the column-wise linear and the row-wise linear, we would need to adjust the relevant shape related ops to sharded shape.

For the Llama model, in the attention layer there are couple of view operations that are shape related. In particular, column-wise parallel for ``wq``/ ``wk``/ ``wv`` linear layers, the activation tensor is sharded on the ``num_heads`` dimension, so we would need to adjust the ``num_heads`` to local ``num_heads``.
For the Llama model, in the attention layer, there are several view operations related to shape. Specifically, for column-wise parallelism in the ``wq``/``wk``/``wv`` linear layers, the activation tensor is sharded on the ``num_heads`` dimension. To manage the difference between global and local ``num_heads``, we should set ``use_local_output=False`` to ensure the output is a DTensor. Unlike a regular tensor, a DTensor is aware of the parallelism plans and will automatically handle changes in the ``num_heads`` dimension.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should be able to use DTensor i.e. set use_local_output=False everywhere.
Maybe it's OK to keep a mixed usage of use_local_output so people are aware of this flexibility, but we should mention it here.


Finally, we need to call ``parallelize_module`` API to make the plan for each ``TransformerBlock`` effective. Under the hood, it distributes the model parameters inside ``Attention`` and ``FeedForward`` layers to DTensors, and registers communication hooks for model inputs and outputs (before and after each module respectively), if necessary:

Expand All @@ -150,11 +150,6 @@ Finally, we need to call ``parallelize_module`` API to make the plan for each ``
for layer_id, transformer_block in enumerate(model.layers):
layer_tp_plan = {...} # i.e. the plan we just generated

# Adjust attention module to use the local number of heads
attn_layer = transformer_block.attention
attn_layer.n_heads = attn_layer.n_heads // tp_mesh.size()
attn_layer.n_kv_heads = attn_layer.n_kv_heads // tp_mesh.size()

parallelize_module(
module=transformer_block,
device_mesh=tp_mesh,
Expand Down Expand Up @@ -219,12 +214,12 @@ Next let's adjust the ``layer_tp_plan`` to enable sequence parallel on the ``RMS
# to represent the input/output tensors sharded on the sequence dimension
"attention_norm": SequenceParallel(),
"attention": PrepareModuleInput(
input_layouts=(Shard(1),),
desired_input_layouts=(Replicate(),),
input_layouts=(Shard(1), Replicate()),
desired_input_layouts=(Replicate(), Replicate()),
),
"attention.wq": ColwiseParallel(),
"attention.wk": ColwiseParallel(),
"attention.wv": ColwiseParallel(),
"attention.wq": ColwiseParallel(use_local_output=False),
"attention.wk": ColwiseParallel(use_local_output=False),
"attention.wv": ColwiseParallel(use_local_output=False),
"attention.wo": RowwiseParallel(output_layouts=Shard(1)),
"ffn_norm": SequenceParallel(),
"feed_forward": PrepareModuleInput(
Expand Down