Skip to content

Commit 65107c4

Browse files
committed
SOFA measurement labels are read from the database when available. Makes API more consistent.
1 parent ccacfed commit 65107c4

File tree

7 files changed

+261
-110
lines changed

7 files changed

+261
-110
lines changed

examples/simulate_binaural_recording.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -76,16 +76,15 @@
7676
fs=fs,
7777
interp_order=args.interp_order,
7878
interp_n_points=args.interp_n_points,
79-
mic_labels=["left", "right"],
8079
)
8180

8281
orientation = DirectionVector(
8382
azimuth=azimuth_deg, colatitude=colatitude_deg, degrees=True
8483
)
8584

86-
_, dir_left = hrtf.get_microphone("left", orientation=orientation)
85+
dir_left = hrtf.get_mic_directivity("left", orientation=orientation)
8786

88-
_, dir_right = hrtf.get_microphone("right", orientation=orientation)
87+
dir_right = hrtf.get_mic_directivity("right", orientation=orientation)
8988

9089
room_dim = [6, 6, 2.4]
9190

@@ -149,8 +148,8 @@
149148

150149
room.add_source([1.5, 3.01, 1.044], signal=speech)
151150

152-
room.add_microphone([1.1, 3.01, 1.8], directivity=dir_left)
153-
room.add_microphone([1.1, 3.01, 1.8], directivity=dir_right)
151+
room.add_microphone([1.1, 3.01, 2.2], directivity=dir_left)
152+
room.add_microphone([1.1, 3.01, 2.2], directivity=dir_right)
154153

155154
room.simulate()
156155

examples/simulation_with_measured_directivity.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,8 @@ class DirectionVector
8282

8383
vec_54_73 = DirectionVector(azimuth=54, colatitude=73, degrees=True)
8484

85-
_, dir_obj_Dmic = akg.get_microphone("AKG_c414K", orientation=vec_54_73)
86-
_, dir_obj_Emic = eigenmike.get_microphone("EM_32_9", orientation=vec_54_73)
85+
dir_obj_Dmic = akg.get_mic_directivity("AKG_c414K", orientation=vec_54_73)
86+
dir_obj_Emic = eigenmike.get_mic_directivity("EM_32_9", orientation=vec_54_73)
8787

8888
dir_obj_Cmic = CardioidFamily(
8989
orientation=DirectionVector(azimuth=90, colatitude=123, degrees=True),

pyroomacoustics/data/sofa_files.json

+2-2
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@
100100
"type": "microphones",
101101
"url": "http://sofacoustics.org/data/database/mit/mit_kemar_large_pinna.sofa",
102102
"homepage": "https://sound.media.mit.edu/resources/KEMAR/",
103-
"license": "custom-attribution",
103+
"license": "This data is Copyright 1994 by the MIT Media Laboratory. It is provided free with no restrictions on use, provided the authors are cited when the data is used in any research or commercial application.",
104104
"contains": [
105105
"left",
106106
"right"
@@ -111,7 +111,7 @@
111111
"type": "microphones",
112112
"url": "http://sofacoustics.org/data/database/mit/mit_kemar_normal_pinna.sofa",
113113
"homepage": "https://sound.media.mit.edu/resources/KEMAR/",
114-
"license": "custom-attribution",
114+
"license": "This data is Copyright 1994 by the MIT Media Laboratory. It is provided free with no restrictions on use, provided the authors are cited when the data is used in any research or commercial application.",
115115
"contains": [
116116
"left",
117117
"right"

pyroomacoustics/datasets/sofa.py

+105-13
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import json
2+
import typing as tp
3+
from dataclasses import dataclass
24
from pathlib import Path
35

46
from .utils import AttrDict, download_multiple
@@ -7,6 +9,42 @@
79
DEFAULT_SOFA_PATH = _pra_data_folder / "sofa"
810
SOFA_INFO = _pra_data_folder / "sofa_files.json"
911

12+
_DIRPAT_FILES = [
13+
"Soundfield_ST450_CUBE",
14+
"AKG_c480_c414_CUBE",
15+
"Oktava_MK4012_CUBE",
16+
"LSPs_HATS_GuitarCabinets_Akustikmessplatz",
17+
]
18+
19+
20+
def is_dirpat(name):
21+
if isinstance(name, Path):
22+
name = name.stem
23+
return name in _DIRPAT_FILES
24+
25+
26+
def get_sofa_db():
27+
# we want to avoid loading the database multiple times
28+
global sofa_db
29+
try:
30+
return sofa_db
31+
except NameError:
32+
sofa_db = SOFADatabase()
33+
return sofa_db
34+
35+
36+
def resolve_sofa_path(path):
37+
path = Path(path)
38+
39+
if path.exists():
40+
return path
41+
42+
sofa_db = get_sofa_db()
43+
if path.stem in sofa_db:
44+
return Path(sofa_db[path.stem].path)
45+
46+
raise ValueError(f"SOFA file {path} could not be found")
47+
1048

1149
def get_sofa_db_info():
1250
with open(SOFA_INFO, "r") as f:
@@ -51,7 +89,66 @@ def download_sofa_files(path=None, overwrite=False, verbose=False, no_fail=False
5189
return list(files.keys())
5290

5391

92+
@dataclass
93+
class SOFAFileInfo:
94+
"""
95+
A class to store information about a SOFA file
96+
97+
Parameters
98+
----------
99+
path: Path
100+
The path to the SOFA file
101+
supported: bool
102+
Whether the SOFA file is supported by Pyroom Acoustics
103+
type: str
104+
The type of device (e.g., 'sources' or 'microphones')
105+
url: str
106+
The URL where the SOFA file can be downloaded
107+
homepage: str
108+
The URL of the SOFA file homepage
109+
license: str
110+
The license of the SOFA file
111+
contains: List[str]
112+
The labels of the sources/microphones contained in the SOFA file,
113+
or``None`` if the information is not available
114+
"""
115+
116+
path: Path
117+
supported: bool = True
118+
type: str = "unknown"
119+
url: str = "unknown"
120+
homepage: str = "unknown"
121+
license: str = "unknown"
122+
contains: tp.List[str] = None
123+
124+
54125
class SOFADatabase(dict):
126+
"""
127+
A small database of SOFA files containing source/microphone directional
128+
impulse responses
129+
130+
The database object is a dictionary-like object where the keys are the
131+
names of the SOFA files and the values are objects with the following
132+
attributes:
133+
134+
.. code-block:: python
135+
136+
db = SOFADatabase()
137+
138+
# type of device: 'sources' or 'microphones'
139+
db["Soundfield_ST450_CUBE"].type
140+
141+
# list of the labels of the sources/microphones
142+
db["Soundfield_ST450_CUBE"].contains
143+
144+
145+
Parameters
146+
----------
147+
download: bool, optional
148+
If set to `True`, the SOFA files are downloaded if they are not already
149+
present in the default folder
150+
"""
151+
55152
def __init__(self, download=True):
56153
super().__init__()
57154

@@ -63,29 +160,22 @@ def __init__(self, download=True):
63160
for name, info in get_sofa_db_info().items():
64161
path = self.root / f"{name}.sofa"
65162
if path.exists():
66-
dict.__setitem__(self, name, AttrDict(info))
67-
self[name]["path"] = path
163+
dict.__setitem__(self, name, SOFAFileInfo(path=path, **info))
68164

69165
for path in DEFAULT_SOFA_PATH.glob("*.sofa"):
70166
name = path.stem
71167
if name not in self:
72168
dict.__setitem__(
73169
self,
74170
name,
75-
AttrDict(
76-
{
77-
"path": path,
78-
"supported": "???",
79-
"type": "unknown",
80-
"url": "???",
81-
"homepage": "???",
82-
"license": "???",
83-
"contains": None,
84-
}
85-
),
171+
SOFAFileInfo(path=path),
86172
)
87173

88174
def list(self):
175+
"""
176+
Print a list of the available SOFA files and the labels of the
177+
different devices they contain
178+
"""
89179
for name, info in self.items():
90180
print(f"- {name} ({info.type})")
91181
if info.contains is not None:
@@ -94,10 +184,12 @@ def list(self):
94184

95185
@property
96186
def root(self):
187+
"""The path to the folder containing the SOFA files"""
97188
return DEFAULT_SOFA_PATH
98189

99190
@property
100191
def db_info_path(self):
192+
"""The path to the JSON file containing the SOFA files information"""
101193
return SOFA_INFO
102194

103195
def __setitem__(self, key, val):

pyroomacoustics/directivities/measured.py

+59-31
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from ..doa import Grid, GridSphere, cart2spher, fibonacci_spherical_sampling, spher2cart
2121
from ..utilities import requires_matplotlib, resample
2222
from .interp import spherical_interpolation
23-
from .sofa import get_sofa_db, open_sofa_file
23+
from .sofa import open_sofa_file
2424

2525

2626
class MeasuredDirectivity(Directivity):
@@ -205,23 +205,44 @@ def __init__(
205205
source_labels=None,
206206
):
207207
self.path = Path(path)
208-
self.mic_labels, self.source_labels = self._set_labels(
209-
self.path, mic_labels, source_labels
210-
)
211208

212209
if file_reader_callback is None:
210+
# default reader is for SOFA files
213211
file_reader_callback = open_sofa_file
214212

215213
(
216214
self.impulse_responses, # (n_sources, n_mics, taps)
217-
self.sources_loc, # (3, n_sources), spherical coordinates
218-
self.mics_loc, # (3, n_mics), cartesian coordinates
219215
self.fs,
216+
self.source_locs, # (3, n_sources), spherical coordinates
217+
self.mic_locs, # (3, n_mics), cartesian coordinates
218+
src_labels_file,
219+
mic_labels_file,
220220
) = file_reader_callback(
221221
path=self.path,
222222
fs=fs,
223223
)
224224

225+
if mic_labels is None:
226+
self.mic_labels = mic_labels_file
227+
else:
228+
if len(mic_labels) != self.mic_locs.shape[1]:
229+
breakpoint()
230+
raise ValueError(
231+
f"Number of labels provided ({len(mic_labels)}) does not match the "
232+
f"number of microphones ({self.mic_locs.shape[1]})"
233+
)
234+
self.mic_labels = mic_labels
235+
236+
if source_labels is None:
237+
self.source_labels = src_labels_file
238+
else:
239+
if len(source_labels) != self.source_locs.shape[1]:
240+
raise ValueError(
241+
f"Number of labels provided ({len(source_labels)}) does not match "
242+
f"the number of sources ({self.source_locs.shape[1]})"
243+
)
244+
self.source_labels = source_labels
245+
225246
self.interp_order = interp_order
226247
self.interp_n_points = interp_n_points
227248

@@ -233,16 +254,6 @@ def __init__(
233254
else:
234255
self.interp_grid = None
235256

236-
def _set_labels(self, path, mic_labels, src_labels):
237-
sofa_db = get_sofa_db()
238-
if path.stem in sofa_db:
239-
info = sofa_db[path.stem]
240-
if info.type == "microphones" and mic_labels is None:
241-
mic_labels = info.contains
242-
elif info.type == "sources" and src_labels is None:
243-
src_labels = info.contains
244-
return mic_labels, src_labels
245-
246257
def _interpolate(self, type, mid, grid, impulse_responses):
247258
if self.interp_order is None:
248259
return grid, impulse_responses
@@ -272,40 +283,57 @@ def _get_measurement_index(self, meas_id, labels):
272283

273284
raise ValueError(f"Measurement id {meas_id} not found")
274285

275-
def get_microphone(self, measurement_id, orientation, offset=None):
286+
def get_mic_position(self, measurement_id):
287+
mid = self._get_measurement_index(measurement_id, self.mic_labels)
288+
289+
if not (0 <= mid < self.mic_locs.shape[1]):
290+
raise ValueError(f"Microphone id {mid} not found")
291+
292+
return self.mic_locs[:, mid]
293+
294+
def get_source_position(self, measurement_id):
295+
mid = self._get_measurement_index(measurement_id, self.source_labels)
296+
297+
if not (0 <= mid < self.source_locs.shape[1]):
298+
raise ValueError(f"Source id {mid} not found")
299+
300+
# convert to cartesian since the sources are stored by
301+
# default in spherical coordinates
302+
pos = spher2cart(*self.source_locs[:, mid])
303+
304+
return pos
305+
306+
def get_mic_directivity(self, measurement_id, orientation):
276307
mid = self._get_measurement_index(measurement_id, self.mic_labels)
277308

309+
if not (0 <= mid < self.mic_locs.shape[1]):
310+
raise ValueError(f"Microphone id {mid} not found")
311+
278312
# select the measurements corresponding to the mic id
279313
ir = self.impulse_responses[:, mid, :]
280-
src_grid = GridSphere(spherical_points=self.sources_loc[:2])
281-
282-
mic_loc = self.mics_loc[:, mid]
283-
if offset is not None:
284-
mic_loc += offset
314+
src_grid = GridSphere(spherical_points=self.source_locs[:2])
285315

286316
# interpolate the IR
287317
grid, ir = self._interpolate("mic", mid, src_grid, ir)
288318

289319
dir_obj = MeasuredDirectivity(orientation, grid, ir, self.fs)
290-
return mic_loc, dir_obj
320+
return dir_obj
291321

292-
def get_source(self, measurement_id, orientation, offset=None):
322+
def get_source_directivity(self, measurement_id, orientation):
293323
mid = self._get_measurement_index(measurement_id, self.source_labels)
294324

325+
if not (0 <= mid < self.source_locs.shape[1]):
326+
raise ValueError(f"Source id {mid} not found")
327+
295328
# select the measurements corresponding to the mic id
296329
ir = self.impulse_responses[mid, :, :]
297330

298331
# here we need to swap the coordinate types
299-
mic_pos = np.array(cart2spher(self.mics_loc))
332+
mic_pos = np.array(cart2spher(self.mic_locs))
300333
mic_grid = GridSphere(spherical_points=mic_pos[:2])
301334

302-
# source location
303-
src_loc = spher2cart(*self.sources_loc[:, mid])
304-
if offset is not None:
305-
src_loc += offset
306-
307335
# interpolate the IR
308336
grid, ir = self._interpolate("source", mid, mic_grid, ir)
309337

310338
dir_obj = MeasuredDirectivity(orientation, grid, ir, self.fs)
311-
return src_loc, dir_obj
339+
return dir_obj

0 commit comments

Comments
 (0)