Skip to content

Commit 37e52a8

Browse files
grasskinmattdangerw
andcommitted
Add Gemma2 to Keras (#91)
Add Gemma2 building blocks and presets. --------- Co-authored-by: Matt Watson <[email protected]>
1 parent b58b56e commit 37e52a8

7 files changed

+257
-9
lines changed

keras_nlp/src/models/gemma/gemma_attention.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,19 +28,28 @@ def __init__(
2828
num_query_heads,
2929
num_key_value_heads,
3030
kernel_initializer="glorot_uniform",
31+
logit_soft_cap=None,
32+
use_sliding_window_attention=False,
33+
sliding_window_size=4096,
34+
query_head_dim_normalize=True,
3135
dropout=0,
3236
**kwargs,
3337
):
3438
super().__init__(**kwargs)
3539
self.num_query_heads = num_query_heads
3640
self.num_key_value_heads = num_key_value_heads
3741
self.head_dim = head_dim
42+
self.logit_soft_cap = logit_soft_cap
43+
self.use_sliding_window_attention = use_sliding_window_attention
44+
self.sliding_window_size = sliding_window_size
45+
self.query_head_dim_normalize = query_head_dim_normalize
3846
self.dropout = dropout
3947

4048
self._kernel_initializer = keras.initializers.get(
4149
clone_initializer(kernel_initializer)
4250
)
4351
self.num_key_value_groups = num_query_heads // num_key_value_heads
52+
self.query_head_dim_normalize = query_head_dim_normalize
4453

4554
def build(self, inputs_shape):
4655
self.hidden_dim = inputs_shape[-1]
@@ -114,7 +123,12 @@ def _compute_attention(
114123
attention_mask,
115124
training=False,
116125
):
117-
query_normalization = 1 / np.sqrt(self.head_dim)
126+
if self.query_head_dim_normalize:
127+
query_normalization = 1 / np.sqrt(self.head_dim)
128+
else:
129+
query_normalization = 1 / np.sqrt(
130+
self.hidden_dim // self.num_query_heads
131+
)
118132

119133
q *= ops.cast(query_normalization, dtype=q.dtype)
120134
q_shape = ops.shape(q)
@@ -130,6 +144,38 @@ def _compute_attention(
130144
b, q_len, _, _, h = ops.shape(q)
131145

132146
attention_logits = ops.einsum("btkgh,bskh->bkgts", q, k)
147+
148+
if self.logit_soft_cap is not None:
149+
attention_logits = ops.divide(attention_logits, self.logit_soft_cap)
150+
attention_logits = ops.multiply(
151+
ops.tanh(attention_logits), self.logit_soft_cap
152+
)
153+
154+
if self.use_sliding_window_attention:
155+
all_ones = ops.ones_like(attention_mask)
156+
if keras.config.backend() == "tensorflow":
157+
import tensorflow as tf
158+
159+
sliding_window_size = ops.minimum(
160+
self.sliding_window_size - 1, q_len
161+
)
162+
sliding_window_size = ops.cast(
163+
sliding_window_size, dtype="int32"
164+
)
165+
sliding_mask = tf.linalg.band_part(
166+
all_ones, sliding_window_size - 1, sliding_window_size - 1
167+
)
168+
sliding_mask = ops.cast(sliding_mask, dtype="bool")
169+
bool_attention_mask = ops.cast(attention_mask, dtype="bool")
170+
attention_mask = tf.math.logical_and(
171+
sliding_mask, bool_attention_mask
172+
)
173+
else:
174+
sliding_mask = ops.triu(
175+
all_ones, -1 * self.sliding_window_size + 1
176+
) * ops.tril(all_ones, self.sliding_window_size - 1)
177+
attention_mask = sliding_mask * attention_mask
178+
133179
attention_mask = attention_mask[:, None, None, :, :]
134180
orig_dtype = attention_logits.dtype
135181
attention_softmax = self.softmax(attention_logits, mask=attention_mask)
@@ -186,3 +232,6 @@ def call(
186232
if cache is not None:
187233
return attention_output, cache
188234
return attention_output
235+
236+
def compute_output_shape(self, input_shape):
237+
return input_shape

keras_nlp/src/models/gemma/gemma_backbone.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,21 @@ class GemmaBackbone(Backbone):
5454
layer_norm_epsilon: float. The epsilon value user for every layer norm
5555
in the transformer model.
5656
dropout: float. Dropout probability for the Transformer encoder.
57+
query_head_dim_normalize: boolean. Whether to normalize attention with
58+
head dimension or hidden_dim/num_query_heads. Gemma2 uses the
59+
second option. Defaults to True.
60+
use_post_ffw_norm: boolean. Whether to normalize after the feedforward
61+
block. Defaults to False.
62+
use_post_attention_norm: boolean. Whether to normalize after the attention
63+
block. Defaults to False.
64+
attention_logit_soft_cap: None or int. Soft cap for the attention logits.
65+
Defaults to None.
66+
final_logit_soft_cap: None or int. Soft cap for the final logits.
67+
Defaults to None.
68+
use_sliding_window_attention boolean. Whether to use sliding local
69+
window attention. Defaults to False.
70+
sliding_window_size: int. Size of the sliding local window. Defaults to
71+
4096.
5772
dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
5873
for the models computations and weights. Note that some
5974
computations, such as softmax and layer normalization will always
@@ -93,6 +108,13 @@ def __init__(
93108
hidden_dim,
94109
intermediate_dim,
95110
head_dim,
111+
query_head_dim_normalize=True,
112+
use_post_ffw_norm=False,
113+
use_post_attention_norm=False,
114+
attention_logit_soft_cap=None,
115+
final_logit_soft_cap=None,
116+
use_sliding_window_attention=False,
117+
sliding_window_size=4096,
96118
layer_norm_epsilon=1e-6,
97119
dropout=0,
98120
dtype=None,
@@ -114,12 +136,19 @@ def __init__(
114136
)
115137
self.transformer_layers = []
116138
for i in range(num_layers):
139+
sliding_window = use_sliding_window_attention and (i % 2 == 0)
117140
layer = GemmaDecoderBlock(
118141
intermediate_dim=intermediate_dim,
119142
hidden_dim=hidden_dim,
120143
num_query_heads=num_query_heads,
121144
head_dim=head_dim,
122145
num_key_value_heads=num_key_value_heads,
146+
query_head_dim_normalize=query_head_dim_normalize,
147+
use_post_ffw_norm=use_post_ffw_norm,
148+
use_post_attention_norm=use_post_attention_norm,
149+
logit_soft_cap=attention_logit_soft_cap,
150+
use_sliding_window_attention=sliding_window,
151+
sliding_window_size=sliding_window_size,
123152
dropout=dropout,
124153
dtype=dtype,
125154
name=f"decoder_block_{i}",
@@ -163,6 +192,13 @@ def __init__(
163192
self.head_dim = head_dim
164193
self.layer_norm_epsilon = layer_norm_epsilon
165194
self.dropout = dropout
195+
self.query_head_dim_normalize = query_head_dim_normalize
196+
self.use_post_ffw_norm = use_post_ffw_norm
197+
self.use_post_attention_norm = use_post_attention_norm
198+
self.attention_logit_soft_cap = attention_logit_soft_cap
199+
self.final_logit_soft_cap = final_logit_soft_cap
200+
self.sliding_window_size = sliding_window_size
201+
self.use_sliding_window_attention = use_sliding_window_attention
166202

167203
def get_config(self):
168204
config = super().get_config()
@@ -177,6 +213,13 @@ def get_config(self):
177213
"head_dim": self.head_dim,
178214
"layer_norm_epsilon": self.layer_norm_epsilon,
179215
"dropout": self.dropout,
216+
"query_head_dim_normalize": self.query_head_dim_normalize,
217+
"use_post_ffw_norm": self.use_post_ffw_norm,
218+
"use_post_attention_norm": self.use_post_attention_norm,
219+
"final_logit_soft_cap": self.final_logit_soft_cap,
220+
"attention_logit_soft_cap": self.attention_logit_soft_cap,
221+
"sliding_window_size": self.sliding_window_size,
222+
"use_sliding_window_attention": self.use_sliding_window_attention,
180223
}
181224
)
182225
return config

keras_nlp/src/models/gemma/gemma_backbone_test.py

Lines changed: 50 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,13 @@
2222
class GemmaBackboneTest(TestCase):
2323
def setUp(self):
2424
self.init_kwargs = {
25-
"vocabulary_size": 256128,
25+
"vocabulary_size": 20,
2626
"num_layers": 2,
27-
"num_query_heads": 8,
28-
"num_key_value_heads": 8,
29-
"hidden_dim": 128,
30-
"intermediate_dim": 256,
31-
"head_dim": 128,
27+
"num_query_heads": 4,
28+
"num_key_value_heads": 1,
29+
"hidden_dim": 16,
30+
"intermediate_dim": 32,
31+
"head_dim": 4,
3232
"layer_norm_epsilon": 1e-6,
3333
}
3434
self.input_data = {
@@ -41,7 +41,7 @@ def test_backbone_basics(self):
4141
cls=GemmaBackbone,
4242
init_kwargs=self.init_kwargs,
4343
input_data=self.input_data,
44-
expected_output_shape=(2, 5, 128),
44+
expected_output_shape=(2, 5, 16),
4545
)
4646

4747
@pytest.mark.large
@@ -82,7 +82,7 @@ def test_all_presets(self):
8282

8383
def test_architecture_characteristics(self):
8484
model = GemmaBackbone(**self.init_kwargs)
85-
self.assertEqual(model.count_params(), 33931904)
85+
self.assertEqual(model.count_params(), 3216)
8686
self.assertEqual(len(model.layers), 6)
8787

8888
def test_distribution(self):
@@ -169,3 +169,45 @@ def test_distribution_with_lora(self):
169169
)
170170
if "attention/value/lora_kernel_b" in w.path:
171171
self.assertEqual(tuple(w.value.sharding.spec), (None, None))
172+
173+
174+
@pytest.mark.keras_3_only
175+
class Gemma2BackboneTest(TestCase):
176+
def setUp(self):
177+
self.init_kwargs = {
178+
"vocabulary_size": 20, # 256128
179+
"num_layers": 2, # 46
180+
"num_query_heads": 4, # 32
181+
"num_key_value_heads": 2, # 16
182+
"hidden_dim": 16, # 4608
183+
"intermediate_dim": 32, # 73728
184+
"head_dim": 4, # 128
185+
"sliding_window_size": 5, # 4096
186+
"attention_logit_soft_cap": 50,
187+
"final_logit_soft_cap": 30,
188+
"layer_norm_epsilon": 1e-6,
189+
"query_head_dim_normalize": False,
190+
"use_post_ffw_norm": True,
191+
"use_post_attention_norm": True,
192+
"use_sliding_window_attention": True,
193+
}
194+
self.input_data = {
195+
"token_ids": ops.ones((2, 10), dtype="int32"),
196+
"padding_mask": ops.ones((2, 10), dtype="int32"),
197+
}
198+
199+
def test_backbone_basics(self):
200+
self.run_backbone_test(
201+
cls=GemmaBackbone,
202+
init_kwargs=self.init_kwargs,
203+
input_data=self.input_data,
204+
expected_output_shape=(2, 10, 16),
205+
)
206+
207+
@pytest.mark.large
208+
def test_saved_model(self):
209+
self.run_model_saving_test(
210+
cls=GemmaBackbone,
211+
init_kwargs=self.init_kwargs,
212+
input_data=self.input_data,
213+
)

keras_nlp/src/models/gemma/gemma_causal_lm.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,9 +223,17 @@ def call_with_cache(
223223
cache_update_index=cache_update_index,
224224
)
225225
caches.append(next_cache)
226+
226227
cache = ops.stack(caches, axis=1)
227228
hidden_states = x = self.backbone.layer_norm(x)
228229
logits = self.backbone.token_embedding(x, reverse=True)
230+
231+
if self.backbone.final_logit_soft_cap is not None:
232+
logits = ops.divide(logits, self.backbone.final_logit_soft_cap)
233+
logits = ops.multiply(
234+
ops.tanh(logits), self.backbone.final_logit_soft_cap
235+
)
236+
229237
return logits, hidden_states, cache
230238

231239
def _build_cache(self, token_ids):

keras_nlp/src/models/gemma/gemma_causal_lm_test.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,3 +264,18 @@ def layer_intercept_fn_for_testing(x, i):
264264
# Assert shapes for info exfiltrated into the parent context.
265265
self.assertEqual(ops.shape(embedded_prompts), expected_embedded_shape)
266266
self.assertEqual(ops.shape(scores), expected_score_shape)
267+
268+
269+
class Gemma2CausalLMTest(TestCase):
270+
@pytest.mark.large
271+
def test_preset(self):
272+
# Setup prompts, models, and associated expected shapes.
273+
keras.config.set_floatx("bfloat16")
274+
gemma_lm = GemmaCausalLM.from_preset(
275+
"/usr/local/google/home/grasskin/gemma2/keras-nlp-private/gemma_9b_en"
276+
)
277+
# gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma2_9b_en")
278+
gemma_lm.summary()
279+
print(
280+
gemma_lm.generate("what is the meaning of life?.", max_length=256)
281+
)

0 commit comments

Comments
 (0)