Skip to content

Commit 39f8493

Browse files
committed
fix a bug, addressing #6 thanks to @MattMcPartlon
1 parent 2f3ea68 commit 39f8493

File tree

2 files changed

+5
-9
lines changed

2 files changed

+5
-9
lines changed

rin_pytorch/rin_pytorch.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -194,23 +194,19 @@ def forward(
194194
time = None
195195
):
196196
h = self.heads
197-
has_context = exists(context)
198197

199-
context = default(context, x)
200-
201-
if x.shape[-1] != self.norm.gamma.shape[-1]:
202-
print(context.shape, x.shape, self.norm.gamma.shape)
198+
if exists(context):
199+
context = self.norm_context(context)
203200

204201
x = self.norm(x)
205202

203+
context = default(context, x)
204+
206205
if exists(self.time_cond):
207206
assert exists(time)
208207
scale, shift = self.time_cond(time).chunk(2, dim = -1)
209208
x = (x * (scale + 1)) + shift
210209

211-
if has_context:
212-
context = self.norm_context(context)
213-
214210
qkv = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
215211
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
216212

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'RIN-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.4.6',
6+
version = '0.4.7',
77
license='MIT',
88
description = 'RIN - Recurrent Interface Network - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)