Skip to content

Conversation

@MatthewGleeson
Copy link

@MatthewGleeson MatthewGleeson commented Oct 31, 2025

Summary by CodeRabbit

  • New Features

    • Added context-parallel training and inference support for distributed multi-GPU execution
    • Introduced distributed halo exchange utilities for optimized sequence processing across ranks
    • Added memory benchmarking tools for analyzing GPU memory usage across configurations
  • Documentation

    • Added comprehensive guides for context-parallel training setup and usage
  • Tests

    • Added validation suites for numerical correctness, forward/backward parity, and distributed stability
    • Included memory benchmarking and realistic training examples for context-parallel workflows

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Oct 31, 2025

Walkthrough

This PR introduces comprehensive context-parallel (CP) support for the GatedDeltaNet model, enabling efficient distributed training across multiple GPU ranks. It adds CP-enabled layer and model implementations, distributed halo exchange operations with autograd support, and extensive test coverage including forward/backward parity validation, memory benchmarking, and training examples.

Changes

Cohort / File(s) Summary
Core CP Layer Implementation
fla/layers/gated_deltanet_cp.py
Implements GatedDeltaNet layer with context-parallel support, including gating mechanisms, short convolution branches, halo exchange for CP, dual kernel modes (chunk/fused_recurrent), and comprehensive state handling.
CP Model Architecture
fla/models/gated_deltanet/modeling_gated_deltanet_cp.py
Adds GatedDeltaNetBlockCP, GatedDeltaNetPreTrainedModel, GatedDeltaNetModelCP, and GatedDeltaNetForCausalLMCP classes with residual connections, optional fusion paths, generation support, and CP parameter propagation.
CP Halo Exchange Operations
fla/ops/gated_delta_rule/cp_halo.py, fla/ops/gated_delta_rule/chunk_cp.py, fla/ops/gated_delta_rule/__init__.py
Implements distributed halo exchange utilities (HaloExchangeAndExtendFn, halo_exchange_and_extend), chunk gated delta rule CP variant (ChunkGatedDeltaRuleFunctionCP), and public exports.
Layer-Level CP Testing
tests/cp/test_cp_halo.py, tests/cp/test_cp_halo_unit.py, tests/cp/test_cp_stability.py
Unit and integration tests for halo exchange with fixed/variable-length sequences, autograd validation, and stability smoke tests across distributed ranks.
Model-Level CP Validation
tests/cp/test_gdn_cp_forward.py, tests/cp/test_gdn_cp_backward.py, tests/cp/test_numerical_correctness.py
Forward/backward parity tests between single-rank and multi-rank CP execution, embedding and parameter gradient validation, and end-to-end numerical correctness checks.
Operators Testing
tests/ops/cp/test_gated_delta_cp_fake.py, tests/ops/cp/test_gated_delta_cp_real.py
Fake (simulated) and real distributed tests for chunk gated delta rule CP with forward/backward pass validation across sharded partitions.
Benchmarking & Training
tests/cp/benchmark_memory.py, tests/cp/train_cp_real.py
Memory usage profiling across CP configurations and realistic multi-rank training example with custom data collation, checkpointing, and gradient accumulation.
Documentation
tests/cp/README.md, tests/ops/cp/README.md
CP concept overview, quick-start commands, key insights, and validation workflow documentation.

Sequence Diagram(s)

sequenceDiagram
    participant R0 as Rank 0
    participant R1 as Rank 1
    participant R2 as Rank 2
    
    Note over R0,R2: Forward Pass with Halo Exchange
    R0->>R0: Compute local<br/>chunk
    R0->>R1: Send tail h tokens<br/>(right halo)
    R1->>R1: Receive left halo<br/>from R0
    R1->>R1: Compute local<br/>chunk + halo
    R1->>R2: Send tail h tokens
    R2->>R2: Receive left halo<br/>from R1
    R2->>R2: Compute local<br/>chunk + halo
    
    Note over R0,R2: Backward Pass (Gradient Flow)
    R2->>R2: Compute local<br/>gradients
    R2->>R1: Send gradient<br/>baton (dh)
    R1->>R1: Receive dh,<br/>compute grads
    R1->>R0: Send gradient<br/>baton (dh)
    R0->>R0: Receive dh,<br/>compute grads
Loading
sequenceDiagram
    participant Model as GatedDeltaNetForCausalLMCP
    participant Block as GatedDeltaNetBlockCP
    participant Attn as GatedDeltaNet (CP)
    participant Halo as halo_exchange_and_extend_autograd
    
    Model->>Block: forward(input_ids, cp_rank, cp_size, cp_group)
    activate Block
    Block->>Attn: forward(hidden_states, cp_rank, cp_size, cp_group)
    activate Attn
    Attn->>Halo: halo_exchange_and_extend_autograd(q, k, v, h)
    activate Halo
    Halo-->>Attn: (q_ext, k_ext, v_ext, cu_seqlens_ext)
    deactivate Halo
    Attn->>Attn: chunk_gated_delta_rule_cp<br/>(with halo-extended inputs)
    Attn-->>Block: (attn_output)
    deactivate Attn
    Block-->>Model: (block_output)
    deactivate Block
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Areas requiring extra attention:

  • fla/ops/gated_delta_rule/cp_halo.py: Distributed communication primitives (send/recv patterns), tensor concatenation logic for halo boundaries, and autograd backward pass synchronization across ranks—critical for correctness in multi-GPU scenarios.
  • fla/ops/gated_delta_rule/chunk_cp.py: ChunkGatedDeltaRuleFunctionCP forward/backward implementation with inter-rank state propagation, force_output_final_state behavior, and optional L2 norm corrections—complex gradient flow.
  • fla/layers/gated_deltanet_cp.py: Integration of halo exchange with short convolution paths, handling of cu_seqlens transformations, dual kernel modes with proper guards, and state management across CP boundaries.
  • fla/models/gated_deltanet/modeling_gated_deltanet_cp.py: Model architecture with residual connections, parameter propagation through CP-aware kwargs, weight initialization strategies, and generation interface compatibility.
  • tests/cp/test_gdn_cp_forward.py, test_gdn_cp_backward.py: Distributed test setup, reference generation/broadcasting, input sharding logic, and gradient aggregation with tolerance validation—ensure parity logic is sound.

Possibly related PRs

  • [L2Norm] Speedup by saving rstd #506: Touches L2-normalization integration and use_qk_l2norm_in_kernel kernel flag used by gated-delta chunk/fused_recurrent paths; the related PR changes rstd propagation for backward while this PR enables and uses per-kernel q/k L2-norm in CP variants.

Suggested reviewers

  • yzhangcs

Poem

🐰 Context parallel paths now align,
Halo halos dance in ranks of nine!
Tensors sharded, gradients flow,
DeltaNet scales—watch it grow! 🚀
Distributed dreams, no longer a fuss,
Context wins; trust us, trust us! ✨

Pre-merge checks and finishing touches

❌ Failed checks (1 warning, 1 inconclusive)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 32.18% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
Title check ❓ Inconclusive The title 'Mattg context parallel halo' is too vague and generic, using non-descriptive terms without clearly conveying what was changed. While it mentions 'context parallel' and 'halo', it lacks specificity about the nature of the changes (implementation, feature, fix, etc.). Revise the title to be more descriptive and specific. For example: 'Add context-parallel halo exchange for GatedDeltaNet' or 'Implement context-parallel support with halo exchange for gated delta rule' would better convey the scope and purpose of the changes.
✅ Passed checks (1 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Tip

📝 Customizable high-level summaries are now available in beta!

You can now customize how CodeRabbit generates the high-level summary in your pull requests — including its content, structure, tone, and formatting.

  • Provide your own instructions using the high_level_summary_instructions setting.
  • Format the summary however you like (bullet lists, tables, multi-section layouts, contributor stats, etc.).
  • Use high_level_summary_in_walkthrough to move the summary from the description to the walkthrough section.

Example instruction:

"Divide the high-level summary into five sections:

  1. 📝 Description — Summarize the main change in 50–60 words, explaining why this PR is needed, why this solution was chosen, and what was done.
  2. 📓 References — List relevant issues, discussions, documentation, or related PRs.
  3. 📦 Dependencies & Requirements — Mention any new/updated dependencies, environment variable changes, or configuration updates.
  4. 📊 Contributor Summary — Include a Markdown table showing contributions:
    | Contributor | Lines Added | Lines Removed | Files Changed |
  5. ✔️ Additional Notes — Add any extra reviewer context.
    Keep each section concise (under 200 words) and use bullet or numbered lists for clarity."

Note: This feature is currently in beta for Pro-tier users, and pricing will be announced later.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@sustcsonglin
Copy link
Collaborator

Hi, thanks for the PR! could you fix the following CI test error?

==================================== ERRORS ====================================
__________ ERROR collecting tests/ops/cp/test_gated_delta_cp_real.py ___________
ImportError while importing test module '/home/runner/work/flash-linear-attention/flash-linear-attention/tests/ops/cp/test_gated_delta_cp_real.py'.
Hint: make sure your test modules/packages have valid Python names.
Traceback:
/opt/hostedtoolcache/Python/3.10.18/x64/lib/python3.10/importlib/init.py:126: in import_module
return _bootstrap._gcd_import(name[level:], package, level)
tests/ops/cp/test_gated_delta_cp_real.py:5: in
from dtest import DTest
E ModuleNotFoundError: No module named 'dtest'

@sustcsonglin sustcsonglin force-pushed the main branch 2 times, most recently from 1700f8d to f4082b3 Compare November 11, 2025 16:49
@sustcsonglin sustcsonglin marked this pull request as ready for review November 17, 2025 20:13
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 7

🧹 Nitpick comments (34)
tests/cp/test_cp_stability.py (1)

1-1: Consider making the file executable or removing the shebang.

The static analysis tool notes that the shebang is present but the file isn't executable.

To fix, either:

chmod +x tests/cp/test_cp_stability.py

Or remove the shebang if the file is only meant to be run via torchrun.

fla/ops/gated_delta_rule/cp_halo.py (2)

94-97: Fix code style: split multiple statements onto separate lines.

Line 97 contains three statements separated by semicolons, which violates PEP 8 style guidelines and is flagged by linters.

Apply this diff:

             seq_boundaries = cu_seqlens[1:-1]
             if seq_boundaries.numel() > 0 and bool((seq_boundaries == cp_shard_start_idx).any().item()):
-                q_halo.zero_(); k_halo.zero_(); v_halo.zero_()
+                q_halo.zero_()
+                k_halo.zero_()
+                v_halo.zero_()

202-208: Fix code style: split statement onto separate line.

Line 208 contains send_buf.zero_() which should be on its own line for consistency with the rest of the codebase.

Apply this diff:

                 if cu_seqlens_saved is not None:
                     seq_boundaries = cu_seqlens_saved[1:-1]
                     if seq_boundaries.numel() > 0 and bool((seq_boundaries == ctx.cp_shard_start_idx).any().item()):
-                        send_buf.zero_()
+                        send_buf.zero_()

Note: This is more about maintaining consistency with the fix suggested for line 97.

tests/ops/cp/README.md (1)

1-2: Documentation addresses the CI import error.

This README explains the ModuleNotFoundError: No module named 'dtest' mentioned in the PR comments. Users need to install the DTest package from the referenced repository to run these distributed tests.

Consider adding a note in the main repository README or CI documentation about this optional dependency requirement.

tests/cp/README.md (3)

24-32: Add language specifier to code block.

The code block should specify bash as the language for proper syntax highlighting and markdown compliance.

Apply this diff:

-```
+```bash

 torchrun --nproc_per_node=8 tests/cp/train_cp_real.py --cp_size 8

38-46: Add language specifier to code block.

The code block should specify bash as the language for proper syntax highlighting.

Apply this diff:

-```
+```bash

 # Single GPU baseline
 python tests/cp/benchmark_memory.py --cp_size 1

59-67: Add language specifier to code block.

The code block should specify bash as the language for proper syntax highlighting.

Apply this diff:

-```
+```bash

 # Step 1: Generate reference on single GPU
 python test_numerical_correctness.py --mode single --save_ref
tests/ops/cp/test_gated_delta_cp_fake.py (1)

22-28: Clarify cp_shard default and tidy unused temporaries

cp_shard currently does:

return tensor.tensor_split(cp_degree or self.cp_degree, dim=shard_dim)

but self.cp_degree is never defined on TestFakeCP. Right now all callsites pass cp_degree, so this doesn’t break, but it’s a latent bug if cp_shard is ever reused without that argument.

You can either (a) store a self.cp_degree attribute on the class, or (b) make cp_degree required and drop the fallback, e.g.:

-    def cp_shard(
-        self,
-        tensor: torch.Tensor,
-        cp_degree: int | None = None,
-        shard_dim: int | None = 1,
-    ) -> torch.Tensor:
-        return tensor.tensor_split(cp_degree or self.cp_degree, dim=shard_dim)
+    def cp_shard(
+        self,
+        tensor: torch.Tensor,
+        cp_degree: int,
+        shard_dim: int = 1,
+    ) -> tuple[torch.Tensor, ...]:
+        return tensor.tensor_split(cp_degree, dim=shard_dim)

Separately, the unpacked locals g_out, A, dh0_ref, and o_i are only used to reach later outputs. If you want Ruff to be clean in strict mode, consider prefixing the unused ones with _:

-        g_out, o_ref, A, _ = chunk_gated_delta_rule_fwd(...)
+        _g_out, o_ref, A, _ = chunk_gated_delta_rule_fwd(...)
...
-        dq_ref, dk_ref, dv_ref, db_ref, dg_ref, dh0_ref = chunk_gated_delta_rule_bwd(...)
+        dq_ref, dk_ref, dv_ref, db_ref, dg_ref, _dh0_ref = chunk_gated_delta_rule_bwd(...)
...
-            g_out_i, o_i, A_i, initial_state = chunk_gated_delta_rule_fwd(...)
+            g_out_i, _o_i, A_i, initial_state = chunk_gated_delta_rule_fwd(...)

Also applies to: 50-56, 128-128, 148-148

fla/layers/gated_deltanet_cp.py (2)

10-21: Remove unused imports and local to keep the module lint‑clean

In this module:

  • torch.distributed as dist is imported but not used.
  • chunk_gated_delta_rule is imported from fla.ops.gated_delta_rule but never referenced (only fused_recurrent_gated_delta_rule and chunk_gated_delta_rule_cp are used).
  • cp_shard_start_idx = kwargs.get('cp_shard_start_idx', None) is assigned but never used; later calls re‑fetch cp_shard_start_idx from kwargs directly.

These don’t change behavior but will trip Flake8 (F401/F841). A minimal cleanup would be:

-import torch.distributed as dist
...
-from fla.ops.gated_delta_rule import chunk_gated_delta_rule, fused_recurrent_gated_delta_rule
+from fla.ops.gated_delta_rule import fused_recurrent_gated_delta_rule
...
-        cp_shard_start_idx = kwargs.get('cp_shard_start_idx', None)
+        # cp_shard_start_idx is only threaded through to halo helpers via kwargs

Also applies to: 92-107, 227-232


92-107: Align type hints with None defaults for num_v_heads and layer_idx

__init__ currently declares:

num_v_heads: int = None,
...
layer_idx: int = None,

With Ruff’s RUF013 and standard typing conventions, these should be explicitly optional:

-        num_v_heads: int = None,
+        num_v_heads: Optional[int] = None,
...
-        layer_idx: int = None,
+        layer_idx: Optional[int] = None,

This keeps the signature honest for type checkers without changing runtime behavior.

fla/ops/gated_delta_rule/chunk_cp.py (1)

1-17: Trim unused imports and align a couple of type hints

Top of file, several heavy operators are imported but never used in the CP wrapper:

from fla.ops.common.chunk_delta_h import chunk_gated_delta_rule_bwd_dhu, chunk_gated_delta_rule_fwd_h
from fla.ops.common.chunk_o import chunk_bwd_dqkwg, chunk_bwd_dv_local, chunk_fwd_o
from fla.ops.common.chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd
from fla.ops.gated_delta_rule.wy_fast import prepare_wy_repr_bwd, recompute_w_u_fwd
from fla.ops.utils import chunk_local_cumsum, solve_tril

Given ChunkGatedDeltaRuleFunctionCP delegates entirely to chunk_gated_delta_rule_fwd / chunk_gated_delta_rule_bwd, these imports can be dropped to keep the module light and avoid Flake8 F401:

-from fla.modules.l2norm import l2norm_bwd, l2norm_fwd
-from fla.ops.common.chunk_delta_h import chunk_gated_delta_rule_bwd_dhu, chunk_gated_delta_rule_fwd_h
-from fla.ops.common.chunk_o import chunk_bwd_dqkwg, chunk_bwd_dv_local, chunk_fwd_o
-from fla.ops.common.chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd
-from fla.ops.gated_delta_rule.wy_fast import prepare_wy_repr_bwd, recompute_w_u_fwd
-from fla.ops.utils import chunk_local_cumsum, solve_tril
-from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard
+from fla.modules.l2norm import l2norm_bwd, l2norm_fwd
+from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard

A couple of small typing nits if you want Ruff’s RUF013 clean:

-    scale: float = None,
-    initial_state: torch.Tensor = None,
+    scale: Optional[float] = None,
+    initial_state: Optional[torch.Tensor] = None,

and similarly keep the local unpackings for T as _T if you want to silence “unused variable” warnings:

-        B, T, H, K = q.shape
-        B, T, H, V = v.shape
+        B, _T, H, K = q.shape
+        B, _T, H, V = v.shape
...
-        B, T, H, K = q.shape
-        B, T, H, V = v.shape
+        B, _T, H, K = q.shape
+        B, _T, H, V = v.shape

These are non-functional, but they keep the file tidy.

Also applies to: 42-43, 115-116, 171-186

tests/cp/test_gdn_cp_backward.py (1)

111-129: Minor cleanups for unused locals in CP backward test harness

The CP parity harness is logically sound, but a couple of locals are never used:

  • rank from setup is returned and then not used in cp.
  • In cp, B, T = input_ids.shape only T is used.

If you run Ruff in stricter modes, these will trigger RUF059. You can either prefix them with _ or drop them:

-def setup(cp_size: int):
+def setup(cp_size: int):
@@
-    rank = dist.get_rank()
+    rank = dist.get_rank()
@@
-    return rank, device, group
+    return rank, device, group
...
-def cp(cp_size: int):
+def cp(cp_size: int):
@@
-    rank, device, group = setup(cp_size)
+    _rank, device, group = setup(cp_size)
@@
-    B, T = input_ids.shape
+    _B, T = input_ids.shape

Purely cosmetic, but it keeps linters quiet.

Also applies to: 132-136, 148-151

tests/ops/cp/test_gated_delta_cp_real.py (2)

71-131: Use world_size argument for a sanity check (also fixes Ruff ARG002)

test_fwd takes world_size but never uses it, which triggers Ruff’s ARG002 and also misses an easy safety check. You can both satisfy the linter and assert the distributed setup matches the mark:

-    @pytest.mark.world_size([2, 4, 8])
-    def test_fwd(self, world_size: int) -> None:
+    @pytest.mark.world_size([2, 4, 8])
+    def test_fwd(self, world_size: int) -> None:
         q = torch.rand(self.B, self.T, self.H, self.D, dtype=self.dtype)
         ...
-        # ---- reference forward (no CP) ----
+        # ---- reference forward (no CP) ----
+        assert dist.get_world_size() == world_size

This keeps the fixture useful and avoids the unused‑arg warning.


132-233: Clean up minor duplication and unused values in test_bwd

The backward test logic looks sound (reference path vs CP path with baton gradients), but there are a few small nits you may want to tidy:

  • qn and kn are computed twice (Lines 145–147 and again 151–153); the second pair is redundant.
  • o_ref and dh0_ref from chunk_gated_delta_rule_fwd / chunk_gated_delta_rule_bwd aren’t used; similarly o_cp from the CP forward is unused. You can prefix them with _ to silence RUF059 or replace with _ in the unpacking:
-        g_out_ref, o_ref, A_ref, _ = chunk_gated_delta_rule_fwd(
+        g_out_ref, _o_ref, A_ref, _ = chunk_gated_delta_rule_fwd(
...
-        dq_ref, dk_ref, dv_ref, db_ref, dg_ref, dh0_ref = chunk_gated_delta_rule_bwd(
+        dq_ref, dk_ref, dv_ref, db_ref, dg_ref, _dh0_ref = chunk_gated_delta_rule_bwd(
...
-        g_out_cp, o_cp, A_cp, ht = chunk_gated_delta_rule_fwd(
+        g_out_cp, _o_cp, A_cp, ht = chunk_gated_delta_rule_fwd(

Also, the NOTE about some asserts failing at tighter tolerance for world_size=8 (Lines 227–228) is a bit concerning; if this is not fully understood yet, consider either tightening the tolerance back after investigation or documenting clearly why the relaxed 0.02 tolerance is acceptable.

tests/cp/test_cp_halo_unit.py (3)

25-33: Drop unused Tuple import

Tuple from typing is imported but never used, which flake8 flags (F401). You can safely remove it:

-import argparse
-import os
-from typing import Optional, Tuple
+import argparse
+import os
+from typing import Optional

No behavioral impact.


65-87: Tighten shape unpacking to reflect actual usage

In depthwise_causal_conv and global_baseline, you unpack B, T, D = x.shape, but some of these variables aren’t used, which Ruff flags. You can avoid the warnings and make the intent clearer by only binding what you need, e.g.:

-def depthwise_causal_conv(x: torch.Tensor, w: torch.Tensor, h: int) -> torch.Tensor:
-    ...
-    B, T, D = x.shape
+def depthwise_causal_conv(x: torch.Tensor, w: torch.Tensor, h: int) -> torch.Tensor:
+    ...
+    _, _, D = x.shape

and in global_baseline:

-    B, T, D = x.shape
+    B, _, D = x.shape

This keeps behavior identical while satisfying linters.


151-241: Address unused varlen argument and unneeded outputs from halo exchange

run_cp doesn’t actually use its varlen argument, and Ruff also points out unused unpacked outputs from halo_exchange_and_extend_autograd and an unused B from x_full.shape. You can clean this up as follows:

  • If varlen is not needed (since the reference already encodes varlen via cu_seqlens), drop it from the signature and callers, or rename to _varlen to indicate intentional discard.
  • Avoid unused outputs from halo exchange:
-        q_ext, k_ext, v_ext, cu_ext = halo_exchange_and_extend_autograd(
+        q_ext, *_ = halo_exchange_and_extend_autograd(
             q_in, k_in, v_in, h,
             cp_rank=rank, cp_size=cp_size, cp_group=group,
             cu_seqlens=cu_full, cp_shard_start_idx=s if cu_full is not None else None,
         )
  • Avoid binding unused B:
-    B, T, D = x_full.shape
+    _, T, D = x_full.shape

These are purely cosmetic/quality improvements; the core CP/halo logic looks correct.

tests/cp/benchmark_memory.py (2)

22-37: Consider asserting world_size == cp_size for multi-GPU runs

In setup_distributed, when cp_size > 1 you assume the default process group has exactly cp_size ranks but don’t enforce it. A mismatched --nproc_per_node vs --cp_size could cause confusing hangs inside the model. You can add a quick sanity check:

def setup_distributed(cp_size):
    """Setup distributed"""
    if cp_size > 1:
        dist.init_process_group(backend='nccl')
        rank = dist.get_rank()
+        world_size = dist.get_world_size()
+        if world_size != cp_size:
+            raise ValueError(f"Expected world_size == cp_size, got {world_size} vs {cp_size}")
        ...

This mirrors the stricter check you use in tests/cp/test_numerical_correctness.py and fails fast on misconfiguration.


132-209: Fix trivial f-string lint issues in printing

Several print calls use f-strings without any placeholders (Ruff/flake8 F541), e.g.:

  • Line 152: print(f"\nConfiguration:")
  • Line 180: print(f" ✓ Success")
  • Lines 205–207: print(f" python benchmark_memory.py --cp_size 1"), etc.

These can be plain strings:

-        print(f"\nConfiguration:")
+        print("\nConfiguration:")
...
-                print(f"  ✓ Success")
+                print("  ✓ Success")
...
-        print(f"  python benchmark_memory.py --cp_size 1")
+        print("  python benchmark_memory.py --cp_size 1")

Purely a lint/style fix; behavior is unchanged.

tests/cp/test_gdn_cp_forward.py (3)

18-27: Drop unused Optional import

Optional is imported from typing but never used (flake8 F401). You can safely remove it:

-import argparse
-import os
-from typing import Optional
+import argparse
+import os

No functional impact.


99-107: Distributed setup is reasonable but assumes world_size == cp_size

setup_dist initializes a process group and then creates a new group containing range(cp_size) without explicitly verifying that the world size matches cp_size. In practice you’ll run with --nproc_per_node == cp_size, but to harden this, consider asserting dist.get_world_size() == cp_size (similar to test_numerical_correctness). This is optional but would catch misconfigured launches early.


110-170: Clear CP parity logic; address minor lint (unused varlen and B)

The run_cp path correctly:

  • Broadcasts the reference dict from rank 0.
  • Rebuilds the model per conv size and loads the saved state.
  • Shards the sequence over ranks, runs CP forward, gathers logits, and compares to reference on rank 0 with a sane tolerance.

Two small cleanups:

  • varlen is accepted but never used; either drop it from the signature (and from main) or rename to _varlen to indicate it’s intentionally ignored.
  • B, T = input_ids.shape binds B but never uses it; you can write _, T = input_ids.shape to avoid RUF059.

Otherwise this function looks good.

tests/cp/test_numerical_correctness.py (2)

56-115: Single-GPU reference flow is well-structured; consider optional CPU fallback

The single-GPU test builds a compact config, creates dummy data via create_model_and_data, runs a forward pass with cp_rank=0, cp_size=1, and optionally saves a full reference bundle. That’s correct and consistent with the CP side. The only caveat is the hardcoded device = torch.device('cuda'); if you ever want to run these tests on CPU-only environments, you’d need a guard similar to other files ('cuda' if torch.cuda.is_available() else 'cpu'). Not required if these tests are GPU-only by design.


117-303: CP numerical correctness logic is solid; fix trivial lint issues

The test_context_parallel flow is robust:

  • Uses setup_distributed to ensure ranks/devices are aligned.
  • Loads the single-GPU reference on rank 0, broadcasts tensors to all ranks.
  • Constructs identical models on all ranks and synchronizes weights by broadcasting parameter tensors.
  • Shards the sequence over ranks, runs CP forward, gathers logits, and compares against the reference using loss diff, absolute/relative logits diff, and correlation, with clear thresholds.

A few small cleanups that will fix the static-analysis warnings:

  • Avoid reloading the reference twice on rank 0; you can reuse the first ref.
  • Replace f-strings without placeholders with plain strings, e.g.:
-    print(f"\nResults:")
+    print("\nResults:")
...
-        print(f"\n📊 Loss Comparison:")
+        print("\n📊 Loss Comparison:")
...
-        print(f"\n📊 Logits Comparison:")
+        print("\n📊 Logits Comparison:")
...
-            print(f"\n🎉 ALL TESTS PASSED!")
+            print("\n🎉 ALL TESTS PASSED!")

and similarly for other lines Ruff/flake8 flagged (170, 173, 287, 288, 293, 297, 298).

  • The continuation of the parser.add_argument('--mode', ...) line (around Line 308) can be re-indented to satisfy E128, e.g.:
parser.add_argument(
    '--mode', type=str, required=True, choices=['single', 'cp'],
    help='Test mode',
)

None of these change behavior but will keep the file lint‑clean.

tests/cp/train_cp_real.py (4)

53-82: Dataset is fine, but mode is currently unused.

RealisticTextDataset does what it promises (deterministic synthetic full sequences per rank), but the mode argument is not used anywhere, which is a small source of confusion.

If you don’t plan to support file-based loading here, consider dropping mode for now; otherwise, implement the mode == "file" branch (or at least raise NotImplementedError when mode != "random").

-        self.vocab_size = vocab_size
-        self.mode = mode
+        self.vocab_size = vocab_size
+        self.mode = mode
+
+        if self.mode != "random":
+            raise NotImplementedError("Only mode='random' is implemented in RealisticTextDataset.")

125-173: Fix unused batch_size binding in CPCollator (RUF059) and keep shard logic as-is.

The collator’s sharding logic looks sound, but batch_size from batch_size, seq_len = input_ids.shape is never used, which Ruff flags as RUF059.

Minimal fix:

-        batch_size, seq_len = input_ids.shape
+        _, seq_len = input_ids.shape

Everything else in this collator (padding to make seq_len divisible by cp_size, using -100 for label padding, and exposing chunk_start/chunk_end) looks consistent with the downstream CP model.


259-371: Fix unused local_rank and stray f-strings without placeholders (F541, RUF059).

In main:

  • local_rank is unpacked but never used outside logging of device, which already captures what we need.
  • Several print calls use f"..." without any placeholders, flagged by Flake8 F541.

Suggested adjustments:

-    rank, world_size, local_rank, device, cp_rank, cp_size, cp_group = setup_distributed_cp(args.cp_size)
+    rank, world_size, _local_rank, device, cp_rank, cp_size, cp_group = setup_distributed_cp(args.cp_size)
@@
-    if rank == 0:
-        print("\n" + "="*80)
-        print("REALISTIC CONTEXT PARALLELISM TRAINING")
-        print("="*80)
-        print(f"\n🔧 Setup:")
+    if rank == 0:
+        print("\n" + "=" * 80)
+        print("REALISTIC CONTEXT PARALLELISM TRAINING")
+        print("=" * 80)
+        print("\n🔧 Setup:")
@@
-    if rank == 0:
-        print(f"\n📐 Model Configuration:")
+    if rank == 0:
+        print("\n📐 Model Configuration:")
@@
-    if rank == 0:
-        print(f"\n📚 Dataset:")
+    if rank == 0:
+        print("\n📚 Dataset:")
@@
-    if rank == 0:
-        print(f"\n🚀 Training:")
+    if rank == 0:
+        print("\n🚀 Training:")

Similar change for the batch-info header:

-            if rank == 0 and step == start_step:
-                print(f"📦 Batch info:")
+            if rank == 0 and step == start_step:
+                print("📦 Batch info:")

1-12: Optional: drop shebang or make file executable to silence EXE001.

Ruff flags EXE001 because there’s a shebang but the file is not necessarily executable in the repo. If you don’t plan to run this via ./train_cp_real.py, you can safely remove the shebang:

-#!/usr/bin/env python
-# -*- coding: utf-8 -*-
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-

or just drop the shebang entirely and keep the encoding line:

-#!/usr/bin/env python
-# -*- coding: utf-8 -*-
+# -*- coding: utf-8 -*-

Pick whichever matches how you intend to invoke this script.

fla/models/gated_deltanet/modeling_gated_deltanet_cp.py (6)

38-82: CP block wiring looks good; just confirm MLP supports **kwargs.

GatedDeltaNetBlockCP correctly:

  • Normalizes + residuals in a prenorm style.
  • Switches between Attention and GatedDeltaNet based on config.attn.
  • Threads **kwargs (CP params) into self.attn.

One thing to double-check: the tail call

hidden_states = self.mlp(hidden_states, **kwargs)

assumes GatedDeltaNetMLP’s forward accepts **kwargs. If its signature is still forward(self, hidden_states: torch.Tensor), this will raise TypeError as soon as you pass CP keywords like cp_rank.

Please verify the MLP signature and, if needed, either:

  • Extend GatedDeltaNetMLP.forward to accept **kwargs, or
  • Drop the kwargs here if they are not needed:
-        hidden_states = self.mlp(hidden_states, **kwargs)
+        hidden_states = self.mlp(hidden_states)

119-126: Avoid mutable class-level attributes for _no_split_modules (RUF012).

_no_split_modules is currently a list, which Ruff flags as a mutable class attribute. Since this is effectively constant metadata, a tuple is safer and silences RUF012.

-    _no_split_modules = ['GatedDeltaNetBlockCP']
+    _no_split_modules = ('GatedDeltaNetBlockCP',)

268-282: Convert _tied_weights_keys to a tuple to avoid mutable class attribute (RUF012).

Similar to _no_split_modules, _tied_weights_keys is defined as a list and flagged by Ruff. It’s logically constant, so a tuple is more appropriate:

-    _tied_weights_keys = ["lm_head.weight"]
+    _tied_weights_keys = ("lm_head.weight",)

373-383: Tidy tuple construction in the non-return_dict path (RUF005).

Ruff suggests avoiding tuple concatenation with + in favor of unpacking. You can simplify the non-return_dict branch as:

-        if not return_dict:
-            output = (logits,) + outputs[1:]
-            return (loss,) + output if loss is not None else output
+        if not return_dict:
+            output = (logits, *outputs[1:])
+            if loss is not None:
+                return (loss, *output)
+            return output

This is a minor style/readability improvement and eliminates the RUF005 warning.


189-217: Add stacklevel=2 to warnings.warn for clearer call sites.

Python's warnings.warn supports a stacklevel keyword argument, and using stacklevel=2 is the standard way for a wrapper to make the warning point to the wrapper's caller rather than the warn() call itself. Apply the suggested change to line 213-215:

         if output_attentions:
-            warnings.warn("`GatedDeltaNetModelCP` does not `output_attentions` now, setting it to `False`.")
+            warnings.warn(
+                "`GatedDeltaNetModelCP` does not `output_attentions` now, setting it to `False`.",
+                stacklevel=2,
+            )
             output_attentions = False

301-315: Apply the recommended exception re-raising patterns in generate method.

The current code uses raise exception in the else branch, which obscures the original traceback. The recommended pattern is to use plain raise (no arguments) to re-raise and preserve the original traceback, or use explicit chaining raise NewError(...) from original_exc to raise a new exception while attaching the original as the cause.

Suggested fix:

  • Replace raise exception with bare raise
  • Change raise AttributeError(...) to raise AttributeError(...) from err to preserve exception chain
  • Rename the caught variable from exception to err for clarity
    def generate(self, *args, **kwargs):
        try:
            return super().generate(*args, **kwargs)
-        except AttributeError as exception:
-            if 'past_key_values' in str(exception):
+        except AttributeError as err:
+            if 'past_key_values' in str(err):
                 raise AttributeError(
                     f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
                     f"which is not supported for {self.__class__.__name__}. "
                     f"Try another generation strategy instead. "
                     f"For the available generation strategies, check this doc: "
                     f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
-                )
+                ) from err
             else:
-                raise exception
+                raise
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 231e3d4 and a5a967c.

📒 Files selected for processing (17)
  • fla/layers/gated_deltanet_cp.py (1 hunks)
  • fla/models/gated_deltanet/modeling_gated_deltanet_cp.py (1 hunks)
  • fla/ops/gated_delta_rule/__init__.py (1 hunks)
  • fla/ops/gated_delta_rule/chunk_cp.py (1 hunks)
  • fla/ops/gated_delta_rule/cp_halo.py (1 hunks)
  • tests/cp/README.md (1 hunks)
  • tests/cp/benchmark_memory.py (1 hunks)
  • tests/cp/test_cp_halo.py (1 hunks)
  • tests/cp/test_cp_halo_unit.py (1 hunks)
  • tests/cp/test_cp_stability.py (1 hunks)
  • tests/cp/test_gdn_cp_backward.py (1 hunks)
  • tests/cp/test_gdn_cp_forward.py (1 hunks)
  • tests/cp/test_numerical_correctness.py (1 hunks)
  • tests/cp/train_cp_real.py (1 hunks)
  • tests/ops/cp/README.md (1 hunks)
  • tests/ops/cp/test_gated_delta_cp_fake.py (1 hunks)
  • tests/ops/cp/test_gated_delta_cp_real.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (15)
tests/cp/test_cp_stability.py (3)
fla/models/gated_deltanet/configuration_gated_deltanet.py (1)
  • GatedDeltaNetConfig (9-103)
fla/models/gated_deltanet/modeling_gated_deltanet_cp.py (1)
  • GatedDeltaNetForCausalLMCP (268-383)
tests/cp/test_gdn_cp_backward.py (1)
  • setup (111-129)
fla/layers/gated_deltanet_cp.py (7)
fla/layers/utils.py (2)
  • get_unpad_data (75-98)
  • pad_input (176-197)
fla/modules/fused_norm_gate.py (1)
  • FusedRMSNormGated (997-1058)
fla/modules/layernorm.py (1)
  • RMSNorm (1064-1111)
fla/modules/convolution.py (1)
  • ShortConvolution (794-1009)
fla/ops/gated_delta_rule/fused_recurrent.py (1)
  • fused_recurrent_gated_delta_rule (242-353)
fla/ops/gated_delta_rule/chunk_cp.py (2)
  • chunk_gated_delta_rule_cp (171-235)
  • forward (24-100)
fla/ops/gated_delta_rule/cp_halo.py (3)
  • halo_exchange_and_extend (12-115)
  • halo_exchange_and_extend_autograd (226-245)
  • forward (128-164)
tests/cp/test_numerical_correctness.py (2)
fla/models/gated_deltanet/configuration_gated_deltanet.py (1)
  • GatedDeltaNetConfig (9-103)
fla/models/gated_deltanet/modeling_gated_deltanet_cp.py (1)
  • GatedDeltaNetForCausalLMCP (268-383)
tests/cp/test_gdn_cp_backward.py (4)
fla/models/gated_deltanet/configuration_gated_deltanet.py (1)
  • GatedDeltaNetConfig (9-103)
fla/models/gated_deltanet/modeling_gated_deltanet_cp.py (1)
  • GatedDeltaNetForCausalLMCP (268-383)
fla/ops/gated_delta_rule/cp_halo.py (1)
  • backward (167-223)
tests/cp/test_cp_stability.py (1)
  • setup (20-29)
tests/ops/cp/test_gated_delta_cp_real.py (2)
fla/ops/gated_delta_rule/chunk.py (2)
  • chunk_gated_delta_rule_bwd (72-149)
  • chunk_gated_delta_rule_fwd (18-69)
fla/utils.py (1)
  • assert_close (82-93)
tests/cp/test_cp_halo.py (1)
fla/ops/gated_delta_rule/cp_halo.py (1)
  • halo_exchange_and_extend (12-115)
tests/cp/test_cp_halo_unit.py (2)
fla/ops/gated_delta_rule/cp_halo.py (1)
  • halo_exchange_and_extend_autograd (226-245)
tests/cp/test_gdn_cp_forward.py (2)
  • setup_dist (99-107)
  • run_cp (110-170)
tests/cp/train_cp_real.py (4)
fla/models/gated_deltanet/configuration_gated_deltanet.py (1)
  • GatedDeltaNetConfig (9-103)
fla/models/gated_deltanet/modeling_gated_deltanet_cp.py (1)
  • GatedDeltaNetForCausalLMCP (268-383)
fla/ops/gated_delta_rule/chunk_cp.py (1)
  • backward (105-168)
fla/ops/gated_delta_rule/cp_halo.py (1)
  • backward (167-223)
fla/ops/gated_delta_rule/cp_halo.py (1)
fla/layers/gated_deltanet_cp.py (1)
  • forward (206-369)
tests/cp/test_gdn_cp_forward.py (2)
fla/models/gated_deltanet/configuration_gated_deltanet.py (1)
  • GatedDeltaNetConfig (9-103)
fla/models/gated_deltanet/modeling_gated_deltanet_cp.py (1)
  • GatedDeltaNetForCausalLMCP (268-383)
fla/ops/gated_delta_rule/__init__.py (1)
fla/ops/gated_delta_rule/cp_halo.py (2)
  • halo_exchange_and_extend (12-115)
  • halo_exchange_and_extend_autograd (226-245)
fla/ops/gated_delta_rule/chunk_cp.py (7)
fla/modules/l2norm.py (3)
  • l2norm (266-271)
  • l2norm_bwd (195-239)
  • l2norm_fwd (149-192)
fla/ops/common/chunk_delta_h.py (2)
  • chunk_gated_delta_rule_bwd_dhu (480-529)
  • chunk_gated_delta_rule_fwd_h (431-477)
fla/ops/common/chunk_o.py (3)
  • chunk_bwd_dqkwg (624-685)
  • chunk_bwd_dv_local (577-621)
  • chunk_fwd_o (485-522)
fla/ops/common/chunk_scaled_dot_kkt.py (1)
  • chunk_scaled_dot_kkt_fwd (222-309)
fla/ops/utils/cumsum.py (1)
  • chunk_local_cumsum (432-469)
fla/ops/utils/solve_tril.py (1)
  • solve_tril (337-384)
fla/ops/gated_delta_rule/chunk.py (2)
  • chunk_gated_delta_rule_fwd (18-69)
  • chunk_gated_delta_rule_bwd (72-149)
tests/cp/benchmark_memory.py (3)
fla/models/gated_deltanet/configuration_gated_deltanet.py (1)
  • GatedDeltaNetConfig (9-103)
fla/models/gated_deltanet/modeling_gated_deltanet_cp.py (1)
  • GatedDeltaNetForCausalLMCP (268-383)
fla/ops/gated_delta_rule/chunk_cp.py (1)
  • backward (105-168)
fla/models/gated_deltanet/modeling_gated_deltanet_cp.py (6)
fla/layers/gated_deltanet_cp.py (2)
  • GatedDeltaNet (38-369)
  • forward (206-369)
fla/models/gated_deltanet/configuration_gated_deltanet.py (1)
  • GatedDeltaNetConfig (9-103)
fla/models/utils.py (1)
  • FLAGenerationMixin (385-462)
fla/modules/layernorm.py (1)
  • RMSNorm (1064-1111)
fla/models/modeling_layers.py (1)
  • GradientCheckpointingLayer (11-71)
fla/ops/gated_delta_rule/chunk_cp.py (1)
  • forward (24-100)
tests/ops/cp/test_gated_delta_cp_fake.py (2)
fla/ops/gated_delta_rule/chunk.py (2)
  • chunk_gated_delta_rule_fwd (18-69)
  • chunk_gated_delta_rule_bwd (72-149)
fla/utils.py (1)
  • assert_close (82-93)
🪛 Flake8 (7.3.0)
fla/layers/gated_deltanet_cp.py

[error] 11-11: 'torch.distributed as dist' imported but unused

(F401)


[error] 18-18: 'fla.ops.gated_delta_rule.chunk_gated_delta_rule' imported but unused

(F401)


[error] 231-231: local variable 'cp_shard_start_idx' is assigned to but never used

(F841)

tests/cp/test_numerical_correctness.py

[error] 95-95: f-string is missing placeholders

(F541)


[error] 170-170: f-string is missing placeholders

(F541)


[error] 173-173: f-string is missing placeholders

(F541)


[error] 238-238: f-string is missing placeholders

(F541)


[error] 259-259: f-string is missing placeholders

(F541)


[error] 287-287: f-string is missing placeholders

(F541)


[error] 288-288: f-string is missing placeholders

(F541)


[error] 293-293: f-string is missing placeholders

(F541)


[error] 297-297: f-string is missing placeholders

(F541)


[error] 298-298: f-string is missing placeholders

(F541)


[error] 308-308: continuation line under-indented for visual indent

(E128)

tests/cp/test_cp_halo_unit.py

[error] 27-27: 'typing.Tuple' imported but unused

(F401)

tests/cp/train_cp_real.py

[error] 16-16: 'typing.Optional' imported but unused

(F401)


[error] 21-21: 'transformers.AutoTokenizer' imported but unused

(F401)


[error] 23-23: 'torch.distributed.nn.all_reduce' imported but unused

(F401)


[error] 192-192: local variable 'attention_mask' is assigned to but never used

(F841)


[error] 296-296: f-string is missing placeholders

(F541)


[error] 316-316: f-string is missing placeholders

(F541)


[error] 338-338: f-string is missing placeholders

(F541)


[error] 365-365: f-string is missing placeholders

(F541)


[error] 386-386: f-string is missing placeholders

(F541)

fla/ops/gated_delta_rule/cp_halo.py

[error] 97-97: multiple statements on one line (semicolon)

(E702)


[error] 97-97: multiple statements on one line (semicolon)

(E702)

tests/cp/test_gdn_cp_forward.py

[error] 20-20: 'typing.Optional' imported but unused

(F401)

fla/ops/gated_delta_rule/chunk_cp.py

[error] 8-8: 'fla.ops.common.chunk_delta_h.chunk_gated_delta_rule_bwd_dhu' imported but unused

(F401)


[error] 8-8: 'fla.ops.common.chunk_delta_h.chunk_gated_delta_rule_fwd_h' imported but unused

(F401)


[error] 9-9: 'fla.ops.common.chunk_o.chunk_bwd_dqkwg' imported but unused

(F401)


[error] 9-9: 'fla.ops.common.chunk_o.chunk_bwd_dv_local' imported but unused

(F401)


[error] 9-9: 'fla.ops.common.chunk_o.chunk_fwd_o' imported but unused

(F401)


[error] 10-10: 'fla.ops.common.chunk_scaled_dot_kkt.chunk_scaled_dot_kkt_fwd' imported but unused

(F401)


[error] 11-11: 'fla.ops.gated_delta_rule.wy_fast.prepare_wy_repr_bwd' imported but unused

(F401)


[error] 11-11: 'fla.ops.gated_delta_rule.wy_fast.recompute_w_u_fwd' imported but unused

(F401)


[error] 12-12: 'fla.ops.utils.chunk_local_cumsum' imported but unused

(F401)


[error] 12-12: 'fla.ops.utils.solve_tril' imported but unused

(F401)

tests/cp/benchmark_memory.py

[error] 152-152: f-string is missing placeholders

(F541)


[error] 180-180: f-string is missing placeholders

(F541)


[error] 205-205: f-string is missing placeholders

(F541)


[error] 206-206: f-string is missing placeholders

(F541)


[error] 207-207: f-string is missing placeholders

(F541)

🪛 markdownlint-cli2 (0.18.1)
tests/cp/README.md

24-24: Fenced code blocks should have a language specified

(MD040, fenced-code-language)


38-38: Fenced code blocks should have a language specified

(MD040, fenced-code-language)


59-59: Fenced code blocks should have a language specified

(MD040, fenced-code-language)

🪛 Ruff (0.14.5)
tests/cp/test_cp_stability.py

1-1: Shebang is present but file is not executable

(EXE001)

fla/layers/gated_deltanet_cp.py

98-98: PEP 484 prohibits implicit Optional

Convert to T | None

(RUF013)


105-105: PEP 484 prohibits implicit Optional

Convert to T | None

(RUF013)


107-107: Unused method argument: kwargs

(ARG002)


133-136: Avoid specifying long messages outside the exception class

(TRY003)


138-140: Avoid specifying long messages outside the exception class

(TRY003)


143-146: Avoid specifying long messages outside the exception class

(TRY003)


195-195: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


212-212: Unused method argument: output_attentions

(ARG002)


231-231: Local variable cp_shard_start_idx is assigned to but never used

Remove assignment to unused variable cp_shard_start_idx

(F841)

tests/cp/test_numerical_correctness.py

1-1: Shebang is present but file is not executable

(EXE001)


35-35: Avoid specifying long messages outside the exception class

(TRY003)


95-95: f-string without any placeholders

Remove extraneous f prefix

(F541)


170-170: f-string without any placeholders

Remove extraneous f prefix

(F541)


173-173: f-string without any placeholders

Remove extraneous f prefix

(F541)


186-186: Comment contains ambiguous (EN DASH). Did you mean - (HYPHEN-MINUS)?

(RUF003)


238-238: f-string without any placeholders

Remove extraneous f prefix

(F541)


259-259: f-string without any placeholders

Remove extraneous f prefix

(F541)


287-287: f-string without any placeholders

Remove extraneous f prefix

(F541)


288-288: f-string without any placeholders

Remove extraneous f prefix

(F541)


293-293: f-string without any placeholders

Remove extraneous f prefix

(F541)


297-297: f-string without any placeholders

Remove extraneous f prefix

(F541)


298-298: f-string without any placeholders

Remove extraneous f prefix

(F541)

tests/cp/test_gdn_cp_backward.py

1-1: Shebang is present but file is not executable

(EXE001)


134-134: Unpacked variable rank is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


148-148: Unpacked variable B is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


177-177: Avoid specifying long messages outside the exception class

(TRY003)


188-188: Avoid specifying long messages outside the exception class

(TRY003)


205-205: Avoid specifying long messages outside the exception class

(TRY003)

tests/ops/cp/test_gated_delta_cp_real.py

72-72: Unused method argument: world_size

(ARG002)


156-156: Unpacked variable o_ref is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


169-169: Unpacked variable dh0_ref is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


195-195: Unpacked variable o_cp is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

tests/cp/test_cp_halo.py

1-1: Shebang is present but file is not executable

(EXE001)


52-52: Unpacked variable k_ext is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


52-52: Unpacked variable v_ext is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

tests/cp/test_cp_halo_unit.py

1-1: Shebang is present but file is not executable

(EXE001)


70-70: Unpacked variable B is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


70-70: Unpacked variable T is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


84-84: Unpacked variable T is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


151-151: Unused function argument: varlen

(ARG001)


168-168: Unpacked variable B is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


187-187: Unpacked variable k_ext is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


187-187: Unpacked variable v_ext is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


187-187: Unpacked variable cu_ext is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

tests/cp/train_cp_real.py

1-1: Shebang is present but file is not executable

(EXE001)


38-41: Avoid specifying long messages outside the exception class

(TRY003)


145-145: Unpacked variable batch_size is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


192-192: Local variable attention_mask is assigned to but never used

Remove assignment to unused variable attention_mask

(F841)


220-220: Consider moving this statement to an else block

(TRY300)


222-222: Do not catch blind exception: Exception

(BLE001)


290-290: Unpacked variable local_rank is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


296-296: f-string without any placeholders

Remove extraneous f prefix

(F541)


316-316: f-string without any placeholders

Remove extraneous f prefix

(F541)


338-338: f-string without any placeholders

Remove extraneous f prefix

(F541)


365-365: f-string without any placeholders

Remove extraneous f prefix

(F541)


386-386: f-string without any placeholders

Remove extraneous f prefix

(F541)

fla/ops/gated_delta_rule/cp_halo.py

97-97: Multiple statements on one line (semicolon)

(E702)


97-97: Multiple statements on one line (semicolon)

(E702)

tests/cp/test_gdn_cp_forward.py

1-1: Shebang is present but file is not executable

(EXE001)


110-110: Unused function argument: varlen

(ARG001)


133-133: Unpacked variable B is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

fla/ops/gated_delta_rule/chunk_cp.py

42-42: Unpacked variable T is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


115-115: Unpacked variable T is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


177-177: PEP 484 prohibits implicit Optional

Convert to T | None

(RUF013)


195-195: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


200-200: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


208-211: Avoid specifying long messages outside the exception class

(TRY003)


213-216: Avoid specifying long messages outside the exception class

(TRY003)

tests/cp/benchmark_memory.py

1-1: Shebang is present but file is not executable

(EXE001)


152-152: f-string without any placeholders

Remove extraneous f prefix

(F541)


180-180: f-string without any placeholders

Remove extraneous f prefix

(F541)


205-205: f-string without any placeholders

Remove extraneous f prefix

(F541)


206-206: f-string without any placeholders

Remove extraneous f prefix

(F541)


207-207: f-string without any placeholders

Remove extraneous f prefix

(F541)

fla/models/gated_deltanet/modeling_gated_deltanet_cp.py

124-124: Mutable class attributes should be annotated with typing.ClassVar

(RUF012)


171-171: Avoid specifying long messages outside the exception class

(TRY003)


212-212: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


221-221: Avoid specifying long messages outside the exception class

(TRY003)


223-223: Avoid specifying long messages outside the exception class

(TRY003)


271-271: Mutable class attributes should be annotated with typing.ClassVar

(RUF012)


306-312: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling

(B904)


306-312: Avoid specifying long messages outside the exception class

(TRY003)


314-314: Use raise without specifying exception name

Remove exception name

(TRY201)


374-374: Consider (logits, *outputs[1:]) instead of concatenation

Replace with (logits, *outputs[1:])

(RUF005)


375-375: Consider (loss, *output) instead of concatenation

Replace with (loss, *output)

(RUF005)

tests/ops/cp/test_gated_delta_cp_fake.py

50-50: Unpacked variable g_out is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


50-50: Unpacked variable A is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


128-128: Unpacked variable dh0_ref is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


148-148: Unpacked variable o_i is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

🔇 Additional comments (21)
tests/cp/test_cp_stability.py (1)

58-78: LGTM: Core test logic is solid.

The test correctly:

  • Partitions the sequence across CP ranks (lines 58-63)
  • Synchronizes with barriers before each iteration (line 71)
  • Passes CP parameters to the model (lines 74-75)
  • Validates stability across multiple iterations
fla/ops/gated_delta_rule/cp_halo.py (2)

127-164: LGTM: Forward pass correctly saves context and delegates.

The forward method properly:

  • Saves all necessary metadata for backward (lines 139-152)
  • Delegates to the underlying exchange function (lines 154-159)
  • Marks metadata tensors as non-differentiable (lines 162-163)

212-220: LGTM: Gradient accumulation correctly handles padding.

The backward pass properly:

  • Receives tail gradients from the next rank (lines 213-215)
  • Only accumulates the non-padded portion when T < h (lines 216-220)
  • Uses in-place addition to accumulate gradients correctly
fla/ops/gated_delta_rule/__init__.py (1)

3-10: LGTM: Clean API extension.

The new halo exchange functions are properly imported and exported, following the existing pattern in this module.

tests/cp/test_cp_halo.py (1)

22-23: Verify that halo_exchange_and_extend is re‑exported at the package level

This test imports:

from fla.ops.gated_delta_rule import halo_exchange_and_extend

while the implementation lives in fla.ops.gated_delta_rule.cp_halo. This assumes halo_exchange_and_extend is re-exported from fla.ops.gated_delta_rule.__init__. If that’s not the case, this import will fail at runtime, and you should either:

  • add the appropriate re-export in fla/ops/gated_delta_rule/__init__.py, or
  • import directly from fla.ops.gated_delta_rule.cp_halo.

Would you confirm in your environment whether from fla.ops.gated_delta_rule import halo_exchange_and_extend works, or whether you need to adjust the import?

fla/ops/gated_delta_rule/chunk_cp.py (1)

170-186: Confirm torch.compiler.disable is available for your minimum PyTorch version

The public wrapper is decorated with:

@torch.compiler.disable
def chunk_gated_delta_rule_cp(...):
    ...

Support for torch.compiler.disable is relatively recent; on older PyTorch versions, this may raise an AttributeError at import time. If you target a wide range of versions, you may need a guarded import or to fall back to torch._dynamo.disable / no decoration.

Could you verify against your target PyTorch version(s) that torch.compiler.disable exists and behaves as expected for this function?

tests/ops/cp/test_gated_delta_cp_real.py (2)

14-44: Baton send/recv pattern is logically consistent

The forward and gradient baton helpers implement a simple chain: rank 0 seeds h0, later ranks recv from rank-1; the last rank seeds dht, earlier ranks recv from rank+1. This avoids circular dependencies and should not deadlock under the current call order. No changes required here.


46-70: CP shard/unshard helpers look correct

cp_shard splitting along shard_dim=1 with cp_degree or self.world_size and cp_unshard via torch.cat match the intended time‑dimension sharding in these tests. No functional issues spotted.

tests/cp/test_cp_halo_unit.py (2)

105-149: run_single reference generation logic looks solid

The single‑rank baseline generation covers both h=1 and h=3, supports optional varlen via cu_seqlens, and properly detaches and moves tensors to CPU for saving. The gradient collection (gx, gw) and reset of .grad fields per h are correct. No changes needed.


243-260: CLI/main orchestration is fine

Argument parsing (mode, cp_size, save_ref, varlen, backward) and dispatch to run_single vs run_cp is straightforward and matches the documented usage examples in the module docstring. No functional issues here.

tests/cp/benchmark_memory.py (2)

46-129: Benchmark flow and memory accounting look reasonable

The benchmark_config function correctly isolates model/data allocations, captures memory usage after model creation, after data creation, after forward, and after backward, and computes deltas along those checkpoints. Error handling for RuntimeError and cleanup (del + torch.cuda.empty_cache()) are also appropriate for a benchmarking script. No functional issues stand out here.


161-213: Top-level loop and summary formatting are clear

The main loop over seq_lens, conditional printing on rank 0, aggregation into results, and final summary table are all straightforward and easy to read. The script should be convenient to run for quick CP memory sweeps. No further changes needed here.

tests/cp/test_gdn_cp_forward.py (3)

32-57: Config and data factories look good

create_config and create_data encode a compact config for quick tests and support both equal-length and varlen via cu_seqlens. The seeding in create_data ensures deterministic inputs for reproducible parity checks. No issues here.


59-97: run_single reference generation is consistent with CP usage

The single-rank run:

  • Covers two conv_size values (2 and 4).
  • Stores full model state_dict, loss, logits, and inputs/labels (plus optional cu_seqlens) per conv size.
  • Uses cp_rank=0, cp_size=1 when calling the CP model, aligning with the CP API.

This sets up exactly the information run_cp needs to reconstruct models and compare logits. No functional problems spotted.


173-188: CLI entrypoint matches documented usage

The main function exposes --mode, --cp_size, --save_ref, and --varlen exactly as described in the module docstring and ties into run_single/run_cp correctly. No issues.

tests/cp/test_numerical_correctness.py (2)

24-38: Distributed setup sanity check is good

setup_distributed initializes NCCL, binds CUDA device via LOCAL_RANK, and explicitly enforces world_size == cp_size, which is exactly what you want for these CP tests. No changes needed here.


305-321: CLI wiring is straightforward

The main entrypoint cleanly dispatches between single-reference generation and CP correctness testing based on --mode, and passes --cp_size/--save_ref through appropriately. No issues here.

tests/cp/train_cp_real.py (1)

29-50: Distributed CP setup looks correct for CP-only runs.

The setup_distributed_cp helper correctly enforces world_size == cp_size, sets the CUDA device from LOCAL_RANK, and builds a single CP group over all ranks, which matches the “CP-only” assumption.

No changes required here.

fla/models/gated_deltanet/modeling_gated_deltanet_cp.py (3)

130-172: Initialization logic is reasonable and matches GatedDeltaNet’s parameterization.

The custom _init_weights:

  • Re-initializes GatedDeltaNet.A_log and dt_bias in a way consistent with the layer’s constructor.
  • Uses standard normal init for Linear/Conv/Embedding.
  • Optionally applies a prenorm residual strategy (rescale or zero) to o_proj/down_proj.

The logic looks consistent with the underlying layer semantics and the config, and should integrate cleanly with post_init().

No changes needed here.


235-247: CP kwargs are correctly threaded through the encoder stack.

The loop over self.layers passes **kwargs down to each GatedDeltaNetBlockCP, which in turn forwards them to the CP-aware attention layer. This is what enables CP rank/size/group to reach the halo and chunk kernels.

This section looks consistent and doesn’t need changes.


316-372: Causal LM forward and CP param propagation are consistent with HF patterns.

GatedDeltaNetForCausalLMCP.forward:

  • Delegates to self.model while threading through **kwargs for CP.
  • Supports logits_to_keep with the standard “keep last K logits” semantics.
  • Handles fused vs non-fused cross-entropy and optional L2 warp in line with the config.
  • Shifts labels for next-token prediction consistent with other causal models in the codebase.

This section looks coherent and aligned with the rest of the modeling stack. No changes required here.

Comment on lines +50 to +66
# Handle initial state for CP
if cp_size > 1:
if cp_rank == 0:
h0_local = initial_state.to(torch.float32) if initial_state is not None else torch.zeros(
(B, H, K, V), dtype=torch.float32, device=device
)
else:
h0_local = torch.empty((B, H, K, V), dtype=torch.float32, device=device)
recv_req = dist.irecv(h0_local, src=cp_rank - 1, group=cp_group)
recv_req.wait()
else:
h0_local = initial_state.to(torch.float32) if initial_state is not None else torch.zeros(
(B, H, K, V), dtype=torch.float32, device=device
)

# Force output_final_state=True for CP to get the final state
force_output_final_state = True if cp_size > 1 else output_final_state
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Fix gradient propagation for initial_state in multi‑rank CP backward

In ChunkGatedDeltaRuleFunctionCP:

  • Forward only uses the user‑provided initial_state when cp_size == 1 or cp_rank == 0. For cp_rank > 0, h0_local is filled via dist.irecv and the initial_state argument is ignored.
  • Backward always computes dh0 (grad wrt the local initial state), sends it to the previous rank (correct for CP baton), and returns it as the gradient for the initial_state input whenever initial_state.requires_grad is True:
ctx.initial_state_requires_grad = initial_state is not None and initial_state.requires_grad
...
dh0_return = dh0 if ctx.initial_state_requires_grad else None
return ..., dh0_return, ...

For cp_size > 1 and cp_rank > 0 this is incorrect: the user-provided initial_state was never used in that shard’s forward, so its gradient contribution from this shard must be zero. The only rank that should contribute a gradient to the input initial_state is rank 0 (or the single-rank case), with all CP shards’ influence already folded in via the dht baton and dh0 on rank 0.

You can fix this by restricting initial_state_requires_grad to the ranks that actually consume the input:

-        ctx.output_final_state = output_final_state  # Store original value
-        ctx.initial_state_requires_grad = initial_state is not None and initial_state.requires_grad  # Track if gradient needed
+        ctx.output_final_state = output_final_state  # Store original value
+        # Only rank 0 (or single‑rank) actually consumes the user‑provided initial_state;
+        # later ranks receive their h0 via dist. Do not propagate grads for initial_state from other ranks.
+        ctx.initial_state_requires_grad = (
+            initial_state is not None
+            and initial_state.requires_grad
+            and (cp_size == 1 or cp_rank == 0)
+        )

The existing dh0_return = dh0 if ctx.initial_state_requires_grad else None in backward can remain as-is once this flag is corrected.

Also applies to: 87-97, 117-139, 140-168

🤖 Prompt for AI Agents
In fla/ops/gated_delta_rule/chunk_cp.py around lines 50-66 (and similarly at
87-97, 117-139, 140-168), the bug is that ctx.initial_state_requires_grad is set
from the presence/requires_grad of the user-provided initial_state on all ranks,
but ranks with cp_size>1 and cp_rank>0 never actually use the user initial_state
in forward (they receive h0 via dist.irecv), so they must not report gradients
for that input; fix by setting ctx.initial_state_requires_grad = (initial_state
is not None and initial_state.requires_grad) AND (cp_size == 1 or cp_rank == 0)
when you save ctx in forward on all relevant blocks, so that backward only
returns dh0 for the input initial_state on the rank(s) that consumed it (rank 0
or single-rank), leaving other ranks to return None as gradient for that
argument.

Comment on lines +33 to +78
def test_basic_halo(cp_size: int, h: int = 2):
"""Test that rank r receives the tail from rank r-1 as its left halo."""
rank, world_size = _init_dist('gloo')
assert world_size == cp_size, f"world_size({world_size}) != cp_size({cp_size})"

# Shapes
B = 1
T_local = 6
Dq, Dk, Dv = 3, 4, 5

# Create deterministic inputs per rank
torch.manual_seed(2025 + rank)
q_in = torch.randn(B, T_local, Dq)
k_in = torch.randn(B, T_local, Dk)
v_in = torch.randn(B, T_local, Dv)

# No varlen here
cu_seqlens = None

q_ext, k_ext, v_ext, cu_ext = halo_exchange_and_extend(
q_in, k_in, v_in, h,
cp_rank=rank, cp_size=cp_size, cp_group=dist.group.WORLD,
cu_seqlens=cu_seqlens, cp_shard_start_idx=None,
)

# Gather previous rank tails on ALL ranks to avoid deadlock
q_tails = [torch.zeros(B, h, Dq) for _ in range(world_size)]
local_tail = q_in[:, -h:, :].contiguous()
dist.all_gather(q_tails, local_tail)

# Rank 0 should have zero halo
if rank == 0:
assert torch.allclose(q_ext[:, :h], torch.zeros_like(q_ext[:, :h]))
assert cu_ext is None
# Rank > 0 should have halo equal to last h tokens of previous rank
else:
prev_tail = q_tails[rank - 1]
assert torch.allclose(q_ext[:, :h, :], prev_tail, atol=1e-6), (
f"Rank {rank}: halo mismatch vs prev rank tail"
)

if dist.is_initialized():
dist.barrier()


def test_varlen_zero_halo(cp_size: int, h: int = 2):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Prevent script-style CP halo tests from breaking pytest discovery

Both test_basic_halo and test_varlen_zero_halo take a required cp_size argument and are meant to be driven via the main() entrypoint with torchrun. Under normal pytest collection, these names will still be treated as tests, and cp_size will be interpreted as a fixture. Since there is no cp_size fixture, the test suite will fail during collection.

If these are intended as manual / torchrun utilities rather than automated pytest tests, consider renaming them so they are not auto‑collected, and adjust main() accordingly, for example:

-def test_basic_halo(cp_size: int, h: int = 2):
+def run_basic_halo(cp_size: int, h: int = 2):
     ...

-def test_varlen_zero_halo(cp_size: int, h: int = 2):
+def run_varlen_zero_halo(cp_size: int, h: int = 2):
     ...
@@
-    test_basic_halo(args.cp_size)
-    test_varlen_zero_halo(args.cp_size)
+    run_basic_halo(args.cp_size)
+    run_varlen_zero_halo(args.cp_size)

That keeps the torchrun UX the same while preventing accidental CI failures from pytest treating these as normal unit tests.

Also applies to: 78-120

🧰 Tools
🪛 Ruff (0.14.5)

52-52: Unpacked variable k_ext is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


52-52: Unpacked variable v_ext is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

🤖 Prompt for AI Agents
In tests/cp/test_cp_halo.py around lines 33 to 78 (and similarly lines 78 to
120), the functions test_basic_halo and test_varlen_zero_halo are defined as
pytest-discoverable tests but require a cp_size argument and are intended to be
run via the main()/torchrun entrypoint; rename them (e.g., to run_basic_halo and
run_varlen_zero_halo) or add a pytest skip/collection guard so pytest does not
treat them as tests during collection, and update the main() invocation to call
the new names so the torchrun UX is unchanged.

Comment on lines +20 to +29
def setup(cp_size: int):
backend = 'nccl' if torch.cuda.is_available() else 'gloo'
dist.init_process_group(backend=backend)
rank = dist.get_rank()
local_rank = int(os.environ.get('LOCAL_RANK', 0))
device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else 'cpu')
if torch.cuda.is_available():
torch.cuda.set_device(local_rank)
group = dist.new_group(list(range(cp_size)))
return rank, device, group
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

Verify group creation logic and LOCAL_RANK handling.

Two concerns:

  1. Line 24: Defaulting LOCAL_RANK to 0 can cause device collisions when multiple ranks run on the same node. In distributed scenarios without LOCAL_RANK set, this will fail or produce incorrect behavior.

  2. Line 28: dist.new_group(list(range(cp_size))) assumes the world consists of ranks 0 through cp_size-1. If world_size > cp_size, this creates a subgroup that might not match the intended CP group. Consider asserting world_size == cp_size earlier or clarifying the intended group membership.

Apply this diff to add an assertion and improve LOCAL_RANK handling:

 def setup(cp_size: int):
     backend = 'nccl' if torch.cuda.is_available() else 'gloo'
     dist.init_process_group(backend=backend)
+    world_size = dist.get_world_size()
+    assert world_size == cp_size, f"world_size ({world_size}) must equal cp_size ({cp_size})"
     rank = dist.get_rank()
-    local_rank = int(os.environ.get('LOCAL_RANK', 0))
+    if torch.cuda.is_available():
+        local_rank = int(os.environ['LOCAL_RANK'])
+    else:
+        local_rank = 0
     device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else 'cpu')
     if torch.cuda.is_available():
         torch.cuda.set_device(local_rank)
-    group = dist.new_group(list(range(cp_size)))
+    group = dist.group.WORLD  # Use full world as CP group when world_size == cp_size
     return rank, device, group

🏁 Script executed:

#!/bin/bash
# Locate and examine the test file and setup function
fd -t f "test_cp_stability.py" --exec cat {} \;

Length of output: 2554


🏁 Script executed:

#!/bin/bash
# Search for all usages of the setup function to understand how cp_size is used
rg -n "setup\(" tests/ -A 2 -B 2 --type py

Length of output: 1221


🏁 Script executed:

#!/bin/bash
# Check how torch.distributed.new_group is used elsewhere in codebase
rg -n "new_group" --type py -C 3

Length of output: 2917


🏁 Script executed:

#!/bin/bash
# Verify LOCAL_RANK environment variable usage patterns
rg -n "LOCAL_RANK" --type py -C 2

Length of output: 2896


🏁 Script executed:

#!/bin/bash
# Check torch.distributed.group usage in codebase
rg -n "dist\.group\." --type py -C 2

Length of output: 1053


🏁 Script executed:

#!/bin/bash
# Look for world_size assertions or comparisons
rg -n "world_size" --type py -C 2

Length of output: 12271


🏁 Script executed:

#!/bin/bash
# Check if there are any comments or documentation about cp_size expectations
rg -n "cp_size" --type py -C 3

Length of output: 58292


🏁 Script executed:

#!/bin/bash
# Verify torch.cuda.set_device usage patterns
rg -n "set_device" --type py -C 2

Length of output: 2643


🏁 Script executed:

#!/bin/bash
# Check what dist module functions are commonly used after init_process_group
rg -n "init_process_group" --type py -A 5 -B 2

Length of output: 4793


🏁 Script executed:

#!/bin/bash
# Verify get_world_size usage
rg -n "get_world_size" --type py -C 2

Length of output: 1994


🏁 Script executed:

#!/bin/bash
# Check if os.environ with default vs mandatory key access for LOCAL_RANK
rg -n "os\.environ\.get.*LOCAL_RANK\|os\.environ\[.*LOCAL_RANK" --type py -C 2

Length of output: 0


🏁 Script executed:

#!/bin/bash
# Look for any documentation files explaining the distributed setup
fd -t f -e md -e rst -e txt | xargs grep -l "LOCAL_RANK\|cp_size" 2>/dev/null || echo "No docs found"

Length of output: 18


🏁 Script executed:

#!/bin/bash
# Check PyTorch version requirements or constraints
fd -t f "requirements\|setup\|pyproject" | head -20

Length of output: 0


🏁 Script executed:

#!/bin/bash
# Look at how the test file structure - what tests use setup()
rg -n "def test_" tests/cp/test_cp_stability.py -A 3 || echo "No test functions found"

Length of output: 23


🏁 Script executed:

#!/bin/bash
# Check if there are any fixtures or initialization that set up environments
rg -n "pytest\|@fixture\|LOCAL_RANK" tests/cp/ -C 2

Length of output: 0


🏁 Script executed:

#!/bin/bash
# Verify if the diff's suggestion (dist.group.WORLD) is the right approach
python3 << 'EOF'
# Check PyTorch distributed API availability
try:
    import torch.distributed as dist
    print("Available in dist.group:")
    print(dir(dist.group))
except Exception as e:
    print(f"Error: {e}")
EOF

Length of output: 30


🏁 Script executed:

#!/bin/bash
# Get full content of the test file for complete context
cat tests/cp/test_cp_stability.py 2>/dev/null || echo "File not found in tests/cp/"

Length of output: 2554


🏁 Script executed:

#!/bin/bash
# Search more broadly for this file
find . -name "*test_cp*" -type f 2>/dev/null | head -20

Length of output: 90


🏁 Script executed:

#!/bin/bash
# If file not found, search for similar patterns
fd -t f "stability" --type py | head -20

Length of output: 169


🏁 Script executed:

#!/bin/bash
# Check repository structure
ls -la 2>/dev/null | head -20

Length of output: 1188


🏁 Script executed:

#!/bin/bash
# Find test directories
fd -t d "tests" -o -t d "test" 2>/dev/null | head -10

Length of output: 0


🏁 Script executed:

#!/bin/bash
# List all Python files in top-level test areas
fd -t f "\.py$" tests/ 2>/dev/null | head -20 || fd -t f "\.py$" test/ 2>/dev/null | head -20

Length of output: 706


🏁 Script executed:

#!/bin/bash
# Search for any process group initialization pattern
rg -n "init_process_group\|new_group\|get_rank" --type py | head -30

Length of output: 0


🏁 Script executed:

#!/bin/bash
# Look for distributed setup patterns
rg -n "def setup\(" --type py -A 15 | head -50

Length of output: 2338


🏁 Script executed:

#!/bin/bash
# Check if there is a conftest or any setup file
fd -t f "conftest.py" -o -t f "setup.py"

Length of output: 116


🏁 Script executed:

#!/bin/bash
# Broad search for cp_size pattern
rg "cp_size" --type py

Length of output: 11740


🏁 Script executed:

#!/bin/bash
# Search across all files for the exact function definition
rg -n "def setup\(cp_size" --type py -A 15

Length of output: 2338


🏁 Script executed:

#!/bin/bash
# Get current working directory context
pwd && ls -la

Length of output: 1567


🏁 Script executed:

#!/bin/bash
# Check if repo is cloned and list structure
find . -type d -name "tests" -o -type d -name "test" 2>/dev/null

Length of output: 7


🏁 Script executed:

#!/bin/bash
# Try to locate the file more flexibly
rg -l "test_cp_stability" 2>/dev/null || rg -l "def setup" tests/ --type py 2>/dev/null | head -5

Length of output: 29


🏁 Script executed:

#!/bin/bash
# Get more info
git log --oneline -1 2>/dev/null || echo "Not a git repo or git unavailable"

Length of output: 34


🏁 Script executed:

#!/bin/bash
# Check for PR details using GitHub CLI if available
gh pr view 2>/dev/null || echo "GitHub CLI not configured or not in PR context"

Length of output: 46


🏁 Script executed:

#!/bin/bash
# Fallback: search everywhere for setup function with cp_size
rg -n "def setup\(" . --type py 2>/dev/null | grep -i cp || rg -n "cp_size" . --type py 2>/dev/null | head -10

Length of output: 123


Add validation for world_size == cp_size and fix LOCAL_RANK handling for multi-node distributed training.

Two critical issues confirmed:

  1. Line 24: Defaulting LOCAL_RANK to 0 when CUDA is available risks device collisions. When running on multi-node clusters without LOCAL_RANK set, all ranks may attempt to use the same GPU. In contrast, other tests with distributed setup (test_gdn_cp_backward.py) require LOCAL_RANK as a mandatory environment variable when using GPUs.

  2. Missing validation: Unlike similar test files that validate world_size == cp_size, this setup function silently assumes all launched ranks belong to the CP group, which could mask misconfiguration.

Apply this diff:

 def setup(cp_size: int):
     backend = 'nccl' if torch.cuda.is_available() else 'gloo'
     dist.init_process_group(backend=backend)
+    world_size = dist.get_world_size()
+    assert world_size == cp_size, f"world_size ({world_size}) must equal cp_size ({cp_size})"
     rank = dist.get_rank()
-    local_rank = int(os.environ.get('LOCAL_RANK', 0))
+    if torch.cuda.is_available():
+        local_rank = int(os.environ['LOCAL_RANK'])
+    else:
+        local_rank = 0
     device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else 'cpu')
     if torch.cuda.is_available():
         torch.cuda.set_device(local_rank)
-    group = dist.new_group(list(range(cp_size)))
+    group = dist.group.WORLD
     return rank, device, group
🤖 Prompt for AI Agents
In tests/cp/test_cp_stability.py around lines 20 to 29, validate that
dist.get_world_size() equals the provided cp_size and enforce proper LOCAL_RANK
handling for GPU runs: after init_process_group, call world_size =
dist.get_world_size() and raise a clear ValueError if world_size != cp_size; for
CUDA paths do not default LOCAL_RANK to 0 — require the LOCAL_RANK environment
variable to be set (use os.environ['LOCAL_RANK'] and raise an error if missing)
before parsing it to int and setting the device and
torch.cuda.set_device(local_rank).

Comment on lines +14 to +23
import argparse
import os
from typing import Optional

import torch
import torch.distributed as dist
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer
# Import differentiable all_reduce
from torch.distributed.nn import all_reduce
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Clean up unused imports and align with linting (F401).

Optional, AutoTokenizer, and all_reduce are imported but never used, which will trigger Flake8 (F401) and Ruff warnings. Since the script doesn’t currently rely on them, they should be removed.

You can apply something like:

-import argparse
-import os
-from typing import Optional
-
-import torch
-import torch.distributed as dist
-from torch.utils.data import DataLoader, Dataset
-from transformers import AutoTokenizer
-# Import differentiable all_reduce
-from torch.distributed.nn import all_reduce
+import argparse
+import os
+
+import torch
+import torch.distributed as dist
+from torch.utils.data import DataLoader, Dataset
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
import argparse
import os
from typing import Optional
import torch
import torch.distributed as dist
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer
# Import differentiable all_reduce
from torch.distributed.nn import all_reduce
import argparse
import os
import torch
import torch.distributed as dist
from torch.utils.data import DataLoader, Dataset
🧰 Tools
🪛 Flake8 (7.3.0)

[error] 16-16: 'typing.Optional' imported but unused

(F401)


[error] 21-21: 'transformers.AutoTokenizer' imported but unused

(F401)


[error] 23-23: 'torch.distributed.nn.all_reduce' imported but unused

(F401)

🤖 Prompt for AI Agents
In tests/cp/train_cp_real.py around lines 14 to 23, there are unused imports:
Optional, AutoTokenizer, and all_reduce are imported but not used; remove these
three imports to satisfy linting (F401) and Ruff. Keep the necessary imports
(argparse, os, torch, torch.distributed as dist, DataLoader, Dataset) and update
any import lines accordingly so the file compiles with no unused import
warnings.

Comment on lines +176 to +225
def compute_loss_with_cp(
model: GatedDeltaNetForCausalLMCP,
batch: dict,
cp_rank: int,
cp_size: int,
cp_group,
device: torch.device
) -> torch.Tensor:
"""
Compute loss with context parallelism.

Important: Each rank computes loss only on its chunk.
NO aggregation needed - gradients flow naturally through autograd!
"""
# Move to device
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['labels'].to(device)

# Debug: Check tensor shapes and values
if torch.any(torch.isnan(input_ids.float())):
print(f"Rank {cp_rank}: NaN detected in input_ids")
return torch.tensor(0.0, device=device, requires_grad=True)

try:
# Forward pass (each rank processes its chunk)
# Disable cu_seqlens for CP to avoid conflicts
outputs = model(
input_ids=input_ids,
attention_mask=None, # Disable attention_mask for now to avoid cu_seqlens
labels=labels,
cp_rank=cp_rank,
cp_size=cp_size,
cp_group=cp_group,
cp_shard_start_idx=batch.get('chunk_start', 0), # important for halo varlen safety
)

loss = outputs.loss

# Check for NaN loss
if torch.isnan(loss):
print(f"Rank {cp_rank}: NaN loss detected")
return torch.tensor(0.0, device=device, requires_grad=True)

return loss

except Exception as e:
print(f"Rank {cp_rank}: Error in forward pass: {e}")
return torch.tensor(0.0, device=device, requires_grad=True)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Remove dead attention_mask assignment and reconsider broad Exception catch.

Two issues here:

  1. Dead local attention_mask (Flake8 F841)
    You assign attention_mask = batch['attention_mask'].to(device) but always pass attention_mask=None to the model. This is confusing and trips F841.

    Suggested change:

-    input_ids = batch['input_ids'].to(device)
-    attention_mask = batch['attention_mask'].to(device)
-    labels = batch['labels'].to(device)
+    input_ids = batch['input_ids'].to(device)
+    labels = batch['labels'].to(device)

(You can re-thread attention_mask later when CP + cu_seqlens interaction is resolved.)

  1. Overly broad except Exception (Ruff BLE001)
    Catching all exceptions and silently returning a zero loss will hide real bugs in the CP stack and can make CI/debugging very hard. At minimum, consider re-raising after logging:
-    except Exception as e:
-        print(f"Rank {cp_rank}: Error in forward pass: {e}")
-        return torch.tensor(0.0, device=device, requires_grad=True)
+    except Exception as e:
+        print(f"Rank {cp_rank}: Error in forward pass: {e}")
+        raise

If you still want “best-effort” training, narrow the exception type (e.g., RuntimeError) and document the trade-off.

🧰 Tools
🪛 Flake8 (7.3.0)

[error] 192-192: local variable 'attention_mask' is assigned to but never used

(F841)

🪛 Ruff (0.14.5)

192-192: Local variable attention_mask is assigned to but never used

Remove assignment to unused variable attention_mask

(F841)


220-220: Consider moving this statement to an else block

(TRY300)


222-222: Do not catch blind exception: Exception

(BLE001)

🤖 Prompt for AI Agents
In tests/cp/train_cp_real.py around lines 176-225, remove the dead local
assignment attention_mask = batch['attention_mask'].to(device) (or actually pass
that variable into model if you intend to use it later) so you don't trigger
Flake8 F841, and replace the broad except Exception: block with a narrower
handler (e.g., except RuntimeError as e:) that logs the error and re-raises it
(or if you intentionally want best-effort behavior, document it and only return
a zero loss for that specific, well-justified exception), avoiding swallowing
all exceptions silently.

Comment on lines +373 to +421
# Training loop
model.train()
total_loss = 0.0
step = start_step
optimizer.zero_grad()

while step < args.num_steps:
for batch in dataloader:
if step >= args.num_steps:
break

# Debug: print chunk info (only rank 0, only first batch)
if rank == 0 and step == start_step:
print(f"📦 Batch info:")
print(f" input_ids shape: {batch['input_ids'].shape}")
print(f" Chunk range: [{batch['chunk_start']}:{batch['chunk_end']}]")
print(f" Full sequence length: {batch['full_seq_len']}\n")

# Forward and backward
loss = compute_loss_with_cp(
model=model,
batch=batch,
cp_rank=cp_rank,
cp_size=cp_size,
cp_group=cp_group,
device=device
)

# Scale loss for gradient accumulation
loss = loss / args.gradient_accumulation_steps
loss.backward()

total_loss += loss.item()

# Update weights
if (step + 1) % args.gradient_accumulation_steps == 0:
# Gradient clipping
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.max_grad_norm)

# Optimizer step
optimizer.step()
optimizer.zero_grad()

# Logging (rank 0 only)
if rank == 0 and (step + 1) % 10 == 0:
avg_loss = total_loss * args.gradient_accumulation_steps / 10
print(f"Step {step + 1}/{args.num_steps} | Loss: {avg_loss:.4f}")
total_loss = 0.0

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

Add gradient all-reduce across CP ranks – current training desynchronizes model replicas.

Right now, each rank computes loss.backward() on its own chunk and steps the optimizer independently. There is no gradient synchronization across cp_group, so parameters on different ranks will diverge, which breaks the “single logical model with context parallelism” assumption.

Given cp_size == world_size and you already have cp_group, you should all-reduce gradients once per optimizer step (after gradient accumulation, before clipping/step). For example:

-            # Update weights
-            if (step + 1) % args.gradient_accumulation_steps == 0:
-                # Gradient clipping
-                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.max_grad_norm)
-                
-                # Optimizer step
-                optimizer.step()
-                optimizer.zero_grad()
+            # Update weights
+            if (step + 1) % args.gradient_accumulation_steps == 0:
+                # All-reduce gradients across CP ranks so training matches single-GPU behavior
+                if cp_size > 1:
+                    for param in model.parameters():
+                        if param.grad is not None:
+                            dist.all_reduce(param.grad, op=dist.ReduceOp.SUM, group=cp_group)
+                            param.grad.div_(cp_size)
+
+                # Gradient clipping
+                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.max_grad_norm)
+
+                # Optimizer step
+                optimizer.step()
+                optimizer.zero_grad()

Also consider updating the docstring in compute_loss_with_cp (“NO aggregation needed”) to clarify that explicit gradient aggregation happens in the training loop, not magically through autograd.


🏁 Script executed:

#!/bin/bash
# Sanity-check: search for any other places where gradients are synchronized
# to ensure this is the only spot responsible for CP grad aggregation.
rg -n "all_reduce|DistributedDataParallel|DDP" --type=py

Length of output: 726


I need to verify the current state of the training loop and the compute_loss_with_cp function to confirm whether gradient synchronization is actually missing. Let me examine the code more closely.


---

🏁 Script executed:

```shell
#!/bin/bash
# Check the compute_loss_with_cp function and its docstring
rg -n -A30 "def compute_loss_with_cp" --type=py

Length of output: 1984


🏁 Script executed:

#!/bin/bash
# Check exact lines around the training loop in train_cp_real.py
sed -n '395,415p' tests/cp/train_cp_real.py

Length of output: 725


🏁 Script executed:

#!/bin/bash
# Check if all_reduce is used anywhere in train_cp_real.py training section
rg -n "all_reduce|dist\." tests/cp/train_cp_real.py | head -20

Length of output: 298


🏁 Script executed:

#!/bin/bash
# Find where cp_group and cp_size are defined in the training file
rg -n "cp_group|cp_size" tests/cp/train_cp_real.py | head -20

Length of output: 1184


Let me verify if there's any other mechanism that might handle parameter synchronization:


---

🏁 Script executed:

```shell
#!/bin/bash
# Check model initialization to see if parameters are synchronized
rg -n -B5 -A10 "class GatedDeltaNetForCausalLMCP" --type=py

Length of output: 1538


🏁 Script executed:

#!/bin/bash
# Check forward pass to see if it does anything special
sed -n '200,230p' tests/cp/train_cp_real.py

Length of output: 1061


🏁 Script executed:

#!/bin/bash
# Look for any broadcast or synchronization in model code
rg -n "broadcast|set_devices|all_gather" --type=py tests/cp/

Length of output: 1538


🏁 Script executed:

#!/bin/bash
# Check if GatedDeltaNetForCausalLMCP is a custom CP-aware model or just uses standard parameters
fd -e py -x grep -l "class GatedDeltaNetForCausalLMCP" {} \;

Length of output: 57


Now let me check the exact gradient synchronization pattern used in the test:


Add gradient all-reduce across CP ranks – current training desynchronizes model replicas.

Right now, each rank computes loss.backward() on its own chunk and steps the optimizer independently. There is no gradient synchronization across cp_group, so parameters on different ranks will diverge, which breaks the "single logical model with context parallelism" assumption.

Given cp_size == world_size and you already have cp_group, you should all-reduce gradients once per optimizer step (after gradient accumulation, before clipping/step). For example:

-            # Update weights
-            if (step + 1) % args.gradient_accumulation_steps == 0:
-                # Gradient clipping
-                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.max_grad_norm)
-                
-                # Optimizer step
-                optimizer.step()
-                optimizer.zero_grad()
+            # Update weights
+            if (step + 1) % args.gradient_accumulation_steps == 0:
+                # All-reduce gradients across CP ranks so training matches single-GPU behavior
+                if cp_size > 1:
+                    for param in model.parameters():
+                        if param.grad is not None:
+                            dist.all_reduce(param.grad, op=dist.ReduceOp.SUM, group=cp_group)
+                            param.grad.div_(cp_size)
+
+                # Gradient clipping
+                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.max_grad_norm)
+
+                # Optimizer step
+                optimizer.step()
+                optimizer.zero_grad()

Also update the docstring in compute_loss_with_cp (line 188) from "NO aggregation needed - gradients flow naturally through autograd!" to clarify that explicit gradient aggregation happens in the training loop, not magically through autograd.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Training loop
model.train()
total_loss = 0.0
step = start_step
optimizer.zero_grad()
while step < args.num_steps:
for batch in dataloader:
if step >= args.num_steps:
break
# Debug: print chunk info (only rank 0, only first batch)
if rank == 0 and step == start_step:
print(f"📦 Batch info:")
print(f" input_ids shape: {batch['input_ids'].shape}")
print(f" Chunk range: [{batch['chunk_start']}:{batch['chunk_end']}]")
print(f" Full sequence length: {batch['full_seq_len']}\n")
# Forward and backward
loss = compute_loss_with_cp(
model=model,
batch=batch,
cp_rank=cp_rank,
cp_size=cp_size,
cp_group=cp_group,
device=device
)
# Scale loss for gradient accumulation
loss = loss / args.gradient_accumulation_steps
loss.backward()
total_loss += loss.item()
# Update weights
if (step + 1) % args.gradient_accumulation_steps == 0:
# Gradient clipping
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.max_grad_norm)
# Optimizer step
optimizer.step()
optimizer.zero_grad()
# Logging (rank 0 only)
if rank == 0 and (step + 1) % 10 == 0:
avg_loss = total_loss * args.gradient_accumulation_steps / 10
print(f"Step {step + 1}/{args.num_steps} | Loss: {avg_loss:.4f}")
total_loss = 0.0
# Training loop
model.train()
total_loss = 0.0
step = start_step
optimizer.zero_grad()
while step < args.num_steps:
for batch in dataloader:
if step >= args.num_steps:
break
# Debug: print chunk info (only rank 0, only first batch)
if rank == 0 and step == start_step:
print(f"📦 Batch info:")
print(f" input_ids shape: {batch['input_ids'].shape}")
print(f" Chunk range: [{batch['chunk_start']}:{batch['chunk_end']}]")
print(f" Full sequence length: {batch['full_seq_len']}\n")
# Forward and backward
loss = compute_loss_with_cp(
model=model,
batch=batch,
cp_rank=cp_rank,
cp_size=cp_size,
cp_group=cp_group,
device=device
)
# Scale loss for gradient accumulation
loss = loss / args.gradient_accumulation_steps
loss.backward()
total_loss += loss.item()
# Update weights
if (step + 1) % args.gradient_accumulation_steps == 0:
# All-reduce gradients across CP ranks so training matches single-GPU behavior
if cp_size > 1:
for param in model.parameters():
if param.grad is not None:
dist.all_reduce(param.grad, op=dist.ReduceOp.SUM, group=cp_group)
param.grad.div_(cp_size)
# Gradient clipping
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.max_grad_norm)
# Optimizer step
optimizer.step()
optimizer.zero_grad()
# Logging (rank 0 only)
if rank == 0 and (step + 1) % 10 == 0:
avg_loss = total_loss * args.gradient_accumulation_steps / 10
print(f"Step {step + 1}/{args.num_steps} | Loss: {avg_loss:.4f}")
total_loss = 0.0
🧰 Tools
🪛 Flake8 (7.3.0)

[error] 386-386: f-string is missing placeholders

(F541)

🪛 Ruff (0.14.5)

386-386: f-string without any placeholders

Remove extraneous f prefix

(F541)

🤖 Prompt for AI Agents
In tests/cp/train_cp_real.py around lines 373 to 421, the training loop calls
loss.backward() per CP rank but never synchronizes gradients across cp_group,
causing model replicas to diverge; after gradient accumulation completes and
before gradient clipping and optimizer.step(), perform an all-reduce across
cp_group over each parameter gradient (torch.distributed.all_reduce with SUM)
and divide by cp_size (or use AVG) so gradients are synchronized across CP
ranks; then continue with clip_grad_norm_ and optimizer.step()/zero_grad(). Also
update the docstring in compute_loss_with_cp (around line 188) to replace "NO
aggregation needed - gradients flow naturally through autograd!" with a note
that explicit gradient aggregation is performed in the training loop via
distributed all-reduce across cp_group.

Comment on lines +1 to +6
import pytest
import torch
import torch.distributed as dist
import torch.nn.functional as F
from dtest import DTest

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix dtest import to unblock CI

CI is currently failing with ModuleNotFoundError: No module named 'dtest' when importing this module. To avoid hard‑depending on dtest and to make the tests fail gracefully when it’s absent, consider guarding the import and skipping the module if dtest is not installed:

-import pytest
-import torch
-import torch.distributed as dist
-import torch.nn.functional as F
-from dtest import DTest
+import pytest
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+
+try:
+    from dtest import DTest
+except ModuleNotFoundError:
+    pytest.skip("dtest not available; skipping CP GatedDeltaNet tests", allow_module_level=True)

If dtest is expected to be available in CI, alternatively ensure it is installed or fix the import path instead, but in any case the current unconditional import is a blocker.

🤖 Prompt for AI Agents
In tests/ops/cp/test_gated_delta_cp_real.py lines 1-6, the unconditional import
of dtest causes CI failures when that package is missing; replace it with a
guarded import so the module is skipped when dtest is unavailable. Use
pytest.importorskip("dtest") (or a try/except ImportError that calls pytest.skip
at import time) and then reference DTest from the imported module, so tests fail
gracefully when dtest is not installed.

@sustcsonglin
Copy link
Collaborator

i don't have the permission to push to the head branch, feel free to merge into the main

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants