|
| 1 | +import argparse |
| 2 | +import math |
| 3 | +from typing import Any, Dict |
| 4 | + |
| 5 | +import torch |
| 6 | +import torch.nn as nn |
| 7 | +import torchvision |
| 8 | + |
| 9 | +from .transformer_util import PositionalEncoding, PositionalEncodingImage, generate_square_subsequent_mask |
| 10 | + |
| 11 | +TF_DIM = 256 |
| 12 | +TF_FC_DIM = 1024 |
| 13 | +TF_DROPOUT = 0.4 |
| 14 | +TF_LAYERS = 4 |
| 15 | +TF_NHEAD = 4 |
| 16 | +RESNET_DIM = 512 # hard-coded |
| 17 | + |
| 18 | + |
| 19 | +class ResnetTransformer(nn.Module): |
| 20 | + """Process the line through a Resnet and process the resulting sequence with a Transformer decoder""" |
| 21 | + |
| 22 | + def __init__(self, data_config: Dict[str, Any], args: argparse.Namespace = None,) -> None: |
| 23 | + super().__init__() |
| 24 | + self.data_config = data_config |
| 25 | + self.input_dims = data_config["input_dims"] |
| 26 | + self.num_classes = len(data_config["mapping"]) |
| 27 | + inverse_mapping = {val: ind for ind, val in enumerate(data_config["mapping"])} |
| 28 | + self.start_token = inverse_mapping["<S>"] |
| 29 | + self.end_token = inverse_mapping["<E>"] |
| 30 | + self.padding_token = inverse_mapping["<P>"] |
| 31 | + self.unknown_token = inverse_mapping["<U>"] |
| 32 | + self.max_output_length = data_config["output_dims"][0] + 2 |
| 33 | + self.args = vars(args) if args is not None else {} |
| 34 | + |
| 35 | + self.dim = self.args.get("tf_dim", TF_DIM) |
| 36 | + tf_fc_dim = self.args.get("tf_fc_dim", TF_FC_DIM) |
| 37 | + tf_nhead = self.args.get("tf_nhead", TF_NHEAD) |
| 38 | + tf_dropout = self.args.get("tf_dropout", TF_DROPOUT) |
| 39 | + tf_layers = self.args.get("tf_layers", TF_LAYERS) |
| 40 | + |
| 41 | + # ## Encoder part - should output vector sequence of length self.dim per sample |
| 42 | + resnet = torchvision.models.resnet18(pretrained=False) |
| 43 | + self.resnet = torch.nn.Sequential(*(list(resnet.children())[:-2])) # Exclude AvgPool and Linear layers |
| 44 | + # Resnet will output (B, RESNET_DIM, _H, _W) logits where _H = input_H // 32, _W = input_W // 32 |
| 45 | + |
| 46 | + # self.encoder_projection = nn.Conv2d(RESNET_DIM, self.dim, kernel_size=(2, 1), stride=(2, 1), padding=0) |
| 47 | + self.encoder_projection = nn.Conv2d(RESNET_DIM, self.dim, kernel_size=1) |
| 48 | + # encoder_projection will output (B, dim, _H, _W) logits |
| 49 | + |
| 50 | + if isinstance(self.input_dims, list): |
| 51 | + _, max_hight, max_width = max(self.input_dims) |
| 52 | + self.enc_pos_encoder = PositionalEncodingImage( |
| 53 | + d_model=self.dim, max_h=max_hight, max_w=max_width |
| 54 | + ) # Max (Ho, Wo) |
| 55 | + else: |
| 56 | + self.enc_pos_encoder = PositionalEncodingImage( |
| 57 | + d_model=self.dim, max_h=self.input_dims[1], max_w=self.input_dims[2] |
| 58 | + ) # Max (Ho, Wo) |
| 59 | + |
| 60 | + # ## Decoder part |
| 61 | + self.embedding = nn.Embedding(self.num_classes, self.dim) |
| 62 | + self.fc = nn.Linear(self.dim, self.num_classes) |
| 63 | + |
| 64 | + self.dec_pos_encoder = PositionalEncoding(d_model=self.dim, max_len=self.max_output_length) |
| 65 | + |
| 66 | + self.y_mask = generate_square_subsequent_mask(self.max_output_length) |
| 67 | + |
| 68 | + self.transformer_decoder = nn.TransformerDecoder( |
| 69 | + nn.TransformerDecoderLayer(d_model=self.dim, nhead=tf_nhead, dim_feedforward=tf_fc_dim, dropout=tf_dropout), |
| 70 | + num_layers=tf_layers, |
| 71 | + ) |
| 72 | + |
| 73 | + self.init_weights() # This is empirically important |
| 74 | + |
| 75 | + def init_weights(self): |
| 76 | + initrange = 0.1 |
| 77 | + self.embedding.weight.data.uniform_(-initrange, initrange) |
| 78 | + self.fc.bias.data.zero_() |
| 79 | + self.fc.weight.data.uniform_(-initrange, initrange) |
| 80 | + |
| 81 | + nn.init.kaiming_normal_(self.encoder_projection.weight.data, a=0, mode="fan_out", nonlinearity="relu") |
| 82 | + if self.encoder_projection.bias is not None: |
| 83 | + _fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out( # pylint: disable=protected-access |
| 84 | + self.encoder_projection.weight.data |
| 85 | + ) |
| 86 | + bound = 1 / math.sqrt(fan_out) |
| 87 | + nn.init.normal_(self.encoder_projection.bias, -bound, bound) |
| 88 | + |
| 89 | + def encode(self, x: torch.Tensor) -> torch.Tensor: |
| 90 | + """ |
| 91 | + Parameters |
| 92 | + ---------- |
| 93 | + x |
| 94 | + (B, H, W) image |
| 95 | +
|
| 96 | + Returns |
| 97 | + ------- |
| 98 | + torch.Tensor |
| 99 | + (Sx, B, E) logits |
| 100 | + """ |
| 101 | + _B, C, _H, _W = x.shape |
| 102 | + if C == 1: |
| 103 | + x = x.repeat(1, 3, 1, 1) |
| 104 | + x = self.resnet(x) # (B, RESNET_DIM, _H // 32, _W // 32), (B, 512, 18, 20) in the case of IAMParagraphs |
| 105 | + x = self.encoder_projection(x) # (B, E, _H // 32, _W // 32), (B, 256, 18, 20) in the case of IAMParagraphs |
| 106 | + |
| 107 | + # x = x * math.sqrt(self.dim) # (B, E, _H // 32, _W // 32) # This prevented any learning |
| 108 | + x = self.enc_pos_encoder(x) # (B, E, Ho, Wo); Ho = _H // 32, Wo = _W // 32 |
| 109 | + x = torch.flatten(x, start_dim=2) # (B, E, Ho * Wo) |
| 110 | + x = x.permute(2, 0, 1) # (Sx, B, E); Sx = Ho * Wo |
| 111 | + return x |
| 112 | + |
| 113 | + def decode(self, x, y): |
| 114 | + """ |
| 115 | + Parameters |
| 116 | + ---------- |
| 117 | + x |
| 118 | + (B, H, W) image |
| 119 | + y |
| 120 | + (B, Sy) with elements in [0, C-1] where C is num_classes |
| 121 | +
|
| 122 | + Returns |
| 123 | + ------- |
| 124 | + torch.Tensor |
| 125 | + (Sy, B, C) logits |
| 126 | + """ |
| 127 | + y_padding_mask = y == self.padding_token |
| 128 | + y = y.permute(1, 0) # (Sy, B) |
| 129 | + y = self.embedding(y) * math.sqrt(self.dim) # (Sy, B, E) |
| 130 | + y = self.dec_pos_encoder(y) # (Sy, B, E) |
| 131 | + Sy = y.shape[0] |
| 132 | + y_mask = self.y_mask[:Sy, :Sy].type_as(x) |
| 133 | + output = self.transformer_decoder( |
| 134 | + tgt=y, memory=x, tgt_mask=y_mask, tgt_key_padding_mask=y_padding_mask |
| 135 | + ) # (Sy, B, E) |
| 136 | + output = self.fc(output) # (Sy, B, C) |
| 137 | + return output |
| 138 | + |
| 139 | + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: |
| 140 | + """ |
| 141 | + Parameters |
| 142 | + ---------- |
| 143 | + x |
| 144 | + (B, H, W) image |
| 145 | + y |
| 146 | + (B, Sy) with elements in [0, C-1] where C is num_classes |
| 147 | +
|
| 148 | + Returns |
| 149 | + ------- |
| 150 | + torch.Tensor |
| 151 | + (B, C, Sy) logits |
| 152 | + """ |
| 153 | + x = self.encode(x) # (Sx, B, E) |
| 154 | + output = self.decode(x, y) # (Sy, B, C) |
| 155 | + return output.permute(1, 2, 0) # (B, C, Sy) |
| 156 | + |
| 157 | + def predict(self, x: torch.Tensor) -> torch.Tensor: |
| 158 | + """ |
| 159 | + Parameters |
| 160 | + ---------- |
| 161 | + x |
| 162 | + (B, H, W) image |
| 163 | +
|
| 164 | + Returns |
| 165 | + ------- |
| 166 | + torch.Tensor |
| 167 | + (B, Sy) with elements in [0, C-1] where C is num_classes |
| 168 | + """ |
| 169 | + B = x.shape[0] |
| 170 | + S = self.max_output_length |
| 171 | + x = self.encode(x) # (Sx, B, E) |
| 172 | + |
| 173 | + output_tokens = (torch.ones((B, S)) * self.padding_token).type_as(x).long() # (B, S) |
| 174 | + output_tokens[:, 0] = self.start_token # Set start token |
| 175 | + for Sy in range(1, S): |
| 176 | + y = output_tokens[:, :Sy] # (B, Sy) |
| 177 | + output = self.decode(x, y) # (Sy, B, C) |
| 178 | + output = torch.argmax(output, dim=-1) # (Sy, B) |
| 179 | + output_tokens[:, Sy] = output[-1:] # Set the last output token |
| 180 | + |
| 181 | + # Early stopping of prediction loop to speed up prediction |
| 182 | + if ((output_tokens[:, Sy] == self.end_token) | (output_tokens[:, Sy] == self.padding_token)).all(): |
| 183 | + break |
| 184 | + |
| 185 | + # Set all tokens after end token to be padding |
| 186 | + for Sy in range(1, S): |
| 187 | + ind = (output_tokens[:, Sy - 1] == self.end_token) | (output_tokens[:, Sy - 1] == self.padding_token) |
| 188 | + output_tokens[ind, Sy] = self.padding_token |
| 189 | + |
| 190 | + return output_tokens # (B, Sy) |
| 191 | + |
| 192 | + @staticmethod |
| 193 | + def add_to_argparse(parser): |
| 194 | + parser.add_argument("--tf_dim", type=int, default=TF_DIM) |
| 195 | + parser.add_argument("--tf_fc_dim", type=int, default=TF_DIM) |
| 196 | + parser.add_argument("--tf_dropout", type=float, default=TF_DROPOUT) |
| 197 | + parser.add_argument("--tf_layers", type=int, default=TF_LAYERS) |
| 198 | + parser.add_argument("--tf_nhead", type=int, default=TF_NHEAD) |
| 199 | + return parser |
0 commit comments