Skip to content

[WIP] PARSeq Model #2089

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 82 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 79 commits
Commits
Show all changes
82 commits
Select commit Hold shift + click to select a range
528d3a4
Base for parseq model
sineeli Jan 31, 2025
3bf11cd
make it vit compatiable with diff height and width sizes
sineeli Jan 31, 2025
a8fb177
correct vit conv scripts
sineeli Jan 31, 2025
6f4363a
make class token optional in backbone by default its included
sineeli Jan 31, 2025
d1cece0
add flags to adjust vit network
sineeli Jan 31, 2025
92b2745
add test case for without class_token
sineeli Jan 31, 2025
ed00b73
Merge branch 'master' into parseq
sineeli Feb 3, 2025
25f661c
decoder file
sineeli Feb 6, 2025
f97fab1
parseq tokenizer base
sineeli Feb 10, 2025
d424210
add api for parseq tokenizer
sineeli Feb 10, 2025
3f3ad0d
Add missing arg max_label_length.
sineeli Feb 10, 2025
bb4457e
nit
sineeli Feb 10, 2025
68829f8
Merge branch 'master' into parseq
sineeli Feb 10, 2025
1bde466
add missing normalization step using tf_text
sineeli Feb 11, 2025
e6c5379
add missing config for preprocessor
sineeli Feb 12, 2025
5b08c93
add default start, pad and end tokens
sineeli Feb 12, 2025
49260ef
nit
sineeli Feb 12, 2025
b4150ed
correct special token order
sineeli Feb 12, 2025
ed8b9d7
return padding mask as well
sineeli Feb 18, 2025
4e4511c
use proper keras ops
sineeli Feb 18, 2025
9222331
nit
sineeli Feb 18, 2025
78a07a0
add decoder for parseq
sineeli Mar 3, 2025
decc12c
Build unbuilt layers for model validation
sineeli Mar 14, 2025
7aa2b67
fix forward pass and decoder
sineeli Mar 14, 2025
82be527
add missing mlp forward pass
sineeli Mar 25, 2025
c0bf528
add generate prprocess and generate step
sineeli Mar 29, 2025
3a862bb
Merge remote-tracking branch 'origin/master' into parseq
sineeli Mar 29, 2025
b6991be
nit
sineeli Mar 29, 2025
40df2ea
add generate_step to parseq causal lm
sineeli Mar 30, 2025
9ce7c62
minor fixes for jax backend and config fix
sineeli Apr 1, 2025
b1cb2ca
update decoder layer with caching mechanism which is used for generat…
sineeli Apr 7, 2025
3cd87cd
modify generate step including cache
sineeli Apr 7, 2025
57a5054
re structure code to make jax backend compatiable
sineeli Apr 8, 2025
3adad55
add postprocess step into preprocessor
sineeli Apr 8, 2025
b7be4dd
test only forward pass
sineeli Apr 8, 2025
103ee5c
nit
sineeli Apr 8, 2025
c9487ae
test build cache
sineeli Apr 8, 2025
d0b3906
test generate step only build cache
sineeli Apr 8, 2025
9dfecc1
correct class name
sineeli Apr 8, 2025
a7619c6
correct dropout
sineeli Apr 8, 2025
4cb3c65
remove slicing in forward pass
sineeli Apr 8, 2025
dd4f8aa
nit
sineeli Apr 8, 2025
c473f6d
use python style slicing
sineeli Apr 8, 2025
456ba1d
support jax for generate step
sineeli Apr 8, 2025
78f319a
Merge branch 'master' into parseq
sineeli Apr 10, 2025
ac30b4b
Merge branch 'master' into parseq
sineeli Apr 10, 2025
18de453
compute attention mask for permutation at decoder block level
sineeli Apr 18, 2025
6ebf0ea
correct syntax error
sineeli Apr 18, 2025
68a4026
nit
sineeli Apr 18, 2025
38a4fc1
Add method for geenrating attention masks during train & permutations…
sineeli Apr 24, 2025
d990c72
update end token after 2 perms
sineeli May 1, 2025
8d05f9c
Merge branch 'master' into parseq
sineeli May 2, 2025
a54e14a
Merge branch 'master' into parseq
sineeli May 5, 2025
fd3166e
minor bug fix
sineeli May 6, 2025
4ffbc53
add save assets and load assets methods
sineeli May 6, 2025
2b27b1c
fix conflict issue
sineeli May 7, 2025
d6dc3fb
nit
sineeli May 7, 2025
675d935
fix minor issues while loading preset
sineeli May 7, 2025
8f6d7fe
fix jax dynamic shape issues
sineeli May 8, 2025
032515d
try to fix jax backend concretization error
sineeli May 8, 2025
1f92e17
fix mask broadcast error
sineeli May 9, 2025
85e9df2
fix repeat for mismatch output length
sineeli May 9, 2025
09157f1
ignore permutation based training
sineeli May 13, 2025
a9e367a
fix dtype and add test case for parseq
sineeli May 13, 2025
eba3e69
Merge branch 'master' into parseq
sineeli May 13, 2025
0e7cbbd
fix input format and add causal lm testing
sineeli May 15, 2025
a87ae57
use numpy random images
sineeli May 15, 2025
7c1fe2c
fix jax backend issue when reduction set to "mean_with_sample_weight"
sineeli May 15, 2025
58917dd
remove redudant classes and use causal lm base calsses itself.
sineeli May 15, 2025
3cf997c
nit
sineeli May 15, 2025
f3f3cef
fix decoder_head_dim usage
sineeli May 15, 2025
eb5d4ef
fix preprocessing issues
sineeli May 16, 2025
f5e21ed
Merge branch 'master' into parseq
sineeli May 16, 2025
b6b7a26
add checkpoint convertion script
sineeli May 16, 2025
8c6f14c
add missing flag
sineeli May 16, 2025
e89398b
validate convertion outputs
sineeli May 19, 2025
764a204
nit
sineeli May 19, 2025
180774d
fix training for permutation logic
sineeli May 20, 2025
4201d0b
Merge branch 'master' into parseq
sineeli May 30, 2025
751b0a8
add example usage for backbone and causal lm
sineeli Jun 19, 2025
3860843
nit
sineeli Jun 19, 2025
6f5f093
Merge remote-tracking branch 'upstream/master' into parseq
sineeli Jun 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions keras_hub/api/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,9 @@
from keras_hub.src.models.pali_gemma.pali_gemma_image_converter import (
PaliGemmaImageConverter as PaliGemmaImageConverter,
)
from keras_hub.src.models.parseq.parseq_image_converter import (
PARSeqImageConverter as PARSeqImageConverter,
)
from keras_hub.src.models.resnet.resnet_image_converter import (
ResNetImageConverter as ResNetImageConverter,
)
Expand Down
12 changes: 12 additions & 0 deletions keras_hub/api/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,18 @@
from keras_hub.src.models.pali_gemma.pali_gemma_tokenizer import (
PaliGemmaTokenizer as PaliGemmaTokenizer,
)
from keras_hub.src.models.parseq.parseq_backbone import (
PARSeqBackbone as PARSeqBackbone,
)
from keras_hub.src.models.parseq.parseq_causal_lm import (
PARSeqCausalLM as PARSeqCausalLM,
)
from keras_hub.src.models.parseq.parseq_causal_lm_preprocessor import (
PARSeqCausalLMPreprocessor as PARSeqCausalLMPreprocessor,
)
from keras_hub.src.models.parseq.parseq_tokenizer import (
PARSeqTokenizer as PARSeqTokenizer,
)
from keras_hub.src.models.phi3.phi3_backbone import Phi3Backbone as Phi3Backbone
from keras_hub.src.models.phi3.phi3_causal_lm import (
Phi3CausalLM as Phi3CausalLM,
Expand Down
3 changes: 3 additions & 0 deletions keras_hub/api/tokenizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@
from keras_hub.src.models.pali_gemma.pali_gemma_tokenizer import (
PaliGemmaTokenizer as PaliGemmaTokenizer,
)
from keras_hub.src.models.parseq.parseq_tokenizer import (
PARSeqTokenizer as PARSeqTokenizer,
)
from keras_hub.src.models.phi3.phi3_tokenizer import (
Phi3Tokenizer as Phi3Tokenizer,
)
Expand Down
Empty file.
132 changes: 132 additions & 0 deletions keras_hub/src/models/parseq/parseq_backbone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import keras

from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.models.backbone import Backbone
from keras_hub.src.models.parseq.parseq_decoder import PARSeqDecoder


@keras_hub_export("keras_hub.models.PARSeqBackbone")
class PARSeqBackbone(Backbone):
"""Scene Text Detection with PARSeq.

Performs OCR in natural scenes using the PARSeq model described in [Scene
Text Recognition with Permuted Autoregressive Sequence Models](
https://arxiv.org/abs/2207.06966). PARSeq is a ViT-based model that allows
iterative decoding by performing an autoregressive decoding phase, followed
by a refinement phase.

Args:
image_encoder: keras.Model. The image encoder model.
vocabulary_size: int. The size of the vocabulary.
max_label_length: int. The maximum length of the label sequence.
decoder_hidden_dim: int. The dimension of the decoder hidden layers.
num_decoder_layers: int. The number of decoder layers.
num_decoder_heads: int. The number of attention heads in the decoder.
decoder_mlp_dim: int. The dimension of the decoder MLP hidden layer.
dropout_rate: float. The dropout rate. Defaults to `0.1`.
attention_dropout: float. The dropout rate for the attention weights.
Defaults to `0.1`.
dtype: str. The dtype used for layers.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Follow same arg description we follow for other models for dtype.

**kwargs: Additional keyword arguments passed to the base
`keras.Model` constructor.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add an Examples section demonstrating sample usage of the backbone

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding in causal_lm file rather than here. Its more suitable there

"""

def __init__(
self,
image_encoder,
vocabulary_size,
max_label_length,
decoder_hidden_dim,
num_decoder_layers,
num_decoder_heads,
decoder_mlp_dim,
dropout_rate=0.1,
attention_dropout=0.1,
dtype=None,
**kwargs,
):
# === Layers ===
self.image_encoder = image_encoder
self.decoder = PARSeqDecoder(
vocabulary_size=vocabulary_size,
max_label_length=max_label_length,
num_layers=num_decoder_layers,
num_heads=num_decoder_heads,
hidden_dim=decoder_hidden_dim,
mlp_dim=decoder_mlp_dim,
dropout_rate=dropout_rate,
attention_dropout=attention_dropout,
name="decoder",
dtype=dtype,
)
self.head = keras.layers.Dense(
vocabulary_size - 2, # We don't predict <bos> nor <pad>
dtype=dtype,
)

# === Functional Model ===
image_input = self.image_encoder.input

token_id_input = keras.Input(
shape=(None,), dtype="int32", name="token_ids"
)
padding_mask_input = keras.Input(
shape=(None,), dtype="int32", name="padding_mask"
)

memory = self.image_encoder(image_input)
target_out = self.decoder(
token_id_input, memory, padding_mask=padding_mask_input
)
logits = self.head(target_out)

# === Config ===
self.vocabulary_size = vocabulary_size
self.max_label_length = max_label_length
self.decoder_hidden_dim = decoder_hidden_dim
self.num_decoder_layers = num_decoder_layers
self.num_decoder_heads = num_decoder_heads
self.decoder_mlp_dim = decoder_mlp_dim
self.dropout_rate = dropout_rate
self.attention_dropout = attention_dropout

super().__init__(
inputs={
"images": image_input,
"token_ids": token_id_input,
"padding_mask": padding_mask_input,
},
outputs=logits,
dtype=dtype,
**kwargs,
)

def get_config(self):
config = super().get_config()
config.update(
{
"image_encoder": keras.layers.serialize(self.image_encoder),
"vocabulary_size": self.vocabulary_size,
"max_label_length": self.max_label_length,
"decoder_hidden_dim": self.decoder_hidden_dim,
"num_decoder_layers": self.num_decoder_layers,
"num_decoder_heads": self.num_decoder_heads,
"decoder_mlp_dim": self.decoder_mlp_dim,
"dropout_rate": self.dropout_rate,
"attention_dropout": self.attention_dropout,
}
)

return config

@classmethod
def from_config(cls, config):
config.update(
{
"image_encoder": keras.layers.deserialize(
config["image_encoder"]
),
}
)

return super().from_config(config)
107 changes: 107 additions & 0 deletions keras_hub/src/models/parseq/parseq_backbone_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import keras
import pytest
from keras import ops

from keras_hub.src.models.parseq.parseq_backbone import PARSeqBackbone
from keras_hub.src.models.vit.vit_backbone import ViTBackbone
from keras_hub.src.tests.test_case import TestCase


class PARSeqBackboneTest(TestCase):
def setUp(self):
self.batch_size = 2
self.image_height = 32
self.image_width = 128
self.num_channels = 3

# Image Encoder parameters (as per your example)
self.vit_patch_size = (4, 8)
self.vit_num_layers = 2
self.vit_num_heads = 2
self.vit_hidden_dim = 64
self.vit_mlp_dim = self.vit_hidden_dim * 4

# PARSeq Backbone parameters
self.vocabulary_size = 97
self.max_label_length = 25
self.decoder_hidden_dim = self.vit_hidden_dim
self.num_decoder_layers = 1
self.num_decoder_heads = 2
self.decoder_mlp_dim = self.decoder_hidden_dim * 4

# Instantiate the actual ViTBackbone to be used as the image_encoder
self.image_encoder = ViTBackbone(
image_shape=(
self.image_height,
self.image_width,
self.num_channels,
),
patch_size=self.vit_patch_size,
num_layers=self.vit_num_layers,
num_heads=self.vit_num_heads,
hidden_dim=self.vit_hidden_dim,
mlp_dim=self.vit_mlp_dim,
use_class_token=False,
name="image_encoder",
)

self.init_kwargs = {
"image_encoder": self.image_encoder,
"vocabulary_size": self.vocabulary_size,
"max_label_length": self.max_label_length,
"decoder_hidden_dim": self.decoder_hidden_dim,
"num_decoder_layers": self.num_decoder_layers,
"num_decoder_heads": self.num_decoder_heads,
"decoder_mlp_dim": self.decoder_mlp_dim,
"dropout_rate": 0.0,
"attention_dropout": 0.0,
}

# Dummy input data
dummy_images = keras.random.normal(
shape=(
self.batch_size,
self.image_height,
self.image_width,
self.num_channels,
),
)

dummy_token_ids = keras.random.randint(
minval=0,
maxval=self.vocabulary_size,
shape=(self.batch_size, self.max_label_length),
)
dummy_padding_mask = ops.ones(
shape=(self.batch_size, self.max_label_length), dtype="int32"
)

self.input_data = {
"images": dummy_images,
"token_ids": dummy_token_ids,
"padding_mask": dummy_padding_mask,
}

def test_backbone_basics(self):
expected_shape_full = (
self.batch_size,
self.max_label_length,
self.vocabulary_size - 2,
)

self.run_backbone_test(
cls=PARSeqBackbone,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
expected_output_shape=expected_shape_full,
# we have image_encoder as init_kwargs which is also a backbone
run_quantization_check=False,
)

@pytest.mark.large
def test_saved_model(self):
self.run_model_saving_test(
cls=PARSeqBackbone,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
)
Loading
Loading