Skip to content

Commit 506a524

Browse files
committed
up stop criteria sgd Gaussian barycenter
1 parent fe3d9db commit 506a524

File tree

2 files changed

+62
-27
lines changed

2 files changed

+62
-27
lines changed

ot/gaussian.py

Lines changed: 45 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -541,6 +541,7 @@ def bures_barycenter_gradient_descent(
541541
log=False,
542542
step_size=1,
543543
batch_size=None,
544+
averaged=False,
544545
nx=None,
545546
):
546547
r"""Return the (Bures-)Wasserstein barycenter between centered Gaussian distributions.
@@ -570,6 +571,8 @@ def bures_barycenter_gradient_descent(
570571
step size for the gradient descent, 1 by default
571572
batch_size : int, optional
572573
batch size if use a stochastic gradient descent
574+
averaged : bool, optional
575+
if True, use the averaged procedure of :ref:`[74] <references-OT-bures-barycenter-gradient_descent>`
573576
nx : module, optional
574577
The numerical backend module to use. If not provided, the backend will
575578
be fetched from the input matrices `C`.
@@ -607,7 +610,9 @@ def bures_barycenter_gradient_descent(
607610
Cb = nx.mean(C * weights[:, None, None], axis=0)
608611
Id = nx.eye(C.shape[-1], type_as=Cb)
609612

610-
L_grads = []
613+
L_diff = []
614+
615+
Cb_averaged = nx.copy(Cb)
611616

612617
for it in range(num_iter):
613618
Cb12 = nx.sqrtm(Cb)
@@ -627,40 +632,38 @@ def bures_barycenter_gradient_descent(
627632

628633
# step size from [74] (page 15)
629634
step_size = 2 / (0.7 * (it + 2 / 0.7 + 1))
630-
631-
# TODO: Add one where we take samples in order, + averaging? cf [74]
632635
else: # gradient descent
633636
M = nx.sqrtm(nx.einsum("ij,njk,kl -> nil", Cb12, C, Cb12))
634637
ot_maps = nx.einsum("ij,njk,kl -> nil", Cb12_, M, Cb12_)
635638
grad_bw = Id - nx.sum(ot_maps * weights[:, None, None], axis=0)
636639

637640
Cnew = exp_bures(Cb, -step_size * grad_bw, nx=nx)
638641

642+
if averaged:
643+
# ot map between Cb_averaged and Cnew
644+
Cb_averaged12 = nx.sqrtm(Cb_averaged)
645+
Cb_averaged12inv = nx.inv(Cb_averaged12)
646+
M = nx.sqrtm(nx.einsum("ij,jk,kl->il", Cb_averaged12, Cnew, Cb_averaged12))
647+
ot_map = nx.einsum("ij,jk,kl->il", Cb_averaged12inv, M, Cb_averaged12inv)
648+
map = Id * step_size / (step_size + 1) + ot_map / (step_size + 1)
649+
Cb_averaged = nx.einsum("ij,jk,kl->il", map, Cb_averaged, map)
650+
639651
# check convergence
640-
if batch_size is not None and batch_size < n:
641-
# TODO: criteria for SGD: on gradients? + test SGD
642-
# TOO slow, test with value? (but don't want to compute the full barycenter)
643-
# + need to make bures_wasserstein_distance batchable (TODO)
644-
L_grads.append(nx.sum(grad_bw**2))
645-
diff = np.mean(L_grads)
646-
647-
# L_values.append(nx.norm(Cb - Cnew))
648-
# print(diff, np.mean(L_values))
649-
else:
650-
diff = nx.norm(Cb - Cnew)
652+
L_diff.append(nx.norm(Cb - Cnew))
651653

652-
if diff <= eps:
654+
# Criteria to stop
655+
if np.mean(L_diff[-100:]) <= eps:
653656
break
654657

655658
Cb = Cnew
656659

657-
if diff > eps:
658-
print("Dit not converge.")
660+
if averaged:
661+
Cb = Cb_averaged
659662

660663
if log:
661664
dict_log = {}
662665
dict_log["num_iter"] = it
663-
dict_log["final_diff"] = diff
666+
dict_log["final_diff"] = L_diff[-1]
664667
return Cb, dict_log
665668
else:
666669
return Cb
@@ -708,7 +711,8 @@ def bures_wasserstein_barycenter(
708711
weights : array-like (k), optional
709712
weights for each distribution
710713
method : str
711-
method used for the solver, either 'fixed_point' or 'gradient_descent'
714+
method used for the solver, either 'fixed_point', 'gradient_descent', 'stochastic_gradient_descent' or
715+
'averaged_stochastic_gradient_descent'
712716
num_iter : int, optional
713717
number of iteration for the fixed point algorithm
714718
eps : float, optional
@@ -756,15 +760,35 @@ def bures_wasserstein_barycenter(
756760
# Compute the mean barycenter
757761
mb = nx.sum(m * weights[:, None], axis=0)
758762

759-
if method == "gradient_descent" or batch_size is not None:
763+
if method == "gradient_descent":
760764
out = bures_barycenter_gradient_descent(
761765
C,
762766
weights=weights,
763767
num_iter=num_iter,
764768
eps=eps,
765769
log=log,
766770
step_size=step_size,
767-
batch_size=batch_size,
771+
nx=nx,
772+
)
773+
elif method == "stochastic_gradient_descent":
774+
out = bures_barycenter_gradient_descent(
775+
C,
776+
weights=weights,
777+
num_iter=num_iter,
778+
eps=eps,
779+
log=log,
780+
batch_size=1 if batch_size is None else batch_size,
781+
nx=nx,
782+
)
783+
elif method == "averaged_stochastic_gradient_descent":
784+
out = bures_barycenter_gradient_descent(
785+
C,
786+
weights=weights,
787+
num_iter=num_iter,
788+
eps=eps,
789+
log=log,
790+
batch_size=1 if batch_size is None else batch_size,
791+
averaged=True,
768792
nx=nx,
769793
)
770794
elif method == "fixed_point":

test/test_gaussian.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,15 @@ def test_empirical_bures_wasserstein_distance(nx, bias):
176176
np.testing.assert_allclose(10 * bias, nx.to_numpy(Wb), rtol=1e-2, atol=1e-2)
177177

178178

179-
@pytest.mark.parametrize("method", ["fixed_point", "gradient_descent"])
179+
@pytest.mark.parametrize(
180+
"method",
181+
[
182+
"fixed_point",
183+
"gradient_descent",
184+
"stochastic_gradient_descent",
185+
"averaged_stochastic_gradient_descent",
186+
],
187+
)
180188
def test_bures_wasserstein_barycenter(nx, method):
181189
n = 50
182190
k = 10
@@ -203,15 +211,15 @@ def test_bures_wasserstein_barycenter(nx, method):
203211
)
204212
mb, Cb = ot.gaussian.bures_wasserstein_barycenter(m, C, method=method, log=False)
205213

206-
np.testing.assert_allclose(Cb, Cblog, rtol=1e-2, atol=1e-2)
214+
np.testing.assert_allclose(Cb, Cblog, rtol=1e-1, atol=1e-1)
207215
np.testing.assert_allclose(mb, mblog, rtol=1e-2, atol=1e-2)
208216

209217
# Test weights argument
210218
weights = nx.ones(k) / k
211219
mbw, Cbw = ot.gaussian.bures_wasserstein_barycenter(
212220
m, C, weights=weights, method=method, log=False
213221
)
214-
np.testing.assert_allclose(Cbw, Cb, rtol=1e-2, atol=1e-2)
222+
np.testing.assert_allclose(Cbw, Cb, rtol=1e-1, atol=1e-1)
215223

216224
# test with closed form for diagonal covariance matrices
217225
Cdiag = [nx.diag(nx.diag(C[i])) for i in range(k)]
@@ -266,7 +274,10 @@ def test_fixedpoint_vs_gradientdescent_bures_wasserstein_barycenter(nx):
266274
np.testing.assert_allclose(Cbw2, Cb2, rtol=1e-5, atol=1e-5)
267275

268276

269-
def test_stochastic_gd_bures_wasserstein_barycenter(nx):
277+
@pytest.mark.parametrize(
278+
"method", ["stochastic_gradient_descent", "averaged_stochastic_gradient_descent"]
279+
)
280+
def test_stochastic_gd_bures_wasserstein_barycenter(nx, method):
270281
n = 50
271282
k = 10
272283
X = []
@@ -296,7 +307,7 @@ def test_stochastic_gd_bures_wasserstein_barycenter(nx):
296307
n_samples = [1, 5]
297308
for n in n_samples:
298309
mb2, Cb2 = ot.gaussian.bures_wasserstein_barycenter(
299-
m, C, method="gradient_descent", log=False, batch_size=n
310+
m, C, method=method, log=False, batch_size=n
300311
)
301312

302313
loss2 = nx.mean(
@@ -311,7 +322,7 @@ def test_stochastic_gd_bures_wasserstein_barycenter(nx):
311322

312323
with pytest.raises(ValueError):
313324
mb2, Cb2 = ot.gaussian.bures_wasserstein_barycenter(
314-
m, C, method="gradient_descent", log=False, batch_size=-5
325+
m, C, method=method, log=False, batch_size=-5
315326
)
316327

317328

0 commit comments

Comments
 (0)