Skip to content

PaliGemma #1636

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 35 commits into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
cd7e2f8
Add VIT Encoder
divyashreepathihalli Mar 28, 2024
53119bd
Add MHAPooling layer + end to end ViT model
fchollet Mar 29, 2024
aa9bd89
Feature/pg gemma changes
VarunS1997 Apr 2, 2024
962aea7
Misc fixes
divyashreepathihalli Apr 2, 2024
f23974b
update vit model and add a test for verifying output shape.
divyashreepathihalli Apr 3, 2024
53865e0
Feature/pg gemma changes
VarunS1997 Apr 18, 2024
86540c0
Vit model weights conversion
divyashreepathihalli Apr 22, 2024
6160ebe
Update imports
divyashreepathihalli Apr 25, 2024
92e697d
Add vit attention
divyashreepathihalli Apr 27, 2024
0454f1f
add paligemma functional model
divyashreepathihalli May 3, 2024
56a95ce
Paligemma full model checkpoints conversion script
divyashreepathihalli May 4, 2024
9a46b6e
Fix ViT build issue
divyashreepathihalli May 4, 2024
2a2e6b1
Multi modal Refactor for PaliGemma
VarunS1997 May 6, 2024
b2a14d8
Export the public API surface
mattdangerw May 7, 2024
53266b1
update image size arg throughout PaliGemma
divyashreepathihalli May 7, 2024
597daaf
Update convert_paligemma_checkpoints.py
divyashreepathihalli May 7, 2024
87d6c5a
Renames for consistency
mattdangerw May 7, 2024
3ab245c
Update convert_pali_gemma_checkpoints.py
divyashreepathihalli May 7, 2024
671a161
More consistency improvements for PaliGemma
mattdangerw May 7, 2024
ffd5d98
Do the same scaling in our backbone we do for generate
mattdangerw May 7, 2024
b94b6d3
Update conversion and add cli arguments
grasskin May 7, 2024
d262569
Add presets
divyashreepathihalli May 7, 2024
caa7cef
Update pali_gemma_causal_lm_preprocesor.py to default text sequence l…
divyashreepathihalli May 8, 2024
c8f327a
Tokenizer fix
divyashreepathihalli May 8, 2024
bc3811b
Allow fit calls for pali_gemma
mattdangerw May 8, 2024
1afea65
Allow generate on unbatch input
mattdangerw May 8, 2024
cfef757
Remove the score function from pali gemma causal lm
mattdangerw May 8, 2024
0ba0920
Greedy sample by default for pali gemma
mattdangerw May 8, 2024
f4d31dd
Added docstrings for paligemma decoder and backbone
VarunS1997 May 10, 2024
e8bff89
Minor fixes for the vit
mattdangerw May 16, 2024
db0d9d1
Add a response_mask input
mattdangerw May 21, 2024
5267d5a
Update pali_gemma_presets.py path
divyashreepathihalli May 21, 2024
83ee31d
Add a tokenizer docstring for pali gemma
mattdangerw May 21, 2024
edc66e8
Update pali_gemma_presets.py
divyashreepathihalli May 21, 2024
569d89e
More consistent defaults for PaliGemma
mattdangerw May 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions keras_nlp/api/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,18 @@
)
from keras_nlp.src.models.opt.opt_preprocessor import OPTPreprocessor
from keras_nlp.src.models.opt.opt_tokenizer import OPTTokenizer
from keras_nlp.src.models.pali_gemma.pali_gemma_backbone import (
PaliGemmaBackbone,
)
from keras_nlp.src.models.pali_gemma.pali_gemma_causal_lm import (
PaliGemmaCausalLM,
)
from keras_nlp.src.models.pali_gemma.pali_gemma_causal_lm_preprocessor import (
PaliGemmaCausalLMPreprocessor,
)
from keras_nlp.src.models.pali_gemma.pali_gemma_tokenizer import (
PaliGemmaTokenizer,
)
from keras_nlp.src.models.phi3.phi3_backbone import Phi3Backbone
from keras_nlp.src.models.phi3.phi3_causal_lm import Phi3CausalLM
from keras_nlp.src.models.phi3.phi3_causal_lm_preprocessor import (
Expand Down
39 changes: 29 additions & 10 deletions keras_nlp/src/layers/preprocessing/multi_segment_packer.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,12 @@ def _trim_inputs(self, inputs):
else:
raise ValueError("Unsupported truncate: %s" % self.truncate)

def _combine_inputs(self, segments):
def _combine_inputs(
self,
segments,
add_start_value=True,
add_end_value=True,
):
"""Combine inputs with start and end values added."""
dtype = segments[0].dtype
batch_size = segments[0].nrows()
Expand All @@ -259,10 +264,12 @@ def _combine_inputs(self, segments):
ones_sep_columns = tf.ones_like(sep_columns, dtype="int32")
ones_end_columns = tf.ones_like(end_columns, dtype="int32")

segments_to_combine = [start_columns]
segment_ids_to_combine = [
tf.ones_like(start_columns, dtype="int32") * 0
]
segments_to_combine = []
segment_ids_to_combine = []
if add_start_value:
segments_to_combine.append(start_columns)
start_segment = tf.zeros_like(start_columns, dtype="int32")
segment_ids_to_combine.append(start_segment)

for i, seg in enumerate(segments):
# Combine all segments.
Expand All @@ -273,8 +280,9 @@ def _combine_inputs(self, segments):

# Account for the sep/end tokens here.
if i == len(segments) - 1:
segments_to_combine.append(end_columns)
segment_ids_to_combine.append(ones_end_columns * i)
if add_end_value:
segments_to_combine.append(end_columns)
segment_ids_to_combine.append(ones_end_columns * i)
else:
segments_to_combine.append(sep_columns)
segment_ids_to_combine.append(ones_sep_columns * i)
Expand All @@ -283,13 +291,24 @@ def _combine_inputs(self, segments):
segment_ids = tf.concat(segment_ids_to_combine, 1)
return token_ids, segment_ids

def call(self, inputs):
def call(
self,
inputs,
sequence_length=None,
add_start_value=True,
add_end_value=True,
):
inputs, unbatched = self._sanitize_inputs(inputs)

segments = self._trim_inputs(inputs)
token_ids, segment_ids = self._combine_inputs(segments)
token_ids, segment_ids = self._combine_inputs(
segments,
add_start_value=add_start_value,
add_end_value=add_end_value,
)
# Pad to dense tensor output.
shape = tf.cast([-1, self.sequence_length], "int64")
sequence_length = sequence_length or self.sequence_length
shape = tf.cast([-1, sequence_length], "int64")
token_ids = token_ids.to_tensor(
shape=shape, default_value=self.pad_value
)
Expand Down
23 changes: 23 additions & 0 deletions keras_nlp/src/models/pali_gemma/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright 2024 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from keras_nlp.src.models.pali_gemma.pali_gemma_backbone import (
PaliGemmaBackbone,
)
from keras_nlp.src.models.pali_gemma.pali_gemma_presets import backbone_presets
from keras_nlp.src.models.pali_gemma.pali_gemma_tokenizer import (
PaliGemmaTokenizer,
)
from keras_nlp.src.utils.preset_utils import register_presets

register_presets(backbone_presets, (PaliGemmaBackbone, PaliGemmaTokenizer))
279 changes: 279 additions & 0 deletions keras_nlp/src/models/pali_gemma/pali_gemma_backbone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,279 @@
# Copyright 2023 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from keras_nlp.src.api_export import keras_nlp_export
from keras_nlp.src.backend import config
from keras_nlp.src.backend import keras
from keras_nlp.src.backend import ops
from keras_nlp.src.layers.modeling.reversible_embedding import (
ReversibleEmbedding,
)
from keras_nlp.src.models.backbone import Backbone
from keras_nlp.src.models.gemma.rms_normalization import RMSNormalization
from keras_nlp.src.models.pali_gemma.pali_gemma_decoder_block import (
PaliGemmaDecoderBlock,
)
from keras_nlp.src.models.pali_gemma.pali_gemma_vit import PaliGemmaVit


@keras_nlp_export("keras_nlp.models.PaliGemmaBackbone")
class PaliGemmaBackbone(Backbone):
"""PaliGemma core network with hyperparameters.

This backbone implements the mixed-modality PaliGemma architecture. It
contains a Visual Transformer network, as well as text token embedding
layer, followed by a backend-agnostic concatenation operation to
construct a sequence of representations of mixed type embeddings (visual
and textual). Then, the concatenated sequence is passed through a series
of Mixed Modality Decoder Blocks. The returned value from calling this model
represents probabilistic values for output tokens.

For a higher-level object for text-generation,
see `keras_nlp.models.PaliGemmaCausalLM`.

The default constructor gives a fully customizable, randomly initialized
PaliGemma model with any number of vit layers, heads, embedding
dimensions, and equivalent configuration for Paligemma Decoder layers. To
load preset architectures and weights, use the `from_preset` constructor.

Args:
vocabulary_size: int. The size of the token vocabulary.
image_size: int. The resolution of the image in both width and height.
Note: input images must be square.
num_layers: int. The number of transformer mixed decoder layers.
num_query_heads: int. The number of heads for the query projections in
the mixed decoder attention layer.
num_key_value_heads: int. The number of heads for the key and value
projections in the mixed decoder attention layers.
hidden_dim: int. The size of the transformer hidden state at the end
of each mixed transformer layer.
intermediate_dim: int. The output dimension of the first Dense layer in
a two-layer feedforward network for each transformer decoder block.
head_dim: int. The size of each attention head in the mixed decoder.
vit_patch_size: int. The size of each square patch in the input image.
vit_num_heads: int. The number of attention heads for the vision(image)
transformer encoder.
vit_hidden_dim: int. The size of the transformer hidden state at the end
of each vision transformer layer.
vit_num_layers: int. The number of vision transformer layers.
vit_intermediate_dim: int. The output dimension of the first Dense layer
in a two-layer feedforward network for vision transformer.
vit_pooling: string. The encoded vision embeddings are pooled using the
specified polling setting. The accepted values are `"map"`, `"gap"`,
`"0"` or `"none"`. Defaults to `"none"`.
vit_classifier_activation: activation function. The activation that
is used for final output classification in the vision transformer.
vit_name: string. The name used for vision transformer layers.
layer_norm_epsilon: float. The epsilon value user for every layer norm
in all transformer blocks.
dropout: float. Dropout probability for the Transformer decoder blocks.
dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
for the models computations and weights. Note that some
computations, such as softmax and layer normalization will always
be done a float32 precision regardless of dtype.

Example:
```python
input_data = {
"token_ids": np.ones(shape=(1, 12), dtype="int32"),
"images": np.random.uniform(size=(1, 224, 224, 3)),
"padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]),
}

# Pretrained PaliGemma decoder.
model = keras_nlp.models.PaliGemmaBackbone.from_preset("pali_gemma_mix_224")
model(input_data)

# Randomly initialized PaliGemma decoder with custom config.
model = keras_nlp.models.PaliGemmaBackbone(
vocabulary_size=50257,
images_size=224,
num_layers=12,
num_query_heads=12,
num_key_value_heads=1,
hidden_dim=768,
intermediate_dim=3072,
head_dim=64,
vit_patch_size=14,
vit_num_heads=8,
vit_hidden_dim=768,
vit_intermediate_dim=3072,
vit_num_layers=2,
)
model(input_data)
```
"""

def __init__(
self,
vocabulary_size,
image_size,
num_layers,
num_query_heads,
num_key_value_heads,
hidden_dim,
intermediate_dim,
head_dim,
vit_patch_size,
vit_num_heads,
vit_hidden_dim,
vit_num_layers,
vit_intermediate_dim=None, # TODO remove default
vit_pooling=None,
vit_classifier_activation=None,
vit_name=None,
layer_norm_epsilon=1e-6,
dropout=0,
dtype=None,
**kwargs,
):
if not config.keras_3():
raise ValueError(
"`PaliGemmaBackbone` requires Keras 3. Run "
"`pip install -U keras` to upgrade your Keras version, or see "
"https://keras.io/getting_started/ "
"for more info on Keras versions and installation."
)

# === Layers ===
self.token_embedding = ReversibleEmbedding(
input_dim=vocabulary_size,
output_dim=hidden_dim,
tie_weights=True,
embeddings_initializer=keras.initializers.VarianceScaling(
scale=1.0,
mode="fan_in",
distribution="untruncated_normal",
seed=None,
),
dtype=dtype,
name="token_embedding",
)
# TODO Remove this. Work around for previous serialization bug.
vit_intermediate_dim = vit_intermediate_dim or 4304
self.vit_encoder = PaliGemmaVit(
image_size=image_size,
patch_size=vit_patch_size,
num_heads=vit_num_heads,
hidden_dim=vit_hidden_dim,
num_layers=vit_num_layers,
intermediate_dim=vit_intermediate_dim,
pooling=vit_pooling,
num_classes=hidden_dim,
classifier_activation=vit_classifier_activation,
dtype=dtype,
name=vit_name,
)
self.transformer_layers = []
for i in range(num_layers):
layer = PaliGemmaDecoderBlock(
hidden_dim=hidden_dim,
intermediate_dim=intermediate_dim,
num_query_heads=num_query_heads,
head_dim=head_dim,
num_key_value_heads=num_key_value_heads,
dropout=dropout,
dtype=dtype,
name=f"decoder_block_{i}",
)
self.transformer_layers.append(layer)
self.layer_norm = RMSNormalization(
epsilon=layer_norm_epsilon,
dtype=dtype,
name="final_normalization",
)

# === Functional Model ===
image_input = self.vit_encoder.inputs[0]
token_id_input = keras.Input(
shape=(None,), dtype="int32", name="token_ids"
)
padding_mask_input = keras.Input(
shape=(None,), dtype="int32", name="padding_mask"
)
response_mask_input = keras.Input(
shape=(None,), dtype="int32", name="response_mask"
)
img_embeddings = self.vit_encoder(image_input)
text_embeddings = self.token_embedding(token_id_input)
text_embeddings = text_embeddings * ops.cast(
ops.sqrt(hidden_dim), text_embeddings.dtype
)
x = ops.concatenate((img_embeddings, text_embeddings), axis=1)
for transformer_layer in self.transformer_layers:
x = transformer_layer(
x,
padding_mask=padding_mask_input,
response_mask=response_mask_input,
)
sequence_output = self.layer_norm(x)
super().__init__(
inputs={
"images": image_input,
"token_ids": token_id_input,
"padding_mask": padding_mask_input,
"response_mask": response_mask_input,
},
outputs=sequence_output,
dtype=dtype,
**kwargs,
)

# === Config ===
self.vocabulary_size = vocabulary_size
self.image_size = image_size
self.num_layers = num_layers
self.num_query_heads = num_query_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_dim = hidden_dim
self.intermediate_dim = intermediate_dim
self.head_dim = head_dim
self.layer_norm_epsilon = layer_norm_epsilon
self.dropout = dropout
# VIT Params
self.vit_patch_size = vit_patch_size
self.vit_num_heads = vit_num_heads
self.vit_hidden_dim = vit_hidden_dim
self.vit_num_layers = vit_num_layers
self.vit_intermediate_dim = vit_intermediate_dim
self.vit_pooling = vit_pooling
self.vit_classifier_activation = vit_classifier_activation
self.vit_name = vit_name
# Keep the image_sequence_length as a backbone property for easy access.
self.image_sequence_length = self.vit_encoder.image_sequence_length

def get_config(self):
config = super().get_config()
config.update(
{
"vocabulary_size": self.vocabulary_size,
"image_size": self.image_size,
"num_layers": self.num_layers,
"num_query_heads": self.num_query_heads,
"num_key_value_heads": self.num_key_value_heads,
"hidden_dim": self.hidden_dim,
"intermediate_dim": self.intermediate_dim,
"head_dim": self.head_dim,
"layer_norm_epsilon": self.layer_norm_epsilon,
"dropout": self.dropout,
"vit_patch_size": self.vit_patch_size,
"vit_num_heads": self.vit_num_heads,
"vit_hidden_dim": self.vit_hidden_dim,
"vit_num_layers": self.vit_num_layers,
"vit_intermediate_dim": self.vit_intermediate_dim,
"vit_pooling": self.vit_pooling,
"vit_classifier_activation": self.vit_classifier_activation,
"vit_name": self.vit_name,
}
)
return config
Loading
Loading