diff --git a/viscy/transforms/__init__.py b/viscy/transforms/__init__.py index 3fff21fa1..1afb6a81c 100644 --- a/viscy/transforms/__init__.py +++ b/viscy/transforms/__init__.py @@ -1,9 +1,11 @@ from viscy.transforms._redef import ( + CenterSpatialCropd, RandAdjustContrastd, RandAffined, RandGaussianNoised, RandGaussianSmoothd, RandScaleIntensityd, + RandSpatialCropd, RandWeightedCropd, ScaleIntensityRangePercentilesd, ) @@ -17,6 +19,7 @@ __all__ = [ "BatchedZoom", + "CenterSpatialCropd", "NormalizeSampled", "RandAdjustContrastd", "RandAffined", @@ -24,6 +27,7 @@ "RandGaussianSmoothd", "RandInvertIntensityd", "RandScaleIntensityd", + "RandSpatialCropd", "RandWeightedCropd", "ScaleIntensityRangePercentilesd", "StackChannelsd", diff --git a/viscy/transforms/_redef.py b/viscy/transforms/_redef.py index e41c27446..a9929d495 100644 --- a/viscy/transforms/_redef.py +++ b/viscy/transforms/_redef.py @@ -3,11 +3,13 @@ from typing import Sequence from monai.transforms import ( + CenterSpatialCropd, RandAdjustContrastd, RandAffined, RandGaussianNoised, RandGaussianSmoothd, RandScaleIntensityd, + RandSpatialCropd, RandWeightedCropd, ScaleIntensityRangePercentilesd, ) @@ -37,9 +39,9 @@ def __init__( self, keys: Sequence[str] | str, prob: float, - rotate_range: Sequence[float] | float, - shear_range: Sequence[float] | float, - scale_range: Sequence[float] | float, + rotate_range: Sequence[float | Sequence[float]] | float, + shear_range: Sequence[float | Sequence[float]] | float, + scale_range: Sequence[float | Sequence[float]] | float, **kwargs, ): super().__init__( @@ -132,3 +134,29 @@ def __init__( dtype=dtype, allow_missing_keys=allow_missing_keys, ) + + +class RandSpatialCropd(RandSpatialCropd): + def __init__( + self, + keys: Sequence[str] | str, + roi_size: Sequence[int] | int, + random_center: bool = True, + **kwargs, + ): + super().__init__( + keys=keys, + roi_size=roi_size, + random_center=random_center, + **kwargs, + ) + + +class CenterSpatialCropd(CenterSpatialCropd): + def __init__( + self, + keys: Sequence[str] | str, + roi_size: Sequence[int] | int, + **kwargs, + ): + super().__init__(keys=keys, roi_size=roi_size, **kwargs)