Skip to content

Commit

Permalink
re-define cropping transforms
Browse files Browse the repository at this point in the history
  • Loading branch information
ziw-liu committed Jan 27, 2025
1 parent e05d802 commit ca8095d
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 3 deletions.
4 changes: 4 additions & 0 deletions viscy/transforms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from viscy.transforms._redef import (
CenterSpatialCropd,
RandAdjustContrastd,
RandAffined,
RandGaussianNoised,
RandGaussianSmoothd,
RandScaleIntensityd,
RandSpatialCropd,
RandWeightedCropd,
ScaleIntensityRangePercentilesd,
)
Expand All @@ -17,13 +19,15 @@

__all__ = [
"BatchedZoom",
"CenterSpatialCropd",
"NormalizeSampled",
"RandAdjustContrastd",
"RandAffined",
"RandGaussianNoised",
"RandGaussianSmoothd",
"RandInvertIntensityd",
"RandScaleIntensityd",
"RandSpatialCropd",
"RandWeightedCropd",
"ScaleIntensityRangePercentilesd",
"StackChannelsd",
Expand Down
34 changes: 31 additions & 3 deletions viscy/transforms/_redef.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
from typing import Sequence

from monai.transforms import (
CenterSpatialCropd,
RandAdjustContrastd,
RandAffined,
RandGaussianNoised,
RandGaussianSmoothd,
RandScaleIntensityd,
RandSpatialCropd,
RandWeightedCropd,
ScaleIntensityRangePercentilesd,
)
Expand Down Expand Up @@ -37,9 +39,9 @@ def __init__(
self,
keys: Sequence[str] | str,
prob: float,
rotate_range: Sequence[float] | float,
shear_range: Sequence[float] | float,
scale_range: Sequence[float] | float,
rotate_range: Sequence[float | Sequence[float]] | float,
shear_range: Sequence[float | Sequence[float]] | float,
scale_range: Sequence[float | Sequence[float]] | float,
**kwargs,
):
super().__init__(
Expand Down Expand Up @@ -132,3 +134,29 @@ def __init__(
dtype=dtype,
allow_missing_keys=allow_missing_keys,
)


class RandSpatialCropd(RandSpatialCropd):
def __init__(
self,
keys: Sequence[str] | str,
roi_size: Sequence[int] | int,
random_center: bool = True,
**kwargs,
):
super().__init__(
keys=keys,
roi_size=roi_size,
random_center=random_center,
**kwargs,
)


class CenterSpatialCropd(CenterSpatialCropd):
def __init__(
self,
keys: Sequence[str] | str,
roi_size: Sequence[int] | int,
**kwargs,
):
super().__init__(keys=keys, roi_size=roi_size, **kwargs)

0 comments on commit ca8095d

Please sign in to comment.