Skip to content

Commit 116cfd6

Browse files
committed
update masked kernel transformer
1 parent 472d3f3 commit 116cfd6

File tree

2 files changed

+55
-13
lines changed

2 files changed

+55
-13
lines changed

model/model.py

Lines changed: 47 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -199,14 +199,8 @@ def forward(self, x):
199199
class KernelTransformer(nn.Module):
200200
def __init__(self, in_channels, emb_size, patch_size, heads, num_classes, struct):
201201
super(KernelTransformer, self).__init__()
202-
# self.patch_embed = PatchEmbedding(in_channels, patch_size, emb_size)
203-
# self.num_patches = (emb_size // patch_size) ** 2
204-
# self.pos_embed = PositionalEmbedding(emb_size, self.num_patches + 1)
205-
# grid_size = 32 // patch_size
206-
# self.pos_embed = PositionalEmbedding2D(emb_size, grid_size, grid_size)
207-
# self.cls_token = nn.Parameter(torch.zeros(1, emb_size, 1, 1))
208202
self.blocks = nn.ModuleList([
209-
PatchEmbedding2D(in_channels, emb_size, 2, permute=False),
203+
PatchEmbedding2D(in_channels, emb_size, patch_size, permute=False),
210204
KernelTransformerStage(emb_size, struct[0], heads=4, kernel_size=4, stride=2),
211205
PatchEmbedding2D(emb_size, emb_size * 2, 2),
212206
KernelTransformerStage(emb_size * 2, struct[1], heads=8, kernel_size=4, stride=2),
@@ -227,6 +221,52 @@ def forward(self, x):
227221
return self.classifier(x)
228222

229223

224+
class MaskedKernelTransformer(nn.Module):
225+
def __init__(self, in_channels, emb_size, patch_size, heads, num_classes, struct, mask_ratio=0.1):
226+
super(MaskedKernelTransformer, self).__init__()
227+
self.emb_size = emb_size
228+
self.img_size = 32 # CIFAR-10 image size
229+
self.num_patches = (self.img_size // patch_size) ** 2
230+
self.mask_ratio = mask_ratio
231+
self.blocks = nn.ModuleList([
232+
PatchEmbedding2D(in_channels, emb_size, patch_size, permute=False),
233+
KernelTransformerStage(emb_size, struct[0], heads=4, kernel_size=4, stride=2),
234+
PatchEmbedding2D(emb_size, emb_size * 2, 2),
235+
KernelTransformerStage(emb_size * 2, struct[1], heads=8, kernel_size=4, stride=2),
236+
PatchEmbedding2D(emb_size * 2, emb_size * 4, 2),
237+
KernelTransformerStage(emb_size * 4, struct[2], heads=16, kernel_size=4, stride=2),
238+
PatchEmbedding2D(emb_size * 4, emb_size * 8, 1),
239+
KernelTransformerStage(emb_size * 8, struct[3], heads=32, kernel_size=4, stride=2)
240+
])
241+
self.classifier = nn.Sequential(
242+
nn.LayerNorm(emb_size * 8),
243+
nn.Linear(emb_size * 8, num_classes)
244+
)
245+
246+
# randomly mask some patches
247+
def generate_random_mask(self, ratio):
248+
num_masked_patches = int(self.num_patches * ratio)
249+
mask = torch.ones(self.num_patches)
250+
mask_idx = torch.randperm(self.num_patches)[:num_masked_patches]
251+
mask[mask_idx] = 0
252+
return mask
253+
254+
def forward(self, x, masked = True):
255+
# masked during training, unmasked during testing
256+
applied = False
257+
if masked:
258+
mask = self.generate_random_mask(self.mask_ratio).to(x.device)
259+
mask = mask.view(16, 16)[None, :, :, None]
260+
mask = mask.expand(x.size(0), -1, -1, self.emb_size)
261+
for blk in self.blocks:
262+
x = blk(x)
263+
if isinstance(blk, PatchEmbedding2D) and masked and not applied:
264+
x = x * mask
265+
applied = True
266+
x = x.mean(dim=[1,2])
267+
return self.classifier(x)
268+
269+
230270
if __name__ == '__main__':
231271
from utils import count_parameters
232272

train.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import torch.nn as nn
33
import torchvision
44
import torchvision.transforms as transforms
5-
from model.model import KernelTransformer
5+
from model.model import KernelTransformer, MaskedKernelTransformer
66
from csv_logger import log_csv
77
from tqdm import tqdm
88

@@ -44,11 +44,13 @@ def save_checkpoint(state, filename='checkpoint/checkpoint.pth.tar'):
4444
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
4545
print(' - Training device currently set to:', device)
4646

47-
model = KernelTransformer(in_channels=3, emb_size=96, patch_size=2,
48-
heads=8, num_classes=10, struct=(2, 2, 6, 2)).to(device)
47+
model = MaskedKernelTransformer(in_channels=3, emb_size=96, patch_size=2,
48+
heads=8, num_classes=10, struct=(2, 2, 6, 2), mask_ratio=0.1).to(device)
49+
# model = KernelTransformer(in_channels=3, emb_size=96, patch_size=2,
50+
# heads=8, num_classes=10, struct=(2, 2, 6, 2)).to(device)
4951
model = nn.DataParallel(model)
5052
criterion = torch.nn.CrossEntropyLoss()
51-
num_epochs = 400
53+
num_epochs = 300
5254
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-4)
5355
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, num_epochs)
5456
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
@@ -77,7 +79,7 @@ def save_checkpoint(state, filename='checkpoint/checkpoint.pth.tar'):
7779
labels = labels.to(device)
7880

7981
optimizer.zero_grad()
80-
outputs = model(images)
82+
outputs = model(images, masked=True)
8183
# calculate training accuracy
8284
_, predicted = torch.max(outputs.data, 1)
8385
total += labels.size(0)
@@ -106,7 +108,7 @@ def save_checkpoint(state, filename='checkpoint/checkpoint.pth.tar'):
106108
images = images.to(device)
107109
labels = labels.to(device)
108110

109-
outputs = model(images)
111+
outputs = model(images, masked=False)
110112
_, predicted = torch.max(outputs.data, 1)
111113
total += labels.size(0)
112114
correct += (predicted == labels).sum().item()

0 commit comments

Comments
 (0)