Skip to content

Commit

Permalink
Merge branch 'fcmae' into normalization_roi
Browse files Browse the repository at this point in the history
  • Loading branch information
ziw-liu committed Feb 24, 2024
2 parents 61f9a9f + 78aed97 commit 154ad31
Show file tree
Hide file tree
Showing 7 changed files with 90 additions and 11 deletions.
1 change: 1 addition & 0 deletions examples/configs/fit_example.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ model:
loss_function: null
lr: 0.001
schedule: Constant
ckpt_path: null
log_batches_per_epoch: 8
log_samples_per_batch: 1
data:
Expand Down
Empty file added tests/data/__init__.py
Empty file.
File renamed without changes.
5 changes: 1 addition & 4 deletions tests/light/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,5 @@

def test_fcmae_vsunet() -> None:
model = FcmaeUNet(
architecture="fcmae",
model_config=dict(in_channels=3),
train_mask_ratio=0.6,
model_config=dict(in_channels=3, out_channels=1), fit_mask_ratio=0.6
)

14 changes: 7 additions & 7 deletions tests/unet/test_fcmae.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def test_generate_mask():
w = 64
s = 16
m = 0.75
mask = generate_mask((2, 3, w, w), stride=s, mask_ratio=m)
mask = generate_mask((2, 3, w, w), stride=s, mask_ratio=m, device="cpu")
assert mask.shape == (2, 1, w // s, w // s)
assert mask.dtype == torch.bool
ratio = mask.sum((2, 3)) / mask.numel() * mask.shape[0]
Expand All @@ -28,7 +28,7 @@ def test_masked_patchify():
b, c, h, w = 2, 3, 4, 8
x = torch.rand(b, c, h, w)
mask_ratio = 0.75
mask = generate_mask(x.shape, stride=2, mask_ratio=mask_ratio)
mask = generate_mask(x.shape, stride=2, mask_ratio=mask_ratio, device=x.device)
mask = upsample_mask(mask, x.shape)
feat = masked_patchify(x, ~mask)
assert feat.shape == (b, int(h * w * (1 - mask_ratio)), c)
Expand All @@ -42,7 +42,7 @@ def test_unmasked_patchify_roundtrip():

def test_masked_patchify_roundtrip():
x = torch.rand(2, 3, 4, 8)
mask = generate_mask(x.shape, stride=2, mask_ratio=0.5)
mask = generate_mask(x.shape, stride=2, mask_ratio=0.5, device=x.device)
mask = upsample_mask(mask, x.shape)
y = masked_unpatchify(masked_patchify(x, ~mask), out_shape=x.shape, unmasked=~mask)
assert torch.all((y == 0) ^ (x == y))
Expand All @@ -51,7 +51,7 @@ def test_masked_patchify_roundtrip():

def test_masked_convnextv2_block() -> None:
x = torch.rand(2, 3, 4, 5)
mask = generate_mask(x.shape, stride=1, mask_ratio=0.5)
mask = generate_mask(x.shape, stride=1, mask_ratio=0.5, device=x.device)
block = MaskedConvNeXtV2Block(3, 3 * 2)
unmasked_out = block(x)
assert len(unmasked_out.unique()) == x.numel() * 2
Expand All @@ -65,7 +65,7 @@ def test_masked_convnextv2_block() -> None:

def test_masked_convnextv2_stage():
x = torch.rand(2, 3, 16, 16)
mask = generate_mask(x.shape, stride=4, mask_ratio=0.5)
mask = generate_mask(x.shape, stride=4, mask_ratio=0.5, device=x.device)
stage = MaskedConvNeXtV2Stage(3, 3, kernel_size=7, stride=2, num_blocks=2)
out = stage(x)
assert out.shape == (2, 3, 8, 8)
Expand All @@ -79,7 +79,7 @@ def test_adaptive_projection():
)
assert proj(torch.rand(2, 3, 5, 8, 8)).shape == (2, 12, 2, 2)
assert proj(torch.rand(2, 3, 1, 12, 16)).shape == (2, 12, 3, 4)
mask = generate_mask((1, 3, 5, 8, 8), stride=4, mask_ratio=0.6)
mask = generate_mask((1, 3, 5, 8, 8), stride=4, mask_ratio=0.6, device="cpu")
masked_out = proj(torch.rand(1, 3, 5, 16, 16), mask)
assert masked_out.shape == (1, 12, 4, 4)
proj = MaskedAdaptiveProjection(
Expand All @@ -106,7 +106,7 @@ def test_masked_multiscale_encoder():

def test_fcmae():
x = torch.rand(2, 3, 5, 128, 128)
model = FullyConvolutionalMAE(3)
model = FullyConvolutionalMAE(3, 3)
y, m = model(x)
assert y.shape == x.shape
assert m is None
Expand Down
75 changes: 75 additions & 0 deletions viscy/data/ctmc_v1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from pathlib import Path

from iohub.ngff import open_ome_zarr
from lightning.pytorch import LightningDataModule
from monai.transforms import Compose, MapTransform
from torch.utils.data import DataLoader

from viscy.data.hcs import ChannelMap, SlidingWindowDataset


class CTMCv1DataModule(LightningDataModule):
"""
Autoregression data module for the CTMCv1 dataset.
Training and validation datasets are stored in separate HCS OME-Zarr stores.
"""

def __init__(
self,
train_data_path: str | Path,
val_data_path: str | Path,
train_transforms: list[MapTransform],
val_transforms: list[MapTransform],
batch_size: int = 16,
num_workers: int = 8,
channel_name: str = "DIC",
) -> None:
super().__init__()
self.train_data_path = train_data_path
self.val_data_path = val_data_path
self.train_transforms = train_transforms
self.val_transforms = val_transforms
self.channel_map = ChannelMap(source=channel_name, target=channel_name)
self.batch_size = batch_size
self.num_workers = num_workers

def setup(self, stage: str) -> None:
if stage != "fit":
raise NotImplementedError("Only fit stage is supported")
self._setup_fit()

def _setup_fit(self) -> None:
train_plate = open_ome_zarr(self.train_data_path)
val_plate = open_ome_zarr(self.val_data_path)
train_positions = [p for _, p in train_plate.positions()]
val_positions = [p for _, p in val_plate.positions()]
self.train_dataset = SlidingWindowDataset(
train_positions,
channels=self.channel_map,
z_window_size=1,
transform=Compose(self.train_transform),
)
self.val_dataset = SlidingWindowDataset(
val_positions,
channels=self.channel_map,
z_window_size=1,
transform=Compose(self.val_transform),
)

def train_dataloader(self) -> DataLoader:
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
persistent_workers=bool(self.num_workers),
shuffle=True,
)

def val_dataloader(self) -> DataLoader:
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
persistent_workers=bool(self.num_workers),
shuffle=False,
)
6 changes: 6 additions & 0 deletions viscy/light/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ class VSUNet(LightningModule):
:param float lr: learning rate in training, defaults to 1e-3
:param Literal['WarmupCosine', 'Constant'] schedule:
learning rate scheduler, defaults to "Constant"
:param str chkpt_path: path to the checkpoint to load weights, defaults to None
:param int log_batches_per_epoch:
number of batches to log each training/validation epoch,
has to be smaller than steps per epoch, defaults to 8
Expand All @@ -124,6 +125,7 @@ def __init__(
lr: float = 1e-3,
schedule: Literal["WarmupCosine", "Constant"] = "Constant",
freeze_encoder: bool = False,
ckpt_path: str = None,
log_batches_per_epoch: int = 8,
log_samples_per_batch: int = 1,
example_input_yx_shape: Sequence[int] = (256, 256),
Expand Down Expand Up @@ -164,6 +166,10 @@ def __init__(
self.test_cellpose_diameter = test_cellpose_diameter
self.test_evaluate_cellpose = test_evaluate_cellpose
self.freeze_encoder = freeze_encoder
if ckpt_path is not None:
self.load_state_dict(
torch.load(ckpt_path)["state_dict"]
) # loading only weights

def forward(self, x: Tensor) -> Tensor:
return self.model(x)
Expand Down

0 comments on commit 154ad31

Please sign in to comment.