Skip to content

enable tensor parallelism for MXLinear #2434

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Jun 24, 2025
Merged

enable tensor parallelism for MXLinear #2434

merged 21 commits into from
Jun 24, 2025

Conversation

vkuzo
Copy link
Contributor

@vkuzo vkuzo commented Jun 24, 2025

Summary:

Enables TP for MXLinear. Specifically:

  1. change the reshape logic from x.reshape(-1, block_size) to
    x.reshape(*orig_shape[:-1], orig_shape[-1] // block_size, block_size)
  2. modify the rest of the code to adhere to (1)
  3. cast input tensor and max_abs to float32 before calculating the MX
    scale, in order to get around another bug in DTensor + view + int16
    target type

(1) is necessary because the old reshape logic would flatten dims, which
did not work if one of those flattened dims was sharded.

Note that TP does not yet work with the custom dim1 triton kernel, we'll need a separate PR to fix that by adding a sharding strategy to the kernel.

I verified that performance for FSDP + mxfp8 + compile is not affected by this stack, with torchtitan llama 3 8B on 8 B200 GPUs:

baseline (without this PR stack)

bf16 FSDP - tps 8.8k, peak_mem 35.0 GiB ([link](https://www.internalfb.com/phabricator/paste/view/P1850041288))
with-proxy CONFIG_FILE="torchtitan/models/llama3/train_configs/llama3_8b.toml " ./run_train.sh --model.print_after_conversion --training.steps 50 --parallelism.tensor_parallel_degree=1 --training.compile

bf16 FSDP + tp - tps 8.2k, peak_mem 29.6 GiB ([link](https://www.internalfb.com/phabricator/paste/view/P1850041882))
with-proxy CONFIG_FILE="torchtitan/models/llama3/train_configs/llama3_8b.toml " ./run_train.sh --model.print_after_conversion --training.steps 50 --parallelism.tensor_parallel_degree=2 --training.compile

mxfp8 FSDP - tps 10k, peak_mem 35.3 GiB ([link](https://www.internalfb.com/phabricator/paste/view/P1850040695))
with-proxy CONFIG_FILE="torchtitan/models/llama3/train_configs/llama3_8b.toml " ./run_train.sh --model.print_after_conversion --training.steps 50 --model.converters mx --mx.recipe_name "mxfp8" --parallelism.tensor_parallel_degree=1 --mx.use_fp8_dim1_cast_triton_kernel --training.compile

mxfp8 FSDP + TP - broken
with-proxy CONFIG_FILE="torchtitan/models/llama3/train_configs/llama3_8b.toml " ./run_train.sh --model.print_after_conversion --training.steps 50 --model.converters mx --mx.recipe_name "mxfp8" --parallelism.tensor_parallel_degree=2 --mx.use_fp8_dim1_cast_triton_kernel --training.compile

experiment (with this PR stack)

mxfp8 FSDP - tps 10k, peak_mem 35.3 GiB ([link](https://www.internalfb.com/phabricator/paste/view/P1850044437))
with-proxy CONFIG_FILE="torchtitan/models/llama3/train_configs/llama3_8b.toml " ./run_train.sh --model.print_after_conversion --training.steps 50 --model.converters mx --mx.recipe_name "mxfp8" --parallelism.tensor_parallel_degree=1 --mx.use_fp8_dim1_cast_triton_kernel --training.compile

mxpf8 + FSDP + TP + turn off dim triton kernel - tps 7.9k, peak_mem 29.7 GiB ([link](https://www.internalfb.com/phabricator/paste/view/P1850045992))
with-proxy CONFIG_FILE="torchtitan/models/llama3/train_configs/llama3_8b.toml " ./run_train.sh --model.print_after_conversion --training.steps 50 --model.converters mx --mx.recipe_name "mxfp8" --parallelism.tensor_parallel_degree=2 --mx.no-use_fp8_dim1_cast_triton_kernel --training.compile

Test Plan:

pytest test/prototype/mx_formats
./test/prototype/mx_formats/test_dtensor.sh

Reviewers:

Subscribers:

Tasks:

Tags:

vkuzo added 8 commits June 20, 2025 07:10
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
@vkuzo
Copy link
Contributor Author

vkuzo commented Jun 24, 2025

Copy link

pytorch-bot bot commented Jun 24, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2434

Note: Links to docs will display an error until the docs builds have been completed.

⏳ No Failures, 11 Pending

As of commit 1001602 with merge base 7d6bb6a (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 24, 2025
vkuzo added a commit that referenced this pull request Jun 24, 2025
Summary:

Enables TP for MXLinear. Specifically:
1. change the reshape logic from `x.reshape(-1, block_size)` to
   `x.reshape(*orig_shape[:-1], orig_shape[-1] // block_size,
   block_size)`
2. modify the rest of the code to adhere to (1)
3. cast input tensor and max_abs to float32 before calculating the MX
   scale, in order to get around another bug in DTensor + view + int16
   target type

(1) is necessary because the old reshape logic would flatten dims, which
did not work if one of those flattened dims was sharded.

Test Plan:

```
pytest test/prototype/mx_formats
./test/prototype/mx_formats/test_dtensor.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 0bcd13e
ghstack-comment-id: 3000664086
Pull Request resolved: #2434
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Jun 24, 2025
Summary:

Enables TP for MXLinear. Specifically:
1. change the reshape logic from `x.reshape(-1, block_size)` to
   `x.reshape(*orig_shape[:-1], orig_shape[-1] // block_size,
   block_size)`
2. modify the rest of the code to adhere to (1)
3. cast input tensor and max_abs to float32 before calculating the MX
   scale, in order to get around another bug in DTensor + view + int16
   target type

(1) is necessary because the old reshape logic would flatten dims, which
did not work if one of those flattened dims was sharded.

Test Plan:

```
pytest test/prototype/mx_formats
./test/prototype/mx_formats/test_dtensor.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: b742646
ghstack-comment-id: 3000664086
Pull Request resolved: #2434
vkuzo added 4 commits June 24, 2025 07:17
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Jun 24, 2025
Summary:

Enables TP for MXLinear. Specifically:
1. change the reshape logic from `x.reshape(-1, block_size)` to
   `x.reshape(*orig_shape[:-1], orig_shape[-1] // block_size,
   block_size)`
2. modify the rest of the code to adhere to (1)
3. cast input tensor and max_abs to float32 before calculating the MX
   scale, in order to get around another bug in DTensor + view + int16
   target type

(1) is necessary because the old reshape logic would flatten dims, which
did not work if one of those flattened dims was sharded.

Test Plan:

```
pytest test/prototype/mx_formats
./test/prototype/mx_formats/test_dtensor.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 0cf8a1f
ghstack-comment-id: 3000664086
Pull Request resolved: #2434
vkuzo added 3 commits June 24, 2025 07:17
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Jun 24, 2025
Summary:

Enables TP for MXLinear. Specifically:
1. change the reshape logic from `x.reshape(-1, block_size)` to
   `x.reshape(*orig_shape[:-1], orig_shape[-1] // block_size,
   block_size)`
2. modify the rest of the code to adhere to (1)
3. cast input tensor and max_abs to float32 before calculating the MX
   scale, in order to get around another bug in DTensor + view + int16
   target type

(1) is necessary because the old reshape logic would flatten dims, which
did not work if one of those flattened dims was sharded.

Test Plan:

```
pytest test/prototype/mx_formats
./test/prototype/mx_formats/test_dtensor.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 0cf8a1f
ghstack-comment-id: 3000664086
Pull Request resolved: #2434
vkuzo added 2 commits June 24, 2025 07:19
[ghstack-poisoned]
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Jun 24, 2025
Summary:

Enables TP for MXLinear. Specifically:
1. change the reshape logic from `x.reshape(-1, block_size)` to
   `x.reshape(*orig_shape[:-1], orig_shape[-1] // block_size,
   block_size)`
2. modify the rest of the code to adhere to (1)
3. cast input tensor and max_abs to float32 before calculating the MX
   scale, in order to get around another bug in DTensor + view + int16
   target type

(1) is necessary because the old reshape logic would flatten dims, which
did not work if one of those flattened dims was sharded.

Test Plan:

```
pytest test/prototype/mx_formats
./test/prototype/mx_formats/test_dtensor.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 0cf8a1f
ghstack-comment-id: 3000664086
Pull Request resolved: #2434
@vkuzo vkuzo added the topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories) label Jun 24, 2025
@@ -190,8 +190,8 @@ def test_linear_eager_emulated_vs_real_gemm(recipe_name, mkn):
# TODO(future): enable compile support
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_activation_checkpointing():
input_shape = (2, 4)
grad_shape = (2, 8)
input_shape = (16, 4)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was broken before, caught by enforcing that inner dim is divisible by block size

# torchtitan but not in a unit test, so not enough info to file a good
# issue in pytorch/pytorch. For now, work around. In the future we should
# debug and fix this properly.
data_hp = data_hp.to(torch.float32)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

performance testing showed that with compile on, having this in float32 does not regress performance


tp_out = tp_model(x_fp32_tp_input)
tp_out.sum().backward()
tp_out.backward(go_fp32_tp)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to make sure grad flowing into the last linear is contiguous

vkuzo added 2 commits June 24, 2025 07:31
[ghstack-poisoned]
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Jun 24, 2025
Summary:

Enables TP for MXLinear. Specifically:
1. change the reshape logic from `x.reshape(-1, block_size)` to
   `x.reshape(*orig_shape[:-1], orig_shape[-1] // block_size,
   block_size)`
2. modify the rest of the code to adhere to (1)
3. cast input tensor and max_abs to float32 before calculating the MX
   scale, in order to get around another bug in DTensor + view + int16
   target type

(1) is necessary because the old reshape logic would flatten dims, which
did not work if one of those flattened dims was sharded.

Test Plan:

```
pytest test/prototype/mx_formats
./test/prototype/mx_formats/test_dtensor.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 860833d
ghstack-comment-id: 3000664086
Pull Request resolved: #2434
[ghstack-poisoned]
@vkuzo vkuzo changed the base branch from gh/vkuzo/90/head to main June 24, 2025 19:18
@vkuzo vkuzo merged commit 32599be into main Jun 24, 2025
47 of 53 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories)
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants