@@ -327,3 +327,119 @@ def get_codebook_entry(self, indices, shape):
327
327
z_q = z_q .permute (0 , 3 , 1 , 2 ).contiguous ()
328
328
329
329
return z_q
330
+
331
+ class EmbeddingEMA (nn .Module ):
332
+ def __init__ (self , num_tokens , codebook_dim , decay = 0.99 , eps = 1e-5 ):
333
+ super ().__init__ ()
334
+ self .decay = decay
335
+ self .eps = eps
336
+ weight = torch .randn (num_tokens , codebook_dim )
337
+ self .weight = nn .Parameter (weight , requires_grad = False )
338
+ self .cluster_size = nn .Parameter (torch .zeros (num_tokens ), requires_grad = False )
339
+ self .embed_avg = nn .Parameter (weight .clone (), requires_grad = False )
340
+ self .update = True
341
+
342
+ def forward (self , embed_id ):
343
+ return F .embedding (embed_id , self .weight )
344
+
345
+ def cluster_size_ema_update (self , new_cluster_size ):
346
+ self .cluster_size .data .mul_ (self .decay ).add_ (new_cluster_size , alpha = 1 - self .decay )
347
+
348
+ def embed_avg_ema_update (self , new_embed_avg ):
349
+ self .embed_avg .data .mul_ (self .decay ).add_ (new_embed_avg , alpha = 1 - self .decay )
350
+
351
+ def weight_update (self , num_tokens ):
352
+ n = self .cluster_size .sum ()
353
+ smoothed_cluster_size = (
354
+ (self .cluster_size + self .eps ) / (n + num_tokens * self .eps ) * n
355
+ )
356
+ #normalize embedding average with smoothed cluster size
357
+ embed_normalized = self .embed_avg / smoothed_cluster_size .unsqueeze (1 )
358
+ self .weight .data .copy_ (embed_normalized )
359
+
360
+
361
+ class EMAVectorQuantizer (nn .Module ):
362
+ def __init__ (self , n_embed , embedding_dim , beta , decay = 0.99 , eps = 1e-5 ,
363
+ remap = None , unknown_index = "random" ):
364
+ super ().__init__ ()
365
+ self .codebook_dim = codebook_dim
366
+ self .num_tokens = num_tokens
367
+ self .beta = beta
368
+ self .embedding = EmbeddingEMA (self .num_tokens , self .codebook_dim , decay , eps )
369
+
370
+ self .remap = remap
371
+ if self .remap is not None :
372
+ self .register_buffer ("used" , torch .tensor (np .load (self .remap )))
373
+ self .re_embed = self .used .shape [0 ]
374
+ self .unknown_index = unknown_index # "random" or "extra" or integer
375
+ if self .unknown_index == "extra" :
376
+ self .unknown_index = self .re_embed
377
+ self .re_embed = self .re_embed + 1
378
+ print (f"Remapping { self .n_embed } indices to { self .re_embed } indices. "
379
+ f"Using { self .unknown_index } for unknown indices." )
380
+ else :
381
+ self .re_embed = n_embed
382
+
383
+ def remap_to_used (self , inds ):
384
+ ishape = inds .shape
385
+ assert len (ishape )> 1
386
+ inds = inds .reshape (ishape [0 ],- 1 )
387
+ used = self .used .to (inds )
388
+ match = (inds [:,:,None ]== used [None ,None ,...]).long ()
389
+ new = match .argmax (- 1 )
390
+ unknown = match .sum (2 )< 1
391
+ if self .unknown_index == "random" :
392
+ new [unknown ]= torch .randint (0 ,self .re_embed ,size = new [unknown ].shape ).to (device = new .device )
393
+ else :
394
+ new [unknown ] = self .unknown_index
395
+ return new .reshape (ishape )
396
+
397
+ def unmap_to_all (self , inds ):
398
+ ishape = inds .shape
399
+ assert len (ishape )> 1
400
+ inds = inds .reshape (ishape [0 ],- 1 )
401
+ used = self .used .to (inds )
402
+ if self .re_embed > self .used .shape [0 ]: # extra token
403
+ inds [inds >= self .used .shape [0 ]] = 0 # simply set to zero
404
+ back = torch .gather (used [None ,:][inds .shape [0 ]* [0 ],:], 1 , inds )
405
+ return back .reshape (ishape )
406
+
407
+ def forward (self , z ):
408
+ # reshape z -> (batch, height, width, channel) and flatten
409
+ #z, 'b c h w -> b h w c'
410
+ z = rearrange (z , 'b c h w -> b h w c' )
411
+ z_flattened = z .reshape (- 1 , self .codebook_dim )
412
+
413
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
414
+ d = z_flattened .pow (2 ).sum (dim = 1 , keepdim = True ) + \
415
+ self .embedding .weight .pow (2 ).sum (dim = 1 ) - 2 * \
416
+ torch .einsum ('bd,nd->bn' , z_flattened , self .embedding .weight ) # 'n d -> d n'
417
+
418
+
419
+ encoding_indices = torch .argmin (d , dim = 1 )
420
+
421
+ z_q = self .embedding (encoding_indices ).view (z .shape )
422
+ encodings = F .one_hot (encoding_indices , self .num_tokens ).type (z .dtype )
423
+ avg_probs = torch .mean (encodings , dim = 0 )
424
+ perplexity = torch .exp (- torch .sum (avg_probs * torch .log (avg_probs + 1e-10 )))
425
+
426
+ if self .training and self .embedding .update :
427
+ #EMA cluster size
428
+ encodings_sum = encodings .sum (0 )
429
+ self .embedding .cluster_size_ema_update (encodings_sum )
430
+ #EMA embedding average
431
+ embed_sum = encodings .transpose (0 ,1 ) @ z_flattened
432
+ self .embedding .embed_avg_ema_update (embed_sum )
433
+ #normalize embed_avg and update weight
434
+ self .embedding .weight_update (self .num_tokens )
435
+
436
+ # compute loss for embedding
437
+ loss = self .beta * F .mse_loss (z_q .detach (), z )
438
+
439
+ # preserve gradients
440
+ z_q = z + (z_q - z ).detach ()
441
+
442
+ # reshape back to match original input shape
443
+ #z_q, 'b h w c -> b c h w'
444
+ z_q = rearrange (z_q , 'b h w c -> b c h w' )
445
+ return z_q , loss , (perplexity , encodings , encoding_indices )
0 commit comments