From 0b2b01ac8fabac49db6e0ad2e158f7ef1711dc81 Mon Sep 17 00:00:00 2001 From: Anwai Archit <52396323+anwai98@users.noreply.github.com> Date: Tue, 31 Dec 2024 11:40:51 +0100 Subject: [PATCH] Minor refactor to LM datasets (#459) Updates to some LM datasets --- .../light_microscopy/check_bitdepth_nucseg.py | 3 +- .../light_microscopy/check_cellpose.py | 4 +- .../datasets/light_microscopy/check_ctc.py | 2 +- .../light_microscopy/check_cvz_fluo.py | 4 +- .../light_microscopy/check_deepbacs.py | 6 +- .../check_dynamicnuclearnet.py | 19 ++++--- .../light_microscopy/check_embedseg.py | 11 ++-- .../light_microscopy/check_ifnuclei.py | 1 + .../light_microscopy/check_orgasegment.py | 19 ++++--- .../light_microscopy/check_tissuenet.py | 3 +- .../datasets/light_microscopy/check_yeaz.py | 16 +++--- .../light_microscopy/bitdepth_nucseg.py | 2 + .../datasets/light_microscopy/cellpose.py | 46 +++++++-------- .../data/datasets/light_microscopy/ctc.py | 6 +- .../data/datasets/light_microscopy/dsb.py | 1 - .../light_microscopy/dynamicnuclearnet.py | 22 ++++--- .../datasets/light_microscopy/ifnuclei.py | 2 + .../datasets/light_microscopy/omnipose.py | 8 +-- .../datasets/light_microscopy/orgasegment.py | 33 ++++------- .../data/datasets/light_microscopy/yeaz.py | 57 +++++++++++++++---- 20 files changed, 151 insertions(+), 114 deletions(-) diff --git a/scripts/datasets/light_microscopy/check_bitdepth_nucseg.py b/scripts/datasets/light_microscopy/check_bitdepth_nucseg.py index c09d28bb..9cbbfe03 100644 --- a/scripts/datasets/light_microscopy/check_bitdepth_nucseg.py +++ b/scripts/datasets/light_microscopy/check_bitdepth_nucseg.py @@ -13,7 +13,8 @@ def check_bitdepth_nucseg(): loader = get_bitdepth_nucseg_loader( path=os.path.join(ROOT, "bitdepth_nucseg"), patch_shape=(512, 512), - batch_size=1, + batch_size=2, + magnification=None, download=True, ) diff --git a/scripts/datasets/light_microscopy/check_cellpose.py b/scripts/datasets/light_microscopy/check_cellpose.py index 2001cca9..982eb173 100644 --- a/scripts/datasets/light_microscopy/check_cellpose.py +++ b/scripts/datasets/light_microscopy/check_cellpose.py @@ -15,7 +15,9 @@ def check_cellpose(): split="train", patch_shape=(512, 512), batch_size=1, - choice=None, + choice="cyto", + download=True, + shuffle=True, ) check_loader(loader, 8, instance_labels=True) diff --git a/scripts/datasets/light_microscopy/check_ctc.py b/scripts/datasets/light_microscopy/check_ctc.py index c94a91b9..d11bde20 100644 --- a/scripts/datasets/light_microscopy/check_ctc.py +++ b/scripts/datasets/light_microscopy/check_ctc.py @@ -1,9 +1,9 @@ import os import sys -from torch_em.data.datasets.light_microscopy.ctc import get_ctc_segmentation_loader, CTC_CHECKSUMS from torch_em.util.debug import check_loader from torch_em.data.sampler import MinInstanceSampler +from torch_em.data.datasets.light_microscopy.ctc import get_ctc_segmentation_loader, CTC_CHECKSUMS sys.path.append("..") diff --git a/scripts/datasets/light_microscopy/check_cvz_fluo.py b/scripts/datasets/light_microscopy/check_cvz_fluo.py index c1960041..750c5004 100644 --- a/scripts/datasets/light_microscopy/check_cvz_fluo.py +++ b/scripts/datasets/light_microscopy/check_cvz_fluo.py @@ -15,11 +15,11 @@ def check_cvz_fluo(): path=os.path.join(ROOT, "cvz"), patch_shape=(512, 512), batch_size=2, - stain_choice="cell", + stain_choice="dapi", data_choice=None, ) - check_loader(loader, 8, instance_labels=True, plt=True, save_path="./test.png", rgb=True) + check_loader(loader, 8, instance_labels=True, rgb=True) if __name__ == "__main__": diff --git a/scripts/datasets/light_microscopy/check_deepbacs.py b/scripts/datasets/light_microscopy/check_deepbacs.py index fd5bab8a..89254233 100644 --- a/scripts/datasets/light_microscopy/check_deepbacs.py +++ b/scripts/datasets/light_microscopy/check_deepbacs.py @@ -10,8 +10,10 @@ def check_deepbacs(): from util import ROOT - loader = get_deepbacs_loader(os.path.join(ROOT, "deepbacs"), "test", bac_type="mixed", download=True, - patch_shape=(256, 256), batch_size=1, shuffle=True) + loader = get_deepbacs_loader( + os.path.join(ROOT, "deepbacs"), "test", bac_type="mixed", + download=True, patch_shape=(256, 256), batch_size=1, shuffle=True + ) check_loader(loader, 15, instance_labels=True) diff --git a/scripts/datasets/light_microscopy/check_dynamicnuclearnet.py b/scripts/datasets/light_microscopy/check_dynamicnuclearnet.py index 1b4c1256..8018f266 100644 --- a/scripts/datasets/light_microscopy/check_dynamicnuclearnet.py +++ b/scripts/datasets/light_microscopy/check_dynamicnuclearnet.py @@ -1,19 +1,24 @@ +import os +import sys + from torch_em.util.debug import check_loader from torch_em.data.datasets import get_dynamicnuclearnet_loader -DYNAMICNUCLEARNET_ROOT = "/home/anwai/data/deepcell/" +sys.path.append("..") -# NOTE: the DynamicNuclearNet data cannot be downloaded automatically. -# you need to download it yourself from https://datasets.deepcell.org/data def check_dynamicnuclearnet(): - # set this path to where you have downloaded the dynamicnuclearnet data + from util import ROOT + loader = get_dynamicnuclearnet_loader( - DYNAMICNUCLEARNET_ROOT, "train", - patch_shape=(512, 512), batch_size=2, download=True + path=os.path.join(ROOT, "dynamicnuclearnet"), + split="train", + patch_shape=(512, 512), + batch_size=2, + download=True, ) - check_loader(loader, 10, instance_labels=True, rgb=False) + check_loader(loader, 8, instance_labels=True) if __name__ == "__main__": diff --git a/scripts/datasets/light_microscopy/check_embedseg.py b/scripts/datasets/light_microscopy/check_embedseg.py index b90aeb71..a8eb5723 100644 --- a/scripts/datasets/light_microscopy/check_embedseg.py +++ b/scripts/datasets/light_microscopy/check_embedseg.py @@ -1,8 +1,8 @@ import os import sys -from torch_em.data.datasets import get_embedseg_loader from torch_em.util.debug import check_loader +from torch_em.data.datasets import get_embedseg_loader sys.path.append("..") @@ -17,12 +17,15 @@ def check_embedseg(): "Platynereis-Nuclei-CBG", ] - patch_shape = (32, 384, 384) for name in names: loader = get_embedseg_loader( - os.path.join(ROOT, "embedseg"), name=name, patch_shape=patch_shape, batch_size=1, download=True + path=os.path.join(ROOT, "embedseg"), + name=name, + patch_shape=(32, 384, 384), + batch_size=2, + download=True, ) - check_loader(loader, 2, instance_labels=True) + check_loader(loader, 8, instance_labels=True) if __name__ == "__main__": diff --git a/scripts/datasets/light_microscopy/check_ifnuclei.py b/scripts/datasets/light_microscopy/check_ifnuclei.py index e41a772c..8b0c68e5 100644 --- a/scripts/datasets/light_microscopy/check_ifnuclei.py +++ b/scripts/datasets/light_microscopy/check_ifnuclei.py @@ -15,6 +15,7 @@ def check_ifnuclei(): patch_shape=(512, 512), batch_size=1, download=True, + shuffle=True, ) check_loader(loader, 8, instance_labels=True) diff --git a/scripts/datasets/light_microscopy/check_orgasegment.py b/scripts/datasets/light_microscopy/check_orgasegment.py index c49d45a3..7dd43613 100644 --- a/scripts/datasets/light_microscopy/check_orgasegment.py +++ b/scripts/datasets/light_microscopy/check_orgasegment.py @@ -1,24 +1,25 @@ +import os +import sys + from torch_em.util.debug import check_loader from torch_em.data.datasets.light_microscopy import get_orgasegment_loader -ROOT = "/media/anwai/ANWAI/data/orgasegment" +sys.path.append("..") def check_orgasegment(): + from util import ROOT + loader = get_orgasegment_loader( - path=ROOT, - split="val", + path=os.path.join(ROOT, "orgasegment"), + split="train", patch_shape=(512, 512), - batch_size=1, + batch_size=2, download=True, ) check_loader(loader, 8, instance_labels=True) -def main(): - check_orgasegment() - - if __name__ == "__main__": - main() + check_orgasegment() diff --git a/scripts/datasets/light_microscopy/check_tissuenet.py b/scripts/datasets/light_microscopy/check_tissuenet.py index 2f70465b..b15e8d6e 100644 --- a/scripts/datasets/light_microscopy/check_tissuenet.py +++ b/scripts/datasets/light_microscopy/check_tissuenet.py @@ -4,8 +4,9 @@ import numpy as np from torch_em.transform.raw import standardize, normalize_percentile -from torch_em.data.datasets import get_tissuenet_loader from torch_em.util.debug import check_loader +from torch_em.data.datasets import get_tissuenet_loader + sys.path.append("..") diff --git a/scripts/datasets/light_microscopy/check_yeaz.py b/scripts/datasets/light_microscopy/check_yeaz.py index fd7ae775..31b97986 100644 --- a/scripts/datasets/light_microscopy/check_yeaz.py +++ b/scripts/datasets/light_microscopy/check_yeaz.py @@ -1,6 +1,7 @@ import os import sys +from torch_em.data import MinInstanceSampler from torch_em.util.debug import check_loader from torch_em.data.datasets import get_yeaz_loader @@ -11,19 +12,16 @@ def check_yeaz(): from util import ROOT - choice = "phc" # choose from 'bf' / 'phc' - if choice == "bf": - patch_shape, ndim = (512, 512), 2 - else: - patch_shape, ndim = (1, 512, 512), 3 - loader = get_yeaz_loader( path=os.path.join(ROOT, "yeaz"), batch_size=2, - patch_shape=patch_shape, - choice=choice, - ndim=ndim, + patch_shape=(512, 512), + ndim=2, + choice="phc", # choose from 'bf' / 'phc' + split="val", + sampler=MinInstanceSampler(), download=False, + shuffle=True, ) check_loader(loader, 8, instance_labels=True) diff --git a/torch_em/data/datasets/light_microscopy/bitdepth_nucseg.py b/torch_em/data/datasets/light_microscopy/bitdepth_nucseg.py index d273aa72..6459f4a2 100644 --- a/torch_em/data/datasets/light_microscopy/bitdepth_nucseg.py +++ b/torch_em/data/datasets/light_microscopy/bitdepth_nucseg.py @@ -93,6 +93,8 @@ def get_bitdepth_nucseg_paths( raw_paths = natsorted(glob(os.path.join(data_dir, magnification, "images_16bit", "*.tif"))) label_paths = natsorted(glob(os.path.join(data_dir, magnification, "label masks", "*.tif"))) + assert len(raw_paths) == len(label_paths) and len(raw_paths) > 0 + return raw_paths, label_paths diff --git a/torch_em/data/datasets/light_microscopy/cellpose.py b/torch_em/data/datasets/light_microscopy/cellpose.py index dc1ca440..cde982ae 100644 --- a/torch_em/data/datasets/light_microscopy/cellpose.py +++ b/torch_em/data/datasets/light_microscopy/cellpose.py @@ -1,5 +1,4 @@ -"""This dataset contains annotation for cell segmentation in -fluorescene microscently-labeled microscopy images. +"""This dataset contains annotation for cell segmentation in fluorescene microscently-labeled microscopy images. This dataset is from the following publications: - https://doi.org/10.1038/s41592-020-01018-x @@ -40,14 +39,8 @@ def get_cellpose_data( download: Whether to download the data if it is not present. Returns: - The filepath to the data. + The filepath to the folder where the data is manually downloaded. """ - if download: - raise NotImplementedError( - "The dataset cannot be automatically downloaded. " - "Please see 'get_cellpose_data' in 'torch_em/data/datasets/light_microscopy/cellpose.py' for details." - ) - per_choice_dir = os.path.join(path, choice) # path where the unzipped files will be stored if choice == "cyto": assert split in ["train", "test"], f"'{split}' is not a valid split in '{choice}'." @@ -60,7 +53,14 @@ def get_cellpose_data( else: raise ValueError(f"'{choice}' is not a valid dataset choice.") - if not os.path.exists(data_dir): + if os.path.exists(data_dir): + return data_dir + else: + if not os.path.exists(zip_path) and download: + raise NotImplementedError( + "The dataset cannot be automatically downloaded. " + "Please see 'get_cellpose_data' in 'torch_em/data/datasets/light_microscopy/cellpose.py' for details." + ) util.unzip(zip_path=zip_path, dst=per_choice_dir, remove=False) return data_dir @@ -69,7 +69,7 @@ def get_cellpose_data( def get_cellpose_paths( path: Union[os.PathLike, str], split: Literal['train', 'test'], - choice: Literal["cyto", "cyto2"], + choice: Optional[Literal["cyto", "cyto2"]] = None, download: bool = False, ) -> Tuple[List[str], List[str]]: """Get paths to the CellPose data. @@ -84,18 +84,20 @@ def get_cellpose_paths( List of filepaths for the image data. List of filepaths for the label data. """ - data_dir = get_cellpose_data(path=path, split=split, choice=choice, download=download) + data_dir = get_cellpose_data(path, split, choice, download) image_paths = natsorted(glob(os.path.join(data_dir, "*_img.png"))) gt_paths = natsorted(glob(os.path.join(data_dir, "*_masks.png"))) + assert len(image_paths) == len(gt_paths) and len(image_paths) > 0 + return image_paths, gt_paths def get_cellpose_dataset( path: Union[os.PathLike, str], - split: Literal["train", "test"], patch_shape: Tuple[int, int], + split: Literal["train", "test"], choice: Optional[Literal["cyto", "cyto2"]] = None, download: bool = False, **kwargs @@ -104,8 +106,8 @@ def get_cellpose_dataset( Args: path: Filepath to a folder where the downloaded data will be saved. - split: The data split to use. Either 'train', or 'test'. patch_shape: The patch shape to use for training. + split: The data split to use. Either 'train', or 'test'. choice: The choice of dataset. Either 'cyto' or 'cyto2'. download: Whether to download the data if it is not present. kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. @@ -124,7 +126,7 @@ def get_cellpose_dataset( image_paths, gt_paths = [], [] for per_choice in choice: assert per_choice in AVAILABLE_CHOICES - per_image_paths, per_gt_paths = get_cellpose_paths(path, split, choice, download) + per_image_paths, per_gt_paths = get_cellpose_paths(path, split, per_choice, download) image_paths.extend(per_image_paths) gt_paths.extend(per_gt_paths) @@ -148,9 +150,9 @@ def get_cellpose_dataset( def get_cellpose_loader( path: Union[os.PathLike, str], - split: Literal["train", "test"], - patch_shape: Tuple[int, int], batch_size: int, + patch_shape: Tuple[int, int], + split: Literal["train", "test"], choice: Optional[Literal["cyto", "cyto2"]] = None, download: bool = False, **kwargs @@ -159,9 +161,9 @@ def get_cellpose_loader( Args: path: Filepath to a folder where the downloaded data will be saved. - split: The data split to use. Either 'train', or 'test'. - patch_shape: The patch shape to use for training. batch_size: The batch size for training. + patch_shape: The patch shape to use for training. + split: The data split to use. Either 'train', or 'test'. choice: The choice of dataset. Either 'cyto' or 'cyto2'. download: Whether to download the data if it is not present. kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. @@ -170,7 +172,5 @@ def get_cellpose_loader( The DataLoader. """ ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) - dataset = get_cellpose_dataset( - path=path, split=split, patch_shape=patch_shape, choice=choice, download=download, **ds_kwargs - ) - return torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs) + dataset = get_cellpose_dataset(path, patch_shape, split, choice, download, **ds_kwargs) + return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/light_microscopy/ctc.py b/torch_em/data/datasets/light_microscopy/ctc.py index d6ed6e16..bec6628b 100644 --- a/torch_em/data/datasets/light_microscopy/ctc.py +++ b/torch_em/data/datasets/light_microscopy/ctc.py @@ -141,7 +141,7 @@ def get_ctc_segmentation_paths( Filepath to the folder where image data is stored. Filepath to the folder where label data is stored. """ - data_path = get_ctc_segmentation_data(path, dataset_name, download, split) + data_path = get_ctc_segmentation_data(path, dataset_name, split, download) if vol_id is None: vol_ids = glob(os.path.join(data_path, "*_GT")) @@ -222,7 +222,5 @@ def get_ctc_segmentation_loader( The DataLoader. """ ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) - dataset = get_ctc_segmentation_dataset( - path, dataset_name, patch_shape, split=split, vol_id=vol_id, download=download, **ds_kwargs, - ) + dataset = get_ctc_segmentation_dataset(path, dataset_name, patch_shape, split, vol_id, download, **ds_kwargs) return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/light_microscopy/dsb.py b/torch_em/data/datasets/light_microscopy/dsb.py index a776c448..f775d55f 100644 --- a/torch_em/data/datasets/light_microscopy/dsb.py +++ b/torch_em/data/datasets/light_microscopy/dsb.py @@ -152,7 +152,6 @@ def get_dsb_paths( raw_paths = natsorted(glob(os.path.join(path, "full", "*", "images", f"{domain}_*.png"))) label_paths = natsorted(glob(os.path.join(path, "full", "*", "preprocessed_labels", f"{domain}_*.tif"))) - print(len(raw_paths), len(label_paths)) assert len(raw_paths) == len(label_paths) and len(raw_paths) > 0 return raw_paths, label_paths diff --git a/torch_em/data/datasets/light_microscopy/dynamicnuclearnet.py b/torch_em/data/datasets/light_microscopy/dynamicnuclearnet.py index e88fa085..8fd1031b 100644 --- a/torch_em/data/datasets/light_microscopy/dynamicnuclearnet.py +++ b/torch_em/data/datasets/light_microscopy/dynamicnuclearnet.py @@ -11,7 +11,7 @@ import os from tqdm import tqdm from glob import glob -from typing import Tuple, Union, List +from typing import Tuple, Union, Literal, List import numpy as np import pandas as pd @@ -58,9 +58,7 @@ def _create_dataset(path, zip_path): def get_dynamicnuclearnet_data( - path: Union[os.PathLike, str], - split: str, - download: bool = False, + path: Union[os.PathLike, str], split: Literal['train', 'val', 'test'], download: bool = False, ) -> str: """Download the DynamicNuclearNet dataset. @@ -76,7 +74,7 @@ def get_dynamicnuclearnet_data( The path where inputs are stored per split. """ splits = ["train", "val", "test"] - assert split in splits + assert split in splits, f"'{split}' is not a valid split." # check if the dataset exists already zip_path = os.path.join(path, "DynamicNuclearNet-segmentation-v1_0.zip") @@ -115,8 +113,8 @@ def get_dynamicnuclearnet_paths(path: Union[os.PathLike, str], split: str, downl def get_dynamicnuclearnet_dataset( path: Union[os.PathLike, str], - split: str, patch_shape: Tuple[int, int], + split: Literal['train', 'val', 'test'], download: bool = False, **kwargs ) -> Dataset: @@ -124,8 +122,8 @@ def get_dynamicnuclearnet_dataset( Args: path: Filepath to a folder where the downloaded data will be saved. - split: The split to use for the dataset. Either 'train', 'val' or 'test'. patch_shape: The patch shape to use for training. + split: The split to use for the dataset. Either 'train', 'val' or 'test'. download: Whether to download the data if it is not present. kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. @@ -148,9 +146,9 @@ def get_dynamicnuclearnet_dataset( def get_dynamicnuclearnet_loader( path: Union[os.PathLike, str], - split: str, - patch_shape: Tuple[int, int], batch_size: int, + patch_shape: Tuple[int, int], + split: Literal['train', 'val', 'test'], download: bool = False, **kwargs ) -> DataLoader: @@ -158,9 +156,9 @@ def get_dynamicnuclearnet_loader( Args: path: Filepath to a folder where the downloaded data will be saved. - split: The split to use for the dataset. Either 'train', 'val' or 'test'. - patch_shape: The patch shape to use for training. batch_size: The batch size for training. + patch_shape: The patch shape to use for training. + split: The split to use for the dataset. Either 'train', 'val' or 'test'. download: Whether to download the data if it is not present. kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. @@ -168,5 +166,5 @@ def get_dynamicnuclearnet_loader( The DataLoader. """ ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) - dataset = get_dynamicnuclearnet_dataset(path, split, patch_shape, download, **ds_kwargs) + dataset = get_dynamicnuclearnet_dataset(path, patch_shape, split, download, **ds_kwargs) return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/light_microscopy/ifnuclei.py b/torch_em/data/datasets/light_microscopy/ifnuclei.py index 7249e33e..9a044342 100644 --- a/torch_em/data/datasets/light_microscopy/ifnuclei.py +++ b/torch_em/data/datasets/light_microscopy/ifnuclei.py @@ -55,6 +55,8 @@ def get_ifnuclei_paths(path: Union[os.PathLike, str], download: bool = False) -> raw_paths = natsorted(glob(os.path.join(path, "rawimages", "*.tif"))) label_paths = natsorted(glob(os.path.join(path, "groundtruth", "*"))) + assert len(raw_paths) == len(label_paths) and len(raw_paths) > 0 + return raw_paths, label_paths diff --git a/torch_em/data/datasets/light_microscopy/omnipose.py b/torch_em/data/datasets/light_microscopy/omnipose.py index ecb49ac9..d000a2a8 100644 --- a/torch_em/data/datasets/light_microscopy/omnipose.py +++ b/torch_em/data/datasets/light_microscopy/omnipose.py @@ -53,7 +53,7 @@ def get_omnipose_data(path: Union[os.PathLike, str], download: bool = False) -> def get_omnipose_paths( path: Union[os.PathLike, str], - split: str, + split: Literal["train", "test"], data_choice: Optional[Union[str, List[str]]] = None, download: bool = False ) -> Tuple[List[str], List[str]]: @@ -70,7 +70,7 @@ def get_omnipose_paths( List of filepaths for the image data. List of filepaths for the label data. """ - data_dir = get_omnipose_data(path=path, download=download) + data_dir = get_omnipose_data(path, download) if split not in ["train", "test"]: raise ValueError(f"'{split}' is not a valid split.") @@ -167,7 +167,5 @@ def get_omnipose_loader( The DataLoader. """ ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) - dataset = get_omnipose_dataset( - path=path, patch_shape=patch_shape, split=split, data_choice=data_choice, download=download, **ds_kwargs - ) + dataset = get_omnipose_dataset(path, patch_shape, split, data_choice, download, **ds_kwargs) return torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/light_microscopy/orgasegment.py b/torch_em/data/datasets/light_microscopy/orgasegment.py index f6b3f2d7..50158d2d 100644 --- a/torch_em/data/datasets/light_microscopy/orgasegment.py +++ b/torch_em/data/datasets/light_microscopy/orgasegment.py @@ -22,9 +22,7 @@ def get_orgasegment_data( - path: Union[os.PathLike, str], - split: Literal["train", "val", "eval"], - download: bool = False + path: Union[os.PathLike, str], split: Literal["train", "val", "eval"], download: bool = False ) -> str: """Download the OrgaSegment dataset for organoid segmentation. @@ -55,9 +53,7 @@ def get_orgasegment_data( def get_orgasegment_paths( - path: Union[os.PathLike, str], - split: Literal["train", "val", "eval"], - download: bool = False + path: Union[os.PathLike, str], split: Literal["train", "val", "eval"], download: bool = False ) -> Tuple[List[str], List[str]]: """Get paths for the OrgaSegment data. @@ -80,8 +76,8 @@ def get_orgasegment_paths( def get_orgasegment_dataset( path: Union[os.PathLike, str], - split: Literal["train", "val", "eval"], patch_shape: Tuple[int, int], + split: Literal["train", "val", "eval"], boundaries: bool = False, binary: bool = False, download: bool = False, @@ -91,8 +87,8 @@ def get_orgasegment_dataset( Args: path: Filepath to a folder where the downloaded data will be saved. - split: The split to download. Either 'train', 'val or 'eval'. patch_shape: The patch shape to use for training. + split: The split to download. Either 'train', 'val or 'eval'. boundaries: Whether to compute boundaries as the target. binary: Whether to use a binary segmentation target. download: Whether to download the data if it is not present. @@ -103,7 +99,7 @@ def get_orgasegment_dataset( """ assert split in ["train", "val", "eval"] - image_paths, label_paths = get_orgasegment_paths(path=path, split=split, download=download) + image_paths, label_paths = get_orgasegment_paths(path, split, download) kwargs, _ = util.add_instance_label_transform(kwargs, add_binary_target=True, binary=binary, boundaries=boundaries) @@ -120,9 +116,9 @@ def get_orgasegment_dataset( def get_orgasegment_loader( path: Union[os.PathLike, str], - split: Literal["train", "val", "eval"], - patch_shape: Tuple[int, int], batch_size: int, + patch_shape: Tuple[int, int], + split: Literal["train", "val", "eval"], boundaries: bool = False, binary: bool = False, download: bool = False, @@ -132,8 +128,9 @@ def get_orgasegment_loader( Args: path: Filepath to a folder where the downloaded data will be saved. - split: The split to download. Either 'train', 'val or 'eval'. + batch_size: The batch size for training. patch_shape: The patch shape to use for training. + split: The split to download. Either 'train', 'val or 'eval'. boundaries: Whether to compute boundaries as the target. binary: Whether to use a binary segmentation target. download: Whether to download the data if it is not present. @@ -143,13 +140,5 @@ def get_orgasegment_loader( The DataLoader. """ ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) - dataset = get_orgasegment_dataset( - path=path, - split=split, - patch_shape=patch_shape, - boundaries=boundaries, - binary=binary, - download=download, - **ds_kwargs - ) - return torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs) + dataset = get_orgasegment_dataset(path, patch_shape, split, boundaries, binary, download, **ds_kwargs) + return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/light_microscopy/yeaz.py b/torch_em/data/datasets/light_microscopy/yeaz.py index 6399737b..d519be73 100644 --- a/torch_em/data/datasets/light_microscopy/yeaz.py +++ b/torch_em/data/datasets/light_microscopy/yeaz.py @@ -15,6 +15,9 @@ from natsort import natsorted from typing import Union, Tuple, Literal, List +import json +from sklearn.model_selection import train_test_split + from torch.utils.data import Dataset, DataLoader import torch_em @@ -28,9 +31,7 @@ } -def get_yeaz_data( - path: Union[os.PathLike, str], choice: Literal['bf, phc'], download: bool = False -) -> str: +def get_yeaz_data(path: Union[os.PathLike, str], choice: Literal['bf, phc'], download: bool = False) -> str: """Obtain the YeaZ dataset. NOTE: Please download the dataset manually. @@ -49,14 +50,15 @@ def get_yeaz_data( if os.path.exists(data_dir): return data_dir + os.makedirs(path, exist_ok=True) + tar_path = os.path.join( path, "gold-standard-PhC-plus-2.tar.gz" if choice == "phc" else "gold-standard-BF-V-1.tar.gz" ) - if not os.path.exists(tar_path) and not download: + if not os.path.exists(tar_path) or download: raise NotImplementedError( - "Automatic download is not supported at the moment. " - f"Please download the data manually from '{URL[choice]}'." + f"Automatic download is not supported. Please download the data manually from '{URL[choice]}'." ) util.unzip_tarfile(tar_path=tar_path, dst=path, remove=False) @@ -64,14 +66,42 @@ def get_yeaz_data( return data_dir +def _create_data_splits(path, data_dir, choice, split, raw_paths): + json_file = os.path.join(path, f"yeaz_{choice}_splits.json") + if os.path.exists(json_file): + with open(json_file, "r") as f: + data = json.load(f) + else: + # Get the filenames + names = [os.path.basename(p) for p in raw_paths] + + # Create train / val / test splits + train_split, test_split = train_test_split(names, test_size=0.2) + train_split, val_split = train_test_split(train_split, test_size=0.15) + data = {"train": train_split, "val": val_split, "test": test_split} + + # Write the filenames with splits to a json file. + with open(json_file, "w") as f: + json.dump(data, f, indent=4) + + _raw_paths = [os.path.join(data_dir, name) for name in data[split]] + _label_paths = [p.replace("_im.tif", "_mask.tif") for p in _raw_paths] + + return _raw_paths, _label_paths + + def get_yeaz_paths( - path: Union[os.PathLike, str], choice: Literal['bf, phc'], download: bool = False + path: Union[os.PathLike, str], + choice: Literal['bf, phc'], + split: Literal['train', 'val', 'test'], + download: bool = False ) -> Tuple[List[str], List[str]]: """Get the YeaZ data. Args: path: Filepath to a folder where the data is expected to be downloaded for further processing. choice: The choice of modality for dataset. + split: The choice of data split. download: Whether to download the data if it is not present. Not implemented for this data. Returns: @@ -81,7 +111,9 @@ def get_yeaz_paths( data_dir = get_yeaz_data(path, choice, download) raw_paths = natsorted(glob(os.path.join(data_dir, "*_im.tif"))) - label_paths = natsorted(glob(os.path.join(data_dir, "*_mask.tif"))) + + # Get the raw and label paths. + raw_paths, label_paths = _create_data_splits(path, data_dir, choice, split, raw_paths) assert len(raw_paths) == len(label_paths) and len(raw_paths) > 0 @@ -92,6 +124,7 @@ def get_yeaz_dataset( path: Union[os.PathLike, str], patch_shape: Tuple[int, int], choice: Literal['bf, phc'], + split: Literal['train', 'val', 'test'], download: bool = False, **kwargs ) -> Dataset: @@ -101,13 +134,14 @@ def get_yeaz_dataset( path: Filepath to a folder where the data is expected to be downloaded for further processing. patch_shape: The patch shape to use for training. choice: The choice of modality for dataset. + split: The choice of data split. download: Whether to download the data if it is not present. Not implemented for this data. kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. Returns: The segmentation dataset. """ - raw_paths, label_paths = get_yeaz_paths(path, choice, download) + raw_paths, label_paths = get_yeaz_paths(path, choice, split, download) return torch_em.default_segmentation_dataset( raw_paths=raw_paths, @@ -124,6 +158,7 @@ def get_yeaz_loader( batch_size: int, patch_shape: Tuple[int, int], choice: Literal['bf, phc'], + split: Literal['train', 'val', 'test'], download: bool = False, **kwargs ) -> DataLoader: @@ -131,8 +166,10 @@ def get_yeaz_loader( Args: path: Filepath to a folder where the data is expected to be downloaded for further processing. + batch_size: The batch size for training. patch_shape: The patch shape to use for training. choice: The choice of modality for dataset. + split: The choice of data split. download: Whether to download the data if it is not present. Not implemented for this data. kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. @@ -140,5 +177,5 @@ def get_yeaz_loader( The DataLoader. """ ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) - dataset = get_yeaz_dataset(path, patch_shape, choice, download, **ds_kwargs) + dataset = get_yeaz_dataset(path, patch_shape, choice, split, download, **ds_kwargs) return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)