-
Notifications
You must be signed in to change notification settings - Fork 411
[cp][flex_attention] integration test trial #1160
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: gh/XilunWu/18/base
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,10 +11,13 @@ | |
from typing import Any, Generator, Iterable, Optional | ||
|
||
import torch | ||
from torch.distributed.elastic.multiprocessing.errors import record | ||
|
||
import torchtitan.components.ft as ft | ||
import torchtitan.protocols.train_spec as train_spec_module | ||
from torch.distributed.elastic.multiprocessing.errors import record | ||
from torch.distributed.tensor.experimental._attention import ( | ||
FlexAttentionContiguousSharder, | ||
) | ||
|
||
from torchtitan.components.checkpoint import CheckpointManager | ||
from torchtitan.components.metrics import ( | ||
|
@@ -133,7 +136,9 @@ def __init__(self, job_config: JobConfig): | |
|
||
# build model (using meta init) | ||
model_cls = self.train_spec.cls | ||
# NOTE (xilunwu): need to store model_args.use_flex_attn for train_step | ||
model_args = self.train_spec.config[job_config.model.flavor] | ||
self.model_args = model_args | ||
# set the model args from training job configs | ||
model_args.update_from_config(job_config, tokenizer) | ||
|
||
|
@@ -319,13 +324,29 @@ def train_step(self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor): | |
# apply context parallelism if cp is enabled | ||
# ensure CP handles the separate freqs_cis buffer for each pp stage | ||
inputs = input_dict["input"] | ||
|
||
# TODO: move this into `create_context_parallel_ctx` | ||
# init block_mask for flex_attention | ||
block_mask = None | ||
if self.model_args.use_flex_attn: | ||
from torchtitan.models.attention import FlexAttention | ||
|
||
mask_mod = FlexAttention._get_causal_mask_mod() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think |
||
batch_dimension = 1 | ||
seq_len = inputs.shape[1] | ||
block_mask = FlexAttention.compiled_create_block_mask( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should either let flex attention provide this compiled_create_block_mask to minimize the dependency on users' code when parallelizing CP. cc., @drisspg There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. meaning that Flex provides the compiled partial with no mask_mod args? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For CP + flex_attention, this PR generates 3 compiled BlockMask object for each mask_mod in:
(1) introduces a dependency in user code in order to adopt CP flex_attention. (2) is how we define the mask_mod in torchtitan and can be modified. Ideally (1) and (2) can be merged so that there's no redundancy as well as user code modification in order to use CP. |
||
mask_mod, batch_dimension, None, seq_len, seq_len | ||
) | ||
|
||
optional_context_parallel_ctx = ( | ||
dist_utils.create_context_parallel_ctx( | ||
cp_mesh=world_mesh["cp"], | ||
cp_buffers=[inputs, labels] + [m.freqs_cis for m in model_parts], | ||
cp_seq_dims=[1, 1] + [0 for _ in model_parts], | ||
cp_no_restore_buffers={inputs, labels}, | ||
cp_rotate_method=self.job_config.parallelism.context_parallel_rotate_method, | ||
block_mask=block_mask, | ||
sharder=FlexAttentionContiguousSharder(), | ||
) | ||
if parallel_dims.cp_enabled | ||
else None | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can just remove this block.