Skip to content

Commit 8f56eff

Browse files
[WIP] Fix gromov examples gallery (#444)
* maj gw/ srgw/ generic cg solver * correct pep8 on current state * fix bug previous tests * fix pep8 * fix bug srGW constC in loss and gradient * fix doc html * fix doc html * start updating test_optim.py * update tests gromov and optim - plus fix gromov dependencies * add symmetry feature to entropic gw * add symmetry feature to entropic gw * add exemple for sr(F)GW matchings * small stuff * remove (reg,M) from line-search/ complete srgw tests with backend * remove backend repetitions / rename fG to costG/ fix innerlog to True * fix pep8 * take comments into account / new nx parameters still to test * factor (f)gw2 + test new backend parameters in ot.gromov + harmonize stopping criterions * split gromov.py in ot/gromov/ + update test_gromov with helper_backend functions * manual documentaion gromov * remove circular autosummary * trying stuff * debug documentation * alphabetic ordering of module * merge into branch * add note in entropic gw solvers * fix exemples/gromov doc * add fixed issue to releases.md --------- Co-authored-by: Rémi Flamary <[email protected]>
1 parent a5930d3 commit 8f56eff

File tree

3 files changed

+50
-46
lines changed

3 files changed

+50
-46
lines changed

RELEASES.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ PR #413)
4646
- Fix an issue where the parameter `stopThr` in `empirical_sinkhorn_divergence` was rendered useless by subcalls
4747
that explicitly specified `stopThr=1e-9` (Issue #421, PR #422).
4848
- Fixed a bug breaking an example where we would try to make an array of arrays of different shapes (Issue #424, PR #425)
49-
49+
- Fixed an issue with the documentation gallery section (PR #444)
5050

5151
## 0.8.2
5252

examples/gromov/plot_gromov_wasserstein_dictionary_learning.py

Lines changed: 28 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,11 @@
4545
import ot
4646
import networkx
4747
from networkx.generators.community import stochastic_block_model as sbm
48-
# %%
49-
# =============================================================================
48+
49+
#############################################################################
50+
#
5051
# Generate a dataset composed of graphs following Stochastic Block models of 1, 2 and 3 clusters.
51-
# =============================================================================
52+
# ---------------------------------------------
5253

5354
np.random.seed(42)
5455

@@ -109,10 +110,10 @@ def plot_graph(x, C, binary=True, color='C0', s=None):
109110
pl.tight_layout()
110111
pl.show()
111112

112-
# %%
113-
# =============================================================================
113+
#############################################################################
114+
#
114115
# Estimate the gromov-wasserstein dictionary from the dataset
115-
# =============================================================================
116+
# ---------------------------------------------
116117

117118

118119
np.random.seed(0)
@@ -140,10 +141,10 @@ def plot_graph(x, C, binary=True, color='C0', s=None):
140141
pl.tight_layout()
141142
pl.show()
142143

143-
# %%
144-
# =============================================================================
144+
#############################################################################
145+
#
145146
# Visualization of the estimated dictionary atoms
146-
# =============================================================================
147+
# ---------------------------------------------
147148

148149

149150
# Continuous connections between nodes of the atoms are colored in shades of grey (1: dark / 2: white)
@@ -164,10 +165,11 @@ def plot_graph(x, C, binary=True, color='C0', s=None):
164165
pl.axis("off")
165166
pl.tight_layout()
166167
pl.show()
167-
#%%
168-
# =============================================================================
168+
169+
#############################################################################
170+
#
169171
# Visualization of the embedding space
170-
# =============================================================================
172+
# ---------------------------------------------
171173

172174
unmixings = []
173175
reconstruction_errors = []
@@ -211,11 +213,11 @@ def plot_graph(x, C, binary=True, color='C0', s=None):
211213
pl.legend(fontsize=11)
212214
pl.tight_layout()
213215
pl.show()
214-
# %%
215-
# =============================================================================
216-
# Endow the dataset with node features
217-
# =============================================================================
218216

217+
#############################################################################
218+
#
219+
# Endow the dataset with node features
220+
# ---------------------------------------------
219221
# We follow this feature assignment on all nodes of a graph depending on its label/number of clusters
220222
# 1 cluster --> 0 as nodes feature
221223
# 2 clusters --> 1 as nodes feature
@@ -251,10 +253,11 @@ def plot_graph(x, C, binary=True, color='C0', s=None):
251253
pl.axis("off")
252254
pl.tight_layout()
253255
pl.show()
254-
# %%
255-
# =============================================================================
256+
257+
#############################################################################
258+
#
256259
# Estimate a Fused Gromov-Wasserstein dictionary from the dataset of attributed graphs
257-
# =============================================================================
260+
# ---------------------------------------------
258261
np.random.seed(0)
259262
ps = [ot.unif(C.shape[0]) for C in dataset]
260263
D = 3 # 6 atoms instead of 3
@@ -280,10 +283,10 @@ def plot_graph(x, C, binary=True, color='C0', s=None):
280283
pl.tight_layout()
281284
pl.show()
282285

283-
# %%
284-
# =============================================================================
286+
#############################################################################
287+
#
285288
# Visualization of the estimated dictionary atoms
286-
# =============================================================================
289+
# ---------------------------------------------
287290

288291
pl.figure(7, (12, 8))
289292
pl.clf()
@@ -307,10 +310,10 @@ def plot_graph(x, C, binary=True, color='C0', s=None):
307310
pl.tight_layout()
308311
pl.show()
309312

310-
# %%
311-
# =============================================================================
313+
#############################################################################
314+
#
312315
# Visualization of the embedding space
313-
# =============================================================================
316+
# ---------------------------------------------
314317

315318
unmixings = []
316319
reconstruction_errors = []

examples/gromov/plot_semirelaxed_fgw.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,10 @@
3131
import networkx
3232
from networkx.generators.community import stochastic_block_model as sbm
3333

34-
# %%
35-
# =============================================================================
34+
#############################################################################
35+
#
3636
# Generate two graphs following Stochastic Block models of 2 and 3 clusters.
37-
# =============================================================================
37+
# ---------------------------------------------
3838

3939

4040
N2 = 20 # 2 communities
@@ -81,10 +81,11 @@
8181
weightedG3.add_edge(i, j, weight=weight_intra_G3)
8282
else:
8383
weightedG3.add_edge(i, j, weight=weight_inter_G3)
84-
# %%
85-
# =============================================================================
84+
85+
#############################################################################
86+
#
8687
# Compute their semi-relaxed Gromov-Wasserstein divergences
87-
# =============================================================================
88+
# ---------------------------------------------
8889

8990
# 0) GW(C2, h2, C3, h3) for reference
9091
OT, log = gromov_wasserstein(C2, C3, h2, h3, symmetric=True, log=True)
@@ -106,11 +107,11 @@
106107
print('srGW(C3, h3, C2) = ', srgw_32)
107108

108109

109-
# %%
110-
# =============================================================================
110+
#############################################################################
111+
#
111112
# Visualization of the semi-relaxed Gromov-Wasserstein matchings
112-
# =============================================================================
113-
113+
# ---------------------------------------------
114+
#
114115
# We color nodes of the graph on the right - then project its node colors
115116
# based on the optimal transport plan from the srGW matching
116117

@@ -222,10 +223,10 @@ def draw_transp_colored_srGW(G1, C1, G2, C2, part_G1,
222223

223224
pl.show()
224225

225-
# %%
226-
# =============================================================================
226+
#############################################################################
227+
#
227228
# Add node features
228-
# =============================================================================
229+
# ---------------------------------------------
229230

230231
# We add node features with given mean - by clusters
231232
# and inversely proportional to clusters' intra-connectivity
@@ -238,10 +239,10 @@ def draw_transp_colored_srGW(G1, C1, G2, C2, part_G1,
238239
for i, c in enumerate(part_G3):
239240
F3[i, 0] = np.random.normal(loc=2. - c, scale=0.01)
240241

241-
# %%
242-
# =============================================================================
242+
#############################################################################
243+
#
243244
# Compute their semi-relaxed Fused Gromov-Wasserstein divergences
244-
# =============================================================================
245+
# ---------------------------------------------
245246

246247
alpha = 0.5
247248
# Compute pairwise euclidean distance between node features
@@ -268,11 +269,11 @@ def draw_transp_colored_srGW(G1, C1, G2, C2, part_G1,
268269
print('srGW(C2, F2, h2, C3, F3) = ', srfgw_23)
269270
print('srGW(C3, F3, h3, C2, F2) = ', srfgw_32)
270271

271-
# %%
272-
# =============================================================================
272+
#############################################################################
273+
#
273274
# Visualization of the semi-relaxed Fused Gromov-Wasserstein matchings
274-
# =============================================================================
275-
275+
# ---------------------------------------------
276+
#
276277
# We color nodes of the graph on the right - then project its node colors
277278
# based on the optimal transport plan from the srFGW matching
278279
# NB: colors refer to clusters - not to node features

0 commit comments

Comments
 (0)