Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
7ce8eed
Second version of degub/deterinistic configs.
githubsgi Oct 7, 2025
e06e1e9
Review relaeted updates.
githubsgi Oct 10, 2025
b23e5bb
Review 2 related changes.
githubsgi Oct 10, 2025
bcb0894
[DSV3] Offload dequantization process to DCP QuantizedHFReader (#1804)
wwwjn Oct 8, 2025
93cd4c6
Disable FlexAttention max-autotune when deterministic is used (#1808)
fegin Oct 8, 2025
44acac8
Fix num of layers for deepseek-v3 (#1845)
wwwjn Oct 9, 2025
7029661
[VLM] Add token-imbalance loss (#1803)
lkhphuc Oct 9, 2025
594fe9c
refactor TrainSpec to remove the name field (#1850)
tianyu-l Oct 10, 2025
ee49f75
Refactor attention and make attention mask an argument to the model (…
fegin Oct 10, 2025
5ba3488
minor refactor over EP (#1854)
tianyu-l Oct 12, 2025
e11ea4b
Graduate qwen3 from experiment to core (#1860)
wwwjn Oct 13, 2025
a92059b
Review related updates.
githubsgi Oct 10, 2025
e53255a
Rebasing and adding MATH attention kernel.
githubsgi Oct 13, 2025
f30caf6
Indent issue fix.
githubsgi Oct 13, 2025
f4cbf9d
Removing ipex.
githubsgi Oct 13, 2025
00c3165
Review updates.
githubsgi Oct 14, 2025
db187ff
Fixing linter error.
githubsgi Oct 14, 2025
315ea57
graduate llama4 to core (#1865)
tianyu-l Oct 14, 2025
64b77de
consolidate experiments/deepseek_v3 (#1869)
tianyu-l Oct 14, 2025
fa50840
add auto_eager_graph_pass (#1813)
ruisizhang123 Oct 14, 2025
00dbd5a
Second version of degub/deterinistic configs.
githubsgi Oct 7, 2025
8bdb11d
Review relaeted updates.
githubsgi Oct 10, 2025
a6a1bab
Review 2 related changes.
githubsgi Oct 15, 2025
2e8585c
[DSV3] Offload dequantization process to DCP QuantizedHFReader (#1804)
wwwjn Oct 8, 2025
2795956
Disable FlexAttention max-autotune when deterministic is used (#1808)
fegin Oct 8, 2025
1432c09
[VLM] Add token-imbalance loss (#1803)
lkhphuc Oct 9, 2025
1b88f57
refactor TrainSpec to remove the name field (#1850)
tianyu-l Oct 10, 2025
139926b
Refactor attention and make attention mask an argument to the model (…
fegin Oct 10, 2025
087dc88
add script to train with ft (#1812)
tushar00jain Oct 10, 2025
5032db6
Indent issue fix.
githubsgi Oct 13, 2025
409da11
Post rebase changes.
githubsgi Oct 15, 2025
2c35b95
Second version of degub/deterinistic configs.
githubsgi Oct 7, 2025
27a942f
Review relaeted updates.
githubsgi Oct 10, 2025
93e6d5e
Review 2 related changes.
githubsgi Oct 10, 2025
ff832d2
add script to train with ft (#1812)
tushar00jain Oct 10, 2025
6bb6254
minor refactor over EP (#1854)
tianyu-l Oct 12, 2025
f5dbc0f
[vlm] Add light-weight CI for experimental models (#1848)
wwwjn Oct 12, 2025
1eb5f8e
add owners and CI status for experiments (#1859)
tianyu-l Oct 13, 2025
aba26b4
TorchTitan e2e test on torchcomms device mesh (#1847)
mori360 Oct 14, 2025
09db1fe
graduate llama4 to core (#1865)
tianyu-l Oct 14, 2025
d0b1987
move PP API to model agnostic file (#1868)
tianyu-l Oct 14, 2025
6648707
[refactor] graduate custom_config_module and unify args/config naming…
tianyu-l Oct 14, 2025
e57adb7
Rebase misses.
githubsgi Oct 15, 2025
d68127a
Rebase mistakes.
githubsgi Oct 15, 2025
a251fd4
Lint'er error fixes.
githubsgi Oct 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions docs/debugging.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ When debugging issues with multi-dimensional parallelism (combinations of FSDP,
Set consistent random seeds across all parallelism dimensions:

```bash
CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --training.seed 42
CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --debug.seed 42
```

**Seed behavior with parallelism:**
Expand All @@ -84,7 +84,7 @@ CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_tr
Enable deterministic algorithms to ensure bit-for-bit reproducibility across runs:

```bash
CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --training.deterministic
CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --debug.deterministic
```

**What it does:**
Expand All @@ -93,6 +93,19 @@ CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_tr
- Sets deterministic workspace configuration for CuBLAS operations
- **Note:** This will significantly reduce training performance but ensures exact reproducibility

Use `--debug.deterministic_warn_only` to only warn about (not stop running) kernel without deterministic implementation.

### Activation Checkipointing Debugging ###

The following debug configs are available for AC.

`ac_preserve_rng_state` - if deterministic output compared to non-checkpointed passes is required, set to true. Results in stashing and restoring the RNG state during each checkpoint, may be slower.

`ac_determinism_check` - A string specifying the determinism function

`ac_debug` - capture ac debug information. Will be slower.

See https://docs.pytorch.org/docs/stable/checkpoint.html for details.

### Seed-Checkpoint-based Reproducibility

Expand Down
62 changes: 34 additions & 28 deletions tests/unit_tests/test_activation_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from torch.utils.flop_counter import FlopCounterMode

from torchtitan.config.job_config import ActivationCheckpoint as ACConfig
from torchtitan.config.job_config import JobConfig
from torchtitan.distributed.activation_checkpoint import apply_ac


Expand Down Expand Up @@ -74,15 +75,16 @@ def get_bw_flops(model_fn):
# 2. SAC
# Per-op SAC's policy is to save every other mm
model_selective_ac = ToyModule()
ac_config_no_force = ACConfig(
job_config = JobConfig()
job_config.activation_checkpoint = ACConfig(
mode="selective",
selective_ac_option="op",
per_op_sac_force_recompute_mm_shapes_by_fqns=[], # Empty list
early_stop=False,
)
apply_ac(
model_selective_ac,
ac_config_no_force,
job_config,
model_compile_enabled=False,
use_flex_attn=False,
op_sac_save_list=_op_sac_save_list,
Expand All @@ -92,15 +94,15 @@ def get_bw_flops(model_fn):
# 3. Per-op SAC with force recompute "moe.router.gate"
# This leads to two mms being recomputed since they share the same shape!
model_with_force_first = ToyModule()
ac_config_with_force_first = ACConfig(
job_config.activation_checkpoint = ACConfig(
mode="selective",
selective_ac_option="op",
per_op_sac_force_recompute_mm_shapes_by_fqns=["moe.router.gate"],
early_stop=False,
)
apply_ac(
model_with_force_first,
ac_config_with_force_first,
job_config,
model_compile_enabled=False,
use_flex_attn=False,
op_sac_save_list=_op_sac_save_list,
Expand All @@ -109,15 +111,15 @@ def get_bw_flops(model_fn):

# 4. Per-op SAC with force recompute "output"
model_with_force_last = ToyModule()
ac_config_with_force_last = ACConfig(
job_config.activation_checkpoint = ACConfig(
mode="selective",
selective_ac_option="op",
per_op_sac_force_recompute_mm_shapes_by_fqns=["output"],
early_stop=False,
)
apply_ac(
model_with_force_last,
ac_config_with_force_last,
job_config,
model_compile_enabled=False,
use_flex_attn=False,
op_sac_save_list=_op_sac_save_list,
Expand All @@ -126,13 +128,13 @@ def get_bw_flops(model_fn):

# 5. Full AC
model_with_full_ac = ToyModule()
ac_config_full_ac = ACConfig(
job_config.activation_checkpoint = ACConfig(
mode="full",
early_stop=False,
)
apply_ac(
model_with_full_ac,
ac_config_full_ac,
job_config,
model_compile_enabled=False,
use_flex_attn=False,
op_sac_save_list=_op_sac_save_list,
Expand Down Expand Up @@ -168,14 +170,14 @@ def get_act_mem(model_fn):
# 2. SAC
# Per-op SAC's policy is to save every other mm
model_selective_ac = ToyModule().cuda()
ac_config_no_force = ACConfig(
job_config.activation_checkpoint = ACConfig(
mode="selective",
selective_ac_option="op",
per_op_sac_force_recompute_mm_shapes_by_fqns=[], # Empty list
)
apply_ac(
model_selective_ac,
ac_config_no_force,
job_config,
model_compile_enabled=False,
use_flex_attn=False,
op_sac_save_list=_op_sac_save_list,
Expand All @@ -185,14 +187,14 @@ def get_act_mem(model_fn):
# 3. Per-op SAC with force recompute "moe.router.gate"
# This leads to two mms being recomputed since they share the same shape!
model_with_force_first = ToyModule().cuda()
ac_config_with_force_first = ACConfig(
job_config.activation_checkpoint = ACConfig(
mode="selective",
selective_ac_option="op",
per_op_sac_force_recompute_mm_shapes_by_fqns=["moe.router.gate"],
)
apply_ac(
model_with_force_first,
ac_config_with_force_first,
job_config,
model_compile_enabled=False,
use_flex_attn=False,
op_sac_save_list=_op_sac_save_list,
Expand All @@ -201,14 +203,14 @@ def get_act_mem(model_fn):

# 4. Per-op SAC with force recompute "output"
model_with_force_last = ToyModule().cuda()
ac_config_with_force_last = ACConfig(
job_config.activation_checkpoint = ACConfig(
mode="selective",
selective_ac_option="op",
per_op_sac_force_recompute_mm_shapes_by_fqns=["output"],
)
apply_ac(
model_with_force_last,
ac_config_with_force_last,
job_config,
model_compile_enabled=False,
use_flex_attn=False,
op_sac_save_list=_op_sac_save_list,
Expand All @@ -217,12 +219,12 @@ def get_act_mem(model_fn):

# 5. Full AC
model_with_full_ac = ToyModule().cuda()
ac_config_full_ac = ACConfig(
job_config.activation_checkpoint = ACConfig(
mode="full",
)
apply_ac(
model_with_full_ac,
ac_config_full_ac,
job_config,
model_compile_enabled=False,
use_flex_attn=False,
op_sac_save_list=_op_sac_save_list,
Expand All @@ -243,40 +245,44 @@ def test_correctness(self):

model_selective_ac = ToyModule()
model_selective_ac.load_state_dict(model_no_ac.state_dict())
apply_ac(
model_selective_ac,
ACConfig(
job_config = JobConfig()
job_config.activation_checkpoint = ACConfig(
mode="selective",
selective_ac_option="op",
per_op_sac_force_recompute_mm_shapes_by_fqns=[],
),
)
apply_ac(
model_selective_ac,
job_config,
model_compile_enabled=False,
use_flex_attn=False,
op_sac_save_list=_op_sac_save_list,
)
model_force_first = ToyModule()
model_force_first.load_state_dict(model_no_ac.state_dict())
apply_ac(
model_force_first,
ACConfig(
job_config.activation_checkpoint = ACConfig(
mode="selective",
selective_ac_option="op",
per_op_sac_force_recompute_mm_shapes_by_fqns=["moe.router.gate"],
),
)
apply_ac(
model_force_first,
job_config,
model_compile_enabled=False,
use_flex_attn=False,
op_sac_save_list=_op_sac_save_list,
)

model_force_last = ToyModule()
model_force_last.load_state_dict(model_no_ac.state_dict())
apply_ac(
model_force_last,
ACConfig(
job_config.activation_checkpoint = ACConfig(
mode="selective",
selective_ac_option="op",
per_op_sac_force_recompute_mm_shapes_by_fqns=["output"],
),
)
apply_ac(
model_force_last,
job_config,
model_compile_enabled=False,
use_flex_attn=False,
op_sac_save_list=_op_sac_save_list,
Expand Down
2 changes: 2 additions & 0 deletions torchtitan/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
Quantize,
Training,
Validation,
Debug
)
from .manager import ConfigManager

Expand All @@ -49,4 +50,5 @@
"Profiling",
"Training",
"Validation",
"Debug"
]
33 changes: 24 additions & 9 deletions torchtitan/config/job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,15 +253,6 @@ class Training:
many temporary files.
"""

seed: int | None = None
"""Choose the base RNG seed used for training"""

deterministic: bool = False
"""Use deterministic algorithms wherever possible, may be slower"""

debug_moe_force_load_balance: bool = False
"""If True, we force each experts to get the same amount of tokens via round-robin. This option is for debugging usage only."""


@dataclass
class Parallelism:
Expand Down Expand Up @@ -880,6 +871,29 @@ def __post_init__(self):
), "validation steps must be positive or -1"


@dataclass
class Debug:
seed: int | None = None
"""Choose the base RNG seed used for training"""

deterministic: bool = False
"""Use deterministic algorithms wherever possible, may be slower"""

deterministic_warn_only: bool = False
"""Only warns about ops without deterministic implementations rather than erroring out """

ac_preserve_rng_state: bool = False
"""If deterministic output compared to non-checkpointed passes is required, set to true. Results in stashing and restoring the RNG state during each checkpoint, may be slower. See https://docs.pytorch.org/docs/stable/checkpoint.html for details."""

ac_determinism_check: str = "default"
"""A string specifying the determinism function. See https://docs.pytorch.org/docs/stable/checkpoint.html for details."""

ac_debug: bool = False
""" Capture ac debug information. Will be slower. See https://docs.pytorch.org/docs/stable/checkpoint.html for details."""

moe_force_load_balance: bool = False
"""If True, we force each experts to get the same amount of tokens via round-robin. This option is for debugging usage only."""

@dataclass
class JobConfig:
"""
Expand All @@ -905,6 +919,7 @@ class JobConfig:
fault_tolerance: FaultTolerance = field(default_factory=FaultTolerance)
experimental: Experimental = field(default_factory=Experimental)
validation: Validation = field(default_factory=Validation)
debug: Debug = field(default_factory=Debug)

def to_dict(self) -> dict[str, Any]:
return asdict(self)
Loading