Skip to content

Commit 51722bf

Browse files
committed
implementation comments
1 parent 6bd4af8 commit 51722bf

File tree

2 files changed

+178
-54
lines changed

2 files changed

+178
-54
lines changed

ot/lp/_barycenter_solvers.py

+88-45
Original file line numberDiff line numberDiff line change
@@ -199,14 +199,12 @@ def free_support_barycenter(
199199
measures_weights : list of N (k_i,) array-like
200200
Numpy arrays where each numpy array has :math:`k_i` non-negatives values summing to one
201201
representing the weights of each discrete input measure
202-
203202
X_init : (k,d) array-like
204203
Initialization of the support locations (on `k` atoms) of the barycenter
205204
b : (k,) array-like
206205
Initialization of the weights of the barycenter (non-negatives, sum to 1)
207206
weights : (N,) array-like
208207
Initialization of the coefficients of the barycenter (non-negatives, sum to 1)
209-
210208
numItermax : int, optional
211209
Max number of iterations
212210
stopThr : float, optional
@@ -219,13 +217,11 @@ def free_support_barycenter(
219217
If compiled with OpenMP, chooses the number of threads to parallelize.
220218
"max" selects the highest number possible.
221219
222-
223220
Returns
224221
-------
225222
X : (k,d) array-like
226223
Support locations (on k atoms) of the barycenter
227224
228-
229225
.. _references-free-support-barycenter:
230226
References
231227
----------
@@ -428,20 +424,20 @@ def generalized_free_support_barycenter(
428424
return Y
429425

430426

431-
class StoppingCriterionReached(Exception):
432-
pass
433-
434-
435427
def free_support_barycenter_generic_costs(
436428
measure_locations,
437429
measure_weights,
438430
X_init,
439431
cost_list,
440-
B,
432+
ground_bary=None,
441433
a=None,
442434
numItermax=100,
443435
stopThr=1e-5,
444436
log=False,
437+
ground_bary_lr=1e-2,
438+
ground_bary_numItermax=100,
439+
ground_bary_stopThr=1e-5,
440+
ground_bary_solver="SGD",
445441
):
446442
r"""
447443
Solves the OT barycenter problem for generic costs using the fixed point
@@ -507,14 +503,15 @@ def free_support_barycenter_generic_costs(
507503
List of K arrays of measure weights, each of shape (m_k).
508504
X_init : array-like
509505
Array of shape (n, d) representing initial barycenter points.
510-
cost_list : list of callable
506+
cost_list : list of callable or callable
511507
List of K cost functions :math:`c_k: \mathbb{R}^{n\times
512508
d}\times\mathbb{R}^{m_k\times d_k} \rightarrow \mathbb{R}_+^{n\times
513-
m_k}`.
514-
B : callable
509+
m_k}`. If cost_list is a single callable, the same cost is used K times.
510+
ground_bary : callable or None, optional
515511
Function List(array(n, d_k)) -> array(n, d) accepting a list of K arrays
516512
of shape (n\times d_K), computing the ground barycenters (broadcasted
517-
over n).
513+
over n). If not provided, done with Adam on PyTorch (requires PyTorch
514+
backend)
518515
a : array-like, optional
519516
Array of shape (n,) representing weights of the barycenter
520517
measure.Defaults to uniform.
@@ -524,6 +521,16 @@ def free_support_barycenter_generic_costs(
524521
If the iterations move less than this, terminate (default is 1e-5).
525522
log : bool, optional
526523
Whether to return the log dictionary (default is False).
524+
ground_bary_lr : float, optional
525+
Learning rate for the ground barycenter solver (if auto is used).
526+
ground_bary_numItermax : int, optional
527+
Maximum number of iterations for the ground barycenter solver (if auto
528+
is used).
529+
ground_bary_stopThr : float, optional
530+
Stop threshold for the ground barycenter solver (if auto is used).
531+
ground_bary_solver : str, optional
532+
Solver for auto ground bary solver (torch SGD or Adam). Default is
533+
"SGD".
527534
528535
Returns
529536
-------
@@ -549,49 +556,85 @@ def free_support_barycenter_generic_costs(
549556
See Also
550557
--------
551558
ot.lp.free_support_barycenter : Free support solver for the case where
552-
:math:`c_k(x,y) = \|x-y\|_2^2`. ot.lp.generalized_free_support_barycenter :
553-
Free support solver for the case where :math:`c_k(x,y) = \|P_kx-y\|_2^2`
554-
with :math:`P_k` linear.
559+
:math:`c_k(x,y) = \lambda_k\|x-y\|_2^2`.
560+
ot.lp.generalized_free_support_barycenter : Free support solver for the case
561+
where :math:`c_k(x,y) = \|P_kx-y\|_2^2` with :math:`P_k` linear.
555562
"""
556563
nx = get_backend(X_init, measure_locations[0])
557564
K = len(measure_locations)
558565
n = X_init.shape[0]
559566
if a is None:
560567
a = nx.ones(n, type_as=X_init) / n
568+
if callable(cost_list): # use the given cost for all K pairs
569+
cost_list = [cost_list] * K
570+
auto_ground_bary = False
571+
572+
if ground_bary is None:
573+
auto_ground_bary = True
574+
assert str(nx) == "torch", (
575+
f"Backend {str(nx)} is not compatible with ground_bary=None, it"
576+
"must be provided if not using PyTorch backend"
577+
)
578+
try:
579+
import torch
580+
from torch.optim import Adam, SGD
581+
582+
def ground_bary(y, x_init):
583+
x = x_init.clone().detach().requires_grad_(True)
584+
solver = Adam if ground_bary_solver == "Adam" else SGD
585+
opt = solver([x], lr=ground_bary_lr)
586+
for _ in range(ground_bary_numItermax):
587+
x_prev = x.data.clone()
588+
opt.zero_grad()
589+
# inefficient cost computation but compatible
590+
# with the choice of cost_list[k] giving the cost matrix
591+
loss = torch.sum(
592+
torch.stack(
593+
[torch.diag(cost_list[k](x, y[k])) for k in range(K)]
594+
)
595+
)
596+
loss.backward()
597+
opt.step()
598+
diff = torch.sum((x.data - x_prev) ** 2)
599+
if diff < ground_bary_stopThr:
600+
break
601+
return x.detach()
602+
603+
except ImportError:
604+
raise ImportError("PyTorch is required to use ground_bary=None")
605+
561606
X_list = [X_init] if log else [] # store the iterations
562607
X = X_init
563608
dX_list = [] # store the displacement squared norms
564-
exit_status = "Unknown"
565-
566-
try:
567-
for _ in range(numItermax):
568-
pi_list = [ # compute the pairwise transport plans
569-
emd(a, measure_weights[k], cost_list[k](X, measure_locations[k]))
570-
for k in range(K)
571-
]
572-
Y_perm = []
573-
for k in range(K): # compute barycentric projections
574-
Y_perm.append(n * pi_list[k] @ measure_locations[k])
575-
X_next = B(Y_perm)
576-
577-
if log:
578-
X_list.append(X_next)
609+
exit_status = "Max iterations reached"
610+
611+
for _ in range(numItermax):
612+
pi_list = [ # compute the pairwise transport plans
613+
emd(a, measure_weights[k], cost_list[k](X, measure_locations[k]))
614+
for k in range(K)
615+
]
616+
Y_perm = []
617+
for k in range(K): # compute barycentric projections
618+
Y_perm.append(n * pi_list[k] @ measure_locations[k])
619+
if auto_ground_bary: # use previous position as initialization
620+
X_next = ground_bary(Y_perm, X)
621+
else:
622+
X_next = ground_bary(Y_perm)
579623

580-
# stationary criterion: move less than the threshold
581-
dX = nx.sum((X - X_next) ** 2)
582-
X = X_next
624+
if log:
625+
X_list.append(X_next)
583626

584-
if log:
585-
dX_list.append(dX)
627+
# stationary criterion: move less than the threshold
628+
dX = nx.sum((X - X_next) ** 2)
629+
X = X_next
586630

587-
if dX < stopThr:
588-
exit_status = "Stationary Point"
589-
raise StoppingCriterionReached
631+
if log:
632+
dX_list.append(dX)
590633

591-
exit_status = "Max iterations reached"
592-
raise StoppingCriterionReached
634+
if dX < stopThr:
635+
exit_status = "Stationary Point"
636+
break
593637

594-
except StoppingCriterionReached:
595-
if log:
596-
return X, {"X_list": X_list, "exit_status": exit_status, "dX_list": dX_list}
597-
return X
638+
if log:
639+
return X, {"X_list": X_list, "exit_status": exit_status, "dX_list": dX_list}
640+
return X

test/test_ot.py

+90-9
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@
1313
from ot.datasets import make_1D_gauss as gauss
1414
from ot.backend import torch, tf
1515

16-
# import ot.lp._barycenter_solvers # TODO: remove this import
17-
1816

1917
def test_emd_dimension_and_mass_mismatch():
2018
# test emd and emd2 for dimension mismatch
@@ -414,14 +412,14 @@ def cost(x, y):
414412

415413
cost_list = [cost, cost]
416414

417-
def B(y):
415+
def ground_bary(y):
418416
out = 0
419417
for yk in y:
420418
out += yk / len(y)
421419
return out
422420

423421
X = ot.lp.free_support_barycenter_generic_costs(
424-
measures_locations, measures_weights, X_init, cost_list, B
422+
measures_locations, measures_weights, X_init, cost_list, ground_bary
425423
)
426424

427425
np.testing.assert_allclose(X, bar_locations, rtol=1e-5, atol=1e-7)
@@ -432,7 +430,7 @@ def B(y):
432430
measures_weights,
433431
X_init,
434432
cost_list,
435-
B,
433+
ground_bary,
436434
a=ot.unif(1),
437435
log=True,
438436
)
@@ -449,12 +447,95 @@ def B(y):
449447
measures_weights,
450448
X_init,
451449
cost_list,
452-
B,
450+
ground_bary,
453451
numItermax=1,
454452
log=True,
455453
)
456454
assert log2["exit_status"] == "Max iterations reached"
457455

456+
# test with a single callable cost
457+
X3, log3 = ot.lp.free_support_barycenter_generic_costs(
458+
measures_locations,
459+
measures_weights,
460+
X_init,
461+
cost,
462+
ground_bary,
463+
numItermax=1,
464+
log=True,
465+
)
466+
467+
# test with no ground_bary but in numpy: requires pytorch backend
468+
with pytest.raises(AssertionError):
469+
ot.lp.free_support_barycenter_generic_costs(
470+
measures_locations,
471+
measures_weights,
472+
X_init,
473+
cost_list,
474+
ground_bary=None,
475+
numItermax=1,
476+
)
477+
478+
479+
@pytest.mark.skipif(not torch, reason="No torch available")
480+
def test_free_support_barycenter_generic_costs_auto_ground_bary():
481+
measures_locations = [
482+
torch.tensor([1.0]).reshape((1, 1)),
483+
torch.tensor([2.0]).reshape((1, 1)),
484+
]
485+
measures_weights = [torch.tensor([1.0]), torch.tensor([1.0])]
486+
487+
X_init = torch.tensor([1.2]).reshape((1, 1))
488+
489+
def cost(x, y):
490+
return ot.dist(x, y)
491+
492+
cost_list = [cost, cost]
493+
494+
def ground_bary(y):
495+
out = 0
496+
for yk in y:
497+
out += yk / len(y)
498+
return out
499+
500+
X = ot.lp.free_support_barycenter_generic_costs(
501+
measures_locations,
502+
measures_weights,
503+
X_init,
504+
cost_list,
505+
ground_bary,
506+
numItermax=1,
507+
)
508+
509+
X2, log2 = ot.lp.free_support_barycenter_generic_costs(
510+
measures_locations,
511+
measures_weights,
512+
X_init,
513+
cost_list,
514+
ground_bary=None,
515+
ground_bary_lr=1e-2,
516+
ground_bary_stopThr=1e-20,
517+
ground_bary_numItermax=50,
518+
numItermax=10,
519+
log=True,
520+
)
521+
522+
np.testing.assert_allclose(X2.numpy(), X.numpy(), rtol=1e-4, atol=1e-4)
523+
524+
X3 = ot.lp.free_support_barycenter_generic_costs(
525+
measures_locations,
526+
measures_weights,
527+
X_init,
528+
cost_list,
529+
ground_bary=None,
530+
ground_bary_lr=1e-2,
531+
ground_bary_stopThr=1e-20,
532+
ground_bary_numItermax=50,
533+
numItermax=10,
534+
ground_bary_solver="Adam",
535+
)
536+
537+
np.testing.assert_allclose(X2.numpy(), X3.numpy(), rtol=1e-3, atol=1e-3)
538+
458539

459540
def test_free_support_barycenter_generic_costs_backends(nx):
460541
measures_locations = [
@@ -469,22 +550,22 @@ def cost(x, y):
469550

470551
cost_list = [cost, cost]
471552

472-
def B(y):
553+
def ground_bary(y):
473554
out = 0
474555
for yk in y:
475556
out += yk / len(y)
476557
return out
477558

478559
X = ot.lp.free_support_barycenter_generic_costs(
479-
measures_locations, measures_weights, X_init, cost_list, B
560+
measures_locations, measures_weights, X_init, cost_list, ground_bary
480561
)
481562

482563
measures_locations2 = nx.from_numpy(*measures_locations)
483564
measures_weights2 = nx.from_numpy(*measures_weights)
484565
X_init2 = nx.from_numpy(X_init)
485566

486567
X2 = ot.lp.free_support_barycenter_generic_costs(
487-
measures_locations2, measures_weights2, X_init2, cost_list, B
568+
measures_locations2, measures_weights2, X_init2, cost_list, ground_bary
488569
)
489570

490571
np.testing.assert_allclose(X, nx.to_numpy(X2))

0 commit comments

Comments
 (0)