Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Shape mismatch in dataloader when rasterization is not used. #854

Open
LucaMarconato opened this issue Jan 30, 2025 · 0 comments
Open

Shape mismatch in dataloader when rasterization is not used. #854

LucaMarconato opened this issue Jan 30, 2025 · 0 comments

Comments

@LucaMarconato
Copy link
Member

Originally observed by @ilia-kats

When constructing a ImageTileDataset without using rasterization, a shape mismatch bug could occur (the returned tiles are not all of of the same size). This bug is for instance triggered when trying to construct a dataloader, as the collate function would fail.

Here below is the code to reproduce, on one of the dataset from https://github.com/giovp/spatialdata-sandbox. Uncommenting the 2 lines about rasterization show that the bug only affects the query in the intrinsic coordinate system (="pixel space").

import spatialdata as sd
import torch

sdata = sd.read_zarr("/Users/macbook/embl/projects/basel/spatialdata-sandbox/visium_2.1.0_1_io/data.zarr")
IMAGE_ELEMENT = "CytAssist_FFPE_Human_Colon_Post_Xenium_Rep1_hires_image"
SHAPES_ELEMENT = "CytAssist_FFPE_Human_Colon_Post_Xenium_Rep1"
COORDINATE_SYSTEM = "downscaled_hires"

# sdata = sd.read_zarr("/Users/macbook/Desktop/mousebrain.zarr")
# IMAGE_ELEMENT = "mousebrain_hires_image"
# SHAPES_ELEMENT = "mousebrain"
# COORDINATE_SYSTEM = "downscaled_hires"

dataset = sd.dataloader.ImageTilesDataset(
    sdata,
    {SHAPES_ELEMENT: IMAGE_ELEMENT},
    {SHAPES_ELEMENT: COORDINATE_SYSTEM},
    tile_scale=1.5,
    # rasterize=True,
    # rasterize_kwargs={"target_width": 224},
)

dloader = torch.utils.data.DataLoader(
    [tile.images[IMAGE_ELEMENT].to_numpy() for tile in dataset],
    batch_size=sdata.shapes[SHAPES_ELEMENT].shape[0],
)

data = next(iter(dloader))
print(data.shape)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant