-
Notifications
You must be signed in to change notification settings - Fork 289
Added LayoutLMv3 #2178
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
carrycooldude
wants to merge
23
commits into
keras-team:master
Choose a base branch
from
carrycooldude:feature/layoutlmv3-port
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Added LayoutLMv3 #2178
Changes from all commits
Commits
Show all changes
23 commits
Select commit
Hold shift + click to select a range
ae79d15
added the files
carrycooldude 737f03a
Restructure LayoutLMv3 implementation to match KerasHub style
carrycooldude 455a140
Refactor: Move LayoutLMv3 files to models directory and make code bac…
carrycooldude d92c8c4
refactor: Move LayoutLMv3 files to dedicated directory
carrycooldude 0948f95
fix: Update LayoutLMv3 init files to follow correct format
carrycooldude 3c02f78
fix: Update LayoutLMv3 backbone to follow project standards
carrycooldude 4a79d9b
refactor: remove unnecessary files and fix imports in LayoutLMv3 module
carrycooldude c2fed4c
Add minimal stub for LayoutLMv3TransformerLayer
carrycooldude e828047
fix: resolve merge conflicts and complete rebase
carrycooldude 063054d
refactor(layoutlmv3): move usage examples to class docstrings and rem…
carrycooldude 476c0fd
style: apply code formatting and lint fixes via pre-commit
carrycooldude 4439fad
made some changes
carrycooldude ad3c758
resolve the conflict issue
carrycooldude 885f2fe
chore: update API directory and fix ruff line length in checkpoint co…
carrycooldude 5019abb
update models
carrycooldude e1fc266
made changes
carrycooldude a32555c
chore: trigger CI
carrycooldude a885afa
Update API files
carrycooldude ad004f7
changed
carrycooldude 6fb0fdc
chore: pre-commit fixes for layoutlmv3 __init__.py
carrycooldude 5aaadab
chore: commit api directory after pre-commit run
carrycooldude 8c7e989
update models
carrycooldude 5a371a5
update layoutlmv3
carrycooldude File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from keras_hub.src.models.layoutlmv3.layoutlmv3_backbone import ( | ||
LayoutLMv3Backbone, | ||
) | ||
from keras_hub.src.models.layoutlmv3.layoutlmv3_presets import backbone_presets | ||
from keras_hub.src.utils.preset_utils import register_presets | ||
|
||
register_presets(backbone_presets, LayoutLMv3Backbone) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,354 @@ | ||
from keras import backend | ||
from keras import layers | ||
from keras.saving import register_keras_serializable | ||
|
||
from keras_hub.src.api_export import keras_hub_export | ||
from keras_hub.src.models.backbone import Backbone | ||
|
||
from keras_hub.src.models.layoutlmv3.layoutlmv3_presets import backbone_presets | ||
from keras_hub.src.models.layoutlmv3.layoutlmv3_transformer import LayoutLMv3TransformerLayer | ||
|
||
|
||
@keras_hub_export("keras_hub.models.LayoutLMv3Backbone") | ||
@register_keras_serializable(package="keras_hub") | ||
class LayoutLMv3Backbone(Backbone): | ||
"""LayoutLMv3 backbone model for document understanding tasks. | ||
|
||
This class implements the LayoutLMv3 model architecture for joint text and | ||
layout understanding in document AI tasks. It processes both text and image | ||
inputs while maintaining spatial relationships in documents. | ||
|
||
Example: | ||
```python | ||
# Initialize backbone from preset | ||
backbone = LayoutLMv3Backbone.from_preset("layoutlmv3_base") | ||
|
||
# Process document image and text | ||
outputs = backbone({ | ||
"input_ids": input_ids, # Shape: (batch_size, seq_length) | ||
"bbox": bbox, # Shape: (batch_size, seq_length, 4) | ||
"attention_mask": attention_mask, # Shape: (batch_size, seq_length) | ||
"image": image # Shape: (batch_size, height, width, channels) | ||
}) | ||
``` | ||
|
||
Args: | ||
vocab_size: int. Size of the vocabulary. Defaults to 30522. | ||
hidden_size: int. Size of the hidden layers. Defaults to 768. | ||
num_hidden_layers: int. Number of transformer layers. Defaults to 12. | ||
num_attention_heads: int. Number of attention heads. Defaults to 12. | ||
intermediate_size: int. Size of the intermediate layer. Defaults to | ||
3072. | ||
hidden_act: str. Activation function for the hidden layers. Defaults to | ||
"gelu". | ||
hidden_dropout_prob: float. Dropout probability for hidden layers. | ||
Defaults to 0.1. | ||
attention_probs_dropout_prob: float. Dropout probability for attention | ||
layers. Defaults to 0.1. | ||
max_position_embeddings: int. Maximum sequence length. Defaults to 512. | ||
type_vocab_size: int. Size of the token type vocabulary. Defaults to 2. | ||
initializer_range: float. Range for weight initialization. Defaults to | ||
0.02. | ||
layer_norm_eps: float. Epsilon for layer normalization. Defaults to | ||
1e-12. | ||
pad_token_id: int. ID of the padding token. Defaults to 0. | ||
position_embedding_type: str. Type of position embedding. Defaults to | ||
"absolute". | ||
use_cache: bool. Whether to use caching. Defaults to True. | ||
classifier_dropout: float. Dropout probability for classifier. Defaults | ||
to None. | ||
patch_size: int. Size of image patches. Defaults to 16. | ||
num_channels: int. Number of image channels. Defaults to 3. | ||
qkv_bias: bool. Whether to use bias in QKV projection. Defaults to | ||
True. | ||
use_abs_pos: bool. Whether to use absolute position embeddings. | ||
Defaults to True. | ||
use_rel_pos: bool. Whether to use relative position embeddings. | ||
Defaults to True. | ||
rel_pos_bins: int. Number of relative position bins. Defaults to 32. | ||
max_rel_pos: int. Maximum relative position. Defaults to 128. | ||
spatial_embedding_dim: int. Dimension of spatial embeddings. Defaults | ||
to 64. | ||
|
||
References: | ||
- [LayoutLMv3 Paper](https://arxiv.org/abs/2204.08387) | ||
- [LayoutLMv3 GitHub](https://github.com/microsoft/unilm/tree/master/layoutlmv3) | ||
""" | ||
|
||
def __init__( | ||
self, | ||
vocab_size: int = 30522, | ||
hidden_size: int = 768, | ||
num_hidden_layers: int = 12, | ||
num_attention_heads: int = 12, | ||
intermediate_size: int = 3072, | ||
hidden_act: str = "gelu", | ||
hidden_dropout_prob: float = 0.1, | ||
attention_probs_dropout_prob: float = 0.1, | ||
max_position_embeddings: int = 512, | ||
type_vocab_size: int = 2, | ||
initializer_range: float = 0.02, | ||
layer_norm_eps: float = 1e-12, | ||
pad_token_id: int = 0, | ||
position_embedding_type: str = "absolute", | ||
use_cache: bool = True, | ||
classifier_dropout: float = None, | ||
patch_size: int = 16, | ||
num_channels: int = 3, | ||
qkv_bias: bool = True, | ||
use_abs_pos: bool = True, | ||
use_rel_pos: bool = True, | ||
rel_pos_bins: int = 32, | ||
max_rel_pos: int = 128, | ||
spatial_embedding_dim: int = 64, | ||
**kwargs, | ||
): | ||
super().__init__(**kwargs) | ||
|
||
self.vocab_size = vocab_size | ||
self.hidden_size = hidden_size | ||
self.num_hidden_layers = num_hidden_layers | ||
self.num_attention_heads = num_attention_heads | ||
self.intermediate_size = intermediate_size | ||
self.hidden_act = hidden_act | ||
self.hidden_dropout_prob = hidden_dropout_prob | ||
self.attention_probs_dropout_prob = attention_probs_dropout_prob | ||
self.max_position_embeddings = max_position_embeddings | ||
self.type_vocab_size = type_vocab_size | ||
self.initializer_range = initializer_range | ||
self.layer_norm_eps = layer_norm_eps | ||
self.pad_token_id = pad_token_id | ||
self.position_embedding_type = position_embedding_type | ||
self.use_cache = use_cache | ||
self.classifier_dropout = classifier_dropout | ||
|
||
# Input layers | ||
self.input_ids = layers.Input( | ||
shape=(None,), dtype="int32", name="input_ids" | ||
) | ||
self.bbox = layers.Input(shape=(None, 4), dtype="int32", name="bbox") | ||
self.attention_mask = layers.Input( | ||
shape=(None,), dtype="int32", name="attention_mask" | ||
) | ||
self.image = layers.Input( | ||
shape=(None, None, None, num_channels), | ||
dtype="float32", | ||
name="image", | ||
) | ||
|
||
# Embeddings | ||
self.word_embeddings = layers.Embedding( | ||
vocab_size, hidden_size, name="embeddings.word_embeddings" | ||
) | ||
|
||
# Position embeddings | ||
self.x_position_embeddings = layers.Embedding( | ||
1024, spatial_embedding_dim, name="embeddings.x_position_embeddings" | ||
) | ||
self.y_position_embeddings = layers.Embedding( | ||
1024, spatial_embedding_dim, name="embeddings.y_position_embeddings" | ||
) | ||
self.h_position_embeddings = layers.Embedding( | ||
1024, spatial_embedding_dim, name="embeddings.h_position_embeddings" | ||
) | ||
self.w_position_embeddings = layers.Embedding( | ||
1024, spatial_embedding_dim, name="embeddings.w_position_embeddings" | ||
) | ||
self.token_type_embeddings = layers.Embedding( | ||
type_vocab_size, | ||
hidden_size, | ||
name="embeddings.token_type_embeddings", | ||
) | ||
|
||
# Layer normalization | ||
self.embeddings_LayerNorm = layers.LayerNormalization( | ||
epsilon=layer_norm_eps, name="embeddings.LayerNorm" | ||
) | ||
self.norm = layers.LayerNormalization( | ||
epsilon=layer_norm_eps, name="norm" | ||
) | ||
|
||
# Spatial embedding projections | ||
self.x_proj = layers.Dense(hidden_size, name="x_proj") | ||
self.y_proj = layers.Dense(hidden_size, name="y_proj") | ||
self.h_proj = layers.Dense(hidden_size, name="h_proj") | ||
self.w_proj = layers.Dense(hidden_size, name="w_proj") | ||
|
||
# Transformer encoder layers | ||
self.encoder_layers = [ | ||
LayoutLMv3TransformerLayer( | ||
hidden_size=hidden_size, | ||
num_attention_heads=num_attention_heads, | ||
intermediate_size=intermediate_size, | ||
hidden_act=hidden_act, | ||
hidden_dropout_prob=hidden_dropout_prob, | ||
attention_probs_dropout_prob=attention_probs_dropout_prob, | ||
initializer_range=initializer_range, | ||
layer_norm_eps=layer_norm_eps, | ||
qkv_bias=qkv_bias, | ||
use_rel_pos=use_rel_pos, | ||
rel_pos_bins=rel_pos_bins, | ||
max_rel_pos=max_rel_pos, | ||
name=f"encoder.layer.{i}", | ||
) | ||
for i in range(num_hidden_layers) | ||
] | ||
|
||
# Image processing | ||
self.patch_embed = layers.Conv2D( | ||
hidden_size, | ||
kernel_size=(patch_size, patch_size), | ||
strides=(patch_size, patch_size), | ||
name="patch_embed.proj", | ||
) | ||
self.patch_embed_layer_norm = layers.LayerNormalization( | ||
epsilon=layer_norm_eps, name="LayerNorm" | ||
) | ||
|
||
# CLS token | ||
self.cls_token = self.add_weight( | ||
shape=(1, 1, hidden_size), | ||
initializer="random_normal", | ||
trainable=True, | ||
name="cls_token", | ||
) | ||
|
||
# Pooler | ||
self.pooler = layers.Dense( | ||
hidden_size, activation="tanh", name="pooler" | ||
) | ||
|
||
def call(self, inputs): | ||
"""Process text and image inputs through the LayoutLMv3 model. | ||
|
||
Args: | ||
inputs: Dictionary containing: | ||
- input_ids: Int tensor of shape (batch_size, sequence_length) | ||
- bbox: Int tensor of shape (batch_size, sequence_length, 4) | ||
- attention_mask: Int tensor of shape (batch_size, | ||
sequence_length) | ||
- image: Float tensor of shape (batch_size, height, width, | ||
channels) | ||
|
||
Returns: | ||
Dictionary containing: | ||
- sequence_output: Float tensor of shape (batch_size, | ||
sequence_length, hidden_size) | ||
- pooled_output: Float tensor of shape (batch_size, | ||
hidden_size) | ||
- hidden_states: List of tensors of shape (batch_size, | ||
sequence_length, hidden_size) | ||
|
||
Example: | ||
```python | ||
model = LayoutLMv3Backbone.from_preset("layoutlmv3_base") | ||
outputs = model({ | ||
"input_ids": input_ids, | ||
"bbox": bbox, | ||
"attention_mask": attention_mask, | ||
"image": image | ||
}) | ||
``` | ||
""" | ||
# Extract inputs | ||
input_ids = inputs["input_ids"] | ||
bbox = inputs["bbox"] | ||
attention_mask = inputs["attention_mask"] | ||
|
||
# Get word embeddings | ||
word_embeddings = self.word_embeddings(input_ids) | ||
|
||
# Get spatial embeddings | ||
x_embeddings = self.x_position_embeddings(bbox[..., 0]) | ||
y_embeddings = self.y_position_embeddings(bbox[..., 1]) | ||
h_embeddings = self.h_position_embeddings(bbox[..., 2]) | ||
w_embeddings = self.w_position_embeddings(bbox[..., 3]) | ||
|
||
# Project spatial embeddings to hidden size | ||
x_embeddings = self.x_proj(x_embeddings) | ||
y_embeddings = self.y_proj(y_embeddings) | ||
h_embeddings = self.h_proj(h_embeddings) | ||
w_embeddings = self.w_proj(w_embeddings) | ||
|
||
# Combine embeddings | ||
embeddings = ( | ||
word_embeddings | ||
+ x_embeddings | ||
+ y_embeddings | ||
+ h_embeddings | ||
+ w_embeddings | ||
) | ||
|
||
# Add token type embeddings | ||
token_type_ids = backend.zeros_like(input_ids) | ||
token_type_embeddings = self.token_type_embeddings(token_type_ids) | ||
embeddings = embeddings + token_type_embeddings | ||
|
||
# Apply layer normalization | ||
embeddings = self.embeddings_LayerNorm(embeddings) | ||
|
||
# Apply dropout | ||
embeddings = self.embeddings_dropout(embeddings) | ||
|
||
# Process through transformer layers | ||
hidden_states = [embeddings] | ||
for layer in self.transformer_layers: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
hidden_state = layer( | ||
hidden_states[-1], | ||
attention_mask=attention_mask, | ||
) | ||
hidden_states.append(hidden_state) | ||
|
||
# Get sequence output | ||
sequence_output = hidden_states[-1] | ||
|
||
# Apply final layer normalization | ||
sequence_output = self.norm(sequence_output) | ||
|
||
# Get pooled output | ||
pooled_output = self.pooler(sequence_output[:, 0]) | ||
|
||
return { | ||
"sequence_output": sequence_output, | ||
"pooled_output": pooled_output, | ||
"hidden_states": hidden_states, | ||
} | ||
|
||
def get_config(self): | ||
"""Get the model configuration. | ||
|
||
Returns: | ||
A dictionary containing the model configuration. | ||
""" | ||
config = super().get_config() | ||
config.update( | ||
{ | ||
"vocab_size": self.vocab_size, | ||
"hidden_size": self.hidden_size, | ||
"num_hidden_layers": self.num_hidden_layers, | ||
"num_attention_heads": self.num_attention_heads, | ||
"intermediate_size": self.intermediate_size, | ||
"hidden_act": self.hidden_act, | ||
"hidden_dropout_prob": self.hidden_dropout_prob, | ||
"attention_probs_dropout_prob": ( | ||
self.attention_probs_dropout_prob | ||
), | ||
"max_position_embeddings": self.max_position_embeddings, | ||
"type_vocab_size": self.type_vocab_size, | ||
"initializer_range": self.initializer_range, | ||
"layer_norm_eps": self.layer_norm_eps, | ||
"pad_token_id": self.pad_token_id, | ||
"position_embedding_type": self.position_embedding_type, | ||
"use_cache": self.use_cache, | ||
"classifier_dropout": self.classifier_dropout, | ||
"patch_size": self.patch_size, | ||
"num_channels": self.num_channels, | ||
"qkv_bias": self.qkv_bias, | ||
"use_abs_pos": self.use_abs_pos, | ||
"use_rel_pos": self.use_rel_pos, | ||
"rel_pos_bins": self.rel_pos_bins, | ||
"max_rel_pos": self.max_rel_pos, | ||
"spatial_embedding_dim": self.spatial_embedding_dim, | ||
} | ||
) | ||
return config |
Empty file.
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
"""LayoutLMv3 model preset configurations.""" | ||
|
||
backbone_presets = { | ||
"layoutlmv3_base": { | ||
"metadata": { | ||
"description": ( | ||
"12-layer LayoutLMv3 model with visual backbone. " | ||
"Trained on IIT-CDIP dataset for document understanding." | ||
), | ||
"params": 113000000, | ||
"path": "layoutlmv3", | ||
}, | ||
"kaggle_handle": "kaggle://keras/layoutlmv3/keras/layoutlmv3_base/1", | ||
}, | ||
"layoutlmv3_large": { | ||
"metadata": { | ||
"description": ( | ||
"24-layer LayoutLMv3 model with multimodal " | ||
"(text + layout + image) understanding capabilities. " | ||
"Trained on IIT-CDIP, RVL-CDIP, FUNSD, CORD, SROIE, " | ||
"and DocVQA datasets." | ||
), | ||
"params": 340787200, | ||
"path": "layoutlmv3", | ||
}, | ||
"kaggle_handle": "kaggle://keras/layoutlmv3/keras/layoutlmv3_large/3", | ||
}, | ||
} |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove type annotation from everywhere, we don't follow type annotation in Keras Hub
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still type annotation needs to be removed