Skip to content

Added LayoutLMv3 #2178

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 23 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
ae79d15
added the files
carrycooldude Mar 30, 2025
737f03a
Restructure LayoutLMv3 implementation to match KerasHub style
carrycooldude Apr 25, 2025
455a140
Refactor: Move LayoutLMv3 files to models directory and make code bac…
carrycooldude Apr 27, 2025
d92c8c4
refactor: Move LayoutLMv3 files to dedicated directory
carrycooldude Apr 27, 2025
0948f95
fix: Update LayoutLMv3 init files to follow correct format
carrycooldude Apr 30, 2025
3c02f78
fix: Update LayoutLMv3 backbone to follow project standards
carrycooldude Apr 30, 2025
4a79d9b
refactor: remove unnecessary files and fix imports in LayoutLMv3 module
carrycooldude May 26, 2025
c2fed4c
Add minimal stub for LayoutLMv3TransformerLayer
carrycooldude May 29, 2025
e828047
fix: resolve merge conflicts and complete rebase
carrycooldude May 30, 2025
063054d
refactor(layoutlmv3): move usage examples to class docstrings and rem…
carrycooldude Jul 4, 2025
476c0fd
style: apply code formatting and lint fixes via pre-commit
carrycooldude Jul 4, 2025
4439fad
made some changes
carrycooldude Jul 7, 2025
ad3c758
resolve the conflict issue
carrycooldude Jul 7, 2025
885f2fe
chore: update API directory and fix ruff line length in checkpoint co…
carrycooldude Jul 7, 2025
5019abb
update models
carrycooldude Jul 7, 2025
e1fc266
made changes
carrycooldude Jul 7, 2025
a32555c
chore: trigger CI
carrycooldude Jul 7, 2025
a885afa
Update API files
carrycooldude Jul 7, 2025
ad004f7
changed
carrycooldude Jul 7, 2025
6fb0fdc
chore: pre-commit fixes for layoutlmv3 __init__.py
carrycooldude Jul 7, 2025
5aaadab
chore: commit api directory after pre-commit run
carrycooldude Jul 8, 2025
8c7e989
update models
carrycooldude Jul 8, 2025
5a371a5
update layoutlmv3
carrycooldude Jul 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions keras_hub/src/models/layoutlmv3/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from keras_hub.src.models.layoutlmv3.layoutlmv3_backbone import (
LayoutLMv3Backbone,
)
from keras_hub.src.models.layoutlmv3.layoutlmv3_presets import backbone_presets
from keras_hub.src.utils.preset_utils import register_presets

register_presets(backbone_presets, LayoutLMv3Backbone)
354 changes: 354 additions & 0 deletions keras_hub/src/models/layoutlmv3/layoutlmv3_backbone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,354 @@
from keras import backend
from keras import layers
from keras.saving import register_keras_serializable

from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.models.backbone import Backbone

from keras_hub.src.models.layoutlmv3.layoutlmv3_presets import backbone_presets
from keras_hub.src.models.layoutlmv3.layoutlmv3_transformer import LayoutLMv3TransformerLayer


@keras_hub_export("keras_hub.models.LayoutLMv3Backbone")
@register_keras_serializable(package="keras_hub")
class LayoutLMv3Backbone(Backbone):
"""LayoutLMv3 backbone model for document understanding tasks.

This class implements the LayoutLMv3 model architecture for joint text and
layout understanding in document AI tasks. It processes both text and image
inputs while maintaining spatial relationships in documents.

Example:
```python
# Initialize backbone from preset
backbone = LayoutLMv3Backbone.from_preset("layoutlmv3_base")

# Process document image and text
outputs = backbone({
"input_ids": input_ids, # Shape: (batch_size, seq_length)
"bbox": bbox, # Shape: (batch_size, seq_length, 4)
"attention_mask": attention_mask, # Shape: (batch_size, seq_length)
"image": image # Shape: (batch_size, height, width, channels)
})
```

Args:
vocab_size: int. Size of the vocabulary. Defaults to 30522.
hidden_size: int. Size of the hidden layers. Defaults to 768.
num_hidden_layers: int. Number of transformer layers. Defaults to 12.
num_attention_heads: int. Number of attention heads. Defaults to 12.
intermediate_size: int. Size of the intermediate layer. Defaults to
3072.
hidden_act: str. Activation function for the hidden layers. Defaults to
"gelu".
hidden_dropout_prob: float. Dropout probability for hidden layers.
Defaults to 0.1.
attention_probs_dropout_prob: float. Dropout probability for attention
layers. Defaults to 0.1.
max_position_embeddings: int. Maximum sequence length. Defaults to 512.
type_vocab_size: int. Size of the token type vocabulary. Defaults to 2.
initializer_range: float. Range for weight initialization. Defaults to
0.02.
layer_norm_eps: float. Epsilon for layer normalization. Defaults to
1e-12.
pad_token_id: int. ID of the padding token. Defaults to 0.
position_embedding_type: str. Type of position embedding. Defaults to
"absolute".
use_cache: bool. Whether to use caching. Defaults to True.
classifier_dropout: float. Dropout probability for classifier. Defaults
to None.
patch_size: int. Size of image patches. Defaults to 16.
num_channels: int. Number of image channels. Defaults to 3.
qkv_bias: bool. Whether to use bias in QKV projection. Defaults to
True.
use_abs_pos: bool. Whether to use absolute position embeddings.
Defaults to True.
use_rel_pos: bool. Whether to use relative position embeddings.
Defaults to True.
rel_pos_bins: int. Number of relative position bins. Defaults to 32.
max_rel_pos: int. Maximum relative position. Defaults to 128.
spatial_embedding_dim: int. Dimension of spatial embeddings. Defaults
to 64.

References:
- [LayoutLMv3 Paper](https://arxiv.org/abs/2204.08387)
- [LayoutLMv3 GitHub](https://github.com/microsoft/unilm/tree/master/layoutlmv3)
"""

def __init__(
self,
vocab_size: int = 30522,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove type annotation from everywhere, we don't follow type annotation in Keras Hub

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still type annotation needs to be removed

hidden_size: int = 768,
num_hidden_layers: int = 12,
num_attention_heads: int = 12,
intermediate_size: int = 3072,
hidden_act: str = "gelu",
hidden_dropout_prob: float = 0.1,
attention_probs_dropout_prob: float = 0.1,
max_position_embeddings: int = 512,
type_vocab_size: int = 2,
initializer_range: float = 0.02,
layer_norm_eps: float = 1e-12,
pad_token_id: int = 0,
position_embedding_type: str = "absolute",
use_cache: bool = True,
classifier_dropout: float = None,
patch_size: int = 16,
num_channels: int = 3,
qkv_bias: bool = True,
use_abs_pos: bool = True,
use_rel_pos: bool = True,
rel_pos_bins: int = 32,
max_rel_pos: int = 128,
spatial_embedding_dim: int = 64,
**kwargs,
):
super().__init__(**kwargs)

self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.pad_token_id = pad_token_id
self.position_embedding_type = position_embedding_type
self.use_cache = use_cache
self.classifier_dropout = classifier_dropout

# Input layers
self.input_ids = layers.Input(
shape=(None,), dtype="int32", name="input_ids"
)
self.bbox = layers.Input(shape=(None, 4), dtype="int32", name="bbox")
self.attention_mask = layers.Input(
shape=(None,), dtype="int32", name="attention_mask"
)
self.image = layers.Input(
shape=(None, None, None, num_channels),
dtype="float32",
name="image",
)

# Embeddings
self.word_embeddings = layers.Embedding(
vocab_size, hidden_size, name="embeddings.word_embeddings"
)

# Position embeddings
self.x_position_embeddings = layers.Embedding(
1024, spatial_embedding_dim, name="embeddings.x_position_embeddings"
)
self.y_position_embeddings = layers.Embedding(
1024, spatial_embedding_dim, name="embeddings.y_position_embeddings"
)
self.h_position_embeddings = layers.Embedding(
1024, spatial_embedding_dim, name="embeddings.h_position_embeddings"
)
self.w_position_embeddings = layers.Embedding(
1024, spatial_embedding_dim, name="embeddings.w_position_embeddings"
)
self.token_type_embeddings = layers.Embedding(
type_vocab_size,
hidden_size,
name="embeddings.token_type_embeddings",
)

# Layer normalization
self.embeddings_LayerNorm = layers.LayerNormalization(
epsilon=layer_norm_eps, name="embeddings.LayerNorm"
)
self.norm = layers.LayerNormalization(
epsilon=layer_norm_eps, name="norm"
)

# Spatial embedding projections
self.x_proj = layers.Dense(hidden_size, name="x_proj")
self.y_proj = layers.Dense(hidden_size, name="y_proj")
self.h_proj = layers.Dense(hidden_size, name="h_proj")
self.w_proj = layers.Dense(hidden_size, name="w_proj")

# Transformer encoder layers
self.encoder_layers = [
LayoutLMv3TransformerLayer(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
intermediate_size=intermediate_size,
hidden_act=hidden_act,
hidden_dropout_prob=hidden_dropout_prob,
attention_probs_dropout_prob=attention_probs_dropout_prob,
initializer_range=initializer_range,
layer_norm_eps=layer_norm_eps,
qkv_bias=qkv_bias,
use_rel_pos=use_rel_pos,
rel_pos_bins=rel_pos_bins,
max_rel_pos=max_rel_pos,
name=f"encoder.layer.{i}",
)
for i in range(num_hidden_layers)
]

# Image processing
self.patch_embed = layers.Conv2D(
hidden_size,
kernel_size=(patch_size, patch_size),
strides=(patch_size, patch_size),
name="patch_embed.proj",
)
self.patch_embed_layer_norm = layers.LayerNormalization(
epsilon=layer_norm_eps, name="LayerNorm"
)

# CLS token
self.cls_token = self.add_weight(
shape=(1, 1, hidden_size),
initializer="random_normal",
trainable=True,
name="cls_token",
)

# Pooler
self.pooler = layers.Dense(
hidden_size, activation="tanh", name="pooler"
)

def call(self, inputs):
"""Process text and image inputs through the LayoutLMv3 model.

Args:
inputs: Dictionary containing:
- input_ids: Int tensor of shape (batch_size, sequence_length)
- bbox: Int tensor of shape (batch_size, sequence_length, 4)
- attention_mask: Int tensor of shape (batch_size,
sequence_length)
- image: Float tensor of shape (batch_size, height, width,
channels)

Returns:
Dictionary containing:
- sequence_output: Float tensor of shape (batch_size,
sequence_length, hidden_size)
- pooled_output: Float tensor of shape (batch_size,
hidden_size)
- hidden_states: List of tensors of shape (batch_size,
sequence_length, hidden_size)

Example:
```python
model = LayoutLMv3Backbone.from_preset("layoutlmv3_base")
outputs = model({
"input_ids": input_ids,
"bbox": bbox,
"attention_mask": attention_mask,
"image": image
})
```
"""
# Extract inputs
input_ids = inputs["input_ids"]
bbox = inputs["bbox"]
attention_mask = inputs["attention_mask"]

# Get word embeddings
word_embeddings = self.word_embeddings(input_ids)

# Get spatial embeddings
x_embeddings = self.x_position_embeddings(bbox[..., 0])
y_embeddings = self.y_position_embeddings(bbox[..., 1])
h_embeddings = self.h_position_embeddings(bbox[..., 2])
w_embeddings = self.w_position_embeddings(bbox[..., 3])

# Project spatial embeddings to hidden size
x_embeddings = self.x_proj(x_embeddings)
y_embeddings = self.y_proj(y_embeddings)
h_embeddings = self.h_proj(h_embeddings)
w_embeddings = self.w_proj(w_embeddings)

# Combine embeddings
embeddings = (
word_embeddings
+ x_embeddings
+ y_embeddings
+ h_embeddings
+ w_embeddings
)

# Add token type embeddings
token_type_ids = backend.zeros_like(input_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = embeddings + token_type_embeddings

# Apply layer normalization
embeddings = self.embeddings_LayerNorm(embeddings)

# Apply dropout
embeddings = self.embeddings_dropout(embeddings)

# Process through transformer layers
hidden_states = [embeddings]
for layer in self.transformer_layers:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The code iterates over self.transformer_layers, which is not defined. Use self.encoder_layers instead.

Suggested change
for layer in self.transformer_layers:
for layer in self.encoder_layers:

hidden_state = layer(
hidden_states[-1],
attention_mask=attention_mask,
)
hidden_states.append(hidden_state)

# Get sequence output
sequence_output = hidden_states[-1]

# Apply final layer normalization
sequence_output = self.norm(sequence_output)

# Get pooled output
pooled_output = self.pooler(sequence_output[:, 0])

return {
"sequence_output": sequence_output,
"pooled_output": pooled_output,
"hidden_states": hidden_states,
}

def get_config(self):
"""Get the model configuration.

Returns:
A dictionary containing the model configuration.
"""
config = super().get_config()
config.update(
{
"vocab_size": self.vocab_size,
"hidden_size": self.hidden_size,
"num_hidden_layers": self.num_hidden_layers,
"num_attention_heads": self.num_attention_heads,
"intermediate_size": self.intermediate_size,
"hidden_act": self.hidden_act,
"hidden_dropout_prob": self.hidden_dropout_prob,
"attention_probs_dropout_prob": (
self.attention_probs_dropout_prob
),
"max_position_embeddings": self.max_position_embeddings,
"type_vocab_size": self.type_vocab_size,
"initializer_range": self.initializer_range,
"layer_norm_eps": self.layer_norm_eps,
"pad_token_id": self.pad_token_id,
"position_embedding_type": self.position_embedding_type,
"use_cache": self.use_cache,
"classifier_dropout": self.classifier_dropout,
"patch_size": self.patch_size,
"num_channels": self.num_channels,
"qkv_bias": self.qkv_bias,
"use_abs_pos": self.use_abs_pos,
"use_rel_pos": self.use_rel_pos,
"rel_pos_bins": self.rel_pos_bins,
"max_rel_pos": self.max_rel_pos,
"spatial_embedding_dim": self.spatial_embedding_dim,
}
)
return config
Empty file.
28 changes: 28 additions & 0 deletions keras_hub/src/models/layoutlmv3/layoutlmv3_presets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""LayoutLMv3 model preset configurations."""

backbone_presets = {
"layoutlmv3_base": {
"metadata": {
"description": (
"12-layer LayoutLMv3 model with visual backbone. "
"Trained on IIT-CDIP dataset for document understanding."
),
"params": 113000000,
"path": "layoutlmv3",
},
"kaggle_handle": "kaggle://keras/layoutlmv3/keras/layoutlmv3_base/1",
},
"layoutlmv3_large": {
"metadata": {
"description": (
"24-layer LayoutLMv3 model with multimodal "
"(text + layout + image) understanding capabilities. "
"Trained on IIT-CDIP, RVL-CDIP, FUNSD, CORD, SROIE, "
"and DocVQA datasets."
),
"params": 340787200,
"path": "layoutlmv3",
},
"kaggle_handle": "kaggle://keras/layoutlmv3/keras/layoutlmv3_large/3",
},
}
Loading
Loading