Skip to content

Commit caafeee

Browse files
committed
fix: encoder attn
1 parent 8acc6b7 commit caafeee

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

vqgan_jax/modeling_flax_vqgan.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -273,12 +273,12 @@ def setup(self):
273273
dtype=self.dtype)
274274

275275
def __call__(self, hidden_states, temb=None, deterministic: bool = True):
276-
for res_block in self.block:
276+
for i, res_block in enumerate(self.block):
277277
hidden_states = res_block(hidden_states,
278278
temb,
279279
deterministic=deterministic)
280-
for attn_block in self.attn:
281-
hidden_states = attn_block(hidden_states)
280+
if self.attn:
281+
hidden_states = self.attn[i](hidden_states)
282282

283283
if self.downsample is not None:
284284
hidden_states = self.downsample(hidden_states)

0 commit comments

Comments
 (0)