-
Notifications
You must be signed in to change notification settings - Fork 565
add auto_eager_graph_pass #1813
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from typing import Any, Union | ||
|
||
import torch | ||
|
||
|
||
def get_compile_backend(backend_name: str) -> Union[str, callable]: | ||
# return the compile backends used in SimpleFSDP training | ||
# Step1: check if backend_name is inside available torch.compile backends | ||
# Step2: check if the backend_name has been registered as a customized backend | ||
available_torch_backend = torch._dynamo.list_backends(exclude_tags=()) | ||
if backend_name in available_torch_backend: | ||
return backend_name | ||
|
||
if backend_name == "aot_eager_autobucketing": | ||
# Perform auto optimization in aten fx-level and execute code in aot_eager backend | ||
# The autobucketing logic is here: https://github.com/pytorch/pytorch/pull/163960 | ||
from torch._dynamo.backends.common import aot_autograd as aot_autograd_backend | ||
from torch._inductor.fx_passes.overlap_scheduling import ( | ||
schedule_overlap_bucketing, | ||
) | ||
|
||
torch._inductor.config.test_configs.aten_fx_overlap_preserving_bucketing = True | ||
torch._inductor.config.test_configs.aten_fx_overlap_insert_overlap_deps = False | ||
torch._inductor.config.allow_buffer_reuse = False | ||
|
||
def aten_autobucketing_reordering_pass( | ||
gm: torch.fx.GraphModule, example_inputs: Any | ||
) -> torch.fx.GraphModule: | ||
schedule_overlap_bucketing(gm) | ||
gm.recompile() | ||
return gm | ||
|
||
backend = aot_autograd_backend( | ||
fw_compiler=aten_autobucketing_reordering_pass, | ||
bw_compiler=aten_autobucketing_reordering_pass, | ||
keep_inference_input_mutations=True, | ||
) | ||
else: | ||
raise AssertionError(f"Unsupported customized backend: {backend_name}") | ||
|
||
return backend |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from dataclasses import dataclass, field | ||
|
||
|
||
@dataclass | ||
class Compile: | ||
model_backend_override: str | None = None | ||
"""Override backend to compile in simplefsdp. Additional backend includes aot_eager_autobucketing""" | ||
|
||
|
||
@dataclass | ||
class JobConfig: | ||
compile: Compile = field(default_factory=Compile) |
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should we add a test for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I can add this test. It is dependent on this PR: pytorch/pytorch#165063. I will add once the pytorch one is merged. |
Uh oh!
There was an error while loading. Please reload this page.