Skip to content

Commit 943b85c

Browse files
committed
add dual patchnorm as an option
1 parent 2a49baa commit 943b85c

File tree

3 files changed

+21
-3
lines changed

3 files changed

+21
-3
lines changed

README.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,3 +129,15 @@ sampled_images.shape # (4, 3, 128, 128)
129129
volume = {abs/2202.00512}
130130
}
131131
```
132+
133+
```bibtex
134+
@misc{https://doi.org/10.48550/arxiv.2302.01327,
135+
doi = {10.48550/ARXIV.2302.01327},
136+
url = {https://arxiv.org/abs/2302.01327},
137+
author = {Kumar, Manoj and Dehghani, Mostafa and Houlsby, Neil},
138+
title = {Dual PatchNorm},
139+
publisher = {arXiv},
140+
year = {2023},
141+
copyright = {Creative Commons Attribution 4.0 International}
142+
}
143+
```

rin_pytorch/rin_pytorch.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,9 @@ def convert_image_to(img_type, image):
6464
return image.convert(img_type)
6565
return image
6666

67+
def Sequential(*mods):
68+
return nn.Sequential(*filter(exists, mods))
69+
6770
# use layernorm without bias, more stable
6871

6972
class LayerNorm(nn.Module):
@@ -347,6 +350,7 @@ def __init__(
347350
num_latents = 256, # they still had to use a fair amount of latents for good results (256), in line with the Perceiver line of papers from Deepmind
348351
learned_sinusoidal_dim = 16,
349352
latent_token_time_cond = False, # whether to use 1 latent token as time conditioning, or do it the adaptive layernorm way (which is highly effective as shown by some other papers "Paella" - Dominic Rampas et al.)
353+
dual_patchnorm = True,
350354
**attn_kwargs
351355
):
352356
super().__init__()
@@ -378,9 +382,11 @@ def __init__(
378382

379383
# pixels to patch and back
380384

381-
self.to_patches = nn.Sequential(
385+
self.to_patches = Sequential(
382386
Rearrange('b c (h p1) (w p2) -> b (h w) (c p1 p2)', p1 = patch_size, p2 = patch_size),
383-
nn.Linear(pixel_patch_dim * 2, dim)
387+
nn.LayerNorm(pixel_patch_dim * 2) if dual_patchnorm else None,
388+
nn.Linear(pixel_patch_dim * 2, dim),
389+
nn.LayerNorm(dim) if dual_patchnorm else None,
384390
)
385391

386392
self.axial_pos_emb = nn.Parameter(torch.randn(2, patch_height_width, dim) * 0.02)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'RIN-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.5.2',
6+
version = '0.5.3',
77
license='MIT',
88
description = 'RIN - Recurrent Interface Network - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)