Skip to content

Using DTensor to handle local num_heads change while TP is applied #3465

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

wwwjn
Copy link
Contributor

@wwwjn wwwjn commented Jul 16, 2025

Fixes #ISSUE_NUMBER. This PR is to make the TP tutorial up-to-date with DTensor changes.

Description

After DTensor enhancement, we are not able to use DTensor to handle the change of num_heads instead of manually handle the tensor shape while TP is applied.
Corresponding changes in pytorch/examples: pytorch/examples#1373

Checklist

  • The issue that is being fixed is referred in the description (see above "Fixes #ISSUE_NUMBER")
  • Only one issue is addressed in this pull request
  • Labels from the issue that this PR is fixing are added to this pull request
  • No unnecessary issues are included into this pull request.

Copy link

pytorch-bot bot commented Jul 16, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/tutorials/3465

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 6dd3297 with merge base 755434d (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the cla signed label Jul 16, 2025
@wwwjn wwwjn changed the title Using DTensor to handel local num_heads change while TP is applied Using DTensor to handle local num_heads change while TP is applied Jul 16, 2025
@wwwjn
Copy link
Contributor Author

wwwjn commented Jul 16, 2025

cc @tianyu-l

@@ -141,7 +141,7 @@ q/k/v projection and row-wise sharding for the ``wo`` linear projection. So we c
This is almost the ``layer_tp_plan`` we need to apply Tensor Parallelism to the ``TransformerBlock``. However, one thing we should be aware is that when sharding the linear layer column-wise, the output of the linear layers would become sharded on the last tensor dimension, and the row-wise sharding linear layer directly accepts an input that shards on the last dimension.
If there are any more tensor operations (such as view operations) between the column-wise linear and the row-wise linear, we would need to adjust the relevant shape related ops to sharded shape.

For the Llama model, in the attention layer there are couple of view operations that are shape related. In particular, column-wise parallel for ``wq``/ ``wk``/ ``wv`` linear layers, the activation tensor is sharded on the ``num_heads`` dimension, so we would need to adjust the ``num_heads`` to local ``num_heads``.
For the Llama model, in the attention layer, there are several view operations related to shape. Specifically, for column-wise parallelism in the ``wq``/``wk``/``wv`` linear layers, the activation tensor is sharded on the ``num_heads`` dimension. To manage the difference between global and local ``num_heads``, we should set ``use_local_output=False`` to ensure the output is a DTensor. Unlike a regular tensor, a DTensor is aware of the parallelism plans and will automatically handle changes in the ``num_heads`` dimension.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should be able to use DTensor i.e. set use_local_output=False everywhere.
Maybe it's OK to keep a mixed usage of use_local_output so people are aware of this flexibility, but we should mention it here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants