-
Notifications
You must be signed in to change notification settings - Fork 291
[WIP] PARSeq Model #2089
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
[WIP] PARSeq Model #2089
Changes from 79 commits
528d3a4
3bf11cd
a8fb177
6f4363a
d1cece0
92b2745
ed00b73
25f661c
f97fab1
d424210
3f3ad0d
bb4457e
68829f8
1bde466
e6c5379
5b08c93
49260ef
b4150ed
ed8b9d7
4e4511c
9222331
78a07a0
decc12c
7aa2b67
82be527
c0bf528
3a862bb
b6991be
40df2ea
9ce7c62
b1cb2ca
3cd87cd
57a5054
3adad55
b7be4dd
103ee5c
c9487ae
d0b3906
9dfecc1
a7619c6
4cb3c65
dd4f8aa
c473f6d
456ba1d
78f319a
ac30b4b
18de453
6ebf0ea
68a4026
38a4fc1
d990c72
8d05f9c
a54e14a
fd3166e
4ffbc53
2b27b1c
d6dc3fb
675d935
8f6d7fe
032515d
1f92e17
85e9df2
09157f1
a9e367a
eba3e69
0e7cbbd
a87ae57
7c1fe2c
58917dd
3cf997c
f3f3cef
eb5d4ef
f5e21ed
b6b7a26
8c6f14c
e89398b
764a204
180774d
4201d0b
751b0a8
3860843
6f5f093
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
import keras | ||
|
||
from keras_hub.src.api_export import keras_hub_export | ||
from keras_hub.src.models.backbone import Backbone | ||
from keras_hub.src.models.parseq.parseq_decoder import PARSeqDecoder | ||
|
||
|
||
@keras_hub_export("keras_hub.models.PARSeqBackbone") | ||
class PARSeqBackbone(Backbone): | ||
"""Scene Text Detection with PARSeq. | ||
|
||
Performs OCR in natural scenes using the PARSeq model described in [Scene | ||
Text Recognition with Permuted Autoregressive Sequence Models]( | ||
https://arxiv.org/abs/2207.06966). PARSeq is a ViT-based model that allows | ||
iterative decoding by performing an autoregressive decoding phase, followed | ||
by a refinement phase. | ||
|
||
Args: | ||
image_encoder: keras.Model. The image encoder model. | ||
vocabulary_size: int. The size of the vocabulary. | ||
max_label_length: int. The maximum length of the label sequence. | ||
decoder_hidden_dim: int. The dimension of the decoder hidden layers. | ||
num_decoder_layers: int. The number of decoder layers. | ||
num_decoder_heads: int. The number of attention heads in the decoder. | ||
decoder_mlp_dim: int. The dimension of the decoder MLP hidden layer. | ||
dropout_rate: float. The dropout rate. Defaults to `0.1`. | ||
attention_dropout: float. The dropout rate for the attention weights. | ||
Defaults to `0.1`. | ||
sineeli marked this conversation as resolved.
Show resolved
Hide resolved
|
||
dtype: str. The dtype used for layers. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Follow same arg description we follow for other models for dtype. |
||
**kwargs: Additional keyword arguments passed to the base | ||
`keras.Model` constructor. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add an Examples section demonstrating sample usage of the backbone There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Adding in causal_lm file rather than here. Its more suitable there |
||
""" | ||
|
||
def __init__( | ||
self, | ||
image_encoder, | ||
vocabulary_size, | ||
max_label_length, | ||
decoder_hidden_dim, | ||
num_decoder_layers, | ||
num_decoder_heads, | ||
decoder_mlp_dim, | ||
dropout_rate=0.1, | ||
attention_dropout=0.1, | ||
dtype=None, | ||
**kwargs, | ||
): | ||
# === Layers === | ||
self.image_encoder = image_encoder | ||
self.decoder = PARSeqDecoder( | ||
vocabulary_size=vocabulary_size, | ||
max_label_length=max_label_length, | ||
num_layers=num_decoder_layers, | ||
num_heads=num_decoder_heads, | ||
hidden_dim=decoder_hidden_dim, | ||
mlp_dim=decoder_mlp_dim, | ||
dropout_rate=dropout_rate, | ||
attention_dropout=attention_dropout, | ||
name="decoder", | ||
dtype=dtype, | ||
) | ||
self.head = keras.layers.Dense( | ||
vocabulary_size - 2, # We don't predict <bos> nor <pad> | ||
dtype=dtype, | ||
) | ||
|
||
# === Functional Model === | ||
image_input = self.image_encoder.input | ||
|
||
token_id_input = keras.Input( | ||
shape=(None,), dtype="int32", name="token_ids" | ||
) | ||
padding_mask_input = keras.Input( | ||
shape=(None,), dtype="int32", name="padding_mask" | ||
) | ||
|
||
memory = self.image_encoder(image_input) | ||
target_out = self.decoder( | ||
token_id_input, memory, padding_mask=padding_mask_input | ||
) | ||
logits = self.head(target_out) | ||
|
||
# === Config === | ||
self.vocabulary_size = vocabulary_size | ||
self.max_label_length = max_label_length | ||
self.decoder_hidden_dim = decoder_hidden_dim | ||
self.num_decoder_layers = num_decoder_layers | ||
self.num_decoder_heads = num_decoder_heads | ||
self.decoder_mlp_dim = decoder_mlp_dim | ||
self.dropout_rate = dropout_rate | ||
self.attention_dropout = attention_dropout | ||
|
||
super().__init__( | ||
inputs={ | ||
"images": image_input, | ||
"token_ids": token_id_input, | ||
"padding_mask": padding_mask_input, | ||
}, | ||
outputs=logits, | ||
dtype=dtype, | ||
**kwargs, | ||
) | ||
|
||
def get_config(self): | ||
config = super().get_config() | ||
config.update( | ||
{ | ||
"image_encoder": keras.layers.serialize(self.image_encoder), | ||
"vocabulary_size": self.vocabulary_size, | ||
"max_label_length": self.max_label_length, | ||
"decoder_hidden_dim": self.decoder_hidden_dim, | ||
"num_decoder_layers": self.num_decoder_layers, | ||
"num_decoder_heads": self.num_decoder_heads, | ||
"decoder_mlp_dim": self.decoder_mlp_dim, | ||
"dropout_rate": self.dropout_rate, | ||
"attention_dropout": self.attention_dropout, | ||
} | ||
) | ||
|
||
return config | ||
|
||
@classmethod | ||
def from_config(cls, config): | ||
config.update( | ||
{ | ||
"image_encoder": keras.layers.deserialize( | ||
config["image_encoder"] | ||
), | ||
} | ||
) | ||
|
||
return super().from_config(config) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
import keras | ||
import pytest | ||
from keras import ops | ||
|
||
from keras_hub.src.models.parseq.parseq_backbone import PARSeqBackbone | ||
from keras_hub.src.models.vit.vit_backbone import ViTBackbone | ||
from keras_hub.src.tests.test_case import TestCase | ||
|
||
|
||
class PARSeqBackboneTest(TestCase): | ||
def setUp(self): | ||
self.batch_size = 2 | ||
self.image_height = 32 | ||
self.image_width = 128 | ||
self.num_channels = 3 | ||
|
||
# Image Encoder parameters (as per your example) | ||
self.vit_patch_size = (4, 8) | ||
self.vit_num_layers = 2 | ||
self.vit_num_heads = 2 | ||
self.vit_hidden_dim = 64 | ||
self.vit_mlp_dim = self.vit_hidden_dim * 4 | ||
|
||
# PARSeq Backbone parameters | ||
self.vocabulary_size = 97 | ||
self.max_label_length = 25 | ||
self.decoder_hidden_dim = self.vit_hidden_dim | ||
self.num_decoder_layers = 1 | ||
self.num_decoder_heads = 2 | ||
self.decoder_mlp_dim = self.decoder_hidden_dim * 4 | ||
|
||
# Instantiate the actual ViTBackbone to be used as the image_encoder | ||
self.image_encoder = ViTBackbone( | ||
image_shape=( | ||
self.image_height, | ||
self.image_width, | ||
self.num_channels, | ||
), | ||
patch_size=self.vit_patch_size, | ||
num_layers=self.vit_num_layers, | ||
num_heads=self.vit_num_heads, | ||
hidden_dim=self.vit_hidden_dim, | ||
mlp_dim=self.vit_mlp_dim, | ||
use_class_token=False, | ||
name="image_encoder", | ||
) | ||
|
||
self.init_kwargs = { | ||
"image_encoder": self.image_encoder, | ||
"vocabulary_size": self.vocabulary_size, | ||
"max_label_length": self.max_label_length, | ||
"decoder_hidden_dim": self.decoder_hidden_dim, | ||
"num_decoder_layers": self.num_decoder_layers, | ||
"num_decoder_heads": self.num_decoder_heads, | ||
"decoder_mlp_dim": self.decoder_mlp_dim, | ||
"dropout_rate": 0.0, | ||
"attention_dropout": 0.0, | ||
} | ||
|
||
# Dummy input data | ||
dummy_images = keras.random.normal( | ||
shape=( | ||
self.batch_size, | ||
self.image_height, | ||
self.image_width, | ||
self.num_channels, | ||
), | ||
) | ||
|
||
dummy_token_ids = keras.random.randint( | ||
minval=0, | ||
maxval=self.vocabulary_size, | ||
shape=(self.batch_size, self.max_label_length), | ||
) | ||
dummy_padding_mask = ops.ones( | ||
shape=(self.batch_size, self.max_label_length), dtype="int32" | ||
) | ||
|
||
self.input_data = { | ||
"images": dummy_images, | ||
"token_ids": dummy_token_ids, | ||
"padding_mask": dummy_padding_mask, | ||
} | ||
|
||
def test_backbone_basics(self): | ||
expected_shape_full = ( | ||
self.batch_size, | ||
self.max_label_length, | ||
self.vocabulary_size - 2, | ||
) | ||
|
||
self.run_backbone_test( | ||
cls=PARSeqBackbone, | ||
init_kwargs=self.init_kwargs, | ||
input_data=self.input_data, | ||
expected_output_shape=expected_shape_full, | ||
# we have image_encoder as init_kwargs which is also a backbone | ||
run_quantization_check=False, | ||
) | ||
|
||
@pytest.mark.large | ||
def test_saved_model(self): | ||
self.run_model_saving_test( | ||
cls=PARSeqBackbone, | ||
init_kwargs=self.init_kwargs, | ||
input_data=self.input_data, | ||
) |
Uh oh!
There was an error while loading. Please reload this page.