8
8
import ot # type: ignore
9
9
from torch .optim import SGD
10
10
import matplotlib .animation as animation
11
+ from matplotlib import cm
11
12
12
13
13
14
device = 'cuda' if torch .cuda .is_available () else 'cpu'
19
20
weights = TT (ot .unif (n )) # weights for the barycentre
20
21
21
22
22
- # map 1: R^2 -> R^2 projection onto circle
23
+ # map R^2 -> R^2 projection onto circle
23
24
def proj_circle (X : torch .tensor , origin : torch .tensor , radius : float ):
24
25
diffs = X - origin [None , :]
25
26
norms = torch .norm (diffs , dim = 1 )
26
27
return origin [None , :] + radius * diffs / norms [:, None ]
27
28
28
29
29
30
# build a measure as a 2D circle
30
- t = np .linspace (0 , 2 * np .pi , n , endpoint = False )
31
+ # t = np.linspace(0, 2 * np.pi, n, endpoint=False)
32
+ t = np .random .rand (n ) * 2 * np .pi
31
33
X = .5 * TT (torch .tensor ([np .cos (t ), np .sin (t )]).T )
32
34
X = X + TT (torch .tensor ([.5 , .5 ]))[None , :]
33
35
origin1 = TT (torch .tensor ([- 1 , - 1 ]))
@@ -57,7 +59,7 @@ def proj_circle(X: torch.tensor, origin: torch.tensor, radius: float):
57
59
its = its , stop_threshold = stop_threshold ,
58
60
gamma = gamma )
59
61
dt = time () - t0
60
- print (f"Finished in { dt :.2f} s, exit status: { exit_status } , final loss: { loss_list [- 1 ]:.2f } " )
62
+ print (f"Finished in { dt :.2f} s, exit status: { exit_status } , final loss: { loss_list [- 1 ]:.10f } " )
61
63
62
64
# %% Plot GD barycentre
63
65
alpha = .5
@@ -106,15 +108,21 @@ def C(x, y):
106
108
y = [(Y_list [k ][y_idx [k ]])[None , :] * torch .ones_like (x , device = device ) for k in range (4 )]
107
109
M = C (x , y ) # shape (n_vis**2)
108
110
M = TN (M .reshape (n_vis , n_vis ))
109
- plt .imshow (M .T , interpolation = "nearest" , origin = "lower" , cmap = 'gray' )
110
- plt .savefig ('B_energy_map.pdf' )
111
+
112
+ fig , ax = plt .subplots (subplot_kw = {"projection" : "3d" })
113
+ surf = ax .plot_surface (uu , vv , M , cmap = cm .CMRmap , linewidth = 0 ,
114
+ antialiased = False )
115
+ for axis in [ax .xaxis , ax .yaxis , ax .zaxis ]:
116
+ axis .set_ticklabels ([])
117
+ plt .savefig ("B_energy_map.pdf" , format = "pdf" , bbox_inches = "tight" )
118
+ plt .show (block = True )
111
119
112
120
# %% define B using GD on its energy
113
121
np .random .seed (42 )
114
122
torch .manual_seed (42 )
115
123
116
124
117
- def B (y , its = 100 , lr = 1 , log = False ):
125
+ def B (y , its = 250 , lr = 1 , log = False ):
118
126
"""
119
127
Computes the barycenter images for candidate points x (n, d) and
120
128
measure supports y: List(n, d_k).
@@ -191,8 +199,9 @@ def update(frame): # Update function for animation
191
199
ani .save ("fixed_point_barycentre_animation.gif" , writer = "pillow" , fps = 2 )
192
200
193
201
# %% First 5 steps on a subplot
194
- fig , axes = plt .subplots (1 , 5 , figsize = (15 , 3 )) # 1 row, 5 columns
195
- fig .suptitle ("First 5 Steps Fixed-point GWB solver" , fontsize = 16 )
202
+ n_plots = 5
203
+ fig , axes = plt .subplots (1 , n_plots , figsize = (3 * n_plots , 3 ))
204
+ fig .suptitle (f"First { n_plots } Steps Fixed-point GWB solver" , fontsize = 16 )
196
205
197
206
for i , ax in enumerate (axes ):
198
207
for Y , label in zip (Y_list , labels ):
@@ -204,6 +213,6 @@ def update(frame): # Update function for animation
204
213
ax .set_xlim (- .3 , 1.3 )
205
214
ax .set_ylim (- .3 , 1.3 )
206
215
ax .set_title (f"Step { i + 1 } " , y = - 0.2 )
207
- plt .savefig ('gwb_circles_fixed_point_5_steps .pdf' )
216
+ plt .savefig (f'gwb_circles_fixed_point_ { n_plots } _steps .pdf' )
208
217
209
218
# %%
0 commit comments