@@ -189,15 +189,15 @@ def forward_nerf(
189189 nerf_pixels = nn .functional .unfold (img_orig , kernel_size = patch_size , stride = patch_size )
190190 nerf_pixels = nerf_pixels .transpose (1 , 2 ) # -> [B, NumPatches, C * P * P]
191191
192+ # Reshape for per-patch processing
193+ nerf_hidden = img_out .reshape (B * num_patches , params .hidden_size )
194+ nerf_pixels = nerf_pixels .reshape (B * num_patches , C , patch_size ** 2 ).transpose (1 , 2 )
195+
192196 if params .nerf_tile_size > 0 and num_patches > params .nerf_tile_size :
193197 # Enable tiling if nerf_tile_size isn't 0 and we actually have more patches than
194198 # the tile size.
195- img_dct = self .forward_tiled_nerf (img_out , nerf_pixels , B , C , num_patches , patch_size , params )
199+ img_dct = self .forward_tiled_nerf (nerf_hidden , nerf_pixels , B , C , num_patches , patch_size , params )
196200 else :
197- # Reshape for per-patch processing
198- nerf_hidden = img_out .reshape (B * num_patches , params .hidden_size )
199- nerf_pixels = nerf_pixels .reshape (B * num_patches , C , patch_size ** 2 ).transpose (1 , 2 )
200-
201201 # Get DCT-encoded pixel embeddings [pixel-dct]
202202 img_dct = self .nerf_image_embedder (nerf_pixels )
203203
@@ -240,17 +240,8 @@ def forward_tiled_nerf(
240240 end = min (i + tile_size , num_patches )
241241
242242 # Slice the current tile from the input tensors
243- nerf_hidden_tile = nerf_hidden [:, i :end , :]
244- nerf_pixels_tile = nerf_pixels [:, i :end , :]
245-
246- # Get the actual number of patches in this tile (can be smaller for the last tile)
247- num_patches_tile = nerf_hidden_tile .shape [1 ]
248-
249- # Reshape the tile for per-patch processing
250- # [B, NumPatches_tile, D] -> [B * NumPatches_tile, D]
251- nerf_hidden_tile = nerf_hidden_tile .reshape (batch * num_patches_tile , params .hidden_size )
252- # [B, NumPatches_tile, C*P*P] -> [B*NumPatches_tile, C, P*P] -> [B*NumPatches_tile, P*P, C]
253- nerf_pixels_tile = nerf_pixels_tile .reshape (batch * num_patches_tile , channels , patch_size ** 2 ).transpose (1 , 2 )
243+ nerf_hidden_tile = nerf_hidden [i * batch :end * batch ]
244+ nerf_pixels_tile = nerf_pixels [i * batch :end * batch ]
254245
255246 # get DCT-encoded pixel embeddings [pixel-dct]
256247 img_dct_tile = self .nerf_image_embedder (nerf_pixels_tile )
0 commit comments