Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memory-mapped caching for image translation training #218

Merged
merged 86 commits into from
Mar 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
86 commits
Select commit Hold shift + click to select a range
e19ee14
caching dataloader
edyoshikun Sep 12, 2024
d31978d
caching data module
edyoshikun Sep 26, 2024
041d738
black
edyoshikun Sep 27, 2024
7f76174
ruff
edyoshikun Sep 27, 2024
85ea791
Bump torch to 2.4.1 (#174)
edyoshikun Sep 28, 2024
1838581
adding timeout to ram_dataloader
edyoshikun Sep 28, 2024
f5c01a3
bandaid to cached dataloader
edyoshikun Oct 4, 2024
26a06b8
fixing the dataloader using torch collate_fn
edyoshikun Oct 4, 2024
f2ff43c
replacing dictionary with single array
edyoshikun Oct 17, 2024
5fb96d7
loading prior to epoch 0
edyoshikun Oct 18, 2024
848cd63
Revert "replacing dictionary with single array"
edyoshikun Oct 19, 2024
f7e57ae
using multiprocessing manager
edyoshikun Oct 19, 2024
c4797b4
add sharded distributed sampler
ziw-liu Oct 21, 2024
2c31e7d
add example script for ddp caching
ziw-liu Oct 21, 2024
5300b4a
format and lint
ziw-liu Oct 21, 2024
8a8b4b0
addding the custom distrb sampler to hcs_ram.py
edyoshikun Oct 22, 2024
49764fa
adding sampler to val train dataloader
edyoshikun Oct 22, 2024
1fe5491
fix divisibility of the last shard
ziw-liu Oct 22, 2024
0b005cf
hcs_ram format and lint
ziw-liu Oct 22, 2024
023ca88
data module that only crops and does not collate
ziw-liu Oct 23, 2024
f7ce0ba
wip: execute transforms on the GPU
ziw-liu Oct 23, 2024
daa6860
path for if not ddp
edyoshikun Oct 24, 2024
55499de
fix randomness in inversion transform
ziw-liu Oct 29, 2024
4280677
add option to pop the normalization metadata
ziw-liu Oct 29, 2024
1561802
move gpu transform definition back to data module
ziw-liu Oct 30, 2024
2e37217
add tiled crop transform for validation
ziw-liu Oct 30, 2024
7edf36e
add stack channel transform for gpu augmentation
ziw-liu Oct 30, 2024
eda5d1b
fix typing
ziw-liu Oct 30, 2024
550101d
collate before sending to gpu
ziw-liu Oct 30, 2024
92e3722
inherit gpu transforms for livecell dataset
ziw-liu Oct 30, 2024
c185377
update fcmae engine to apply per-dataset augmentations
ziw-liu Oct 30, 2024
2ca134b
format and lint hcs_ram
ziw-liu Oct 30, 2024
70fcf1c
Merge branch 'simple-cache' into gpu-transform
ziw-liu Oct 31, 2024
be0e94f
fix abc type hint
ziw-liu Oct 31, 2024
92c4b0a
update docstring style
ziw-liu Oct 31, 2024
f7b585c
disable grad for validation transforms
ziw-liu Oct 31, 2024
42c49f5
improve sample image logging in fcmae
ziw-liu Oct 31, 2024
4bf1088
fix dataset length when batch size is larger than the dataset
ziw-liu Oct 31, 2024
3276950
fix docstring
ziw-liu Oct 31, 2024
14a16ed
add option to disable normalization metadata
ziw-liu Oct 31, 2024
6719305
inherit gpu transform for ctmc
ziw-liu Oct 31, 2024
fad3d4e
remove duplicate method overrride
ziw-liu Oct 31, 2024
07c1021
update docstring for ctmc
ziw-liu Nov 1, 2024
949c445
Merge pull request #196 from mehta-lab/gpu-transform
ziw-liu Nov 8, 2024
d331c1f
Merge branch 'main' into simple-cache
ziw-liu Nov 8, 2024
7d79473
allow skipping caching for large datasets
ziw-liu Nov 13, 2024
736d4c5
Merge branch 'main' into simple-cache
ziw-liu Nov 13, 2024
e548d52
make the fcmae module compatible with image translation
ziw-liu Nov 14, 2024
084717f
remove prototype implementation
ziw-liu Nov 19, 2024
fdc377a
fix import path
ziw-liu Nov 19, 2024
96313fa
Arbitrary prediction time transforms (#209)
ziw-liu Dec 2, 2024
6e1818b
add docstrings
ziw-liu Dec 2, 2024
cbc59f6
wip: segmentation module
ziw-liu Dec 4, 2024
9126083
Merge branch 'main' into simple-cache
ziw-liu Dec 4, 2024
970e861
avoid casting
ziw-liu Dec 4, 2024
9b9a1b1
Merge branch 'simple-cache' into segmentation-module
ziw-liu Dec 4, 2024
976cc1c
update import path from iohub
ziw-liu Dec 7, 2024
ac0643b
make integer array in fixture
ziw-liu Dec 7, 2024
c6ac8ef
labels fixture
ziw-liu Dec 7, 2024
f86e3d5
test segmentation metrics modules
ziw-liu Dec 7, 2024
30b2660
less strings
ziw-liu Dec 7, 2024
bea9413
test non-empty
ziw-liu Dec 7, 2024
a23efe9
select which wells to include in fit
ziw-liu Dec 11, 2024
50f4d9a
Merge branch 'main' into segmentation-module
ziw-liu Jan 2, 2025
16704ac
make well selection a mixin
ziw-liu Jan 4, 2025
fb56710
wip: mmap cache data module
ziw-liu Jan 4, 2025
c9a7b3e
support exclusion of FOVs
ziw-liu Jan 14, 2025
d2cd340
wip: precompute normalization
ziw-liu Jan 14, 2025
2e12f00
add augmentations benchmark
ziw-liu Jan 14, 2025
dd00a36
fix cpu threads default
ziw-liu Jan 14, 2025
6a88ec4
fix probability (affects cpu results)
ziw-liu Jan 14, 2025
a2f8823
disable metadata tracking
ziw-liu Jan 14, 2025
4d07b55
fix non-distributed initialization
ziw-liu Jan 22, 2025
bc49671
refactor transforms into submodules
ziw-liu Jan 22, 2025
a84e0d4
Merge branch 'main' into mmap-cached
ziw-liu Mar 25, 2025
87c6d6c
do not import type hints at runtime
ziw-liu Mar 25, 2025
2184723
update docstring
ziw-liu Mar 25, 2025
4194b5a
backwards compatible import path
ziw-liu Mar 25, 2025
e03eb49
fix annotations
ziw-liu Mar 25, 2025
08823c7
fix style
ziw-liu Mar 25, 2025
6cafdb7
fix dice score import
ziw-liu Mar 26, 2025
391db50
fix dice score parameters
ziw-liu Mar 26, 2025
85c7651
apply formatting to exercise
ziw-liu Mar 26, 2025
dfd9390
fix labels data type
ziw-liu Mar 26, 2025
b93af2a
fix labels input shape
ziw-liu Mar 26, 2025
ce73044
Merge branch 'main' into mmap-cached
ziw-liu Mar 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
454 changes: 270 additions & 184 deletions examples/virtual_staining/dlmbl_exercise/exercise.ipynb

Large diffs are not rendered by default.

2,036 changes: 346 additions & 1,690 deletions examples/virtual_staining/dlmbl_exercise/solution.ipynb

Large diffs are not rendered by default.

609 changes: 370 additions & 239 deletions examples/virtual_staining/dlmbl_exercise/solution.py

Large diffs are not rendered by default.

5 changes: 3 additions & 2 deletions viscy/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def subcommands() -> dict[str, set[str]]:
subcommand_base_args = {"model"}
subcommands["preprocess"] = subcommand_base_args
subcommands["export"] = subcommand_base_args
subcommands["precompute"] = subcommand_base_args
return subcommands

def add_arguments_to_parser(self, parser) -> None:
Expand Down Expand Up @@ -50,8 +51,8 @@ def main() -> None:
Set default random seed to 42.
"""
_setup_environment()
require_model = "preprocess" not in sys.argv
require_data = {"preprocess", "export"}.isdisjoint(sys.argv)
require_model = {"preprocess", "precompute"}.isdisjoint(sys.argv)
require_data = {"preprocess", "precompute", "export"}.isdisjoint(sys.argv)
_ = VisCyCLI(
model_class=LightningModule,
datamodule_class=LightningDataModule if require_data else None,
Expand Down
26 changes: 25 additions & 1 deletion viscy/data/gpu_aug.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from viscy.data.distributed import ShardedDistributedSampler
from viscy.data.hcs import _ensure_channel_list, _read_norm_meta
from viscy.data.typing import DictTransform, NormMeta
from viscy.preprocessing.precompute import _filter_fovs, _filter_wells

if TYPE_CHECKING:
from multiprocessing.managers import DictProxy
Expand All @@ -36,6 +37,7 @@ class GPUTransformDataModule(ABC, LightningDataModule):
batch_size: int
num_workers: int
pin_memory: bool
prefetch_factor: int | None

def _maybe_sampler(
self, dataset: Dataset, shuffle: bool
Expand All @@ -59,6 +61,7 @@ def train_dataloader(self) -> DataLoader:
pin_memory=self.pin_memory,
drop_last=False,
collate_fn=list_data_collate,
prefetch_factor=self.prefetch_factor,
)

def val_dataloader(self) -> DataLoader:
Expand All @@ -74,6 +77,7 @@ def val_dataloader(self) -> DataLoader:
pin_memory=self.pin_memory,
drop_last=False,
collate_fn=list_data_collate,
prefetch_factor=self.prefetch_factor,
)

@property
Expand Down Expand Up @@ -169,7 +173,23 @@ def __getitem__(self, idx: int) -> dict[str, Tensor]:
return sample


class CachedOmeZarrDataModule(GPUTransformDataModule):
class SelectWell:
_include_wells: list[str] | None
_exclude_fovs: list[str] | None

def _filter_fit_fovs(self, plate: Plate) -> list[Position]:
positions = []
for well in _filter_wells(plate, include_wells=self._include_wells):
for fov in _filter_fovs(well, exclude_fovs=self._exclude_fovs):
positions.append(fov)
if len(positions) < 2:
raise ValueError(
"At least 2 FOVs are required for training and validation."
)
return positions


class CachedOmeZarrDataModule(GPUTransformDataModule, SelectWell):
"""Data module for cached OME-Zarr arrays.

Parameters
Expand Down Expand Up @@ -199,6 +219,8 @@ class CachedOmeZarrDataModule(GPUTransformDataModule):
Skip caching for this dataset, by default False
include_wells : list[str], optional
List of well names to include in the dataset, by default None (all)
include_wells : list[str], optional
List of well names to include in the dataset, by default None (all)
"""

def __init__(
Expand All @@ -215,6 +237,7 @@ def __init__(
pin_memory: bool = True,
skip_cache: bool = False,
include_wells: list[str] | None = None,
exclude_fovs: list[str] | None = None,
):
super().__init__()
self.data_path = data_path
Expand All @@ -229,6 +252,7 @@ def __init__(
self.pin_memory = pin_memory
self.skip_cache = skip_cache
self._include_wells = include_wells
self._exclude_fovs = exclude_fovs

@property
def train_cpu_transforms(self) -> Compose:
Expand Down
264 changes: 264 additions & 0 deletions viscy/data/mmap_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,264 @@
from __future__ import annotations

import os
import tempfile
from logging import getLogger
from pathlib import Path
from typing import TYPE_CHECKING, Literal

import numpy as np
import torch
from iohub.ngff import Plate, Position, open_ome_zarr
from monai.data.meta_obj import set_track_meta
from monai.transforms.compose import Compose
from tensordict.memmap import MemoryMappedTensor
from torch import Tensor
from torch.multiprocessing import Manager
from torch.utils.data import Dataset

from viscy.data.gpu_aug import GPUTransformDataModule, SelectWell
from viscy.data.hcs import _ensure_channel_list, _read_norm_meta
from viscy.data.typing import DictTransform, NormMeta

if TYPE_CHECKING:
from multiprocessing.managers import DictProxy

_logger = getLogger("lightning.pytorch")

_CacheMetadata = tuple[Position, int, NormMeta | None]


class MmappedDataset(Dataset):
def __init__(
self,
positions: list[Position],
channel_names: list[str],
cache_map: DictProxy,
buffer: MemoryMappedTensor,
preprocess_transforms: Compose | None = None,
cpu_transform: Compose | None = None,
array_key: str = "0",
load_normalization_metadata: bool = True,
):
key = 0
self._metadata_map: dict[int, _CacheMetadata] = {}
for position in positions:
img = position[array_key]
norm_meta = _read_norm_meta(position)
for time_idx in range(img.frames):
cache_map[key] = None
self._metadata_map[key] = (position, time_idx, norm_meta)
key += 1
self.channels = {ch: position.get_channel_index(ch) for ch in channel_names}
self.array_key = array_key
self._buffer = buffer
self._cache_map = cache_map
self.preprocess_transforms = preprocess_transforms
self.cpu_transform = cpu_transform
self.load_normalization_metadata = load_normalization_metadata

def __len__(self) -> int:
return len(self._metadata_map)

def _split_channels(self, volume: Tensor) -> dict[str, Tensor]:
return {name: img[None] for name, img in zip(self.channels.keys(), volume)}

def _preprocess_volume(self, volume: Tensor, norm_meta) -> Tensor:
if self.preprocess_transforms:
orig_shape = volume.shape
sample = self._split_channels(volume)
if self.load_normalization_metadata:
sample["norm_meta"] = norm_meta
sample = self.preprocess_transforms(sample)
volume = torch.cat([sample[name] for name in self.channels.keys()], dim=0)
assert volume.shape == orig_shape, (volume.shape, orig_shape, sample.keys())
return volume

def __getitem__(self, idx: int) -> dict[str, Tensor]:
position, time_idx, norm_meta = self._metadata_map[idx]
if not self._cache_map[idx]:
_logger.debug(f"Loading volume for index {idx}")
volume = torch.from_numpy(
position[self.array_key]
.oindex[time_idx, list(self.channels.values())]
.astype(np.float32)
)
volume = self._preprocess_volume(volume, norm_meta)
_logger.debug(f"Caching for index {idx}")
self._cache_map[idx] = True
self._buffer[idx] = volume
else:
_logger.debug(f"Using cached volume for index {idx}")
volume = self._buffer[idx]
sample = self._split_channels(volume)
if self.cpu_transform:
sample = self.cpu_transform(sample)
if not isinstance(sample, list):
sample = [sample]
return sample


class MmappedDataModule(GPUTransformDataModule, SelectWell):
"""Data module for cached OME-Zarr arrays.

Parameters
----------
data_path : Path
Path to the HCS OME-Zarr dataset.
channels : str | list[str]
Channel names to load.
batch_size : int
Batch size for training and validation.
num_workers : int
Number of workers for data-loaders.
split_ratio : float
Fraction of the FOVs used for the training split.
The rest will be used for validation.
train_cpu_transforms : list[DictTransform]
Transforms to be applied on the CPU during training.
val_cpu_transforms : list[DictTransform]
Transforms to be applied on the CPU during validation.
train_gpu_transforms : list[DictTransform]
Transforms to be applied on the GPU during training.
val_gpu_transforms : list[DictTransform]
Transforms to be applied on the GPU during validation.
pin_memory : bool, optional
Use page-locked memory in data-loaders, by default True
prefetch_factor : int | None, optional
Prefetching ratio for the torch dataloader, by default None
array_key : str, optional
Name of the image arrays (multiscales level), by default "0"
scratch_dir : Path | None, optional
Path to the scratch directory,
by default None (use OS temporary data directory)
include_wells : list[str] | None, optional
Include only a subset of wells, by default None (include all wells)
exclude_fovs : list[str] | None, optional
Exclude FOVs, by default None (do not exclude any FOVs)
"""

def __init__(
self,
data_path: Path,
channels: str | list[str],
batch_size: int,
num_workers: int,
split_ratio: float,
preprocess_transforms: list[DictTransform],
train_cpu_transforms: list[DictTransform],
val_cpu_transforms: list[DictTransform],
train_gpu_transforms: list[DictTransform],
val_gpu_transforms: list[DictTransform],
pin_memory: bool = True,
prefetch_factor: int | None = None,
array_key: str = "0",
scratch_dir: Path | None = None,
include_wells: list[str] | None = None,
exclude_fovs: list[str] | None = None,
):
super().__init__()
self.data_path = Path(data_path)
self.channels = _ensure_channel_list(channels)
self.batch_size = batch_size
self.num_workers = num_workers
self.split_ratio = split_ratio
self._preprocessing_transforms = Compose(preprocess_transforms)
self._train_cpu_transforms = Compose(train_cpu_transforms)
self._val_cpu_transforms = Compose(val_cpu_transforms)
self._train_gpu_transforms = Compose(train_gpu_transforms)
self._val_gpu_transforms = Compose(val_gpu_transforms)
self.pin_memory = pin_memory
self.array_key = array_key
self.scratch_dir = scratch_dir
self._include_wells = include_wells
self._exclude_fovs = exclude_fovs
self.prepare_data_per_node = True
self.prefetch_factor = prefetch_factor if self.num_workers > 0 else None

@property
def preprocessing_transforms(self) -> Compose:
return self._preprocessing_transforms

@property
def train_cpu_transforms(self) -> Compose:
return self._train_cpu_transforms

@property
def train_gpu_transforms(self) -> Compose:
return self._train_gpu_transforms

@property
def val_cpu_transforms(self) -> Compose:
return self._val_cpu_transforms

@property
def val_gpu_transforms(self) -> Compose:
return self._val_gpu_transforms

@property
def cache_dir(self) -> Path:
scratch_dir = self.scratch_dir or Path(tempfile.gettempdir())
cache_dir = Path(
scratch_dir,
os.getenv("SLURM_JOB_ID", "viscy_cache"),
str(
torch.distributed.get_rank()
if torch.distributed.is_initialized()
else 0
),
self.data_path.name,
)
cache_dir.mkdir(parents=True, exist_ok=True)
return cache_dir

def _set_fit_global_state(self, num_positions: int) -> list[int]:
# disable metadata tracking in MONAI for performance
set_track_meta(False)
# shuffle positions, randomness is handled globally
return torch.randperm(num_positions).tolist()

def _buffer_shape(self, arr_shape, fovs) -> tuple[int, ...]:
return (len(fovs) * arr_shape[0], len(self.channels), *arr_shape[2:])

def setup(self, stage: Literal["fit", "validate"]) -> None:
if stage not in ("fit", "validate"):
raise NotImplementedError("Only fit and validate stages are supported.")
plate: Plate = open_ome_zarr(self.data_path, mode="r", layout="hcs")
positions = self._filter_fit_fovs(plate)
arr_shape = positions[0][self.array_key].shape
shuffled_indices = self._set_fit_global_state(len(positions))
num_train_fovs = int(len(positions) * self.split_ratio)
train_fovs = [positions[i] for i in shuffled_indices[:num_train_fovs]]
val_fovs = [positions[i] for i in shuffled_indices[num_train_fovs:]]
_logger.debug(f"Training FOVs: {[p.zgroup.name for p in train_fovs]}")
_logger.debug(f"Validation FOVs: {[p.zgroup.name for p in val_fovs]}")
train_buffer = MemoryMappedTensor.empty(
self._buffer_shape(arr_shape, train_fovs),
dtype=torch.float32,
filename=self.cache_dir / "train.mmap",
)
val_buffer = MemoryMappedTensor.empty(
self._buffer_shape(arr_shape, val_fovs),
dtype=torch.float32,
filename=self.cache_dir / "val.mmap",
)
cache_map_train = Manager().dict()
self.train_dataset = MmappedDataset(
positions=train_fovs,
channel_names=self.channels,
cache_map=cache_map_train,
buffer=train_buffer,
preprocess_transforms=self.preprocessing_transforms,
cpu_transform=self.train_cpu_transforms,
array_key=self.array_key,
)
cache_map_val = Manager().dict()
self.val_dataset = MmappedDataset(
positions=val_fovs,
channel_names=self.channels,
cache_map=cache_map_val,
buffer=val_buffer,
preprocess_transforms=self.preprocessing_transforms,
cpu_transform=self.val_cpu_transforms,
array_key=self.array_key,
)
Loading