Skip to content

Commit

Permalink
Add Log1pFilter, ScaleFilter, and ZScoreFilter.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 589919755
  • Loading branch information
jan-matthis authored and copybara-github committed Dec 11, 2023
1 parent 2e4f359 commit de25382
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 10 deletions.
50 changes: 50 additions & 0 deletions connectomics/volume/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import jax
import numpy as np
import scipy.ndimage
import scipy.stats
import skimage.feature
import tensorstore as ts
from typing_extensions import Protocol
Expand Down Expand Up @@ -515,6 +516,21 @@ def __init__(self,
**filter_args)


@gin.register
class Log1pFilter(Filter):
"""Applies log(1+x)."""

def __init__(self,
min_chunksize: Optional[Sequence[int]] = None,
context_spec: Optional[MutableJsonSpec] = None,
**filter_args):
super().__init__(
filter_fun=np.log1p,
context_spec=context_spec,
min_chunksize=min_chunksize,
**filter_args)


@gin.register
class MedianFilter(Filter):
"""Runs median filter over image."""
Expand Down Expand Up @@ -631,6 +647,25 @@ def __init__(self,
**filter_args)


def _scale(data: np.ndarray, factor: float) -> np.ndarray:
return data * factor


@gin.register
class ScaleFilter(Filter):
"""Scales data."""

def __init__(self,
min_chunksize: Optional[Sequence[int]] = None,
context_spec: Optional[MutableJsonSpec] = None,
**filter_args):
super().__init__(
filter_fun=_scale,
context_spec=context_spec,
min_chunksize=min_chunksize,
**filter_args)


def _threshold(
data: np.ndarray, threshold: Union[int, float], indices: bool = False
) -> np.ndarray:
Expand Down Expand Up @@ -696,6 +731,21 @@ def __init__(self,
**filter_args)


@gin.register
class ZScoreFilter(Filter):
"""Applies z-scoring based on calculated mean and standard deviation."""

def __init__(self,
min_chunksize: Optional[Sequence[int]] = None,
context_spec: Optional[MutableJsonSpec] = None,
**filter_args):
super().__init__(
filter_fun=scipy.stats.zscore,
context_spec=context_spec,
min_chunksize=min_chunksize,
**filter_args)


@gin.register
class Interpolation(Decorator):
"""Interpolates input TensorStore."""
Expand Down
44 changes: 34 additions & 10 deletions connectomics/volume/decorators_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,16 @@ def test_clahe_filter_with_overlap(self):
np.array(data), mode='equalize_adapthist', cast_float64=True,
**filter_args))

def test_clip_filter(self):
filter_args = {'a_min': 0.5, 'a_max': None}
dec = decorators.ClipFilter(
min_chunksize=self._data.shape, **filter_args)
vc = dec.decorate(self._data)
res = vc[...].read().result()
np.testing.assert_equal(
res, np.clip(np.array(self._data), **filter_args))
self.assertTrue(np.any(np.not_equal(res, self._data)))

def test_gaussian_filter(self):
filter_args = {'sigma': [1.] * self._data.ndim}
dec = decorators.GaussianFilter(
Expand All @@ -157,6 +167,12 @@ def test_label_filter(self):
vc[...].read().result(),
decorators._label_filter(np.array(self._data)))

def test_log1p_filter(self):
dec = decorators.Log1pFilter(
min_chunksize=self._data.shape)
vc = dec.decorate(self._data)
np.testing.assert_equal(vc[...].read().result(), np.log1p(self._data[...]))

def test_median_filter(self):
filter_args = {'size': [3] * self._data.ndim}
dec = decorators.MedianFilter(min_chunksize=self._data.shape, **filter_args)
Expand Down Expand Up @@ -234,24 +250,23 @@ def test_percentile_filter(self):
vc[...].read().result(),
scipy.ndimage.percentile_filter(self._data[...], **filter_args))

def test_threshold_filter(self):
filter_args = {'threshold': 0.25}
dec = decorators.ThresholdFilter(
def test_scale_filter(self):
filter_args = {'factor': 0.5}
dec = decorators.ScaleFilter(
min_chunksize=self._data.shape, **filter_args)
vc = dec.decorate(self._data)
res = vc[...].read().result()
np.testing.assert_equal(
res, decorators._threshold(np.array(self._data), **filter_args))
vc[...].read().result(),
filter_args['factor'] * self._data[...].read().result())

def test_clip_filter(self):
filter_args = {'a_min': 0.5, 'a_max': None}
dec = decorators.ClipFilter(
def test_threshold_filter(self):
filter_args = {'threshold': 0.25}
dec = decorators.ThresholdFilter(
min_chunksize=self._data.shape, **filter_args)
vc = dec.decorate(self._data)
res = vc[...].read().result()
np.testing.assert_equal(
res, np.clip(np.array(self._data), **filter_args))
self.assertTrue(np.any(np.not_equal(res, self._data)))
res, decorators._threshold(np.array(self._data), **filter_args))

def test_standardize_filter(self):
filter_args = {'mean': 5, 'std': 3}
Expand All @@ -262,6 +277,15 @@ def test_standardize_filter(self):
res_true = (np.array(self._data) - 5) / 3
np.testing.assert_equal(res_true, res)

def test_zscore_filter(self):
filter_args = {'axis': None}
dec = decorators.ZScoreFilter(
min_chunksize=self._data.shape, **filter_args)
vc = dec.decorate(self._data)
np.testing.assert_equal(
vc[...].read().result(),
scipy.stats.zscore(self._data[...], **filter_args))

def test_max_projection(self):
for projection_dim in (0, 1):
dec = decorators.MaxProjection(projection_dim=projection_dim)
Expand Down

0 comments on commit de25382

Please sign in to comment.