Skip to content

Commit a69433f

Browse files
committed
Add ResnetTransformer
1 parent c495117 commit a69433f

File tree

3 files changed

+276
-0
lines changed

3 files changed

+276
-0
lines changed

im2latex/models/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
from .cnn_lstm import CNNLSTM
2+
from .resnet_transformer import ResnetTransformer

im2latex/models/resnet_transformer.py

+199
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
import argparse
2+
import math
3+
from typing import Any, Dict
4+
5+
import torch
6+
import torch.nn as nn
7+
import torchvision
8+
9+
from .transformer_util import PositionalEncoding, PositionalEncodingImage, generate_square_subsequent_mask
10+
11+
TF_DIM = 256
12+
TF_FC_DIM = 1024
13+
TF_DROPOUT = 0.4
14+
TF_LAYERS = 4
15+
TF_NHEAD = 4
16+
RESNET_DIM = 512 # hard-coded
17+
18+
19+
class ResnetTransformer(nn.Module):
20+
"""Process the line through a Resnet and process the resulting sequence with a Transformer decoder"""
21+
22+
def __init__(self, data_config: Dict[str, Any], args: argparse.Namespace = None,) -> None:
23+
super().__init__()
24+
self.data_config = data_config
25+
self.input_dims = data_config["input_dims"]
26+
self.num_classes = len(data_config["mapping"])
27+
inverse_mapping = {val: ind for ind, val in enumerate(data_config["mapping"])}
28+
self.start_token = inverse_mapping["<S>"]
29+
self.end_token = inverse_mapping["<E>"]
30+
self.padding_token = inverse_mapping["<P>"]
31+
self.unknown_token = inverse_mapping["<U>"]
32+
self.max_output_length = data_config["output_dims"][0] + 2
33+
self.args = vars(args) if args is not None else {}
34+
35+
self.dim = self.args.get("tf_dim", TF_DIM)
36+
tf_fc_dim = self.args.get("tf_fc_dim", TF_FC_DIM)
37+
tf_nhead = self.args.get("tf_nhead", TF_NHEAD)
38+
tf_dropout = self.args.get("tf_dropout", TF_DROPOUT)
39+
tf_layers = self.args.get("tf_layers", TF_LAYERS)
40+
41+
# ## Encoder part - should output vector sequence of length self.dim per sample
42+
resnet = torchvision.models.resnet18(pretrained=False)
43+
self.resnet = torch.nn.Sequential(*(list(resnet.children())[:-2])) # Exclude AvgPool and Linear layers
44+
# Resnet will output (B, RESNET_DIM, _H, _W) logits where _H = input_H // 32, _W = input_W // 32
45+
46+
# self.encoder_projection = nn.Conv2d(RESNET_DIM, self.dim, kernel_size=(2, 1), stride=(2, 1), padding=0)
47+
self.encoder_projection = nn.Conv2d(RESNET_DIM, self.dim, kernel_size=1)
48+
# encoder_projection will output (B, dim, _H, _W) logits
49+
50+
if isinstance(self.input_dims, list):
51+
_, max_hight, max_width = max(self.input_dims)
52+
self.enc_pos_encoder = PositionalEncodingImage(
53+
d_model=self.dim, max_h=max_hight, max_w=max_width
54+
) # Max (Ho, Wo)
55+
else:
56+
self.enc_pos_encoder = PositionalEncodingImage(
57+
d_model=self.dim, max_h=self.input_dims[1], max_w=self.input_dims[2]
58+
) # Max (Ho, Wo)
59+
60+
# ## Decoder part
61+
self.embedding = nn.Embedding(self.num_classes, self.dim)
62+
self.fc = nn.Linear(self.dim, self.num_classes)
63+
64+
self.dec_pos_encoder = PositionalEncoding(d_model=self.dim, max_len=self.max_output_length)
65+
66+
self.y_mask = generate_square_subsequent_mask(self.max_output_length)
67+
68+
self.transformer_decoder = nn.TransformerDecoder(
69+
nn.TransformerDecoderLayer(d_model=self.dim, nhead=tf_nhead, dim_feedforward=tf_fc_dim, dropout=tf_dropout),
70+
num_layers=tf_layers,
71+
)
72+
73+
self.init_weights() # This is empirically important
74+
75+
def init_weights(self):
76+
initrange = 0.1
77+
self.embedding.weight.data.uniform_(-initrange, initrange)
78+
self.fc.bias.data.zero_()
79+
self.fc.weight.data.uniform_(-initrange, initrange)
80+
81+
nn.init.kaiming_normal_(self.encoder_projection.weight.data, a=0, mode="fan_out", nonlinearity="relu")
82+
if self.encoder_projection.bias is not None:
83+
_fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out( # pylint: disable=protected-access
84+
self.encoder_projection.weight.data
85+
)
86+
bound = 1 / math.sqrt(fan_out)
87+
nn.init.normal_(self.encoder_projection.bias, -bound, bound)
88+
89+
def encode(self, x: torch.Tensor) -> torch.Tensor:
90+
"""
91+
Parameters
92+
----------
93+
x
94+
(B, H, W) image
95+
96+
Returns
97+
-------
98+
torch.Tensor
99+
(Sx, B, E) logits
100+
"""
101+
_B, C, _H, _W = x.shape
102+
if C == 1:
103+
x = x.repeat(1, 3, 1, 1)
104+
x = self.resnet(x) # (B, RESNET_DIM, _H // 32, _W // 32), (B, 512, 18, 20) in the case of IAMParagraphs
105+
x = self.encoder_projection(x) # (B, E, _H // 32, _W // 32), (B, 256, 18, 20) in the case of IAMParagraphs
106+
107+
# x = x * math.sqrt(self.dim) # (B, E, _H // 32, _W // 32) # This prevented any learning
108+
x = self.enc_pos_encoder(x) # (B, E, Ho, Wo); Ho = _H // 32, Wo = _W // 32
109+
x = torch.flatten(x, start_dim=2) # (B, E, Ho * Wo)
110+
x = x.permute(2, 0, 1) # (Sx, B, E); Sx = Ho * Wo
111+
return x
112+
113+
def decode(self, x, y):
114+
"""
115+
Parameters
116+
----------
117+
x
118+
(B, H, W) image
119+
y
120+
(B, Sy) with elements in [0, C-1] where C is num_classes
121+
122+
Returns
123+
-------
124+
torch.Tensor
125+
(Sy, B, C) logits
126+
"""
127+
y_padding_mask = y == self.padding_token
128+
y = y.permute(1, 0) # (Sy, B)
129+
y = self.embedding(y) * math.sqrt(self.dim) # (Sy, B, E)
130+
y = self.dec_pos_encoder(y) # (Sy, B, E)
131+
Sy = y.shape[0]
132+
y_mask = self.y_mask[:Sy, :Sy].type_as(x)
133+
output = self.transformer_decoder(
134+
tgt=y, memory=x, tgt_mask=y_mask, tgt_key_padding_mask=y_padding_mask
135+
) # (Sy, B, E)
136+
output = self.fc(output) # (Sy, B, C)
137+
return output
138+
139+
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
140+
"""
141+
Parameters
142+
----------
143+
x
144+
(B, H, W) image
145+
y
146+
(B, Sy) with elements in [0, C-1] where C is num_classes
147+
148+
Returns
149+
-------
150+
torch.Tensor
151+
(B, C, Sy) logits
152+
"""
153+
x = self.encode(x) # (Sx, B, E)
154+
output = self.decode(x, y) # (Sy, B, C)
155+
return output.permute(1, 2, 0) # (B, C, Sy)
156+
157+
def predict(self, x: torch.Tensor) -> torch.Tensor:
158+
"""
159+
Parameters
160+
----------
161+
x
162+
(B, H, W) image
163+
164+
Returns
165+
-------
166+
torch.Tensor
167+
(B, Sy) with elements in [0, C-1] where C is num_classes
168+
"""
169+
B = x.shape[0]
170+
S = self.max_output_length
171+
x = self.encode(x) # (Sx, B, E)
172+
173+
output_tokens = (torch.ones((B, S)) * self.padding_token).type_as(x).long() # (B, S)
174+
output_tokens[:, 0] = self.start_token # Set start token
175+
for Sy in range(1, S):
176+
y = output_tokens[:, :Sy] # (B, Sy)
177+
output = self.decode(x, y) # (Sy, B, C)
178+
output = torch.argmax(output, dim=-1) # (Sy, B)
179+
output_tokens[:, Sy] = output[-1:] # Set the last output token
180+
181+
# Early stopping of prediction loop to speed up prediction
182+
if ((output_tokens[:, Sy] == self.end_token) | (output_tokens[:, Sy] == self.padding_token)).all():
183+
break
184+
185+
# Set all tokens after end token to be padding
186+
for Sy in range(1, S):
187+
ind = (output_tokens[:, Sy - 1] == self.end_token) | (output_tokens[:, Sy - 1] == self.padding_token)
188+
output_tokens[ind, Sy] = self.padding_token
189+
190+
return output_tokens # (B, Sy)
191+
192+
@staticmethod
193+
def add_to_argparse(parser):
194+
parser.add_argument("--tf_dim", type=int, default=TF_DIM)
195+
parser.add_argument("--tf_fc_dim", type=int, default=TF_DIM)
196+
parser.add_argument("--tf_dropout", type=float, default=TF_DROPOUT)
197+
parser.add_argument("--tf_layers", type=int, default=TF_LAYERS)
198+
parser.add_argument("--tf_nhead", type=int, default=TF_NHEAD)
199+
return parser

im2latex/models/transformer_util.py

+76
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
"""Position Encoding and other utilities for Tranformers"""
2+
import math
3+
4+
import torch
5+
import torch.nn as nn
6+
from torch import Tensor
7+
8+
9+
# Hide lines below until Lab 7
10+
class PositionalEncodingImage(nn.Module):
11+
"""
12+
Module used to add 2-D positional encodings to the feature-map produced by the encoder.
13+
14+
Following https://arxiv.org/abs/2103.06450 by Sumeet Singh.
15+
"""
16+
17+
def __init__(self, d_model: int, max_h: int = 2000, max_w: int = 2000) -> None:
18+
super().__init__()
19+
self.d_model = d_model
20+
assert d_model % 2 == 0, f"Embedding depth {d_model} is not even"
21+
pe = self.make_pe(d_model=d_model, max_h=max_h, max_w=max_w) # (d_model, max_h, max_w)
22+
self.register_buffer("pe", pe)
23+
24+
@staticmethod
25+
def make_pe(d_model: int, max_h: int, max_w: int) -> torch.Tensor:
26+
pe_h = PositionalEncoding.make_pe(d_model=d_model // 2, max_len=max_h) # (max_h, 1 d_model // 2)
27+
pe_h = pe_h.permute(2, 0, 1).expand(-1, -1, max_w) # (d_model // 2, max_h, max_w)
28+
29+
pe_w = PositionalEncoding.make_pe(d_model=d_model // 2, max_len=max_w) # (max_w, 1, d_model // 2)
30+
pe_w = pe_w.permute(2, 1, 0).expand(-1, max_h, -1) # (d_model // 2, max_h, max_w)
31+
32+
pe = torch.cat([pe_h, pe_w], dim=0) # (d_model, max_h, max_w)
33+
return pe
34+
35+
def forward(self, x: Tensor) -> Tensor:
36+
"""pytorch.nn.module.forward"""
37+
# x.shape = (B, d_model, H, W)
38+
assert x.shape[1] == self.pe.shape[0] # type: ignore
39+
x = x + self.pe[:, : x.size(2), : x.size(3)] # type: ignore
40+
return x
41+
42+
43+
# Hide lines above until Lab 7
44+
45+
46+
class PositionalEncoding(torch.nn.Module):
47+
"""Classic Attention-is-all-you-need positional encoding."""
48+
49+
def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000) -> None:
50+
super().__init__()
51+
self.dropout = torch.nn.Dropout(p=dropout)
52+
pe = self.make_pe(d_model=d_model, max_len=max_len) # (max_len, 1, d_model)
53+
self.register_buffer("pe", pe)
54+
55+
@staticmethod
56+
def make_pe(d_model: int, max_len: int) -> torch.Tensor:
57+
pe = torch.zeros(max_len, d_model)
58+
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
59+
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
60+
pe[:, 0::2] = torch.sin(position * div_term)
61+
pe[:, 1::2] = torch.cos(position * div_term)
62+
pe = pe.unsqueeze(1)
63+
return pe
64+
65+
def forward(self, x: torch.Tensor) -> torch.Tensor:
66+
# x.shape = (S, B, d_model)
67+
assert x.shape[2] == self.pe.shape[2] # type: ignore
68+
x = x + self.pe[: x.size(0)] # type: ignore
69+
return self.dropout(x)
70+
71+
72+
def generate_square_subsequent_mask(size: int) -> torch.Tensor:
73+
"""Generate a triangular (size, size) mask."""
74+
mask = (torch.triu(torch.ones(size, size)) == 1).transpose(0, 1)
75+
mask = mask.float().masked_fill(mask == 0, float("-inf")).masked_fill(mask == 1, float(0.0))
76+
return mask

0 commit comments

Comments
 (0)