Skip to content

Conversation

@mingxu1067
Copy link
Collaborator

@mingxu1067 mingxu1067 commented Nov 14, 2025

Description

  • Adding TopK fusion to JAX for both forward and backward.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Adding TopK fusion to JAX for both forward and backward.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@mingxu1067 mingxu1067 requested a review from phu0ngng November 14, 2025 16:16
@mingxu1067 mingxu1067 changed the title TopK Fusion to JAX [Draft] TopK Fusion to JAX Nov 14, 2025
@mingxu1067 mingxu1067 force-pushed the mingh/router_fusion_to_jax branch from 4512e87 to 59047c1 Compare November 14, 2025 16:18
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Nov 14, 2025

Greptile Overview

Greptile Summary

This PR adds fused TopK router functionality to JAX, enabling efficient expert routing in Mixture-of-Experts models. The implementation provides both forward and backward passes with support for sigmoid and softmax score functions, grouped topk, and optional expert bias.

Key changes:

  • Added transformer_engine/jax/router.py with public API and custom VJP implementation
  • Created JAX primitives in cpp_extensions/router.py with complete sharding rules for distributed training
  • Implemented C++ FFI bindings in csrc/extensions/router.cpp that interface with existing CUDA kernels
  • Properly handles optional expert_bias parameter (only supported with sigmoid score function)

Architecture:
The implementation follows TransformerEngine's established pattern with three layers: high-level JAX API → JAX primitives with sharding → C++ FFI → CUDA kernels. The forward pass returns probabilities, routing map, and intermediate outputs needed for gradient computation. The backward pass uses saved intermediate values to efficiently compute gradients.

Confidence Score: 4/5

  • This PR is safe to merge with minor considerations
  • The implementation follows established patterns in the codebase and properly integrates with existing infrastructure. The C++ bindings correctly handle optional parameters and tensor conversions. However, the PR is marked as [Draft] and lacks tests, which prevents a perfect score. The code quality is high with proper error handling and validation.
  • transformer_engine/jax/router.py and transformer_engine/jax/cpp_extensions/router.py - verify gradient correctness when tests are added

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/jax/router.py 4/5 Public API for fused TopK router with custom VJP implementation and parameter validation
transformer_engine/jax/cpp_extensions/router.py 4/5 JAX primitive definitions for forward and backward passes with sharding rules
transformer_engine/jax/csrc/extensions/router.cpp 5/5 C++ FFI bindings that properly handle optional expert_bias and call underlying CUDA kernels

Sequence Diagram

sequenceDiagram
    participant User
    participant JAX API as router.py
    participant Primitive as cpp_extensions/router.py
    participant FFI as csrc/router.cpp
    participant CUDA as nvte_fused_topk

    User->>JAX API: fused_topk_with_score_function(logits, expert_bias, ...)
    JAX API->>JAX API: map_score_function("sigmoid"→0)
    JAX API->>JAX API: validate expert_bias with score_function
    JAX API->>Primitive: _fused_topk_with_score_function()
    
    Note over Primitive: Forward Pass
    Primitive->>Primitive: FusedTopkWithScoreFunctionFwdPrimitive.bind()
    Primitive->>FFI: FusedTopkWithScoreFunctionForwardFFI()
    FFI->>FFI: Create TensorWrappers
    FFI->>FFI: Handle optional expert_bias
    FFI->>CUDA: nvte_fused_topk_with_score_function_forward()
    CUDA-->>FFI: Return (probs, routing_map, intermediate_output)
    FFI-->>Primitive: Return outputs
    Primitive-->>JAX API: (probs, routing_map, intermediate_output)
    JAX API-->>User: Return results
    
    Note over User,CUDA: Backward Pass (during backprop)
    User->>JAX API: Gradient computation triggered
    JAX API->>Primitive: _fused_topk_bwd_rule(ctx, grads)
    Primitive->>Primitive: FusedTopkWithScoreFunctionBwdPrimitive.bind()
    Primitive->>FFI: FusedTopkWithScoreFunctionBackwardFFI()
    FFI->>FFI: Create TensorWrappers for routing_map, intermediate_output
    FFI->>CUDA: nvte_fused_topk_with_score_function_backward()
    CUDA-->>FFI: Return grad_logits
    FFI-->>Primitive: Return grad_logits
    Primitive-->>JAX API: (grad_logits, None)
    JAX API-->>User: Propagate gradients
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

6 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant