-
Notifications
You must be signed in to change notification settings - Fork 291
enable tensor parallelism for MXLinear #2434
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Stack from ghstack (oldest at bottom): |
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2434
Note: Links to docs will display an error until the docs builds have been completed. ⏳ No Failures, 11 PendingAs of commit 1001602 with merge base 7d6bb6a ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Summary: Enables TP for MXLinear. Specifically: 1. change the reshape logic from `x.reshape(-1, block_size)` to `x.reshape(*orig_shape[:-1], orig_shape[-1] // block_size, block_size)` 2. modify the rest of the code to adhere to (1) 3. cast input tensor and max_abs to float32 before calculating the MX scale, in order to get around another bug in DTensor + view + int16 target type (1) is necessary because the old reshape logic would flatten dims, which did not work if one of those flattened dims was sharded. Test Plan: ``` pytest test/prototype/mx_formats ./test/prototype/mx_formats/test_dtensor.sh ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 0bcd13e ghstack-comment-id: 3000664086 Pull Request resolved: #2434
Summary: Enables TP for MXLinear. Specifically: 1. change the reshape logic from `x.reshape(-1, block_size)` to `x.reshape(*orig_shape[:-1], orig_shape[-1] // block_size, block_size)` 2. modify the rest of the code to adhere to (1) 3. cast input tensor and max_abs to float32 before calculating the MX scale, in order to get around another bug in DTensor + view + int16 target type (1) is necessary because the old reshape logic would flatten dims, which did not work if one of those flattened dims was sharded. Test Plan: ``` pytest test/prototype/mx_formats ./test/prototype/mx_formats/test_dtensor.sh ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: b742646 ghstack-comment-id: 3000664086 Pull Request resolved: #2434
Summary: Enables TP for MXLinear. Specifically: 1. change the reshape logic from `x.reshape(-1, block_size)` to `x.reshape(*orig_shape[:-1], orig_shape[-1] // block_size, block_size)` 2. modify the rest of the code to adhere to (1) 3. cast input tensor and max_abs to float32 before calculating the MX scale, in order to get around another bug in DTensor + view + int16 target type (1) is necessary because the old reshape logic would flatten dims, which did not work if one of those flattened dims was sharded. Test Plan: ``` pytest test/prototype/mx_formats ./test/prototype/mx_formats/test_dtensor.sh ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 0cf8a1f ghstack-comment-id: 3000664086 Pull Request resolved: #2434
Summary: Enables TP for MXLinear. Specifically: 1. change the reshape logic from `x.reshape(-1, block_size)` to `x.reshape(*orig_shape[:-1], orig_shape[-1] // block_size, block_size)` 2. modify the rest of the code to adhere to (1) 3. cast input tensor and max_abs to float32 before calculating the MX scale, in order to get around another bug in DTensor + view + int16 target type (1) is necessary because the old reshape logic would flatten dims, which did not work if one of those flattened dims was sharded. Test Plan: ``` pytest test/prototype/mx_formats ./test/prototype/mx_formats/test_dtensor.sh ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 0cf8a1f ghstack-comment-id: 3000664086 Pull Request resolved: #2434
Summary: Enables TP for MXLinear. Specifically: 1. change the reshape logic from `x.reshape(-1, block_size)` to `x.reshape(*orig_shape[:-1], orig_shape[-1] // block_size, block_size)` 2. modify the rest of the code to adhere to (1) 3. cast input tensor and max_abs to float32 before calculating the MX scale, in order to get around another bug in DTensor + view + int16 target type (1) is necessary because the old reshape logic would flatten dims, which did not work if one of those flattened dims was sharded. Test Plan: ``` pytest test/prototype/mx_formats ./test/prototype/mx_formats/test_dtensor.sh ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 0cf8a1f ghstack-comment-id: 3000664086 Pull Request resolved: #2434
@@ -190,8 +190,8 @@ def test_linear_eager_emulated_vs_real_gemm(recipe_name, mkn): | |||
# TODO(future): enable compile support | |||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") | |||
def test_activation_checkpointing(): | |||
input_shape = (2, 4) | |||
grad_shape = (2, 8) | |||
input_shape = (16, 4) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this was broken before, caught by enforcing that inner dim is divisible by block size
# torchtitan but not in a unit test, so not enough info to file a good | ||
# issue in pytorch/pytorch. For now, work around. In the future we should | ||
# debug and fix this properly. | ||
data_hp = data_hp.to(torch.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
performance testing showed that with compile on, having this in float32 does not regress performance
|
||
tp_out = tp_model(x_fp32_tp_input) | ||
tp_out.sum().backward() | ||
tp_out.backward(go_fp32_tp) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to make sure grad flowing into the last linear is contiguous
Summary: Enables TP for MXLinear. Specifically: 1. change the reshape logic from `x.reshape(-1, block_size)` to `x.reshape(*orig_shape[:-1], orig_shape[-1] // block_size, block_size)` 2. modify the rest of the code to adhere to (1) 3. cast input tensor and max_abs to float32 before calculating the MX scale, in order to get around another bug in DTensor + view + int16 target type (1) is necessary because the old reshape logic would flatten dims, which did not work if one of those flattened dims was sharded. Test Plan: ``` pytest test/prototype/mx_formats ./test/prototype/mx_formats/test_dtensor.sh ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 860833d ghstack-comment-id: 3000664086 Pull Request resolved: #2434
Summary:
Enables TP for MXLinear. Specifically:
x.reshape(-1, block_size)
tox.reshape(*orig_shape[:-1], orig_shape[-1] // block_size, block_size)
scale, in order to get around another bug in DTensor + view + int16
target type
(1) is necessary because the old reshape logic would flatten dims, which
did not work if one of those flattened dims was sharded.
Note that TP does not yet work with the custom dim1 triton kernel, we'll need a separate PR to fix that by adding a sharding strategy to the kernel.
I verified that performance for FSDP + mxfp8 + compile is not affected by this stack, with torchtitan llama 3 8B on 8 B200 GPUs:
Test Plan:
Reviewers:
Subscribers:
Tasks:
Tags: