Skip to content

Commit da7fad7

Browse files
committed
[autoparallel] Add experimental config to enable autoparallel_asynctp
stack-info: PR: #1772, branch: IvanKobzarev/stack/2
1 parent db22479 commit da7fad7

File tree

2 files changed

+54
-2
lines changed

2 files changed

+54
-2
lines changed

torchtitan/config/job_config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -739,6 +739,11 @@ class Experimental:
739739

740740
enable_simplefsdp_passes: bool = False
741741

742+
enable_inductor_aten_fx_overlap_scheduler: bool = False
743+
enable_inductor_aten_fx_overlap_scheduler_bucketing: bool = False
744+
enable_autoparallel_asynctp: bool = False
745+
746+
742747
@dataclass
743748
class Validation:
744749
enable: bool = False

torchtitan/experiments/auto_parallel/parallelize_llama.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
from autoparallel.api import AutoParallel
1212

13-
from torch.distributed import DeviceMesh
1413
from torch.distributed.fsdp import MixedPrecisionPolicy
1514
from torch.distributed.tensor.placement_types import Replicate, Shard
1615

@@ -33,6 +32,7 @@ def parallelize_llama(
3332
the model must fit on GPU or CPU memory.
3433
"""
3534
world_mesh = parallel_dims.world_mesh
35+
3636
def input_fn():
3737
global_batch_size = job_config.training.global_batch_size
3838
if global_batch_size < 0:
@@ -62,6 +62,52 @@ def input_fn():
6262
lambda bucket_idx: 1000 / parallel_dims.tp
6363
)
6464

65+
enable_overlap_scheduling = (
66+
job_config.experimental.enable_inductor_aten_fx_overlap_scheduler
67+
)
68+
enable_overlap_scheduling_bucketing = (
69+
job_config.experimental.enable_inductor_aten_fx_overlap_scheduler_bucketing
70+
)
71+
if enable_overlap_scheduling_bucketing:
72+
assert (
73+
enable_overlap_scheduling
74+
), "bucketing can not be used without overlap scheduling"
75+
76+
if enable_overlap_scheduling:
77+
from torch._inductor.fx_passes.overlap_scheduling import OverlapScheduler
78+
79+
torch._inductor.config.test_configs.aten_fx_overlap_preserving_bucketing = (
80+
enable_overlap_scheduling_bucketing
81+
)
82+
83+
def _overlap_bucketing_pass(graph):
84+
overlap_scheduler = OverlapScheduler(graph.owning_module)
85+
overlap_scheduler.run()
86+
87+
torch._inductor.config.post_grad_custom_post_pass = _overlap_bucketing_pass
88+
89+
enable_asynctp = job_config.experimental.enable_autoparallel_asynctp
90+
if enable_asynctp:
91+
from torch.distributed._symmetric_memory import enable_symm_mem_for_group
92+
93+
assert "tp" in world_mesh.mesh_dim_names
94+
enable_symm_mem_for_group(world_mesh["tp"].get_group().group_name)
95+
torch._inductor.config._micro_pipeline_tp = False
96+
# Disable inductor AsyncTP passes, in favor of using Autoparallel passes fork.
97+
from autoparallel.asynctp import micro_pipeline_tp_pass
98+
99+
existing_post_grad_custom_post_pass = (
100+
torch._inductor.config.post_grad_custom_post_pass
101+
)
102+
103+
def _pass(graph):
104+
if existing_post_grad_custom_post_pass is not None:
105+
existing_post_grad_custom_post_pass(graph)
106+
107+
micro_pipeline_tp_pass(graph, None)
108+
109+
torch._inductor.config.post_grad_custom_post_pass = _pass
110+
65111
# bail out
66112
# model = model_fn()
67113
# return model
@@ -101,7 +147,8 @@ def input_fn():
101147
)
102148
out_sharding = x_sharding
103149
loss_parallel_enabled = (
104-
parallel_dims.tp_enabled and not job_config.parallelism.disable_loss_parallel
150+
parallel_dims.tp_enabled
151+
and not job_config.parallelism.disable_loss_parallel
105152
)
106153
if loss_parallel_enabled:
107154
out_sharding = tuple(

0 commit comments

Comments
 (0)