-
Notifications
You must be signed in to change notification settings - Fork 4.2k
Using DTensor to handle local num_heads change while TP is applied #3465
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/tutorials/3465
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 6dd3297 with merge base 755434d ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
cc @tianyu-l |
@@ -141,7 +141,7 @@ q/k/v projection and row-wise sharding for the ``wo`` linear projection. So we c | |||
This is almost the ``layer_tp_plan`` we need to apply Tensor Parallelism to the ``TransformerBlock``. However, one thing we should be aware is that when sharding the linear layer column-wise, the output of the linear layers would become sharded on the last tensor dimension, and the row-wise sharding linear layer directly accepts an input that shards on the last dimension. | |||
If there are any more tensor operations (such as view operations) between the column-wise linear and the row-wise linear, we would need to adjust the relevant shape related ops to sharded shape. | |||
|
|||
For the Llama model, in the attention layer there are couple of view operations that are shape related. In particular, column-wise parallel for ``wq``/ ``wk``/ ``wv`` linear layers, the activation tensor is sharded on the ``num_heads`` dimension, so we would need to adjust the ``num_heads`` to local ``num_heads``. | |||
For the Llama model, in the attention layer, there are several view operations related to shape. Specifically, for column-wise parallelism in the ``wq``/``wk``/``wv`` linear layers, the activation tensor is sharded on the ``num_heads`` dimension. To manage the difference between global and local ``num_heads``, we should set ``use_local_output=False`` to ensure the output is a DTensor. Unlike a regular tensor, a DTensor is aware of the parallelism plans and will automatically handle changes in the ``num_heads`` dimension. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should be able to use DTensor i.e. set use_local_output=False
everywhere.
Maybe it's OK to keep a mixed usage of use_local_output
so people are aware of this flexibility, but we should mention it here.
Fixes #ISSUE_NUMBER. This PR is to make the TP tutorial up-to-date with DTensor changes.
Description
After DTensor enhancement, we are not able to use DTensor to handle the change of
num_heads
instead of manually handle the tensor shape while TP is applied.Corresponding changes in
pytorch/examples
: pytorch/examples#1373Checklist