Skip to content

Commit 6311e25

Browse files
samuelbxSamuel Boïtérflamary
authored
[FIX] numerical errors in ot.gmm (#690)
* fixed numerical errors in density computations * lint * hotfix * cholesky not useful anymore * vectorization * backend tests for slogdet * update releases.md * added contribution --------- Co-authored-by: Samuel Boïté <[email protected]> Co-authored-by: Rémi Flamary <[email protected]>
1 parent b5cfb91 commit 6311e25

File tree

5 files changed

+80
-14
lines changed

5 files changed

+80
-14
lines changed

Diff for: CONTRIBUTORS.md

+1
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ The contributors to this library are:
5454
* [Sonia Mazelet](https://github.com/SoniaMaz8) (Template based GNN layers)
5555
* [Laurène David](https://github.com/laudavid) (Low rank sinkhorn, Low rank Gromov-Wasserstein samples)
5656
* [Julie Delon](https://judelo.github.io/) (GMM OT)
57+
* [Samuel Boïté](https://samuelbx.github.io/) (GMM OT)
5758

5859
## Acknowledgments
5960

Diff for: RELEASES.md

+1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
#### Closed issues
99
- Fixed `ot.mapping` solvers which depended on deprecated `cvxpy` `ECOS` solver (PR #692, Issue #668)
10+
- Fixed numerical errors in `ot.gmm` (PR #690, Issue #689)
1011

1112

1213
## 0.9.5

Diff for: ot/backend.py

+23
Original file line numberDiff line numberDiff line change
@@ -1073,6 +1073,14 @@ def det(self, a):
10731073
"""
10741074
raise NotImplementedError()
10751075

1076+
def slogdet(self, a):
1077+
r"""
1078+
Compute the sign and (natural) logarithm of the determinant of an array.
1079+
1080+
See: https://numpy.org/doc/stable/reference/generated/numpy.linalg.slogdet.html
1081+
"""
1082+
raise NotImplementedError()
1083+
10761084

10771085
class NumpyBackend(Backend):
10781086
"""
@@ -1433,6 +1441,9 @@ def nan_to_num(self, x, copy=True, nan=0.0, posinf=None, neginf=None):
14331441
def det(self, a):
14341442
return np.linalg.det(a)
14351443

1444+
def slogdet(self, a):
1445+
return np.linalg.slogdet(a)
1446+
14361447

14371448
_register_backend_implementation(NumpyBackend)
14381449

@@ -1826,6 +1837,9 @@ def nan_to_num(self, x, copy=True, nan=0.0, posinf=None, neginf=None):
18261837
def det(self, x):
18271838
return jnp.linalg.det(x)
18281839

1840+
def slogdet(self, a):
1841+
return jnp.linalg.slogdet(a)
1842+
18291843

18301844
if jax:
18311845
# Only register jax backend if it is installed
@@ -2359,6 +2373,9 @@ def nan_to_num(self, x, copy=True, nan=0.0, posinf=None, neginf=None):
23592373
def det(self, x):
23602374
return torch.linalg.det(x)
23612375

2376+
def slogdet(self, a):
2377+
return torch.linalg.slogdet(a)
2378+
23622379

23632380
if torch:
23642381
# Only register torch backend if it is installed
@@ -2767,6 +2784,9 @@ def nan_to_num(self, x, copy=True, nan=0.0, posinf=None, neginf=None):
27672784
def det(self, x):
27682785
return cp.linalg.det(x)
27692786

2787+
def slogdet(self, a):
2788+
return cp.linalg.slogdet(a)
2789+
27702790

27712791
if cp:
27722792
# Only register cp backend if it is installed
@@ -3205,6 +3225,9 @@ def nan_to_num(self, x, copy=True, nan=0.0, posinf=None, neginf=None):
32053225
def det(self, x):
32063226
return tf.linalg.det(x)
32073227

3228+
def slogdet(self, a):
3229+
return tf.linalg.slogdet(a)
3230+
32083231

32093232
if tf:
32103233
# Only register tensorflow backend if it is installed

Diff for: ot/gmm.py

+48-14
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616
from .gaussian import bures_wasserstein_mapping
1717

1818

19-
def gaussian_pdf(x, m, C):
19+
def gaussian_logpdf(x, m, C):
2020
r"""
21-
Compute the probability density function of a multivariate
21+
Compute the log of the probability density function of a multivariate
2222
Gaussian distribution.
2323
2424
Parameters
@@ -40,10 +40,35 @@ def gaussian_pdf(x, m, C):
4040
x.shape[-1] == m.shape[-1] == C.shape[-1] == C.shape[-2]
4141
), "Dimension mismatch"
4242
nx = get_backend(x, m, C)
43-
d = x.shape[-1]
44-
z = (2 * np.pi) ** (-d / 2) * nx.det(C) ** (-0.5)
45-
exp = nx.exp(-0.5 * nx.sum(((x - m) @ nx.inv(C)) * (x - m), axis=-1))
46-
return z * exp
43+
d = m.shape[0]
44+
diff = x - m
45+
inv_C = nx.inv(C)
46+
z = nx.sum(diff * (diff @ inv_C), axis=-1)
47+
_, log_det_C = nx.slogdet(C)
48+
return -0.5 * (d * np.log(2 * np.pi) + log_det_C + z)
49+
50+
51+
def gaussian_pdf(x, m, C):
52+
r"""
53+
Compute the probability density function of a multivariate
54+
Gaussian distribution.
55+
56+
Parameters
57+
----------
58+
x : array-like, shape (..., d)
59+
The input samples.
60+
m : array-like, shape (d,)
61+
The mean vector of the Gaussian distribution.
62+
C : array-like, shape (d, d)
63+
The covariance matrix of the Gaussian distribution.
64+
65+
Returns
66+
-------
67+
pdf : array-like, shape (...,)
68+
The probability density function evaluated at each sample.
69+
70+
"""
71+
return get_backend(x, m, C).exp(gaussian_logpdf(x, m, C))
4772

4873

4974
def gmm_pdf(x, m, C, w):
@@ -281,25 +306,28 @@ def gmm_ot_apply_map(
281306
n_samples = x.shape[0]
282307

283308
if method == "bary":
284-
normalization = gmm_pdf(x, m_s, C_s, w_s)[:, None]
285309
out = nx.zeros(x.shape)
286-
print("where plan > 0", nx.where(plan > 0))
310+
logpdf = nx.stack(
311+
[gaussian_logpdf(x, m_s[k], C_s[k])[:, None] for k in range(k_s)]
312+
)
287313

288314
# only need to compute for non-zero plan entries
289315
for i, j in zip(*nx.where(plan > 0)):
290316
Cs12 = nx.sqrtm(C_s[i])
291317
Cs12inv = nx.inv(Cs12)
292-
g = gaussian_pdf(x, m_s[i], C_s[i])[:, None]
293318

294319
M0 = nx.sqrtm(Cs12 @ C_t[j] @ Cs12)
295320
A = Cs12inv @ M0 @ Cs12inv
296321
b = m_t[j] - A @ m_s[i]
297322

298323
# gaussian mapping between components i and j applied to x
299324
T_ij_x = x @ A + b
300-
out = out + plan[i, j] * g * T_ij_x
325+
z = w_s[:, None, None] * nx.exp(logpdf - logpdf[i][None, :, :])
326+
denom = nx.sum(z, axis=0)
301327

302-
return out / normalization
328+
out = out + plan[i, j] * T_ij_x / denom
329+
330+
return out
303331

304332
else: # rand
305333
# A[i, j] is the linear part of the gaussian mapping between components
@@ -318,13 +346,19 @@ def gmm_ot_apply_map(
318346
A[i, j] = Cs12inv @ M0 @ Cs12inv
319347
b[i, j] = m_t[j] - A[i, j] @ m_s[i]
320348

321-
normalization = gmm_pdf(x, m_s, C_s, w_s) # (n_samples,)
322-
gs = np.stack([gaussian_pdf(x, m_s[i], C_s[i]) for i in range(k_s)], axis=-1)
349+
logpdf = nx.stack(
350+
[gaussian_logpdf(x, m_s[k], C_s[k]) for k in range(k_s)], axis=-1
351+
)
323352
# (n_samples, k_s)
324353
out = nx.zeros(x.shape)
325354

326355
for i_sample in range(n_samples):
327-
p_mat = plan * gs[i_sample][:, None] / normalization[i_sample]
356+
log_g = logpdf[i_sample]
357+
log_diff = log_g[:, None] - log_g[None, :]
358+
weighted_exp = w_s[:, None] * nx.exp(log_diff)
359+
denom = nx.sum(weighted_exp, axis=0)[:, None] * nx.ones(plan.shape[1])
360+
p_mat = plan / denom
361+
328362
p = p_mat.reshape(k_s * k_t) # stack line-by-line
329363
# sample between 0 and k_s * k_t - 1
330364
ij_mat = rng.choice(k_s * k_t, p=p)

Diff for: test/test_backend.py

+7
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,8 @@ def test_empty_backend():
271271
nx.eigh(M)
272272
with pytest.raises(NotImplementedError):
273273
nx.det(M)
274+
with pytest.raises(NotImplementedError):
275+
nx.slogdet(M)
274276

275277

276278
def test_func_backends(nx):
@@ -691,6 +693,11 @@ def test_func_backends(nx):
691693
lst_b.append(nx.to_numpy(d))
692694
lst_name.append("det")
693695

696+
s, logabsd = nx.slogdet(M1b)
697+
s, logabsd = nx.to_numpy(s), nx.to_numpy(logabsd)
698+
lst_b.append(np.array([s, logabsd]))
699+
lst_name.append("slogdet")
700+
694701
assert not nx.array_equal(Mb, vb), "array_equal (shape)"
695702
assert nx.array_equal(Mb, Mb), "array_equal (elements) - expected true"
696703
assert not nx.array_equal(

0 commit comments

Comments
 (0)