-
Notifications
You must be signed in to change notification settings - Fork 307
[WIP] Support chunk-parallel via combined chunk and fused_chunk
#505
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
WalkthroughThe Changes
Sequence Diagram(s)sequenceDiagram
participant Caller
participant ChunkSimpleGLAFunction
participant chunk_simple_gla_fwd
participant fused_chunk_fwd
participant chunk_fwd_o
Caller->>ChunkSimpleGLAFunction: forward(q, k, v, ..., split_size=128)
ChunkSimpleGLAFunction->>chunk_simple_gla_fwd: call with split_size=128
alt split_size is not None
chunk_simple_gla_fwd->>fused_chunk_fwd: call with prepared cu_seqlens
else split_size is None
chunk_simple_gla_fwd->>chunk_fwd_o: call as before
end
Possibly related PRs
Poem
Warning There were issues while running some tools. Please review the errors and either fix the tool's configuration or disable the tool if it's a critical failure. 🔧 Pylint (3.3.7)fla/ops/simple_gla/chunk.py✨ Finishing Touches
🧪 Generate Unit Tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bug
The introduction of the fused_chunk_fwd path, triggered by the split_size parameter, introduces several issues:
- Forward-Backward Mismatch: The forward pass modifies
cu_seqlenswhensplit_sizeis used, but the originalcu_seqlensis saved to the autograd context, causing inconsistency in the backward pass. Additionally, the backward pass does not receive or handle thesplit_sizeparameter, leading to an asymmetry in computation paths. - Incorrect
initial_state:fused_chunk_fwdis incorrectly passedh(intermediate hidden states fromchunk_fwd_h) as itsinitial_stateinstead of the trueinitial_stateparameter. - Discarded Final State: The second return value of
fused_chunk_fwd(potentially an updated final state) is discarded, which may lead to an inconsistent final state (ht) being returned compared to the non-fused path.
fla/ops/simple_gla/chunk.py#L45-L57
flash-linear-attention/fla/ops/simple_gla/chunk.py
Lines 45 to 57 in acd6b31
| if split_size is not None: | |
| cu_seqlens = prepare_split_cu_seqlens(*q.shape[:2], split_size, cu_seqlens, device=q.device) | |
| o, _ = fused_chunk_fwd( | |
| q=q, | |
| k=k, | |
| v=v, | |
| g=g, | |
| g_gamma=g_gamma, | |
| scale=scale, | |
| initial_state=h, | |
| cu_seqlens=cu_seqlens, | |
| chunk_size=chunk_size | |
| ) |
fla/ops/simple_gla/chunk.py#L145-L181
flash-linear-attention/fla/ops/simple_gla/chunk.py
Lines 145 to 181 in acd6b31
| @staticmethod | |
| @input_guard | |
| @autocast_custom_fwd | |
| def forward( | |
| ctx, | |
| q, | |
| k, | |
| v, | |
| g, | |
| g_gamma, | |
| scale, | |
| initial_state, | |
| output_final_state, | |
| cu_seqlens | |
| ): | |
| T = q.shape[1] | |
| chunk_size = min(64, max(16, triton.next_power_of_2(T))) | |
| g = chunk_local_cumsum(g, chunk_size=chunk_size, cu_seqlens=cu_seqlens) if g is not None else None | |
| o, ht = chunk_simple_gla_fwd( | |
| q=q, | |
| k=k, | |
| v=v, | |
| g=g, | |
| g_gamma=g_gamma, | |
| scale=scale, | |
| initial_state=initial_state, | |
| output_final_state=output_final_state, | |
| cu_seqlens=cu_seqlens, | |
| chunk_size=chunk_size, | |
| split_size=128 | |
| ) | |
| ctx.save_for_backward(q, k, v, g, g_gamma, initial_state) | |
| ctx.chunk_size = chunk_size | |
| ctx.scale = scale | |
| ctx.cu_seqlens = cu_seqlens | |
| return o.to(q.dtype), ht |
Bug: Hardcoded Split Size Breaks Autograd Function Compatibility
The split_size parameter is hardcoded to 128 in the _chunk_simple_gla_fwd autograd function. This forces the use of fused_chunk_fwd for all calls, rendering the chunk_fwd_o code path unreachable. This change breaks backward compatibility and alters the computation behavior for existing users. The split_size should be configurable or default to None to maintain compatibility.
fla/ops/simple_gla/chunk.py#L174-L175
flash-linear-attention/fla/ops/simple_gla/chunk.py
Lines 174 to 175 in acd6b31
| chunk_size=chunk_size, | |
| split_size=128 |
BugBot free trial expires on July 22, 2025
You have used $0.00 of your $50.00 spend limit so far. Manage your spend limit in the Cursor dashboard.
Was this report helpful? Give feedback by reacting with 👍 or 👎
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
fla/ops/simple_gla/chunk.py(4 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (2)
- GitHub Check: Cursor BugBot
- GitHub Check: Test H100 (PyTorch 2.7) / test-ops
🔇 Additional comments (4)
fla/ops/simple_gla/chunk.py (4)
12-12: LGTM! Clean import additions for fused chunk functionality.The imports for
fused_chunk_fwdandprepare_split_cu_seqlensare correctly added to support the new fused forward path implementation.Also applies to: 14-14
28-29: LGTM! Backward-compatible function signature extension.The addition of
chunk_sizeandsplit_sizeparameters maintains backward compatibility with appropriate default values.
174-175: No functional impact from default fused pathExisting tests in tests/ops/test_simple_gla.py cover all
chunk_simple_glacode paths and pass without any explicitsplit_sizeusage, confirming that changing the default to128does not break correctness.Recommendations (non‐blocking):
- Document the performance difference introduced by the fused path default in the changelog or user guide.
- If desired, expose
split_sizeas an optional argument in a future API version to allow users to fall back to the previous behavior.
42-43: Confirmed:chunk_fwd_hsupports bothchunk_sizeandsplit_size.The
chunk_fwd_himported infla/ops/simple_gla/chunk.pycomes fromfla/ops/common/chunk_h.py, whose signature includes:def chunk_fwd_h( …, chunk_size: int = 64, split_size: Optional[int] = None, … ) -> Tuple[torch.Tensor, torch.Tensor]: …No changes required.
| if split_size is not None: | ||
| cu_seqlens = prepare_split_cu_seqlens(*q.shape[:2], split_size, cu_seqlens, device=q.device) | ||
| o, _ = fused_chunk_fwd( | ||
| q=q, | ||
| k=k, | ||
| v=v, | ||
| g=g, | ||
| g_gamma=g_gamma, | ||
| scale=scale, | ||
| initial_state=h, | ||
| cu_seqlens=cu_seqlens, | ||
| chunk_size=chunk_size | ||
| ) | ||
| else: | ||
| o = chunk_fwd_o( | ||
| q=q, | ||
| k=k, | ||
| v=v, | ||
| g=g, | ||
| g_gamma=g_gamma, | ||
| h=h, | ||
| scale=scale, | ||
| cu_seqlens=cu_seqlens, | ||
| chunk_size=chunk_size | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Potential backward compatibility issue with gradient computation.
The conditional logic correctly implements the fused vs. original forward paths. However, there's a potential issue: the backward method in ChunkSimpleGLAFunction (lines 186-206) only uses chunk_simple_gla_bwd, which doesn't account for the fused forward path. This mismatch could cause gradient computation errors when split_size is provided.
Consider one of these solutions:
- Update the backward method to handle both forward paths
- Ensure
chunk_simple_gla_bwdis compatible with both forward implementations - Add a flag to track which forward path was used
#!/bin/bash
# Description: Check if there's a corresponding fused backward implementation
# Expected: Look for fused_chunk_bwd or similar backward functions
rg -A 5 "fused.*bwd|fused.*backward" --type py🤖 Prompt for AI Agents
In fla/ops/simple_gla/chunk.py around lines 45 to 69, the forward method
conditionally calls either fused_chunk_fwd or chunk_fwd_o based on split_size,
but the backward method only calls chunk_simple_gla_bwd, which does not handle
the fused forward path. To fix this, modify the backward method to detect which
forward path was used (e.g., by adding a flag or checking inputs) and call the
corresponding backward function for fused_chunk_fwd or chunk_fwd_o accordingly,
ensuring correct gradient computation for both cases.
|
@yzhangcs This PR is not ready for merging as we do not observe some accelerations yet. |
chunk and fused_chunkchunk and fused_chunk
a4e1a1c to
90a0fea
Compare
1700f8d to
f4082b3
Compare
Summary by CodeRabbit