Skip to content

Commit b58b56e

Browse files
fix GemmaBackbone.get_layout_map + test (#1669)
* fix to default GemmaBackbone.get_layout_map() to use fixed regexes as per keras-team/keras#19496 (comment) * fix to default GemmaBackbone.get_layout_map() to use fixed regexes as per keras-team/keras#19496 (comment) * Also fixing forgotten ffw_gating_2 in GemmaBackbone.get_layout_map. The sharding spec ("batch", "model") is the one that provides the best training performance. ("batch", "model") and (None, None) are slower (the first one by 40%, the second by 2%). Fixing test too, including typo ffw_linearl => ffw_linear * changed test_architecture_characteristics test to follow the 4->8 heads change necessary for the test to work on TPUs. Also fixed formatting. * Update gemma_backbone_test.py Better test messages --------- Co-authored-by: Matt Watson <[email protected]>
1 parent e0efbc8 commit b58b56e

File tree

2 files changed

+42
-8
lines changed

2 files changed

+42
-8
lines changed

keras_nlp/src/models/gemma/gemma_backbone.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -255,17 +255,18 @@ def get_layout_map(
255255
# See https://arxiv.org/abs/2403.08295
256256
layout_map = keras.distribution.LayoutMap(device_mesh)
257257
layout_map["token_embedding/embeddings"] = (model_dim, data_dim)
258-
layout_map["decoder_block.*attention.*(query|key|value).*kernel"] = (
258+
layout_map["decoder_block.*attention.*(query|key|value).kernel"] = (
259259
model_dim,
260260
data_dim,
261261
None,
262262
)
263-
layout_map["decoder_block.*attention_output.*kernel"] = (
263+
layout_map["decoder_block.*attention_output.kernel"] = (
264264
model_dim,
265265
None,
266266
data_dim,
267267
)
268-
layout_map["decoder_block.*ffw_gating.*kernel"] = (data_dim, model_dim)
269-
layout_map["decoder_block.*ffw_linear.*kernel"] = (model_dim, data_dim)
268+
layout_map["decoder_block.*ffw_gating.kernel"] = (data_dim, model_dim)
269+
layout_map["decoder_block.*ffw_gating_2.kernel"] = (data_dim, model_dim)
270+
layout_map["decoder_block.*ffw_linear.kernel"] = (model_dim, data_dim)
270271

271272
return layout_map

keras_nlp/src/models/gemma/gemma_backbone_test.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ def setUp(self):
2424
self.init_kwargs = {
2525
"vocabulary_size": 256128,
2626
"num_layers": 2,
27-
"num_query_heads": 4,
28-
"num_key_value_heads": 4,
27+
"num_query_heads": 8,
28+
"num_key_value_heads": 8,
2929
"hidden_dim": 128,
3030
"intermediate_dim": 256,
3131
"head_dim": 128,
@@ -82,7 +82,7 @@ def test_all_presets(self):
8282

8383
def test_architecture_characteristics(self):
8484
model = GemmaBackbone(**self.init_kwargs)
85-
self.assertEqual(model.count_params(), 33407616)
85+
self.assertEqual(model.count_params(), 33931904)
8686
self.assertEqual(len(model.layers), 6)
8787

8888
def test_distribution(self):
@@ -132,7 +132,40 @@ def test_distribution(self):
132132
self.assertEqual(
133133
tuple(w.value.sharding.spec), ("batch", "model")
134134
)
135-
if "ffw_linearl" in w.path:
135+
if "ffw_linear" in w.path:
136136
self.assertEqual(
137137
tuple(w.value.sharding.spec), ("model", "batch")
138138
)
139+
140+
def test_distribution_with_lora(self):
141+
if keras.backend.backend() != "jax":
142+
self.skipTest("`ModelParallel` testing requires the Jax backend.")
143+
devices = keras.distribution.list_devices("CPU")
144+
if len(devices) == 1:
145+
# Need more than 1 device for distribution testing.
146+
self.skipTest("`ModelParallel` testing requires multiple devices.")
147+
device_mesh = keras.distribution.DeviceMesh(
148+
shape=(1, len(devices)),
149+
axis_names=("batch", "model"),
150+
devices=devices,
151+
)
152+
153+
layout_map = GemmaBackbone.get_layout_map(device_mesh)
154+
distribution = keras.distribution.ModelParallel(device_mesh, layout_map)
155+
with distribution.scope():
156+
model = GemmaBackbone(**self.init_kwargs)
157+
model.enable_lora(rank=4)
158+
159+
for w in model.weights:
160+
if "attention/query/lora_kernel_a" in w.path:
161+
self.assertEqual(
162+
tuple(w.value.sharding.spec), (None, None, None)
163+
)
164+
if "attention/query/lora_kernel_b" in w.path:
165+
self.assertEqual(tuple(w.value.sharding.spec), (None, None))
166+
if "attention/value/lora_kernel_a" in w.path:
167+
self.assertEqual(
168+
tuple(w.value.sharding.spec), (None, None, None)
169+
)
170+
if "attention/value/lora_kernel_b" in w.path:
171+
self.assertEqual(tuple(w.value.sharding.spec), (None, None))

0 commit comments

Comments
 (0)