Skip to content

Commit

Permalink
Minor fix to argument positions in MitoEM dataset (#478)
Browse files Browse the repository at this point in the history
  • Loading branch information
anwai98 authored Jan 24, 2025
1 parent 5b069cc commit e0b2356
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions torch_em/data/datasets/electron_microscopy/mitoem.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,6 @@ def get_mitoem_data(path: Union[os.PathLike, str], samples: Sequence[str], split
splits: The data splits to download. The available splits are 'train', 'val' and 'test'.
download: Whether to download the data if it is not present.
"""
if isinstance(splits, str):
splits = [splits]
assert len(set(splits) - {"train", "val"}) == 0, f"{splits}"
assert len(set(samples) - {"human", "rat"}) == 0, f"{samples}"
os.makedirs(path, exist_ok=True)
Expand Down Expand Up @@ -181,8 +179,15 @@ def get_mitoem_paths(
Returns:
The filepaths for the stored data.
"""
if isinstance(splits, str):
splits = [splits]

if isinstance(samples, str):
samples = [samples]

get_mitoem_data(path, samples, splits, download)
data_paths = [os.path.join(path, f"{sample}_{split}.n5") for split in splits for sample in samples]

return data_paths


Expand Down Expand Up @@ -215,7 +220,7 @@ def get_mitoem_dataset(
"""
assert len(patch_shape) == 3

data_paths = get_mitoem_paths(path, samples, splits, download)
data_paths = get_mitoem_paths(path, splits, samples, download)

kwargs, _ = util.add_instance_label_transform(
kwargs, add_binary_target=True, binary=binary, boundaries=boundaries, offsets=offsets
Expand Down Expand Up @@ -261,8 +266,5 @@ def get_mitoem_loader(
The DataLoader.
"""
ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
dataset = get_mitoem_dataset(
path, splits, patch_shape, samples=samples, download=download,
offsets=offsets, boundaries=boundaries, binary=binary, **ds_kwargs
)
dataset = get_mitoem_dataset(path, splits, patch_shape, samples, download, offsets, boundaries, binary, **ds_kwargs)
return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)

0 comments on commit e0b2356

Please sign in to comment.