@@ -225,12 +225,12 @@ def setup(self):
225
225
dtype = self .dtype )
226
226
227
227
def __call__ (self , hidden_states , temb = None , deterministic : bool = True ):
228
- for res_block in self .block :
228
+ for i , res_block in enumerate ( self .block ) :
229
229
hidden_states = res_block (hidden_states ,
230
230
temb ,
231
231
deterministic = deterministic )
232
- for attn_block in self .attn :
233
- hidden_states = attn_block (hidden_states )
232
+ if self .attn :
233
+ hidden_states = self . attn [ i ] (hidden_states )
234
234
235
235
if self .upsample is not None :
236
236
hidden_states = self .upsample (hidden_states )
@@ -273,12 +273,12 @@ def setup(self):
273
273
dtype = self .dtype )
274
274
275
275
def __call__ (self , hidden_states , temb = None , deterministic : bool = True ):
276
- for res_block in self .block :
276
+ for i , res_block in enumerate ( self .block ) :
277
277
hidden_states = res_block (hidden_states ,
278
278
temb ,
279
279
deterministic = deterministic )
280
- for attn_block in self .attn :
281
- hidden_states = attn_block (hidden_states )
280
+ if self .attn :
281
+ hidden_states = self . attn [ i ] (hidden_states )
282
282
283
283
if self .downsample is not None :
284
284
hidden_states = self .downsample (hidden_states )
0 commit comments