@@ -1073,10 +1073,10 @@ def get_fitting_dig(info, dig_kinds="auto", exclude_frontal=True, verbose=None):
10731073
10741074
10751075@verbose
1076- def _fit_sphere_to_headshape (info , dig_kinds , verbose = None ):
1076+ def _fit_sphere_to_headshape (info , dig_kinds , * , verbose = None ):
10771077 """Fit a sphere to the given head shape."""
10781078 hsp = get_fitting_dig (info , dig_kinds )
1079- radius , origin_head = _fit_sphere (np .array (hsp ), disp = False )
1079+ radius , origin_head = _fit_sphere (np .array (hsp ))
10801080 # compute origin in device coordinates
10811081 dev_head_t = info ["dev_head_t" ]
10821082 if dev_head_t is None :
@@ -1105,36 +1105,16 @@ def _fit_sphere_to_headshape(info, dig_kinds, verbose=None):
11051105 return radius , origin_head , origin_device
11061106
11071107
1108- def _fit_sphere (points , disp = "auto" ):
1108+ def _fit_sphere (points ):
11091109 """Fit a sphere to an arbitrary set of points."""
1110- if isinstance (disp , str ) and disp == "auto" :
1111- disp = True if logger .level <= 20 else False
1112- # initial guess for center and radius
1113- radii = (np .max (points , axis = 1 ) - np .min (points , axis = 1 )) / 2.0
1114- radius_init = radii .mean ()
1115- center_init = np .median (points , axis = 0 )
1116-
1117- # optimization
1118- x0 = np .concatenate ([center_init , [radius_init ]])
1119-
1120- def cost_fun (center_rad ):
1121- d = np .linalg .norm (points - center_rad [:3 ], axis = 1 ) - center_rad [3 ]
1122- d *= d
1123- return d .sum ()
1124-
1125- def constraint (center_rad ):
1126- return center_rad [3 ] # radius must be >= 0
1127-
1128- x_opt = fmin_cobyla (
1129- cost_fun ,
1130- x0 ,
1131- constraint ,
1132- rhobeg = radius_init ,
1133- rhoend = radius_init * 1e-6 ,
1134- disp = disp ,
1135- )
1136-
1137- origin , radius = x_opt [:3 ], x_opt [3 ]
1110+ # linear least-squares sphere fit, see for example
1111+ # https://stackoverflow.com/a/78909044
1112+ # TODO: At some point we should maybe reject outliers first...
1113+ A = np .c_ [2 * points , np .ones ((len (points ), 1 ))]
1114+ b = (points ** 2 ).sum (axis = 1 )
1115+ x , _ , _ , _ = np .linalg .lstsq (A , b , rcond = 1e-6 )
1116+ origin = x [:3 ]
1117+ radius = np .sqrt (x [0 ] ** 2 + x [1 ] ** 2 + x [2 ] ** 2 + x [3 ])
11381118 return radius , origin
11391119
11401120
0 commit comments