Skip to content

Conversation

@yzhangcs
Copy link
Member

@yzhangcs yzhangcs commented Jul 2, 2025

Summary by CodeRabbit

  • New Features
    • Improved performance for certain operations by enabling a fused processing path with a default split size.
  • Refactor
    • Updated internal logic to support an optional split size parameter for enhanced flexibility.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jul 2, 2025

Walkthrough

The chunk_simple_gla_fwd function was updated to accept an optional split_size parameter, enabling a fused forward path when specified. The ChunkSimpleGLAFunction.forward method now defaults to using this fused path with split_size=128. Supporting imports were added, with no changes to backward logic.

Changes

File(s) Change Summary
fla/ops/simple_gla/chunk.py Added split_size parameter to chunk_simple_gla_fwd, enabled fused path, updated forward method, and added supporting imports.

Sequence Diagram(s)

sequenceDiagram
    participant Caller
    participant ChunkSimpleGLAFunction
    participant chunk_simple_gla_fwd
    participant fused_chunk_fwd
    participant chunk_fwd_o

    Caller->>ChunkSimpleGLAFunction: forward(q, k, v, ..., split_size=128)
    ChunkSimpleGLAFunction->>chunk_simple_gla_fwd: call with split_size=128
    alt split_size is not None
        chunk_simple_gla_fwd->>fused_chunk_fwd: call with prepared cu_seqlens
    else split_size is None
        chunk_simple_gla_fwd->>chunk_fwd_o: call as before
    end
Loading

Possibly related PRs

Poem

A hop and a split in the GLA chunk’s path,
Now fused and swift, it computes in a flash.
With split_size in tow, the tensors align,
Forward we leap, in code so fine.
🐇✨
— A rabbit, in awe of the fused design

Warning

There were issues while running some tools. Please review the errors and either fix the tool's configuration or disable the tool if it's a critical failure.

🔧 Pylint (3.3.7)
fla/ops/simple_gla/chunk.py
✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate Unit Tests
  • Create PR with Unit Tests
  • Post Copyable Unit Tests in a Comment
  • Commit Unit Tests in branch chunk-parallel

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Explain this complex logic.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai explain this code block.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and explain its main purpose.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai generate docstrings to generate docstrings for this PR.
  • @coderabbitai generate sequence diagram to generate a sequence diagram of the changes in this PR.
  • @coderabbitai auto-generate unit tests to generate unit tests for this PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug

The introduction of the fused_chunk_fwd path, triggered by the split_size parameter, introduces several issues:

  • Forward-Backward Mismatch: The forward pass modifies cu_seqlens when split_size is used, but the original cu_seqlens is saved to the autograd context, causing inconsistency in the backward pass. Additionally, the backward pass does not receive or handle the split_size parameter, leading to an asymmetry in computation paths.
  • Incorrect initial_state: fused_chunk_fwd is incorrectly passed h (intermediate hidden states from chunk_fwd_h) as its initial_state instead of the true initial_state parameter.
  • Discarded Final State: The second return value of fused_chunk_fwd (potentially an updated final state) is discarded, which may lead to an inconsistent final state (ht) being returned compared to the non-fused path.

fla/ops/simple_gla/chunk.py#L45-L57

if split_size is not None:
cu_seqlens = prepare_split_cu_seqlens(*q.shape[:2], split_size, cu_seqlens, device=q.device)
o, _ = fused_chunk_fwd(
q=q,
k=k,
v=v,
g=g,
g_gamma=g_gamma,
scale=scale,
initial_state=h,
cu_seqlens=cu_seqlens,
chunk_size=chunk_size
)

fla/ops/simple_gla/chunk.py#L145-L181

@staticmethod
@input_guard
@autocast_custom_fwd
def forward(
ctx,
q,
k,
v,
g,
g_gamma,
scale,
initial_state,
output_final_state,
cu_seqlens
):
T = q.shape[1]
chunk_size = min(64, max(16, triton.next_power_of_2(T)))
g = chunk_local_cumsum(g, chunk_size=chunk_size, cu_seqlens=cu_seqlens) if g is not None else None
o, ht = chunk_simple_gla_fwd(
q=q,
k=k,
v=v,
g=g,
g_gamma=g_gamma,
scale=scale,
initial_state=initial_state,
output_final_state=output_final_state,
cu_seqlens=cu_seqlens,
chunk_size=chunk_size,
split_size=128
)
ctx.save_for_backward(q, k, v, g, g_gamma, initial_state)
ctx.chunk_size = chunk_size
ctx.scale = scale
ctx.cu_seqlens = cu_seqlens
return o.to(q.dtype), ht

Fix in Cursor


Bug: Hardcoded Split Size Breaks Autograd Function Compatibility

The split_size parameter is hardcoded to 128 in the _chunk_simple_gla_fwd autograd function. This forces the use of fused_chunk_fwd for all calls, rendering the chunk_fwd_o code path unreachable. This change breaks backward compatibility and alters the computation behavior for existing users. The split_size should be configurable or default to None to maintain compatibility.

fla/ops/simple_gla/chunk.py#L174-L175

chunk_size=chunk_size,
split_size=128

Fix in Cursor


BugBot free trial expires on July 22, 2025
You have used $0.00 of your $50.00 spend limit so far. Manage your spend limit in the Cursor dashboard.

Was this report helpful? Give feedback by reacting with 👍 or 👎

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 18eab2d and acd6b31.

📒 Files selected for processing (1)
  • fla/ops/simple_gla/chunk.py (4 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (2)
  • GitHub Check: Cursor BugBot
  • GitHub Check: Test H100 (PyTorch 2.7) / test-ops
🔇 Additional comments (4)
fla/ops/simple_gla/chunk.py (4)

12-12: LGTM! Clean import additions for fused chunk functionality.

The imports for fused_chunk_fwd and prepare_split_cu_seqlens are correctly added to support the new fused forward path implementation.

Also applies to: 14-14


28-29: LGTM! Backward-compatible function signature extension.

The addition of chunk_size and split_size parameters maintains backward compatibility with appropriate default values.


174-175: No functional impact from default fused path

Existing tests in tests/ops/test_simple_gla.py cover all chunk_simple_gla code paths and pass without any explicit split_size usage, confirming that changing the default to 128 does not break correctness.

Recommendations (non‐blocking):

  • Document the performance difference introduced by the fused path default in the changelog or user guide.
  • If desired, expose split_size as an optional argument in a future API version to allow users to fall back to the previous behavior.

42-43: Confirmed: chunk_fwd_h supports both chunk_size and split_size.

The chunk_fwd_h imported in fla/ops/simple_gla/chunk.py comes from fla/ops/common/chunk_h.py, whose signature includes:

def chunk_fwd_h(
    …,
    chunk_size: int = 64,
    split_size: Optional[int] = None,
    …
) -> Tuple[torch.Tensor, torch.Tensor]:
    …

No changes required.

Comment on lines +45 to +69
if split_size is not None:
cu_seqlens = prepare_split_cu_seqlens(*q.shape[:2], split_size, cu_seqlens, device=q.device)
o, _ = fused_chunk_fwd(
q=q,
k=k,
v=v,
g=g,
g_gamma=g_gamma,
scale=scale,
initial_state=h,
cu_seqlens=cu_seqlens,
chunk_size=chunk_size
)
else:
o = chunk_fwd_o(
q=q,
k=k,
v=v,
g=g,
g_gamma=g_gamma,
h=h,
scale=scale,
cu_seqlens=cu_seqlens,
chunk_size=chunk_size
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Potential backward compatibility issue with gradient computation.

The conditional logic correctly implements the fused vs. original forward paths. However, there's a potential issue: the backward method in ChunkSimpleGLAFunction (lines 186-206) only uses chunk_simple_gla_bwd, which doesn't account for the fused forward path. This mismatch could cause gradient computation errors when split_size is provided.

Consider one of these solutions:

  1. Update the backward method to handle both forward paths
  2. Ensure chunk_simple_gla_bwd is compatible with both forward implementations
  3. Add a flag to track which forward path was used
#!/bin/bash
# Description: Check if there's a corresponding fused backward implementation
# Expected: Look for fused_chunk_bwd or similar backward functions

rg -A 5 "fused.*bwd|fused.*backward" --type py
🤖 Prompt for AI Agents
In fla/ops/simple_gla/chunk.py around lines 45 to 69, the forward method
conditionally calls either fused_chunk_fwd or chunk_fwd_o based on split_size,
but the backward method only calls chunk_simple_gla_bwd, which does not handle
the fused forward path. To fix this, modify the backward method to detect which
forward path was used (e.g., by adding a flag or checking inputs) and call the
corresponding backward function for fused_chunk_fwd or chunk_fwd_o accordingly,
ensuring correct gradient computation for both cases.

@yzhangcs
Copy link
Member Author

yzhangcs commented Jul 2, 2025

@yzhangcs This PR is not ready for merging as we do not observe some accelerations yet.

@yzhangcs yzhangcs changed the title Support chunk-parallel via combined chunk and fused_chunk [WIP] Support chunk-parallel via combined chunk and fused_chunk Jul 2, 2025
@yzhangcs yzhangcs linked an issue Jul 4, 2025 that may be closed by this pull request
@zhiyuan1i zhiyuan1i force-pushed the main branch 3 times, most recently from a4e1a1c to 90a0fea Compare October 27, 2025 06:43
@sustcsonglin sustcsonglin force-pushed the main branch 2 times, most recently from 1700f8d to f4082b3 Compare November 11, 2025 16:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[RFC] Unifying chunk and fused_chunk mode

2 participants