Skip to content

Commit b2ec380

Browse files
More consistent defaults for PaliGemma
* More consistent defaults for PaliGemma In general, we do not copy the hyper parameters for a specific pre-trained model into the init args. Do the same here for consistency. Also, use as small test models as possible, so our unit testing stays somewhat reasonable. * add basica nd saved model test --------- Co-authored-by: divyashreepathihalli <[email protected]>
1 parent f8845ad commit b2ec380

File tree

5 files changed

+147
-106
lines changed

5 files changed

+147
-106
lines changed

keras_nlp/src/models/pali_gemma/pali_gemma_backbone.py

+27-24
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,6 @@ class PaliGemmaBackbone(Backbone):
6060
intermediate_dim: int. The output dimension of the first Dense layer in
6161
a two-layer feedforward network for each transformer decoder block.
6262
head_dim: int. The size of each attention head in the mixed decoder.
63-
layer_norm_epsilon: float. The epsilon value user for every layer norm
64-
in all transformer blocks.
65-
dropout: float. Dropout probability for the Transformer decoder blocks.
6663
vit_patch_size: int. The size of each square patch in the input image.
6764
vit_num_heads: int. The number of attention heads for the vision(image)
6865
transformer encoder.
@@ -76,9 +73,10 @@ class PaliGemmaBackbone(Backbone):
7673
`"0"` or `"none"`. Defaults to `"none"`.
7774
vit_classifier_activation: activation function. The activation that
7875
is used for final output classification in the vision transformer.
79-
vit_include_rescaling: bool. To be set to `True` if input image values
80-
needs to be rescaled between 0-1.
8176
vit_name: string. The name used for vision transformer layers.
77+
layer_norm_epsilon: float. The epsilon value user for every layer norm
78+
in all transformer blocks.
79+
dropout: float. Dropout probability for the Transformer decoder blocks.
8280
dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
8381
for the models computations and weights. Note that some
8482
computations, such as softmax and layer normalization will always
@@ -99,44 +97,46 @@ class PaliGemmaBackbone(Backbone):
9997
# Randomly initialized PaliGemma decoder with custom config.
10098
model = keras_nlp.models.PaliGemmaBackbone(
10199
vocabulary_size=50257,
100+
images_size=224,
102101
num_layers=12,
103102
num_query_heads=12,
104103
num_key_value_heads=1,
105104
hidden_dim=768,
106105
intermediate_dim=3072,
107106
head_dim=64,
107+
vit_patch_size=14,
108+
vit_num_heads=8,
109+
vit_hidden_dim=768,
110+
vit_intermediate_dim=3072,
111+
vit_num_layers=2,
108112
)
109113
model(input_data)
110114
```
111115
"""
112116

113117
def __init__(
114118
self,
115-
vocabulary_size=257152,
116-
image_size=224,
117-
num_layers=18,
118-
num_query_heads=8,
119-
num_key_value_heads=1,
120-
hidden_dim=2048,
121-
intermediate_dim=32768,
122-
head_dim=256,
123-
layer_norm_epsilon=1e-6,
124-
dropout=0,
125-
vit_patch_size=14,
126-
vit_num_heads=16,
127-
vit_hidden_dim=1152,
128-
vit_num_layers=27,
129-
vit_intermediate_dim=4304,
119+
vocabulary_size,
120+
image_size,
121+
num_layers,
122+
num_query_heads,
123+
num_key_value_heads,
124+
hidden_dim,
125+
intermediate_dim,
126+
head_dim,
127+
vit_patch_size,
128+
vit_num_heads,
129+
vit_hidden_dim,
130+
vit_num_layers,
131+
vit_intermediate_dim=None, # TODO remove default
130132
vit_pooling=None,
131133
vit_classifier_activation=None,
132134
vit_name=None,
135+
layer_norm_epsilon=1e-6,
136+
dropout=0,
133137
dtype=None,
134138
**kwargs,
135139
):
136-
# TODO: remove these from our uploaded models.
137-
kwargs.pop("vit_num_classes", None)
138-
kwargs.pop("vit_include_rescaling", None)
139-
140140
if not config.keras_3():
141141
raise ValueError(
142142
"`PaliGemmaBackbone` requires Keras 3. Run "
@@ -159,6 +159,8 @@ def __init__(
159159
dtype=dtype,
160160
name="token_embedding",
161161
)
162+
# TODO Remove this. Work around for previous serialization bug.
163+
vit_intermediate_dim = vit_intermediate_dim or 4304
162164
self.vit_encoder = PaliGemmaVit(
163165
image_size=image_size,
164166
patch_size=vit_patch_size,
@@ -268,6 +270,7 @@ def get_config(self):
268270
"vit_num_heads": self.vit_num_heads,
269271
"vit_hidden_dim": self.vit_hidden_dim,
270272
"vit_num_layers": self.vit_num_layers,
273+
"vit_intermediate_dim": self.vit_intermediate_dim,
271274
"vit_pooling": self.vit_pooling,
272275
"vit_classifier_activation": self.vit_classifier_activation,
273276
"vit_name": self.vit_name,

keras_nlp/src/models/pali_gemma/pali_gemma_backbone_test.py

+16-21
Original file line numberDiff line numberDiff line change
@@ -34,17 +34,12 @@ def setUp(self):
3434
self.batch_size = 2
3535
self.vocabulary_size = 256
3636
self.text_sequence_length = 64
37-
self.image_size = 224
37+
self.image_size = 16
3838
self.dummy_text = [
3939
"the quick brown fox" for _ in range(self.batch_size)
4040
]
4141
self.dummy_images = np.random.uniform(
42-
size=(
43-
self.batch_size,
44-
self.image_size,
45-
self.image_size,
46-
3,
47-
)
42+
size=(self.batch_size, self.image_size, self.image_size, 3)
4843
)
4944

5045
proto = "gemma_test_vocab.spm"
@@ -56,19 +51,19 @@ def setUp(self):
5651
)
5752

5853
self.backbone = PaliGemmaBackbone(
59-
self.vocabulary_size,
60-
image_size=224,
61-
num_layers=27,
62-
num_query_heads=16,
63-
num_key_value_heads=16,
64-
hidden_dim=256,
65-
intermediate_dim=256,
66-
head_dim=126,
67-
vit_patch_size=14,
68-
vit_num_heads=8,
69-
vit_hidden_dim=16,
54+
vocabulary_size=self.vocabulary_size,
55+
image_size=self.image_size,
56+
num_layers=2,
57+
num_query_heads=2,
58+
num_key_value_heads=1,
59+
hidden_dim=8,
60+
intermediate_dim=16,
61+
head_dim=4,
62+
vit_patch_size=4,
7063
vit_num_layers=2,
71-
vit_intermediate_dim=8,
64+
vit_num_heads=2,
65+
vit_hidden_dim=8,
66+
vit_intermediate_dim=16,
7267
)
7368
self.dummy_imgs = np.random.rand(
7469
self.batch_size, self.image_size, self.image_size, 3
@@ -99,7 +94,7 @@ def test_pali_gemma_backbone(self):
9994
(
10095
self.batch_size,
10196
self.text_sequence_length + self.backbone.image_sequence_length,
102-
256,
97+
8,
10398
),
10499
output.shape,
105100
)
@@ -117,7 +112,7 @@ def test_pali_gemma_backbone_with_preprocessing(self):
117112
(
118113
self.batch_size,
119114
self.text_sequence_length + self.backbone.image_sequence_length,
120-
256,
115+
8,
121116
),
122117
output.shape,
123118
)

keras_nlp/src/models/pali_gemma/pali_gemma_causal_lm_test.py

+56-21
Original file line numberDiff line numberDiff line change
@@ -35,18 +35,13 @@
3535
class PaliGemmaCausalLMTest(TestCase):
3636
def setUp(self):
3737
self.batch_size = 2
38-
self.text_sequence_length = 64
39-
self.image_size = 224
38+
self.text_sequence_length = 16
39+
self.image_size = 16
4040
self.dummy_text = [
4141
"the quick brown fox" for _ in range(self.batch_size)
4242
]
4343
self.dummy_images = np.random.uniform(
44-
size=(
45-
self.batch_size,
46-
self.image_size,
47-
self.image_size,
48-
3,
49-
)
44+
size=(self.batch_size, self.image_size, self.image_size, 3)
5045
)
5146

5247
proto = "gemma_test_vocab.spm"
@@ -62,20 +57,60 @@ def setUp(self):
6257
)
6358

6459
self.backbone = PaliGemmaBackbone(
65-
self.vocabulary_size,
66-
image_size=224,
67-
num_layers=27,
68-
num_query_heads=16,
69-
num_key_value_heads=16,
70-
hidden_dim=256,
71-
intermediate_dim=256,
72-
head_dim=126,
73-
vit_patch_size=14,
74-
vit_num_heads=8,
75-
vit_hidden_dim=16,
60+
vocabulary_size=self.vocabulary_size,
61+
image_size=self.image_size,
62+
num_layers=2,
63+
num_query_heads=2,
64+
num_key_value_heads=1,
65+
hidden_dim=8,
66+
intermediate_dim=16,
67+
head_dim=4,
68+
vit_patch_size=4,
7669
vit_num_layers=2,
77-
vit_intermediate_dim=8,
78-
vit_num_classes=512,
70+
vit_num_heads=2,
71+
vit_hidden_dim=8,
72+
vit_intermediate_dim=16,
73+
)
74+
self.train_data = (
75+
{
76+
"images": self.dummy_images,
77+
"prompts": self.dummy_text,
78+
"responses": self.dummy_text,
79+
},
80+
)
81+
self.init_kwargs = {
82+
"preprocessor": self.preprocessor,
83+
"backbone": self.backbone,
84+
}
85+
86+
def test_causal_lm_basics(self):
87+
self.run_task_test(
88+
cls=PaliGemmaCausalLM,
89+
init_kwargs=self.init_kwargs,
90+
train_data=self.train_data,
91+
expected_output_shape=(2, 16, 11),
92+
)
93+
94+
@pytest.mark.large
95+
def test_saved_model(self):
96+
input_data = {
97+
"token_ids": np.random.rand(
98+
self.batch_size, self.text_sequence_length
99+
),
100+
"images": self.dummy_images,
101+
"padding_mask": np.ones(
102+
(self.batch_size, self.text_sequence_length),
103+
dtype="int32",
104+
),
105+
"response_mask": np.zeros(
106+
(self.batch_size, self.text_sequence_length),
107+
dtype="int32",
108+
),
109+
}
110+
self.run_model_saving_test(
111+
cls=PaliGemmaCausalLM,
112+
init_kwargs=self.init_kwargs,
113+
input_data=input_data,
79114
)
80115

81116
def test_pali_gemma_causal_model(self):

keras_nlp/src/models/pali_gemma/pali_gemma_vit.py

+21-23
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@
1818
class PaliGemmaVitEmbeddings(keras.layers.Layer):
1919
def __init__(
2020
self,
21+
image_size,
22+
patch_size,
2123
hidden_dim,
22-
image_size=224,
23-
patch_size=14,
2424
num_channels=3,
2525
dtype=None,
2626
**kwargs,
@@ -286,12 +286,12 @@ def get_config(self):
286286
class PaliGemmaVitEncoder(keras.layers.Layer):
287287
def __init__(
288288
self,
289+
patch_size,
290+
image_size,
289291
hidden_dim,
290292
num_layers,
291293
num_heads,
292294
intermediate_dim,
293-
patch_size,
294-
image_size,
295295
dtype=None,
296296
**kwargs,
297297
):
@@ -421,26 +421,24 @@ class PaliGemmaVit(keras.Model):
421421
"""Vision Transformer (ViT) model for PaliGemma.
422422
423423
Args:
424+
image_size: int. The height/width of the image. Both height and width is
425+
expected to be the same.
426+
patch_size: int. The size of each square patch in the input image.
424427
num_heads: int. The number of attention heads for the vision(image)
425428
transformer encoder.
426429
hidden_dim: int. The size of the transformer hidden state at the end
427430
of each vision transformer layer.
428431
num_layers: int. The number of transformer layers.
429432
intermediate_dim: int. The output dimension of the first Dense layer in
430433
a two-layer feedforward network for transformer.
431-
pooling: string. The encoded vision embeddings are pooled using the
432-
specified polling setting. The accepted values are `"map"`, `"gap"`,
433-
`"zero"` or `"none"`. Defaults to `"none"`.
434434
num_classes: int. The number of output classes. If this model is used
435435
as a image classifier, this value would correspond to the number of
436436
output classes.
437-
image_size: int. The height/width of the image. Both height and width is
438-
expected to be the same.
439-
patch_size: int. The size of each square patch in the input image.
437+
pooling: string. The encoded vision embeddings are pooled using the
438+
specified polling setting. The accepted values are `"map"`, `"gap"`,
439+
`"zero"` or `None`. Defaults to `None`.
440440
classifier_activation: activation fucntion. The activation that is used
441441
for final output classification
442-
include_rescaling: bool. to be set to `True` if input image values needs
443-
to be rescaled between 0-1.
444442
dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
445443
for the models computations and weights. Note that some
446444
computations, such as softmax and layer normalization will always
@@ -458,14 +456,14 @@ class PaliGemmaVit(keras.Model):
458456

459457
def __init__(
460458
self,
461-
num_heads=16,
462-
hidden_dim=1152,
463-
num_layers=27,
464-
intermediate_dim=4304,
459+
image_size,
460+
patch_size,
461+
num_heads,
462+
hidden_dim,
463+
num_layers,
464+
intermediate_dim,
465+
num_classes,
465466
pooling=None,
466-
num_classes=2048,
467-
image_size=None,
468-
patch_size=14,
469467
classifier_activation=None,
470468
dtype=None,
471469
**kwargs,
@@ -475,10 +473,10 @@ def __init__(
475473
shape=(image_size, image_size, 3), name="images"
476474
)
477475
encoded = PaliGemmaVitEncoder(
478-
hidden_dim,
479-
num_layers,
480-
num_heads,
481-
intermediate_dim,
476+
hidden_dim=hidden_dim,
477+
num_layers=num_layers,
478+
num_heads=num_heads,
479+
intermediate_dim=intermediate_dim,
482480
patch_size=patch_size,
483481
image_size=image_size,
484482
dtype=dtype,

0 commit comments

Comments
 (0)