16
16
from .gaussian import bures_wasserstein_mapping
17
17
18
18
19
- def gaussian_pdf (x , m , C ):
19
+ def gaussian_logpdf (x , m , C ):
20
20
r"""
21
- Compute the probability density function of a multivariate
21
+ Compute the log of the probability density function of a multivariate
22
22
Gaussian distribution.
23
23
24
24
Parameters
@@ -40,10 +40,35 @@ def gaussian_pdf(x, m, C):
40
40
x .shape [- 1 ] == m .shape [- 1 ] == C .shape [- 1 ] == C .shape [- 2 ]
41
41
), "Dimension mismatch"
42
42
nx = get_backend (x , m , C )
43
- d = x .shape [- 1 ]
44
- z = (2 * np .pi ) ** (- d / 2 ) * nx .det (C ) ** (- 0.5 )
45
- exp = nx .exp (- 0.5 * nx .sum (((x - m ) @ nx .inv (C )) * (x - m ), axis = - 1 ))
46
- return z * exp
43
+ d = m .shape [0 ]
44
+ diff = x - m
45
+ inv_C = nx .inv (C )
46
+ z = nx .sum (diff * (diff @ inv_C ), axis = - 1 )
47
+ _ , log_det_C = nx .slogdet (C )
48
+ return - 0.5 * (d * np .log (2 * np .pi ) + log_det_C + z )
49
+
50
+
51
+ def gaussian_pdf (x , m , C ):
52
+ r"""
53
+ Compute the probability density function of a multivariate
54
+ Gaussian distribution.
55
+
56
+ Parameters
57
+ ----------
58
+ x : array-like, shape (..., d)
59
+ The input samples.
60
+ m : array-like, shape (d,)
61
+ The mean vector of the Gaussian distribution.
62
+ C : array-like, shape (d, d)
63
+ The covariance matrix of the Gaussian distribution.
64
+
65
+ Returns
66
+ -------
67
+ pdf : array-like, shape (...,)
68
+ The probability density function evaluated at each sample.
69
+
70
+ """
71
+ return get_backend (x , m , C ).exp (gaussian_logpdf (x , m , C ))
47
72
48
73
49
74
def gmm_pdf (x , m , C , w ):
@@ -281,25 +306,28 @@ def gmm_ot_apply_map(
281
306
n_samples = x .shape [0 ]
282
307
283
308
if method == "bary" :
284
- normalization = gmm_pdf (x , m_s , C_s , w_s )[:, None ]
285
309
out = nx .zeros (x .shape )
286
- print ("where plan > 0" , nx .where (plan > 0 ))
310
+ logpdf = nx .stack (
311
+ [gaussian_logpdf (x , m_s [k ], C_s [k ])[:, None ] for k in range (k_s )]
312
+ )
287
313
288
314
# only need to compute for non-zero plan entries
289
315
for i , j in zip (* nx .where (plan > 0 )):
290
316
Cs12 = nx .sqrtm (C_s [i ])
291
317
Cs12inv = nx .inv (Cs12 )
292
- g = gaussian_pdf (x , m_s [i ], C_s [i ])[:, None ]
293
318
294
319
M0 = nx .sqrtm (Cs12 @ C_t [j ] @ Cs12 )
295
320
A = Cs12inv @ M0 @ Cs12inv
296
321
b = m_t [j ] - A @ m_s [i ]
297
322
298
323
# gaussian mapping between components i and j applied to x
299
324
T_ij_x = x @ A + b
300
- out = out + plan [i , j ] * g * T_ij_x
325
+ z = w_s [:, None , None ] * nx .exp (logpdf - logpdf [i ][None , :, :])
326
+ denom = nx .sum (z , axis = 0 )
301
327
302
- return out / normalization
328
+ out = out + plan [i , j ] * T_ij_x / denom
329
+
330
+ return out
303
331
304
332
else : # rand
305
333
# A[i, j] is the linear part of the gaussian mapping between components
@@ -318,13 +346,19 @@ def gmm_ot_apply_map(
318
346
A [i , j ] = Cs12inv @ M0 @ Cs12inv
319
347
b [i , j ] = m_t [j ] - A [i , j ] @ m_s [i ]
320
348
321
- normalization = gmm_pdf (x , m_s , C_s , w_s ) # (n_samples,)
322
- gs = np .stack ([gaussian_pdf (x , m_s [i ], C_s [i ]) for i in range (k_s )], axis = - 1 )
349
+ logpdf = nx .stack (
350
+ [gaussian_logpdf (x , m_s [k ], C_s [k ]) for k in range (k_s )], axis = - 1
351
+ )
323
352
# (n_samples, k_s)
324
353
out = nx .zeros (x .shape )
325
354
326
355
for i_sample in range (n_samples ):
327
- p_mat = plan * gs [i_sample ][:, None ] / normalization [i_sample ]
356
+ log_g = logpdf [i_sample ]
357
+ log_diff = log_g [:, None ] - log_g [None , :]
358
+ weighted_exp = w_s [:, None ] * nx .exp (log_diff )
359
+ denom = nx .sum (weighted_exp , axis = 0 )[:, None ] * nx .ones (plan .shape [1 ])
360
+ p_mat = plan / denom
361
+
328
362
p = p_mat .reshape (k_s * k_t ) # stack line-by-line
329
363
# sample between 0 and k_s * k_t - 1
330
364
ij_mat = rng .choice (k_s * k_t , p = p )
0 commit comments