Skip to content

Commit 34caeca

Browse files
committed
also bring in the qk norm used in simple diffusion (and in a lot of vision models at Brain) for potential stability, but make it an option, as still not sure if it hurts eval
1 parent 2168826 commit 34caeca

File tree

3 files changed

+39
-2
lines changed

3 files changed

+39
-2
lines changed

README.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,3 +158,11 @@ sampled_images.shape # (4, 3, 128, 128)
158158
year = {2022}
159159
}
160160
```
161+
162+
```bibtex
163+
@inproceedings{Hoogeboom2023simpleDE,
164+
title = {simple diffusion: End-to-end diffusion for high resolution images},
165+
author = {Emiel Hoogeboom and Jonathan Heek and Tim Salimans},
166+
year = {2023}
167+
}
168+
```

rin_pytorch/rin_pytorch.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,15 @@ def __init__(self, dim):
8181
def forward(self, x):
8282
return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
8383

84+
class MultiHeadedRMSNorm(nn.Module):
85+
def __init__(self, dim, heads = 1):
86+
super().__init__()
87+
self.scale = dim ** 0.5
88+
self.gamma = nn.Parameter(torch.ones(heads, 1, dim))
89+
90+
def forward(self, x):
91+
return F.normalize(x, dim = -1) * self.scale * self.gamma
92+
8493
# positional embeds
8594

8695
class LearnedSinusoidalPosEmb(nn.Module):
@@ -104,6 +113,7 @@ def __init__(
104113
heads = 4,
105114
dim_head = 32,
106115
norm = False,
116+
qk_norm = False,
107117
time_cond_dim = None
108118
):
109119
super().__init__()
@@ -127,6 +137,11 @@ def __init__(
127137

128138
self.to_qkv = nn.Linear(dim, hidden_dim * 3, bias = False)
129139

140+
self.qk_norm = qk_norm
141+
if qk_norm:
142+
self.q_norm = MultiHeadedRMSNorm(dim_head, heads)
143+
self.k_norm = MultiHeadedRMSNorm(dim_head, heads)
144+
130145
self.to_out = nn.Sequential(
131146
nn.Linear(hidden_dim, dim, bias = False),
132147
LayerNorm(dim)
@@ -148,6 +163,10 @@ def forward(
148163
qkv = self.to_qkv(x).chunk(3, dim = -1)
149164
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
150165

166+
if self.qk_norm:
167+
q = self.q_norm(q)
168+
k = self.k_norm(k)
169+
151170
q = q.softmax(dim = -1)
152171
k = k.softmax(dim = -2)
153172

@@ -169,7 +188,8 @@ def __init__(
169188
norm = False,
170189
norm_context = False,
171190
time_cond_dim = None,
172-
flash = False
191+
flash = False,
192+
qk_norm = False
173193
):
174194
super().__init__()
175195
hidden_dim = dim_head * heads
@@ -197,6 +217,11 @@ def __init__(
197217
self.to_kv = nn.Linear(dim_context, hidden_dim * 2, bias = False)
198218
self.to_out = nn.Linear(hidden_dim, dim, bias = False)
199219

220+
self.qk_norm = qk_norm
221+
if qk_norm:
222+
self.q_norm = MultiHeadedRMSNorm(dim_head, heads)
223+
self.k_norm = MultiHeadedRMSNorm(dim_head, heads)
224+
200225
self.attend = Attend(flash = flash)
201226

202227
def forward(
@@ -222,6 +247,10 @@ def forward(
222247
qkv = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
223248
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
224249

250+
if self.qk_norm:
251+
q = self.q_norm(q)
252+
k = self.k_norm(k)
253+
225254
out = self.attend(q, k, v)
226255

227256
out = rearrange(out, 'b h n d -> b n (h d)')

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'RIN-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.7.8',
6+
version = '0.7.9',
77
license='MIT',
88
description = 'RIN - Recurrent Interface Network - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)