20
20
from ..doa import Grid , GridSphere , cart2spher , fibonacci_spherical_sampling , spher2cart
21
21
from ..utilities import requires_matplotlib , resample
22
22
from .interp import spherical_interpolation
23
- from .sofa import get_sofa_db , open_sofa_file
23
+ from .sofa import open_sofa_file
24
24
25
25
26
26
class MeasuredDirectivity (Directivity ):
@@ -205,23 +205,44 @@ def __init__(
205
205
source_labels = None ,
206
206
):
207
207
self .path = Path (path )
208
- self .mic_labels , self .source_labels = self ._set_labels (
209
- self .path , mic_labels , source_labels
210
- )
211
208
212
209
if file_reader_callback is None :
210
+ # default reader is for SOFA files
213
211
file_reader_callback = open_sofa_file
214
212
215
213
(
216
214
self .impulse_responses , # (n_sources, n_mics, taps)
217
- self .sources_loc , # (3, n_sources), spherical coordinates
218
- self .mics_loc , # (3, n_mics), cartesian coordinates
219
215
self .fs ,
216
+ self .source_locs , # (3, n_sources), spherical coordinates
217
+ self .mic_locs , # (3, n_mics), cartesian coordinates
218
+ src_labels_file ,
219
+ mic_labels_file ,
220
220
) = file_reader_callback (
221
221
path = self .path ,
222
222
fs = fs ,
223
223
)
224
224
225
+ if mic_labels is None :
226
+ self .mic_labels = mic_labels_file
227
+ else :
228
+ if len (mic_labels ) != self .mic_locs .shape [1 ]:
229
+ breakpoint ()
230
+ raise ValueError (
231
+ f"Number of labels provided ({ len (mic_labels )} ) does not match the "
232
+ f"number of microphones ({ self .mic_locs .shape [1 ]} )"
233
+ )
234
+ self .mic_labels = mic_labels
235
+
236
+ if source_labels is None :
237
+ self .source_labels = src_labels_file
238
+ else :
239
+ if len (source_labels ) != self .source_locs .shape [1 ]:
240
+ raise ValueError (
241
+ f"Number of labels provided ({ len (source_labels )} ) does not match "
242
+ f"the number of sources ({ self .source_locs .shape [1 ]} )"
243
+ )
244
+ self .source_labels = source_labels
245
+
225
246
self .interp_order = interp_order
226
247
self .interp_n_points = interp_n_points
227
248
@@ -233,16 +254,6 @@ def __init__(
233
254
else :
234
255
self .interp_grid = None
235
256
236
- def _set_labels (self , path , mic_labels , src_labels ):
237
- sofa_db = get_sofa_db ()
238
- if path .stem in sofa_db :
239
- info = sofa_db [path .stem ]
240
- if info .type == "microphones" and mic_labels is None :
241
- mic_labels = info .contains
242
- elif info .type == "sources" and src_labels is None :
243
- src_labels = info .contains
244
- return mic_labels , src_labels
245
-
246
257
def _interpolate (self , type , mid , grid , impulse_responses ):
247
258
if self .interp_order is None :
248
259
return grid , impulse_responses
@@ -272,40 +283,57 @@ def _get_measurement_index(self, meas_id, labels):
272
283
273
284
raise ValueError (f"Measurement id { meas_id } not found" )
274
285
275
- def get_microphone (self , measurement_id , orientation , offset = None ):
286
+ def get_mic_position (self , measurement_id ):
287
+ mid = self ._get_measurement_index (measurement_id , self .mic_labels )
288
+
289
+ if not (0 <= mid < self .mic_locs .shape [1 ]):
290
+ raise ValueError (f"Microphone id { mid } not found" )
291
+
292
+ return self .mic_locs [:, mid ]
293
+
294
+ def get_source_position (self , measurement_id ):
295
+ mid = self ._get_measurement_index (measurement_id , self .source_labels )
296
+
297
+ if not (0 <= mid < self .source_locs .shape [1 ]):
298
+ raise ValueError (f"Source id { mid } not found" )
299
+
300
+ # convert to cartesian since the sources are stored by
301
+ # default in spherical coordinates
302
+ pos = spher2cart (* self .source_locs [:, mid ])
303
+
304
+ return pos
305
+
306
+ def get_mic_directivity (self , measurement_id , orientation ):
276
307
mid = self ._get_measurement_index (measurement_id , self .mic_labels )
277
308
309
+ if not (0 <= mid < self .mic_locs .shape [1 ]):
310
+ raise ValueError (f"Microphone id { mid } not found" )
311
+
278
312
# select the measurements corresponding to the mic id
279
313
ir = self .impulse_responses [:, mid , :]
280
- src_grid = GridSphere (spherical_points = self .sources_loc [:2 ])
281
-
282
- mic_loc = self .mics_loc [:, mid ]
283
- if offset is not None :
284
- mic_loc += offset
314
+ src_grid = GridSphere (spherical_points = self .source_locs [:2 ])
285
315
286
316
# interpolate the IR
287
317
grid , ir = self ._interpolate ("mic" , mid , src_grid , ir )
288
318
289
319
dir_obj = MeasuredDirectivity (orientation , grid , ir , self .fs )
290
- return mic_loc , dir_obj
320
+ return dir_obj
291
321
292
- def get_source (self , measurement_id , orientation , offset = None ):
322
+ def get_source_directivity (self , measurement_id , orientation ):
293
323
mid = self ._get_measurement_index (measurement_id , self .source_labels )
294
324
325
+ if not (0 <= mid < self .source_locs .shape [1 ]):
326
+ raise ValueError (f"Source id { mid } not found" )
327
+
295
328
# select the measurements corresponding to the mic id
296
329
ir = self .impulse_responses [mid , :, :]
297
330
298
331
# here we need to swap the coordinate types
299
- mic_pos = np .array (cart2spher (self .mics_loc ))
332
+ mic_pos = np .array (cart2spher (self .mic_locs ))
300
333
mic_grid = GridSphere (spherical_points = mic_pos [:2 ])
301
334
302
- # source location
303
- src_loc = spher2cart (* self .sources_loc [:, mid ])
304
- if offset is not None :
305
- src_loc += offset
306
-
307
335
# interpolate the IR
308
336
grid , ir = self ._interpolate ("source" , mid , mic_grid , ir )
309
337
310
338
dir_obj = MeasuredDirectivity (orientation , grid , ir , self .fs )
311
- return src_loc , dir_obj
339
+ return dir_obj
0 commit comments