Skip to content

Commit 193ea36

Browse files
Add Mixtral (#2196)
* mistral init commit * wip mixtral * mixtral wip * checkpoint conversion wip * mixtral weight matching complete * batched moe impl * output matching with batched moe complete * update * flash attention fixes * bug fixes * bug fix * address comments * api gen * update * update * chore: address feedback --------- Co-authored-by: Anshuman Mishra <[email protected]> Co-authored-by: Divyashree Sreepathihalli <[email protected]>
1 parent 634590b commit 193ea36

17 files changed

+2032
-0
lines changed

keras_hub/api/models/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,18 @@
348348
from keras_hub.src.models.mit.mit_image_classifier_preprocessor import (
349349
MiTImageClassifierPreprocessor as MiTImageClassifierPreprocessor,
350350
)
351+
from keras_hub.src.models.mixtral.mixtral_backbone import (
352+
MixtralBackbone as MixtralBackbone,
353+
)
354+
from keras_hub.src.models.mixtral.mixtral_causal_lm import (
355+
MixtralCausalLM as MixtralCausalLM,
356+
)
357+
from keras_hub.src.models.mixtral.mixtral_causal_lm_preprocessor import (
358+
MixtralCausalLMPreprocessor as MixtralCausalLMPreprocessor,
359+
)
360+
from keras_hub.src.models.mixtral.mixtral_tokenizer import (
361+
MixtralTokenizer as MixtralTokenizer,
362+
)
351363
from keras_hub.src.models.mobilenet.mobilenet_backbone import (
352364
MobileNetBackbone as MobileNetBackbone,
353365
)

keras_hub/api/tokenizers/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@
5555
from keras_hub.src.models.mistral.mistral_tokenizer import (
5656
MistralTokenizer as MistralTokenizer,
5757
)
58+
from keras_hub.src.models.mixtral.mixtral_tokenizer import (
59+
MixtralTokenizer as MixtralTokenizer,
60+
)
5861
from keras_hub.src.models.opt.opt_tokenizer import OPTTokenizer as OPTTokenizer
5962
from keras_hub.src.models.pali_gemma.pali_gemma_tokenizer import (
6063
PaliGemmaTokenizer as PaliGemmaTokenizer,
Lines changed: 263 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,263 @@
1+
import inspect
2+
import math
3+
4+
import keras
5+
from keras import ops
6+
7+
from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
8+
from keras_hub.src.utils.keras_utils import clone_initializer
9+
from keras_hub.src.utils.keras_utils import fused_attention_op_available
10+
from keras_hub.src.utils.keras_utils import gpu_supports_fused_attention_op
11+
from keras_hub.src.utils.keras_utils import running_on_gpu
12+
from keras_hub.src.utils.keras_utils import running_on_tpu
13+
14+
15+
class CachedMixtralAttention(keras.layers.Layer):
16+
"""A cached grounded query attention layer with sliding window."""
17+
18+
def __init__(
19+
self,
20+
num_query_heads,
21+
num_key_value_heads,
22+
rope_max_wavelength=10000,
23+
rope_scaling_factor=1.0,
24+
kernel_initializer="glorot_uniform",
25+
sliding_window=512,
26+
dropout=0,
27+
**kwargs,
28+
):
29+
super().__init__(**kwargs)
30+
self._num_query_heads = num_query_heads
31+
self._num_key_value_heads = num_key_value_heads
32+
self._sliding_window = sliding_window
33+
self._dropout = dropout
34+
35+
self._num_key_value_groups = num_query_heads // num_key_value_heads
36+
self._rope_max_wavelength = rope_max_wavelength
37+
38+
self._kernel_initializer = keras.initializers.get(
39+
clone_initializer(kernel_initializer)
40+
)
41+
42+
self._rope_scaling_factor = rope_scaling_factor
43+
44+
def build(self, inputs_shape):
45+
# Einsum variables:
46+
# b = batch size
47+
# q = query length
48+
# k = key/value length
49+
# m = model dim
50+
# u = num query heads
51+
# v = num key/value heads
52+
# h = head dim
53+
self._hidden_dim = inputs_shape[-1]
54+
self._head_dim = self._hidden_dim // self._num_query_heads
55+
self._inv_norm_factor = 1.0 / math.sqrt(self._head_dim)
56+
57+
self.query_dense = keras.layers.EinsumDense(
58+
equation="bqm,muh->bquh",
59+
output_shape=(None, self._num_query_heads, self._head_dim),
60+
kernel_initializer=self._kernel_initializer,
61+
dtype=self.dtype_policy,
62+
name="query",
63+
)
64+
self.query_dense.build(inputs_shape)
65+
66+
self.key_dense = keras.layers.EinsumDense(
67+
equation="bkm,mvh->bkvh",
68+
output_shape=(
69+
None,
70+
self._num_key_value_heads,
71+
self._head_dim,
72+
),
73+
kernel_initializer=self._kernel_initializer,
74+
dtype=self.dtype_policy,
75+
name="key",
76+
)
77+
self.key_dense.build(inputs_shape)
78+
79+
self.value_dense = keras.layers.EinsumDense(
80+
equation="bkm,mvh->bkvh",
81+
output_shape=(
82+
None,
83+
self._num_key_value_heads,
84+
self._head_dim,
85+
),
86+
kernel_initializer=self._kernel_initializer,
87+
dtype=self.dtype_policy,
88+
name="value",
89+
)
90+
self.value_dense.build(inputs_shape)
91+
92+
self._softmax = keras.layers.Softmax(
93+
axis=-1,
94+
dtype="float32",
95+
name="attention_softmax",
96+
)
97+
98+
self._dropout_layer = keras.layers.Dropout(
99+
rate=self._dropout,
100+
dtype=self.dtype_policy,
101+
)
102+
103+
self._output_dense = keras.layers.EinsumDense(
104+
equation="bquh,uhm->bqm",
105+
output_shape=(None, self._hidden_dim),
106+
kernel_initializer=self._kernel_initializer,
107+
dtype=self.dtype_policy,
108+
name="attention_output",
109+
)
110+
self._output_dense.build(
111+
(None, None, self._num_query_heads, self._head_dim)
112+
)
113+
114+
self.rotary_embedding_layer = RotaryEmbedding(
115+
max_wavelength=self._rope_max_wavelength,
116+
scaling_factor=self._rope_scaling_factor,
117+
dtype=self.dtype_policy,
118+
)
119+
120+
self._dot_product_equation = "bquh,bkuh->buqk"
121+
self._combine_equation = "buqk,bkuh->bquh"
122+
123+
self.built = True
124+
125+
def call(
126+
self,
127+
hidden_states,
128+
attention_mask=None,
129+
cache=None,
130+
cache_update_index=None,
131+
training=None,
132+
):
133+
start_index = (
134+
cache_update_index if cache_update_index is not None else 0
135+
)
136+
137+
query = self.query_dense(hidden_states)
138+
139+
# Compute RoPE for queries
140+
query = self.rotary_embedding_layer(query, start_index=start_index)
141+
142+
def _compute_key_value(x):
143+
key, value = self.key_dense(x), self.value_dense(x)
144+
# Compute RoPE for keys
145+
key = self.rotary_embedding_layer(key, start_index=start_index)
146+
return key, value
147+
148+
if cache is not None:
149+
key_cache = cache[:, 0, ...]
150+
value_cache = cache[:, 1, ...]
151+
if cache_update_index is None:
152+
key = key_cache
153+
value = value_cache
154+
else:
155+
key_update, value_update = _compute_key_value(hidden_states)
156+
start = [0, cache_update_index, 0, 0]
157+
key = ops.slice_update(key_cache, start, key_update)
158+
value = ops.slice_update(value_cache, start, value_update)
159+
cache = ops.stack((key, value), axis=1)
160+
else:
161+
if cache_update_index is not None:
162+
raise ValueError(
163+
"`cache_update_index` should not be set if `cache` is "
164+
f"`None`. Received: cache={cache}, "
165+
f"cache_update_index={cache_update_index}"
166+
)
167+
key, value = _compute_key_value(hidden_states)
168+
169+
# [batch_shape, seq_len, num_key_value_heads, head_dim]
170+
# -> [batch_shape, seq_len, num_heads, head_dim]
171+
key = ops.repeat(key, repeats=self._num_key_value_groups, axis=2)
172+
value = ops.repeat(value, repeats=self._num_key_value_groups, axis=2)
173+
174+
attention_output = self._compute_attention(
175+
query, key, value, attention_mask
176+
)
177+
178+
attention_output = self._dropout_layer(
179+
attention_output, training=training
180+
)
181+
182+
attention_output = self._output_dense(attention_output)
183+
184+
if cache is not None:
185+
return attention_output, cache
186+
return attention_output
187+
188+
def _masked_softmax(self, attention_scores, attention_mask=None):
189+
if attention_mask is not None:
190+
return self._softmax(
191+
attention_scores, attention_mask[:, None, :, :]
192+
)
193+
return self._softmax(attention_scores)
194+
195+
def _use_fused_attention_op(self):
196+
if not fused_attention_op_available():
197+
return False
198+
if self.dropout > 0.0:
199+
return False
200+
if running_on_gpu():
201+
# GPU never supports softcap in the fused op.
202+
if self.logit_soft_cap is not None:
203+
return False
204+
return gpu_supports_fused_attention_op()
205+
elif running_on_tpu():
206+
# TPU supports softcap with on keras >= 3.10.
207+
sig = inspect.signature(ops.dot_product_attention)
208+
return "attn_logits_soft_cap" in sig.parameters
209+
else:
210+
return False
211+
212+
def _compute_attention(self, query, key, value, attention_mask=None):
213+
if self._use_fused_attention_op():
214+
if attention_mask is not None:
215+
attention_mask = ops.expand_dims(attention_mask, axis=1)
216+
attention_mask = ops.cast(attention_mask, dtype="bool")
217+
218+
if self.logit_soft_cap:
219+
kwargs = {"attn_logits_soft_cap": self.logit_soft_cap}
220+
else:
221+
kwargs = {}
222+
223+
attention_output = ops.dot_product_attention(
224+
query,
225+
key,
226+
value,
227+
mask=attention_mask,
228+
scale=self._inv_norm_factor,
229+
**kwargs,
230+
)
231+
return attention_output
232+
233+
attention_scores = ops.einsum(self._dot_product_equation, query, key)
234+
attention_scores = ops.multiply(
235+
attention_scores,
236+
ops.cast(self._inv_norm_factor, self.compute_dtype),
237+
)
238+
attention_scores = self._masked_softmax(
239+
attention_scores, attention_mask
240+
)
241+
attention_scores = ops.cast(attention_scores, self.compute_dtype)
242+
attention_output = ops.einsum(
243+
self._combine_equation, attention_scores, value
244+
)
245+
246+
return attention_output
247+
248+
def get_config(self):
249+
config = super().get_config()
250+
config.update(
251+
{
252+
"num_query_heads": self._num_query_heads,
253+
"num_key_value_heads": self._num_key_value_heads,
254+
"rope_max_wavelength": self._rope_max_wavelength,
255+
"rope_scaling_factor": self._rope_scaling_factor,
256+
"kernel_initializer": keras.initializers.serialize(
257+
self._kernel_initializer
258+
),
259+
"sliding_window": self._sliding_window,
260+
"dropout": self._dropout,
261+
}
262+
)
263+
return config

0 commit comments

Comments
 (0)