Skip to content

[WIP] Integrate autoparallel into torchtitan #1458

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: gh/IvanKobzarev/1/base
Choose a base branch
from

Conversation

IvanKobzarev
Copy link

@IvanKobzarev IvanKobzarev commented Jul 25, 2025

Stack from ghstack (oldest at bottom):

TODO

  • try converting model params into fake tensors
  • figure out init fn
  • integrate torchtitan configs for DP/TP to control autop

Hack an init_fn for llama3 and observe loss decreasing with autoparallel

"""
[rank0]:[titan] 2025-06-16 16:24:16,593 - root - INFO - Training starts at step 1.
[rank0]:[titan] 2025-06-16 16:24:23,544 - root - INFO - step: 1 loss: 8.1880 memory: 4.88GiB(6.16%) tps: 28
[rank0]:[titan] 2025-06-16 16:24:23,545 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:[titan] 2025-06-16 16:24:23,842 - root - INFO - step: 2 loss: 8.1610 memory: 4.90GiB(6.20%) tps: 13,785
[rank0]:[titan] 2025-06-16 16:24:24,135 - root - INFO - step: 3 loss: 8.0871 memory: 4.90GiB(6.20%) tps: 14,006
[rank0]:[titan] 2025-06-16 16:24:24,433 - root - INFO - step: 4 loss: 7.9516 memory: 4.90GiB(6.20%) tps: 13,770
[rank0]:[titan] 2025-06-16 16:24:24,727 - root - INFO - step: 5 loss: 7.8552 memory: 4.90GiB(6.20%) tps: 13,959
[rank0]:[titan] 2025-06-16 16:24:25,023 - root - INFO - step: 6 loss: 7.7732 memory: 4.90GiB(6.20%) tps: 13,859
[rank0]:[titan] 2025-06-16 16:24:25,324 - root - INFO - step: 7 loss: 7.6987 memory: 4.90GiB(6.20%) tps: 13,664
[rank0]:[titan] 2025-06-16 16:24:25,617 - root - INFO - step: 8 loss: 7.6779 memory: 4.90GiB(6.20%) tps: 13,985
[rank0]:[titan] 2025-06-16 16:24:25,911 - root - INFO - step: 9 loss: 7.6043 memory: 4.90GiB(6.20%) tps: 13,962
[rank0]:[titan] 2025-06-16 16:24:26,207 - root - INFO - step: 10 loss: 7.5778 memory: 4.90GiB(6.20%) tps: 13,891
"""

Adopt new autoparallel API with meta-init model

Allows reverting a lot of the hacks in the original integration that
were caused by not creating a model obj in the train.py due to passing a
model_fn builder to autop.

Fixes to align with latest autoparallel

Add inductor config knobs for comms optimizations to torchtitan

Make inductor always run compile passes

basically, this is an annoying workaround for debugging iteratively.

1- you run the model, it compiles, but something weird happens
2- you enable some logging or tlparse, rerun. but inductor decides not
to run your pass anymore, its results are cached.

since (2) has confused me horribly on more than one occasion, i just
disable caching for now

Drop hacky llama3_init_fn and use autop init_weights feature

Relying on https://github.com/pytorch-labs/autoparallel/pull/20, this
lets us automatically apply a user's init_weights fn to the autoparallel
model.

Verified this works with

CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --model.name llama3_auto_parallel --parallelism.tensor_parallel_degree 4 --training.dataset c4

[rank0]:[titan] 2025-07-02 16:18:02,007 - root - INFO - Training starts at step 1.
[rank0]:[titan] 2025-07-02 16:18:08,224 - root - INFO - step:  1  loss:  8.1848  memory:  1.09GiB(1.14%)  tps: 77  tflops: 0.01  mfu: 0.00%
[rank0]:[titan] 2025-07-02 16:18:08,224 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:[titan] 2025-07-02 16:18:08,310 - root - INFO - step:  2  loss:  8.1619  memory:  1.15GiB(1.21%)  tps: 48,138  tflops: 3.46  mfu: 0.35
%
[rank0]:[titan] 2025-07-02 16:18:08,356 - root - INFO - step:  3  loss:  8.1140  memory:  1.15GiB(1.21%)  tps: 88,440  tflops: 6.36  mfu: 0.64
%
[rank0]:[titan] 2025-07-02 16:18:08,406 - root - INFO - step:  4  loss:  8.0099  memory:  1.15GiB(1.21%)  tps: 82,626  tflops: 5.94  mfu: 0.60
%
[rank0]:[titan] 2025-07-02 16:18:08,457 - root - INFO - step:  5  loss:  7.8928  memory:  1.15GiB(1.21%)  tps: 81,594  tflops: 5.87  mfu: 0.59
%
[rank0]:[titan] 2025-07-02 16:18:08,508 - root - INFO - step:  6  loss:  7.7758  memory:  1.15GiB(1.21%)  tps: 79,607  tflops: 5.72  mfu: 0.58
%
[rank0]:[titan] 2025-07-02 16:18:08,559 - root - INFO - step:  7  loss:  7.6221  memory:  1.15GiB(1.21%)  tps: 81,448  tflops: 5.86  mfu: 0.59
%
[rank0]:[titan] 2025-07-02 16:18:08,611 - root - INFO - step:  8  loss:  7.5578  memory:  1.15GiB(1.21%)  tps: 79,732  tflops: 5.73  mfu: 0.58
%
[rank0]:[titan] 2025-07-02 16:18:08,659 - root - INFO - step:  9  loss:  7.3851  memory:  1.15GiB(1.21%)  tps: 85,655  tflops: 6.16  mfu: 0.62
%
[rank0]:[titan] 2025-07-02 16:18:08,709 - root - INFO - step: 10  loss:  7.3361  memory:  1.15GiB(1.21%)  tps: 81,855  tflops: 5.89  mfu: 0.60
%
[rank0]:[titan] 2025-07-02 16:18:08,709 - root - INFO - Sleeping 2 seconds for other ranks to complete

fix lint

TODO
- try converting model params into fake tensors
- figure out init fn
- integrate torchtitan configs for DP/TP to control autop

Hack an init_fn for llama3 and observe loss decreasing with autoparallel

"""
[rank0]:[titan] 2025-06-16 16:24:16,593 - root - INFO - Training starts at step 1.
[rank0]:[titan] 2025-06-16 16:24:23,544 - root - INFO - step:  1  loss:  8.1880  memory:  4.88GiB(6.16%)  tps: 28
[rank0]:[titan] 2025-06-16 16:24:23,545 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:[titan] 2025-06-16 16:24:23,842 - root - INFO - step:  2  loss:  8.1610  memory:  4.90GiB(6.20%)  tps: 13,785
[rank0]:[titan] 2025-06-16 16:24:24,135 - root - INFO - step:  3  loss:  8.0871  memory:  4.90GiB(6.20%)  tps: 14,006
[rank0]:[titan] 2025-06-16 16:24:24,433 - root - INFO - step:  4  loss:  7.9516  memory:  4.90GiB(6.20%)  tps: 13,770
[rank0]:[titan] 2025-06-16 16:24:24,727 - root - INFO - step:  5  loss:  7.8552  memory:  4.90GiB(6.20%)  tps: 13,959
[rank0]:[titan] 2025-06-16 16:24:25,023 - root - INFO - step:  6  loss:  7.7732  memory:  4.90GiB(6.20%)  tps: 13,859
[rank0]:[titan] 2025-06-16 16:24:25,324 - root - INFO - step:  7  loss:  7.6987  memory:  4.90GiB(6.20%)  tps: 13,664
[rank0]:[titan] 2025-06-16 16:24:25,617 - root - INFO - step:  8  loss:  7.6779  memory:  4.90GiB(6.20%)  tps: 13,985
[rank0]:[titan] 2025-06-16 16:24:25,911 - root - INFO - step:  9  loss:  7.6043  memory:  4.90GiB(6.20%)  tps: 13,962
[rank0]:[titan] 2025-06-16 16:24:26,207 - root - INFO - step: 10  loss:  7.5778  memory:  4.90GiB(6.20%)  tps: 13,891
"""

Adopt new autoparallel API with meta-init model

Allows reverting a lot of the hacks in the original integration that
were caused by not creating a model obj in the train.py due to passing a
model_fn builder to autop.

Fixes to align with latest autoparallel

Add inductor config knobs for comms optimizations to torchtitan

Make inductor always run compile passes

basically, this is an annoying workaround for debugging iteratively.

1- you run the model, it compiles, but something weird happens
2- you enable some logging or tlparse, rerun. but inductor decides not
to run your pass anymore, its results are cached.

since (2) has confused me horribly on more than one occasion, i just
disable caching for now

Drop hacky llama3_init_fn and use autop init_weights feature

Relying on meta-pytorch/autoparallel#20, this
lets us automatically apply a user's init_weights fn to the autoparallel
model.

Verified this works with

`CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --model.name llama3_auto_parallel --parallelism.tensor_parallel_degree 4 --training.dataset c4`

```
[rank0]:[titan] 2025-07-02 16:18:02,007 - root - INFO - Training starts at step 1.
[rank0]:[titan] 2025-07-02 16:18:08,224 - root - INFO - step:  1  loss:  8.1848  memory:  1.09GiB(1.14%)  tps: 77  tflops: 0.01  mfu: 0.00%
[rank0]:[titan] 2025-07-02 16:18:08,224 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:[titan] 2025-07-02 16:18:08,310 - root - INFO - step:  2  loss:  8.1619  memory:  1.15GiB(1.21%)  tps: 48,138  tflops: 3.46  mfu: 0.35
%
[rank0]:[titan] 2025-07-02 16:18:08,356 - root - INFO - step:  3  loss:  8.1140  memory:  1.15GiB(1.21%)  tps: 88,440  tflops: 6.36  mfu: 0.64
%
[rank0]:[titan] 2025-07-02 16:18:08,406 - root - INFO - step:  4  loss:  8.0099  memory:  1.15GiB(1.21%)  tps: 82,626  tflops: 5.94  mfu: 0.60
%
[rank0]:[titan] 2025-07-02 16:18:08,457 - root - INFO - step:  5  loss:  7.8928  memory:  1.15GiB(1.21%)  tps: 81,594  tflops: 5.87  mfu: 0.59
%
[rank0]:[titan] 2025-07-02 16:18:08,508 - root - INFO - step:  6  loss:  7.7758  memory:  1.15GiB(1.21%)  tps: 79,607  tflops: 5.72  mfu: 0.58
%
[rank0]:[titan] 2025-07-02 16:18:08,559 - root - INFO - step:  7  loss:  7.6221  memory:  1.15GiB(1.21%)  tps: 81,448  tflops: 5.86  mfu: 0.59
%
[rank0]:[titan] 2025-07-02 16:18:08,611 - root - INFO - step:  8  loss:  7.5578  memory:  1.15GiB(1.21%)  tps: 79,732  tflops: 5.73  mfu: 0.58
%
[rank0]:[titan] 2025-07-02 16:18:08,659 - root - INFO - step:  9  loss:  7.3851  memory:  1.15GiB(1.21%)  tps: 85,655  tflops: 6.16  mfu: 0.62
%
[rank0]:[titan] 2025-07-02 16:18:08,709 - root - INFO - step: 10  loss:  7.3361  memory:  1.15GiB(1.21%)  tps: 81,855  tflops: 5.89  mfu: 0.60
%
[rank0]:[titan] 2025-07-02 16:18:08,709 - root - INFO - Sleeping 2 seconds for other ranks to complete
```

fix lint

[ghstack-poisoned]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants