diff --git a/examples/configs/fit_example.yml b/examples/configs/fit_example.yml index 511daf5a..fd17071e 100644 --- a/examples/configs/fit_example.yml +++ b/examples/configs/fit_example.yml @@ -23,6 +23,7 @@ model: loss_function: null lr: 0.001 schedule: Constant + ckpt_path: null log_batches_per_epoch: 8 log_samples_per_batch: 1 data: diff --git a/tests/data/__init__.py b/tests/data/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/light/test_data.py b/tests/data/test_data.py similarity index 100% rename from tests/light/test_data.py rename to tests/data/test_data.py diff --git a/tests/light/test_engine.py b/tests/light/test_engine.py index c6013365..9ce182f5 100644 --- a/tests/light/test_engine.py +++ b/tests/light/test_engine.py @@ -3,8 +3,5 @@ def test_fcmae_vsunet() -> None: model = FcmaeUNet( - architecture="fcmae", - model_config=dict(in_channels=3), - train_mask_ratio=0.6, + model_config=dict(in_channels=3, out_channels=1), fit_mask_ratio=0.6 ) - diff --git a/tests/unet/test_fcmae.py b/tests/unet/test_fcmae.py index 36fb673e..4ed441b4 100644 --- a/tests/unet/test_fcmae.py +++ b/tests/unet/test_fcmae.py @@ -17,7 +17,7 @@ def test_generate_mask(): w = 64 s = 16 m = 0.75 - mask = generate_mask((2, 3, w, w), stride=s, mask_ratio=m) + mask = generate_mask((2, 3, w, w), stride=s, mask_ratio=m, device="cpu") assert mask.shape == (2, 1, w // s, w // s) assert mask.dtype == torch.bool ratio = mask.sum((2, 3)) / mask.numel() * mask.shape[0] @@ -28,7 +28,7 @@ def test_masked_patchify(): b, c, h, w = 2, 3, 4, 8 x = torch.rand(b, c, h, w) mask_ratio = 0.75 - mask = generate_mask(x.shape, stride=2, mask_ratio=mask_ratio) + mask = generate_mask(x.shape, stride=2, mask_ratio=mask_ratio, device=x.device) mask = upsample_mask(mask, x.shape) feat = masked_patchify(x, ~mask) assert feat.shape == (b, int(h * w * (1 - mask_ratio)), c) @@ -42,7 +42,7 @@ def test_unmasked_patchify_roundtrip(): def test_masked_patchify_roundtrip(): x = torch.rand(2, 3, 4, 8) - mask = generate_mask(x.shape, stride=2, mask_ratio=0.5) + mask = generate_mask(x.shape, stride=2, mask_ratio=0.5, device=x.device) mask = upsample_mask(mask, x.shape) y = masked_unpatchify(masked_patchify(x, ~mask), out_shape=x.shape, unmasked=~mask) assert torch.all((y == 0) ^ (x == y)) @@ -51,7 +51,7 @@ def test_masked_patchify_roundtrip(): def test_masked_convnextv2_block() -> None: x = torch.rand(2, 3, 4, 5) - mask = generate_mask(x.shape, stride=1, mask_ratio=0.5) + mask = generate_mask(x.shape, stride=1, mask_ratio=0.5, device=x.device) block = MaskedConvNeXtV2Block(3, 3 * 2) unmasked_out = block(x) assert len(unmasked_out.unique()) == x.numel() * 2 @@ -65,7 +65,7 @@ def test_masked_convnextv2_block() -> None: def test_masked_convnextv2_stage(): x = torch.rand(2, 3, 16, 16) - mask = generate_mask(x.shape, stride=4, mask_ratio=0.5) + mask = generate_mask(x.shape, stride=4, mask_ratio=0.5, device=x.device) stage = MaskedConvNeXtV2Stage(3, 3, kernel_size=7, stride=2, num_blocks=2) out = stage(x) assert out.shape == (2, 3, 8, 8) @@ -79,7 +79,7 @@ def test_adaptive_projection(): ) assert proj(torch.rand(2, 3, 5, 8, 8)).shape == (2, 12, 2, 2) assert proj(torch.rand(2, 3, 1, 12, 16)).shape == (2, 12, 3, 4) - mask = generate_mask((1, 3, 5, 8, 8), stride=4, mask_ratio=0.6) + mask = generate_mask((1, 3, 5, 8, 8), stride=4, mask_ratio=0.6, device="cpu") masked_out = proj(torch.rand(1, 3, 5, 16, 16), mask) assert masked_out.shape == (1, 12, 4, 4) proj = MaskedAdaptiveProjection( @@ -106,7 +106,7 @@ def test_masked_multiscale_encoder(): def test_fcmae(): x = torch.rand(2, 3, 5, 128, 128) - model = FullyConvolutionalMAE(3) + model = FullyConvolutionalMAE(3, 3) y, m = model(x) assert y.shape == x.shape assert m is None diff --git a/viscy/data/ctmc_v1.py b/viscy/data/ctmc_v1.py new file mode 100644 index 00000000..0d65a36a --- /dev/null +++ b/viscy/data/ctmc_v1.py @@ -0,0 +1,75 @@ +from pathlib import Path + +from iohub.ngff import open_ome_zarr +from lightning.pytorch import LightningDataModule +from monai.transforms import Compose, MapTransform +from torch.utils.data import DataLoader + +from viscy.data.hcs import ChannelMap, SlidingWindowDataset + + +class CTMCv1DataModule(LightningDataModule): + """ + Autoregression data module for the CTMCv1 dataset. + Training and validation datasets are stored in separate HCS OME-Zarr stores. + """ + + def __init__( + self, + train_data_path: str | Path, + val_data_path: str | Path, + train_transforms: list[MapTransform], + val_transforms: list[MapTransform], + batch_size: int = 16, + num_workers: int = 8, + channel_name: str = "DIC", + ) -> None: + super().__init__() + self.train_data_path = train_data_path + self.val_data_path = val_data_path + self.train_transforms = train_transforms + self.val_transforms = val_transforms + self.channel_map = ChannelMap(source=channel_name, target=channel_name) + self.batch_size = batch_size + self.num_workers = num_workers + + def setup(self, stage: str) -> None: + if stage != "fit": + raise NotImplementedError("Only fit stage is supported") + self._setup_fit() + + def _setup_fit(self) -> None: + train_plate = open_ome_zarr(self.train_data_path) + val_plate = open_ome_zarr(self.val_data_path) + train_positions = [p for _, p in train_plate.positions()] + val_positions = [p for _, p in val_plate.positions()] + self.train_dataset = SlidingWindowDataset( + train_positions, + channels=self.channel_map, + z_window_size=1, + transform=Compose(self.train_transform), + ) + self.val_dataset = SlidingWindowDataset( + val_positions, + channels=self.channel_map, + z_window_size=1, + transform=Compose(self.val_transform), + ) + + def train_dataloader(self) -> DataLoader: + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + persistent_workers=bool(self.num_workers), + shuffle=True, + ) + + def val_dataloader(self) -> DataLoader: + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + persistent_workers=bool(self.num_workers), + shuffle=False, + ) diff --git a/viscy/light/engine.py b/viscy/light/engine.py index c6197a9a..4d18e9c4 100644 --- a/viscy/light/engine.py +++ b/viscy/light/engine.py @@ -98,6 +98,7 @@ class VSUNet(LightningModule): :param float lr: learning rate in training, defaults to 1e-3 :param Literal['WarmupCosine', 'Constant'] schedule: learning rate scheduler, defaults to "Constant" + :param str chkpt_path: path to the checkpoint to load weights, defaults to None :param int log_batches_per_epoch: number of batches to log each training/validation epoch, has to be smaller than steps per epoch, defaults to 8 @@ -124,6 +125,7 @@ def __init__( lr: float = 1e-3, schedule: Literal["WarmupCosine", "Constant"] = "Constant", freeze_encoder: bool = False, + ckpt_path: str = None, log_batches_per_epoch: int = 8, log_samples_per_batch: int = 1, example_input_yx_shape: Sequence[int] = (256, 256), @@ -164,6 +166,10 @@ def __init__( self.test_cellpose_diameter = test_cellpose_diameter self.test_evaluate_cellpose = test_evaluate_cellpose self.freeze_encoder = freeze_encoder + if ckpt_path is not None: + self.load_state_dict( + torch.load(ckpt_path)["state_dict"] + ) # loading only weights def forward(self, x: Tensor) -> Tensor: return self.model(x)