Skip to content

Commit 53fb9aa

Browse files
committed
[cp][flex_attention] integration test trial
ghstack-source-id: bb8df2a Pull-Request-resolved: #1160 ghstack-source-id: bb8df2a Pull Request resolved: #1228
1 parent 0b44d4c commit 53fb9aa

File tree

10 files changed

+86
-22
lines changed

10 files changed

+86
-22
lines changed

torchtitan/distributed/utils.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,17 @@
99
import os
1010
from collections.abc import Generator, Iterable
1111
from datetime import timedelta
12+
from typing import Optional
1213

1314
import torch
1415
import torch.distributed._functional_collectives as funcol
1516
import torch.distributed.distributed_c10d as c10d
1617
from torch import distributed as dist
1718
from torch.distributed.device_mesh import DeviceMesh
1819
from torch.distributed.tensor import DTensor
20+
from torch.distributed.tensor.experimental._attention import _FlexAttentionSharder
1921
from torch.nn.attention import SDPBackend
22+
from torch.nn.attention.flex_attention import BlockMask
2023

2124
from torchtitan.config_manager import TORCH_DTYPE_MAP
2225
from torchtitan.distributed.parallel_dims import ParallelDims
@@ -158,22 +161,35 @@ def create_context_parallel_ctx(
158161
cp_seq_dims: list[int],
159162
cp_no_restore_buffers: set[torch.Tensor],
160163
cp_rotate_method: str,
164+
sharder: Optional[_FlexAttentionSharder] = None,
161165
):
162166
try:
163167
from torch.distributed.tensor.experimental import context_parallel
164-
from torch.distributed.tensor.experimental._attention import set_rotate_method
168+
from torch.distributed.tensor.experimental._attention import (
169+
_DispatchMode,
170+
_set_dispatch_mode,
171+
set_rotate_method,
172+
)
165173
except ImportError:
166174
print(
167175
f"PyTorch version {torch.__version__} does not include the experimental "
168176
"Context Parallel API. Please update to a newer version."
169177
)
170178

171179
set_rotate_method(cp_rotate_method)
180+
"""
181+
_set_dispatch_mode("torch_dispatch")
182+
assert (
183+
torch.distributed.tensor.experimental._attention._dispatch_mode
184+
== _DispatchMode.TORCH_DISPATCH
185+
)
186+
"""
172187
return context_parallel(
173188
cp_mesh,
174189
buffers=cp_buffers,
175190
buffer_seq_dims=cp_seq_dims,
176191
no_restore_buffers=cp_no_restore_buffers,
192+
sharder=sharder,
177193
)
178194

179195

@@ -194,8 +210,9 @@ def context(cp_context: Generator[None, None, None] | None = None):
194210
if cp_context is not None:
195211
if SDPBackend.MATH in ScaledDotProductAttention.backends:
196212
ScaledDotProductAttention.backends.remove(SDPBackend.MATH)
213+
# TODO: add logic for flex-attention
197214
assert (
198-
ScaledDotProductAttention.backends
215+
ScaledDotProductAttention.backends or True
199216
), "No valid SDPA backends with CP."
200217
stack.enter_context(cp_context)
201218

torchtitan/experiments/llama4/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@
4040
rope_theta=500000,
4141
num_experts=16,
4242
interleave_moe_layer_step=1,
43+
use_flex_attn=True,
44+
attn_mask_type="block_causal",
45+
# attn_mask_type="causal",
4346
),
4447
"17bx128e": TransformerModelArgs(
4548
dim=5120,

torchtitan/experiments/llama4/model/args.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ class TransformerModelArgs(BaseModelArgs):
5555
interleave_moe_layer_step: int = 2
5656
# token-choice
5757
top_k: int = 1
58-
use_grouped_mm: bool = True # grouped mm or for-loop for the experts computation
58+
use_grouped_mm: bool = False # grouped mm or for-loop for the experts computation
5959
load_balance_coeff: float | None = 1e-3
6060

6161
def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> None:
@@ -74,12 +74,13 @@ def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> Non
7474
"FlexAttention is not compatible with selective AC yet. "
7575
"See https://github.com/pytorch/pytorch/issues/147879"
7676
)
77-
77+
"""
7878
if job_config.parallelism.context_parallel_degree > 1 and self.use_flex_attn:
7979
raise ValueError(
8080
"FlexAttention is not compatible with CP yet. "
8181
"We are still working on this."
8282
)
83+
"""
8384

8485
def get_nparams_and_flops(
8586
self, model: nn.Module, seq_len: int

torchtitan/experiments/llama4/train_configs/debug_model.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ local_batch_size = 8
4141
seq_len = 2048
4242
max_norm = 1.0 # grad norm clipping
4343
steps = 10
44-
compile = false
44+
compile = true
4545
dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M)
4646

4747
[parallelism]

torchtitan/experiments/llama4/train_configs/llama4_17bx16e.toml

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ profile_freq = 100
1111

1212
[metrics]
1313
log_freq = 10
14-
enable_tensorboard = false
14+
enable_tensorboard = true
1515
save_tb_folder = "tb"
1616

1717
[model]
@@ -27,23 +27,30 @@ eps = 1e-15
2727

2828
[lr_scheduler]
2929
warmup_steps = 600
30+
# warmup_steps = 20
3031
lr_min = 0.1
3132

3233
[training]
3334
local_batch_size = 8
3435
seq_len = 8192
36+
# seq_len = 16384
37+
# seq_len = 32768
38+
# seq_len = 65536
3539
max_norm = 1.0 # grad norm clipping
3640
steps = 3000
41+
# steps = 100
3742
compile = false
43+
# compile = true
3844
dataset = "c4"
45+
deterministic = true
3946

4047
[parallelism]
4148
data_parallel_replicate_degree = 1
4249
data_parallel_shard_degree = -1
4350
tensor_parallel_degree = 8
4451
enable_async_tensor_parallel = false
4552
pipeline_parallel_degree = 1
46-
context_parallel_degree = 1
53+
context_parallel_degree = 4
4754

4855
[checkpoint]
4956
enable_checkpoint = false

torchtitan/models/llama3/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,17 @@
4848
multiple_of=1024,
4949
rope_theta=500000,
5050
),
51+
"8B_flex_attn": TransformerModelArgs(
52+
dim=4096,
53+
n_layers=32,
54+
n_heads=32,
55+
n_kv_heads=8,
56+
ffn_dim_multiplier=1.3,
57+
multiple_of=1024,
58+
rope_theta=500000,
59+
use_flex_attn=True,
60+
attn_mask_type="block_causal",
61+
),
5162
"70B": TransformerModelArgs(
5263
dim=8192,
5364
n_layers=80,

torchtitan/models/llama3/model/args.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,6 @@ def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> Non
4848
"See https://github.com/pytorch/pytorch/issues/147879"
4949
)
5050

51-
if job_config.parallelism.context_parallel_degree > 1 and self.use_flex_attn:
52-
raise ValueError(
53-
"FlexAttention is not compatible with CP yet. "
54-
"We are still working on this."
55-
)
56-
5751
def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]:
5852
nparams = sum(p.numel() for p in model.parameters())
5953
nparams_embedding = sum(

torchtitan/models/llama3/train_configs/debug_model.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ local_batch_size = 8
4343
seq_len = 2048
4444
max_norm = 1.0 # grad norm clipping
4545
steps = 10
46-
compile = false
46+
compile = true
4747
dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M)
4848

4949
[parallelism]

torchtitan/models/llama3/train_configs/llama3_8b.toml

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,17 @@ dump_folder = "./outputs"
66
description = "Llama 3 8B training"
77

88
[profiling]
9-
enable_profiling = true
9+
enable_profiling = false
1010
save_traces_folder = "profile_trace"
11-
profile_freq = 100
11+
# profile_freq = 100
12+
profile_freq = 10
13+
enable_memory_snapshot = false
14+
save_memory_snapshot_folder = "memory_snapshot"
1215

1316
[metrics]
1417
log_freq = 10
1518
enable_tensorboard = true
19+
# enable_tensorboard = false
1620
save_tb_folder = "tb"
1721

1822
[model]
@@ -27,15 +31,18 @@ lr = 3e-4
2731
eps = 1e-8
2832

2933
[lr_scheduler]
30-
warmup_steps = 200 # lr scheduler warm up
34+
# warmup_steps = 200 # lr scheduler warm up
35+
warmup_steps = 600
3136

3237
[training]
3338
local_batch_size = 1
34-
seq_len = 8192
39+
seq_len = 32768
3540
max_norm = 1.0 # grad norm clipping
36-
steps = 1000
37-
compile = false
41+
# steps = 1000
42+
steps = 3000
43+
compile = true
3844
dataset = "c4"
45+
deterministic = true
3946

4047
[parallelism]
4148
data_parallel_replicate_degree = 1
@@ -53,7 +60,8 @@ export_dtype = "float32"
5360
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]
5461

5562
[activation_checkpoint]
56-
mode = "selective" # ["none", "selective", "full"]
63+
# mode = "selective" # ["none", "selective", "full"]
64+
mode = "full"
5765
selective_ac_option = "op" # "int" = ac every positive int layer or 'op', ac based on ops policy
5866

5967
[float8]

torchtitan/train.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,14 @@
1111
from typing import Any, Generator, Iterable, Optional
1212

1313
import torch
14-
from torch.distributed.elastic.multiprocessing.errors import record
1514

1615
import torchtitan.components.ft as ft
1716
import torchtitan.protocols.train_spec as train_spec_module
17+
from torch.distributed.elastic.multiprocessing.errors import record
18+
from torch.distributed.tensor.experimental._attention import (
19+
_FlexAttentionSequentialSharder,
20+
)
21+
1822
from torchtitan.components.checkpoint import CheckpointManager
1923
from torchtitan.components.dataloader import DataloaderStopIteration
2024
from torchtitan.components.loss import rescale_accumulated_loss
@@ -139,7 +143,9 @@ def __init__(self, job_config: JobConfig):
139143

140144
# build model (using meta init)
141145
model_cls = self.train_spec.cls
146+
# NOTE (xilunwu): need to store model_args.use_flex_attn for train_step
142147
model_args = self.train_spec.config[job_config.model.flavor]
148+
self.model_args = model_args
143149
# set the model args from training job configs
144150
model_args.update_from_config(job_config, tokenizer)
145151

@@ -367,13 +373,30 @@ def forward_backward_step(
367373
# apply context parallelism if cp is enabled
368374
# ensure CP handles the separate freqs_cis buffer for each pp stage
369375
inputs = input_dict["input"]
376+
377+
# TODO: move this into `create_context_parallel_ctx`
378+
# init block_mask for flex_attention
379+
block_mask = None
380+
if self.model_args.use_flex_attn:
381+
from torchtitan.models.attention import FlexAttention
382+
383+
mask_mod = FlexAttention._get_causal_mask_mod()
384+
batch_dimension = 1
385+
seq_len = inputs.shape[1]
386+
block_mask = FlexAttention.compiled_create_block_mask(
387+
mask_mod, batch_dimension, None, seq_len, seq_len
388+
)
389+
370390
optional_context_parallel_ctx = (
371391
dist_utils.create_context_parallel_ctx(
372392
cp_mesh=self.world_mesh["cp"],
373393
cp_buffers=[inputs, labels] + [m.freqs_cis for m in model_parts],
374394
cp_seq_dims=[1, 1] + [0 for _ in model_parts],
375395
cp_no_restore_buffers={inputs, labels},
376396
cp_rotate_method=self.job_config.parallelism.context_parallel_rotate_method,
397+
sharder=_FlexAttentionSequentialSharder(
398+
mesh=world_mesh["cp"], block_mask=block_mask
399+
),
377400
)
378401
if parallel_dims.cp_enabled
379402
else None

0 commit comments

Comments
 (0)