-
Notifications
You must be signed in to change notification settings - Fork 289
Add Qwen3 Moe #2260
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
kanpuriyanawab
wants to merge
13
commits into
keras-team:master
Choose a base branch
from
kanpuriyanawab:qwen3_moe
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Add Qwen3 Moe #2260
Changes from all commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
e528378
qwen3 moe init
kanpuriyanawab 84043a3
bug fixes
kanpuriyanawab 750412c
update
kanpuriyanawab 9b3d779
Merge branch 'keras-team:master' into qwen3_moe
kanpuriyanawab 6b74171
address comments
kanpuriyanawab 730a9c4
address comments
kanpuriyanawab 5f90d10
update output matching script
kanpuriyanawab cda9cfc
fix test
kanpuriyanawab 1b21c7c
update qwen3 causal lm
kanpuriyanawab 6214a1b
Merge branch 'master' into qwen3_moe
kanpuriyanawab e9e6ca5
output matching + bug fixes
kanpuriyanawab 6fd2de8
uncomment flag in conversion script
kanpuriyanawab 992bfc6
api gen
kanpuriyanawab File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,361 @@ | ||
import math | ||
|
||
import keras | ||
from keras import ops | ||
|
||
from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding | ||
from keras_hub.src.models.qwen3_moe.qwen3_moe_layernorm import Qwen3MoeLayerNorm | ||
from keras_hub.src.utils.keras_utils import clone_initializer | ||
from keras_hub.src.utils.keras_utils import fused_attention_op_available | ||
|
||
|
||
class Qwen3MoeAttention(keras.layers.Layer): | ||
"""A multi-head attention layer for Qwen3Moe models | ||
This attention implementation supports grouped-query attention (GQA) where | ||
the number of key-value heads can be less than the number of query heads. | ||
Args: | ||
num_query_heads: Number of query heads. | ||
num_key_value_heads: Number of key/value heads (for GQA). | ||
rope_max_wavelength: Maximum wavelength for RoPE (Rotary Position | ||
Embedding). | ||
rope_scaling_factor: Scaling factor for RoPE, used for extending | ||
context length. | ||
kernel_initializer: Initializer for the kernel weights. | ||
dropout: Dropout rate for attention weights. | ||
sliding_window_size: Size of the sliding window for attention. | ||
**kwargs: Additional keyword arguments to pass to the Layer. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
num_query_heads, | ||
num_key_value_heads, | ||
head_dim, | ||
rope_max_wavelength=10000, | ||
rope_scaling_factor=1, | ||
kernel_initializer="glorot_uniform", | ||
dropout=0.0, | ||
layer_norm_epsilon=1e-6, | ||
sliding_window_size=None, | ||
**kwargs, | ||
): | ||
super().__init__( | ||
**kwargs, | ||
) | ||
self.num_query_heads = num_query_heads | ||
self.num_key_value_heads = num_key_value_heads | ||
self.head_dim = head_dim | ||
self.dropout = dropout | ||
|
||
self.layer_norm_epsilon = layer_norm_epsilon | ||
|
||
self.num_key_value_groups = num_query_heads // num_key_value_heads | ||
self.rope_max_wavelength = rope_max_wavelength | ||
|
||
self.kernel_initializer = keras.initializers.get( | ||
clone_initializer(kernel_initializer) | ||
) | ||
|
||
self.rope_scaling_factor = rope_scaling_factor | ||
self.sliding_window_size = sliding_window_size | ||
|
||
def build(self, inputs_shape): | ||
# Einsum variables: | ||
# b = batch size | ||
# q = query length | ||
# k = key/value length | ||
# m = model dim | ||
# u = num query heads | ||
# v = num key/value heads | ||
# h = head dim | ||
hidden_dim = inputs_shape[-1] | ||
if not self.head_dim: | ||
self.head_dim = hidden_dim // self.num_query_heads | ||
|
||
self._inv_norm_factor = 1.0 / math.sqrt(self.head_dim) | ||
self._query_dense = keras.layers.EinsumDense( | ||
equation="bqm,muh->bquh", | ||
output_shape=(None, self.num_query_heads, self.head_dim), | ||
kernel_initializer=self.kernel_initializer, | ||
dtype=self.dtype_policy, | ||
name="query", | ||
) | ||
self._query_dense.build(inputs_shape) | ||
|
||
self._query_dense_layer_norm = Qwen3MoeLayerNorm( | ||
epsilon=self.layer_norm_epsilon, | ||
dtype=self.dtype_policy, | ||
head_dim=self.head_dim, | ||
name="query_dense_layernorm", | ||
) | ||
self._query_dense_layer_norm.build(inputs_shape) | ||
|
||
self._key_dense = keras.layers.EinsumDense( | ||
equation="bkm,mvh->bkvh", | ||
output_shape=( | ||
None, | ||
self.num_key_value_heads, | ||
self.head_dim, | ||
), | ||
kernel_initializer=self.kernel_initializer, | ||
dtype=self.dtype_policy, | ||
name="key", | ||
) | ||
self._key_dense.build(inputs_shape) | ||
|
||
self._key_dense_layer_norm = Qwen3MoeLayerNorm( | ||
epsilon=self.layer_norm_epsilon, | ||
dtype=self.dtype_policy, | ||
head_dim=self.head_dim, | ||
name="key_dense_layernorm", | ||
) | ||
self._key_dense_layer_norm.build(inputs_shape) | ||
|
||
self._value_dense = keras.layers.EinsumDense( | ||
equation="bkm,mvh->bkvh", | ||
output_shape=( | ||
None, | ||
self.num_key_value_heads, | ||
self.head_dim, | ||
), | ||
kernel_initializer=self.kernel_initializer, | ||
dtype=self.dtype_policy, | ||
name="value", | ||
) | ||
self._value_dense.build(inputs_shape) | ||
|
||
self._softmax = keras.layers.Softmax( | ||
axis=-1, | ||
dtype="float32", | ||
name="attention_softmax", | ||
) | ||
|
||
self._dropout_layer = keras.layers.Dropout( | ||
rate=self.dropout, | ||
dtype=self.dtype_policy, | ||
) | ||
|
||
self._output_dense = keras.layers.EinsumDense( | ||
equation="bquh,uhm->bqm", | ||
output_shape=(None, hidden_dim), | ||
kernel_initializer=self.kernel_initializer, | ||
dtype=self.dtype_policy, | ||
name="attention_output", | ||
) | ||
self._output_dense.build( | ||
(None, None, self.num_query_heads, self.head_dim) | ||
) | ||
|
||
self.rotary_embedding_layer = RotaryEmbedding( | ||
max_wavelength=self.rope_max_wavelength, | ||
scaling_factor=self.rope_scaling_factor, | ||
dtype=self.dtype_policy, | ||
) | ||
|
||
self._dot_product_equation = "bquh,bkuh->buqk" | ||
self._combine_equation = "buqk,bkuh->bquh" | ||
|
||
self.built = True | ||
|
||
def call( | ||
self, | ||
hidden_states, | ||
attention_mask=None, | ||
cache=None, | ||
cache_update_index=None, | ||
training=None, | ||
): | ||
"""Applies attention mechanism to the input hidden states. | ||
Args: | ||
hidden_states: Input tensor of shape [batch_size, seq_length, | ||
hidden_size]. | ||
attention_mask: Mask tensor of shape [batch_size, seq_length, | ||
seq_length]. | ||
cache: Optional cached key and value tensors. | ||
cache_update_index: Index at which to update the cache. | ||
training: Boolean indicating whether in training mode. | ||
Returns: | ||
attention_output: Output tensor after applying attention. | ||
cache: Updated cache tensors (if cache is provided). | ||
""" | ||
start_index = ( | ||
cache_update_index if cache_update_index is not None else 0 | ||
) | ||
|
||
query = self._query_dense(hidden_states) | ||
query = self._query_dense_layer_norm(query) | ||
|
||
# Compute RoPE for queries | ||
query = self.rotary_embedding_layer(query, start_index=start_index) | ||
|
||
def _compute_key_value(x): | ||
key = self._key_dense(x) | ||
key = self._key_dense_layer_norm(key) | ||
key = self.rotary_embedding_layer(key, start_index=start_index) | ||
|
||
value = self._value_dense(x) | ||
|
||
return key, value | ||
|
||
if cache is not None: | ||
key_cache = cache[:, 0, ...] | ||
value_cache = cache[:, 1, ...] | ||
if cache_update_index is None: | ||
key = key_cache | ||
value = value_cache | ||
else: | ||
key_update, value_update = _compute_key_value(hidden_states) | ||
start = [0, cache_update_index, 0, 0] | ||
key = ops.slice_update(key_cache, start, key_update) | ||
value = ops.slice_update(value_cache, start, value_update) | ||
cache = ops.stack((key, value), axis=1) | ||
else: | ||
if cache_update_index is not None: | ||
raise ValueError( | ||
"`cache_update_index` should not be set if `cache` is " | ||
f"`None`. Received: cache={cache}, " | ||
f"cache_update_index={cache_update_index}" | ||
) | ||
key, value = _compute_key_value(hidden_states) | ||
|
||
# [batch_shape, seq_len, num_key_value_heads, head_dim] | ||
# -> [batch_shape, seq_len, num_heads, head_dim] | ||
key = ops.repeat(key, repeats=self.num_key_value_groups, axis=2) | ||
value = ops.repeat(value, repeats=self.num_key_value_groups, axis=2) | ||
|
||
attention_output = self._compute_attention( | ||
query, | ||
key, | ||
value, | ||
attention_mask, | ||
cache_update_index=cache_update_index, | ||
) | ||
|
||
attention_output = self._dropout_layer( | ||
attention_output, training=training | ||
) | ||
|
||
attention_output = self._output_dense(attention_output) | ||
|
||
if cache is not None: | ||
return attention_output, cache | ||
return attention_output | ||
|
||
def _masked_softmax(self, attention_scores, attention_mask=None): | ||
"""Applies softmax with optional masking. | ||
Args: | ||
attention_scores: Attention score tensor. | ||
attention_mask: Optional mask tensor. | ||
Returns: | ||
Masked softmax attention weights. | ||
""" | ||
if attention_mask is not None: | ||
return self._softmax( | ||
attention_scores, attention_mask[:, None, :, :] | ||
) | ||
return self._softmax(attention_scores) | ||
|
||
def _compute_attention( | ||
self, query, key, value, attention_mask=None, cache_update_index=None | ||
): | ||
"""Computes attention using query, key, and value tensors. | ||
Uses Flash Attention when available for better performance. | ||
Args: | ||
query: Query tensor. | ||
key: Key tensor. | ||
value: Value tensor. | ||
attention_mask: Optional mask tensor. | ||
cache_update_index: Index for sliding window computation. | ||
Returns: | ||
attention_output: Output tensor after applying attention. | ||
""" | ||
if fused_attention_op_available(): | ||
# Use `dot_product_attention` with Flash Attention support if | ||
# available. | ||
if attention_mask is not None: | ||
attention_mask = ops.expand_dims(attention_mask, axis=1) | ||
attention_mask = ops.cast(attention_mask, dtype="bool") | ||
attention_output = ops.dot_product_attention( | ||
query, | ||
key, | ||
value, | ||
mask=attention_mask, | ||
scale=self._inv_norm_factor, | ||
) | ||
return attention_output | ||
|
||
attention_scores = ops.einsum(self._dot_product_equation, query, key) | ||
|
||
attention_scores = ops.multiply( | ||
attention_scores, | ||
ops.cast(self._inv_norm_factor, self.compute_dtype), | ||
) | ||
if self.sliding_window_size: | ||
attention_mask = self._mask_sliding_window( | ||
attention_mask, | ||
cache_update_index=cache_update_index | ||
if cache_update_index is not None | ||
else 0, | ||
) | ||
attention_scores = self._masked_softmax( | ||
attention_scores, attention_mask | ||
) | ||
attention_scores = ops.cast(attention_scores, self.compute_dtype) | ||
attention_output = ops.einsum( | ||
self._combine_equation, attention_scores, value | ||
) | ||
|
||
return attention_output | ||
|
||
def _mask_sliding_window( | ||
self, | ||
attention_mask, | ||
cache_update_index=0, | ||
): | ||
"""Creates and combines a sliding window mask with the attention mask. | ||
Args: | ||
attention_mask: Original attention mask. | ||
cache_update_index: Starting index for the sliding window. | ||
Returns: | ||
Combined attention mask with sliding window constraints. | ||
""" | ||
_, query_len, key_len = ops.shape(attention_mask) | ||
# Compute the sliding window for square attention. | ||
all_ones = ops.ones((key_len, key_len), "bool") | ||
if keras.config.backend() == "tensorflow": | ||
# TODO: trui/tril has issues with dynamic shape on the tensorflow | ||
# backend. We should fix, but use `band_part` for now. | ||
import tensorflow as tf | ||
|
||
band_size = ops.minimum(key_len, self.sliding_window_size - 1) | ||
band_size = ops.cast(band_size, "int32") | ||
sliding_mask = tf.linalg.band_part(all_ones, band_size, band_size) | ||
else: | ||
sliding_mask = ops.triu( | ||
all_ones, -1 * self.sliding_window_size + 1 | ||
) * ops.tril(all_ones, self.sliding_window_size - 1) | ||
# Slice the window for short queries during generation. | ||
start = (cache_update_index, 0) | ||
sliding_mask = ops.slice(sliding_mask, start, (query_len, key_len)) | ||
sliding_mask = ops.expand_dims(sliding_mask, 0) | ||
return ops.logical_and(attention_mask, ops.cast(sliding_mask, "bool")) | ||
|
||
def get_config(self): | ||
config = super().get_config() | ||
config.update( | ||
{ | ||
"num_query_heads": self.num_query_heads, | ||
"num_key_value_heads": self.num_key_value_heads, | ||
"rope_max_wavelength": self.rope_max_wavelength, | ||
"rope_scaling_factor": self.rope_scaling_factor, | ||
"kernel_initializer": keras.initializers.serialize( | ||
self.kernel_initializer | ||
), | ||
"dropout": self.dropout, | ||
"sliding_window_size": self.sliding_window_size, | ||
"layer_index": self.layer_index, | ||
"head_dim": self.head_dim, | ||
"layer_norm_epsilon": self.layer_norm_epsilon, | ||
} | ||
) | ||
return config |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.