Skip to content

Commit f31e223

Browse files
committed
enh: load ITK's .mat files with Affine's loaders
1 parent 8b31bf2 commit f31e223

File tree

3 files changed

+36
-12
lines changed

3 files changed

+36
-12
lines changed

nitransforms/io/itk.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,14 @@ def from_string(cls, string):
204204
parameters[:3, :3] = vals[:-3].reshape((3, 3))
205205
parameters[:3, 3] = vals[-3:]
206206
sa["parameters"] = parameters
207-
return tf
207+
208+
# Try to double-dip and see if there are more transforms
209+
try:
210+
cls.from_string("\n".join(lines[4:8]))
211+
except TransformFileError:
212+
return tf
213+
else:
214+
raise TransformFileError("More than one linear transform found.")
208215

209216

210217
class ITKLinearTransformArray(BaseLinearTransformList):

nitransforms/linear.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -203,18 +203,33 @@ def from_filename(cls, filename, fmt=None, reference=None, moving=None):
203203
"""Create an affine from a transform file."""
204204
fmtlist = [fmt] if fmt is not None else ("itk", "lta", "afni", "fsl")
205205

206+
is_array = cls != Affine
207+
208+
errors = []
206209
for potential_fmt in fmtlist:
210+
if (potential_fmt == "itk" and Path(filename).suffix == ".mat"):
211+
is_array = False
212+
cls = Affine
213+
207214
try:
208-
struct = get_linear_factory(potential_fmt).from_filename(filename)
209-
matrix = struct.to_ras(reference=reference, moving=moving)
210-
if cls == Affine:
211-
if np.shape(matrix)[0] != 1:
212-
raise TypeError("Cannot load transform array '%s'" % filename)
213-
matrix = matrix[0]
214-
return cls(matrix, reference=reference)
215-
except (TransformFileError, FileNotFoundError):
215+
struct = get_linear_factory(
216+
potential_fmt,
217+
is_array=is_array
218+
).from_filename(filename)
219+
except (TransformFileError, FileNotFoundError) as err:
220+
errors.append((potential_fmt, err))
216221
continue
217222

223+
matrix = struct.to_ras(reference=reference, moving=moving)
224+
225+
# Process matrix
226+
if not is_array and np.ndim(matrix) == 3:
227+
if np.shape(matrix)[0] != 1:
228+
raise TypeError("Cannot load transform array '%s'" % filename)
229+
matrix = matrix[0]
230+
231+
return cls(matrix, reference=reference)
232+
218233
raise TransformFileError(
219234
f"Could not open <{filename}> (formats tried: {', '.join(fmtlist)})."
220235
)
@@ -499,6 +514,8 @@ def load(filename, fmt=None, reference=None, moving=None):
499514
xfm = LinearTransformsMapping.from_filename(
500515
filename, fmt=fmt, reference=reference, moving=moving
501516
)
502-
if len(xfm) == 1:
503-
return xfm[0]
517+
518+
if isinstance(xfm, LinearTransformsMapping) and len(xfm) == 1:
519+
xfm = xfm[0]
520+
504521
return xfm

nitransforms/tests/test_linear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def test_linear_typeerrors1(matrix):
4444

4545
def test_linear_typeerrors2(data_path):
4646
"""Exercise errors in Affine creation."""
47-
with pytest.raises(TypeError):
47+
with pytest.raises(io.TransformFileError):
4848
nitl.Affine.from_filename(data_path / "itktflist.tfm", fmt="itk")
4949

5050

0 commit comments

Comments
 (0)