diff --git a/viscy/transforms/__init__.py b/viscy/transforms/__init__.py new file mode 100644 index 000000000..3fff21fa1 --- /dev/null +++ b/viscy/transforms/__init__.py @@ -0,0 +1,31 @@ +from viscy.transforms._redef import ( + RandAdjustContrastd, + RandAffined, + RandGaussianNoised, + RandGaussianSmoothd, + RandScaleIntensityd, + RandWeightedCropd, + ScaleIntensityRangePercentilesd, +) +from viscy.transforms._transforms import ( + BatchedZoom, + NormalizeSampled, + RandInvertIntensityd, + StackChannelsd, + TiledSpatialCropSamplesd, +) + +__all__ = [ + "BatchedZoom", + "NormalizeSampled", + "RandAdjustContrastd", + "RandAffined", + "RandGaussianNoised", + "RandGaussianSmoothd", + "RandInvertIntensityd", + "RandScaleIntensityd", + "RandWeightedCropd", + "ScaleIntensityRangePercentilesd", + "StackChannelsd", + "TiledSpatialCropSamplesd", +] diff --git a/viscy/transforms/_gaussian_blur.py b/viscy/transforms/_gaussian_blur.py new file mode 100644 index 000000000..85292e408 --- /dev/null +++ b/viscy/transforms/_gaussian_blur.py @@ -0,0 +1,73 @@ +"""3D version of `kornia.augmentation._2d.intensity.gaussian_blur`.""" + +from typing import Any, Iterable + +from kornia.augmentation import random_generator as rg +from kornia.augmentation._3d.intensity.base import IntensityAugmentationBase3D +from kornia.constants import BorderType +from kornia.filters import filter3d, get_gaussian_kernel3d +from monai.transforms import MapTransform, RandomizableTransform +from torch import Tensor + + +class RandomGaussianBlur(IntensityAugmentationBase3D): + def __init__( + self, + kernel_size: tuple[int, int, int] | int, + sigma: tuple[float, float, float] | Tensor, + border_type: str = "reflect", + same_on_batch: bool = False, + p: float = 0.5, + keepdim: bool = False, + ) -> None: + super().__init__(p=p, same_on_batch=same_on_batch, p_batch=1.0, keepdim=keepdim) + + self.flags = { + "kernel_size": kernel_size, + "border_type": BorderType.get(border_type), + } + self._param_generator = rg.RandomGaussianBlurGenerator(sigma) + + def apply_transform( + self, + input: Tensor, + params: dict[str, Tensor], + flags: dict[str, Any], + transform: Tensor | None = None, + ) -> Tensor: + sigma = params["sigma"].unsqueeze(-1).expand(-1, 2) + kernel = get_gaussian_kernel3d( + kernel_size=self.flags["kernel_size"], sigma=sigma + ) + return filter3d(input, kernel, border_type=self.flags["border_type"]) + + +class BatchedRandGaussianBlurd(MapTransform, RandomizableTransform): + def __init__( + self, + keys: str | Iterable[str], + kernel_size: tuple[int, int] | int, + sigma: tuple[float, float], + border_type: str = "reflect", + same_on_batch: bool = False, + prob: float = 0.1, + allow_missing_keys: bool = False, + ) -> None: + MapTransform.__init__(self, keys, allow_missing_keys=allow_missing_keys) + RandomizableTransform.__init__(self, prob) + self.filter = RandomGaussianBlur( + kernel_size=kernel_size, + sigma=sigma, + border_type=border_type, + same_on_batch=same_on_batch, + p=prob, + ) + + def __call__(self, sample: dict[str, Tensor]) -> dict[str, Tensor]: + self.randomize(None) + if not self._do_transform: + return sample + for key in self.keys: + if key in sample: + sample[key] = -sample[key] + return sample diff --git a/viscy/transforms/_redef.py b/viscy/transforms/_redef.py new file mode 100644 index 000000000..e41c27446 --- /dev/null +++ b/viscy/transforms/_redef.py @@ -0,0 +1,134 @@ +"""Redefine transforms from MONAI for jsonargparse.""" + +from typing import Sequence + +from monai.transforms import ( + RandAdjustContrastd, + RandAffined, + RandGaussianNoised, + RandGaussianSmoothd, + RandScaleIntensityd, + RandWeightedCropd, + ScaleIntensityRangePercentilesd, +) +from numpy.typing import DTypeLike + + +class RandWeightedCropd(RandWeightedCropd): + def __init__( + self, + keys: Sequence[str] | str, + w_key: str, + spatial_size: Sequence[int], + num_samples: int = 1, + **kwargs, + ): + super().__init__( + keys=keys, + w_key=w_key, + spatial_size=spatial_size, + num_samples=num_samples, + **kwargs, + ) + + +class RandAffined(RandAffined): + def __init__( + self, + keys: Sequence[str] | str, + prob: float, + rotate_range: Sequence[float] | float, + shear_range: Sequence[float] | float, + scale_range: Sequence[float] | float, + **kwargs, + ): + super().__init__( + keys=keys, + prob=prob, + rotate_range=rotate_range, + shear_range=shear_range, + scale_range=scale_range, + **kwargs, + ) + + +class RandAdjustContrastd(RandAdjustContrastd): + def __init__( + self, + keys: Sequence[str] | str, + prob: float, + gamma: tuple[float, float] | float, + **kwargs, + ): + super().__init__(keys=keys, prob=prob, gamma=gamma, **kwargs) + + +class RandScaleIntensityd(RandScaleIntensityd): + def __init__( + self, + keys: Sequence[str] | str, + factors: tuple[float, float] | float, + prob: float, + **kwargs, + ): + super().__init__(keys=keys, factors=factors, prob=prob, **kwargs) + + +class RandGaussianNoised(RandGaussianNoised): + def __init__( + self, + keys: Sequence[str] | str, + prob: float, + mean: float, + std: float, + **kwargs, + ): + super().__init__(keys=keys, prob=prob, mean=mean, std=std, **kwargs) + + +class RandGaussianSmoothd(RandGaussianSmoothd): + def __init__( + self, + keys: Sequence[str] | str, + prob: float, + sigma_x: tuple[float, float] | float, + sigma_y: tuple[float, float] | float, + sigma_z: tuple[float, float] | float, + **kwargs, + ): + super().__init__( + keys=keys, + prob=prob, + sigma_x=sigma_x, + sigma_y=sigma_y, + sigma_z=sigma_z, + **kwargs, + ) + + +class ScaleIntensityRangePercentilesd(ScaleIntensityRangePercentilesd): + def __init__( + self, + keys: Sequence[str] | str, + lower: float, + upper: float, + b_min: float | None, + b_max: float | None, + clip: bool = False, + relative: bool = False, + channel_wise: bool = False, + dtype: DTypeLike | None = None, + allow_missing_keys: bool = False, + ): + super().__init__( + keys=keys, + lower=lower, + upper=upper, + b_min=b_min, + b_max=b_max, + clip=clip, + relative=relative, + channel_wise=channel_wise, + dtype=dtype, + allow_missing_keys=allow_missing_keys, + ) diff --git a/viscy/transforms.py b/viscy/transforms/_transforms.py similarity index 60% rename from viscy/transforms.py rename to viscy/transforms/_transforms.py index f4e8103bc..0b99a2ac4 100644 --- a/viscy/transforms.py +++ b/viscy/transforms/_transforms.py @@ -1,20 +1,9 @@ -"""Redefine transforms from MONAI for jsonargparse.""" - -from typing import Sequence, Union - import numpy as np import torch from monai.transforms import ( MapTransform, MultiSampleTrait, - RandAdjustContrastd, - RandAffined, - RandGaussianNoised, - RandGaussianSmoothd, RandomizableTransform, - RandScaleIntensityd, - RandWeightedCropd, - ScaleIntensityRangePercentilesd, Transform, ) from torch import Tensor @@ -23,149 +12,13 @@ from viscy.data.typing import ChannelMap, Sample -class RandWeightedCropd(RandWeightedCropd): - def __init__( - self, - keys: Union[Sequence[str], str], - w_key: str, - spatial_size: Sequence[int], - num_samples: int = 1, - **kwargs, - ): - super().__init__( - keys=keys, - w_key=w_key, - spatial_size=spatial_size, - num_samples=num_samples, - **kwargs, - ) - - -class RandAffined(RandAffined): - def __init__( - self, - keys: Union[Sequence[str], str], - prob: float, - rotate_range: Union[Sequence[float], float], - shear_range: Union[Sequence[float], float], - scale_range: Union[Sequence[float], float], - **kwargs, - ): - super().__init__( - keys=keys, - prob=prob, - rotate_range=rotate_range, - shear_range=shear_range, - scale_range=scale_range, - **kwargs, - ) - - -class RandAdjustContrastd(RandAdjustContrastd): - def __init__( - self, - keys: Union[Sequence[str], str], - prob: float, - gamma: Union[Sequence[float], float], - **kwargs, - ): - super().__init__( - keys=keys, - prob=prob, - gamma=gamma, - **kwargs, - ) - - -class RandScaleIntensityd(RandScaleIntensityd): - def __init__( - self, - keys: Union[Sequence[str], str], - factors: Union[Sequence[float], float], - prob: float, - **kwargs, - ): - super().__init__( - keys=keys, - factors=factors, - prob=prob, - **kwargs, - ) - - -class RandGaussianNoised(RandGaussianNoised): - def __init__( - self, - keys: Union[Sequence[str], str], - prob: float, - mean: Union[Sequence[float], float], - std: Union[Sequence[float], float], - **kwargs, - ): - super().__init__( - keys=keys, - prob=prob, - mean=mean, - std=std, - **kwargs, - ) - - -class RandGaussianSmoothd(RandGaussianSmoothd): - def __init__( - self, - keys: Union[Sequence[str], str], - prob: float, - sigma_x: Union[Sequence[float], float], - sigma_y: Union[Sequence[float], float], - sigma_z: Union[Sequence[float], float], - **kwargs, - ): - super().__init__( - keys=keys, - prob=prob, - sigma_x=sigma_x, - sigma_y=sigma_y, - sigma_z=sigma_z, - **kwargs, - ) - - -class ScaleIntensityRangePercentilesd(ScaleIntensityRangePercentilesd): - def __init__( - self, - keys: Union[Sequence[str], str], - lower: float, - upper: float, - b_min: float | None, - b_max: float | None, - clip: bool = False, - relative: bool = False, - channel_wise: bool = False, - dtype: Union[Sequence[str], str] = None, - allow_missing_keys: bool = False, - ): - super().__init__( - keys=keys, - lower=lower, - upper=upper, - b_min=b_min, - b_max=b_max, - clip=clip, - relative=relative, - channel_wise=channel_wise, - dtype=dtype, - allow_missing_keys=allow_missing_keys, - ) - - class NormalizeSampled(MapTransform): """ Normalize the sample. Parameters ---------- - keys : Union[str, Iterable[str]] + keys : str | Iterable[str] Keys to normalize. level : {'fov_statistics', 'dataset_statistics'} Level of normalization. @@ -179,7 +32,7 @@ class NormalizeSampled(MapTransform): def __init__( self, - keys: Union[str, Iterable[str]], + keys: str | Iterable[str], level: Literal["fov_statistics", "dataset_statistics"], subtrahend="mean", divisor="std", @@ -213,7 +66,7 @@ class RandInvertIntensityd(MapTransform, RandomizableTransform): def __init__( self, - keys: Union[str, Iterable[str]], + keys: str | Iterable[str], prob: float = 0.1, allow_missing_keys: bool = False, ) -> None: @@ -238,7 +91,7 @@ class TiledSpatialCropSamplesd(MapTransform, MultiSampleTrait): def __init__( self, - keys: Union[str, Iterable[str]], + keys: str | Iterable[str], roi_size: tuple[int, int, int], num_samples: int, ) -> None: