Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/torchft.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,12 @@ The `--training.global_batch_size` parameter refers to global batch size that wi

#### Replica Group 0
```bash
CUDA_VISIBLE_DEVICES=0,1,2,3 NGPU=4 ./run_train.sh --parallelism.data_parallel_shard_degree=4 --fault_tolerance.enable --fault_tolerance.group_size=2 --fault_tolerance.replica_id=0 --fault_tolerance.semi_sync_method="diloco" --experimental.custom_args_module=torchtitan.components.ft.config
CONFIG_FILE=./torchtitan/models/llama3_ft/train_configs/debug_model.toml CUDA_VISIBLE_DEVICES=0,1,2,3 NGPU=4 ./run_train.sh --parallelism.data_parallel_shard_degree=4 --fault_tolerance.enable --fault_tolerance.group_size=2 --fault_tolerance.replica_id=0
```

#### Replica Group 1
```bash
CUDA_VISIBLE_DEVICES=4,5,6,7 NGPU=4 ./run_train.sh --parallelism.data_parallel_shard_degree=4 --fault_tolerance.enable --fault_tolerance.group_size=2 --fault_tolerance.replica_id=1 --fault_tolerance.semi_sync_method="diloco" --experimental.custom_args_module=torchtitan.components.ft.config
CONFIG_FILE=./torchtitan/models/llama3_ft/train_configs/debug_model.toml CUDA_VISIBLE_DEVICES=4,5,6,7 NGPU=4 ./run_train.sh --parallelism.data_parallel_shard_degree=4 --fault_tolerance.enable --fault_tolerance.group_size=2 --fault_tolerance.replica_id=1
```

## Fault Tolerance Configuration Options
Expand Down
13 changes: 12 additions & 1 deletion torchtitan/components/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,19 @@ def __init__(
self.load_only = checkpoint_config.load_only

self.ft_manager = (
ft_manager.manager if ft_manager and ft_manager.enabled else None
ft_manager.manager
if ft_manager
and ft_manager.enabled
and checkpoint_config.enable_ft_dataloader_checkpoints
else None
)

if ft_manager and ft_manager.enabled and not self.ft_manager:
logger.warn(
"Fault tolerance is enabled but enable_ft_dataloader_checkpoints is False. "
"This means replicas can retrain over the same data multiple times, which can result in overfitting."
)

if self.ft_manager:
optimizers.init_cache_state_dict()

Expand Down
36 changes: 36 additions & 0 deletions torchtitan/config/job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,20 @@ class Profiling:
profile_freq: int = 10
"""How often to collect profile traces, in iterations"""

profiler_active: int = 1
"""
The steps profiler is active for.

This is used to configure torch.profile.schedule.
"""

profiler_warmup: int = 3
"""
The number of warmup steps before the active step in each profiling cycle.

This is used to configure torch.profile.schedule.
"""

enable_memory_snapshot: bool = False
"""Whether to dump memory snapshot"""

Expand Down Expand Up @@ -408,6 +422,28 @@ class Checkpoint:
enable: bool = False
"""Whether to enable checkpoint"""

enable_ft_dataloader_checkpoints: bool = True
"""
Warning: Disabling this can have fault tolerant replicas training
over the same data multiple times. Use it with caution if training
over the same data is acceptable.

Used to enable checkpointing the dataloader index for fault tolerant training with torchft.

Fault tolerant training stores data loader index in the checkpoints, so that training can resume
without going over the same batch twice.

If enabled, data loader state is checkpointed. Otherwise, replicas
will train over the same data multiple times, which can result in
overfitting.

The failed replcia will still recover other state e.g. model
parameters from other replcias.

Note, if regular checkpointing is enabled, we also checkpoint the
data loader state. But when not using fault tolerance, the entire training starts from scratch.
"""

folder: str = "checkpoint"
"""
The folder to store the checkpoints.
Expand Down
17 changes: 13 additions & 4 deletions torchtitan/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,10 @@ def maybe_enable_amp(


def init_distributed(
comm_config: CommConfig, enable_cpu_backend: bool = False, base_folder: str = ""
comm_config: CommConfig,
enable_cpu_backend: bool = False,
base_folder: str = "",
replica_id: int | None = None,
):
def _warn_overwrite_env(env, val):
if env in os.environ:
Expand Down Expand Up @@ -279,9 +282,17 @@ def _get_distributed_backend(enable_cpu_backend):
os.makedirs(dump_dir, exist_ok=True)
_warn_overwrite_env(TRACE_FILE, f"{dump_dir}/{prefix}")

local_rank = os.environ.get("RANK")
world_size = os.environ.get("WORLD_SIZE")

global_rank = None
if local_rank is not None and replica_id is not None and world_size is not None:
global_rank = int(local_rank) + int(replica_id) * int(world_size)

torch.distributed.init_process_group(
backend=_get_distributed_backend(enable_cpu_backend),
timeout=timedelta(seconds=comm_config.init_timeout_seconds),
rank=global_rank,
)


Expand Down Expand Up @@ -432,9 +443,7 @@ def _clip_grad_norm_with_ep(
if math.isinf(norm_type):
total_norm = torch.maximum(ep_grads_total_norm, non_ep_grads_total_norm)
else:
total_norm = (
ep_grads_total_norm**norm_type + non_ep_grads_total_norm**norm_type
)
total_norm = ep_grads_total_norm**norm_type + non_ep_grads_total_norm**norm_type
total_norm **= 1.0 / norm_type

if pp_mesh is not None:
Expand Down
11 changes: 6 additions & 5 deletions torchtitan/experiments/forge/example_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,12 +277,13 @@ def train(self):
self.checkpointer.load(step=job_config.checkpoint.load_step)
logger.info(f"Training starts at step {self.step + 1}.")

torch_profiler = maybe_enable_profiling(
job_config.profiling,
global_step=self.step,
base_folder=job_config.job.dump_folder,
)

with (
maybe_enable_profiling(
job_config.profiling,
global_step=self.step,
base_folder=job_config.job.dump_folder,
) as torch_profiler,
maybe_enable_memory_snapshot(
job_config.profiling,
global_step=self.step,
Expand Down
95 changes: 95 additions & 0 deletions torchtitan/models/llama3_ft/train_configs/debug_model.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
[job]
dump_folder = "./outputs"
description = "Llama 3 debug training"
print_args = false

[profiling]
enable_profiling = true
save_traces_folder = "profile_trace"
profile_freq = 10
profiler_active = 10
profiler_warmup = 0
enable_memory_snapshot = false
save_memory_snapshot_folder = "memory_snapshot"

[metrics]
log_freq = 1
disable_color_printing = false
enable_tensorboard = false
save_tb_folder = "tb"
enable_wandb = false

[model]
name = "llama3"
flavor = "debugmodel"
# test folder with tokenizer.json, for debug purpose only
hf_assets_path = "./tests/assets/tokenizer"
# converters = ["float8"]

[optimizer]
name = "AdamW"
lr = 8e-4
eps = 1e-8

[lr_scheduler]
warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps
decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps
decay_type = "linear"
min_lr_factor = 0.0

[training]
local_batch_size = 8
seq_len = 2048
max_norm = 1.0 # grad norm clipping
steps = 100
dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M)

[parallelism]
data_parallel_replicate_degree = 1
data_parallel_shard_degree = -1
fsdp_reshard_after_forward = "default" # default / never / always
tensor_parallel_degree = 1
enable_async_tensor_parallel = false
pipeline_parallel_degree = 1
context_parallel_degree = 1

[checkpoint]
enable = false
folder = "checkpoint"
interval = 10
last_save_model_only = false
export_dtype = "float32"
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]

[activation_checkpoint]
mode = "selective" # ["none", "selective", "full"]
selective_ac_option = "2" # 'int' = ac every positive int layer or 'op', ac based on ops policy

[compile]
enable = false
components = ["model", "loss"]

[quantize.linear.float8]
enable_fsdp_float8_all_gather = false
precompute_float8_dynamic_scale_for_fsdp = false
filter_fqns = ["output"]

[validation]
enable = false
dataset = "c4_validation"
freq = 5
steps = 10

[comm]
train_timeout_seconds = 15

[fault_tolerance]
enable = true
sync_steps = 10
num_fragments = 2
semi_sync_method = "diloco"
process_group = "nccl"
process_group_timeout_ms = 10000

[experimental]
custom_args_module = "torchtitan.components.ft.config"
23 changes: 11 additions & 12 deletions torchtitan/tools/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,10 @@
from torchtitan.config import Profiling as ProfilingConfig
from torchtitan.tools.logging import logger

# the number of warmup steps before the active step in each profiling cycle
WARMUP = 3

# how much memory allocation/free ops to record in memory snapshots
MEMORY_SNAPSHOT_MAX_ENTRIES = 100000


@contextlib.contextmanager
def maybe_enable_profiling(
profiling_config: ProfilingConfig,
*,
Expand All @@ -34,7 +30,11 @@ def maybe_enable_profiling(

if enable_profiling:
trace_dir = os.path.join(base_folder, profiling_config.save_traces_folder)
profile_freq = profiling_config.profile_freq
profile_freq, warmup, active = (
profiling_config.profile_freq,
profiling_config.profiler_warmup,
profiling_config.profiler_active,
)

rank = torch.distributed.get_rank()

Expand All @@ -58,7 +58,6 @@ def trace_handler(prof):
if not os.path.exists(trace_dir):
os.makedirs(trace_dir, exist_ok=True)

warmup, active = WARMUP, 1
wait = profile_freq - (active + warmup)
assert (
wait >= 0
Expand All @@ -68,20 +67,20 @@ def trace_handler(prof):
gpu_device_profiled = torch.profiler.ProfilerActivity.CUDA
elif torch.xpu.is_available():
gpu_device_profiled = torch.profiler.ProfilerActivity.XPU
with torch.profiler.profile(
torch_profiler = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
gpu_device_profiled,
],
schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active),
on_trace_ready=trace_handler,
record_shapes=True,
) as torch_profiler:
torch_profiler.step_num = global_step
yield torch_profiler
)
torch_profiler.step_num = global_step
torch_profiler.start()
return torch_profiler
else:
torch_profiler = contextlib.nullcontext()
yield None
return None


@contextlib.contextmanager
Expand Down
Loading
Loading