Skip to content

Commit d0971ee

Browse files
committed
finished gmm xp + documentation + tweaked circles xp + algo comparison WIP
1 parent 3df869d commit d0971ee

32 files changed

+627
-312
lines changed
0 Bytes
Binary file not shown.
Loading
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
19 Bytes
Binary file not shown.
Binary file not shown.

examples/2d_circle_L2_inexact/src.py

+42-52
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# %%
22
import numpy as np
33
import torch
4-
from ot_bar.solvers import solve_OT_barycenter_fixed_point, solve_NLGWB_GD, StoppingCriterionReached
4+
from ot_bar.solvers import solve_OT_barycenter_fixed_point, solve_OT_barycenter_GD, StoppingCriterionReached
55
from ot_bar.utils import TT, TN
66
import matplotlib.pyplot as plt
77
from time import time
@@ -19,7 +19,7 @@
1919
d = 2 # dimensions of the original measure
2020
K = 4 # number of measures to barycentre
2121
m = 50 # number of points of the measures
22-
a_list = TT([ot.unif(m)] * K) # weights of the 4 measures
22+
b_list = TT([ot.unif(m)] * K) # weights of the 4 measures
2323
weights = TT(ot.unif(K)) # weights for the barycentre
2424

2525

@@ -51,21 +51,38 @@ def proj_circle(X: torch.tensor, origin: torch.tensor, radius: float):
5151
Y_list.append(P_list[k](X_temp))
5252

5353

54+
# cost_list[k] is a function taking x (n, d) and y (n_k, d_k) and returning a
55+
# (n, n_k) matrix of costs
56+
def c1(x, y):
57+
return ot.dist(P_list[0](x), y)
58+
59+
60+
def c2(x, y):
61+
return ot.dist(P_list[1](x), y)
62+
63+
64+
def c3(x, y):
65+
return ot.dist(P_list[2](x), y)
66+
67+
68+
def c4(x, y):
69+
return ot.dist(P_list[3](x), y)
70+
71+
72+
cost_list = [c1, c2, c3, c4]
5473
# %% Find generalised barycenter using gradient descent
5574
# optimiser parameters
5675
learning_rate = 30 # initial learning rate
5776
its = 2000 # Gradient Descent iterations
58-
stop_threshold = 1e-20 # stops if |loss_{t+1} - loss_{t}| < this
77+
stop_threshold = 1e-5 # stops if |X_{t+1} - X_{t}| / |X_t| < this
5978
gamma = 1 # learning rate at step t is initial learning rate * gamma^t
6079
np.random.seed(42)
6180
torch.manual_seed(42)
6281
t0 = time()
63-
X_bar, b, loss_list, exit_status = solve_NLGWB_GD(Y_list, a_list, weights,
64-
P_list, n, d, return_exit_status=True, eta_init=learning_rate,
65-
its=its, stop_threshold=stop_threshold,
66-
gamma=gamma)
82+
X_bar, a, log_dict = solve_OT_barycenter_GD(
83+
Y_list, b_list, weights, cost_list, n, d, eta_init=learning_rate, its=its, stop_threshold=stop_threshold, gamma=gamma, log=True)
6784
dt = time() - t0
68-
print(f"Finished in {dt:.2f}s, exit status: {exit_status}, final loss: {loss_list[-1]:.10f}")
85+
print(f"Finished in {dt:.2f}s, exit status: {log_dict['exit_status']}, final loss: {log_dict['loss_list'][-1]:.10f}")
6986

7087
# %% Plot GD barycentre
7188
alpha = .5
@@ -81,33 +98,12 @@ def proj_circle(X: torch.tensor, origin: torch.tensor, radius: float):
8198
plt.savefig('gwb_circles_gd.pdf')
8299

83100
# %% Plot GD barycentre loss
84-
plt.plot(loss_list)
101+
plt.plot(log_dict['loss_list'])
85102
plt.yscale('log')
86103
plt.savefig('gwb_circles_gd_loss.pdf')
87104

88105

89106
# %% Solve with fixed-point iterations: studying the energy for the function B
90-
# cost_list[k] is a function taking x (n, d) and y (n_k, d_k) and returning a
91-
# (n, n_k) matrix of costs
92-
def c1(x, y):
93-
return ot.dist(P_list[0](x), y)
94-
95-
96-
def c2(x, y):
97-
return ot.dist(P_list[1](x), y)
98-
99-
100-
def c3(x, y):
101-
return ot.dist(P_list[2](x), y)
102-
103-
104-
def c4(x, y):
105-
return ot.dist(P_list[3](x), y)
106-
107-
108-
cost_list = [c1, c2, c3, c4]
109-
110-
111107
def C(x, y):
112108
"""
113109
Computes the barycenter cost for candidate points x (n, d) and
@@ -144,7 +140,7 @@ def C(x, y):
144140
torch.manual_seed(42)
145141

146142

147-
def B(y, its=250, lr=1, log=False, stop_threshold=1e-20):
143+
def B(y, its=150, lr=1, log=False, stop_threshold=stop_threshold):
148144
"""
149145
Computes the barycenter images for candidate points x (n, d) and
150146
measure supports y: List(n, d_k).
@@ -157,12 +153,15 @@ def B(y, its=250, lr=1, log=False, stop_threshold=1e-20):
157153
exit_status = 'unknown'
158154
try:
159155
for _ in range(its):
156+
x_prev = x.data.clone()
160157
opt.zero_grad()
161158
loss = torch.sum(C(x, y))
162159
loss.backward()
163160
opt.step()
164161
loss_list.append(loss.item())
165-
if stop_threshold > loss_list[-2] - loss_list[-1] >= 0:
162+
diff = torch.sum((x.data - x_prev)**2)
163+
current = torch.sum((x_prev)**2)
164+
if diff / current < stop_threshold:
166165
exit_status = 'Local optimum'
167166
raise StoppingCriterionReached
168167
exit_status = 'Max iterations reached'
@@ -180,7 +179,7 @@ def B(y, its=250, lr=1, log=False, stop_threshold=1e-20):
180179
Y_perm = []
181180
for k in range(K):
182181
Y_perm.append(n * pi_list[k] @ Y_list[k])
183-
Bx, log = B(Y_perm, its=500, lr=1, log=True)
182+
Bx, log = B(Y_perm, its=150, lr=1, log=True)
184183
plt.plot(log['loss_list'])
185184
plt.yscale('log')
186185
plt.savefig('gwb_circles_B_loss.pdf')
@@ -190,19 +189,21 @@ def B(y, its=250, lr=1, log=False, stop_threshold=1e-20):
190189
np.random.seed(0)
191190
torch.manual_seed(0)
192191

192+
t0 = time()
193193
fixed_point_its = 15
194194
X_init = torch.rand(n, d, device=device, dtype=torch.double)
195195
b_list = [TT(ot.unif(m))] * K
196-
X_bar, X_bar_list = solve_OT_barycenter_fixed_point(X_init, Y_list, b_list,
197-
cost_list,
198-
B, max_its=fixed_point_its, pbar=True, log=True)
196+
X_bar, log_dict = solve_OT_barycenter_fixed_point(
197+
X_init, Y_list, b_list, cost_list, B, max_its=fixed_point_its, pbar=True, log=True, stop_threshold=stop_threshold)
198+
dt = time() - t0
199+
print(f"Finished in {dt:.2f}s, exit status: {log_dict['exit_status']}")
199200

200201
# %% plot fixed-point barycentre final step
201202
alpha = .5
202203
labels = ['circle 1', 'circle 2', 'circle 3', 'circle 4']
203204
for Y, label in zip(Y_list, labels):
204205
plt.scatter(*TN(Y).T, alpha=alpha, label=label)
205-
plt.scatter(*TN(X_bar_list[-1]).T, label='GWB', c='black', alpha=alpha)
206+
plt.scatter(*TN(log_dict['X_list'][-1]).T, label='GWB', c='black', alpha=alpha)
206207
plt.axis('equal')
207208
plt.xlim(-.3, 1.3)
208209
plt.ylim(-.3, 1.3)
@@ -211,7 +212,7 @@ def B(y, its=250, lr=1, log=False, stop_threshold=1e-20):
211212
plt.savefig('gwb_circles_fixed_point.pdf')
212213

213214
# %% animate fixed-point barycentre steps
214-
num_frames = fixed_point_its + 1 # +1 for initialisation
215+
num_frames = len(log_dict['X_list'])
215216
fig, ax = plt.subplots()
216217
ax.set_xlim(-.3, 1.3)
217218
ax.set_ylim(-.3, 1.3)
@@ -230,7 +231,7 @@ def B(y, its=250, lr=1, log=False, stop_threshold=1e-20):
230231

231232
def update(frame): # Update function for animation
232233
# Update moving scatterplot data
233-
moving_scatter.set_offsets(TN(X_bar_list[frame]))
234+
moving_scatter.set_offsets(TN(log_dict['X_list'][frame]))
234235
return moving_scatter,
235236

236237

@@ -245,7 +246,7 @@ def update(frame): # Update function for animation
245246
for i, ax in enumerate(axes):
246247
for Y, label in zip(Y_list, labels):
247248
ax.scatter(*TN(Y).T, alpha=alpha, label=label)
248-
ax.scatter(*TN(X_bar_list[i]).T, label='GWB', c='black', alpha=alpha)
249+
ax.scatter(*TN(log_dict['X_list'][i]).T, label='GWB', c='black', alpha=alpha)
249250
ax.axis('equal')
250251
ax.axis('off')
251252
ax.set_xlim(-.3, 1.3)
@@ -257,7 +258,7 @@ def update(frame): # Update function for animation
257258
V_list = []
258259
a = TT(ot.unif(n))
259260
b = TT(ot.unif(m))
260-
for X in X_bar_list:
261+
for X in log_dict['X_list']:
261262
V = 0
262263
for k in range(K):
263264
V += (1 / K) * ot.emd2(a, b, ot.dist(P_list[k](X), Y_list[k]))
@@ -268,14 +269,3 @@ def update(frame): # Update function for animation
268269
plt.savefig('gwb_circles_fixed_point_V.pdf')
269270

270271
# %%
271-
X = X_bar_list[0]
272-
pi_list = [ot.emd(a, b, cost_list[k](X, Y_list[k])) for k in range(K)]
273-
Y_perm = []
274-
for k in range(K):
275-
Y_perm.append(n * pi_list[k] @ Y_list[k])
276-
X, log = B(Y_perm, log=True)
277-
plt.plot(log['loss_list'])
278-
plt.yscale('log')
279-
plt.savefig('gwb_circles_B_loss.pdf')
280-
281-
# %%
0 Bytes
Binary file not shown.
Loading
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
-9 Bytes
Binary file not shown.
-925 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)