Skip to content

Commit 8acc6b7

Browse files
authored
fix: decoder attention
Attention layers in decoder were not used at the correct step
1 parent 1be20ee commit 8acc6b7

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

vqgan_jax/modeling_flax_vqgan.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -225,12 +225,12 @@ def setup(self):
225225
dtype=self.dtype)
226226

227227
def __call__(self, hidden_states, temb=None, deterministic: bool = True):
228-
for res_block in self.block:
228+
for i, res_block in enumerate(self.block):
229229
hidden_states = res_block(hidden_states,
230230
temb,
231231
deterministic=deterministic)
232-
for attn_block in self.attn:
233-
hidden_states = attn_block(hidden_states)
232+
if self.attn:
233+
hidden_states = self.attn[i](hidden_states)
234234

235235
if self.upsample is not None:
236236
hidden_states = self.upsample(hidden_states)

0 commit comments

Comments
 (0)