Skip to content

Commit b72466a

Browse files
committed
Add DualPipeV
1 parent 7354848 commit b72466a

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

torchtitan/distributed/pipeline_parallel.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
PipelineScheduleMulti,
2020
PipelineScheduleSingle,
2121
ScheduleZBVZeroBubble,
22+
ScheduleDualPipeV,
2223
)
2324

2425
from torchtitan.config import JobConfig
@@ -335,7 +336,7 @@ def _build_stage_from_modules(
335336
models = []
336337

337338
schedule_class = get_schedule_class(pp_schedule)
338-
style = "v" if schedule_class == ScheduleZBVZeroBubble else "loop"
339+
style = "v" if schedule_class in (ScheduleZBVZeroBubble, ScheduleDualPipeV) else "loop"
339340

340341
for stage_idx in stage_ids_this_rank(pp_rank, pp_size, num_stages, style=style):
341342
module_names = module_names_per_stage[stage_idx]

0 commit comments

Comments
 (0)