Skip to content

Commit 5952019

Browse files
committed
ot bar demo
1 parent 081e4eb commit 5952019

File tree

9 files changed

+177
-9
lines changed

9 files changed

+177
-9
lines changed
Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
=====================================
4+
OT Barycenter with Generic Costs Demo
5+
=====================================
6+
7+
This example illustrates the computation of an Optimal Transport for a ground
8+
cost that is not a power of a norm. We take the example of ground costs
9+
:math:`c_k(x, y) = |P_k(x)-y|^2`, where :math:`P_k` is the (non-linear)
10+
projection onto a circle k. This is an example of the fixed-point barycenter
11+
solver introduced in [74] which generalises [20].
12+
13+
The ground barycenter function :math:`B(y_1, ..., y_K)` = \mathrm{argmin}_{x \in
14+
\mathbb{R}^2} \sum_k \lambda_k c_k(x, y_k) is computed by gradient descent over
15+
:math:`x` with Pytorch.
16+
17+
[74] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). [Computing
18+
Barycentres of Measures for Generic Transport
19+
Costs](https://arxiv.org/abs/2501.04016). arXiv preprint 2501.04016 (2024)
20+
21+
[20] Cuturi, M. and Doucet, A. (2014) [Fast Computation of Wasserstein
22+
Barycenters](http://proceedings.mlr.press/v32/cuturi14.html). International
23+
Conference in Machine Learning
24+
25+
"""
26+
27+
# Author: Eloi Tanguy <[email protected]>
28+
#
29+
# License: MIT License
30+
31+
# sphinx_gallery_thumbnail_number = 1
32+
33+
# %% Generate data
34+
import torch
35+
from torch.optim import Adam
36+
from ot.utils import dist
37+
import numpy as np
38+
from ot.lp import free_support_barycenter_generic_costs
39+
import matplotlib.pyplot as plt
40+
41+
42+
torch.manual_seed(42)
43+
44+
n = 100 # number of points of the of the barycentre
45+
d = 2 # dimensions of the original measure
46+
K = 4 # number of measures to barycentre
47+
m = 50 # number of points of the measures
48+
b_list = [torch.ones(m) / m] * K # weights of the 4 measures
49+
weights = torch.ones(K) / K # weights for the barycentre
50+
stop_threshold = 1e-20 # stop threshold for B and for fixed-point algo
51+
52+
53+
# map R^2 -> R^2 projection onto circle
54+
def proj_circle(X, origin, radius):
55+
diffs = X - origin[None, :]
56+
norms = torch.norm(diffs, dim=1)
57+
return origin[None, :] + radius * diffs / norms[:, None]
58+
59+
60+
# circles on which to project
61+
origin1 = torch.tensor([-1.0, -1.0])
62+
origin2 = torch.tensor([-1.0, 2.0])
63+
origin3 = torch.tensor([2.0, 2.0])
64+
origin4 = torch.tensor([2.0, -1.0])
65+
r = np.sqrt(2)
66+
P_list = [
67+
lambda X: proj_circle(X, origin1, r),
68+
lambda X: proj_circle(X, origin2, r),
69+
lambda X: proj_circle(X, origin3, r),
70+
lambda X: proj_circle(X, origin4, r),
71+
]
72+
73+
# measures to barycentre are projections of different random circles
74+
# onto the K circles
75+
Y_list = []
76+
for k in range(K):
77+
t = torch.rand(m) * 2 * np.pi
78+
X_temp = 0.5 * torch.stack([torch.cos(t), torch.sin(t)], axis=1)
79+
X_temp = X_temp + torch.tensor([0.5, 0.5])[None, :]
80+
Y_list.append(P_list[k](X_temp))
81+
82+
83+
# %% Define costs and ground barycenter function
84+
# cost_list[k] is a function taking x (n, d) and y (n_k, d_k) and returning a
85+
# (n, n_k) matrix of costs
86+
def c1(x, y):
87+
return dist(P_list[0](x), y)
88+
89+
90+
def c2(x, y):
91+
return dist(P_list[1](x), y)
92+
93+
94+
def c3(x, y):
95+
return dist(P_list[2](x), y)
96+
97+
98+
def c4(x, y):
99+
return dist(P_list[3](x), y)
100+
101+
102+
cost_list = [c1, c2, c3, c4]
103+
104+
105+
# batched total ground cost function for candidate points x (n, d)
106+
# for computation of the ground barycenter B with gradient descent
107+
def C(x, y):
108+
"""
109+
Computes the barycenter cost for candidate points x (n, d) and
110+
measure supports y: List(n, d_k).
111+
"""
112+
n = x.shape[0]
113+
K = len(y)
114+
out = torch.zeros(n)
115+
for k in range(K):
116+
out += (1 / K) * torch.sum((P_list[k](x) - y[k]) ** 2, axis=1)
117+
return out
118+
119+
120+
# ground barycenter function
121+
def B(y, its=150, lr=1, stop_threshold=stop_threshold):
122+
"""
123+
Computes the ground barycenter for measure supports y: List(n, d_k).
124+
Output: (n, d) array
125+
"""
126+
x = torch.randn(n, d)
127+
x.requires_grad_(True)
128+
opt = Adam([x], lr=lr)
129+
for _ in range(its):
130+
x_prev = x.data.clone()
131+
opt.zero_grad()
132+
loss = torch.sum(C(x, y))
133+
loss.backward()
134+
opt.step()
135+
diff = torch.sum((x.data - x_prev) ** 2)
136+
if diff < stop_threshold:
137+
break
138+
return x
139+
140+
141+
# %% Compute the barycenter measure
142+
fixed_point_its = 10
143+
X_init = torch.rand(n, d)
144+
X_bar = free_support_barycenter_generic_costs(
145+
X_init,
146+
Y_list,
147+
b_list,
148+
cost_list,
149+
B,
150+
max_its=fixed_point_its,
151+
stop_threshold=stop_threshold,
152+
)
153+
154+
# %% Plot Barycenter (Iteration 10)
155+
alpha = 0.5
156+
labels = ["circle 1", "circle 2", "circle 3", "circle 4"]
157+
for Y, label in zip(Y_list, labels):
158+
plt.scatter(*(Y.numpy()).T, alpha=alpha, label=label)
159+
plt.scatter(*(X_bar.detach().numpy()).T, label="Barycenter", c="black", alpha=alpha)
160+
plt.axis("equal")
161+
plt.xlim(-0.3, 1.3)
162+
plt.ylim(-0.3, 1.3)
163+
plt.axis("off")
164+
plt.legend()
165+
plt.tight_layout()
166+
167+
# %%

examples/barycenters/plot_generalized_free_support_barycenter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
1515
"""
1616

17-
# Author: Eloi Tanguy <eloi.tanguy@polytechnique.edu>
17+
# Author: Eloi Tanguy <eloi.tanguy@math.cnrs.fr>
1818
#
1919
# License: MIT License
2020

examples/others/plot_GMMOT_plan.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
1717
"""
1818

19-
# Author: Eloi Tanguy <eloi.tanguy@u-paris>
19+
# Author: Eloi Tanguy <eloi.tanguy@math.cnrs.fr>
2020
# Remi Flamary <[email protected]>
2121
# Julie Delon <[email protected]>
2222
#

examples/others/plot_GMM_flow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
1111
"""
1212

13-
# Author: Eloi Tanguy <eloi.tanguy@u-paris>
13+
# Author: Eloi Tanguy <eloi.tanguy@math.cnrs.fr>
1414
# Remi Flamary <[email protected]>
1515
# Julie Delon <[email protected]>
1616
#

examples/others/plot_SSNB.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
2017.
3939
"""
4040

41-
# Author: Eloi Tanguy <eloi.tanguy@u-paris.fr>
41+
# Author: Eloi Tanguy <eloi.tanguy@math.cnrs.fr>
4242
# License: MIT License
4343

4444
# sphinx_gallery_thumbnail_number = 3

ot/gmm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
Optimal transport for Gaussian Mixtures
44
"""
55

6-
# Author: Eloi Tanguy <eloi.tanguy@u-paris>
7-
# Remi Flamary <remi.flamary@polytehnique.edu>
6+
# Author: Eloi Tanguy <eloi.tanguy@math.cnrs.fr>
7+
# Remi Flamary <remi.flamary@polytechnique.edu>
88
# Julie Delon <[email protected]>
99
#
1010
# License: MIT License

ot/lp/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,13 @@
88
#
99
# License: MIT License
1010

11-
from . import cvx
1211
from .dmmot import dmmot_monge_1dgrid_loss, dmmot_monge_1dgrid_optimize
1312
from ._network_simplex import emd, emd2
1413
from ._barycenter_solvers import (
1514
barycenter,
1615
free_support_barycenter,
1716
generalized_free_support_barycenter,
17+
free_support_barycenter_generic_costs,
1818
)
1919
from ..utils import check_number_threads
2020

@@ -46,4 +46,5 @@
4646
"dmmot_monge_1dgrid_loss",
4747
"dmmot_monge_1dgrid_optimize",
4848
"check_number_threads",
49+
"free_support_barycenter_generic_costs",
4950
]

ot/lp/_barycenter_solvers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -428,7 +428,7 @@ class StoppingCriterionReached(Exception):
428428
pass
429429

430430

431-
def solve_OT_barycenter_fixed_point(
431+
def free_support_barycenter_generic_costs(
432432
X_init,
433433
Y_list,
434434
b_list,

ot/mapping.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
use it you need to explicitly import :mod:`ot.mapping`
88
"""
99

10-
# Author: Eloi Tanguy <eloi.tanguy@u-paris.fr>
10+
# Author: Eloi Tanguy <eloi.tanguy@math.cnrs.fr>
1111
# Remi Flamary <[email protected]>
1212
#
1313
# License: MIT License

0 commit comments

Comments
 (0)