Skip to content

Commit 0cf3395

Browse files
Fix batch size above 1 giving bad output in chroma radiance. (Comfy-Org#10394)
1 parent 5b80add commit 0cf3395

File tree

1 file changed

+7
-16
lines changed

1 file changed

+7
-16
lines changed

comfy/ldm/chroma_radiance/model.py

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -189,15 +189,15 @@ def forward_nerf(
189189
nerf_pixels = nn.functional.unfold(img_orig, kernel_size=patch_size, stride=patch_size)
190190
nerf_pixels = nerf_pixels.transpose(1, 2) # -> [B, NumPatches, C * P * P]
191191

192+
# Reshape for per-patch processing
193+
nerf_hidden = img_out.reshape(B * num_patches, params.hidden_size)
194+
nerf_pixels = nerf_pixels.reshape(B * num_patches, C, patch_size**2).transpose(1, 2)
195+
192196
if params.nerf_tile_size > 0 and num_patches > params.nerf_tile_size:
193197
# Enable tiling if nerf_tile_size isn't 0 and we actually have more patches than
194198
# the tile size.
195-
img_dct = self.forward_tiled_nerf(img_out, nerf_pixels, B, C, num_patches, patch_size, params)
199+
img_dct = self.forward_tiled_nerf(nerf_hidden, nerf_pixels, B, C, num_patches, patch_size, params)
196200
else:
197-
# Reshape for per-patch processing
198-
nerf_hidden = img_out.reshape(B * num_patches, params.hidden_size)
199-
nerf_pixels = nerf_pixels.reshape(B * num_patches, C, patch_size**2).transpose(1, 2)
200-
201201
# Get DCT-encoded pixel embeddings [pixel-dct]
202202
img_dct = self.nerf_image_embedder(nerf_pixels)
203203

@@ -240,17 +240,8 @@ def forward_tiled_nerf(
240240
end = min(i + tile_size, num_patches)
241241

242242
# Slice the current tile from the input tensors
243-
nerf_hidden_tile = nerf_hidden[:, i:end, :]
244-
nerf_pixels_tile = nerf_pixels[:, i:end, :]
245-
246-
# Get the actual number of patches in this tile (can be smaller for the last tile)
247-
num_patches_tile = nerf_hidden_tile.shape[1]
248-
249-
# Reshape the tile for per-patch processing
250-
# [B, NumPatches_tile, D] -> [B * NumPatches_tile, D]
251-
nerf_hidden_tile = nerf_hidden_tile.reshape(batch * num_patches_tile, params.hidden_size)
252-
# [B, NumPatches_tile, C*P*P] -> [B*NumPatches_tile, C, P*P] -> [B*NumPatches_tile, P*P, C]
253-
nerf_pixels_tile = nerf_pixels_tile.reshape(batch * num_patches_tile, channels, patch_size**2).transpose(1, 2)
243+
nerf_hidden_tile = nerf_hidden[i * batch:end * batch]
244+
nerf_pixels_tile = nerf_pixels[i * batch:end * batch]
254245

255246
# get DCT-encoded pixel embeddings [pixel-dct]
256247
img_dct_tile = self.nerf_image_embedder(nerf_pixels_tile)

0 commit comments

Comments
 (0)