1010
1111from autoparallel .api import AutoParallel
1212
13- from torch .distributed import DeviceMesh
1413from torch .distributed .fsdp import MixedPrecisionPolicy
1514from torch .distributed .tensor .placement_types import Replicate , Shard
1615
@@ -33,6 +32,7 @@ def parallelize_llama(
3332 the model must fit on GPU or CPU memory.
3433 """
3534 world_mesh = parallel_dims .world_mesh
35+
3636 def input_fn ():
3737 global_batch_size = job_config .training .global_batch_size
3838 if global_batch_size < 0 :
@@ -62,6 +62,52 @@ def input_fn():
6262 lambda bucket_idx : 1000 / parallel_dims .tp
6363 )
6464
65+ enable_overlap_scheduling = (
66+ job_config .experimental .enable_inductor_aten_fx_overlap_scheduler
67+ )
68+ enable_overlap_scheduling_bucketing = (
69+ job_config .experimental .enable_inductor_aten_fx_overlap_scheduler_bucketing
70+ )
71+ if enable_overlap_scheduling_bucketing :
72+ assert (
73+ enable_overlap_scheduling
74+ ), "bucketing can not be used without overlap scheduling"
75+
76+ if enable_overlap_scheduling :
77+ from torch ._inductor .fx_passes .overlap_scheduling import OverlapScheduler
78+
79+ torch ._inductor .config .test_configs .aten_fx_overlap_preserving_bucketing = (
80+ enable_overlap_scheduling_bucketing
81+ )
82+
83+ def _overlap_bucketing_pass (graph ):
84+ overlap_scheduler = OverlapScheduler (graph .owning_module )
85+ overlap_scheduler .run ()
86+
87+ torch ._inductor .config .post_grad_custom_post_pass = _overlap_bucketing_pass
88+
89+ enable_asynctp = job_config .experimental .enable_autoparallel_asynctp
90+ if enable_asynctp :
91+ from torch .distributed ._symmetric_memory import enable_symm_mem_for_group
92+
93+ assert "tp" in world_mesh .mesh_dim_names
94+ enable_symm_mem_for_group (world_mesh ["tp" ].get_group ().group_name )
95+ torch ._inductor .config ._micro_pipeline_tp = False
96+ # Disable inductor AsyncTP passes, in favor of using Autoparallel passes fork.
97+ from autoparallel .asynctp import micro_pipeline_tp_pass
98+
99+ existing_post_grad_custom_post_pass = (
100+ torch ._inductor .config .post_grad_custom_post_pass
101+ )
102+
103+ def _pass (graph ):
104+ if existing_post_grad_custom_post_pass is not None :
105+ existing_post_grad_custom_post_pass (graph )
106+
107+ micro_pipeline_tp_pass (graph , None )
108+
109+ torch ._inductor .config .post_grad_custom_post_pass = _pass
110+
65111 # bail out
66112 # model = model_fn()
67113 # return model
@@ -101,7 +147,8 @@ def input_fn():
101147 )
102148 out_sharding = x_sharding
103149 loss_parallel_enabled = (
104- parallel_dims .tp_enabled and not job_config .parallelism .disable_loss_parallel
150+ parallel_dims .tp_enabled
151+ and not job_config .parallelism .disable_loss_parallel
105152 )
106153 if loss_parallel_enabled :
107154 out_sharding = tuple (
0 commit comments