20
20
ImageGrid ,
21
21
SpatialReference ,
22
22
_as_homogeneous ,
23
- EQUALITY_TOL ,
24
23
)
25
24
26
25
27
- class DeformationFieldTransform (TransformBase ):
28
- """Represents a dense field of deformed locations (corresponding to each voxel) ."""
26
+ class DenseFieldTransform (TransformBase ):
27
+ """Represents dense field ( voxel-wise) transforms ."""
29
28
30
- __slots__ = [ "_field" ]
29
+ __slots__ = ( "_field" , "_displacements" )
31
30
32
- def __init__ (self , field , reference = None ):
31
+ def __init__ (self , field = None , displacements = True , reference = None ):
33
32
"""
34
- Create a dense deformation field transform.
33
+ Create a dense field transform.
34
+
35
+ Converting to a field of deformations is straightforward by just adding the corresponding
36
+ displacement to the :math:`(x, y, z)` coordinates of each voxel.
37
+ Numerically, deformation fields are less susceptible to rounding errors
38
+ than displacements fields.
39
+ SPM generally prefers deformations for that reason.
35
40
36
41
Example
37
42
-------
38
- >>> DeformationFieldTransform (test_dir / "someones_displacement_field.nii.gz")
39
- <DeformationFieldTransform [3D] (57, 67, 56)>
43
+ >>> DenseFieldTransform (test_dir / "someones_displacement_field.nii.gz")
44
+ <DenseFieldTransform [3D] (57, 67, 56)>
40
45
41
46
"""
47
+ if field is None and reference is None :
48
+ raise TransformError ("DenseFieldTransforms require a spatial reference" )
49
+
42
50
super ().__init__ ()
43
51
44
- field = _ensure_image (field )
45
- self ._field = np .squeeze (
46
- np .asanyarray (field .dataobj ) if hasattr (field , "dataobj" ) else field
47
- )
52
+ if field is not None :
53
+ field = _ensure_image (field )
54
+ self ._field = np .squeeze (
55
+ np .asanyarray (field .dataobj ) if hasattr (field , "dataobj" ) else field
56
+ )
57
+ else :
58
+ self ._field = np .zeros ((* reference .shape , reference .ndim ), dtype = "float32" )
59
+ displacements = True
48
60
49
61
try :
50
62
self .reference = ImageGrid (
@@ -64,6 +76,12 @@ def __init__(self, field, reference=None):
64
76
"the number of dimensions (%d)" % (self ._field .shape [- 1 ], ndim )
65
77
)
66
78
79
+ if displacements :
80
+ self ._displacements = self ._field
81
+ # Convert from displacements to deformations fields
82
+ # (just add the origin to the displacements vector)
83
+ self ._field += self .reference .ndcoords .T .reshape (self ._field .shape )
84
+
67
85
def __repr__ (self ):
68
86
"""Beautify the python representation."""
69
87
return f"<{ self .__class__ .__name__ } [{ self ._field .shape [- 1 ]} D] { self ._field .shape [:3 ]} >"
@@ -93,13 +111,23 @@ def map(self, x, inverse=False):
93
111
94
112
Examples
95
113
--------
96
- >>> xfm = DeformationFieldTransform(test_dir / "someones_displacement_field.nii.gz")
114
+ >>> xfm = DenseFieldTransform(
115
+ ... test_dir / "someones_displacement_field.nii.gz",
116
+ ... displacements=False,
117
+ ... )
97
118
>>> xfm.map([-6.5, -36., -19.5]).tolist()
98
119
[[0.0, -0.47516798973083496, 0.0]]
99
120
100
121
>>> xfm.map([[-6.5, -36., -19.5], [-1., -41.5, -11.25]]).tolist()
101
122
[[0.0, -0.47516798973083496, 0.0], [0.0, -0.538356602191925, 0.0]]
102
123
124
+ >>> xfm = DenseFieldTransform(
125
+ ... test_dir / "someones_displacement_field.nii.gz",
126
+ ... displacements=True,
127
+ ... )
128
+ >>> xfm.map([[-6.5, -36., -19.5], [-1., -41.5, -11.25]]).tolist()
129
+ [[-6.5, -36.47516632080078, -19.5], [-1.0, -42.03835678100586, -11.25]]
130
+
103
131
"""
104
132
105
133
if inverse is True :
@@ -117,25 +145,33 @@ def __matmul__(self, b):
117
145
118
146
Examples
119
147
--------
120
- >>> xfm = DeformationFieldTransform(test_dir / "someones_displacement_field.nii.gz")
121
- >>> xfm2 = xfm @ TransformBase()
122
- >>> xfm == xfm2
148
+ >>> deff = DenseFieldTransform(
149
+ ... test_dir / "someones_displacement_field.nii.gz",
150
+ ... displacements=False,
151
+ ... )
152
+ >>> deff2 = deff @ TransformBase()
153
+ >>> deff == deff2
154
+ True
155
+
156
+ >>> disp = DenseFieldTransform(test_dir / "someones_displacement_field.nii.gz")
157
+ >>> disp2 = disp @ TransformBase()
158
+ >>> disp == disp2
123
159
True
124
160
125
161
"""
126
162
retval = b .map (
127
163
self ._field .reshape ((- 1 , self ._field .shape [- 1 ]))
128
164
).reshape (self ._field .shape )
129
- return DeformationFieldTransform (retval , reference = self .reference )
165
+ return DenseFieldTransform (retval , displacements = False , reference = self .reference )
130
166
131
167
def __eq__ (self , other ):
132
168
"""
133
169
Overload equals operator.
134
170
135
171
Examples
136
172
--------
137
- >>> xfm1 = DeformationFieldTransform (test_dir / "someones_displacement_field.nii.gz")
138
- >>> xfm2 = DeformationFieldTransform (test_dir / "someones_displacement_field.nii.gz")
173
+ >>> xfm1 = DenseFieldTransform (test_dir / "someones_displacement_field.nii.gz")
174
+ >>> xfm2 = DenseFieldTransform (test_dir / "someones_displacement_field.nii.gz")
139
175
>>> xfm1 == xfm2
140
176
True
141
177
@@ -145,41 +181,6 @@ def __eq__(self, other):
145
181
warnings .warn ("Fields are equal, but references do not match." )
146
182
return _eq
147
183
148
-
149
- class DisplacementsFieldTransform (DeformationFieldTransform ):
150
- """
151
- Represents a dense field of displacements (one vector per voxel).
152
-
153
- Converting to a field of deformations is straightforward by just adding the corresponding
154
- displacement to the :math:`(x, y, z)` coordinates of each voxel.
155
- Numerically, deformation fields are less susceptible to rounding errors
156
- than displacements fields.
157
- SPM generally prefers deformations for that reason.
158
-
159
- """
160
-
161
- __slots__ = ["_displacements" ]
162
-
163
- def __init__ (self , field , reference = None ):
164
- """
165
- Create a transform supported by a field of voxel-wise displacements.
166
-
167
- Example
168
- -------
169
- >>> xfm = DisplacementsFieldTransform(test_dir / "someones_displacement_field.nii.gz")
170
- >>> xfm
171
- <DisplacementsFieldTransform[3D] (57, 67, 56)>
172
-
173
- >>> xfm.map([[-6.5, -36., -19.5], [-1., -41.5, -11.25]]).tolist()
174
- [[-6.5, -36.47516632080078, -19.5], [-1.0, -42.03835678100586, -11.25]]
175
-
176
- """
177
- super ().__init__ (field , reference = reference )
178
- self ._displacements = self ._field
179
- # Convert from displacements to deformations fields
180
- # (just add the origin to the displacements vector)
181
- self ._field += self .reference .ndcoords .T .reshape (self ._field .shape )
182
-
183
184
@classmethod
184
185
def from_filename (cls , filename , fmt = "X5" ):
185
186
_factory = {
@@ -193,7 +194,7 @@ def from_filename(cls, filename, fmt="X5"):
193
194
return cls (_factory [fmt ].from_filename (filename ))
194
195
195
196
196
- load = DisplacementsFieldTransform .from_filename
197
+ load = DenseFieldTransform .from_filename
197
198
198
199
199
200
class BSplineFieldTransform (TransformBase ):
@@ -239,8 +240,9 @@ def to_field(self, reference=None, dtype="float32"):
239
240
# 1 x Nvox : (1 x K) @ (K x Nvox)
240
241
field [:, d ] = self ._coeffs [..., d ].reshape (- 1 ) @ self ._weights
241
242
242
- return DisplacementsFieldTransform (
243
- field .astype (dtype ).reshape (* _ref .shape , - 1 ), reference = _ref )
243
+ return DenseFieldTransform (
244
+ field .astype (dtype ).reshape (* _ref .shape , - 1 ), reference = _ref
245
+ )
244
246
245
247
def apply (
246
248
self ,
0 commit comments