@@ -203,18 +203,33 @@ def from_filename(cls, filename, fmt=None, reference=None, moving=None):
203
203
"""Create an affine from a transform file."""
204
204
fmtlist = [fmt ] if fmt is not None else ("itk" , "lta" , "afni" , "fsl" )
205
205
206
+ is_array = cls != Affine
207
+
208
+ errors = []
206
209
for potential_fmt in fmtlist :
210
+ if (potential_fmt == "itk" and Path (filename ).suffix == ".mat" ):
211
+ is_array = False
212
+ cls = Affine
213
+
207
214
try :
208
- struct = get_linear_factory (potential_fmt ).from_filename (filename )
209
- matrix = struct .to_ras (reference = reference , moving = moving )
210
- if cls == Affine :
211
- if np .shape (matrix )[0 ] != 1 :
212
- raise TypeError ("Cannot load transform array '%s'" % filename )
213
- matrix = matrix [0 ]
214
- return cls (matrix , reference = reference )
215
- except (TransformFileError , FileNotFoundError ):
215
+ struct = get_linear_factory (
216
+ potential_fmt ,
217
+ is_array = is_array
218
+ ).from_filename (filename )
219
+ except (TransformFileError , FileNotFoundError ) as err :
220
+ errors .append ((potential_fmt , err ))
216
221
continue
217
222
223
+ matrix = struct .to_ras (reference = reference , moving = moving )
224
+
225
+ # Process matrix
226
+ if not is_array and np .ndim (matrix ) == 3 :
227
+ if np .shape (matrix )[0 ] != 1 :
228
+ raise TypeError ("Cannot load transform array '%s'" % filename )
229
+ matrix = matrix [0 ]
230
+
231
+ return cls (matrix , reference = reference )
232
+
218
233
raise TransformFileError (
219
234
f"Could not open <{ filename } > (formats tried: { ', ' .join (fmtlist )} )."
220
235
)
@@ -499,6 +514,8 @@ def load(filename, fmt=None, reference=None, moving=None):
499
514
xfm = LinearTransformsMapping .from_filename (
500
515
filename , fmt = fmt , reference = reference , moving = moving
501
516
)
502
- if len (xfm ) == 1 :
503
- return xfm [0 ]
517
+
518
+ if isinstance (xfm , LinearTransformsMapping ) and len (xfm ) == 1 :
519
+ xfm = xfm [0 ]
520
+
504
521
return xfm
0 commit comments