Skip to content

Commit

Permalink
Minor refactor to LM datasets (#459)
Browse files Browse the repository at this point in the history
Updates to some LM datasets
  • Loading branch information
anwai98 authored Dec 31, 2024
1 parent 2b389c1 commit 0b2b01a
Show file tree
Hide file tree
Showing 20 changed files with 151 additions and 114 deletions.
3 changes: 2 additions & 1 deletion scripts/datasets/light_microscopy/check_bitdepth_nucseg.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ def check_bitdepth_nucseg():
loader = get_bitdepth_nucseg_loader(
path=os.path.join(ROOT, "bitdepth_nucseg"),
patch_shape=(512, 512),
batch_size=1,
batch_size=2,
magnification=None,
download=True,
)

Expand Down
4 changes: 3 additions & 1 deletion scripts/datasets/light_microscopy/check_cellpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ def check_cellpose():
split="train",
patch_shape=(512, 512),
batch_size=1,
choice=None,
choice="cyto",
download=True,
shuffle=True,
)
check_loader(loader, 8, instance_labels=True)

Expand Down
2 changes: 1 addition & 1 deletion scripts/datasets/light_microscopy/check_ctc.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import os
import sys

from torch_em.data.datasets.light_microscopy.ctc import get_ctc_segmentation_loader, CTC_CHECKSUMS
from torch_em.util.debug import check_loader
from torch_em.data.sampler import MinInstanceSampler
from torch_em.data.datasets.light_microscopy.ctc import get_ctc_segmentation_loader, CTC_CHECKSUMS

sys.path.append("..")

Expand Down
4 changes: 2 additions & 2 deletions scripts/datasets/light_microscopy/check_cvz_fluo.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ def check_cvz_fluo():
path=os.path.join(ROOT, "cvz"),
patch_shape=(512, 512),
batch_size=2,
stain_choice="cell",
stain_choice="dapi",
data_choice=None,
)

check_loader(loader, 8, instance_labels=True, plt=True, save_path="./test.png", rgb=True)
check_loader(loader, 8, instance_labels=True, rgb=True)


if __name__ == "__main__":
Expand Down
6 changes: 4 additions & 2 deletions scripts/datasets/light_microscopy/check_deepbacs.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
def check_deepbacs():
from util import ROOT

loader = get_deepbacs_loader(os.path.join(ROOT, "deepbacs"), "test", bac_type="mixed", download=True,
patch_shape=(256, 256), batch_size=1, shuffle=True)
loader = get_deepbacs_loader(
os.path.join(ROOT, "deepbacs"), "test", bac_type="mixed",
download=True, patch_shape=(256, 256), batch_size=1, shuffle=True
)
check_loader(loader, 15, instance_labels=True)


Expand Down
19 changes: 12 additions & 7 deletions scripts/datasets/light_microscopy/check_dynamicnuclearnet.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,24 @@
import os
import sys

from torch_em.util.debug import check_loader
from torch_em.data.datasets import get_dynamicnuclearnet_loader


DYNAMICNUCLEARNET_ROOT = "/home/anwai/data/deepcell/"
sys.path.append("..")


# NOTE: the DynamicNuclearNet data cannot be downloaded automatically.
# you need to download it yourself from https://datasets.deepcell.org/data
def check_dynamicnuclearnet():
# set this path to where you have downloaded the dynamicnuclearnet data
from util import ROOT

loader = get_dynamicnuclearnet_loader(
DYNAMICNUCLEARNET_ROOT, "train",
patch_shape=(512, 512), batch_size=2, download=True
path=os.path.join(ROOT, "dynamicnuclearnet"),
split="train",
patch_shape=(512, 512),
batch_size=2,
download=True,
)
check_loader(loader, 10, instance_labels=True, rgb=False)
check_loader(loader, 8, instance_labels=True)


if __name__ == "__main__":
Expand Down
11 changes: 7 additions & 4 deletions scripts/datasets/light_microscopy/check_embedseg.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import os
import sys

from torch_em.data.datasets import get_embedseg_loader
from torch_em.util.debug import check_loader
from torch_em.data.datasets import get_embedseg_loader

sys.path.append("..")

Expand All @@ -17,12 +17,15 @@ def check_embedseg():
"Platynereis-Nuclei-CBG",
]

patch_shape = (32, 384, 384)
for name in names:
loader = get_embedseg_loader(
os.path.join(ROOT, "embedseg"), name=name, patch_shape=patch_shape, batch_size=1, download=True
path=os.path.join(ROOT, "embedseg"),
name=name,
patch_shape=(32, 384, 384),
batch_size=2,
download=True,
)
check_loader(loader, 2, instance_labels=True)
check_loader(loader, 8, instance_labels=True)


if __name__ == "__main__":
Expand Down
1 change: 1 addition & 0 deletions scripts/datasets/light_microscopy/check_ifnuclei.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def check_ifnuclei():
patch_shape=(512, 512),
batch_size=1,
download=True,
shuffle=True,
)

check_loader(loader, 8, instance_labels=True)
Expand Down
19 changes: 10 additions & 9 deletions scripts/datasets/light_microscopy/check_orgasegment.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,25 @@
import os
import sys

from torch_em.util.debug import check_loader
from torch_em.data.datasets.light_microscopy import get_orgasegment_loader


ROOT = "/media/anwai/ANWAI/data/orgasegment"
sys.path.append("..")


def check_orgasegment():
from util import ROOT

loader = get_orgasegment_loader(
path=ROOT,
split="val",
path=os.path.join(ROOT, "orgasegment"),
split="train",
patch_shape=(512, 512),
batch_size=1,
batch_size=2,
download=True,
)
check_loader(loader, 8, instance_labels=True)


def main():
check_orgasegment()


if __name__ == "__main__":
main()
check_orgasegment()
3 changes: 2 additions & 1 deletion scripts/datasets/light_microscopy/check_tissuenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
import numpy as np
from torch_em.transform.raw import standardize, normalize_percentile

from torch_em.data.datasets import get_tissuenet_loader
from torch_em.util.debug import check_loader
from torch_em.data.datasets import get_tissuenet_loader


sys.path.append("..")

Expand Down
16 changes: 7 additions & 9 deletions scripts/datasets/light_microscopy/check_yeaz.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import sys

from torch_em.data import MinInstanceSampler
from torch_em.util.debug import check_loader
from torch_em.data.datasets import get_yeaz_loader

Expand All @@ -11,19 +12,16 @@
def check_yeaz():
from util import ROOT

choice = "phc" # choose from 'bf' / 'phc'
if choice == "bf":
patch_shape, ndim = (512, 512), 2
else:
patch_shape, ndim = (1, 512, 512), 3

loader = get_yeaz_loader(
path=os.path.join(ROOT, "yeaz"),
batch_size=2,
patch_shape=patch_shape,
choice=choice,
ndim=ndim,
patch_shape=(512, 512),
ndim=2,
choice="phc", # choose from 'bf' / 'phc'
split="val",
sampler=MinInstanceSampler(),
download=False,
shuffle=True,
)

check_loader(loader, 8, instance_labels=True)
Expand Down
2 changes: 2 additions & 0 deletions torch_em/data/datasets/light_microscopy/bitdepth_nucseg.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ def get_bitdepth_nucseg_paths(
raw_paths = natsorted(glob(os.path.join(data_dir, magnification, "images_16bit", "*.tif")))
label_paths = natsorted(glob(os.path.join(data_dir, magnification, "label masks", "*.tif")))

assert len(raw_paths) == len(label_paths) and len(raw_paths) > 0

return raw_paths, label_paths


Expand Down
46 changes: 23 additions & 23 deletions torch_em/data/datasets/light_microscopy/cellpose.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""This dataset contains annotation for cell segmentation in
fluorescene microscently-labeled microscopy images.
"""This dataset contains annotation for cell segmentation in fluorescene microscently-labeled microscopy images.
This dataset is from the following publications:
- https://doi.org/10.1038/s41592-020-01018-x
Expand Down Expand Up @@ -40,14 +39,8 @@ def get_cellpose_data(
download: Whether to download the data if it is not present.
Returns:
The filepath to the data.
The filepath to the folder where the data is manually downloaded.
"""
if download:
raise NotImplementedError(
"The dataset cannot be automatically downloaded. "
"Please see 'get_cellpose_data' in 'torch_em/data/datasets/light_microscopy/cellpose.py' for details."
)

per_choice_dir = os.path.join(path, choice) # path where the unzipped files will be stored
if choice == "cyto":
assert split in ["train", "test"], f"'{split}' is not a valid split in '{choice}'."
Expand All @@ -60,7 +53,14 @@ def get_cellpose_data(
else:
raise ValueError(f"'{choice}' is not a valid dataset choice.")

if not os.path.exists(data_dir):
if os.path.exists(data_dir):
return data_dir
else:
if not os.path.exists(zip_path) and download:
raise NotImplementedError(
"The dataset cannot be automatically downloaded. "
"Please see 'get_cellpose_data' in 'torch_em/data/datasets/light_microscopy/cellpose.py' for details."
)
util.unzip(zip_path=zip_path, dst=per_choice_dir, remove=False)

return data_dir
Expand All @@ -69,7 +69,7 @@ def get_cellpose_data(
def get_cellpose_paths(
path: Union[os.PathLike, str],
split: Literal['train', 'test'],
choice: Literal["cyto", "cyto2"],
choice: Optional[Literal["cyto", "cyto2"]] = None,
download: bool = False,
) -> Tuple[List[str], List[str]]:
"""Get paths to the CellPose data.
Expand All @@ -84,18 +84,20 @@ def get_cellpose_paths(
List of filepaths for the image data.
List of filepaths for the label data.
"""
data_dir = get_cellpose_data(path=path, split=split, choice=choice, download=download)
data_dir = get_cellpose_data(path, split, choice, download)

image_paths = natsorted(glob(os.path.join(data_dir, "*_img.png")))
gt_paths = natsorted(glob(os.path.join(data_dir, "*_masks.png")))

assert len(image_paths) == len(gt_paths) and len(image_paths) > 0

return image_paths, gt_paths


def get_cellpose_dataset(
path: Union[os.PathLike, str],
split: Literal["train", "test"],
patch_shape: Tuple[int, int],
split: Literal["train", "test"],
choice: Optional[Literal["cyto", "cyto2"]] = None,
download: bool = False,
**kwargs
Expand All @@ -104,8 +106,8 @@ def get_cellpose_dataset(
Args:
path: Filepath to a folder where the downloaded data will be saved.
split: The data split to use. Either 'train', or 'test'.
patch_shape: The patch shape to use for training.
split: The data split to use. Either 'train', or 'test'.
choice: The choice of dataset. Either 'cyto' or 'cyto2'.
download: Whether to download the data if it is not present.
kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
Expand All @@ -124,7 +126,7 @@ def get_cellpose_dataset(
image_paths, gt_paths = [], []
for per_choice in choice:
assert per_choice in AVAILABLE_CHOICES
per_image_paths, per_gt_paths = get_cellpose_paths(path, split, choice, download)
per_image_paths, per_gt_paths = get_cellpose_paths(path, split, per_choice, download)
image_paths.extend(per_image_paths)
gt_paths.extend(per_gt_paths)

Expand All @@ -148,9 +150,9 @@ def get_cellpose_dataset(

def get_cellpose_loader(
path: Union[os.PathLike, str],
split: Literal["train", "test"],
patch_shape: Tuple[int, int],
batch_size: int,
patch_shape: Tuple[int, int],
split: Literal["train", "test"],
choice: Optional[Literal["cyto", "cyto2"]] = None,
download: bool = False,
**kwargs
Expand All @@ -159,9 +161,9 @@ def get_cellpose_loader(
Args:
path: Filepath to a folder where the downloaded data will be saved.
split: The data split to use. Either 'train', or 'test'.
patch_shape: The patch shape to use for training.
batch_size: The batch size for training.
patch_shape: The patch shape to use for training.
split: The data split to use. Either 'train', or 'test'.
choice: The choice of dataset. Either 'cyto' or 'cyto2'.
download: Whether to download the data if it is not present.
kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader.
Expand All @@ -170,7 +172,5 @@ def get_cellpose_loader(
The DataLoader.
"""
ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
dataset = get_cellpose_dataset(
path=path, split=split, patch_shape=patch_shape, choice=choice, download=download, **ds_kwargs
)
return torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs)
dataset = get_cellpose_dataset(path, patch_shape, split, choice, download, **ds_kwargs)
return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
6 changes: 2 additions & 4 deletions torch_em/data/datasets/light_microscopy/ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def get_ctc_segmentation_paths(
Filepath to the folder where image data is stored.
Filepath to the folder where label data is stored.
"""
data_path = get_ctc_segmentation_data(path, dataset_name, download, split)
data_path = get_ctc_segmentation_data(path, dataset_name, split, download)

if vol_id is None:
vol_ids = glob(os.path.join(data_path, "*_GT"))
Expand Down Expand Up @@ -222,7 +222,5 @@ def get_ctc_segmentation_loader(
The DataLoader.
"""
ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs)
dataset = get_ctc_segmentation_dataset(
path, dataset_name, patch_shape, split=split, vol_id=vol_id, download=download, **ds_kwargs,
)
dataset = get_ctc_segmentation_dataset(path, dataset_name, patch_shape, split, vol_id, download, **ds_kwargs)
return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs)
1 change: 0 additions & 1 deletion torch_em/data/datasets/light_microscopy/dsb.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,6 @@ def get_dsb_paths(
raw_paths = natsorted(glob(os.path.join(path, "full", "*", "images", f"{domain}_*.png")))
label_paths = natsorted(glob(os.path.join(path, "full", "*", "preprocessed_labels", f"{domain}_*.tif")))

print(len(raw_paths), len(label_paths))
assert len(raw_paths) == len(label_paths) and len(raw_paths) > 0

return raw_paths, label_paths
Expand Down
Loading

0 comments on commit 0b2b01a

Please sign in to comment.