-
Notifications
You must be signed in to change notification settings - Fork 310
Mattg context parallel halo #626
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Mattg context parallel halo #626
Conversation
Fake CP backward
distributed cp (no optimizations)
* DTests * bwd passing * bwd actually passing for world_size=8, higher tol * rm old dist test * add test for final state fwd
WalkthroughThis PR introduces comprehensive context-parallel (CP) support for the GatedDeltaNet model, enabling efficient distributed training across multiple GPU ranks. It adds CP-enabled layer and model implementations, distributed halo exchange operations with autograd support, and extensive test coverage including forward/backward parity validation, memory benchmarking, and training examples. Changes
Sequence Diagram(s)sequenceDiagram
participant R0 as Rank 0
participant R1 as Rank 1
participant R2 as Rank 2
Note over R0,R2: Forward Pass with Halo Exchange
R0->>R0: Compute local<br/>chunk
R0->>R1: Send tail h tokens<br/>(right halo)
R1->>R1: Receive left halo<br/>from R0
R1->>R1: Compute local<br/>chunk + halo
R1->>R2: Send tail h tokens
R2->>R2: Receive left halo<br/>from R1
R2->>R2: Compute local<br/>chunk + halo
Note over R0,R2: Backward Pass (Gradient Flow)
R2->>R2: Compute local<br/>gradients
R2->>R1: Send gradient<br/>baton (dh)
R1->>R1: Receive dh,<br/>compute grads
R1->>R0: Send gradient<br/>baton (dh)
R0->>R0: Receive dh,<br/>compute grads
sequenceDiagram
participant Model as GatedDeltaNetForCausalLMCP
participant Block as GatedDeltaNetBlockCP
participant Attn as GatedDeltaNet (CP)
participant Halo as halo_exchange_and_extend_autograd
Model->>Block: forward(input_ids, cp_rank, cp_size, cp_group)
activate Block
Block->>Attn: forward(hidden_states, cp_rank, cp_size, cp_group)
activate Attn
Attn->>Halo: halo_exchange_and_extend_autograd(q, k, v, h)
activate Halo
Halo-->>Attn: (q_ext, k_ext, v_ext, cu_seqlens_ext)
deactivate Halo
Attn->>Attn: chunk_gated_delta_rule_cp<br/>(with halo-extended inputs)
Attn-->>Block: (attn_output)
deactivate Attn
Block-->>Model: (block_output)
deactivate Block
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Areas requiring extra attention:
Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (1 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Tip 📝 Customizable high-level summaries are now available in beta!You can now customize how CodeRabbit generates the high-level summary in your pull requests — including its content, structure, tone, and formatting.
Example instruction:
Note: This feature is currently in beta for Pro-tier users, and pricing will be announced later. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
Hi, thanks for the PR! could you fix the following CI test error?
|
1700f8d to
f4082b3
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 7
🧹 Nitpick comments (34)
tests/cp/test_cp_stability.py (1)
1-1: Consider making the file executable or removing the shebang.The static analysis tool notes that the shebang is present but the file isn't executable.
To fix, either:
chmod +x tests/cp/test_cp_stability.pyOr remove the shebang if the file is only meant to be run via
torchrun.fla/ops/gated_delta_rule/cp_halo.py (2)
94-97: Fix code style: split multiple statements onto separate lines.Line 97 contains three statements separated by semicolons, which violates PEP 8 style guidelines and is flagged by linters.
Apply this diff:
seq_boundaries = cu_seqlens[1:-1] if seq_boundaries.numel() > 0 and bool((seq_boundaries == cp_shard_start_idx).any().item()): - q_halo.zero_(); k_halo.zero_(); v_halo.zero_() + q_halo.zero_() + k_halo.zero_() + v_halo.zero_()
202-208: Fix code style: split statement onto separate line.Line 208 contains
send_buf.zero_()which should be on its own line for consistency with the rest of the codebase.Apply this diff:
if cu_seqlens_saved is not None: seq_boundaries = cu_seqlens_saved[1:-1] if seq_boundaries.numel() > 0 and bool((seq_boundaries == ctx.cp_shard_start_idx).any().item()): - send_buf.zero_() + send_buf.zero_()Note: This is more about maintaining consistency with the fix suggested for line 97.
tests/ops/cp/README.md (1)
1-2: Documentation addresses the CI import error.This README explains the
ModuleNotFoundError: No module named 'dtest'mentioned in the PR comments. Users need to install the DTest package from the referenced repository to run these distributed tests.Consider adding a note in the main repository README or CI documentation about this optional dependency requirement.
tests/cp/README.md (3)
24-32: Add language specifier to code block.The code block should specify
bashas the language for proper syntax highlighting and markdown compliance.Apply this diff:
-``` +```bash torchrun --nproc_per_node=8 tests/cp/train_cp_real.py --cp_size 8
38-46: Add language specifier to code block.The code block should specify
bashas the language for proper syntax highlighting.Apply this diff:
-``` +```bash # Single GPU baseline python tests/cp/benchmark_memory.py --cp_size 1
59-67: Add language specifier to code block.The code block should specify
bashas the language for proper syntax highlighting.Apply this diff:
-``` +```bash # Step 1: Generate reference on single GPU python test_numerical_correctness.py --mode single --save_reftests/ops/cp/test_gated_delta_cp_fake.py (1)
22-28: Clarifycp_sharddefault and tidy unused temporaries
cp_shardcurrently does:return tensor.tensor_split(cp_degree or self.cp_degree, dim=shard_dim)but
self.cp_degreeis never defined onTestFakeCP. Right now all callsites passcp_degree, so this doesn’t break, but it’s a latent bug ifcp_shardis ever reused without that argument.You can either (a) store a
self.cp_degreeattribute on the class, or (b) makecp_degreerequired and drop the fallback, e.g.:- def cp_shard( - self, - tensor: torch.Tensor, - cp_degree: int | None = None, - shard_dim: int | None = 1, - ) -> torch.Tensor: - return tensor.tensor_split(cp_degree or self.cp_degree, dim=shard_dim) + def cp_shard( + self, + tensor: torch.Tensor, + cp_degree: int, + shard_dim: int = 1, + ) -> tuple[torch.Tensor, ...]: + return tensor.tensor_split(cp_degree, dim=shard_dim)Separately, the unpacked locals
g_out,A,dh0_ref, ando_iare only used to reach later outputs. If you want Ruff to be clean in strict mode, consider prefixing the unused ones with_:- g_out, o_ref, A, _ = chunk_gated_delta_rule_fwd(...) + _g_out, o_ref, A, _ = chunk_gated_delta_rule_fwd(...) ... - dq_ref, dk_ref, dv_ref, db_ref, dg_ref, dh0_ref = chunk_gated_delta_rule_bwd(...) + dq_ref, dk_ref, dv_ref, db_ref, dg_ref, _dh0_ref = chunk_gated_delta_rule_bwd(...) ... - g_out_i, o_i, A_i, initial_state = chunk_gated_delta_rule_fwd(...) + g_out_i, _o_i, A_i, initial_state = chunk_gated_delta_rule_fwd(...)Also applies to: 50-56, 128-128, 148-148
fla/layers/gated_deltanet_cp.py (2)
10-21: Remove unused imports and local to keep the module lint‑cleanIn this module:
torch.distributed as distis imported but not used.chunk_gated_delta_ruleis imported fromfla.ops.gated_delta_rulebut never referenced (onlyfused_recurrent_gated_delta_ruleandchunk_gated_delta_rule_cpare used).cp_shard_start_idx = kwargs.get('cp_shard_start_idx', None)is assigned but never used; later calls re‑fetchcp_shard_start_idxfromkwargsdirectly.These don’t change behavior but will trip Flake8 (F401/F841). A minimal cleanup would be:
-import torch.distributed as dist ... -from fla.ops.gated_delta_rule import chunk_gated_delta_rule, fused_recurrent_gated_delta_rule +from fla.ops.gated_delta_rule import fused_recurrent_gated_delta_rule ... - cp_shard_start_idx = kwargs.get('cp_shard_start_idx', None) + # cp_shard_start_idx is only threaded through to halo helpers via kwargsAlso applies to: 92-107, 227-232
92-107: Align type hints withNonedefaults fornum_v_headsandlayer_idx
__init__currently declares:num_v_heads: int = None, ... layer_idx: int = None,With Ruff’s RUF013 and standard typing conventions, these should be explicitly optional:
- num_v_heads: int = None, + num_v_heads: Optional[int] = None, ... - layer_idx: int = None, + layer_idx: Optional[int] = None,This keeps the signature honest for type checkers without changing runtime behavior.
fla/ops/gated_delta_rule/chunk_cp.py (1)
1-17: Trim unused imports and align a couple of type hintsTop of file, several heavy operators are imported but never used in the CP wrapper:
from fla.ops.common.chunk_delta_h import chunk_gated_delta_rule_bwd_dhu, chunk_gated_delta_rule_fwd_h from fla.ops.common.chunk_o import chunk_bwd_dqkwg, chunk_bwd_dv_local, chunk_fwd_o from fla.ops.common.chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd from fla.ops.gated_delta_rule.wy_fast import prepare_wy_repr_bwd, recompute_w_u_fwd from fla.ops.utils import chunk_local_cumsum, solve_trilGiven
ChunkGatedDeltaRuleFunctionCPdelegates entirely tochunk_gated_delta_rule_fwd/chunk_gated_delta_rule_bwd, these imports can be dropped to keep the module light and avoid Flake8 F401:-from fla.modules.l2norm import l2norm_bwd, l2norm_fwd -from fla.ops.common.chunk_delta_h import chunk_gated_delta_rule_bwd_dhu, chunk_gated_delta_rule_fwd_h -from fla.ops.common.chunk_o import chunk_bwd_dqkwg, chunk_bwd_dv_local, chunk_fwd_o -from fla.ops.common.chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd -from fla.ops.gated_delta_rule.wy_fast import prepare_wy_repr_bwd, recompute_w_u_fwd -from fla.ops.utils import chunk_local_cumsum, solve_tril -from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard +from fla.modules.l2norm import l2norm_bwd, l2norm_fwd +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guardA couple of small typing nits if you want Ruff’s RUF013 clean:
- scale: float = None, - initial_state: torch.Tensor = None, + scale: Optional[float] = None, + initial_state: Optional[torch.Tensor] = None,and similarly keep the local unpackings for
Tas_Tif you want to silence “unused variable” warnings:- B, T, H, K = q.shape - B, T, H, V = v.shape + B, _T, H, K = q.shape + B, _T, H, V = v.shape ... - B, T, H, K = q.shape - B, T, H, V = v.shape + B, _T, H, K = q.shape + B, _T, H, V = v.shapeThese are non-functional, but they keep the file tidy.
Also applies to: 42-43, 115-116, 171-186
tests/cp/test_gdn_cp_backward.py (1)
111-129: Minor cleanups for unused locals in CP backward test harnessThe CP parity harness is logically sound, but a couple of locals are never used:
rankfromsetupis returned and then not used incp.- In
cp,B, T = input_ids.shapeonlyTis used.If you run Ruff in stricter modes, these will trigger RUF059. You can either prefix them with
_or drop them:-def setup(cp_size: int): +def setup(cp_size: int): @@ - rank = dist.get_rank() + rank = dist.get_rank() @@ - return rank, device, group + return rank, device, group ... -def cp(cp_size: int): +def cp(cp_size: int): @@ - rank, device, group = setup(cp_size) + _rank, device, group = setup(cp_size) @@ - B, T = input_ids.shape + _B, T = input_ids.shapePurely cosmetic, but it keeps linters quiet.
Also applies to: 132-136, 148-151
tests/ops/cp/test_gated_delta_cp_real.py (2)
71-131: Useworld_sizeargument for a sanity check (also fixes Ruff ARG002)
test_fwdtakesworld_sizebut never uses it, which triggers Ruff’s ARG002 and also misses an easy safety check. You can both satisfy the linter and assert the distributed setup matches the mark:- @pytest.mark.world_size([2, 4, 8]) - def test_fwd(self, world_size: int) -> None: + @pytest.mark.world_size([2, 4, 8]) + def test_fwd(self, world_size: int) -> None: q = torch.rand(self.B, self.T, self.H, self.D, dtype=self.dtype) ... - # ---- reference forward (no CP) ---- + # ---- reference forward (no CP) ---- + assert dist.get_world_size() == world_sizeThis keeps the fixture useful and avoids the unused‑arg warning.
132-233: Clean up minor duplication and unused values intest_bwdThe backward test logic looks sound (reference path vs CP path with baton gradients), but there are a few small nits you may want to tidy:
qnandknare computed twice (Lines 145–147 and again 151–153); the second pair is redundant.o_refanddh0_reffromchunk_gated_delta_rule_fwd/chunk_gated_delta_rule_bwdaren’t used; similarlyo_cpfrom the CP forward is unused. You can prefix them with_to silence RUF059 or replace with_in the unpacking:- g_out_ref, o_ref, A_ref, _ = chunk_gated_delta_rule_fwd( + g_out_ref, _o_ref, A_ref, _ = chunk_gated_delta_rule_fwd( ... - dq_ref, dk_ref, dv_ref, db_ref, dg_ref, dh0_ref = chunk_gated_delta_rule_bwd( + dq_ref, dk_ref, dv_ref, db_ref, dg_ref, _dh0_ref = chunk_gated_delta_rule_bwd( ... - g_out_cp, o_cp, A_cp, ht = chunk_gated_delta_rule_fwd( + g_out_cp, _o_cp, A_cp, ht = chunk_gated_delta_rule_fwd(Also, the NOTE about some asserts failing at tighter tolerance for
world_size=8(Lines 227–228) is a bit concerning; if this is not fully understood yet, consider either tightening the tolerance back after investigation or documenting clearly why the relaxed0.02tolerance is acceptable.tests/cp/test_cp_halo_unit.py (3)
25-33: Drop unusedTupleimport
Tuplefromtypingis imported but never used, which flake8 flags (F401). You can safely remove it:-import argparse -import os -from typing import Optional, Tuple +import argparse +import os +from typing import OptionalNo behavioral impact.
65-87: Tighten shape unpacking to reflect actual usageIn
depthwise_causal_convandglobal_baseline, you unpackB, T, D = x.shape, but some of these variables aren’t used, which Ruff flags. You can avoid the warnings and make the intent clearer by only binding what you need, e.g.:-def depthwise_causal_conv(x: torch.Tensor, w: torch.Tensor, h: int) -> torch.Tensor: - ... - B, T, D = x.shape +def depthwise_causal_conv(x: torch.Tensor, w: torch.Tensor, h: int) -> torch.Tensor: + ... + _, _, D = x.shapeand in
global_baseline:- B, T, D = x.shape + B, _, D = x.shapeThis keeps behavior identical while satisfying linters.
151-241: Address unusedvarlenargument and unneeded outputs from halo exchange
run_cpdoesn’t actually use itsvarlenargument, and Ruff also points out unused unpacked outputs fromhalo_exchange_and_extend_autogradand an unusedBfromx_full.shape. You can clean this up as follows:
- If
varlenis not needed (since the reference already encodes varlen viacu_seqlens), drop it from the signature and callers, or rename to_varlento indicate intentional discard.- Avoid unused outputs from halo exchange:
- q_ext, k_ext, v_ext, cu_ext = halo_exchange_and_extend_autograd( + q_ext, *_ = halo_exchange_and_extend_autograd( q_in, k_in, v_in, h, cp_rank=rank, cp_size=cp_size, cp_group=group, cu_seqlens=cu_full, cp_shard_start_idx=s if cu_full is not None else None, )
- Avoid binding unused
B:- B, T, D = x_full.shape + _, T, D = x_full.shapeThese are purely cosmetic/quality improvements; the core CP/halo logic looks correct.
tests/cp/benchmark_memory.py (2)
22-37: Consider assertingworld_size == cp_sizefor multi-GPU runsIn
setup_distributed, whencp_size > 1you assume the default process group has exactlycp_sizeranks but don’t enforce it. A mismatched--nproc_per_nodevs--cp_sizecould cause confusing hangs inside the model. You can add a quick sanity check:def setup_distributed(cp_size): """Setup distributed""" if cp_size > 1: dist.init_process_group(backend='nccl') rank = dist.get_rank() + world_size = dist.get_world_size() + if world_size != cp_size: + raise ValueError(f"Expected world_size == cp_size, got {world_size} vs {cp_size}") ...This mirrors the stricter check you use in
tests/cp/test_numerical_correctness.pyand fails fast on misconfiguration.
132-209: Fix trivial f-string lint issues in printingSeveral print calls use f-strings without any placeholders (Ruff/flake8 F541), e.g.:
- Line 152:
print(f"\nConfiguration:")- Line 180:
print(f" ✓ Success")- Lines 205–207:
print(f" python benchmark_memory.py --cp_size 1"), etc.These can be plain strings:
- print(f"\nConfiguration:") + print("\nConfiguration:") ... - print(f" ✓ Success") + print(" ✓ Success") ... - print(f" python benchmark_memory.py --cp_size 1") + print(" python benchmark_memory.py --cp_size 1")Purely a lint/style fix; behavior is unchanged.
tests/cp/test_gdn_cp_forward.py (3)
18-27: Drop unusedOptionalimport
Optionalis imported fromtypingbut never used (flake8 F401). You can safely remove it:-import argparse -import os -from typing import Optional +import argparse +import osNo functional impact.
99-107: Distributed setup is reasonable but assumesworld_size == cp_size
setup_distinitializes a process group and then creates a new group containingrange(cp_size)without explicitly verifying that the world size matchescp_size. In practice you’ll run with--nproc_per_node == cp_size, but to harden this, consider assertingdist.get_world_size() == cp_size(similar totest_numerical_correctness). This is optional but would catch misconfigured launches early.
110-170: Clear CP parity logic; address minor lint (unusedvarlenandB)The
run_cppath correctly:
- Broadcasts the reference dict from rank 0.
- Rebuilds the model per conv size and loads the saved state.
- Shards the sequence over ranks, runs CP forward, gathers logits, and compares to reference on rank 0 with a sane tolerance.
Two small cleanups:
varlenis accepted but never used; either drop it from the signature (and frommain) or rename to_varlento indicate it’s intentionally ignored.B, T = input_ids.shapebindsBbut never uses it; you can write_, T = input_ids.shapeto avoid RUF059.Otherwise this function looks good.
tests/cp/test_numerical_correctness.py (2)
56-115: Single-GPU reference flow is well-structured; consider optional CPU fallbackThe single-GPU test builds a compact config, creates dummy data via
create_model_and_data, runs a forward pass withcp_rank=0, cp_size=1, and optionally saves a full reference bundle. That’s correct and consistent with the CP side. The only caveat is the hardcodeddevice = torch.device('cuda'); if you ever want to run these tests on CPU-only environments, you’d need a guard similar to other files ('cuda' if torch.cuda.is_available() else 'cpu'). Not required if these tests are GPU-only by design.
117-303: CP numerical correctness logic is solid; fix trivial lint issuesThe
test_context_parallelflow is robust:
- Uses
setup_distributedto ensure ranks/devices are aligned.- Loads the single-GPU reference on rank 0, broadcasts tensors to all ranks.
- Constructs identical models on all ranks and synchronizes weights by broadcasting parameter tensors.
- Shards the sequence over ranks, runs CP forward, gathers logits, and compares against the reference using loss diff, absolute/relative logits diff, and correlation, with clear thresholds.
A few small cleanups that will fix the static-analysis warnings:
- Avoid reloading the reference twice on rank 0; you can reuse the first
ref.- Replace f-strings without placeholders with plain strings, e.g.:
- print(f"\nResults:") + print("\nResults:") ... - print(f"\n📊 Loss Comparison:") + print("\n📊 Loss Comparison:") ... - print(f"\n📊 Logits Comparison:") + print("\n📊 Logits Comparison:") ... - print(f"\n🎉 ALL TESTS PASSED!") + print("\n🎉 ALL TESTS PASSED!")and similarly for other lines Ruff/flake8 flagged (170, 173, 287, 288, 293, 297, 298).
- The continuation of the
parser.add_argument('--mode', ...)line (around Line 308) can be re-indented to satisfy E128, e.g.:parser.add_argument( '--mode', type=str, required=True, choices=['single', 'cp'], help='Test mode', )None of these change behavior but will keep the file lint‑clean.
tests/cp/train_cp_real.py (4)
53-82: Dataset is fine, butmodeis currently unused.
RealisticTextDatasetdoes what it promises (deterministic synthetic full sequences per rank), but themodeargument is not used anywhere, which is a small source of confusion.If you don’t plan to support file-based loading here, consider dropping
modefor now; otherwise, implement themode == "file"branch (or at least raiseNotImplementedErrorwhenmode != "random").- self.vocab_size = vocab_size - self.mode = mode + self.vocab_size = vocab_size + self.mode = mode + + if self.mode != "random": + raise NotImplementedError("Only mode='random' is implemented in RealisticTextDataset.")
125-173: Fix unusedbatch_sizebinding inCPCollator(RUF059) and keep shard logic as-is.The collator’s sharding logic looks sound, but
batch_sizefrombatch_size, seq_len = input_ids.shapeis never used, which Ruff flags as RUF059.Minimal fix:
- batch_size, seq_len = input_ids.shape + _, seq_len = input_ids.shapeEverything else in this collator (padding to make
seq_lendivisible bycp_size, using-100for label padding, and exposingchunk_start/chunk_end) looks consistent with the downstream CP model.
259-371: Fix unusedlocal_rankand stray f-strings without placeholders (F541, RUF059).In
main:
local_rankis unpacked but never used outside logging ofdevice, which already captures what we need.- Several
f"..."without any placeholders, flagged by Flake8 F541.Suggested adjustments:
- rank, world_size, local_rank, device, cp_rank, cp_size, cp_group = setup_distributed_cp(args.cp_size) + rank, world_size, _local_rank, device, cp_rank, cp_size, cp_group = setup_distributed_cp(args.cp_size) @@ - if rank == 0: - print("\n" + "="*80) - print("REALISTIC CONTEXT PARALLELISM TRAINING") - print("="*80) - print(f"\n🔧 Setup:") + if rank == 0: + print("\n" + "=" * 80) + print("REALISTIC CONTEXT PARALLELISM TRAINING") + print("=" * 80) + print("\n🔧 Setup:") @@ - if rank == 0: - print(f"\n📐 Model Configuration:") + if rank == 0: + print("\n📐 Model Configuration:") @@ - if rank == 0: - print(f"\n📚 Dataset:") + if rank == 0: + print("\n📚 Dataset:") @@ - if rank == 0: - print(f"\n🚀 Training:") + if rank == 0: + print("\n🚀 Training:")Similar change for the batch-info header:
- if rank == 0 and step == start_step: - print(f"📦 Batch info:") + if rank == 0 and step == start_step: + print("📦 Batch info:")
1-12: Optional: drop shebang or make file executable to silence EXE001.Ruff flags EXE001 because there’s a shebang but the file is not necessarily executable in the repo. If you don’t plan to run this via
./train_cp_real.py, you can safely remove the shebang:-#!/usr/bin/env python -# -*- coding: utf-8 -*- +#!/usr/bin/env python +# -*- coding: utf-8 -*-or just drop the shebang entirely and keep the encoding line:
-#!/usr/bin/env python -# -*- coding: utf-8 -*- +# -*- coding: utf-8 -*-Pick whichever matches how you intend to invoke this script.
fla/models/gated_deltanet/modeling_gated_deltanet_cp.py (6)
38-82: CP block wiring looks good; just confirm MLP supports**kwargs.
GatedDeltaNetBlockCPcorrectly:
- Normalizes + residuals in a prenorm style.
- Switches between
AttentionandGatedDeltaNetbased onconfig.attn.- Threads
**kwargs(CP params) intoself.attn.One thing to double-check: the tail call
hidden_states = self.mlp(hidden_states, **kwargs)assumes
GatedDeltaNetMLP’sforwardaccepts**kwargs. If its signature is stillforward(self, hidden_states: torch.Tensor), this will raiseTypeErroras soon as you pass CP keywords likecp_rank.Please verify the MLP signature and, if needed, either:
- Extend
GatedDeltaNetMLP.forwardto accept**kwargs, or- Drop the kwargs here if they are not needed:
- hidden_states = self.mlp(hidden_states, **kwargs) + hidden_states = self.mlp(hidden_states)
119-126: Avoid mutable class-level attributes for_no_split_modules(RUF012).
_no_split_modulesis currently a list, which Ruff flags as a mutable class attribute. Since this is effectively constant metadata, a tuple is safer and silences RUF012.- _no_split_modules = ['GatedDeltaNetBlockCP'] + _no_split_modules = ('GatedDeltaNetBlockCP',)
268-282: Convert_tied_weights_keysto a tuple to avoid mutable class attribute (RUF012).Similar to
_no_split_modules,_tied_weights_keysis defined as a list and flagged by Ruff. It’s logically constant, so a tuple is more appropriate:- _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = ("lm_head.weight",)
373-383: Tidy tuple construction in the non-return_dictpath (RUF005).Ruff suggests avoiding tuple concatenation with
+in favor of unpacking. You can simplify the non-return_dictbranch as:- if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output + if not return_dict: + output = (logits, *outputs[1:]) + if loss is not None: + return (loss, *output) + return outputThis is a minor style/readability improvement and eliminates the RUF005 warning.
189-217: Addstacklevel=2towarnings.warnfor clearer call sites.Python's
warnings.warnsupports astacklevelkeyword argument, and usingstacklevel=2is the standard way for a wrapper to make the warning point to the wrapper's caller rather than the warn() call itself. Apply the suggested change to line 213-215:if output_attentions: - warnings.warn("`GatedDeltaNetModelCP` does not `output_attentions` now, setting it to `False`.") + warnings.warn( + "`GatedDeltaNetModelCP` does not `output_attentions` now, setting it to `False`.", + stacklevel=2, + ) output_attentions = False
301-315: Apply the recommended exception re-raising patterns ingeneratemethod.The current code uses
raise exceptionin theelsebranch, which obscures the original traceback. The recommended pattern is to use plainraise(no arguments) to re-raise and preserve the original traceback, or use explicit chainingraise NewError(...) from original_excto raise a new exception while attaching the original as the cause.Suggested fix:
- Replace
raise exceptionwith bareraise- Change
raise AttributeError(...)toraise AttributeError(...) from errto preserve exception chain- Rename the caught variable from
exceptiontoerrfor claritydef generate(self, *args, **kwargs): try: return super().generate(*args, **kwargs) - except AttributeError as exception: - if 'past_key_values' in str(exception): + except AttributeError as err: + if 'past_key_values' in str(err): raise AttributeError( f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, " f"which is not supported for {self.__class__.__name__}. " f"Try another generation strategy instead. " f"For the available generation strategies, check this doc: " f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies" - ) + ) from err else: - raise exception + raise
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (17)
fla/layers/gated_deltanet_cp.py(1 hunks)fla/models/gated_deltanet/modeling_gated_deltanet_cp.py(1 hunks)fla/ops/gated_delta_rule/__init__.py(1 hunks)fla/ops/gated_delta_rule/chunk_cp.py(1 hunks)fla/ops/gated_delta_rule/cp_halo.py(1 hunks)tests/cp/README.md(1 hunks)tests/cp/benchmark_memory.py(1 hunks)tests/cp/test_cp_halo.py(1 hunks)tests/cp/test_cp_halo_unit.py(1 hunks)tests/cp/test_cp_stability.py(1 hunks)tests/cp/test_gdn_cp_backward.py(1 hunks)tests/cp/test_gdn_cp_forward.py(1 hunks)tests/cp/test_numerical_correctness.py(1 hunks)tests/cp/train_cp_real.py(1 hunks)tests/ops/cp/README.md(1 hunks)tests/ops/cp/test_gated_delta_cp_fake.py(1 hunks)tests/ops/cp/test_gated_delta_cp_real.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (15)
tests/cp/test_cp_stability.py (3)
fla/models/gated_deltanet/configuration_gated_deltanet.py (1)
GatedDeltaNetConfig(9-103)fla/models/gated_deltanet/modeling_gated_deltanet_cp.py (1)
GatedDeltaNetForCausalLMCP(268-383)tests/cp/test_gdn_cp_backward.py (1)
setup(111-129)
fla/layers/gated_deltanet_cp.py (7)
fla/layers/utils.py (2)
get_unpad_data(75-98)pad_input(176-197)fla/modules/fused_norm_gate.py (1)
FusedRMSNormGated(997-1058)fla/modules/layernorm.py (1)
RMSNorm(1064-1111)fla/modules/convolution.py (1)
ShortConvolution(794-1009)fla/ops/gated_delta_rule/fused_recurrent.py (1)
fused_recurrent_gated_delta_rule(242-353)fla/ops/gated_delta_rule/chunk_cp.py (2)
chunk_gated_delta_rule_cp(171-235)forward(24-100)fla/ops/gated_delta_rule/cp_halo.py (3)
halo_exchange_and_extend(12-115)halo_exchange_and_extend_autograd(226-245)forward(128-164)
tests/cp/test_numerical_correctness.py (2)
fla/models/gated_deltanet/configuration_gated_deltanet.py (1)
GatedDeltaNetConfig(9-103)fla/models/gated_deltanet/modeling_gated_deltanet_cp.py (1)
GatedDeltaNetForCausalLMCP(268-383)
tests/cp/test_gdn_cp_backward.py (4)
fla/models/gated_deltanet/configuration_gated_deltanet.py (1)
GatedDeltaNetConfig(9-103)fla/models/gated_deltanet/modeling_gated_deltanet_cp.py (1)
GatedDeltaNetForCausalLMCP(268-383)fla/ops/gated_delta_rule/cp_halo.py (1)
backward(167-223)tests/cp/test_cp_stability.py (1)
setup(20-29)
tests/ops/cp/test_gated_delta_cp_real.py (2)
fla/ops/gated_delta_rule/chunk.py (2)
chunk_gated_delta_rule_bwd(72-149)chunk_gated_delta_rule_fwd(18-69)fla/utils.py (1)
assert_close(82-93)
tests/cp/test_cp_halo.py (1)
fla/ops/gated_delta_rule/cp_halo.py (1)
halo_exchange_and_extend(12-115)
tests/cp/test_cp_halo_unit.py (2)
fla/ops/gated_delta_rule/cp_halo.py (1)
halo_exchange_and_extend_autograd(226-245)tests/cp/test_gdn_cp_forward.py (2)
setup_dist(99-107)run_cp(110-170)
tests/cp/train_cp_real.py (4)
fla/models/gated_deltanet/configuration_gated_deltanet.py (1)
GatedDeltaNetConfig(9-103)fla/models/gated_deltanet/modeling_gated_deltanet_cp.py (1)
GatedDeltaNetForCausalLMCP(268-383)fla/ops/gated_delta_rule/chunk_cp.py (1)
backward(105-168)fla/ops/gated_delta_rule/cp_halo.py (1)
backward(167-223)
fla/ops/gated_delta_rule/cp_halo.py (1)
fla/layers/gated_deltanet_cp.py (1)
forward(206-369)
tests/cp/test_gdn_cp_forward.py (2)
fla/models/gated_deltanet/configuration_gated_deltanet.py (1)
GatedDeltaNetConfig(9-103)fla/models/gated_deltanet/modeling_gated_deltanet_cp.py (1)
GatedDeltaNetForCausalLMCP(268-383)
fla/ops/gated_delta_rule/__init__.py (1)
fla/ops/gated_delta_rule/cp_halo.py (2)
halo_exchange_and_extend(12-115)halo_exchange_and_extend_autograd(226-245)
fla/ops/gated_delta_rule/chunk_cp.py (7)
fla/modules/l2norm.py (3)
l2norm(266-271)l2norm_bwd(195-239)l2norm_fwd(149-192)fla/ops/common/chunk_delta_h.py (2)
chunk_gated_delta_rule_bwd_dhu(480-529)chunk_gated_delta_rule_fwd_h(431-477)fla/ops/common/chunk_o.py (3)
chunk_bwd_dqkwg(624-685)chunk_bwd_dv_local(577-621)chunk_fwd_o(485-522)fla/ops/common/chunk_scaled_dot_kkt.py (1)
chunk_scaled_dot_kkt_fwd(222-309)fla/ops/utils/cumsum.py (1)
chunk_local_cumsum(432-469)fla/ops/utils/solve_tril.py (1)
solve_tril(337-384)fla/ops/gated_delta_rule/chunk.py (2)
chunk_gated_delta_rule_fwd(18-69)chunk_gated_delta_rule_bwd(72-149)
tests/cp/benchmark_memory.py (3)
fla/models/gated_deltanet/configuration_gated_deltanet.py (1)
GatedDeltaNetConfig(9-103)fla/models/gated_deltanet/modeling_gated_deltanet_cp.py (1)
GatedDeltaNetForCausalLMCP(268-383)fla/ops/gated_delta_rule/chunk_cp.py (1)
backward(105-168)
fla/models/gated_deltanet/modeling_gated_deltanet_cp.py (6)
fla/layers/gated_deltanet_cp.py (2)
GatedDeltaNet(38-369)forward(206-369)fla/models/gated_deltanet/configuration_gated_deltanet.py (1)
GatedDeltaNetConfig(9-103)fla/models/utils.py (1)
FLAGenerationMixin(385-462)fla/modules/layernorm.py (1)
RMSNorm(1064-1111)fla/models/modeling_layers.py (1)
GradientCheckpointingLayer(11-71)fla/ops/gated_delta_rule/chunk_cp.py (1)
forward(24-100)
tests/ops/cp/test_gated_delta_cp_fake.py (2)
fla/ops/gated_delta_rule/chunk.py (2)
chunk_gated_delta_rule_fwd(18-69)chunk_gated_delta_rule_bwd(72-149)fla/utils.py (1)
assert_close(82-93)
🪛 Flake8 (7.3.0)
fla/layers/gated_deltanet_cp.py
[error] 11-11: 'torch.distributed as dist' imported but unused
(F401)
[error] 18-18: 'fla.ops.gated_delta_rule.chunk_gated_delta_rule' imported but unused
(F401)
[error] 231-231: local variable 'cp_shard_start_idx' is assigned to but never used
(F841)
tests/cp/test_numerical_correctness.py
[error] 95-95: f-string is missing placeholders
(F541)
[error] 170-170: f-string is missing placeholders
(F541)
[error] 173-173: f-string is missing placeholders
(F541)
[error] 238-238: f-string is missing placeholders
(F541)
[error] 259-259: f-string is missing placeholders
(F541)
[error] 287-287: f-string is missing placeholders
(F541)
[error] 288-288: f-string is missing placeholders
(F541)
[error] 293-293: f-string is missing placeholders
(F541)
[error] 297-297: f-string is missing placeholders
(F541)
[error] 298-298: f-string is missing placeholders
(F541)
[error] 308-308: continuation line under-indented for visual indent
(E128)
tests/cp/test_cp_halo_unit.py
[error] 27-27: 'typing.Tuple' imported but unused
(F401)
tests/cp/train_cp_real.py
[error] 16-16: 'typing.Optional' imported but unused
(F401)
[error] 21-21: 'transformers.AutoTokenizer' imported but unused
(F401)
[error] 23-23: 'torch.distributed.nn.all_reduce' imported but unused
(F401)
[error] 192-192: local variable 'attention_mask' is assigned to but never used
(F841)
[error] 296-296: f-string is missing placeholders
(F541)
[error] 316-316: f-string is missing placeholders
(F541)
[error] 338-338: f-string is missing placeholders
(F541)
[error] 365-365: f-string is missing placeholders
(F541)
[error] 386-386: f-string is missing placeholders
(F541)
fla/ops/gated_delta_rule/cp_halo.py
[error] 97-97: multiple statements on one line (semicolon)
(E702)
[error] 97-97: multiple statements on one line (semicolon)
(E702)
tests/cp/test_gdn_cp_forward.py
[error] 20-20: 'typing.Optional' imported but unused
(F401)
fla/ops/gated_delta_rule/chunk_cp.py
[error] 8-8: 'fla.ops.common.chunk_delta_h.chunk_gated_delta_rule_bwd_dhu' imported but unused
(F401)
[error] 8-8: 'fla.ops.common.chunk_delta_h.chunk_gated_delta_rule_fwd_h' imported but unused
(F401)
[error] 9-9: 'fla.ops.common.chunk_o.chunk_bwd_dqkwg' imported but unused
(F401)
[error] 9-9: 'fla.ops.common.chunk_o.chunk_bwd_dv_local' imported but unused
(F401)
[error] 9-9: 'fla.ops.common.chunk_o.chunk_fwd_o' imported but unused
(F401)
[error] 10-10: 'fla.ops.common.chunk_scaled_dot_kkt.chunk_scaled_dot_kkt_fwd' imported but unused
(F401)
[error] 11-11: 'fla.ops.gated_delta_rule.wy_fast.prepare_wy_repr_bwd' imported but unused
(F401)
[error] 11-11: 'fla.ops.gated_delta_rule.wy_fast.recompute_w_u_fwd' imported but unused
(F401)
[error] 12-12: 'fla.ops.utils.chunk_local_cumsum' imported but unused
(F401)
[error] 12-12: 'fla.ops.utils.solve_tril' imported but unused
(F401)
tests/cp/benchmark_memory.py
[error] 152-152: f-string is missing placeholders
(F541)
[error] 180-180: f-string is missing placeholders
(F541)
[error] 205-205: f-string is missing placeholders
(F541)
[error] 206-206: f-string is missing placeholders
(F541)
[error] 207-207: f-string is missing placeholders
(F541)
🪛 markdownlint-cli2 (0.18.1)
tests/cp/README.md
24-24: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
38-38: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
59-59: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
🪛 Ruff (0.14.5)
tests/cp/test_cp_stability.py
1-1: Shebang is present but file is not executable
(EXE001)
fla/layers/gated_deltanet_cp.py
98-98: PEP 484 prohibits implicit Optional
Convert to T | None
(RUF013)
105-105: PEP 484 prohibits implicit Optional
Convert to T | None
(RUF013)
107-107: Unused method argument: kwargs
(ARG002)
133-136: Avoid specifying long messages outside the exception class
(TRY003)
138-140: Avoid specifying long messages outside the exception class
(TRY003)
143-146: Avoid specifying long messages outside the exception class
(TRY003)
195-195: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
212-212: Unused method argument: output_attentions
(ARG002)
231-231: Local variable cp_shard_start_idx is assigned to but never used
Remove assignment to unused variable cp_shard_start_idx
(F841)
tests/cp/test_numerical_correctness.py
1-1: Shebang is present but file is not executable
(EXE001)
35-35: Avoid specifying long messages outside the exception class
(TRY003)
95-95: f-string without any placeholders
Remove extraneous f prefix
(F541)
170-170: f-string without any placeholders
Remove extraneous f prefix
(F541)
173-173: f-string without any placeholders
Remove extraneous f prefix
(F541)
186-186: Comment contains ambiguous – (EN DASH). Did you mean - (HYPHEN-MINUS)?
(RUF003)
238-238: f-string without any placeholders
Remove extraneous f prefix
(F541)
259-259: f-string without any placeholders
Remove extraneous f prefix
(F541)
287-287: f-string without any placeholders
Remove extraneous f prefix
(F541)
288-288: f-string without any placeholders
Remove extraneous f prefix
(F541)
293-293: f-string without any placeholders
Remove extraneous f prefix
(F541)
297-297: f-string without any placeholders
Remove extraneous f prefix
(F541)
298-298: f-string without any placeholders
Remove extraneous f prefix
(F541)
tests/cp/test_gdn_cp_backward.py
1-1: Shebang is present but file is not executable
(EXE001)
134-134: Unpacked variable rank is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
148-148: Unpacked variable B is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
177-177: Avoid specifying long messages outside the exception class
(TRY003)
188-188: Avoid specifying long messages outside the exception class
(TRY003)
205-205: Avoid specifying long messages outside the exception class
(TRY003)
tests/ops/cp/test_gated_delta_cp_real.py
72-72: Unused method argument: world_size
(ARG002)
156-156: Unpacked variable o_ref is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
169-169: Unpacked variable dh0_ref is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
195-195: Unpacked variable o_cp is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
tests/cp/test_cp_halo.py
1-1: Shebang is present but file is not executable
(EXE001)
52-52: Unpacked variable k_ext is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
52-52: Unpacked variable v_ext is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
tests/cp/test_cp_halo_unit.py
1-1: Shebang is present but file is not executable
(EXE001)
70-70: Unpacked variable B is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
70-70: Unpacked variable T is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
84-84: Unpacked variable T is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
151-151: Unused function argument: varlen
(ARG001)
168-168: Unpacked variable B is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
187-187: Unpacked variable k_ext is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
187-187: Unpacked variable v_ext is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
187-187: Unpacked variable cu_ext is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
tests/cp/train_cp_real.py
1-1: Shebang is present but file is not executable
(EXE001)
38-41: Avoid specifying long messages outside the exception class
(TRY003)
145-145: Unpacked variable batch_size is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
192-192: Local variable attention_mask is assigned to but never used
Remove assignment to unused variable attention_mask
(F841)
220-220: Consider moving this statement to an else block
(TRY300)
222-222: Do not catch blind exception: Exception
(BLE001)
290-290: Unpacked variable local_rank is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
296-296: f-string without any placeholders
Remove extraneous f prefix
(F541)
316-316: f-string without any placeholders
Remove extraneous f prefix
(F541)
338-338: f-string without any placeholders
Remove extraneous f prefix
(F541)
365-365: f-string without any placeholders
Remove extraneous f prefix
(F541)
386-386: f-string without any placeholders
Remove extraneous f prefix
(F541)
fla/ops/gated_delta_rule/cp_halo.py
97-97: Multiple statements on one line (semicolon)
(E702)
97-97: Multiple statements on one line (semicolon)
(E702)
tests/cp/test_gdn_cp_forward.py
1-1: Shebang is present but file is not executable
(EXE001)
110-110: Unused function argument: varlen
(ARG001)
133-133: Unpacked variable B is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
fla/ops/gated_delta_rule/chunk_cp.py
42-42: Unpacked variable T is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
115-115: Unpacked variable T is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
177-177: PEP 484 prohibits implicit Optional
Convert to T | None
(RUF013)
195-195: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
200-200: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
208-211: Avoid specifying long messages outside the exception class
(TRY003)
213-216: Avoid specifying long messages outside the exception class
(TRY003)
tests/cp/benchmark_memory.py
1-1: Shebang is present but file is not executable
(EXE001)
152-152: f-string without any placeholders
Remove extraneous f prefix
(F541)
180-180: f-string without any placeholders
Remove extraneous f prefix
(F541)
205-205: f-string without any placeholders
Remove extraneous f prefix
(F541)
206-206: f-string without any placeholders
Remove extraneous f prefix
(F541)
207-207: f-string without any placeholders
Remove extraneous f prefix
(F541)
fla/models/gated_deltanet/modeling_gated_deltanet_cp.py
124-124: Mutable class attributes should be annotated with typing.ClassVar
(RUF012)
171-171: Avoid specifying long messages outside the exception class
(TRY003)
212-212: No explicit stacklevel keyword argument found
Set stacklevel=2
(B028)
221-221: Avoid specifying long messages outside the exception class
(TRY003)
223-223: Avoid specifying long messages outside the exception class
(TRY003)
271-271: Mutable class attributes should be annotated with typing.ClassVar
(RUF012)
306-312: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling
(B904)
306-312: Avoid specifying long messages outside the exception class
(TRY003)
314-314: Use raise without specifying exception name
Remove exception name
(TRY201)
374-374: Consider (logits, *outputs[1:]) instead of concatenation
Replace with (logits, *outputs[1:])
(RUF005)
375-375: Consider (loss, *output) instead of concatenation
Replace with (loss, *output)
(RUF005)
tests/ops/cp/test_gated_delta_cp_fake.py
50-50: Unpacked variable g_out is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
50-50: Unpacked variable A is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
128-128: Unpacked variable dh0_ref is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
148-148: Unpacked variable o_i is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
🔇 Additional comments (21)
tests/cp/test_cp_stability.py (1)
58-78: LGTM: Core test logic is solid.The test correctly:
- Partitions the sequence across CP ranks (lines 58-63)
- Synchronizes with barriers before each iteration (line 71)
- Passes CP parameters to the model (lines 74-75)
- Validates stability across multiple iterations
fla/ops/gated_delta_rule/cp_halo.py (2)
127-164: LGTM: Forward pass correctly saves context and delegates.The forward method properly:
- Saves all necessary metadata for backward (lines 139-152)
- Delegates to the underlying exchange function (lines 154-159)
- Marks metadata tensors as non-differentiable (lines 162-163)
212-220: LGTM: Gradient accumulation correctly handles padding.The backward pass properly:
- Receives tail gradients from the next rank (lines 213-215)
- Only accumulates the non-padded portion when
T < h(lines 216-220)- Uses in-place addition to accumulate gradients correctly
fla/ops/gated_delta_rule/__init__.py (1)
3-10: LGTM: Clean API extension.The new halo exchange functions are properly imported and exported, following the existing pattern in this module.
tests/cp/test_cp_halo.py (1)
22-23: Verify thathalo_exchange_and_extendis re‑exported at the package levelThis test imports:
from fla.ops.gated_delta_rule import halo_exchange_and_extendwhile the implementation lives in
fla.ops.gated_delta_rule.cp_halo. This assumeshalo_exchange_and_extendis re-exported fromfla.ops.gated_delta_rule.__init__. If that’s not the case, this import will fail at runtime, and you should either:
- add the appropriate re-export in
fla/ops/gated_delta_rule/__init__.py, or- import directly from
fla.ops.gated_delta_rule.cp_halo.Would you confirm in your environment whether
from fla.ops.gated_delta_rule import halo_exchange_and_extendworks, or whether you need to adjust the import?fla/ops/gated_delta_rule/chunk_cp.py (1)
170-186: Confirmtorch.compiler.disableis available for your minimum PyTorch versionThe public wrapper is decorated with:
@torch.compiler.disable def chunk_gated_delta_rule_cp(...): ...Support for
torch.compiler.disableis relatively recent; on older PyTorch versions, this may raise anAttributeErrorat import time. If you target a wide range of versions, you may need a guarded import or to fall back totorch._dynamo.disable/ no decoration.Could you verify against your target PyTorch version(s) that
torch.compiler.disableexists and behaves as expected for this function?tests/ops/cp/test_gated_delta_cp_real.py (2)
14-44: Baton send/recv pattern is logically consistentThe forward and gradient baton helpers implement a simple chain: rank 0 seeds
h0, later ranksrecvfrom rank-1; the last rank seedsdht, earlier ranksrecvfrom rank+1. This avoids circular dependencies and should not deadlock under the current call order. No changes required here.
46-70: CP shard/unshard helpers look correct
cp_shardsplitting alongshard_dim=1withcp_degree or self.world_sizeandcp_unshardviatorch.catmatch the intended time‑dimension sharding in these tests. No functional issues spotted.tests/cp/test_cp_halo_unit.py (2)
105-149:run_singlereference generation logic looks solidThe single‑rank baseline generation covers both
h=1andh=3, supports optional varlen viacu_seqlens, and properly detaches and moves tensors to CPU for saving. The gradient collection (gx,gw) and reset of.gradfields perhare correct. No changes needed.
243-260: CLI/main orchestration is fineArgument parsing (
mode,cp_size,save_ref,varlen,backward) and dispatch torun_singlevsrun_cpis straightforward and matches the documented usage examples in the module docstring. No functional issues here.tests/cp/benchmark_memory.py (2)
46-129: Benchmark flow and memory accounting look reasonableThe
benchmark_configfunction correctly isolates model/data allocations, captures memory usage after model creation, after data creation, after forward, and after backward, and computes deltas along those checkpoints. Error handling forRuntimeErrorand cleanup (del+torch.cuda.empty_cache()) are also appropriate for a benchmarking script. No functional issues stand out here.
161-213: Top-level loop and summary formatting are clearThe main loop over
seq_lens, conditional printing on rank 0, aggregation intoresults, and final summary table are all straightforward and easy to read. The script should be convenient to run for quick CP memory sweeps. No further changes needed here.tests/cp/test_gdn_cp_forward.py (3)
32-57: Config and data factories look good
create_configandcreate_dataencode a compact config for quick tests and support both equal-length and varlen viacu_seqlens. The seeding increate_dataensures deterministic inputs for reproducible parity checks. No issues here.
59-97:run_singlereference generation is consistent with CP usageThe single-rank run:
- Covers two
conv_sizevalues (2 and 4).- Stores full model
state_dict, loss, logits, and inputs/labels (plus optionalcu_seqlens) per conv size.- Uses
cp_rank=0, cp_size=1when calling the CP model, aligning with the CP API.This sets up exactly the information
run_cpneeds to reconstruct models and compare logits. No functional problems spotted.
173-188: CLI entrypoint matches documented usageThe
mainfunction exposes--mode,--cp_size,--save_ref, and--varlenexactly as described in the module docstring and ties intorun_single/run_cpcorrectly. No issues.tests/cp/test_numerical_correctness.py (2)
24-38: Distributed setup sanity check is good
setup_distributedinitializes NCCL, binds CUDA device viaLOCAL_RANK, and explicitly enforcesworld_size == cp_size, which is exactly what you want for these CP tests. No changes needed here.
305-321: CLI wiring is straightforwardThe
mainentrypoint cleanly dispatches between single-reference generation and CP correctness testing based on--mode, and passes--cp_size/--save_refthrough appropriately. No issues here.tests/cp/train_cp_real.py (1)
29-50: Distributed CP setup looks correct for CP-only runs.The
setup_distributed_cphelper correctly enforcesworld_size == cp_size, sets the CUDA device fromLOCAL_RANK, and builds a single CP group over all ranks, which matches the “CP-only” assumption.No changes required here.
fla/models/gated_deltanet/modeling_gated_deltanet_cp.py (3)
130-172: Initialization logic is reasonable and matches GatedDeltaNet’s parameterization.The custom
_init_weights:
- Re-initializes
GatedDeltaNet.A_loganddt_biasin a way consistent with the layer’s constructor.- Uses standard normal init for Linear/Conv/Embedding.
- Optionally applies a prenorm residual strategy (
rescaleorzero) too_proj/down_proj.The logic looks consistent with the underlying layer semantics and the config, and should integrate cleanly with
post_init().No changes needed here.
235-247: CP kwargs are correctly threaded through the encoder stack.The loop over
self.layerspasses**kwargsdown to eachGatedDeltaNetBlockCP, which in turn forwards them to the CP-aware attention layer. This is what enables CP rank/size/group to reach the halo and chunk kernels.This section looks consistent and doesn’t need changes.
316-372: Causal LM forward and CP param propagation are consistent with HF patterns.
GatedDeltaNetForCausalLMCP.forward:
- Delegates to
self.modelwhile threading through**kwargsfor CP.- Supports
logits_to_keepwith the standard “keep last K logits” semantics.- Handles fused vs non-fused cross-entropy and optional L2 warp in line with the config.
- Shifts labels for next-token prediction consistent with other causal models in the codebase.
This section looks coherent and aligned with the rest of the modeling stack. No changes required here.
| # Handle initial state for CP | ||
| if cp_size > 1: | ||
| if cp_rank == 0: | ||
| h0_local = initial_state.to(torch.float32) if initial_state is not None else torch.zeros( | ||
| (B, H, K, V), dtype=torch.float32, device=device | ||
| ) | ||
| else: | ||
| h0_local = torch.empty((B, H, K, V), dtype=torch.float32, device=device) | ||
| recv_req = dist.irecv(h0_local, src=cp_rank - 1, group=cp_group) | ||
| recv_req.wait() | ||
| else: | ||
| h0_local = initial_state.to(torch.float32) if initial_state is not None else torch.zeros( | ||
| (B, H, K, V), dtype=torch.float32, device=device | ||
| ) | ||
|
|
||
| # Force output_final_state=True for CP to get the final state | ||
| force_output_final_state = True if cp_size > 1 else output_final_state |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix gradient propagation for initial_state in multi‑rank CP backward
In ChunkGatedDeltaRuleFunctionCP:
- Forward only uses the user‑provided
initial_statewhencp_size == 1orcp_rank == 0. Forcp_rank > 0,h0_localis filled viadist.irecvand theinitial_stateargument is ignored. - Backward always computes
dh0(grad wrt the local initial state), sends it to the previous rank (correct for CP baton), and returns it as the gradient for theinitial_stateinput wheneverinitial_state.requires_gradis True:
ctx.initial_state_requires_grad = initial_state is not None and initial_state.requires_grad
...
dh0_return = dh0 if ctx.initial_state_requires_grad else None
return ..., dh0_return, ...For cp_size > 1 and cp_rank > 0 this is incorrect: the user-provided initial_state was never used in that shard’s forward, so its gradient contribution from this shard must be zero. The only rank that should contribute a gradient to the input initial_state is rank 0 (or the single-rank case), with all CP shards’ influence already folded in via the dht baton and dh0 on rank 0.
You can fix this by restricting initial_state_requires_grad to the ranks that actually consume the input:
- ctx.output_final_state = output_final_state # Store original value
- ctx.initial_state_requires_grad = initial_state is not None and initial_state.requires_grad # Track if gradient needed
+ ctx.output_final_state = output_final_state # Store original value
+ # Only rank 0 (or single‑rank) actually consumes the user‑provided initial_state;
+ # later ranks receive their h0 via dist. Do not propagate grads for initial_state from other ranks.
+ ctx.initial_state_requires_grad = (
+ initial_state is not None
+ and initial_state.requires_grad
+ and (cp_size == 1 or cp_rank == 0)
+ )The existing dh0_return = dh0 if ctx.initial_state_requires_grad else None in backward can remain as-is once this flag is corrected.
Also applies to: 87-97, 117-139, 140-168
🤖 Prompt for AI Agents
In fla/ops/gated_delta_rule/chunk_cp.py around lines 50-66 (and similarly at
87-97, 117-139, 140-168), the bug is that ctx.initial_state_requires_grad is set
from the presence/requires_grad of the user-provided initial_state on all ranks,
but ranks with cp_size>1 and cp_rank>0 never actually use the user initial_state
in forward (they receive h0 via dist.irecv), so they must not report gradients
for that input; fix by setting ctx.initial_state_requires_grad = (initial_state
is not None and initial_state.requires_grad) AND (cp_size == 1 or cp_rank == 0)
when you save ctx in forward on all relevant blocks, so that backward only
returns dh0 for the input initial_state on the rank(s) that consumed it (rank 0
or single-rank), leaving other ranks to return None as gradient for that
argument.
| def test_basic_halo(cp_size: int, h: int = 2): | ||
| """Test that rank r receives the tail from rank r-1 as its left halo.""" | ||
| rank, world_size = _init_dist('gloo') | ||
| assert world_size == cp_size, f"world_size({world_size}) != cp_size({cp_size})" | ||
|
|
||
| # Shapes | ||
| B = 1 | ||
| T_local = 6 | ||
| Dq, Dk, Dv = 3, 4, 5 | ||
|
|
||
| # Create deterministic inputs per rank | ||
| torch.manual_seed(2025 + rank) | ||
| q_in = torch.randn(B, T_local, Dq) | ||
| k_in = torch.randn(B, T_local, Dk) | ||
| v_in = torch.randn(B, T_local, Dv) | ||
|
|
||
| # No varlen here | ||
| cu_seqlens = None | ||
|
|
||
| q_ext, k_ext, v_ext, cu_ext = halo_exchange_and_extend( | ||
| q_in, k_in, v_in, h, | ||
| cp_rank=rank, cp_size=cp_size, cp_group=dist.group.WORLD, | ||
| cu_seqlens=cu_seqlens, cp_shard_start_idx=None, | ||
| ) | ||
|
|
||
| # Gather previous rank tails on ALL ranks to avoid deadlock | ||
| q_tails = [torch.zeros(B, h, Dq) for _ in range(world_size)] | ||
| local_tail = q_in[:, -h:, :].contiguous() | ||
| dist.all_gather(q_tails, local_tail) | ||
|
|
||
| # Rank 0 should have zero halo | ||
| if rank == 0: | ||
| assert torch.allclose(q_ext[:, :h], torch.zeros_like(q_ext[:, :h])) | ||
| assert cu_ext is None | ||
| # Rank > 0 should have halo equal to last h tokens of previous rank | ||
| else: | ||
| prev_tail = q_tails[rank - 1] | ||
| assert torch.allclose(q_ext[:, :h, :], prev_tail, atol=1e-6), ( | ||
| f"Rank {rank}: halo mismatch vs prev rank tail" | ||
| ) | ||
|
|
||
| if dist.is_initialized(): | ||
| dist.barrier() | ||
|
|
||
|
|
||
| def test_varlen_zero_halo(cp_size: int, h: int = 2): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Prevent script-style CP halo tests from breaking pytest discovery
Both test_basic_halo and test_varlen_zero_halo take a required cp_size argument and are meant to be driven via the main() entrypoint with torchrun. Under normal pytest collection, these names will still be treated as tests, and cp_size will be interpreted as a fixture. Since there is no cp_size fixture, the test suite will fail during collection.
If these are intended as manual / torchrun utilities rather than automated pytest tests, consider renaming them so they are not auto‑collected, and adjust main() accordingly, for example:
-def test_basic_halo(cp_size: int, h: int = 2):
+def run_basic_halo(cp_size: int, h: int = 2):
...
-def test_varlen_zero_halo(cp_size: int, h: int = 2):
+def run_varlen_zero_halo(cp_size: int, h: int = 2):
...
@@
- test_basic_halo(args.cp_size)
- test_varlen_zero_halo(args.cp_size)
+ run_basic_halo(args.cp_size)
+ run_varlen_zero_halo(args.cp_size)That keeps the torchrun UX the same while preventing accidental CI failures from pytest treating these as normal unit tests.
Also applies to: 78-120
🧰 Tools
🪛 Ruff (0.14.5)
52-52: Unpacked variable k_ext is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
52-52: Unpacked variable v_ext is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
🤖 Prompt for AI Agents
In tests/cp/test_cp_halo.py around lines 33 to 78 (and similarly lines 78 to
120), the functions test_basic_halo and test_varlen_zero_halo are defined as
pytest-discoverable tests but require a cp_size argument and are intended to be
run via the main()/torchrun entrypoint; rename them (e.g., to run_basic_halo and
run_varlen_zero_halo) or add a pytest skip/collection guard so pytest does not
treat them as tests during collection, and update the main() invocation to call
the new names so the torchrun UX is unchanged.
| def setup(cp_size: int): | ||
| backend = 'nccl' if torch.cuda.is_available() else 'gloo' | ||
| dist.init_process_group(backend=backend) | ||
| rank = dist.get_rank() | ||
| local_rank = int(os.environ.get('LOCAL_RANK', 0)) | ||
| device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else 'cpu') | ||
| if torch.cuda.is_available(): | ||
| torch.cuda.set_device(local_rank) | ||
| group = dist.new_group(list(range(cp_size))) | ||
| return rank, device, group |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
Verify group creation logic and LOCAL_RANK handling.
Two concerns:
-
Line 24: Defaulting
LOCAL_RANKto0can cause device collisions when multiple ranks run on the same node. In distributed scenarios withoutLOCAL_RANKset, this will fail or produce incorrect behavior. -
Line 28:
dist.new_group(list(range(cp_size)))assumes the world consists of ranks0throughcp_size-1. Ifworld_size > cp_size, this creates a subgroup that might not match the intended CP group. Consider assertingworld_size == cp_sizeearlier or clarifying the intended group membership.
Apply this diff to add an assertion and improve LOCAL_RANK handling:
def setup(cp_size: int):
backend = 'nccl' if torch.cuda.is_available() else 'gloo'
dist.init_process_group(backend=backend)
+ world_size = dist.get_world_size()
+ assert world_size == cp_size, f"world_size ({world_size}) must equal cp_size ({cp_size})"
rank = dist.get_rank()
- local_rank = int(os.environ.get('LOCAL_RANK', 0))
+ if torch.cuda.is_available():
+ local_rank = int(os.environ['LOCAL_RANK'])
+ else:
+ local_rank = 0
device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else 'cpu')
if torch.cuda.is_available():
torch.cuda.set_device(local_rank)
- group = dist.new_group(list(range(cp_size)))
+ group = dist.group.WORLD # Use full world as CP group when world_size == cp_size
return rank, device, group🏁 Script executed:
#!/bin/bash
# Locate and examine the test file and setup function
fd -t f "test_cp_stability.py" --exec cat {} \;Length of output: 2554
🏁 Script executed:
#!/bin/bash
# Search for all usages of the setup function to understand how cp_size is used
rg -n "setup\(" tests/ -A 2 -B 2 --type pyLength of output: 1221
🏁 Script executed:
#!/bin/bash
# Check how torch.distributed.new_group is used elsewhere in codebase
rg -n "new_group" --type py -C 3Length of output: 2917
🏁 Script executed:
#!/bin/bash
# Verify LOCAL_RANK environment variable usage patterns
rg -n "LOCAL_RANK" --type py -C 2Length of output: 2896
🏁 Script executed:
#!/bin/bash
# Check torch.distributed.group usage in codebase
rg -n "dist\.group\." --type py -C 2Length of output: 1053
🏁 Script executed:
#!/bin/bash
# Look for world_size assertions or comparisons
rg -n "world_size" --type py -C 2Length of output: 12271
🏁 Script executed:
#!/bin/bash
# Check if there are any comments or documentation about cp_size expectations
rg -n "cp_size" --type py -C 3Length of output: 58292
🏁 Script executed:
#!/bin/bash
# Verify torch.cuda.set_device usage patterns
rg -n "set_device" --type py -C 2Length of output: 2643
🏁 Script executed:
#!/bin/bash
# Check what dist module functions are commonly used after init_process_group
rg -n "init_process_group" --type py -A 5 -B 2Length of output: 4793
🏁 Script executed:
#!/bin/bash
# Verify get_world_size usage
rg -n "get_world_size" --type py -C 2Length of output: 1994
🏁 Script executed:
#!/bin/bash
# Check if os.environ with default vs mandatory key access for LOCAL_RANK
rg -n "os\.environ\.get.*LOCAL_RANK\|os\.environ\[.*LOCAL_RANK" --type py -C 2Length of output: 0
🏁 Script executed:
#!/bin/bash
# Look for any documentation files explaining the distributed setup
fd -t f -e md -e rst -e txt | xargs grep -l "LOCAL_RANK\|cp_size" 2>/dev/null || echo "No docs found"Length of output: 18
🏁 Script executed:
#!/bin/bash
# Check PyTorch version requirements or constraints
fd -t f "requirements\|setup\|pyproject" | head -20Length of output: 0
🏁 Script executed:
#!/bin/bash
# Look at how the test file structure - what tests use setup()
rg -n "def test_" tests/cp/test_cp_stability.py -A 3 || echo "No test functions found"Length of output: 23
🏁 Script executed:
#!/bin/bash
# Check if there are any fixtures or initialization that set up environments
rg -n "pytest\|@fixture\|LOCAL_RANK" tests/cp/ -C 2Length of output: 0
🏁 Script executed:
#!/bin/bash
# Verify if the diff's suggestion (dist.group.WORLD) is the right approach
python3 << 'EOF'
# Check PyTorch distributed API availability
try:
import torch.distributed as dist
print("Available in dist.group:")
print(dir(dist.group))
except Exception as e:
print(f"Error: {e}")
EOFLength of output: 30
🏁 Script executed:
#!/bin/bash
# Get full content of the test file for complete context
cat tests/cp/test_cp_stability.py 2>/dev/null || echo "File not found in tests/cp/"Length of output: 2554
🏁 Script executed:
#!/bin/bash
# Search more broadly for this file
find . -name "*test_cp*" -type f 2>/dev/null | head -20Length of output: 90
🏁 Script executed:
#!/bin/bash
# If file not found, search for similar patterns
fd -t f "stability" --type py | head -20Length of output: 169
🏁 Script executed:
#!/bin/bash
# Check repository structure
ls -la 2>/dev/null | head -20Length of output: 1188
🏁 Script executed:
#!/bin/bash
# Find test directories
fd -t d "tests" -o -t d "test" 2>/dev/null | head -10Length of output: 0
🏁 Script executed:
#!/bin/bash
# List all Python files in top-level test areas
fd -t f "\.py$" tests/ 2>/dev/null | head -20 || fd -t f "\.py$" test/ 2>/dev/null | head -20Length of output: 706
🏁 Script executed:
#!/bin/bash
# Search for any process group initialization pattern
rg -n "init_process_group\|new_group\|get_rank" --type py | head -30Length of output: 0
🏁 Script executed:
#!/bin/bash
# Look for distributed setup patterns
rg -n "def setup\(" --type py -A 15 | head -50Length of output: 2338
🏁 Script executed:
#!/bin/bash
# Check if there is a conftest or any setup file
fd -t f "conftest.py" -o -t f "setup.py"Length of output: 116
🏁 Script executed:
#!/bin/bash
# Broad search for cp_size pattern
rg "cp_size" --type pyLength of output: 11740
🏁 Script executed:
#!/bin/bash
# Search across all files for the exact function definition
rg -n "def setup\(cp_size" --type py -A 15Length of output: 2338
🏁 Script executed:
#!/bin/bash
# Get current working directory context
pwd && ls -laLength of output: 1567
🏁 Script executed:
#!/bin/bash
# Check if repo is cloned and list structure
find . -type d -name "tests" -o -type d -name "test" 2>/dev/nullLength of output: 7
🏁 Script executed:
#!/bin/bash
# Try to locate the file more flexibly
rg -l "test_cp_stability" 2>/dev/null || rg -l "def setup" tests/ --type py 2>/dev/null | head -5Length of output: 29
🏁 Script executed:
#!/bin/bash
# Get more info
git log --oneline -1 2>/dev/null || echo "Not a git repo or git unavailable"Length of output: 34
🏁 Script executed:
#!/bin/bash
# Check for PR details using GitHub CLI if available
gh pr view 2>/dev/null || echo "GitHub CLI not configured or not in PR context"Length of output: 46
🏁 Script executed:
#!/bin/bash
# Fallback: search everywhere for setup function with cp_size
rg -n "def setup\(" . --type py 2>/dev/null | grep -i cp || rg -n "cp_size" . --type py 2>/dev/null | head -10Length of output: 123
Add validation for world_size == cp_size and fix LOCAL_RANK handling for multi-node distributed training.
Two critical issues confirmed:
-
Line 24: Defaulting
LOCAL_RANKto0when CUDA is available risks device collisions. When running on multi-node clusters without LOCAL_RANK set, all ranks may attempt to use the same GPU. In contrast, other tests with distributed setup (test_gdn_cp_backward.py) require LOCAL_RANK as a mandatory environment variable when using GPUs. -
Missing validation: Unlike similar test files that validate
world_size == cp_size, this setup function silently assumes all launched ranks belong to the CP group, which could mask misconfiguration.
Apply this diff:
def setup(cp_size: int):
backend = 'nccl' if torch.cuda.is_available() else 'gloo'
dist.init_process_group(backend=backend)
+ world_size = dist.get_world_size()
+ assert world_size == cp_size, f"world_size ({world_size}) must equal cp_size ({cp_size})"
rank = dist.get_rank()
- local_rank = int(os.environ.get('LOCAL_RANK', 0))
+ if torch.cuda.is_available():
+ local_rank = int(os.environ['LOCAL_RANK'])
+ else:
+ local_rank = 0
device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else 'cpu')
if torch.cuda.is_available():
torch.cuda.set_device(local_rank)
- group = dist.new_group(list(range(cp_size)))
+ group = dist.group.WORLD
return rank, device, group🤖 Prompt for AI Agents
In tests/cp/test_cp_stability.py around lines 20 to 29, validate that
dist.get_world_size() equals the provided cp_size and enforce proper LOCAL_RANK
handling for GPU runs: after init_process_group, call world_size =
dist.get_world_size() and raise a clear ValueError if world_size != cp_size; for
CUDA paths do not default LOCAL_RANK to 0 — require the LOCAL_RANK environment
variable to be set (use os.environ['LOCAL_RANK'] and raise an error if missing)
before parsing it to int and setting the device and
torch.cuda.set_device(local_rank).
| import argparse | ||
| import os | ||
| from typing import Optional | ||
|
|
||
| import torch | ||
| import torch.distributed as dist | ||
| from torch.utils.data import DataLoader, Dataset | ||
| from transformers import AutoTokenizer | ||
| # Import differentiable all_reduce | ||
| from torch.distributed.nn import all_reduce |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Clean up unused imports and align with linting (F401).
Optional, AutoTokenizer, and all_reduce are imported but never used, which will trigger Flake8 (F401) and Ruff warnings. Since the script doesn’t currently rely on them, they should be removed.
You can apply something like:
-import argparse
-import os
-from typing import Optional
-
-import torch
-import torch.distributed as dist
-from torch.utils.data import DataLoader, Dataset
-from transformers import AutoTokenizer
-# Import differentiable all_reduce
-from torch.distributed.nn import all_reduce
+import argparse
+import os
+
+import torch
+import torch.distributed as dist
+from torch.utils.data import DataLoader, Dataset📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| import argparse | |
| import os | |
| from typing import Optional | |
| import torch | |
| import torch.distributed as dist | |
| from torch.utils.data import DataLoader, Dataset | |
| from transformers import AutoTokenizer | |
| # Import differentiable all_reduce | |
| from torch.distributed.nn import all_reduce | |
| import argparse | |
| import os | |
| import torch | |
| import torch.distributed as dist | |
| from torch.utils.data import DataLoader, Dataset |
🧰 Tools
🪛 Flake8 (7.3.0)
[error] 16-16: 'typing.Optional' imported but unused
(F401)
[error] 21-21: 'transformers.AutoTokenizer' imported but unused
(F401)
[error] 23-23: 'torch.distributed.nn.all_reduce' imported but unused
(F401)
🤖 Prompt for AI Agents
In tests/cp/train_cp_real.py around lines 14 to 23, there are unused imports:
Optional, AutoTokenizer, and all_reduce are imported but not used; remove these
three imports to satisfy linting (F401) and Ruff. Keep the necessary imports
(argparse, os, torch, torch.distributed as dist, DataLoader, Dataset) and update
any import lines accordingly so the file compiles with no unused import
warnings.
| def compute_loss_with_cp( | ||
| model: GatedDeltaNetForCausalLMCP, | ||
| batch: dict, | ||
| cp_rank: int, | ||
| cp_size: int, | ||
| cp_group, | ||
| device: torch.device | ||
| ) -> torch.Tensor: | ||
| """ | ||
| Compute loss with context parallelism. | ||
|
|
||
| Important: Each rank computes loss only on its chunk. | ||
| NO aggregation needed - gradients flow naturally through autograd! | ||
| """ | ||
| # Move to device | ||
| input_ids = batch['input_ids'].to(device) | ||
| attention_mask = batch['attention_mask'].to(device) | ||
| labels = batch['labels'].to(device) | ||
|
|
||
| # Debug: Check tensor shapes and values | ||
| if torch.any(torch.isnan(input_ids.float())): | ||
| print(f"Rank {cp_rank}: NaN detected in input_ids") | ||
| return torch.tensor(0.0, device=device, requires_grad=True) | ||
|
|
||
| try: | ||
| # Forward pass (each rank processes its chunk) | ||
| # Disable cu_seqlens for CP to avoid conflicts | ||
| outputs = model( | ||
| input_ids=input_ids, | ||
| attention_mask=None, # Disable attention_mask for now to avoid cu_seqlens | ||
| labels=labels, | ||
| cp_rank=cp_rank, | ||
| cp_size=cp_size, | ||
| cp_group=cp_group, | ||
| cp_shard_start_idx=batch.get('chunk_start', 0), # important for halo varlen safety | ||
| ) | ||
|
|
||
| loss = outputs.loss | ||
|
|
||
| # Check for NaN loss | ||
| if torch.isnan(loss): | ||
| print(f"Rank {cp_rank}: NaN loss detected") | ||
| return torch.tensor(0.0, device=device, requires_grad=True) | ||
|
|
||
| return loss | ||
|
|
||
| except Exception as e: | ||
| print(f"Rank {cp_rank}: Error in forward pass: {e}") | ||
| return torch.tensor(0.0, device=device, requires_grad=True) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove dead attention_mask assignment and reconsider broad Exception catch.
Two issues here:
-
Dead local
attention_mask(Flake8 F841)
You assignattention_mask = batch['attention_mask'].to(device)but always passattention_mask=Noneto the model. This is confusing and trips F841.Suggested change:
- input_ids = batch['input_ids'].to(device)
- attention_mask = batch['attention_mask'].to(device)
- labels = batch['labels'].to(device)
+ input_ids = batch['input_ids'].to(device)
+ labels = batch['labels'].to(device)(You can re-thread attention_mask later when CP + cu_seqlens interaction is resolved.)
- Overly broad
except Exception(Ruff BLE001)
Catching all exceptions and silently returning a zero loss will hide real bugs in the CP stack and can make CI/debugging very hard. At minimum, consider re-raising after logging:
- except Exception as e:
- print(f"Rank {cp_rank}: Error in forward pass: {e}")
- return torch.tensor(0.0, device=device, requires_grad=True)
+ except Exception as e:
+ print(f"Rank {cp_rank}: Error in forward pass: {e}")
+ raiseIf you still want “best-effort” training, narrow the exception type (e.g., RuntimeError) and document the trade-off.
🧰 Tools
🪛 Flake8 (7.3.0)
[error] 192-192: local variable 'attention_mask' is assigned to but never used
(F841)
🪛 Ruff (0.14.5)
192-192: Local variable attention_mask is assigned to but never used
Remove assignment to unused variable attention_mask
(F841)
220-220: Consider moving this statement to an else block
(TRY300)
222-222: Do not catch blind exception: Exception
(BLE001)
🤖 Prompt for AI Agents
In tests/cp/train_cp_real.py around lines 176-225, remove the dead local
assignment attention_mask = batch['attention_mask'].to(device) (or actually pass
that variable into model if you intend to use it later) so you don't trigger
Flake8 F841, and replace the broad except Exception: block with a narrower
handler (e.g., except RuntimeError as e:) that logs the error and re-raises it
(or if you intentionally want best-effort behavior, document it and only return
a zero loss for that specific, well-justified exception), avoiding swallowing
all exceptions silently.
| # Training loop | ||
| model.train() | ||
| total_loss = 0.0 | ||
| step = start_step | ||
| optimizer.zero_grad() | ||
|
|
||
| while step < args.num_steps: | ||
| for batch in dataloader: | ||
| if step >= args.num_steps: | ||
| break | ||
|
|
||
| # Debug: print chunk info (only rank 0, only first batch) | ||
| if rank == 0 and step == start_step: | ||
| print(f"📦 Batch info:") | ||
| print(f" input_ids shape: {batch['input_ids'].shape}") | ||
| print(f" Chunk range: [{batch['chunk_start']}:{batch['chunk_end']}]") | ||
| print(f" Full sequence length: {batch['full_seq_len']}\n") | ||
|
|
||
| # Forward and backward | ||
| loss = compute_loss_with_cp( | ||
| model=model, | ||
| batch=batch, | ||
| cp_rank=cp_rank, | ||
| cp_size=cp_size, | ||
| cp_group=cp_group, | ||
| device=device | ||
| ) | ||
|
|
||
| # Scale loss for gradient accumulation | ||
| loss = loss / args.gradient_accumulation_steps | ||
| loss.backward() | ||
|
|
||
| total_loss += loss.item() | ||
|
|
||
| # Update weights | ||
| if (step + 1) % args.gradient_accumulation_steps == 0: | ||
| # Gradient clipping | ||
| torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.max_grad_norm) | ||
|
|
||
| # Optimizer step | ||
| optimizer.step() | ||
| optimizer.zero_grad() | ||
|
|
||
| # Logging (rank 0 only) | ||
| if rank == 0 and (step + 1) % 10 == 0: | ||
| avg_loss = total_loss * args.gradient_accumulation_steps / 10 | ||
| print(f"Step {step + 1}/{args.num_steps} | Loss: {avg_loss:.4f}") | ||
| total_loss = 0.0 | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
Add gradient all-reduce across CP ranks – current training desynchronizes model replicas.
Right now, each rank computes loss.backward() on its own chunk and steps the optimizer independently. There is no gradient synchronization across cp_group, so parameters on different ranks will diverge, which breaks the “single logical model with context parallelism” assumption.
Given cp_size == world_size and you already have cp_group, you should all-reduce gradients once per optimizer step (after gradient accumulation, before clipping/step). For example:
- # Update weights
- if (step + 1) % args.gradient_accumulation_steps == 0:
- # Gradient clipping
- torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.max_grad_norm)
-
- # Optimizer step
- optimizer.step()
- optimizer.zero_grad()
+ # Update weights
+ if (step + 1) % args.gradient_accumulation_steps == 0:
+ # All-reduce gradients across CP ranks so training matches single-GPU behavior
+ if cp_size > 1:
+ for param in model.parameters():
+ if param.grad is not None:
+ dist.all_reduce(param.grad, op=dist.ReduceOp.SUM, group=cp_group)
+ param.grad.div_(cp_size)
+
+ # Gradient clipping
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.max_grad_norm)
+
+ # Optimizer step
+ optimizer.step()
+ optimizer.zero_grad()Also consider updating the docstring in compute_loss_with_cp (“NO aggregation needed”) to clarify that explicit gradient aggregation happens in the training loop, not magically through autograd.
🏁 Script executed:
#!/bin/bash
# Sanity-check: search for any other places where gradients are synchronized
# to ensure this is the only spot responsible for CP grad aggregation.
rg -n "all_reduce|DistributedDataParallel|DDP" --type=pyLength of output: 726
I need to verify the current state of the training loop and the compute_loss_with_cp function to confirm whether gradient synchronization is actually missing. Let me examine the code more closely.
---
🏁 Script executed:
```shell
#!/bin/bash
# Check the compute_loss_with_cp function and its docstring
rg -n -A30 "def compute_loss_with_cp" --type=py
Length of output: 1984
🏁 Script executed:
#!/bin/bash
# Check exact lines around the training loop in train_cp_real.py
sed -n '395,415p' tests/cp/train_cp_real.pyLength of output: 725
🏁 Script executed:
#!/bin/bash
# Check if all_reduce is used anywhere in train_cp_real.py training section
rg -n "all_reduce|dist\." tests/cp/train_cp_real.py | head -20Length of output: 298
🏁 Script executed:
#!/bin/bash
# Find where cp_group and cp_size are defined in the training file
rg -n "cp_group|cp_size" tests/cp/train_cp_real.py | head -20Length of output: 1184
Let me verify if there's any other mechanism that might handle parameter synchronization:
---
🏁 Script executed:
```shell
#!/bin/bash
# Check model initialization to see if parameters are synchronized
rg -n -B5 -A10 "class GatedDeltaNetForCausalLMCP" --type=py
Length of output: 1538
🏁 Script executed:
#!/bin/bash
# Check forward pass to see if it does anything special
sed -n '200,230p' tests/cp/train_cp_real.pyLength of output: 1061
🏁 Script executed:
#!/bin/bash
# Look for any broadcast or synchronization in model code
rg -n "broadcast|set_devices|all_gather" --type=py tests/cp/Length of output: 1538
🏁 Script executed:
#!/bin/bash
# Check if GatedDeltaNetForCausalLMCP is a custom CP-aware model or just uses standard parameters
fd -e py -x grep -l "class GatedDeltaNetForCausalLMCP" {} \;Length of output: 57
Now let me check the exact gradient synchronization pattern used in the test:
Add gradient all-reduce across CP ranks – current training desynchronizes model replicas.
Right now, each rank computes loss.backward() on its own chunk and steps the optimizer independently. There is no gradient synchronization across cp_group, so parameters on different ranks will diverge, which breaks the "single logical model with context parallelism" assumption.
Given cp_size == world_size and you already have cp_group, you should all-reduce gradients once per optimizer step (after gradient accumulation, before clipping/step). For example:
- # Update weights
- if (step + 1) % args.gradient_accumulation_steps == 0:
- # Gradient clipping
- torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.max_grad_norm)
-
- # Optimizer step
- optimizer.step()
- optimizer.zero_grad()
+ # Update weights
+ if (step + 1) % args.gradient_accumulation_steps == 0:
+ # All-reduce gradients across CP ranks so training matches single-GPU behavior
+ if cp_size > 1:
+ for param in model.parameters():
+ if param.grad is not None:
+ dist.all_reduce(param.grad, op=dist.ReduceOp.SUM, group=cp_group)
+ param.grad.div_(cp_size)
+
+ # Gradient clipping
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.max_grad_norm)
+
+ # Optimizer step
+ optimizer.step()
+ optimizer.zero_grad()Also update the docstring in compute_loss_with_cp (line 188) from "NO aggregation needed - gradients flow naturally through autograd!" to clarify that explicit gradient aggregation happens in the training loop, not magically through autograd.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| # Training loop | |
| model.train() | |
| total_loss = 0.0 | |
| step = start_step | |
| optimizer.zero_grad() | |
| while step < args.num_steps: | |
| for batch in dataloader: | |
| if step >= args.num_steps: | |
| break | |
| # Debug: print chunk info (only rank 0, only first batch) | |
| if rank == 0 and step == start_step: | |
| print(f"📦 Batch info:") | |
| print(f" input_ids shape: {batch['input_ids'].shape}") | |
| print(f" Chunk range: [{batch['chunk_start']}:{batch['chunk_end']}]") | |
| print(f" Full sequence length: {batch['full_seq_len']}\n") | |
| # Forward and backward | |
| loss = compute_loss_with_cp( | |
| model=model, | |
| batch=batch, | |
| cp_rank=cp_rank, | |
| cp_size=cp_size, | |
| cp_group=cp_group, | |
| device=device | |
| ) | |
| # Scale loss for gradient accumulation | |
| loss = loss / args.gradient_accumulation_steps | |
| loss.backward() | |
| total_loss += loss.item() | |
| # Update weights | |
| if (step + 1) % args.gradient_accumulation_steps == 0: | |
| # Gradient clipping | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.max_grad_norm) | |
| # Optimizer step | |
| optimizer.step() | |
| optimizer.zero_grad() | |
| # Logging (rank 0 only) | |
| if rank == 0 and (step + 1) % 10 == 0: | |
| avg_loss = total_loss * args.gradient_accumulation_steps / 10 | |
| print(f"Step {step + 1}/{args.num_steps} | Loss: {avg_loss:.4f}") | |
| total_loss = 0.0 | |
| # Training loop | |
| model.train() | |
| total_loss = 0.0 | |
| step = start_step | |
| optimizer.zero_grad() | |
| while step < args.num_steps: | |
| for batch in dataloader: | |
| if step >= args.num_steps: | |
| break | |
| # Debug: print chunk info (only rank 0, only first batch) | |
| if rank == 0 and step == start_step: | |
| print(f"📦 Batch info:") | |
| print(f" input_ids shape: {batch['input_ids'].shape}") | |
| print(f" Chunk range: [{batch['chunk_start']}:{batch['chunk_end']}]") | |
| print(f" Full sequence length: {batch['full_seq_len']}\n") | |
| # Forward and backward | |
| loss = compute_loss_with_cp( | |
| model=model, | |
| batch=batch, | |
| cp_rank=cp_rank, | |
| cp_size=cp_size, | |
| cp_group=cp_group, | |
| device=device | |
| ) | |
| # Scale loss for gradient accumulation | |
| loss = loss / args.gradient_accumulation_steps | |
| loss.backward() | |
| total_loss += loss.item() | |
| # Update weights | |
| if (step + 1) % args.gradient_accumulation_steps == 0: | |
| # All-reduce gradients across CP ranks so training matches single-GPU behavior | |
| if cp_size > 1: | |
| for param in model.parameters(): | |
| if param.grad is not None: | |
| dist.all_reduce(param.grad, op=dist.ReduceOp.SUM, group=cp_group) | |
| param.grad.div_(cp_size) | |
| # Gradient clipping | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.max_grad_norm) | |
| # Optimizer step | |
| optimizer.step() | |
| optimizer.zero_grad() | |
| # Logging (rank 0 only) | |
| if rank == 0 and (step + 1) % 10 == 0: | |
| avg_loss = total_loss * args.gradient_accumulation_steps / 10 | |
| print(f"Step {step + 1}/{args.num_steps} | Loss: {avg_loss:.4f}") | |
| total_loss = 0.0 |
🧰 Tools
🪛 Flake8 (7.3.0)
[error] 386-386: f-string is missing placeholders
(F541)
🪛 Ruff (0.14.5)
386-386: f-string without any placeholders
Remove extraneous f prefix
(F541)
🤖 Prompt for AI Agents
In tests/cp/train_cp_real.py around lines 373 to 421, the training loop calls
loss.backward() per CP rank but never synchronizes gradients across cp_group,
causing model replicas to diverge; after gradient accumulation completes and
before gradient clipping and optimizer.step(), perform an all-reduce across
cp_group over each parameter gradient (torch.distributed.all_reduce with SUM)
and divide by cp_size (or use AVG) so gradients are synchronized across CP
ranks; then continue with clip_grad_norm_ and optimizer.step()/zero_grad(). Also
update the docstring in compute_loss_with_cp (around line 188) to replace "NO
aggregation needed - gradients flow naturally through autograd!" with a note
that explicit gradient aggregation is performed in the training loop via
distributed all-reduce across cp_group.
| import pytest | ||
| import torch | ||
| import torch.distributed as dist | ||
| import torch.nn.functional as F | ||
| from dtest import DTest | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix dtest import to unblock CI
CI is currently failing with ModuleNotFoundError: No module named 'dtest' when importing this module. To avoid hard‑depending on dtest and to make the tests fail gracefully when it’s absent, consider guarding the import and skipping the module if dtest is not installed:
-import pytest
-import torch
-import torch.distributed as dist
-import torch.nn.functional as F
-from dtest import DTest
+import pytest
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+
+try:
+ from dtest import DTest
+except ModuleNotFoundError:
+ pytest.skip("dtest not available; skipping CP GatedDeltaNet tests", allow_module_level=True)If dtest is expected to be available in CI, alternatively ensure it is installed or fix the import path instead, but in any case the current unconditional import is a blocker.
🤖 Prompt for AI Agents
In tests/ops/cp/test_gated_delta_cp_real.py lines 1-6, the unconditional import
of dtest causes CI failures when that package is missing; replace it with a
guarded import so the module is skipped when dtest is unavailable. Use
pytest.importorskip("dtest") (or a try/except ImportError that calls pytest.skip
at import time) and then reference DTest from the imported module, so tests fail
gracefully when dtest is not installed.
|
i don't have the permission to push to the head branch, feel free to merge into the main |
Summary by CodeRabbit
New Features
Documentation
Tests