Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion examples/configs/fit_example.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ data:
batch_size: 32
num_workers: 16
yx_patch_size: [256, 256]
normalizations:
pyramid_resolution: "0"
normalizations:
- class_path: viscy.transforms.NormalizeSampled
init_args:
keys: [source]
Expand Down Expand Up @@ -92,3 +93,4 @@ data:
sigma_y: [0.25, 1.5]
sigma_x: [0.25, 1.5]
caching: false

1 change: 1 addition & 0 deletions examples/configs/predict_example.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,5 +66,6 @@ predict:
- 256
- 256
caching: false
pyramid_resolution: "0"
return_predictions: false
ckpt_path: null
2 changes: 2 additions & 0 deletions examples/configs/test_example.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,5 +66,7 @@ data:
- 256
caching: false
ground_truth_masks: null
pyramid_resolution: "0"
ckpt_path: null
verbose: true
a
18 changes: 16 additions & 2 deletions viscy/data/hcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,16 +104,22 @@ class SlidingWindowDataset(Dataset):
:param ChannelMap channels: source and target channel names,
e.g. ``{'source': 'Phase', 'target': ['Nuclei', 'Membrane']}``
:param int z_window_size: Z window size of the 2.5D U-Net, 1 for 2D
:param str array_key:
Name of the image arrays (multiscales level), by default "0"
:param DictTransform | None transform:
a callable that transforms data, defaults to None
:param bool load_normalization_metadata:
whether to load normalization metadata, defaults to True
"""

def __init__(
self,
positions: list[Position],
channels: ChannelMap,
z_window_size: int,
array_key: str = "0",
transform: DictTransform | None = None,
load_normalization_metadata: bool = True,
) -> None:
super().__init__()
self.positions = positions
Expand All @@ -128,7 +134,9 @@ def __init__(
)
self.z_window_size = z_window_size
self.transform = transform
self.array_key = array_key
self._get_windows()
self.load_normalization_metadata = load_normalization_metadata

def _get_windows(self) -> None:
"""Count the sliding windows along T and Z,
Expand All @@ -138,7 +146,7 @@ def _get_windows(self) -> None:
self.window_arrays = []
self.window_norm_meta: list[NormMeta | None] = []
for fov in self.positions:
img_arr: ImageArray = fov["0"]
img_arr: ImageArray = fov[str(self.array_key)]
ts = img_arr.frames
zs = img_arr.slices - self.z_window_size + 1
if zs < 1:
Expand Down Expand Up @@ -225,10 +233,11 @@ def __getitem__(self, index: int) -> Sample:
sample = {
"index": sample_index,
"source": self._stack_channels(sample_images, "source"),
"norm_meta": norm_meta,
}
if self.target_ch_idx is not None:
sample["target"] = self._stack_channels(sample_images, "target")
if self.load_normalization_metadata:
sample["norm_meta"] = norm_meta
return sample


Expand Down Expand Up @@ -326,6 +335,8 @@ class HCSDataModule(LightningDataModule):
prefetch_factor : int or None, optional
Number of samples loaded in advance by each worker during fitting,
defaults to None (2 per PyTorch default).
array_key : str, optional
Name of the image arrays (multiscales level), by default "0"
"""

def __init__(
Expand All @@ -345,6 +356,7 @@ def __init__(
ground_truth_masks: Path | None = None,
persistent_workers=False,
prefetch_factor=None,
array_key: str = "0",
):
super().__init__()
self.data_path = Path(data_path)
Expand All @@ -363,6 +375,7 @@ def __init__(
self.prepare_data_per_node = True
self.persistent_workers = persistent_workers
self.prefetch_factor = prefetch_factor
self.array_key = array_key

@property
def cache_path(self):
Expand Down Expand Up @@ -419,6 +432,7 @@ def _base_dataset_settings(self) -> dict[str, dict[str, list[str]] | int]:
return {
"channels": {"source": self.source_channel},
"z_window_size": self.z_window_size,
"array_key": self.array_key,
}

def setup(self, stage: Literal["fit", "validate", "test", "predict"]):
Expand Down