Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for recorded source/mic directivities from DIRPAT database (#259) #302

Merged
merged 66 commits into from
Oct 29, 2024
Merged
Changes from 1 commit
Commits
Show all changes
66 commits
Select commit Hold shift + click to select a range
54ac420
Support for recorded source/mic directivities from DIRPAT database (#…
prerak23 Feb 2, 2023
63d13f8
merge with master
fakufaku Feb 6, 2023
4274ec2
remove deprecated np.float type alias from numpy. run black
fakufaku Feb 6, 2023
8853437
isort
fakufaku Feb 6, 2023
02af651
isort
fakufaku Feb 8, 2023
76f8b96
Adds download functions for the SOFA files from DIRPAT. Adds tests fo…
fakufaku Feb 9, 2023
88e1c5b
Adds workaround to CI for windows and py3.7
fakufaku Feb 9, 2023
a2e1080
Changes the EM32 sofa file url to official url
fakufaku Feb 10, 2023
2c04a8c
open_sofa_interpolate.py: vectorizes fibonnaci computations, use sphe…
fakufaku Feb 13, 2023
910e5b5
Fixes the weights in the weighted pseudo-inverse used in the interpol…
fakufaku Feb 13, 2023
1a45b7c
Moves SOFA open function out of the DIRPATInterpolate class.
fakufaku Feb 14, 2023
70a562e
Separates the spherical interpolation from the DIRPAT and SOFA classe…
fakufaku Feb 14, 2023
b662934
Merge branch 'master' into dev/dirpat
fakufaku Feb 14, 2023
683b498
Vectorizes the spherical harmonics computation function for spherical…
fakufaku Feb 14, 2023
7507118
Moves the regular grid detection function to doa.utils
fakufaku Feb 14, 2023
9f91a55
Changes the DIRPAT terminology to something more general to support m…
fakufaku Feb 15, 2023
079bb1e
black/isort
fakufaku Feb 15, 2023
5958ce0
black/sort
fakufaku Feb 15, 2023
b6ff1cb
Adds support sofe simple hrir sofa files
fakufaku Jun 3, 2023
f558bfc
Adds two SOFA files for MIT Kemar HRTF. Adds an example script for bi…
fakufaku Jun 3, 2023
f76ca8f
merge master
fakufaku Jun 5, 2023
bdb8309
Modularize the RIR building routines for ISM and RT
fakufaku Jun 10, 2023
32269ae
Fixing tests for sofa files
fakufaku Jun 12, 2023
92e4007
Adds simulation sub-package in list in setup.py
fakufaku Jun 12, 2023
9b38cca
Adds init file for simulation sub-package
fakufaku Jun 12, 2023
88a9c54
fix sofa test
fakufaku Jun 12, 2023
4922a0e
ray tracing: adds air absorption, linear interpolation of histogram
fakufaku Jun 14, 2023
5ffe18c
Fixes issue when histogram is all zeros
fakufaku Jun 14, 2023
0ab2ad9
Switch to using the origin pyroomacoustics fibonacci spherical grid f…
fakufaku Jun 16, 2023
46e30d1
moves is_dirpat check inside the open sofa file function
fakufaku Jul 4, 2023
b0d2b1d
modularize the element location parsing routine for sofa files
fakufaku Sep 29, 2023
b4abb9b
adds sofa database as attrdict
fakufaku Apr 23, 2024
3984eaa
black
fakufaku May 7, 2024
20e4106
merge master
fakufaku May 7, 2024
1c1ddc9
Improved SOFA api
fakufaku May 25, 2024
cf6802e
example
fakufaku May 25, 2024
0edebfb
Fixes the SOFA tests. Due to change of sampling algorithm (scipy.sign…
fakufaku May 25, 2024
bc016ad
lint
fakufaku May 26, 2024
e96f4ee
merge
fakufaku May 26, 2024
b6aa302
relocates the from directivities to simulation/ism
fakufaku May 26, 2024
9fb7ba7
fixes source_angle_shoebox test
fakufaku May 26, 2024
7874a01
refactors directivities sub-module and sofa sub-module into a single …
fakufaku May 26, 2024
ccacfed
fixes sofa db to include files that are present, but not in the json …
fakufaku May 26, 2024
65107c4
SOFA measurement labels are read from the database when available. Ma…
fakufaku May 26, 2024
26f4a32
removes unnecessary imports and lint
fakufaku May 26, 2024
b1a0357
Merge branch 'master' into dev/dirpat
fakufaku May 26, 2024
120ac43
CHANGELOG formatting for branch
fakufaku May 26, 2024
1d09fb8
merge
fakufaku May 27, 2024
f866eb8
Fixes the cardioid family formula in the corresponding object. Adds a…
fakufaku May 27, 2024
b3917a0
lint
fakufaku May 27, 2024
d8d79ea
changes the interface for cardioid family. Removes the enum, uses sep…
fakufaku May 27, 2024
d392bb6
lint
fakufaku May 27, 2024
d48a899
Changes the KDTree in MeasuredDirectivity from operating on spherical…
fakufaku May 27, 2024
e845041
Adds directivities (analytic and measured) to the documentation.
fakufaku May 28, 2024
f89e815
adds soxr to requirements for doc building
fakufaku May 28, 2024
8d24c10
Doc
fakufaku May 28, 2024
d7371a8
Disallow directivities for 2D rooms.
fakufaku May 28, 2024
5e1334f
Fixes directivities examples
fakufaku May 28, 2024
200e875
lint
fakufaku May 28, 2024
b6d9a6c
Fixes some docstrings
fakufaku May 28, 2024
c4c9e27
Reflecting Prerak's review (PR #302).
fakufaku Oct 27, 2024
0fd67c4
Merge with master
fakufaku Oct 27, 2024
6f5f88e
Re-fixes issue #353 in the dirpat branch.
fakufaku Oct 27, 2024
5fa9c3b
lint
fakufaku Oct 27, 2024
2c48ac2
Fixes build for win and py 3.8.
fakufaku Oct 28, 2024
325b856
Edits CHANGELOG. Adds more details about the SOFA files included incl…
fakufaku Oct 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Separates the spherical interpolation from the DIRPAT and SOFA classe…
…s. Modifies slightly doa.GridSphere object so that it can be used to hold simple spherical grids without overhead.
fakufaku committed Feb 14, 2023
commit 70a562e555bad33253968a5e17c0bc89bec5c603
21 changes: 18 additions & 3 deletions examples/simulation_with_measured_directivity.py
Original file line number Diff line number Diff line change
@@ -63,6 +63,7 @@ class DirectionVector

import os

import matplotlib.pyplot as plt
import numpy as np
from scipy import signal
from scipy.fft import fft, fftfreq
@@ -104,6 +105,7 @@ class DirectionVector
path=path_Eigenmic_file,
DIRPAT_pattern_enum="EM_32_9",
fs=16000,
no_points_on_fibo_sphere=0,
)

dir_obj_Cmic = CardioidFamily(
@@ -170,8 +172,8 @@ class DirectionVector
room = pra.ShoeBox(
room_dim,
fs=16000,
max_order=2,
materials=pra.Material(0.99),
max_order=20,
materials=pra.Material(0.5),
air_absorption=True,
ray_tracing=False,
min_phase=False,
@@ -190,7 +192,20 @@ class DirectionVector
room.compute_rir()
room.plot_rir(FD=True)

dir_mic.obj_open_sofa_inter.plot_new(freq_bin=50, depth=True)
# print(dir_mic.obj_open_sofa_inter.freq_angles_fft.shape)
# dir_mic.obj_open_sofa_inter.interpolate = False

fig = plt.figure()
for idx, fb in enumerate(range(44)):
if idx >= 5 * 10:
break
ax = fig.add_subplot(5, 10, idx + 1, projection="3d")
dir_mic.obj_open_sofa_inter.plot_new(freq_bin=fb, ax=ax, depth=True)
# ax.set_xticks([])
# ax.set_yticks([])
# ax.set_zticks([])
ax.set_title(idx)
plt.show()


rir_1_0 = room.rir[0][0]
31 changes: 18 additions & 13 deletions pyroomacoustics/directivities.py
Original file line number Diff line number Diff line change
@@ -288,8 +288,7 @@ def plot_response(


class DIRPATRir(Directivity):

"""
r"""
Open specific DIRPAT files and interpolate the FIR filters on the fibonacci sphere also frequency independent
cardioid patterns can be called from this function
This class inherits the base class Directivity and all its functions.
@@ -387,8 +386,10 @@ def __init__(

if no_points_on_fibo_sphere == 0:
self.interpolate = False
interp_order = None
else:
self.interpolate = True
interp_order = 12

self.fs = fs

@@ -401,12 +402,17 @@ def __init__(
fs=self.fs,
DIRPAT_pattern_enum=DIRPAT_pattern_enum,
source=self.source,
interpolate=self.interpolate,
no_of_points_fibo_sphere=self.points_on_fibo,
azimuth_simulation=self._orientation.get_azimuth(degrees=False),
colatitude_simulation=self._orientation.get_colatitude(degrees=False),
interp_order=interp_order,
interp_n_points=self.points_on_fibo,
)

self.obj_open_sofa_inter.change_orientation(
azimuth_change=self._orientation.get_azimuth(degrees=False),
colatitude_change=self._orientation.get_colatitude(degrees=False),
degrees=False,
)
self.filter_len_ir = self.obj_open_sofa_inter.samples_size_ir

self.filter_len_ir = self.obj_open_sofa_inter.impulse_responses.shape[-1]
# self.obj_open_sofa_inter.plot(freq_bin=30) For plotting directivity pattern on the sphere for a specific frequency bin

def set_orientation(self, azimuth, colatitude):
@@ -445,13 +451,12 @@ def get_response(
False : Azimuth and colatitude are provided in radians.
"""
if degrees:
azimuth = np.deg2rad(azimuth)
colatitude = np.deg2rad(colatitude)

indexs = self.obj_open_sofa_inter.cal_index_knn(
azimuth, colatitude
) # Using NN search calculates all the indexes for list of azimuth and colitude of all the image sources
return self.obj_open_sofa_inter.neareast_neighbour(
indexs
) # Returns filter response for the given indexes from the sphere (interpolation (True) : fibonacci sphere , interpolation (False) : original grid )
return self.obj_open_sofa_inter.nearest_neighbour(azimuth, colatitude)
# Returns filter response for the given indexes from the sphere (interpolation (True) : fibonacci sphere , interpolation (False) : original grid )


def cardioid_func(x, direction, coef, gain=1.0, normalize=True, magnitude=False):
54 changes: 43 additions & 11 deletions pyroomacoustics/doa/grid.py
Original file line number Diff line number Diff line change
@@ -9,7 +9,7 @@
import scipy.spatial as sp # import ConvexHull, SphericalVoronoi

from .detect_peaks import detect_peaks
from .utils import great_circ_dist, fibonnaci_spherical_sampling
from .utils import cart2spher, fibonnaci_spherical_sampling, great_circ_dist, spher2cart


class Grid:
@@ -37,6 +37,9 @@ def __init__(self, n_points):
self.dim = 0
self.values = None

def __len__(self):
return self.cartesian.shape[1]

@abstractmethod
def apply(self, func, spherical=False):
return NotImplemented
@@ -153,35 +156,53 @@ class GridSphere(Grid):
The number of points to sample
spherical_points: ndarray, optional
A 2 x n_points array of spherical coordinates with azimuth in
the top row and colatitude in the second row. Overrides n_points.
the top row and colatitude in the second row. Overrides ``n_points``
and ``cartesian_points``.
cartesian_points: ndarray, optional
A 3 x n_points array of Cartesian coordinates with x, y, z coordinates
in the rows. The vectors are normalized to unit-norm in the constructor.
Overrides ``n_points``.
precompute_neighbors: bool, optional
If `True`, the convex hull algorithm is used to find all
the neighbors of the grid points. This is used for the peak finding
algorithm.
References
----------
http://lgdv.cs.fau.de/uploads/publications/spherical_fibonacci_mapping.pdf
http://stackoverflow.com/questions/9600801/evenly-distributing-n-points-on-a-sphere
"""

def __init__(self, n_points=1000, spherical_points=None):
def __init__(
self,
n_points=1000,
spherical_points=None,
cartesian_points=None,
precompute_neighbors=False,
):
if spherical_points is not None:
if spherical_points.ndim != 2 or spherical_points.shape[0] != 2:
raise ValueError("spherical_points must be a 2D array with two rows.")

n_points = spherical_points.shape[1]
elif cartesian_points is not None:
if cartesian_points.ndim != 2 or cartesian_points.shape[0] != 3:
raise ValueError("cartesian_points must be a 3D array with two rows.")
n_points = cartesian_points.shape[1]

# Parent constructor
Grid.__init__(self, n_points)

self.dim = 3

if spherical_points is not None:
# If a list of points was provided, use it

self.spherical[:, :] = spherical_points
self.cartesian[:] = spher2cart(self.azimuth, self.colatitude)

# transform to cartesian coordinates
self.x[:] = np.cos(self.azimuth) * np.sin(self.colatitude)
self.y[:] = np.sin(self.azimuth) * np.sin(self.colatitude)
self.z[:] = np.cos(self.colatitude)
elif cartesian_points is not None:
# normalize all
norms = np.linalg.norm(cartesian_points, axis=0, keepdims=True)
self.cartesian[:] = cartesian_points / norms
self.azimuth[:], self.colatitude[:], _ = cart2spher(self.cartesian)

else:
# If no list was provided, samples points on the sphere
@@ -194,6 +215,17 @@ def __init__(self, n_points=1000, spherical_points=None):
self.azimuth[:] = np.arctan2(self.y, self.x)
self.colatitude[:] = np.arctan2(np.sqrt(self.x**2 + self.y**2), self.z)

self._neighbors = None
if precompute_neighbors:
self._compute_neighbors()

@property
def neighbors(self):
if self._neighbors is None:
self._compute_neighbors()
return self._neighbors

def _compute_neighbors(self):
# To perform the peak detection in 2D on a non-squared grid it is
# necessary to know the neighboring points of each grid point. The
# Convex Hull of points on the sphere is equivalent to the Delauney
@@ -217,7 +249,7 @@ def __init__(self, n_points=1000, spherical_points=None):
adjacency[tri[2]].add(tri[1])

# convert to list of lists
self.neighbors = [list(x) for x in adjacency]
self._neighbors = [list(x) for x in adjacency]

def apply(self, func, spherical=False):
"""
464 changes: 194 additions & 270 deletions pyroomacoustics/open_sofa_interpolate.py

Large diffs are not rendered by default.

162 changes: 113 additions & 49 deletions pyroomacoustics/tests/test_sofa_directivities.py
Original file line number Diff line number Diff line change
@@ -5,10 +5,12 @@
To generate the samples run this file: `python ./test_sofa_directivities.py`
"""
import argparse
import os
import warnings
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pytest

@@ -24,10 +26,12 @@
DirectivityPattern,
DIRPATRir,
)
from pyroomacoustics.doa import GridSphere
from pyroomacoustics.open_sofa_interpolate import (
RegularGrid,
_detect_regular_grid,
calculation_pinv_voronoi_cells,
calculation_pinv_voronoi_cells_general,
_detect_regular_grid,
)

sofa_info = get_sofa_db_info()
@@ -94,22 +98,27 @@ def test_dirpat_download():


SOFA_ONE_SIDE_PARAMETERS = [
("AKG_c480", "AKG_c480_c414_CUBE.sofa", False),
("AKG_c414K", "AKG_c480_c414_CUBE.sofa", False),
("AKG_c414N", "AKG_c480_c414_CUBE.sofa", False),
("AKG_c414S", "AKG_c480_c414_CUBE.sofa", False),
("AKG_c414A", "AKG_c480_c414_CUBE.sofa", False),
("EM_32_0", "EM32_Directivity.sofa", False),
("EM_32_31", "EM32_Directivity.sofa", False),
("Genelec_8020", "LSPs_HATS_GuitarCabinets_Akustikmessplatz.sofa", False),
("Vibrolux_2x10inch", "LSPs_HATS_GuitarCabinets_Akustikmessplatz.sofa", False),
("AKG_c480", "AKG_c480_c414_CUBE.sofa", False, False),
("AKG_c414K", "AKG_c480_c414_CUBE.sofa", False, False),
("AKG_c414N", "AKG_c480_c414_CUBE.sofa", False, False),
("AKG_c414S", "AKG_c480_c414_CUBE.sofa", False, False),
("AKG_c414A", "AKG_c480_c414_CUBE.sofa", False, False),
("EM_32_0", "EM32_Directivity.sofa", False, False),
("EM_32_31", "EM32_Directivity.sofa", False, False),
("Genelec_8020", "LSPs_HATS_GuitarCabinets_Akustikmessplatz.sofa", False, False),
(
"Vibrolux_2x10inch",
"LSPs_HATS_GuitarCabinets_Akustikmessplatz.sofa",
False,
False,
),
]


@pytest.mark.parametrize(
"pattern_id,sofa_file_name,save_flag", SOFA_ONE_SIDE_PARAMETERS
"pattern_id,sofa_file_name,save_flag,plot_flag", SOFA_ONE_SIDE_PARAMETERS
)
def test_sofa_one_side(pattern_id, sofa_file_name, save_flag):
def test_sofa_one_side(pattern_id, sofa_file_name, save_flag, plot_flag):
"""
Tests with only microphone *or* source from a SOFA file
"""
@@ -131,7 +140,7 @@ def test_sofa_one_side(pattern_id, sofa_file_name, save_flag):

# define source with figure_eight directivity
directivity = DIRPATRir(
orientation=DirectionVector(azimuth=90, colatitude=90, degrees=True),
orientation=DirectionVector(azimuth=0, colatitude=0, degrees=True),
path=Path(DEFAULT_SOFA_PATH) / sofa_file_name,
DIRPAT_pattern_enum=pattern_id,
fs=16000,
@@ -172,6 +181,12 @@ def test_sofa_one_side(pattern_id, sofa_file_name, save_flag):
"Rel diff.:",
abs(reference_data - rir_1_0).max() / abs(reference_data).max(),
)
if plot_flag:
fig, ax = plt.subplots(1, 1)
ax.plot(rir_1_0, label="test")
ax.plot(reference_data, label="ref")
ax.legend()
fig.savefig(test_file_path.with_suffix(".pdf"))
assert np.allclose(reference_data, rir_1_0, atol=atol, rtol=rtol)
else:
warnings.warn("Did not find the reference data. Output was not checked.")
@@ -184,37 +199,47 @@ def test_sofa_one_side(pattern_id, sofa_file_name, save_flag):
"AKG_c480",
"AKG_c480_c414_CUBE.sofa",
False,
False,
),
(
"Vibrolux_2x10inch",
"LSPs_HATS_GuitarCabinets_Akustikmessplatz.sofa",
"AKG_c414K",
"AKG_c480_c414_CUBE.sofa",
False,
False,
),
(
"Vibrolux_2x10inch",
"LSPs_HATS_GuitarCabinets_Akustikmessplatz.sofa",
"EM_32_0",
"EM32_Directivity.sofa",
False,
False,
),
(
"Genelec_8020",
"LSPs_HATS_GuitarCabinets_Akustikmessplatz.sofa",
"EM_32_31",
"EM32_Directivity.sofa",
False,
False,
),
]


@pytest.mark.parametrize(
"src_pattern_id, src_sofa_file_name, mic_pattern_id, mic_sofa_file_name, save_flag",
"src_pattern_id, src_sofa_file_name, mic_pattern_id, "
"mic_sofa_file_name, save_flag, plot_flag",
SOFA_TWO_SIDES_PARAMETERS,
)
def test_sofa_two_sides(
src_pattern_id, src_sofa_file_name, mic_pattern_id, mic_sofa_file_name, save_flag
src_pattern_id,
src_sofa_file_name,
mic_pattern_id,
mic_sofa_file_name,
save_flag,
plot_flag,
):
"""
Tests with only microphone *or* source from a SOFA file
@@ -236,14 +261,14 @@ def test_sofa_two_sides(
)

src_directivity = DIRPATRir(
orientation=DirectionVector(azimuth=90, colatitude=90, degrees=True),
orientation=DirectionVector(azimuth=0, colatitude=0, degrees=True),
path=Path(DEFAULT_SOFA_PATH) / src_sofa_file_name,
DIRPAT_pattern_enum=src_pattern_id,
fs=16000,
)

mic_directivity = DIRPATRir(
orientation=DirectionVector(azimuth=90, colatitude=90, degrees=True),
orientation=DirectionVector(azimuth=0, colatitude=0, degrees=True),
path=Path(DEFAULT_SOFA_PATH) / mic_sofa_file_name,
DIRPAT_pattern_enum=mic_pattern_id,
fs=16000,
@@ -279,33 +304,47 @@ def test_sofa_two_sides(
np.save(test_file_path, rir_1_0)
elif test_file_path.exists():
reference_data = np.load(test_file_path)

print("Max diff.:", abs(reference_data - rir_1_0).max())
print(
"Rel diff.:",
abs(reference_data - rir_1_0).max() / abs(reference_data).max(),
)

if plot_flag:
fig, ax = plt.subplots(1, 1)
ax.plot(rir_1_0, label="test")
ax.plot(reference_data, label="ref")
ax.legend()
fig.savefig(test_file_path.with_suffix(".pdf"))

assert np.allclose(reference_data, rir_1_0, atol=atol, rtol=rtol)
else:
warnings.warn("Did not find the reference data. Output was not checked.")


SOFA_CARDIOID_PARAMETERS = [
("AKG_c480", "AKG_c480_c414_CUBE.sofa", False),
("AKG_c414K", "AKG_c480_c414_CUBE.sofa", False),
("AKG_c414N", "AKG_c480_c414_CUBE.sofa", False),
("AKG_c414S", "AKG_c480_c414_CUBE.sofa", False),
("AKG_c414A", "AKG_c480_c414_CUBE.sofa", False),
("EM_32_0", "EM32_Directivity.sofa", False),
("EM_32_31", "EM32_Directivity.sofa", False),
("Genelec_8020", "LSPs_HATS_GuitarCabinets_Akustikmessplatz.sofa", False),
("Vibrolux_2x10inch", "LSPs_HATS_GuitarCabinets_Akustikmessplatz.sofa", False),
("AKG_c480", "AKG_c480_c414_CUBE.sofa", False, False),
("AKG_c414K", "AKG_c480_c414_CUBE.sofa", False, False),
("AKG_c414N", "AKG_c480_c414_CUBE.sofa", False, False),
("AKG_c414S", "AKG_c480_c414_CUBE.sofa", False, False),
("AKG_c414A", "AKG_c480_c414_CUBE.sofa", False, False),
("EM_32_0", "EM32_Directivity.sofa", False, False),
("EM_32_31", "EM32_Directivity.sofa", False, False),
("Genelec_8020", "LSPs_HATS_GuitarCabinets_Akustikmessplatz.sofa", False, False),
(
"Vibrolux_2x10inch",
"LSPs_HATS_GuitarCabinets_Akustikmessplatz.sofa",
False,
False,
),
]


@pytest.mark.parametrize(
"pattern_id,sofa_file_name,save_flag", SOFA_CARDIOID_PARAMETERS
"pattern_id,sofa_file_name,save_flag,plot_flag", SOFA_CARDIOID_PARAMETERS
)
def test_sofa_and_cardioid(pattern_id, sofa_file_name, save_flag):
def test_sofa_and_cardioid(pattern_id, sofa_file_name, save_flag, plot_flag):
"""
Tests with only microphone *or* source from a SOFA file
"""
@@ -327,7 +366,7 @@ def test_sofa_and_cardioid(pattern_id, sofa_file_name, save_flag):

# define source with figure_eight directivity
directivity = DIRPATRir(
orientation=DirectionVector(azimuth=270, colatitude=90, degrees=True),
orientation=DirectionVector(azimuth=0, colatitude=0, degrees=True),
path=Path(DEFAULT_SOFA_PATH) / sofa_file_name,
DIRPAT_pattern_enum=pattern_id,
fs=16000,
@@ -369,11 +408,20 @@ def test_sofa_and_cardioid(pattern_id, sofa_file_name, save_flag):
np.save(test_file_path, rir_1_0)
elif test_file_path.exists():
reference_data = np.load(test_file_path)

print("Max diff.:", abs(reference_data - rir_1_0).max())
print(
"Rel diff.:",
abs(reference_data - rir_1_0).max() / abs(reference_data).max(),
)

if plot_flag:
fig, ax = plt.subplots(1, 1)
ax.plot(rir_1_0, label="test")
ax.plot(reference_data, label="ref")
ax.legend()
fig.savefig(test_file_path.with_suffix(".pdf"))

assert np.allclose(reference_data, rir_1_0, atol=atol, rtol=rtol)
else:
warnings.warn("Did not find the reference data. Output was not checked.")
@@ -423,18 +471,18 @@ def test_weighted_pinv(n_azimuth, n_col, col_start, col_end, order, atol, rtol):

@pytest.mark.parametrize("n_az, n_co", [(36, 12), (72, 11), (360, 180)])
def test_detect_grid_regular(n_az, n_co):

azimuth = np.linspace(0, 2 * np.pi, n_az, endpoint=False)
colatitude = np.linspace(np.pi / 2.0 / n_co, np.pi - np.pi / 2.0 / n_co, n_co)
A, C = np.meshgrid(azimuth, colatitude)
alin = A.flatten()
clin = C.flatten()

dic = _detect_regular_grid(alin, clin)
grid = GridSphere(spherical_points=np.array((alin, clin)))
reg_grid = _detect_regular_grid(grid)

assert isinstance(dic, dict)
assert np.allclose(dic["azimuth"], azimuth)
assert np.allclose(dic["colatitude"], colatitude)
assert isinstance(reg_grid, RegularGrid)
assert np.allclose(reg_grid.azimuth, azimuth)
assert np.allclose(reg_grid.colatitude, colatitude)


@pytest.mark.parametrize(
@@ -443,8 +491,9 @@ def test_detect_grid_regular(n_az, n_co):
def test_detect_not_grid(n_points):
alin = np.random.rand(n_points) * 2 * np.pi
clin = np.random.rand(n_points) * np.pi
dic = _detect_regular_grid(alin, clin)
assert dic is None
grid = GridSphere(spherical_points=np.array((alin, clin)))
reg_grid = _detect_regular_grid(grid)
assert reg_grid is None


@pytest.mark.parametrize("n_az, n_co", [(36, 12), (72, 11), (360, 180)])
@@ -455,9 +504,10 @@ def test_detect_grid_irregular_azimuth(n_az, n_co):
alin = A.flatten()
clin = C.flatten()

dic = _detect_regular_grid(alin, clin)
grid = GridSphere(spherical_points=np.array((alin, clin)))
reg_grid = _detect_regular_grid(grid)

assert dic is None # should fail when azimuth is irregular
assert reg_grid is None # should fail when azimuth is irregular


@pytest.mark.parametrize("n_az, n_co", [(36, 12), (72, 11), (360, 180)])
@@ -468,12 +518,13 @@ def test_detect_grid_irregular_colatitude(n_az, n_co):
alin = A.flatten()
clin = C.flatten()

dic = _detect_regular_grid(alin, clin)
grid = GridSphere(spherical_points=np.array((alin, clin)))
reg_grid = _detect_regular_grid(grid)

# should succeed when azimuth is regular
assert isinstance(dic, dict)
assert np.allclose(dic["azimuth"], azimuth)
assert np.allclose(dic["colatitude"], colatitude)
assert isinstance(reg_grid, RegularGrid)
assert np.allclose(reg_grid.azimuth, azimuth)
assert np.allclose(reg_grid.colatitude, colatitude)


@pytest.mark.parametrize("n_az, n_co", [(36, 12), (72, 11), (360, 180)])
@@ -488,27 +539,40 @@ def test_detect_grid_point_duplicate(n_az, n_co):
clin[i] = clin[i + 1]
alin[i] = alin[i + 1]

dic = _detect_regular_grid(alin, clin)
grid = GridSphere(spherical_points=np.array((alin, clin)))
reg_grid = _detect_regular_grid(grid)

# should fail because this is not a grid
assert dic is None
assert reg_grid is None


if __name__ == "__main__":
# generate the test files for regression testing
"""
parser = argparse.ArgumentParser()
parser.add_argument(
"--save", action="store_true", help="save the signal as a reference"
)
parser.add_argument(
"--plot", action="store_true", help="plot the generated signals"
)
args = parser.parse_args()

download_sofa_files(verbose=True)

for params in SOFA_ONE_SIDE_PARAMETERS:
new_params = params[:-1] + (True,)
new_params = params[:-2] + (args.save, args.plot)
test_sofa_one_side(*new_params)

for params in SOFA_TWO_SIDES_PARAMETERS:
new_params = params[:-1] + (True,)
new_params = params[:-2] + (args.save, args.plot)
test_sofa_two_sides(*new_params)

for params in SOFA_CARDIOID_PARAMETERS:
new_params = params[:-1] + (True,)
new_params = params[:-2] + (args.save, args.plot)
test_sofa_and_cardioid(*new_params)

for params in PINV_PARAMETERS:
test_weighted_pinv(*params)
"""

for p in [(36, 12), (72, 11), (360, 180)]:
test_detect_grid_irregular_colatitude(*p)