Skip to content

[cp][flex_attention] integration test trial #1160

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: gh/XilunWu/18/base
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion torchtitan/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,16 @@
import os
from collections.abc import Generator, Iterable
from datetime import timedelta
from typing import Optional

import torch
import torch.distributed._functional_collectives as funcol
import torch.distributed.distributed_c10d as c10d
from torch import distributed as dist
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor import DTensor
from torch.distributed.tensor.experimental._attention import FlexAttentionSharder
from torch.nn.attention.flex_attention import BlockMask

from torchtitan.tools.logging import logger
from torchtitan.tools.utils import device_module, device_type
Expand Down Expand Up @@ -154,22 +157,37 @@ def create_context_parallel_ctx(
cp_seq_dims: list[int],
cp_no_restore_buffers: set[torch.Tensor],
cp_rotate_method: str,
block_mask: Optional[BlockMask] = None,
sharder: Optional[FlexAttentionSharder] = None,
):
try:
from torch.distributed.tensor.experimental import context_parallel
from torch.distributed.tensor.experimental._attention import set_rotate_method
from torch.distributed.tensor.experimental._attention import (
_dispatch_mode,
_DispatchMode,
set_rotate_method,
)
except ImportError:
print(
f"PyTorch version {torch.__version__} does not include the experimental "
"Context Parallel API. Please update to a newer version."
)

set_rotate_method(cp_rotate_method)
torch.distributed.tensor.experimental._attention._dispatch_mode = (
_DispatchMode.TORCH_DISPATCH
)
assert (
torch.distributed.tensor.experimental._attention._dispatch_mode
== _DispatchMode.TORCH_DISPATCH
)
return context_parallel(
cp_mesh,
buffers=cp_buffers,
buffer_seq_dims=cp_seq_dims,
no_restore_buffers=cp_no_restore_buffers,
block_mask=block_mask,
sharder=sharder,
)


Expand Down
3 changes: 3 additions & 0 deletions torchtitan/experiments/llama4/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@
rope_theta=500000,
num_experts=16,
interleave_moe_layer_step=1,
use_flex_attn=True,
attn_mask_type="block_causal",
# attn_mask_type="causal",
),
"17bx128e": TransformerModelArgs(
dim=5120,
Expand Down
5 changes: 3 additions & 2 deletions torchtitan/experiments/llama4/model/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class TransformerModelArgs(BaseModelArgs):
interleave_moe_layer_step: int = 2
# token-choice
top_k: int = 1
use_grouped_mm: bool = True # grouped mm or for-loop for the experts computation
use_grouped_mm: bool = False # grouped mm or for-loop for the experts computation
load_balance_coeff: float | None = 1e-3

def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> None:
Expand All @@ -74,12 +74,13 @@ def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> Non
"FlexAttention is not compatible with selective AC yet. "
"See https://github.com/pytorch/pytorch/issues/147879"
)

"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can just remove this block.

if job_config.parallelism.context_parallel_degree > 1 and self.use_flex_attn:
raise ValueError(
"FlexAttention is not compatible with CP yet. "
"We are still working on this."
)
"""

def get_nparams_and_flops(
self, model: nn.Module, seq_len: int
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ batch_size = 8
seq_len = 2048
max_norm = 1.0 # grad norm clipping
steps = 10
compile = false
compile = true
dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M)

[parallelism]
Expand Down
14 changes: 11 additions & 3 deletions torchtitan/experiments/llama4/train_configs/llama4_17bx16e.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ profile_freq = 100

[metrics]
log_freq = 10
enable_tensorboard = false
enable_tensorboard = true
save_tb_folder = "tb"

[model]
Expand All @@ -27,23 +27,31 @@ eps = 1e-15

[lr_scheduler]
warmup_steps = 600
# warmup_steps = 20
lr_min = 0.1

[training]
batch_size = 8
# batch_size = 8
batch_size = 4
seq_len = 8192
# seq_len = 16384
# seq_len = 32768
# seq_len = 65536
max_norm = 1.0 # grad norm clipping
steps = 3000
# steps = 100
compile = false
# compile = true
dataset = "c4"
deterministic = true

[parallelism]
data_parallel_replicate_degree = 1
data_parallel_shard_degree = -1
tensor_parallel_degree = 8
enable_async_tensor_parallel = false
pipeline_parallel_degree = 1
context_parallel_degree = 1
context_parallel_degree = 4

[checkpoint]
enable_checkpoint = false
Expand Down
11 changes: 11 additions & 0 deletions torchtitan/models/llama3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,17 @@
multiple_of=1024,
rope_theta=500000,
),
"8B_flex_attn": TransformerModelArgs(
dim=4096,
n_layers=32,
n_heads=32,
n_kv_heads=8,
ffn_dim_multiplier=1.3,
multiple_of=1024,
rope_theta=500000,
use_flex_attn=True,
attn_mask_type="block_causal",
),
"70B": TransformerModelArgs(
dim=8192,
n_layers=80,
Expand Down
6 changes: 0 additions & 6 deletions torchtitan/models/llama3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,6 @@ def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> Non
"See https://github.com/pytorch/pytorch/issues/147879"
)

if job_config.parallelism.context_parallel_degree > 1 and self.use_flex_attn:
raise ValueError(
"FlexAttention is not compatible with CP yet. "
"We are still working on this."
)

def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]:
nparams = sum(p.numel() for p in model.parameters())
nparams_embedding = sum(
Expand Down
2 changes: 1 addition & 1 deletion torchtitan/models/llama3/train_configs/debug_model.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ batch_size = 8
seq_len = 2048
max_norm = 1.0 # grad norm clipping
steps = 10
compile = false
compile = true
dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M)

[parallelism]
Expand Down
19 changes: 12 additions & 7 deletions torchtitan/models/llama3/train_configs/llama3_8b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@ save_traces_folder = "profile_trace"
profile_freq = 100

[metrics]
log_freq = 10
log_freq = 50
enable_tensorboard = true
# enable_tensorboard = false
save_tb_folder = "tb"

[model]
Expand All @@ -27,22 +28,25 @@ lr = 3e-4
eps = 1e-8

[lr_scheduler]
warmup_steps = 200 # lr scheduler warm up
# warmup_steps = 200 # lr scheduler warm up
warmup_steps = 600

[training]
batch_size = 1
batch_size = 4
seq_len = 8192
max_norm = 1.0 # grad norm clipping
steps = 1000
compile = false
# steps = 1000
steps = 3000
compile = true
dataset = "c4"
deterministic = true

[parallelism]
data_parallel_replicate_degree = 1
data_parallel_shard_degree = -1
tensor_parallel_degree = 1
pipeline_parallel_degree = 1
context_parallel_degree = 1
context_parallel_degree = 4

[checkpoint]
enable_checkpoint = false
Expand All @@ -53,7 +57,8 @@ export_dtype = "float32"
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]

[activation_checkpoint]
mode = "selective" # ["none", "selective", "full"]
# mode = "selective" # ["none", "selective", "full"]
mode = "full"
selective_ac_option = "op" # "int" = ac every positive int layer or 'op', ac based on ops policy

[float8]
Expand Down
23 changes: 22 additions & 1 deletion torchtitan/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,13 @@
from typing import Any, Generator, Iterable, Optional

import torch
from torch.distributed.elastic.multiprocessing.errors import record

import torchtitan.components.ft as ft
import torchtitan.protocols.train_spec as train_spec_module
from torch.distributed.elastic.multiprocessing.errors import record
from torch.distributed.tensor.experimental._attention import (
FlexAttentionContiguousSharder,
)

from torchtitan.components.checkpoint import CheckpointManager
from torchtitan.components.metrics import (
Expand Down Expand Up @@ -133,7 +136,9 @@ def __init__(self, job_config: JobConfig):

# build model (using meta init)
model_cls = self.train_spec.cls
# NOTE (xilunwu): need to store model_args.use_flex_attn for train_step
model_args = self.train_spec.config[job_config.model.flavor]
self.model_args = model_args
# set the model args from training job configs
model_args.update_from_config(job_config, tokenizer)

Expand Down Expand Up @@ -319,13 +324,29 @@ def train_step(self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor):
# apply context parallelism if cp is enabled
# ensure CP handles the separate freqs_cis buffer for each pp stage
inputs = input_dict["input"]

# TODO: move this into `create_context_parallel_ctx`
# init block_mask for flex_attention
block_mask = None
if self.model_args.use_flex_attn:
from torchtitan.models.attention import FlexAttention

mask_mod = FlexAttention._get_causal_mask_mod()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think mask_mod should be the input of context_parallel() and we can directly call compiled_create_block_mask. See the comment below.

batch_dimension = 1
seq_len = inputs.shape[1]
block_mask = FlexAttention.compiled_create_block_mask(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should either let flex attention provide this compiled_create_block_mask to minimize the dependency on users' code when parallelizing CP. cc., @drisspg

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

meaning that Flex provides the compiled partial with no mask_mod args?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For CP + flex_attention, this PR generates 3 compiled BlockMask object for each mask_mod in:

  1. QKV sharding -- this requires the existence of compiled BlockMask on global batch input and mask_mod to load balance.
  2. actual training. The first FlexAttention module in model will create a compiled BlockMask from the sharded batch input and mask_mod. Note that applying this mask_mod to the sharded batch input is meaningless. Therefore this BlockMask will not be used the actual CP flex_attention computation.
  3. actual training. When forward flex_attention is called over the sharded batch input for the first time in the current step, a BlockMask will be created from the sharded batch input and a remapped mask_mod which corresponds to the local region in the attention score (the Q_LEN by KV_LEN rectangle).

(1) introduces a dependency in user code in order to adopt CP flex_attention. (2) is how we define the mask_mod in torchtitan and can be modified. Ideally (1) and (2) can be merged so that there's no redundancy as well as user code modification in order to use CP.

mask_mod, batch_dimension, None, seq_len, seq_len
)

optional_context_parallel_ctx = (
dist_utils.create_context_parallel_ctx(
cp_mesh=world_mesh["cp"],
cp_buffers=[inputs, labels] + [m.freqs_cis for m in model_parts],
cp_seq_dims=[1, 1] + [0 for _ in model_parts],
cp_no_restore_buffers={inputs, labels},
cp_rotate_method=self.job_config.parallelism.context_parallel_rotate_method,
block_mask=block_mask,
sharder=FlexAttentionContiguousSharder(),
)
if parallel_dims.cp_enabled
else None
Expand Down
Loading