Skip to content

Dense Qwen3 support (0.6b, 4b, 8b) #1858

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions MaxText/common_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ class DecoderBlockType(enum.Enum):
GEMMA = "gemma"
GEMMA2 = "gemma2"
GEMMA3 = "gemma3"
QWEN3 = "qwen3"
GPT3 = "gpt3"
SIMPLE = "simple"
SIMPLE_MLP = "simple_mlp"
Expand Down
41 changes: 41 additions & 0 deletions MaxText/configs/models/qwen3-0.6b.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# model config for qwen3-0.6b

base_emb_dim: 1024
base_num_query_heads: 16
base_num_kv_heads: 8
base_mlp_dim: 3072
base_num_decoder_layers: 28
head_dim: 128
mlp_activations: ["silu", "linear"] # "hidden_act": "silu" implies SwiGLU
vocab_size: 151936

decoder_block: "qwen3"

normalization_layer_epsilon: 1.0e-6
rope_max_timescale: 1000000

use_qk_norm: True # Qwen3 models use QK Normalization

logits_via_embedding: True # from "tie_word_embeddings": true
normalize_embedding_logits: False
enable_dropout: False # deterministic for testing

tokenizer_type: "huggingface"

dtype: "bfloat16"
weight_dtype: "bfloat16"

41 changes: 41 additions & 0 deletions MaxText/configs/models/qwen3-4b.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# model config for qwen3-0.6b

base_emb_dim: 2560
base_num_query_heads: 32
base_num_kv_heads: 8
base_mlp_dim: 9728
base_num_decoder_layers: 36
head_dim: 128
mlp_activations: ["silu", "linear"] # "hidden_act": "silu" implies SwiGLU
vocab_size: 151936

decoder_block: "qwen3"

normalization_layer_epsilon: 1.0e-6
rope_max_timescale: 1000000

use_qk_norm: True # Qwen3 models use QK Normalization

logits_via_embedding: True # from "tie_word_embeddings": true
normalize_embedding_logits: False
enable_dropout: False # deterministic for testing

tokenizer_type: "huggingface"

dtype: "bfloat16"
weight_dtype: "bfloat16"

40 changes: 40 additions & 0 deletions MaxText/configs/models/qwen3-8b.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# model config for qwen3-0.6b

base_emb_dim: 4096
base_num_query_heads: 32
base_num_kv_heads: 8
base_mlp_dim: 12288
base_num_decoder_layers: 36
head_dim: 128
mlp_activations: ["silu", "linear"] # "hidden_act": "silu" implies SwiGLU
vocab_size: 151936

decoder_block: "qwen3"

normalization_layer_epsilon: 1.0e-6
rope_max_timescale: 1000000

use_qk_norm: True # Qwen3 models use QK Normalization

logits_via_embedding: False # different from smaller variants, "tie_word_embeddings": false
normalize_embedding_logits: False
enable_dropout: False # deterministic for testing

tokenizer_type: "huggingface"

dtype: "bfloat16"
weight_dtype: "bfloat16"
6 changes: 5 additions & 1 deletion MaxText/layers/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,10 @@ def get_decoder_layers(self):
from MaxText.layers import gpt3 # pylint: disable=import-outside-toplevel

return [gpt3.Gpt3DecoderLayer]
elif self.config.decoder_block == DecoderBlockType.QWEN3:
from MaxText.layers import qwen3 # pylint: disable=import-outside-toplevel

return [qwen3.Qwen3DecoderLayer]
elif self.config.decoder_block == DecoderBlockType.SIMPLE:
from MaxText.layers import simple_layer # pylint: disable=import-outside-toplevel

Expand Down Expand Up @@ -376,6 +380,7 @@ def get_norm_layer(self):
DecoderBlockType.GEMMA,
DecoderBlockType.GEMMA2,
DecoderBlockType.GEMMA3,
DecoderBlockType.QWEN3,
DecoderBlockType.SIMPLE,
DecoderBlockType.SIMPLE_MLP,
DecoderBlockType.LLAMA4,
Expand Down Expand Up @@ -804,4 +809,3 @@ def __call__(
image_embeddings=image_embeddings,
)
return logits

159 changes: 159 additions & 0 deletions MaxText/layers/qwen3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
"""
Copyright 2025 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

"""Qwen3 model decoder layer."""
# pylint: disable=arguments-differ
# pylint: disable=no-name-in-module

from typing import Optional

from jax.ad_checkpoint import checkpoint_name
from jax.sharding import Mesh
import jax.numpy as jnp

from flax import linen as nn

from MaxText.common_types import Config
from MaxText.layers import attentions
from MaxText.layers import initializers
from MaxText.layers import linears
from MaxText.layers import moe
from MaxText.layers import quantizations
from MaxText.layers.normalizations import RMSNorm
from MaxText.layers.quantizations import AqtQuantization as Quant
from MaxText.inference import page_manager


class Qwen3DecoderLayer(nn.Module):
"""Qwen3 Transformer decoder layer."""

config: Config
mesh: Mesh
quant: Optional[Quant] = None

@nn.compact
def __call__(
self,
inputs: jnp.ndarray,
decoder_segment_ids: Optional[jnp.ndarray],
decoder_positions: Optional[jnp.ndarray],
deterministic: bool,
model_mode: str,
previous_chunk=None,
page_state: Optional[page_manager.PageState] = None,
slot: Optional[int] = None,
):
cfg = self.config
mesh = self.mesh

inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_length", "activation_embed"))
inputs_checkpoint = checkpoint_name(inputs, "decoder_layer_input")

# Corresponds to Qwen3's `input_layernorm`
lnx = RMSNorm(
dtype=cfg.dtype,
weight_dtype=cfg.weight_dtype,
name="pre_self_attention_layer_norm",
epsilon=cfg.normalization_layer_epsilon,
kernel_axes=("norm",),
)(inputs_checkpoint)
lnx = nn.with_logical_constraint(lnx, ("activation_batch", "activation_length", "activation_embed"))

# Self-attention block
attention_layer = attentions.Attention(
config=cfg,
num_query_heads=cfg.num_query_heads,
num_kv_heads=cfg.num_kv_heads,
head_dim=cfg.head_dim,
max_target_length=cfg.max_target_length,
max_prefill_predict_length=cfg.max_prefill_predict_length,
attention_kernel=cfg.attention,
mesh=mesh,
dtype=cfg.dtype,
weight_dtype=cfg.weight_dtype,
dropout_rate=cfg.dropout_rate,
name="self_attention",
quant=self.quant,
kv_quant=quantizations.configure_kv_quant(cfg),
use_qk_norm=cfg.use_qk_norm,
query_pre_attn_scalar=(cfg.head_dim**-0.5), # Qwen3 specific scaling
)

attention_output = attention_layer(
lnx, # inputs_q
lnx, # inputs_kv
decoder_positions,
decoder_segment_ids=decoder_segment_ids,
deterministic=deterministic,
model_mode=model_mode,
)
attention_output = nn.with_logical_constraint(
attention_output, ("activation_batch", "activation_length", "activation_embed")
)

# Residual connection after attention
residual_after_attention = inputs_checkpoint + attention_output

# Post Attention LayerNorm (corresponds to Qwen3's `post_attention_layernorm`)
mlp_input = RMSNorm(
dtype=cfg.dtype,
weight_dtype=cfg.weight_dtype,
name="post_self_attention_layer_norm", # Standard MaxText naming
epsilon=cfg.normalization_layer_epsilon,
kernel_axes=("norm",),
)(residual_after_attention)
mlp_input = nn.with_logical_constraint(mlp_input, ("activation_batch", "activation_length", "activation_embed"))

# MLP block
if cfg.num_experts is None or cfg.num_experts <= 1: # Dense MLP
mlp_output = linears.MlpBlock(
intermediate_dim=cfg.mlp_dim,
activations=cfg.mlp_activations,
intermediate_dropout_rate=cfg.dropout_rate,
dtype=cfg.dtype,
weight_dtype=cfg.weight_dtype,
name="mlp",
config=cfg,
quant=self.quant,
)(mlp_input, deterministic=deterministic)
else: # Mixture of Experts MLP
mlp_output, _ = moe.RoutedMoE(
config=cfg,
num_experts=cfg.num_experts,
num_experts_per_tok=cfg.num_experts_per_tok,
mesh=self.mesh,
kernel_init=initializers.nd_dense_init(1.0, "fan_in", "truncated_normal"),
kernel_axes=("embed", None),
intermediate_dim=cfg.mlp_dim,
dtype=cfg.dtype,
weight_dtype=cfg.weight_dtype,
name="moe_block",
quant=self.quant,
)(mlp_input)

mlp_output = nn.with_logical_constraint(mlp_output, ("activation_batch", "activation_length", "activation_embed"))

# Final residual connection
layer_output = residual_after_attention + mlp_output
layer_output = nn.with_logical_constraint(
layer_output,
("activation_batch", "activation_length", "activation_embed"),
)

if cfg.scan_layers:
return layer_output, None
else:
return layer_output
3 changes: 3 additions & 0 deletions MaxText/pyconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,9 @@ def validate_model_name(s: str) -> bool:
"gemma3-4b",
"gemma3-12b",
"gemma3-27b",
"qwen3-0.6b",
"qwen3-4b",
"qwen3-8b",
"gpt3-175b",
"gpt3-22b",
"gpt3-6b",
Expand Down
51 changes: 51 additions & 0 deletions MaxText/utils/ckpt_conversion/utils/hf_model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,54 @@
query_pre_attn_scalar=144,
)

qwen3_0_6b_config = transformers.Qwen3Config(
vocab_size=151936,
hidden_size=1024,
intermediate_size=3072,
num_hidden_layers=28,
num_attention_heads=16,
num_key_value_heads=8,
head_dim=128,
hidden_act="silu",
max_position_embeddings=40960,
rms_norm_eps=1.0e-6,
rope_theta=1000000.0,
tie_word_embeddings=True,
torch_dtype="bfloat16",
)

qwen3_4b_config = transformers.Qwen3Config(
vocab_size=151936,
hidden_size=2560,
intermediate_size=9728,
num_hidden_layers=36,
num_attention_heads=32,
num_key_value_heads=8,
head_dim=128,
hidden_act="silu",
max_position_embeddings=40960,
rms_norm_eps=1.0e-6,
rope_theta=1000000.0,
tie_word_embeddings=True,
torch_dtype="bfloat16",
)

qwen3_8b_config = transformers.Qwen3Config(
vocab_size=151936,
hidden_size=4096,
intermediate_size=12288,
num_hidden_layers=36,
num_attention_heads=32,
num_key_value_heads=8,
head_dim=128,
hidden_act="silu",
max_position_embeddings=40960,
rms_norm_eps=1.0e-6,
rope_theta=1000000.0,
tie_word_embeddings=False,
torch_dtype="bfloat16",
)


HF_MODEL_CONFIGS = {
"gemma2-2b": gemma2_2b_config,
Expand All @@ -117,4 +165,7 @@
"gemma3-4b": gemma3text_4b_config,
"gemma3-12b": gemma3text_12b_config,
"gemma3-27b": gemma3text_27b_config,
"qwen3-0.6b": qwen3_0_6b_config,
"qwen3-4b": qwen3_4b_config,
"qwen3-8b": qwen3_8b_config,
}
Loading
Loading