Skip to content

Commit a03d7a7

Browse files
committed
W1 GWB example
1 parent 2fd3f3b commit a03d7a7

20 files changed

+208
-11
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
*__pycache__/
2-
*.egg-info/
2+
*.egg-info/
3+
.vscode/
127 KB
Binary file not shown.
Loading
319 Bytes
Binary file not shown.
Binary file not shown.
Binary file not shown.
4.8 KB
Binary file not shown.
17 Bytes
Binary file not shown.

examples/2d_circles_L2/src.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import ot # type: ignore
99
from torch.optim import SGD
1010
import matplotlib.animation as animation
11+
from matplotlib import cm
1112

1213

1314
device = 'cuda' if torch.cuda.is_available() else 'cpu'
@@ -19,15 +20,16 @@
1920
weights = TT(ot.unif(n)) # weights for the barycentre
2021

2122

22-
# map 1: R^2 -> R^2 projection onto circle
23+
# map R^2 -> R^2 projection onto circle
2324
def proj_circle(X: torch.tensor, origin: torch.tensor, radius: float):
2425
diffs = X - origin[None, :]
2526
norms = torch.norm(diffs, dim=1)
2627
return origin[None, :] + radius * diffs / norms[:, None]
2728

2829

2930
# build a measure as a 2D circle
30-
t = np.linspace(0, 2 * np.pi, n, endpoint=False)
31+
# t = np.linspace(0, 2 * np.pi, n, endpoint=False)
32+
t = np.random.rand(n) * 2 * np.pi
3133
X = .5 * TT(torch.tensor([np.cos(t), np.sin(t)]).T)
3234
X = X + TT(torch.tensor([.5, .5]))[None, :]
3335
origin1 = TT(torch.tensor([-1, -1]))
@@ -57,7 +59,7 @@ def proj_circle(X: torch.tensor, origin: torch.tensor, radius: float):
5759
its=its, stop_threshold=stop_threshold,
5860
gamma=gamma)
5961
dt = time() - t0
60-
print(f"Finished in {dt:.2f}s, exit status: {exit_status}, final loss: {loss_list[-1]:.2f}")
62+
print(f"Finished in {dt:.2f}s, exit status: {exit_status}, final loss: {loss_list[-1]:.10f}")
6163

6264
# %% Plot GD barycentre
6365
alpha = .5
@@ -106,15 +108,21 @@ def C(x, y):
106108
y = [(Y_list[k][y_idx[k]])[None, :] * torch.ones_like(x, device=device) for k in range(4)]
107109
M = C(x, y) # shape (n_vis**2)
108110
M = TN(M.reshape(n_vis, n_vis))
109-
plt.imshow(M.T, interpolation="nearest", origin="lower", cmap='gray')
110-
plt.savefig('B_energy_map.pdf')
111+
112+
fig, ax = plt.subplots(subplot_kw={"projection": "3d"})
113+
surf = ax.plot_surface(uu, vv, M, cmap=cm.CMRmap, linewidth=0,
114+
antialiased=False)
115+
for axis in [ax.xaxis, ax.yaxis, ax.zaxis]:
116+
axis.set_ticklabels([])
117+
plt.savefig("B_energy_map.pdf", format="pdf", bbox_inches="tight")
118+
plt.show(block=True)
111119

112120
# %% define B using GD on its energy
113121
np.random.seed(42)
114122
torch.manual_seed(42)
115123

116124

117-
def B(y, its=100, lr=1, log=False):
125+
def B(y, its=250, lr=1, log=False):
118126
"""
119127
Computes the barycenter images for candidate points x (n, d) and
120128
measure supports y: List(n, d_k).
@@ -191,8 +199,9 @@ def update(frame): # Update function for animation
191199
ani.save("fixed_point_barycentre_animation.gif", writer="pillow", fps=2)
192200

193201
# %% First 5 steps on a subplot
194-
fig, axes = plt.subplots(1, 5, figsize=(15, 3)) # 1 row, 5 columns
195-
fig.suptitle("First 5 Steps Fixed-point GWB solver", fontsize=16)
202+
n_plots = 5
203+
fig, axes = plt.subplots(1, n_plots, figsize=(3 * n_plots, 3))
204+
fig.suptitle(f"First {n_plots} Steps Fixed-point GWB solver", fontsize=16)
196205

197206
for i, ax in enumerate(axes):
198207
for Y, label in zip(Y_list, labels):
@@ -204,6 +213,6 @@ def update(frame): # Update function for animation
204213
ax.set_xlim(-.3, 1.3)
205214
ax.set_ylim(-.3, 1.3)
206215
ax.set_title(f"Step {i+1}", y=-0.2)
207-
plt.savefig('gwb_circles_fixed_point_5_steps.pdf')
216+
plt.savefig(f'gwb_circles_fixed_point_{n_plots}_steps.pdf')
208217

209218
# %%
12.5 KB
Binary file not shown.

0 commit comments

Comments
 (0)