Skip to content

Commit f8650e4

Browse files
committed
enh: propose a new API for dense fields
1 parent 5da403b commit f8650e4

File tree

4 files changed

+69
-67
lines changed

4 files changed

+69
-67
lines changed

nitransforms/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
"""
1919
from . import linear, manip, nonlinear
2020
from .linear import Affine, LinearTransformsMapping
21-
from .nonlinear import DisplacementsFieldTransform
21+
from .nonlinear import DenseFieldTransform
2222
from .manip import TransformChain
2323

2424
try:
@@ -42,7 +42,7 @@
4242
"nonlinear",
4343
"Affine",
4444
"LinearTransformsMapping",
45-
"DisplacementsFieldTransform",
45+
"DenseFieldTransform",
4646
"TransformChain",
4747
"__copyright__",
4848
"__packagename__",

nitransforms/manip.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
TransformError,
1616
)
1717
from .linear import Affine
18-
from .nonlinear import DisplacementsFieldTransform
18+
from .nonlinear import DenseFieldTransform
1919

2020

2121
class TransformChain(TransformBase):
@@ -197,7 +197,7 @@ def from_filename(cls, filename, fmt="X5", reference=None, moving=None):
197197
if isinstance(xfmobj, itk.ITKLinearTransform):
198198
retval.insert(0, Affine(xfmobj.to_ras(), reference=reference))
199199
else:
200-
retval.insert(0, DisplacementsFieldTransform(xfmobj))
200+
retval.insert(0, DenseFieldTransform(xfmobj))
201201

202202
return TransformChain(retval)
203203

nitransforms/nonlinear.py

Lines changed: 59 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -20,31 +20,43 @@
2020
ImageGrid,
2121
SpatialReference,
2222
_as_homogeneous,
23-
EQUALITY_TOL,
2423
)
2524

2625

27-
class DeformationFieldTransform(TransformBase):
28-
"""Represents a dense field of deformed locations (corresponding to each voxel)."""
26+
class DenseFieldTransform(TransformBase):
27+
"""Represents dense field (voxel-wise) transforms."""
2928

30-
__slots__ = ["_field"]
29+
__slots__ = ("_field", "_displacements")
3130

32-
def __init__(self, field, reference=None):
31+
def __init__(self, field=None, displacements=True, reference=None):
3332
"""
34-
Create a dense deformation field transform.
33+
Create a dense field transform.
34+
35+
Converting to a field of deformations is straightforward by just adding the corresponding
36+
displacement to the :math:`(x, y, z)` coordinates of each voxel.
37+
Numerically, deformation fields are less susceptible to rounding errors
38+
than displacements fields.
39+
SPM generally prefers deformations for that reason.
3540
3641
Example
3742
-------
38-
>>> DeformationFieldTransform(test_dir / "someones_displacement_field.nii.gz")
39-
<DeformationFieldTransform[3D] (57, 67, 56)>
43+
>>> DenseFieldTransform(test_dir / "someones_displacement_field.nii.gz")
44+
<DenseFieldTransform[3D] (57, 67, 56)>
4045
4146
"""
47+
if field is None and reference is None:
48+
raise TransformError("DenseFieldTransforms require a spatial reference")
49+
4250
super().__init__()
4351

44-
field = _ensure_image(field)
45-
self._field = np.squeeze(
46-
np.asanyarray(field.dataobj) if hasattr(field, "dataobj") else field
47-
)
52+
if field is not None:
53+
field = _ensure_image(field)
54+
self._field = np.squeeze(
55+
np.asanyarray(field.dataobj) if hasattr(field, "dataobj") else field
56+
)
57+
else:
58+
self._field = np.zeros((*reference.shape, reference.ndim), dtype="float32")
59+
displacements = True
4860

4961
try:
5062
self.reference = ImageGrid(
@@ -64,6 +76,12 @@ def __init__(self, field, reference=None):
6476
"the number of dimensions (%d)" % (self._field.shape[-1], ndim)
6577
)
6678

79+
if displacements:
80+
self._displacements = self._field
81+
# Convert from displacements to deformations fields
82+
# (just add the origin to the displacements vector)
83+
self._field += self.reference.ndcoords.T.reshape(self._field.shape)
84+
6785
def __repr__(self):
6886
"""Beautify the python representation."""
6987
return f"<{self.__class__.__name__}[{self._field.shape[-1]}D] {self._field.shape[:3]}>"
@@ -93,13 +111,23 @@ def map(self, x, inverse=False):
93111
94112
Examples
95113
--------
96-
>>> xfm = DeformationFieldTransform(test_dir / "someones_displacement_field.nii.gz")
114+
>>> xfm = DenseFieldTransform(
115+
... test_dir / "someones_displacement_field.nii.gz",
116+
... displacements=False,
117+
... )
97118
>>> xfm.map([-6.5, -36., -19.5]).tolist()
98119
[[0.0, -0.47516798973083496, 0.0]]
99120
100121
>>> xfm.map([[-6.5, -36., -19.5], [-1., -41.5, -11.25]]).tolist()
101122
[[0.0, -0.47516798973083496, 0.0], [0.0, -0.538356602191925, 0.0]]
102123
124+
>>> xfm = DenseFieldTransform(
125+
... test_dir / "someones_displacement_field.nii.gz",
126+
... displacements=True,
127+
... )
128+
>>> xfm.map([[-6.5, -36., -19.5], [-1., -41.5, -11.25]]).tolist()
129+
[[-6.5, -36.47516632080078, -19.5], [-1.0, -42.03835678100586, -11.25]]
130+
103131
"""
104132

105133
if inverse is True:
@@ -117,25 +145,33 @@ def __matmul__(self, b):
117145
118146
Examples
119147
--------
120-
>>> xfm = DeformationFieldTransform(test_dir / "someones_displacement_field.nii.gz")
121-
>>> xfm2 = xfm @ TransformBase()
122-
>>> xfm == xfm2
148+
>>> deff = DenseFieldTransform(
149+
... test_dir / "someones_displacement_field.nii.gz",
150+
... displacements=False,
151+
... )
152+
>>> deff2 = deff @ TransformBase()
153+
>>> deff == deff2
154+
True
155+
156+
>>> disp = DenseFieldTransform(test_dir / "someones_displacement_field.nii.gz")
157+
>>> disp2 = disp @ TransformBase()
158+
>>> disp == disp2
123159
True
124160
125161
"""
126162
retval = b.map(
127163
self._field.reshape((-1, self._field.shape[-1]))
128164
).reshape(self._field.shape)
129-
return DeformationFieldTransform(retval, reference=self.reference)
165+
return DenseFieldTransform(retval, displacements=False, reference=self.reference)
130166

131167
def __eq__(self, other):
132168
"""
133169
Overload equals operator.
134170
135171
Examples
136172
--------
137-
>>> xfm1 = DeformationFieldTransform(test_dir / "someones_displacement_field.nii.gz")
138-
>>> xfm2 = DeformationFieldTransform(test_dir / "someones_displacement_field.nii.gz")
173+
>>> xfm1 = DenseFieldTransform(test_dir / "someones_displacement_field.nii.gz")
174+
>>> xfm2 = DenseFieldTransform(test_dir / "someones_displacement_field.nii.gz")
139175
>>> xfm1 == xfm2
140176
True
141177
@@ -145,41 +181,6 @@ def __eq__(self, other):
145181
warnings.warn("Fields are equal, but references do not match.")
146182
return _eq
147183

148-
149-
class DisplacementsFieldTransform(DeformationFieldTransform):
150-
"""
151-
Represents a dense field of displacements (one vector per voxel).
152-
153-
Converting to a field of deformations is straightforward by just adding the corresponding
154-
displacement to the :math:`(x, y, z)` coordinates of each voxel.
155-
Numerically, deformation fields are less susceptible to rounding errors
156-
than displacements fields.
157-
SPM generally prefers deformations for that reason.
158-
159-
"""
160-
161-
__slots__ = ["_displacements"]
162-
163-
def __init__(self, field, reference=None):
164-
"""
165-
Create a transform supported by a field of voxel-wise displacements.
166-
167-
Example
168-
-------
169-
>>> xfm = DisplacementsFieldTransform(test_dir / "someones_displacement_field.nii.gz")
170-
>>> xfm
171-
<DisplacementsFieldTransform[3D] (57, 67, 56)>
172-
173-
>>> xfm.map([[-6.5, -36., -19.5], [-1., -41.5, -11.25]]).tolist()
174-
[[-6.5, -36.47516632080078, -19.5], [-1.0, -42.03835678100586, -11.25]]
175-
176-
"""
177-
super().__init__(field, reference=reference)
178-
self._displacements = self._field
179-
# Convert from displacements to deformations fields
180-
# (just add the origin to the displacements vector)
181-
self._field += self.reference.ndcoords.T.reshape(self._field.shape)
182-
183184
@classmethod
184185
def from_filename(cls, filename, fmt="X5"):
185186
_factory = {
@@ -193,7 +194,7 @@ def from_filename(cls, filename, fmt="X5"):
193194
return cls(_factory[fmt].from_filename(filename))
194195

195196

196-
load = DisplacementsFieldTransform.from_filename
197+
load = DenseFieldTransform.from_filename
197198

198199

199200
class BSplineFieldTransform(TransformBase):
@@ -239,8 +240,9 @@ def to_field(self, reference=None, dtype="float32"):
239240
# 1 x Nvox : (1 x K) @ (K x Nvox)
240241
field[:, d] = self._coeffs[..., d].reshape(-1) @ self._weights
241242

242-
return DisplacementsFieldTransform(
243-
field.astype(dtype).reshape(*_ref.shape, -1), reference=_ref)
243+
return DenseFieldTransform(
244+
field.astype(dtype).reshape(*_ref.shape, -1), reference=_ref
245+
)
244246

245247
def apply(
246248
self,

nitransforms/tests/test_nonlinear.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from nitransforms.io.base import TransformFileError
1313
from nitransforms.nonlinear import (
1414
BSplineFieldTransform,
15-
DisplacementsFieldTransform,
15+
DenseFieldTransform,
1616
load as nlload,
1717
)
1818
from ..io.itk import ITKDisplacementsField
@@ -45,7 +45,7 @@ def test_itk_disp_load(size):
4545
def test_displacements_bad_sizes(size):
4646
"""Checks field sizes."""
4747
with pytest.raises(TransformError):
48-
DisplacementsFieldTransform(nb.Nifti1Image(np.zeros(size), np.eye(4), None))
48+
DenseFieldTransform(nb.Nifti1Image(np.zeros(size), np.eye(4), None))
4949

5050

5151
def test_itk_disp_load_intent():
@@ -59,15 +59,15 @@ def test_itk_disp_load_intent():
5959

6060

6161
def test_displacements_init():
62-
DisplacementsFieldTransform(
62+
DenseFieldTransform(
6363
np.zeros((10, 10, 10, 3)),
6464
reference=nb.Nifti1Image(np.zeros((10, 10, 10, 3)), np.eye(4), None),
6565
)
6666

6767
with pytest.raises(TransformError):
68-
DisplacementsFieldTransform(np.zeros((10, 10, 10, 3)))
68+
DenseFieldTransform(np.zeros((10, 10, 10, 3)))
6969
with pytest.raises(TransformError):
70-
DisplacementsFieldTransform(
70+
DenseFieldTransform(
7171
np.zeros((10, 10, 10, 3)),
7272
reference=np.zeros((10, 10, 10, 3)),
7373
)
@@ -237,7 +237,7 @@ def test_bspline(tmp_path, testdata_path):
237237
bs_name = testdata_path / "someones_bspline_coefficients.nii.gz"
238238

239239
bsplxfm = BSplineFieldTransform(bs_name, reference=img_name)
240-
dispxfm = DisplacementsFieldTransform(disp_name)
240+
dispxfm = DenseFieldTransform(disp_name)
241241

242242
out_disp = dispxfm.apply(img_name)
243243
out_bspl = bsplxfm.apply(img_name)

0 commit comments

Comments
 (0)