Skip to content

Conversation

parambole
Copy link
Collaborator

TL;DR

  • What: This PR integrates the Qwen3-Next architecture into MaxText. Key features include a hybrid attention mechanism combining a Gated Delta Net (a form of linear attention) with standard attention, and an ultra-sparse MoE block that includes a shared expert.

  • How: By implementing several new, highly-optimized JAX layers (Qwen3NextGatedDeltaNet, Qwen3NextSparseMoeBlock, Qwen3NextRMSNorm), creating a hybrid Qwen3NextDecoderLayer that alternates between attention types based on the full_attention_interval, and validating each new component against a PyTorch reference in a robust new test suite.


Detailed Description

This pull request introduces an implementation of the Qwen3-Next architecture, as described in the official Qwen3-Next blog post.

Architectural Implementation (src/MaxText/layers/qwen3.py)

  • Hybrid Attention: The core of this architecture is its hybrid attention system. This PR implements:

    • Qwen3NextGatedDeltaNet: A full JAX implementation of the Gated Delta Net for efficient linear attention. The core recurrence is handled by the pure JAX jax_chunk_gated_delta_rule, a parallel scan algorithm.

    • Qwen3NextFullAttention: A placeholder for the model's standard attention layers.

  • Custom Normalization: A custom Qwen3NextRMSNorm layer is included where the learnable weight is initialized to zeros and applied as (1.0 + weight), matching the reference implementation's specific behavior.


Decoder and Model Integration

  • Qwen3NextDecoderLayer: This new decoder layer acts as the hybrid controller. Based on its layer_idx and the full_attention_interval config parameter, it conditionally invokes either the Qwen3NextGatedDeltaNet or the Qwen3NextFullAttention layer before passing the result to the Qwen3NextSparseMoeBlock.

  • Scanned Layers: The implementation also introduces a Qwen3NextScannableBlock.

  • Configuration: A new model config qwen3-next-80b-a3b.yml is added, along with the necessary parameters in base.yml to control the new architectural features.


Testing and Validation

This PR includes a new test file, tests/check_qwen3_next_vs_reference.py.

  • It directly copies the reference PyTorch implementations for each new component.

  • It provides granular unit tests that validate the numerical output of each JAX module (Qwen3NextGatedDeltaNet, Qwen3NextSparseMoeBlock, etc.) against its corresponding PyTorch reference.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed.

@parambole parambole force-pushed the parambole/maxtext_qwen3_next_v1 branch from 5ebac43 to f1f22d2 Compare October 11, 2025 00:38
@parambole parambole force-pushed the parambole/maxtext_qwen3_next_v1 branch from 8553570 to 8d642ab Compare October 13, 2025 18:35
@parambole parambole requested a review from bvandermoon October 13, 2025 22:36
Copy link

🤖 Hi @parambole, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants