Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion torchtitan/experiments/simple_fsdp/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu

This folder includes an experimental frontend implementation for [SimpleFSDP: Simpler Fully Sharded Data Parallel with torch.compile](https://arxiv.org/abs/2411.00284). SimpleFSDP is a compiler-based Fully Sharded Data Parallel (FSDP) framework, which has a simple implementation for maintenance and composability, allows full computation-communication graph tracing, and brings performance enhancement via compiler backend optimizations.

### Run SimpleFSDP Training on Llama 3
### Run SimpleFSDP Training on Llama3 & DeepSeek_v3

#### Training Llama3 models

Expand Down Expand Up @@ -42,6 +42,23 @@ Some of the features require the updates from PyTorch, with which we are working
|Expert Parallelism + Activation Checkpointing| 🚧 |
|Expert Parallelism + Pipeline Parallelism| 🚧 |


### Compiler Optimizations

SimpleFSDP relies on compiler backend to perform optimizations (i.e., bucketing & reordering) for good training performance. Currently, the following optimization passes are supported:

1. no optimization: default torch.compile backends (e.g., "inductor", "aot_eager", "eager")

2. auto optimization: perform auto-bucketing & reordering without user inputs. **Note: it is not guaranteed that users will get the most optimized training performance**
- "aot_eager_autobucketing": perform autobucketing at aten fx-level, and perform code execution with aot_eager backend.


users can specify the pass (e.g., "aot_eager_autobucketing") via addtional configs:

```bash
--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config --compile.model_backend_override "aot_eager_autobucketing"
```

### Citation

If you find SimpleFSDP useful, please kindly consider citing the following paper:
Expand Down
47 changes: 47 additions & 0 deletions torchtitan/experiments/simple_fsdp/backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Union

import torch


def get_compile_backend(backend_name: str) -> Union[str, callable]:
# return the compile backends used in SimpleFSDP training
# Step1: check if backend_name is inside available torch.compile backends
# Step2: check if the backend_name has been registered as a customized backend
available_torch_backend = torch._dynamo.list_backends(exclude_tags=())
if backend_name in available_torch_backend:
return backend_name

if backend_name == "aot_eager_autobucketing":
# Perform auto optimization in aten fx-level and execute code in aot_eager backend
# The autobucketing logic is here: https://github.com/pytorch/pytorch/pull/163960
from torch._dynamo.backends.common import aot_autograd as aot_autograd_backend
from torch._inductor.fx_passes.overlap_scheduling import (
schedule_overlap_bucketing,
)

torch._inductor.config.test_configs.aten_fx_overlap_preserving_bucketing = True
torch._inductor.config.test_configs.aten_fx_overlap_insert_overlap_deps = False
torch._inductor.config.allow_buffer_reuse = False

def aten_autobucketing_reordering_pass(
gm: torch.fx.GraphModule, example_inputs: Any
) -> torch.fx.GraphModule:
schedule_overlap_bucketing(gm)
gm.recompile()
return gm

backend = aot_autograd_backend(
fw_compiler=aten_autobucketing_reordering_pass,
bw_compiler=aten_autobucketing_reordering_pass,
keep_inference_input_mutations=True,
)
else:
raise AssertionError(f"Unsupported customized backend: {backend_name}")

return backend
18 changes: 18 additions & 0 deletions torchtitan/experiments/simple_fsdp/job_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass, field


@dataclass
class Compile:
model_backend_override: str | None = None
"""Override backend to compile in simplefsdp. Additional backend includes aot_eager_autobucketing"""


@dataclass
class JobConfig:
compile: Compile = field(default_factory=Compile)
11 changes: 10 additions & 1 deletion torchtitan/experiments/simple_fsdp/llama3/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from torchtitan.models.llama3.infra.parallelize import apply_tp
from torchtitan.tools.logging import logger

from ..backend import get_compile_backend

from ..simple_fsdp import data_parallel, MixedPrecisionPolicy


Expand Down Expand Up @@ -123,6 +125,13 @@ def parallelize_llama(

if job_config.compile.enable and "model" in job_config.compile.components:
torch._inductor.config.reorder_for_peak_memory = False
model = torch.compile(model, backend=job_config.compile.backend, fullgraph=True)
backend = (
job_config.compile.model_backend_override or job_config.compile.backend
)
model = torch.compile(
model,
backend=get_compile_backend(backend),
fullgraph=True,
)

return model
33 changes: 33 additions & 0 deletions torchtitan/experiments/simple_fsdp/tests/integration_tests.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we add a test for aot_eager_autobucketing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can add this test. It is dependent on this PR: pytorch/pytorch#165063. I will add once the pytorch one is merged.

Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,32 @@ def build_simple_fsdp_test_list() -> list[OverrideDefinitions]:
[
"--model.name simple_fsdp.llama3",
"--compile.enable",
"--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config",
],
],
"1D",
"1d",
),
OverrideDefinitions(
[
[
"--model.name simple_fsdp.llama3",
"--compile.enable",
"--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config",
"--compile.model_backend_override aot_eager_autobucketing",
],
],
"1D+aot_eager_autobucketing",
"1d_aot_eager_autobucketing",
),
OverrideDefinitions(
[
[
"--model.name simple_fsdp.llama3",
"--compile.enable",
"--activation_checkpoint.mode selective",
"--activation_checkpoint.selective_ac_option op",
"--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config",
],
],
"1D with selective op AC",
Expand All @@ -46,6 +60,7 @@ def build_simple_fsdp_test_list() -> list[OverrideDefinitions]:
"--model.name simple_fsdp.llama3",
"--compile.enable",
"--activation_checkpoint.mode full",
"--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config",
],
],
"1D with full AC",
Expand All @@ -57,6 +72,7 @@ def build_simple_fsdp_test_list() -> list[OverrideDefinitions]:
"--model.name simple_fsdp.llama3",
"--compile.enable",
"--parallelism.tensor_parallel_degree 2",
"--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config",
],
],
"2D",
Expand All @@ -70,6 +86,7 @@ def build_simple_fsdp_test_list() -> list[OverrideDefinitions]:
"--compile.enable",
"--parallelism.tensor_parallel_degree 2",
"--parallelism.enable_async_tensor_parallel",
"--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config",
],
],
"2D async TP",
Expand All @@ -82,12 +99,14 @@ def build_simple_fsdp_test_list() -> list[OverrideDefinitions]:
"--model.name simple_fsdp.llama3",
"--compile.enable",
"--checkpoint.enable",
"--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config",
],
[
"--model.name simple_fsdp.llama3",
"--compile.enable",
"--checkpoint.enable",
"--training.steps 20",
"--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config",
],
],
"Checkpoint Integration Test - Save Load Full Checkpoint",
Expand All @@ -102,6 +121,7 @@ def build_simple_fsdp_test_list() -> list[OverrideDefinitions]:
"--parallelism.pipeline_parallel_degree 2",
"--parallelism.data_parallel_shard_degree 2",
"--parallelism.tensor_parallel_degree 2",
"--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config",
],
[
"--model.name simple_fsdp.llama3",
Expand All @@ -111,6 +131,7 @@ def build_simple_fsdp_test_list() -> list[OverrideDefinitions]:
"--parallelism.pipeline_parallel_degree 2",
"--parallelism.data_parallel_shard_degree 2",
"--parallelism.tensor_parallel_degree 2",
"--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config",
],
],
"PP+DP+TP 3D test with save/load resume ckpt",
Expand All @@ -124,6 +145,7 @@ def build_simple_fsdp_test_list() -> list[OverrideDefinitions]:
"--compile.enable",
"--parallelism.data_parallel_shard_degree 1",
"--parallelism.data_parallel_replicate_degree 4",
"--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config",
]
],
"DDP",
Expand All @@ -137,6 +159,7 @@ def build_simple_fsdp_test_list() -> list[OverrideDefinitions]:
"--compile.enable",
"--parallelism.data_parallel_shard_degree 2",
"--parallelism.data_parallel_replicate_degree 2",
"--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config",
]
],
"HSDP",
Expand All @@ -151,6 +174,7 @@ def build_simple_fsdp_test_list() -> list[OverrideDefinitions]:
"--parallelism.data_parallel_shard_degree 2",
"--parallelism.data_parallel_replicate_degree 2",
"--parallelism.tensor_parallel_degree 2",
"--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config",
]
],
"HSDP+TP",
Expand All @@ -164,6 +188,7 @@ def build_simple_fsdp_test_list() -> list[OverrideDefinitions]:
"--compile.enable",
"--parallelism.data_parallel_replicate_degree 2",
"--parallelism.tensor_parallel_degree 2",
"--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config",
]
],
"DDP+TP",
Expand All @@ -178,6 +203,7 @@ def build_simple_fsdp_test_list() -> list[OverrideDefinitions]:
"--parallelism.data_parallel_shard_degree 2",
"--parallelism.data_parallel_replicate_degree 2",
"--parallelism.context_parallel_degree 2",
"--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config",
]
],
"HSDP+CP (with dp_shard)",
Expand All @@ -192,6 +218,7 @@ def build_simple_fsdp_test_list() -> list[OverrideDefinitions]:
"--parallelism.data_parallel_shard_degree 2",
"--parallelism.tensor_parallel_degree 2",
"--parallelism.context_parallel_degree 2",
"--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config",
]
],
"FSDP+TP+CP",
Expand All @@ -205,6 +232,7 @@ def build_simple_fsdp_test_list() -> list[OverrideDefinitions]:
"--compile.enable",
"--checkpoint.enable",
"--training.steps 10",
"--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config",
],
# Save at [dp:4] and load at [dp:2, tp:2]. Note that the dataloader should be
# excluded during loading to avoid errors caused by mismatched dp_degree.
Expand All @@ -215,6 +243,7 @@ def build_simple_fsdp_test_list() -> list[OverrideDefinitions]:
"--checkpoint.exclude_from_loading lr_scheduler,dataloader,optimizer",
"--parallelism.tensor_parallel_degree 2",
"--training.steps 20",
"--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config",
],
# load at [tp:4].
[
Expand All @@ -224,6 +253,7 @@ def build_simple_fsdp_test_list() -> list[OverrideDefinitions]:
"--checkpoint.exclude_from_loading lr_scheduler,dataloader,optimizer",
"--parallelism.tensor_parallel_degree 4",
"--training.steps 30",
"--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config",
],
],
"Optional checkpoint",
Expand All @@ -236,6 +266,7 @@ def build_simple_fsdp_test_list() -> list[OverrideDefinitions]:
"--model.name simple_fsdp.deepseek_v3",
"--parallelism.data_parallel_shard_degree 4",
"--parallelism.expert_parallel_degree 2",
"--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config",
],
],
"FSDP+EP",
Expand All @@ -250,6 +281,7 @@ def build_simple_fsdp_test_list() -> list[OverrideDefinitions]:
"--parallelism.tensor_parallel_degree 2",
"--parallelism.expert_parallel_degree 4",
"--parallelism.expert_tensor_parallel_degree 1",
"--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config",
],
],
"FSDP+TP+EP",
Expand All @@ -264,6 +296,7 @@ def build_simple_fsdp_test_list() -> list[OverrideDefinitions]:
"--parallelism.tensor_parallel_degree 2",
"--parallelism.expert_parallel_degree 2",
"--parallelism.expert_tensor_parallel_degree 2",
"--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config",
],
],
"FSDP+TP+EP+ETP",
Expand Down
Loading