-
Notifications
You must be signed in to change notification settings - Fork 565
add auto_eager_graph_pass #1813
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
7be92c7
to
56049a8
Compare
if job_config.compile.enable and "model" in job_config.compile.components: | ||
torch._inductor.config.reorder_for_peak_memory = False | ||
model = torch.compile(model, backend=job_config.compile.backend, fullgraph=True) | ||
from torch._dynamo.backends.common import aot_autograd as auto_autograd_backend |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: auto_autograd_backend -> aot_autograd_backend
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we make this configurable, instead of always turning on? The default could be "turning on". But I think researchers may want to be able to play with the non-optimized version as well.
Also more documentation is needed, in the code. The flags look mysterious.
yesss, it's not ready for review yet... |
56049a8
to
2fa5702
Compare
2fa5702
to
7dd7471
Compare
2f7415e
to
0d70c22
Compare
return model | ||
|
||
|
||
def get_compile_backend(backend_name: str) -> Union[str, callable]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given the complexity, I think we should start putting things to separate files. E.g. this function can go to compile_utils.py
or backend.py
. The original file can stay as simple_fsdp.py
for now. WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, this make sense to me.
torchtitan/config/job_config.py
Outdated
model_backend: str = "inductor" | ||
loss_backend: str = "inductor" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I somehow feel it's not worth changing the global config yet, as people who don't use SimpleFSDP won't have the motivation to separate them.
For now you could extend JobConfig following https://github.com/pytorch/torchtitan/blob/main/docs/extension.md#extending-jobconfig
and define a new config, e.g. compile.model_backend_override
which default to None
so that you could set
backend=compile_config.model_backend_override or compile_config.backend,
0d70c22
to
615abb3
Compare
torchtitan/config/job_config.py
Outdated
"""Which components to compile""" | ||
backend: str = "inductor" | ||
|
||
simplefsdp_backend_override: str | None = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we don't need to modify the config/job_config.py
here's an example https://github.com/pytorch/torchtitan/blob/main/torchtitan/experiments/flux/job_config.py
you'll have to do https://github.com/pytorch/torchtitan/blob/main/torchtitan/experiments/flux/train_configs/debug_model.toml#L53 for it to take effect for now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
my bad, updated
24a1e8b
to
9717183
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we add a test for aot_eager_autobucketing
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can add this test. It is dependent on this PR: pytorch/pytorch#165063. I will add once the pytorch one is merged.
9717183
to
dc4486a
Compare
0c8187d
to
b9d3cf7
Compare
could you rebase onto #1871 before merge? |
) When the autobucketing pass is registered as aot_eager backend `fw_compiler` and `bw_compiler`, this pr ensures the tensors are all-gathers on "cpu/cuda" device instead of "meta" device. When we do `dist.all_gather_object`, it will create new bytestorage outside no_dispatch [here](https://github.com/pytorch/pytorch/blob/a2e2e1d8c026951baa345f0dd17668bd1718eda5/torch/distributed/distributed_c10d.py#L3303), which is on meta device. Thus, I updated the code to use `unset_fake_temporarily`, which would gather RealTensor from other ranks. It is needed to unblock the aot_eager+autobucketing pass in this [PR](pytorch/torchtitan#1813). Otherwise, I hit the error as follows: ```bash traceback : Traceback (most recent call last): File "/home/ruisizhang123/pytorch/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 358, in wrapper return f(*args, **kwargs) File "/home/ruisizhang123/torchtitan/torchtitan/train.py", line 607, in train self.train_step(data_iterator) ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^ File "/home/ruisizhang123/torchtitan/torchtitan/train.py", line 507, in train_step loss = self.forward_backward_step(input_dict, labels) File "/home/ruisizhang123/torchtitan/torchtitan/train.py", line 483, in forward_backward_step pred = model_parts[0](inputs, **extra_inputs, **extra_args) File "/home/ruisizhang123/pytorch/torch/_dynamo/eval_frame.py", line 418, in __call__ return super().__call__(*args, **kwargs) ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^ File "/home/ruisizhang123/pytorch/torch/nn/modules/module.py", line 1784, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^ File "/home/ruisizhang123/pytorch/torch/nn/modules/module.py", line 1795, in _call_impl return forward_call(*args, **kwargs) File "/home/ruisizhang123/pytorch/torch/_dynamo/eval_frame.py", line 901, in compile_wrapper raise e.remove_dynamo_frames() from None # see TORCHDYNAMO_VERBOSE=1 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/ruisizhang123/pytorch/torch/_dynamo/output_graph.py", line 2359, in _call_user_compiler raise BackendCompilerFailed( self.compiler_fn, e, inspect.currentframe() ).with_traceback(e.__traceback__) from None File "/home/ruisizhang123/pytorch/torch/_dynamo/output_graph.py", line 2334, in _call_user_compiler compiled_fn = compiler_fn(gm, example_inputs) File "/home/ruisizhang123/pytorch/torch/_dynamo/repro/after_dynamo.py", line 156, in __call__ compiled_gm = compiler_fn(gm, example_inputs) File "/home/ruisizhang123/pytorch/torch/__init__.py", line 2441, in __call__ return self.compiler_fn(model_, inputs_, **self.kwargs) ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/ruisizhang123/pytorch/torch/_dynamo/backends/common.py", line 117, in __call__ cg = aot_module_simplified(gm, example_inputs, **self.kwargs) File "/home/ruisizhang123/pytorch/torch/_functorch/aot_autograd.py", line 1100, in aot_module_simplified compiled_fn, _ = aot_stage2_compile( ~~~~~~~~~~~~~~~~~~^ aot_state, ^^^^^^^^^^ ...<4 lines>... inference_compiler, ^^^^^^^^^^^^^^^^^^^ ) ^ File "/home/ruisizhang123/pytorch/torch/_functorch/_aot_autograd/graph_compile.py", line 257, in aot_stage2_compile return aot_stage2_autograd(aot_state, aot_graph_capture) File "/home/ruisizhang123/pytorch/torch/_functorch/_aot_autograd/graph_compile.py", line 1696, in aot_stage2_autograd compiled_fw_func = aot_config.fw_compiler(fw_module, adjusted_flat_args) File "/home/ruisizhang123/torchtitan/torchtitan/experiments/simple_fsdp/backend.py", line 35, in aten_autobucketing_reordering_pass schedule_overlap_bucketing(gm) ~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^ File "/home/ruisizhang123/pytorch/torch/_inductor/fx_passes/overlap_scheduling.py", line 755, in schedule_overlap_bucketing ).run() ~~~^^ File "/home/ruisizhang123/pytorch/torch/_inductor/fx_passes/overlap_scheduling.py", line 358, in run self._align_compute_nodes_runtime_estimations_across_all_distributed_ranks() ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^ File "/home/ruisizhang123/pytorch/torch/_inductor/fx_passes/overlap_scheduling.py", line 337, in _align_compute_nodes_runtime_estimations_across_all_distributed_ranks dist.all_gather_object( ~~~~~~~~~~~~~~~~~~~~~~^ gathered_runtime_estimations, runtime_estimations, pg ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ) ^ File "/home/ruisizhang123/pytorch/torch/distributed/c10d_logger.py", line 82, in wrapper return func(*args, **kwargs) File "/home/ruisizhang123/pytorch/torch/distributed/distributed_c10d.py", line 3170, in all_gather_object input_tensor, local_size = _object_to_tensor(obj, current_device, group) ~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/ruisizhang123/pytorch/torch/distributed/distributed_c10d.py", line 3079, in _object_to_tensor byte_tensor = torch.ByteTensor(byte_storage).to(device) ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^ torch._dynamo.exc.BackendCompilerFailed: backend='compiler_fn' raised: RuntimeError: Attempted to set the storage of a tensor on device "cpu" to a storage on different device "meta". This is no longer allowed; the devices must match. Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo" ``` Pull Request resolved: #165063 Approved by: https://github.com/eellison
b9d3cf7
to
5c7dc96
Compare
This pr adds the autobucketing pass at aten-level to simplefsdp. It runs autobucketing + aot_eager backend without inductor. The aten fx autobucketing pass can be find in this PR: pytorch/pytorch#163960. Key updates are: 1. Support customized `aot_eger_autobucketing` backend to perform autobucketing optimization. 2. In simplefsdp, the model_backend can be replaced by user's customized passes using `compile.model_backend_override`.
…rch#165063) When the autobucketing pass is registered as aot_eager backend `fw_compiler` and `bw_compiler`, this pr ensures the tensors are all-gathers on "cpu/cuda" device instead of "meta" device. When we do `dist.all_gather_object`, it will create new bytestorage outside no_dispatch [here](https://github.com/pytorch/pytorch/blob/a2e2e1d8c026951baa345f0dd17668bd1718eda5/torch/distributed/distributed_c10d.py#L3303), which is on meta device. Thus, I updated the code to use `unset_fake_temporarily`, which would gather RealTensor from other ranks. It is needed to unblock the aot_eager+autobucketing pass in this [PR](pytorch/torchtitan#1813). Otherwise, I hit the error as follows: ```bash traceback : Traceback (most recent call last): File "/home/ruisizhang123/pytorch/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 358, in wrapper return f(*args, **kwargs) File "/home/ruisizhang123/torchtitan/torchtitan/train.py", line 607, in train self.train_step(data_iterator) ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^ File "/home/ruisizhang123/torchtitan/torchtitan/train.py", line 507, in train_step loss = self.forward_backward_step(input_dict, labels) File "/home/ruisizhang123/torchtitan/torchtitan/train.py", line 483, in forward_backward_step pred = model_parts[0](inputs, **extra_inputs, **extra_args) File "/home/ruisizhang123/pytorch/torch/_dynamo/eval_frame.py", line 418, in __call__ return super().__call__(*args, **kwargs) ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^ File "/home/ruisizhang123/pytorch/torch/nn/modules/module.py", line 1784, in _wrapped_call_impl return self._call_impl(*args, **kwargs) ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^ File "/home/ruisizhang123/pytorch/torch/nn/modules/module.py", line 1795, in _call_impl return forward_call(*args, **kwargs) File "/home/ruisizhang123/pytorch/torch/_dynamo/eval_frame.py", line 901, in compile_wrapper raise e.remove_dynamo_frames() from None # see TORCHDYNAMO_VERBOSE=1 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/ruisizhang123/pytorch/torch/_dynamo/output_graph.py", line 2359, in _call_user_compiler raise BackendCompilerFailed( self.compiler_fn, e, inspect.currentframe() ).with_traceback(e.__traceback__) from None File "/home/ruisizhang123/pytorch/torch/_dynamo/output_graph.py", line 2334, in _call_user_compiler compiled_fn = compiler_fn(gm, example_inputs) File "/home/ruisizhang123/pytorch/torch/_dynamo/repro/after_dynamo.py", line 156, in __call__ compiled_gm = compiler_fn(gm, example_inputs) File "/home/ruisizhang123/pytorch/torch/__init__.py", line 2441, in __call__ return self.compiler_fn(model_, inputs_, **self.kwargs) ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/ruisizhang123/pytorch/torch/_dynamo/backends/common.py", line 117, in __call__ cg = aot_module_simplified(gm, example_inputs, **self.kwargs) File "/home/ruisizhang123/pytorch/torch/_functorch/aot_autograd.py", line 1100, in aot_module_simplified compiled_fn, _ = aot_stage2_compile( ~~~~~~~~~~~~~~~~~~^ aot_state, ^^^^^^^^^^ ...<4 lines>... inference_compiler, ^^^^^^^^^^^^^^^^^^^ ) ^ File "/home/ruisizhang123/pytorch/torch/_functorch/_aot_autograd/graph_compile.py", line 257, in aot_stage2_compile return aot_stage2_autograd(aot_state, aot_graph_capture) File "/home/ruisizhang123/pytorch/torch/_functorch/_aot_autograd/graph_compile.py", line 1696, in aot_stage2_autograd compiled_fw_func = aot_config.fw_compiler(fw_module, adjusted_flat_args) File "/home/ruisizhang123/torchtitan/torchtitan/experiments/simple_fsdp/backend.py", line 35, in aten_autobucketing_reordering_pass schedule_overlap_bucketing(gm) ~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^ File "/home/ruisizhang123/pytorch/torch/_inductor/fx_passes/overlap_scheduling.py", line 755, in schedule_overlap_bucketing ).run() ~~~^^ File "/home/ruisizhang123/pytorch/torch/_inductor/fx_passes/overlap_scheduling.py", line 358, in run self._align_compute_nodes_runtime_estimations_across_all_distributed_ranks() ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^ File "/home/ruisizhang123/pytorch/torch/_inductor/fx_passes/overlap_scheduling.py", line 337, in _align_compute_nodes_runtime_estimations_across_all_distributed_ranks dist.all_gather_object( ~~~~~~~~~~~~~~~~~~~~~~^ gathered_runtime_estimations, runtime_estimations, pg ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ) ^ File "/home/ruisizhang123/pytorch/torch/distributed/c10d_logger.py", line 82, in wrapper return func(*args, **kwargs) File "/home/ruisizhang123/pytorch/torch/distributed/distributed_c10d.py", line 3170, in all_gather_object input_tensor, local_size = _object_to_tensor(obj, current_device, group) ~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/ruisizhang123/pytorch/torch/distributed/distributed_c10d.py", line 3079, in _object_to_tensor byte_tensor = torch.ByteTensor(byte_storage).to(device) ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^ torch._dynamo.exc.BackendCompilerFailed: backend='compiler_fn' raised: RuntimeError: Attempted to set the storage of a tensor on device "cpu" to a storage on different device "meta". This is no longer allowed; the devices must match. Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo" ``` Pull Request resolved: pytorch#165063 Approved by: https://github.com/eellison
This pr adds the autobucketing pass at aten-level to simplefsdp. It runs autobucketing + aot_eager backend without inductor. The aten fx autobucketing pass can be find in this PR: pytorch/pytorch#163960. Key updates are: 1. Support customized `aot_eger_autobucketing` backend to perform autobucketing optimization. 2. In simplefsdp, the model_backend can be replaced by user's customized passes using `compile.model_backend_override`.
This pr adds the autobucketing pass at aten-level to simplefsdp. It runs autobucketing + aot_eager backend without inductor. The aten fx autobucketing pass can be find in this PR: pytorch/pytorch#163960.
Key updates are:
aot_eger_autobucketing
backend to perform autobucketing optimization.compile.model_backend_override
.