1
1
# %%
2
2
import numpy as np
3
3
import torch
4
- from ot_bar .solvers import solve_OT_barycenter_fixed_point , solve_NLGWB_GD , StoppingCriterionReached
4
+ from ot_bar .solvers import solve_OT_barycenter_fixed_point , solve_OT_barycenter_GD , StoppingCriterionReached
5
5
from ot_bar .utils import TT , TN
6
6
import matplotlib .pyplot as plt
7
7
from time import time
19
19
d = 2 # dimensions of the original measure
20
20
K = 4 # number of measures to barycentre
21
21
m = 50 # number of points of the measures
22
- a_list = TT ([ot .unif (m )] * K ) # weights of the 4 measures
22
+ b_list = TT ([ot .unif (m )] * K ) # weights of the 4 measures
23
23
weights = TT (ot .unif (K )) # weights for the barycentre
24
24
25
25
@@ -51,21 +51,38 @@ def proj_circle(X: torch.tensor, origin: torch.tensor, radius: float):
51
51
Y_list .append (P_list [k ](X_temp ))
52
52
53
53
54
+ # cost_list[k] is a function taking x (n, d) and y (n_k, d_k) and returning a
55
+ # (n, n_k) matrix of costs
56
+ def c1 (x , y ):
57
+ return ot .dist (P_list [0 ](x ), y )
58
+
59
+
60
+ def c2 (x , y ):
61
+ return ot .dist (P_list [1 ](x ), y )
62
+
63
+
64
+ def c3 (x , y ):
65
+ return ot .dist (P_list [2 ](x ), y )
66
+
67
+
68
+ def c4 (x , y ):
69
+ return ot .dist (P_list [3 ](x ), y )
70
+
71
+
72
+ cost_list = [c1 , c2 , c3 , c4 ]
54
73
# %% Find generalised barycenter using gradient descent
55
74
# optimiser parameters
56
75
learning_rate = 30 # initial learning rate
57
76
its = 2000 # Gradient Descent iterations
58
- stop_threshold = 1e-20 # stops if |loss_ {t+1} - loss_ {t}| < this
77
+ stop_threshold = 1e-5 # stops if |X_ {t+1} - X_ {t}| / |X_t | < this
59
78
gamma = 1 # learning rate at step t is initial learning rate * gamma^t
60
79
np .random .seed (42 )
61
80
torch .manual_seed (42 )
62
81
t0 = time ()
63
- X_bar , b , loss_list , exit_status = solve_NLGWB_GD (Y_list , a_list , weights ,
64
- P_list , n , d , return_exit_status = True , eta_init = learning_rate ,
65
- its = its , stop_threshold = stop_threshold ,
66
- gamma = gamma )
82
+ X_bar , a , log_dict = solve_OT_barycenter_GD (
83
+ Y_list , b_list , weights , cost_list , n , d , eta_init = learning_rate , its = its , stop_threshold = stop_threshold , gamma = gamma , log = True )
67
84
dt = time () - t0
68
- print (f"Finished in { dt :.2f} s, exit status: { exit_status } , final loss: { loss_list [- 1 ]:.10f} " )
85
+ print (f"Finished in { dt :.2f} s, exit status: { log_dict [ ' exit_status' ] } , final loss: { log_dict [ ' loss_list' ] [- 1 ]:.10f} " )
69
86
70
87
# %% Plot GD barycentre
71
88
alpha = .5
@@ -81,33 +98,12 @@ def proj_circle(X: torch.tensor, origin: torch.tensor, radius: float):
81
98
plt .savefig ('gwb_circles_gd.pdf' )
82
99
83
100
# %% Plot GD barycentre loss
84
- plt .plot (loss_list )
101
+ plt .plot (log_dict [ ' loss_list' ] )
85
102
plt .yscale ('log' )
86
103
plt .savefig ('gwb_circles_gd_loss.pdf' )
87
104
88
105
89
106
# %% Solve with fixed-point iterations: studying the energy for the function B
90
- # cost_list[k] is a function taking x (n, d) and y (n_k, d_k) and returning a
91
- # (n, n_k) matrix of costs
92
- def c1 (x , y ):
93
- return ot .dist (P_list [0 ](x ), y )
94
-
95
-
96
- def c2 (x , y ):
97
- return ot .dist (P_list [1 ](x ), y )
98
-
99
-
100
- def c3 (x , y ):
101
- return ot .dist (P_list [2 ](x ), y )
102
-
103
-
104
- def c4 (x , y ):
105
- return ot .dist (P_list [3 ](x ), y )
106
-
107
-
108
- cost_list = [c1 , c2 , c3 , c4 ]
109
-
110
-
111
107
def C (x , y ):
112
108
"""
113
109
Computes the barycenter cost for candidate points x (n, d) and
@@ -144,7 +140,7 @@ def C(x, y):
144
140
torch .manual_seed (42 )
145
141
146
142
147
- def B (y , its = 250 , lr = 1 , log = False , stop_threshold = 1e-20 ):
143
+ def B (y , its = 150 , lr = 1 , log = False , stop_threshold = stop_threshold ):
148
144
"""
149
145
Computes the barycenter images for candidate points x (n, d) and
150
146
measure supports y: List(n, d_k).
@@ -157,12 +153,15 @@ def B(y, its=250, lr=1, log=False, stop_threshold=1e-20):
157
153
exit_status = 'unknown'
158
154
try :
159
155
for _ in range (its ):
156
+ x_prev = x .data .clone ()
160
157
opt .zero_grad ()
161
158
loss = torch .sum (C (x , y ))
162
159
loss .backward ()
163
160
opt .step ()
164
161
loss_list .append (loss .item ())
165
- if stop_threshold > loss_list [- 2 ] - loss_list [- 1 ] >= 0 :
162
+ diff = torch .sum ((x .data - x_prev )** 2 )
163
+ current = torch .sum ((x_prev )** 2 )
164
+ if diff / current < stop_threshold :
166
165
exit_status = 'Local optimum'
167
166
raise StoppingCriterionReached
168
167
exit_status = 'Max iterations reached'
@@ -180,7 +179,7 @@ def B(y, its=250, lr=1, log=False, stop_threshold=1e-20):
180
179
Y_perm = []
181
180
for k in range (K ):
182
181
Y_perm .append (n * pi_list [k ] @ Y_list [k ])
183
- Bx , log = B (Y_perm , its = 500 , lr = 1 , log = True )
182
+ Bx , log = B (Y_perm , its = 150 , lr = 1 , log = True )
184
183
plt .plot (log ['loss_list' ])
185
184
plt .yscale ('log' )
186
185
plt .savefig ('gwb_circles_B_loss.pdf' )
@@ -190,19 +189,21 @@ def B(y, its=250, lr=1, log=False, stop_threshold=1e-20):
190
189
np .random .seed (0 )
191
190
torch .manual_seed (0 )
192
191
192
+ t0 = time ()
193
193
fixed_point_its = 15
194
194
X_init = torch .rand (n , d , device = device , dtype = torch .double )
195
195
b_list = [TT (ot .unif (m ))] * K
196
- X_bar , X_bar_list = solve_OT_barycenter_fixed_point (X_init , Y_list , b_list ,
197
- cost_list ,
198
- B , max_its = fixed_point_its , pbar = True , log = True )
196
+ X_bar , log_dict = solve_OT_barycenter_fixed_point (
197
+ X_init , Y_list , b_list , cost_list , B , max_its = fixed_point_its , pbar = True , log = True , stop_threshold = stop_threshold )
198
+ dt = time () - t0
199
+ print (f"Finished in { dt :.2f} s, exit status: { log_dict ['exit_status' ]} " )
199
200
200
201
# %% plot fixed-point barycentre final step
201
202
alpha = .5
202
203
labels = ['circle 1' , 'circle 2' , 'circle 3' , 'circle 4' ]
203
204
for Y , label in zip (Y_list , labels ):
204
205
plt .scatter (* TN (Y ).T , alpha = alpha , label = label )
205
- plt .scatter (* TN (X_bar_list [- 1 ]).T , label = 'GWB' , c = 'black' , alpha = alpha )
206
+ plt .scatter (* TN (log_dict [ 'X_list' ] [- 1 ]).T , label = 'GWB' , c = 'black' , alpha = alpha )
206
207
plt .axis ('equal' )
207
208
plt .xlim (- .3 , 1.3 )
208
209
plt .ylim (- .3 , 1.3 )
@@ -211,7 +212,7 @@ def B(y, its=250, lr=1, log=False, stop_threshold=1e-20):
211
212
plt .savefig ('gwb_circles_fixed_point.pdf' )
212
213
213
214
# %% animate fixed-point barycentre steps
214
- num_frames = fixed_point_its + 1 # +1 for initialisation
215
+ num_frames = len ( log_dict [ 'X_list' ])
215
216
fig , ax = plt .subplots ()
216
217
ax .set_xlim (- .3 , 1.3 )
217
218
ax .set_ylim (- .3 , 1.3 )
@@ -230,7 +231,7 @@ def B(y, its=250, lr=1, log=False, stop_threshold=1e-20):
230
231
231
232
def update (frame ): # Update function for animation
232
233
# Update moving scatterplot data
233
- moving_scatter .set_offsets (TN (X_bar_list [frame ]))
234
+ moving_scatter .set_offsets (TN (log_dict [ 'X_list' ] [frame ]))
234
235
return moving_scatter ,
235
236
236
237
@@ -245,7 +246,7 @@ def update(frame): # Update function for animation
245
246
for i , ax in enumerate (axes ):
246
247
for Y , label in zip (Y_list , labels ):
247
248
ax .scatter (* TN (Y ).T , alpha = alpha , label = label )
248
- ax .scatter (* TN (X_bar_list [i ]).T , label = 'GWB' , c = 'black' , alpha = alpha )
249
+ ax .scatter (* TN (log_dict [ 'X_list' ] [i ]).T , label = 'GWB' , c = 'black' , alpha = alpha )
249
250
ax .axis ('equal' )
250
251
ax .axis ('off' )
251
252
ax .set_xlim (- .3 , 1.3 )
@@ -257,7 +258,7 @@ def update(frame): # Update function for animation
257
258
V_list = []
258
259
a = TT (ot .unif (n ))
259
260
b = TT (ot .unif (m ))
260
- for X in X_bar_list :
261
+ for X in log_dict [ 'X_list' ] :
261
262
V = 0
262
263
for k in range (K ):
263
264
V += (1 / K ) * ot .emd2 (a , b , ot .dist (P_list [k ](X ), Y_list [k ]))
@@ -268,14 +269,3 @@ def update(frame): # Update function for animation
268
269
plt .savefig ('gwb_circles_fixed_point_V.pdf' )
269
270
270
271
# %%
271
- X = X_bar_list [0 ]
272
- pi_list = [ot .emd (a , b , cost_list [k ](X , Y_list [k ])) for k in range (K )]
273
- Y_perm = []
274
- for k in range (K ):
275
- Y_perm .append (n * pi_list [k ] @ Y_list [k ])
276
- X , log = B (Y_perm , log = True )
277
- plt .plot (log ['loss_list' ])
278
- plt .yscale ('log' )
279
- plt .savefig ('gwb_circles_B_loss.pdf' )
280
-
281
- # %%
0 commit comments