Skip to content

[not for land] float8 blockwise scaling training prototype using deep_gemm #2386

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

vkuzo
Copy link
Contributor

@vkuzo vkuzo commented Jun 16, 2025

Since this is a common community request, I did a test drive of how we could integrate deep_gemm into an e2e training workflow. deep_gemm (https://github.com/deepseek-ai/DeepGEMM) provides the following things:

  • fwd and bwd gemms for dense linear (tested in this PR)
  • fwd and bwd grouped gemms for MoE (not tested yet)

What I saw:

  • gemm performance: good on my H100! (seeing ~60% of peak TOPs)
  • toy linear with correct numerics: done
  • e2e performance work: not started, and currently not planned. I did notice that 128x128 scaling of a single tensor is not torch.compile friendly as written - it results in 3 kernels per tensor.
  • trying this e2e in torchtitan: [not for land] testing out float8 128_1_128_128 blockwise scaling torchtitan#1317
    • if we use the 128_1_128_1 gemm, currently crashes during the backward with illegal memory access, seems like this is specific to the result of grad_weight: https://gist.github.com/vkuzo/6e9cacb226593f7e5f27ac5cd5e79fb1. For now, work around this by leaving the gemm to calculate grad_weight in bf16. Something is funky with how we are wrapping the 128_1_128_1 gemm.

If we were to integrate this, here is the path forward:

  1. either make deep_gemms 128_1_128_1 gemm work properly, or write our own, or just leave this matmul in bf16
  2. get a fast version of 128x128 scaling going with a handwritten kernel, and file a compile issue to have compile catch up. This seems generally useful for other formats as well.
  3. optimize performance
  4. integrate as a recipe into the main float8 training path

Copy link

pytorch-bot bot commented Jun 16, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2386

Note: Links to docs will display an error until the docs builds have been completed.

❌ 10 New Failures

As of commit c2115b5 with merge base 5bdc25d (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 16, 2025
@vkuzo vkuzo changed the title 20250616 deepgemm hack [not for land] 20250616 deepgemm hack Jun 16, 2025
@vkuzo vkuzo force-pushed the 20250616_deepgemm_hack branch 10 times, most recently from c4df31a to a2a31eb Compare June 17, 2025 19:54
@vkuzo vkuzo changed the title [not for land] 20250616 deepgemm hack [not for land] try blockwise scaling using deep_gemm Jun 18, 2025
@vkuzo vkuzo changed the title [not for land] try blockwise scaling using deep_gemm [not for land] float8 blockwise scaling training prototype using deep_gemm Jun 18, 2025
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
@vkuzo vkuzo force-pushed the 20250616_deepgemm_hack branch from a2a31eb to c2115b5 Compare June 18, 2025 12:44
vkuzo added a commit to pytorch/torchtitan that referenced this pull request Jun 18, 2025
Summary:

Test drive of pytorch/ao#2386, not for land

Test Plan:

```bash
with-proxy CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --model.converters float8 --model.print_after_conversion
```

Reviewers:

Subscribers:

Tasks:

Tags:
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants