Skip to content
4 changes: 4 additions & 0 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
RandGaussianSharpen,
RandGaussianSmooth,
RandHistogramShift,
RandRicianNoise,
RandScaleIntensity,
RandShiftIntensity,
RandStdShiftIntensity,
Expand Down Expand Up @@ -123,6 +124,9 @@
RandHistogramShiftd,
RandHistogramShiftD,
RandHistogramShiftDict,
RandRicianNoised,
RandRicianNoiseD,
RandRicianNoiseDict,
RandScaleIntensityd,
RandScaleIntensityD,
RandScaleIntensityDict,
Expand Down
85 changes: 84 additions & 1 deletion monai/transforms/intensity/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,17 @@
from monai.networks.layers import GaussianFilter, HilbertTransform, SavitzkyGolayFilter
from monai.transforms.transform import RandomizableTransform, Transform
from monai.transforms.utils import rescale_array
from monai.utils import PT_BEFORE_1_7, InvalidPyTorchVersionError, dtype_torch_to_numpy, ensure_tuple_size
from monai.utils import (
PT_BEFORE_1_7,
InvalidPyTorchVersionError,
dtype_torch_to_numpy,
ensure_tuple_rep,
ensure_tuple_size,
)

__all__ = [
"RandGaussianNoise",
"RandRicianNoise",
"ShiftIntensity",
"RandShiftIntensity",
"StdShiftIntensity",
Expand Down Expand Up @@ -85,6 +92,82 @@ def __call__(self, img: Union[torch.Tensor, np.ndarray]) -> Union[torch.Tensor,
return img + self._noise.astype(dtype)


class RandRicianNoise(RandomizableTransform):
"""
Add Rician noise to image.
Rician noise in MRI is the result of performing a magnitude operation on complex
data with Gaussian noise of the same variance in both channels, as described in `Noise in Magnitude Magnetic Resonance Images
<https://doi.org/10.1002/cmr.a.20124>`_. This transform is adapted from
`DIPY<https://github.com/dipy/dipy>`_. See also: `The rician distribution of noisy mri data
<https://doi.org/10.1002/mrm.1910340618>`_.

Args:
prob: Probability to add Rician noise.
mean: Mean or "centre" of the Gaussian distributions sampled to make up
the Rician noise.
std: Standard deviation (spread) of the Gaussian distributions sampled
to make up the Rician noise.
channel_wise: If True, treats each channel of the image separately.
relative: If True, the spread of the sampled Gaussian distributions will
be std times the standard deviation of the image or channel's intensity
histogram.
sample_std: If True, sample the spread of the Gaussian distributions
uniformly from 0 to std.
"""

def __init__(
self,
prob: float = 0.1,
mean: Union[Sequence[float], float] = 0.0,
std: Union[Sequence[float], float] = 1.0,
channel_wise: bool = False,
relative: bool = False,
sample_std: bool = True,
) -> None:
RandomizableTransform.__init__(self, prob)
self.prob = prob
self.mean = mean
self.std = std
self.channel_wise = channel_wise
self.relative = relative
self.sample_std = sample_std
self._noise1 = None
self._noise2 = None

def _add_noise(self, img: Union[torch.Tensor, np.ndarray], mean: float, std: float):
im_shape = img.shape
_std = self.R.uniform(0, std) if self.sample_std else std
self._noise1 = self.R.normal(mean, _std, size=im_shape)
self._noise2 = self.R.normal(mean, _std, size=im_shape)
if self._noise1 is None or self._noise2 is None:
raise AssertionError
dtype = dtype_torch_to_numpy(img.dtype) if isinstance(img, torch.Tensor) else img.dtype
return np.sqrt((img + self._noise1.astype(dtype)) ** 2 + self._noise2.astype(dtype) ** 2)

def __call__(self, img: Union[torch.Tensor, np.ndarray]) -> Union[torch.Tensor, np.ndarray]:
"""
Apply the transform to `img`.
"""
super().randomize(None)
if not self._do_transform:
return img
if self.channel_wise:
_mean = ensure_tuple_rep(self.mean, len(img))
_std = ensure_tuple_rep(self.std, len(img))
for i, d in enumerate(img):
img[i] = self._add_noise(d, mean=_mean[i], std=_std[i] * d.std() if self.relative else _std[i])
else:
if not isinstance(self.mean, (int, float)):
raise AssertionError("If channel_wise is False, mean must be a float or int number.")
if not isinstance(self.std, (int, float)):
raise AssertionError("If channel_wise is False, std must be a float or int number.")
std = self.std * img.std() if self.relative else self.std
if not isinstance(std, (int, float)):
raise AssertionError
img = self._add_noise(img, mean=self.mean, std=std)
return img


class ShiftIntensity(Transform):
"""
Shift intensity uniformly for the entire image with specified `offset`.
Expand Down
62 changes: 62 additions & 0 deletions monai/transforms/intensity/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
MaskIntensity,
NormalizeIntensity,
RandBiasField,
RandRicianNoise,
ScaleIntensity,
ScaleIntensityRange,
ScaleIntensityRangePercentiles,
Expand All @@ -41,6 +42,7 @@

__all__ = [
"RandGaussianNoised",
"RandRicianNoised",
"ShiftIntensityd",
"RandShiftIntensityd",
"ScaleIntensityd",
Expand Down Expand Up @@ -152,6 +154,65 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda
return d


class RandRicianNoised(RandomizableTransform, MapTransform):
"""
Dictionary-based version :py:class:`monai.transforms.RandRicianNoise`.
Add Rician noise to image. This transform assumes all the expected fields have same shape.

Args:
keys: Keys of the corresponding items to be transformed.
See also: :py:class:`monai.transforms.compose.MapTransform`
global_prob: Probability to add Rician noise to the dictionary.
prob: Probability to add Rician noise to each item in the dictionary,
once asserted that noise will be added to the dictionary at all.
mean: Mean or "centre" of the Gaussian distributions sampled to make up
the Rician noise.
std: Standard deviation (spread) of the Gaussian distributions sampled
to make up the Rician noise.
channel_wise: If True, treats each channel of the image separately.
relative: If True, the spread of the sampled Gaussian distributions will
be std times the standard deviation of the image or channel's intensity
histogram.
sample_std: If True, sample the spread of the Gaussian distributions
uniformly from 0 to std.
allow_missing_keys: Don't raise exception if key is missing.
"""

def __init__(
self,
keys: KeysCollection,
global_prob: float = 0.1,
prob: float = 1.0,
mean: Union[Sequence[float], float] = 0.0,
std: Union[Sequence[float], float] = 1.0,
channel_wise: bool = False,
relative: bool = False,
sample_std: bool = True,
allow_missing_keys: bool = False,
) -> None:
MapTransform.__init__(self, keys, allow_missing_keys)
RandomizableTransform.__init__(self, global_prob)
self.rand_rician_noise = RandRicianNoise(
prob,
mean,
std,
channel_wise,
relative,
sample_std,
)

def __call__(
self, data: Mapping[Hashable, Union[torch.Tensor, np.ndarray]]
) -> Dict[Hashable, Union[torch.Tensor, np.ndarray]]:
d = dict(data)
super().randomize(None)
if not self._do_transform:
return d
for key in self.key_iterator(d):
d[key] = self.rand_rician_noise(d[key])
return d


class ShiftIntensityd(MapTransform):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.ShiftIntensity`.
Expand Down Expand Up @@ -958,6 +1019,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda


RandGaussianNoiseD = RandGaussianNoiseDict = RandGaussianNoised
RandRicianNoiseD = RandRicianNoiseDict = RandRicianNoised
ShiftIntensityD = ShiftIntensityDict = ShiftIntensityd
RandShiftIntensityD = RandShiftIntensityDict = RandShiftIntensityd
StdShiftIntensityD = StdShiftIntensityDict = StdShiftIntensityd
Expand Down
56 changes: 56 additions & 0 deletions tests/test_rand_rician_noise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np
from parameterized import parameterized

from monai.transforms import RandRicianNoise
from tests.utils import NumpyImageTestCase2D, TorchImageTestCase2D


class TestRandRicianNoise(NumpyImageTestCase2D):
@parameterized.expand([("test_zero_mean", 0, 0.1), ("test_non_zero_mean", 1, 0.5)])
def test_correct_results(self, _, mean, std):
seed = 0
rician_fn = RandRicianNoise(prob=1.0, mean=mean, std=std)
rician_fn.set_random_state(seed)
noised = rician_fn(self.imt)
np.random.seed(seed)
np.random.random()
_std = np.random.uniform(0, std)
expected = np.sqrt(
(self.imt + np.random.normal(mean, _std, size=self.imt.shape)) ** 2
+ np.random.normal(mean, _std, size=self.imt.shape) ** 2
)
np.testing.assert_allclose(expected, noised, atol=1e-5)


class TestRandRicianNoiseTorch(TorchImageTestCase2D):
@parameterized.expand([("test_zero_mean", 0, 0.1), ("test_non_zero_mean", 1, 0.5)])
def test_correct_results(self, _, mean, std):
seed = 0
rician_fn = RandRicianNoise(prob=1.0, mean=mean, std=std)
rician_fn.set_random_state(seed)
noised = rician_fn(self.imt)
np.random.seed(seed)
np.random.random()
_std = np.random.uniform(0, std)
expected = np.sqrt(
(self.imt + np.random.normal(mean, _std, size=self.imt.shape)) ** 2
+ np.random.normal(mean, _std, size=self.imt.shape) ** 2
)
np.testing.assert_allclose(expected, noised, atol=1e-5)


if __name__ == "__main__":
unittest.main()
60 changes: 60 additions & 0 deletions tests/test_rand_rician_noised.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np
from parameterized import parameterized

from monai.transforms import RandRicianNoised
from tests.utils import NumpyImageTestCase2D, TorchImageTestCase2D

TEST_CASE_0 = ["test_zero_mean", ["img1", "img2"], 0, 0.1]
TEST_CASE_1 = ["test_non_zero_mean", ["img1", "img2"], 1, 0.5]
TEST_CASES = [TEST_CASE_0, TEST_CASE_1]

seed = 0


def test_numpy_or_torch(keys, mean, std, imt):
rician_fn = RandRicianNoised(keys=keys, global_prob=1.0, prob=1.0, mean=mean, std=std)
rician_fn.set_random_state(seed)
rician_fn.rand_rician_noise.set_random_state(seed)
noised = rician_fn({k: imt for k in keys})
np.random.seed(seed)
np.random.random()
np.random.seed(seed)
for k in keys:
np.random.random()
_std = np.random.uniform(0, std)
expected = np.sqrt(
(imt + np.random.normal(mean, _std, size=imt.shape)) ** 2
+ np.random.normal(mean, _std, size=imt.shape) ** 2
)
np.testing.assert_allclose(expected, noised[k], atol=1e-5, rtol=1e-5)


# Test with numpy
class TestRandRicianNoisedNumpy(NumpyImageTestCase2D):
@parameterized.expand(TEST_CASES)
def test_correct_results(self, _, keys, mean, std):
test_numpy_or_torch(keys, mean, std, self.imt)


# Test with torch
class TestRandRicianNoisedTorch(TorchImageTestCase2D):
@parameterized.expand(TEST_CASES)
def test_correct_results(self, _, keys, mean, std):
test_numpy_or_torch(keys, mean, std, self.imt)


if __name__ == "__main__":
unittest.main()