Skip to content

Conversation

ruisizhang123
Copy link
Contributor

@ruisizhang123 ruisizhang123 commented Oct 7, 2025

This pr adds the autobucketing pass at aten-level to simplefsdp. It runs autobucketing + aot_eager backend without inductor. The aten fx autobucketing pass can be find in this PR: pytorch/pytorch#163960.

Key updates are:

  1. Support customized aot_eger_autobucketing backend to perform autobucketing optimization.
  2. In simplefsdp, the model_backend can be replaced by user's customized passes using compile.model_backend_override.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Oct 7, 2025
@ruisizhang123 ruisizhang123 force-pushed the ruisi/aot_eager_pass branch 4 times, most recently from 7be92c7 to 56049a8 Compare October 9, 2025 22:02
if job_config.compile.enable and "model" in job_config.compile.components:
torch._inductor.config.reorder_for_peak_memory = False
model = torch.compile(model, backend=job_config.compile.backend, fullgraph=True)
from torch._dynamo.backends.common import aot_autograd as auto_autograd_backend
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: auto_autograd_backend -> aot_autograd_backend

@ruisizhang123 ruisizhang123 changed the title [wip] add auto_eager_graph_pass add auto_eager_graph_pass Oct 9, 2025
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we make this configurable, instead of always turning on? The default could be "turning on". But I think researchers may want to be able to play with the non-optimized version as well.

Also more documentation is needed, in the code. The flags look mysterious.

@ruisizhang123
Copy link
Contributor Author

Should we make this configurable, instead of always turning on? The default could be "turning on". But I think researchers may want to be able to play with the non-optimized version as well.

Also more documentation is needed, in the code. The flags look mysterious.

yesss, it's not ready for review yet...

@ruisizhang123 ruisizhang123 force-pushed the ruisi/aot_eager_pass branch 3 times, most recently from 2f7415e to 0d70c22 Compare October 10, 2025 01:09
return model


def get_compile_backend(backend_name: str) -> Union[str, callable]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given the complexity, I think we should start putting things to separate files. E.g. this function can go to compile_utils.py or backend.py. The original file can stay as simple_fsdp.py for now. WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, this make sense to me.

Comment on lines 629 to 630
model_backend: str = "inductor"
loss_backend: str = "inductor"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I somehow feel it's not worth changing the global config yet, as people who don't use SimpleFSDP won't have the motivation to separate them.

For now you could extend JobConfig following https://github.com/pytorch/torchtitan/blob/main/docs/extension.md#extending-jobconfig
and define a new config, e.g. compile.model_backend_override which default to None so that you could set

backend=compile_config.model_backend_override or compile_config.backend,

"""Which components to compile"""
backend: str = "inductor"

simplefsdp_backend_override: str | None = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my bad, updated

@ruisizhang123 ruisizhang123 force-pushed the ruisi/aot_eager_pass branch 4 times, most recently from 24a1e8b to 9717183 Compare October 10, 2025 05:07
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we add a test for aot_eager_autobucketing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can add this test. It is dependent on this PR: pytorch/pytorch#165063. I will add once the pytorch one is merged.

@ruisizhang123 ruisizhang123 force-pushed the ruisi/aot_eager_pass branch 2 times, most recently from 0c8187d to b9d3cf7 Compare October 13, 2025 16:37
@tianyu-l
Copy link
Contributor

could you rebase onto #1871 before merge?

pytorchmergebot pushed a commit to pytorch/pytorch that referenced this pull request Oct 14, 2025
)

When the autobucketing pass  is registered as aot_eager backend `fw_compiler` and `bw_compiler`, this pr ensures the tensors are all-gathers on "cpu/cuda" device instead of "meta" device.

When we do `dist.all_gather_object`, it will create new bytestorage outside no_dispatch [here](https://github.com/pytorch/pytorch/blob/a2e2e1d8c026951baa345f0dd17668bd1718eda5/torch/distributed/distributed_c10d.py#L3303), which is on meta device. Thus, I updated the code to use `unset_fake_temporarily`, which would gather RealTensor from other ranks.

 It is needed to unblock the aot_eager+autobucketing pass in this [PR](pytorch/torchtitan#1813).

Otherwise, I hit the error as follows:

```bash
  traceback : Traceback (most recent call last):
    File "/home/ruisizhang123/pytorch/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 358, in wrapper
      return f(*args, **kwargs)
    File "/home/ruisizhang123/torchtitan/torchtitan/train.py", line 607, in train
      self.train_step(data_iterator)
      ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^
    File "/home/ruisizhang123/torchtitan/torchtitan/train.py", line 507, in train_step
      loss = self.forward_backward_step(input_dict, labels)
    File "/home/ruisizhang123/torchtitan/torchtitan/train.py", line 483, in forward_backward_step
      pred = model_parts[0](inputs, **extra_inputs, **extra_args)
    File "/home/ruisizhang123/pytorch/torch/_dynamo/eval_frame.py", line 418, in __call__
      return super().__call__(*args, **kwargs)
             ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
    File "/home/ruisizhang123/pytorch/torch/nn/modules/module.py", line 1784, in _wrapped_call_impl
      return self._call_impl(*args, **kwargs)
             ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
    File "/home/ruisizhang123/pytorch/torch/nn/modules/module.py", line 1795, in _call_impl
      return forward_call(*args, **kwargs)
    File "/home/ruisizhang123/pytorch/torch/_dynamo/eval_frame.py", line 901, in compile_wrapper
      raise e.remove_dynamo_frames() from None  # see TORCHDYNAMO_VERBOSE=1
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/home/ruisizhang123/pytorch/torch/_dynamo/output_graph.py", line 2359, in _call_user_compiler
      raise BackendCompilerFailed(
          self.compiler_fn, e, inspect.currentframe()
      ).with_traceback(e.__traceback__) from None
    File "/home/ruisizhang123/pytorch/torch/_dynamo/output_graph.py", line 2334, in _call_user_compiler
      compiled_fn = compiler_fn(gm, example_inputs)
    File "/home/ruisizhang123/pytorch/torch/_dynamo/repro/after_dynamo.py", line 156, in __call__
      compiled_gm = compiler_fn(gm, example_inputs)
    File "/home/ruisizhang123/pytorch/torch/__init__.py", line 2441, in __call__
      return self.compiler_fn(model_, inputs_, **self.kwargs)
             ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/home/ruisizhang123/pytorch/torch/_dynamo/backends/common.py", line 117, in __call__
      cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
    File "/home/ruisizhang123/pytorch/torch/_functorch/aot_autograd.py", line 1100, in aot_module_simplified
      compiled_fn, _ = aot_stage2_compile(
                       ~~~~~~~~~~~~~~~~~~^
          aot_state,
          ^^^^^^^^^^
      ...<4 lines>...
          inference_compiler,
          ^^^^^^^^^^^^^^^^^^^
      )
      ^
    File "/home/ruisizhang123/pytorch/torch/_functorch/_aot_autograd/graph_compile.py", line 257, in aot_stage2_compile
      return aot_stage2_autograd(aot_state, aot_graph_capture)
    File "/home/ruisizhang123/pytorch/torch/_functorch/_aot_autograd/graph_compile.py", line 1696, in aot_stage2_autograd
      compiled_fw_func = aot_config.fw_compiler(fw_module, adjusted_flat_args)
    File "/home/ruisizhang123/torchtitan/torchtitan/experiments/simple_fsdp/backend.py", line 35, in aten_autobucketing_reordering_pass
      schedule_overlap_bucketing(gm)
      ~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^
    File "/home/ruisizhang123/pytorch/torch/_inductor/fx_passes/overlap_scheduling.py", line 755, in schedule_overlap_bucketing
      ).run()
        ~~~^^
    File "/home/ruisizhang123/pytorch/torch/_inductor/fx_passes/overlap_scheduling.py", line 358, in run
      self._align_compute_nodes_runtime_estimations_across_all_distributed_ranks()
      ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^
    File "/home/ruisizhang123/pytorch/torch/_inductor/fx_passes/overlap_scheduling.py", line 337, in _align_compute_nodes_runtime_estimations_across_all_distributed_ranks
      dist.all_gather_object(
      ~~~~~~~~~~~~~~~~~~~~~~^
          gathered_runtime_estimations, runtime_estimations, pg
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      )
      ^
    File "/home/ruisizhang123/pytorch/torch/distributed/c10d_logger.py", line 82, in wrapper
      return func(*args, **kwargs)
    File "/home/ruisizhang123/pytorch/torch/distributed/distributed_c10d.py", line 3170, in all_gather_object
      input_tensor, local_size = _object_to_tensor(obj, current_device, group)
                                 ~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/home/ruisizhang123/pytorch/torch/distributed/distributed_c10d.py", line 3079, in _object_to_tensor
      byte_tensor = torch.ByteTensor(byte_storage).to(device)
                    ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^
  torch._dynamo.exc.BackendCompilerFailed: backend='compiler_fn' raised:
  RuntimeError: Attempted to set the storage of a tensor on device "cpu" to a storage on different device "meta".  This is no longer allowed; the devices must match.

  Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

```

Pull Request resolved: #165063
Approved by: https://github.com/eellison
@ruisizhang123 ruisizhang123 merged commit d0e2545 into main Oct 14, 2025
4 of 5 checks passed
@ruisizhang123 ruisizhang123 deleted the ruisi/aot_eager_pass branch October 14, 2025 21:16
githubsgi pushed a commit to githubsgi/torchtitan that referenced this pull request Oct 15, 2025
This pr adds the autobucketing pass at aten-level to simplefsdp. It runs
autobucketing + aot_eager backend without inductor. The aten fx
autobucketing pass can be find in this PR:
pytorch/pytorch#163960.

Key updates are:

1. Support customized `aot_eger_autobucketing` backend to perform
autobucketing optimization.
2. In simplefsdp, the model_backend can be replaced by user's customized
passes using `compile.model_backend_override`.
zhudada0120 pushed a commit to zhudada0120/pytorch that referenced this pull request Oct 15, 2025
…rch#165063)

When the autobucketing pass  is registered as aot_eager backend `fw_compiler` and `bw_compiler`, this pr ensures the tensors are all-gathers on "cpu/cuda" device instead of "meta" device.

When we do `dist.all_gather_object`, it will create new bytestorage outside no_dispatch [here](https://github.com/pytorch/pytorch/blob/a2e2e1d8c026951baa345f0dd17668bd1718eda5/torch/distributed/distributed_c10d.py#L3303), which is on meta device. Thus, I updated the code to use `unset_fake_temporarily`, which would gather RealTensor from other ranks.

 It is needed to unblock the aot_eager+autobucketing pass in this [PR](pytorch/torchtitan#1813).

Otherwise, I hit the error as follows:

```bash
  traceback : Traceback (most recent call last):
    File "/home/ruisizhang123/pytorch/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 358, in wrapper
      return f(*args, **kwargs)
    File "/home/ruisizhang123/torchtitan/torchtitan/train.py", line 607, in train
      self.train_step(data_iterator)
      ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^
    File "/home/ruisizhang123/torchtitan/torchtitan/train.py", line 507, in train_step
      loss = self.forward_backward_step(input_dict, labels)
    File "/home/ruisizhang123/torchtitan/torchtitan/train.py", line 483, in forward_backward_step
      pred = model_parts[0](inputs, **extra_inputs, **extra_args)
    File "/home/ruisizhang123/pytorch/torch/_dynamo/eval_frame.py", line 418, in __call__
      return super().__call__(*args, **kwargs)
             ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
    File "/home/ruisizhang123/pytorch/torch/nn/modules/module.py", line 1784, in _wrapped_call_impl
      return self._call_impl(*args, **kwargs)
             ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
    File "/home/ruisizhang123/pytorch/torch/nn/modules/module.py", line 1795, in _call_impl
      return forward_call(*args, **kwargs)
    File "/home/ruisizhang123/pytorch/torch/_dynamo/eval_frame.py", line 901, in compile_wrapper
      raise e.remove_dynamo_frames() from None  # see TORCHDYNAMO_VERBOSE=1
      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/home/ruisizhang123/pytorch/torch/_dynamo/output_graph.py", line 2359, in _call_user_compiler
      raise BackendCompilerFailed(
          self.compiler_fn, e, inspect.currentframe()
      ).with_traceback(e.__traceback__) from None
    File "/home/ruisizhang123/pytorch/torch/_dynamo/output_graph.py", line 2334, in _call_user_compiler
      compiled_fn = compiler_fn(gm, example_inputs)
    File "/home/ruisizhang123/pytorch/torch/_dynamo/repro/after_dynamo.py", line 156, in __call__
      compiled_gm = compiler_fn(gm, example_inputs)
    File "/home/ruisizhang123/pytorch/torch/__init__.py", line 2441, in __call__
      return self.compiler_fn(model_, inputs_, **self.kwargs)
             ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/home/ruisizhang123/pytorch/torch/_dynamo/backends/common.py", line 117, in __call__
      cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
    File "/home/ruisizhang123/pytorch/torch/_functorch/aot_autograd.py", line 1100, in aot_module_simplified
      compiled_fn, _ = aot_stage2_compile(
                       ~~~~~~~~~~~~~~~~~~^
          aot_state,
          ^^^^^^^^^^
      ...<4 lines>...
          inference_compiler,
          ^^^^^^^^^^^^^^^^^^^
      )
      ^
    File "/home/ruisizhang123/pytorch/torch/_functorch/_aot_autograd/graph_compile.py", line 257, in aot_stage2_compile
      return aot_stage2_autograd(aot_state, aot_graph_capture)
    File "/home/ruisizhang123/pytorch/torch/_functorch/_aot_autograd/graph_compile.py", line 1696, in aot_stage2_autograd
      compiled_fw_func = aot_config.fw_compiler(fw_module, adjusted_flat_args)
    File "/home/ruisizhang123/torchtitan/torchtitan/experiments/simple_fsdp/backend.py", line 35, in aten_autobucketing_reordering_pass
      schedule_overlap_bucketing(gm)
      ~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^
    File "/home/ruisizhang123/pytorch/torch/_inductor/fx_passes/overlap_scheduling.py", line 755, in schedule_overlap_bucketing
      ).run()
        ~~~^^
    File "/home/ruisizhang123/pytorch/torch/_inductor/fx_passes/overlap_scheduling.py", line 358, in run
      self._align_compute_nodes_runtime_estimations_across_all_distributed_ranks()
      ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^
    File "/home/ruisizhang123/pytorch/torch/_inductor/fx_passes/overlap_scheduling.py", line 337, in _align_compute_nodes_runtime_estimations_across_all_distributed_ranks
      dist.all_gather_object(
      ~~~~~~~~~~~~~~~~~~~~~~^
          gathered_runtime_estimations, runtime_estimations, pg
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      )
      ^
    File "/home/ruisizhang123/pytorch/torch/distributed/c10d_logger.py", line 82, in wrapper
      return func(*args, **kwargs)
    File "/home/ruisizhang123/pytorch/torch/distributed/distributed_c10d.py", line 3170, in all_gather_object
      input_tensor, local_size = _object_to_tensor(obj, current_device, group)
                                 ~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/home/ruisizhang123/pytorch/torch/distributed/distributed_c10d.py", line 3079, in _object_to_tensor
      byte_tensor = torch.ByteTensor(byte_storage).to(device)
                    ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^
  torch._dynamo.exc.BackendCompilerFailed: backend='compiler_fn' raised:
  RuntimeError: Attempted to set the storage of a tensor on device "cpu" to a storage on different device "meta".  This is no longer allowed; the devices must match.

  Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

```

Pull Request resolved: pytorch#165063
Approved by: https://github.com/eellison
githubsgi pushed a commit to githubsgi/torchtitan that referenced this pull request Oct 16, 2025
This pr adds the autobucketing pass at aten-level to simplefsdp. It runs
autobucketing + aot_eager backend without inductor. The aten fx
autobucketing pass can be find in this PR:
pytorch/pytorch#163960.

Key updates are:

1. Support customized `aot_eger_autobucketing` backend to perform
autobucketing optimization.
2. In simplefsdp, the model_backend can be replaced by user's customized
passes using `compile.model_backend_override`.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants