-
Notifications
You must be signed in to change notification settings - Fork 291
[WIP] PARSeq Model #2089
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
[WIP] PARSeq Model #2089
Changes from 21 commits
528d3a4
3bf11cd
a8fb177
6f4363a
d1cece0
92b2745
ed00b73
25f661c
f97fab1
d424210
3f3ad0d
bb4457e
68829f8
1bde466
e6c5379
5b08c93
49260ef
b4150ed
ed8b9d7
4e4511c
9222331
78a07a0
decc12c
7aa2b67
82be527
c0bf528
3a862bb
b6991be
40df2ea
9ce7c62
b1cb2ca
3cd87cd
57a5054
3adad55
b7be4dd
103ee5c
c9487ae
d0b3906
9dfecc1
a7619c6
4cb3c65
dd4f8aa
c473f6d
456ba1d
78f319a
ac30b4b
18de453
6ebf0ea
68a4026
38a4fc1
d990c72
8d05f9c
a54e14a
fd3166e
4ffbc53
2b27b1c
d6dc3fb
675d935
8f6d7fe
032515d
1f92e17
85e9df2
09157f1
a9e367a
eba3e69
0e7cbbd
a87ae57
7c1fe2c
58917dd
3cf997c
f3f3cef
eb5d4ef
f5e21ed
b6b7a26
8c6f14c
e89398b
764a204
180774d
4201d0b
751b0a8
3860843
6f5f093
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
import keras | ||
|
||
from keras_hub.src.api_export import keras_hub_export | ||
from keras_hub.src.models.backbone import Backbone | ||
|
||
|
||
@keras_hub_export("keras_hub.models.PARSeqBackbone") | ||
class PARSeqBackbone(Backbone): | ||
"""Scene Text Detection with PARSeq. | ||
|
||
Performs OCR in natural scenes using the PARSeq model described in [Scene | ||
Text Recognition with Permuted Autoregressive Sequence Models]( | ||
https://arxiv.org/abs/2207.06966). PARSeq is a ViT-based model that allows | ||
iterative decoding by performing an autoregressive decoding phase, followed | ||
by a refinement phase. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
image_encoder, | ||
decode_autoregressive=True, | ||
alphabet_size=97, | ||
max_text_length=25, | ||
num_decoder_layers=1, | ||
num_decoder_heads=12, | ||
dropout_rate=0.1, | ||
dtype=None, | ||
**kwargs, | ||
): | ||
# === Layers === | ||
self.image_encoder = image_encoder | ||
|
||
image_input = self.image_encoder.input | ||
output = self.image_encoder(image_input) | ||
|
||
# === Config === | ||
self.decode_autoregressive = decode_autoregressive | ||
self.alphabet_size = alphabet_size | ||
self.max_text_length = max_text_length | ||
self.num_decoder_layers = num_decoder_layers | ||
self.num_decoder_heads = num_decoder_heads | ||
self.dropout_rate = dropout_rate | ||
|
||
super().__init__( | ||
inputs=image_input, | ||
outputs=output, | ||
dtype=dtype, | ||
**kwargs, | ||
) | ||
|
||
def get_config(self): | ||
config = super().get_config() | ||
config.update( | ||
{ | ||
"encoder": keras.layers.serialize(self.image_encoder), | ||
"alphabet_size": self.alphabet_size, | ||
"max_text_length": self.max_text_length, | ||
"num_decoder_layers": self.num_decoder_layers, | ||
"num_decoder_heads": self.num_decoder_heads, | ||
"dropout_rate": self.dropout_rate, | ||
} | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
from keras_hub.src.api_export import keras_hub_export | ||
from keras_hub.src.layers.preprocessing.image_converter import ImageConverter | ||
from keras_hub.src.models.parseq.parseq_backbone import PARSeqBackbone | ||
|
||
|
||
@keras_hub_export("keras_hub.layers.PARSeqImageConverter") | ||
class PARSeqImageConverter(ImageConverter): | ||
backbone_cls = PARSeqBackbone |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
from keras_hub.src.api_export import keras_hub_export | ||
from keras_hub.src.models.parseq.parseq_backbone import PARSeqBackbone | ||
from keras_hub.src.models.parseq.parseq_image_converter import ( | ||
PARSeqImageConverter, | ||
) | ||
from keras_hub.src.models.text_recognition_preprocessor import ( | ||
TextRecognitionPreprocessor, | ||
) | ||
|
||
|
||
@keras_hub_export("keras_hub.models.PARSeqPreprocessor") | ||
class PARSeqPreprocessor(TextRecognitionPreprocessor): | ||
backbone_cls = PARSeqBackbone | ||
image_converter_cls = PARSeqImageConverter |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
import re | ||
|
||
from keras_hub.src.api_export import keras_hub_export | ||
from keras_hub.src.tokenizers import tokenizer | ||
from keras_hub.src.utils.tensor_utils import is_int_dtype | ||
from keras_hub.src.utils.tensor_utils import is_string_dtype | ||
from keras_hub.src.utils.tensor_utils import preprocessing_function | ||
|
||
PARSEQ_VOCAB = ( | ||
"0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!" | ||
"\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~" | ||
) | ||
|
||
try: | ||
import tensorflow as tf | ||
import tensorflow_text as tf_text | ||
except ImportError: | ||
tf = None | ||
tf_text = None | ||
|
||
|
||
@keras_hub_export( | ||
[ | ||
"keras_hub.tokenizers.PARSeqTokenizer", | ||
"keras_hub.models.PARSeqTokenizer", | ||
] | ||
) | ||
class PARSeqTokenizer(tokenizer.Tokenizer): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add a doc-string here, with examples. Makes it easier to review when we have examples :P There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's add unit tests as well There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, will add them |
||
def __init__( | ||
self, | ||
vocabulary=PARSEQ_VOCAB, | ||
remove_whitespace=True, | ||
normalize_unicode=True, | ||
max_label_length=25, | ||
dtype="int32", | ||
**kwargs, | ||
): | ||
if not is_int_dtype(dtype) and not is_string_dtype(dtype): | ||
raise ValueError( | ||
"Output dtype must be an integer type or a string. " | ||
f"Received: dtype={dtype}" | ||
) | ||
super().__init__(dtype=dtype, **kwargs) | ||
self.vocabulary = vocabulary | ||
sineeli marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.target_charset = tf.convert_to_tensor(vocabulary, dtype=tf.string) | ||
sineeli marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.lowercase_only = self.target_charset == tf.strings.lower( | ||
sineeli marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.target_charset | ||
) | ||
self.uppercase_only = self.target_charset == tf.strings.upper( | ||
self.target_charset | ||
) | ||
escaped_charset = re.escape(vocabulary) # Escape for safe regex | ||
self.unsupported_regex = f"[^{escaped_charset}]" | ||
self._itos = ("[E]",) + tuple(vocabulary) + ("[B]", "[P]") | ||
self._stoi = {s: i for i, s in enumerate(self._itos)} | ||
|
||
self.remove_whitespace = remove_whitespace | ||
self.normalize_unicode = normalize_unicode | ||
self.max_label_length = max_label_length | ||
sineeli marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self._add_special_token("[B]", "start_token") | ||
self._add_special_token("[E]", "end_token") | ||
self._add_special_token("[P]", "pad_token") | ||
# Create lookup tables. | ||
self.char_to_id = tf.lookup.StaticHashTable( | ||
initializer=tf.lookup.KeyValueTensorInitializer( | ||
keys=list(self._stoi.keys()), | ||
values=list(self._stoi.values()), | ||
key_dtype=tf.string, | ||
value_dtype=tf.int32, | ||
), | ||
default_value=0, | ||
) | ||
self.id_to_char = tf.lookup.StaticHashTable( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need this? We aren't using it anywhere There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But in case if user wants to bulk change the token ids to characters it will be helpful |
||
initializer=tf.lookup.KeyValueTensorInitializer( | ||
keys=list(self._stoi.values()), | ||
values=list(self._stoi.keys()), | ||
key_dtype=tf.int32, | ||
value_dtype=tf.string, | ||
), | ||
default_value=self.pad_token, | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The defaults don't match. EOS is the 0th token, and pad is the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I recognized the same in the original code, but seems they are using EOS -> 0, BOS->len(vocabulary), but while padding they are doing BOS first and then EOS at the end. |
||
|
||
def id_to_token(self, id): | ||
if id >= self.vocabulary_size() or id < 0: | ||
raise ValueError( | ||
f"`id` must be in range [0, {self.vocabulary_size() - 1}]. " | ||
f"Received: {id}" | ||
) | ||
return self._itos[id] | ||
|
||
def token_to_id(self, token): | ||
return self._stoi[token] | ||
|
||
def _preprocess(self, label): | ||
sineeli marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Performs preprocessing include only characters from ASCII.""" | ||
if self.remove_whitespace: | ||
label = tf.strings.regex_replace(label, r"\s+", "") | ||
|
||
if self.normalize_unicode: | ||
label = tf_text.normalize_utf8(label, normalization_form="NFKD") | ||
label = tf.strings.regex_replace(label, r"[^!-~]", "") | ||
|
||
if self.lowercase_only: | ||
label = tf.strings.lower(label) | ||
elif self.uppercase_only: | ||
label = tf.strings.upper(label) | ||
|
||
label = tf.strings.regex_replace(label, self.unsupported_regex, "") | ||
label = tf.strings.substr(label, 0, self.max_label_length) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why are we truncating the input to 25 characters? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. While preparing the dataset in the preprocessing itself if the label is above 25 they jus ignore that datapoint itself. Instead I truncated and we can start and end tokens instead. |
||
|
||
return label | ||
|
||
@preprocessing_function | ||
def tokenize(self, inputs): | ||
self._check_vocabulary() | ||
sineeli marked this conversation as resolved.
Show resolved
Hide resolved
|
||
inputs = tf.convert_to_tensor(inputs) | ||
unbatched = inputs.shape.rank == 0 | ||
if unbatched: | ||
inputs = tf.expand_dims(inputs, 0) | ||
|
||
inputs = tf.map_fn(self._preprocess, inputs, dtype=tf.string) | ||
|
||
if tf.size(inputs) > 0: | ||
chars = tf.strings.unicode_split(inputs, "UTF-8") | ||
token_ids = self.char_to_id.lookup(chars) | ||
token_ids = tf.cast(token_ids, dtype=tf.int32) | ||
sineeli marked this conversation as resolved.
Show resolved
Hide resolved
|
||
else: | ||
token_ids = tf.ragged.constant([], dtype=tf.int32) | ||
|
||
return token_ids | ||
|
||
def vocabulary_size(self): | ||
"""Get the integer size of the tokenizer vocabulary.""" | ||
self._check_vocabulary() | ||
return len(self.vocabulary) | ||
sineeli marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def _check_vocabulary(self): | ||
sineeli marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if self.vocabulary is None: | ||
raise ValueError( | ||
"No vocabulary has been set for PARSeqTokenizer. Make sure " | ||
"to pass a `vocabulary` argument when creating the layer." | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import keras | ||
|
||
from keras_hub.src.api_export import keras_hub_export | ||
from keras_hub.src.models.task import Task | ||
|
||
|
||
@keras_hub_export("keras_hub.models.TextRecognition") | ||
class TextRecognition(Task): | ||
"""Base class for all TextRecognition tasks. | ||
|
||
`TextRecognition` tasks wrap a `keras_hub.models.Task` and | ||
a `keras_hub.models.Preprocessor` to create a model that can be used for | ||
recognizing text in images. | ||
|
||
All `TextRecognition` tasks include a `from_preset()` constructor which can | ||
be used to load a pre-trained config and weights. | ||
""" | ||
|
||
def compile( | ||
self, | ||
optimizer="auto", | ||
loss="auto", | ||
*, | ||
metrics="auto", | ||
**kwargs, | ||
): | ||
"""Configures the `ImageOCR` task for training. | ||
|
||
The `ImageOCR` task extends the default compilation signature of | ||
`keras.Model.compile` with defaults for `optimizer`, `loss`, and | ||
`metrics`. To override these defaults, pass any value | ||
to these arguments during compilation. | ||
|
||
Args: | ||
optimizer: `"auto"`, an optimizer name, or a `keras.Optimizer` | ||
instance. Defaults to `"auto"`, which uses the default optimizer | ||
for the given model and task. See `keras.Model.compile` and | ||
`keras.optimizers` for more info on possible `optimizer` values. | ||
loss: `"auto"`, a loss name, or a `keras.losses.Loss` instance. | ||
Defaults to `"auto"`, where a | ||
`keras.losses.SparseCategoricalCrossentropy` loss will be | ||
applied for the classification task. See | ||
`keras.Model.compile` and `keras.losses` for more info on | ||
possible `loss` values. | ||
metrics: `"auto"`, or a list of metrics to be evaluated by | ||
the model during training and testing. Defaults to `"auto"`, | ||
where a `keras.metrics.SparseCategoricalAccuracy` will be | ||
applied to track the accuracy of the model during training. | ||
See `keras.Model.compile` and `keras.metrics` for | ||
more info on possible `metrics` values. | ||
**kwargs: See `keras.Model.compile` for a full list of arguments | ||
supported by the compile method. | ||
""" | ||
if optimizer == "auto": | ||
optimizer = keras.optimizers.Adam(1e-4) | ||
if loss == "auto": | ||
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=False) | ||
if metrics == "auto": | ||
metrics = [keras.metrics.SparseCategoricalAccuracy()] | ||
super().compile( | ||
optimizer=optimizer, | ||
loss=loss, | ||
metrics=metrics, | ||
**kwargs, | ||
) |
Uh oh!
There was an error while loading. Please reload this page.