-
Notifications
You must be signed in to change notification settings - Fork 546
[Draft] TopK Fusion to JAX #2385
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Ming Huang <[email protected]>
Signed-off-by: Ming Huang <[email protected]>
Signed-off-by: Ming Huang <[email protected]>
Signed-off-by: Ming Huang <[email protected]>
4512e87 to
59047c1
Compare
for more information, see https://pre-commit.ci
Greptile OverviewGreptile SummaryThis PR adds fused TopK router functionality to JAX, enabling efficient expert routing in Mixture-of-Experts models. The implementation provides both forward and backward passes with support for sigmoid and softmax score functions, grouped topk, and optional expert bias. Key changes:
Architecture: Confidence Score: 4/5
Important Files ChangedFile Analysis
Sequence DiagramsequenceDiagram
participant User
participant JAX API as router.py
participant Primitive as cpp_extensions/router.py
participant FFI as csrc/router.cpp
participant CUDA as nvte_fused_topk
User->>JAX API: fused_topk_with_score_function(logits, expert_bias, ...)
JAX API->>JAX API: map_score_function("sigmoid"→0)
JAX API->>JAX API: validate expert_bias with score_function
JAX API->>Primitive: _fused_topk_with_score_function()
Note over Primitive: Forward Pass
Primitive->>Primitive: FusedTopkWithScoreFunctionFwdPrimitive.bind()
Primitive->>FFI: FusedTopkWithScoreFunctionForwardFFI()
FFI->>FFI: Create TensorWrappers
FFI->>FFI: Handle optional expert_bias
FFI->>CUDA: nvte_fused_topk_with_score_function_forward()
CUDA-->>FFI: Return (probs, routing_map, intermediate_output)
FFI-->>Primitive: Return outputs
Primitive-->>JAX API: (probs, routing_map, intermediate_output)
JAX API-->>User: Return results
Note over User,CUDA: Backward Pass (during backprop)
User->>JAX API: Gradient computation triggered
JAX API->>Primitive: _fused_topk_bwd_rule(ctx, grads)
Primitive->>Primitive: FusedTopkWithScoreFunctionBwdPrimitive.bind()
Primitive->>FFI: FusedTopkWithScoreFunctionBackwardFFI()
FFI->>FFI: Create TensorWrappers for routing_map, intermediate_output
FFI->>CUDA: nvte_fused_topk_with_score_function_backward()
CUDA-->>FFI: Return grad_logits
FFI-->>Primitive: Return grad_logits
Primitive-->>JAX API: (grad_logits, None)
JAX API-->>User: Propagate gradients
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
6 files reviewed, 1 comment
Description
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: