@@ -199,14 +199,8 @@ def forward(self, x):
199
199
class KernelTransformer (nn .Module ):
200
200
def __init__ (self , in_channels , emb_size , patch_size , heads , num_classes , struct ):
201
201
super (KernelTransformer , self ).__init__ ()
202
- # self.patch_embed = PatchEmbedding(in_channels, patch_size, emb_size)
203
- # self.num_patches = (emb_size // patch_size) ** 2
204
- # self.pos_embed = PositionalEmbedding(emb_size, self.num_patches + 1)
205
- # grid_size = 32 // patch_size
206
- # self.pos_embed = PositionalEmbedding2D(emb_size, grid_size, grid_size)
207
- # self.cls_token = nn.Parameter(torch.zeros(1, emb_size, 1, 1))
208
202
self .blocks = nn .ModuleList ([
209
- PatchEmbedding2D (in_channels , emb_size , 2 , permute = False ),
203
+ PatchEmbedding2D (in_channels , emb_size , patch_size , permute = False ),
210
204
KernelTransformerStage (emb_size , struct [0 ], heads = 4 , kernel_size = 4 , stride = 2 ),
211
205
PatchEmbedding2D (emb_size , emb_size * 2 , 2 ),
212
206
KernelTransformerStage (emb_size * 2 , struct [1 ], heads = 8 , kernel_size = 4 , stride = 2 ),
@@ -227,6 +221,52 @@ def forward(self, x):
227
221
return self .classifier (x )
228
222
229
223
224
+ class MaskedKernelTransformer (nn .Module ):
225
+ def __init__ (self , in_channels , emb_size , patch_size , heads , num_classes , struct , mask_ratio = 0.1 ):
226
+ super (MaskedKernelTransformer , self ).__init__ ()
227
+ self .emb_size = emb_size
228
+ self .img_size = 32 # CIFAR-10 image size
229
+ self .num_patches = (self .img_size // patch_size ) ** 2
230
+ self .mask_ratio = mask_ratio
231
+ self .blocks = nn .ModuleList ([
232
+ PatchEmbedding2D (in_channels , emb_size , patch_size , permute = False ),
233
+ KernelTransformerStage (emb_size , struct [0 ], heads = 4 , kernel_size = 4 , stride = 2 ),
234
+ PatchEmbedding2D (emb_size , emb_size * 2 , 2 ),
235
+ KernelTransformerStage (emb_size * 2 , struct [1 ], heads = 8 , kernel_size = 4 , stride = 2 ),
236
+ PatchEmbedding2D (emb_size * 2 , emb_size * 4 , 2 ),
237
+ KernelTransformerStage (emb_size * 4 , struct [2 ], heads = 16 , kernel_size = 4 , stride = 2 ),
238
+ PatchEmbedding2D (emb_size * 4 , emb_size * 8 , 1 ),
239
+ KernelTransformerStage (emb_size * 8 , struct [3 ], heads = 32 , kernel_size = 4 , stride = 2 )
240
+ ])
241
+ self .classifier = nn .Sequential (
242
+ nn .LayerNorm (emb_size * 8 ),
243
+ nn .Linear (emb_size * 8 , num_classes )
244
+ )
245
+
246
+ # randomly mask some patches
247
+ def generate_random_mask (self , ratio ):
248
+ num_masked_patches = int (self .num_patches * ratio )
249
+ mask = torch .ones (self .num_patches )
250
+ mask_idx = torch .randperm (self .num_patches )[:num_masked_patches ]
251
+ mask [mask_idx ] = 0
252
+ return mask
253
+
254
+ def forward (self , x , masked = True ):
255
+ # masked during training, unmasked during testing
256
+ applied = False
257
+ if masked :
258
+ mask = self .generate_random_mask (self .mask_ratio ).to (x .device )
259
+ mask = mask .view (16 , 16 )[None , :, :, None ]
260
+ mask = mask .expand (x .size (0 ), - 1 , - 1 , self .emb_size )
261
+ for blk in self .blocks :
262
+ x = blk (x )
263
+ if isinstance (blk , PatchEmbedding2D ) and masked and not applied :
264
+ x = x * mask
265
+ applied = True
266
+ x = x .mean (dim = [1 ,2 ])
267
+ return self .classifier (x )
268
+
269
+
230
270
if __name__ == '__main__' :
231
271
from utils import count_parameters
232
272
0 commit comments