Skip to content

Commit 10ef240

Browse files
authored
Merge pull request #10 from borisdayma/patch-1
fix: encoder/decoder attention
2 parents 1be20ee + caafeee commit 10ef240

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

vqgan_jax/modeling_flax_vqgan.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -225,12 +225,12 @@ def setup(self):
225225
dtype=self.dtype)
226226

227227
def __call__(self, hidden_states, temb=None, deterministic: bool = True):
228-
for res_block in self.block:
228+
for i, res_block in enumerate(self.block):
229229
hidden_states = res_block(hidden_states,
230230
temb,
231231
deterministic=deterministic)
232-
for attn_block in self.attn:
233-
hidden_states = attn_block(hidden_states)
232+
if self.attn:
233+
hidden_states = self.attn[i](hidden_states)
234234

235235
if self.upsample is not None:
236236
hidden_states = self.upsample(hidden_states)
@@ -273,12 +273,12 @@ def setup(self):
273273
dtype=self.dtype)
274274

275275
def __call__(self, hidden_states, temb=None, deterministic: bool = True):
276-
for res_block in self.block:
276+
for i, res_block in enumerate(self.block):
277277
hidden_states = res_block(hidden_states,
278278
temb,
279279
deterministic=deterministic)
280-
for attn_block in self.attn:
281-
hidden_states = attn_block(hidden_states)
280+
if self.attn:
281+
hidden_states = self.attn[i](hidden_states)
282282

283283
if self.downsample is not None:
284284
hidden_states = self.downsample(hidden_states)

0 commit comments

Comments
 (0)