Skip to content

Commit 79eb337

Browse files
[MRG] Partial Entropic FGW solvers (#702)
* merge * new dev version * first commit partial fgw * complete tests + solve_gromov * complete tests + solve_gromov * release * partial entropic fgw solvers * add tests * complete solve_gromov * update * fix solvers * fix solvers * fix solvers * improve example and doc * update readme * reset solve_gromov behavior for unbalanced=none * merge * merge --------- Co-authored-by: Rémi Flamary <[email protected]>
1 parent c128104 commit 79eb337

10 files changed

+1051
-42
lines changed

README.md

+1-2
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,7 @@ POT provides the following generic OT solvers (links to examples):
4040
* [Sampled solver of Gromov Wasserstein](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) for large-scale problem with any loss functions [33]
4141
* Non regularized [free support Wasserstein barycenters](https://pythonot.github.io/auto_examples/barycenters/plot_free_support_barycenter.html) [20].
4242
* [One dimensional Unbalanced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_1D.html) with KL relaxation and [barycenter](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_barycenter_1D.html) [10, 25]. Also [exact unbalanced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_unbalanced_ot.html) with KL and quadratic regularization and the [regularization path of UOT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_regpath.html) [41]
43-
* [Partial Wasserstein and Gromov-Wasserstein](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_partial_wass_and_gromov.html) (exact [29] and entropic [3]
44-
formulations).
43+
* [Partial Wasserstein and Gromov-Wasserstein](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_partial_wass_and_gromov.html) and [Partial Fused Gromov-Wasserstein](https://pythonot.github.io/auto_examples/gromov/plot_partial_fgw.html) (exact [29] and entropic [3] formulations).
4544
* [Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance.html) [31, 32] and Max-sliced Wasserstein [35] that can be used for gradient flows [36].
4645
* [Wasserstein distance on the circle](https://pythonot.github.io/auto_examples/plot_compute_wasserstein_circle.html) [44, 45]
4746
* [Spherical Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance_ssw.html) [46]

RELEASES.md

+1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
- Added feature `grad=last_step` for `ot.solvers.solve` (PR #693)
88
- Automatic PR labeling and release file update check (PR #704)
99
- Reorganize sub-module `ot/lp/__init__.py` into separate files (PR #714)
10+
- Implement projected gradient descent solvers for entropic partial FGW (PR #702)
1011
- Fix documentation in the module `ot.gaussian` (PR #718)
1112

1213
#### Closed issues

examples/gromov/plot_barycenter_fgw.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def build_noisy_circular_graph(
9191
g = nx.Graph()
9292
g.add_nodes_from(list(range(N)))
9393
for i in range(N):
94-
noise = float(np.random.normal(mu, sigma, 1))
94+
noise = np.random.normal(mu, sigma, 1)[0]
9595
if with_noise:
9696
g.add_node(i, attr_name=math.sin((2 * i * math.pi / N)) + noise)
9797
else:
@@ -107,7 +107,7 @@ def build_noisy_circular_graph(
107107
if i == N - 1:
108108
g.add_edge(i, 1)
109109
g.add_edge(N, 0)
110-
noise = float(np.random.normal(mu, sigma, 1))
110+
noise = np.random.normal(mu, sigma, 1)[0]
111111
if with_noise:
112112
g.add_node(N, attr_name=math.sin((2 * N * math.pi / N)) + noise)
113113
else:
@@ -157,7 +157,7 @@ def graph_colors(nx_graph, vmin=0, vmax=7):
157157
plt.subplot(3, 3, i + 1)
158158
g = X0[i]
159159
pos = nx.kamada_kawai_layout(g)
160-
nx.draw(
160+
nx.draw_networkx(
161161
g,
162162
pos=pos,
163163
node_color=graph_colors(g, vmin=-1, vmax=1),
@@ -173,7 +173,7 @@ def graph_colors(nx_graph, vmin=0, vmax=7):
173173

174174
# %% We compute the barycenter using FGW. Structure matrices are computed using the shortest_path distance in the graph
175175
# Features distances are the euclidean distances
176-
Cs = [shortest_path(nx.adjacency_matrix(x).todense()) for x in X0]
176+
Cs = [shortest_path(nx.adjacency_matrix(x).toarray()) for x in X0]
177177
ps = [np.ones(len(x.nodes())) / len(x.nodes()) for x in X0]
178178
Ys = [
179179
np.array([v for (k, v) in nx.get_node_attributes(x, "attr_name").items()]).reshape(
@@ -199,7 +199,7 @@ def graph_colors(nx_graph, vmin=0, vmax=7):
199199

200200
# %%
201201
pos = nx.kamada_kawai_layout(bary)
202-
nx.draw(
202+
nx.draw_networkx(
203203
bary, pos=pos, node_color=graph_colors(bary, vmin=-1, vmax=1), with_labels=False
204204
)
205205
plt.suptitle("Barycenter", fontsize=20)

0 commit comments

Comments
 (0)