Skip to content

Commit e1bd94b

Browse files
committed
code review1
1 parent d432038 commit e1bd94b

File tree

5 files changed

+204
-54
lines changed

5 files changed

+204
-54
lines changed

examples/plot_barycenter_fgw.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,11 @@ def graph_colors(nx_graph, vmin=0, vmax=7):
125125
colors.append(val_map[node])
126126
return colors
127127

128-
#%% create dataset
128+
##############################################################################
129+
# Generate data
130+
# -------------
131+
132+
#%% circular dataset
129133
# We build a dataset of noisy circular graphs.
130134
# Noise is added on the structures by random connections and on the features by gaussian noise.
131135

@@ -135,7 +139,11 @@ def graph_colors(nx_graph, vmin=0, vmax=7):
135139
for k in range(9):
136140
X0.append(build_noisy_circular_graph(np.random.randint(15, 25), with_noise=True, structure_noise=True, p=3))
137141

138-
#%% Plot dataset
142+
##############################################################################
143+
# Plot data
144+
# ---------
145+
146+
#%% Plot graphs
139147

140148
plt.figure(figsize=(8, 10))
141149
for i in range(len(X0)):
@@ -146,24 +154,28 @@ def graph_colors(nx_graph, vmin=0, vmax=7):
146154
plt.suptitle('Dataset of noisy graphs. Color indicates the label', fontsize=20)
147155
plt.show()
148156

157+
##############################################################################
158+
# Barycenter computation
159+
# ----------------------
149160

150-
#%%
151-
# We compute the barycenter using FGW. Structure matrices are computed using the shortest_path distance in the graph
161+
#%% We compute the barycenter using FGW. Structure matrices are computed using the shortest_path distance in the graph
152162
# Features distances are the euclidean distances
153163
Cs = [shortest_path(nx.adjacency_matrix(x)) for x in X0]
154164
ps = [np.ones(len(x.nodes())) / len(x.nodes()) for x in X0]
155165
Ys = [np.array([v for (k, v) in nx.get_node_attributes(x, 'attr_name').items()]).reshape(-1, 1) for x in X0]
156166
lambdas = np.array([np.ones(len(Ys)) / len(Ys)]).ravel()
157167
sizebary = 15 # we choose a barycenter with 15 nodes
158168

159-
#%%
160-
161169
A, C, log = fgw_barycenters(sizebary, Ys, Cs, ps, lambdas, alpha=0.95)
162170

163-
#%%
171+
##############################################################################
172+
# Plot Barycenter
173+
# -------------------------
174+
175+
#%% Create the barycenter
164176
bary = nx.from_numpy_matrix(sp_to_adjency(C, threshinf=0, threshsup=find_thresh(C, sup=100, step=100)[0]))
165-
for i in range(len(A.ravel())):
166-
bary.add_node(i, attr_name=float(A.ravel()[i]))
177+
for i, v in enumerate(A.ravel()):
178+
bary.add_node(i, attr_name=v)
167179

168180
#%%
169181
pos = nx.kamada_kawai_layout(bary)

examples/plot_fgw.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,16 @@
2222
import ot
2323
from ot.gromov import gromov_wasserstein, fused_gromov_wasserstein
2424

25+
##############################################################################
26+
# Generate data
27+
# ---------
28+
2529
#%% parameters
2630
# We create two 1D random measures
27-
n = 20
28-
n2 = 30
29-
sig = 1
30-
sig2 = 0.1
31+
n = 20 # number of points in the first distribution
32+
n2 = 30 # number of points in the second distribution
33+
sig = 1 # std of first distribution
34+
sig2 = 0.1 # std of second distribution
3135

3236
np.random.seed(0)
3337

@@ -43,6 +47,10 @@
4347
p = ot.unif(n)
4448
q = ot.unif(n2)
4549

50+
##############################################################################
51+
# Plot data
52+
# ---------
53+
4654
#%% plot the distributions
4755

4856
pl.close(10)
@@ -64,15 +72,22 @@
6472
pl.tight_layout()
6573
pl.show()
6674

75+
##############################################################################
76+
# Create structure matrices and across-feature distance matrix
77+
# ---------
6778

6879
#%% Structure matrices and across-features distance matrix
6980
C1 = ot.dist(xs)
70-
C2 = ot.dist(xt).T
81+
C2 = ot.dist(xt)
7182
M = ot.dist(ys, yt)
7283
w1 = ot.unif(C1.shape[0])
7384
w2 = ot.unif(C2.shape[0])
7485
Got = ot.emd([], [], M)
7586

87+
##############################################################################
88+
# Plot matrices
89+
# ---------
90+
7691
#%%
7792
cmap = 'Reds'
7893
pl.close(10)
@@ -112,6 +127,9 @@
112127
ax3.set_aspect('auto')
113128
pl.show()
114129

130+
##############################################################################
131+
# Compute FGW/GW
132+
# ---------
115133

116134
#%% Computing FGW and GW
117135
alpha = 1e-3
@@ -123,6 +141,10 @@
123141
#%reload_ext WGW
124142
Gg, log = gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', verbose=True, log=True)
125143

144+
##############################################################################
145+
# Visualize transport matrices
146+
# ---------
147+
126148
#%% visu OT matrix
127149
cmap = 'Blues'
128150
fs = 15

ot/gromov.py

Lines changed: 96 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
# Nicolas Courty <[email protected]>
1111
# Rémi Flamary <[email protected]>
1212
# Titouan Vayer <[email protected]>
13+
#
1314
# License: MIT License
1415

1516
import numpy as np
@@ -351,9 +352,9 @@ def df(G):
351352
return cg(p, q, 0, 1, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
352353

353354

354-
def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, **kwargs):
355+
def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs):
355356
"""
356-
Computes the FGW distance between two graphs see [3]
357+
Computes the FGW transport between two graphs see [24]
357358
.. math::
358359
\gamma = arg\min_\gamma (1-\alpha)*<\gamma,M>_F + alpha* \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}
359360
s.t. \gamma 1 = p
@@ -377,7 +378,7 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5,
377378
distribution in the source space
378379
q : ndarray, shape (nt,)
379380
distribution in the target space
380-
loss_fun : string,optionnal
381+
loss_fun : string,optional
381382
loss function used for the solver
382383
max_iter : int, optional
383384
Max number of iterations
@@ -416,7 +417,86 @@ def f(G):
416417
def df(G):
417418
return gwggrad(constC, hC1, hC2, G)
418419

419-
return cg(p, q, M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
420+
if log:
421+
res, log = cg(p, q, M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs)
422+
log['fgw_dist'] = log['loss'][::-1][0]
423+
return res, log
424+
else:
425+
return cg(p, q, M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
426+
427+
428+
def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs):
429+
"""
430+
Computes the FGW distance between two graphs see [24]
431+
.. math::
432+
\gamma = arg\min_\gamma (1-\alpha)*<\gamma,M>_F + alpha* \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}
433+
s.t. \gamma 1 = p
434+
\gamma^T 1= q
435+
\gamma\geq 0
436+
where :
437+
- M is the (ns,nt) metric cost matrix
438+
- :math:`f` is the regularization term ( and df is its gradient)
439+
- a and b are source and target weights (sum to 1)
440+
- L is a loss function to account for the misfit between the similarity matrices
441+
The algorithm used for solving the problem is conditional gradient as discussed in [1]_
442+
Parameters
443+
----------
444+
M : ndarray, shape (ns, nt)
445+
Metric cost matrix between features across domains
446+
C1 : ndarray, shape (ns, ns)
447+
Metric cost matrix respresentative of the structure in the source space
448+
C2 : ndarray, shape (nt, nt)
449+
Metric cost matrix espresentative of the structure in the target space
450+
p : ndarray, shape (ns,)
451+
distribution in the source space
452+
q : ndarray, shape (nt,)
453+
distribution in the target space
454+
loss_fun : string,optional
455+
loss function used for the solver
456+
max_iter : int, optional
457+
Max number of iterations
458+
tol : float, optional
459+
Stop threshold on error (>0)
460+
verbose : bool, optional
461+
Print information along iterations
462+
log : bool, optional
463+
record log if True
464+
armijo : bool, optional
465+
If True the steps of the line-search is found via an armijo research. Else closed form is used.
466+
If there is convergence issues use False.
467+
**kwargs : dict
468+
parameters can be directly pased to the ot.optim.cg solver
469+
Returns
470+
-------
471+
gamma : (ns x nt) ndarray
472+
Optimal transportation matrix for the given parameters
473+
log : dict
474+
log dictionary return only if log==True in parameters
475+
References
476+
----------
477+
.. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain
478+
and Courty Nicolas
479+
"Optimal Transport for structured data with application on graphs"
480+
International Conference on Machine Learning (ICML). 2019.
481+
"""
482+
483+
constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
484+
485+
G0 = p[:, None] * q[None, :]
486+
487+
def f(G):
488+
return gwloss(constC, hC1, hC2, G)
489+
490+
def df(G):
491+
return gwggrad(constC, hC1, hC2, G)
492+
493+
res, log = cg(p, q, M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs)
494+
if log:
495+
log['fgw_dist'] = log['loss'][::-1][0]
496+
log['T'] = res
497+
return log['fgw_dist'], log
498+
else:
499+
return log['fgw_dist']
420500

421501

422502
def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs):
@@ -889,7 +969,7 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun,
889969

890970
def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_features=False,
891971
p=None, loss_fun='square_loss', max_iter=100, tol=1e-9,
892-
verbose=False, log=True, init_C=None, init_X=None):
972+
verbose=False, log=False, init_C=None, init_X=None):
893973
"""
894974
Compute the fgw barycenter as presented eq (5) in [24].
895975
----------
@@ -919,7 +999,8 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_
919999
Barycenters' features
9201000
C : ndarray, shape (N,N)
9211001
Barycenters' structure matrix
922-
log_:
1002+
log_: dictionary
1003+
Only returned when log=True
9231004
T : list of (N,ns) transport matrices
9241005
Ms : all distance matrices between the feature of the barycenter and the other features dist(X,Ys) shape (N,ns)
9251006
References
@@ -1015,14 +1096,13 @@ class UndefinedParameter(Exception):
10151096
T = [fused_gromov_wasserstein((1 - alpha) * Ms[s], C, Cs[s], p, ps[s], loss_fun, alpha, numItermax=max_iter, stopThr=1e-5, verbose=verbose) for s in range(S)]
10161097

10171098
# T is N,ns
1018-
1019-
log_['Ts_iter'].append(T)
10201099
err_feature = np.linalg.norm(X - Xprev.reshape(N, d))
10211100
err_structure = np.linalg.norm(C - Cprev)
10221101

10231102
if log:
10241103
log_['err_feature'].append(err_feature)
10251104
log_['err_structure'].append(err_structure)
1105+
log_['Ts_iter'].append(T)
10261106

10271107
if verbose:
10281108
if cpt % 200 == 0:
@@ -1032,11 +1112,15 @@ class UndefinedParameter(Exception):
10321112
print('{:5d}|{:8e}|'.format(cpt, err_feature))
10331113

10341114
cpt += 1
1035-
log_['T'] = T # from target to Ys
1036-
log_['p'] = p
1037-
log_['Ms'] = Ms # Ms are N,ns
1115+
if log:
1116+
log_['T'] = T # from target to Ys
1117+
log_['p'] = p
1118+
log_['Ms'] = Ms # Ms are N,ns
10381119

1039-
return X, C, log_
1120+
if log:
1121+
return X, C, log_
1122+
else:
1123+
return X, C
10401124

10411125

10421126
def update_sructure_matrix(p, lambdas, T, Cs):

ot/optim.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
# Author: Remi Flamary <[email protected]>
77
# Titouan Vayer <[email protected]>
8+
#
89
# License: MIT License
910

1011
import numpy as np
@@ -88,20 +89,20 @@ def do_linesearch(cost, G, deltaG, Mi, f_val,
8889
Cost matrix of the linearized transport problem. Corresponds to the gradient of the cost
8990
f_val : float
9091
Value of the cost at G
91-
armijo : bool, optionnal
92+
armijo : bool, optional
9293
If True the steps of the line-search is found via an armijo research. Else closed form is used.
9394
If there is convergence issues use False.
94-
C1 : ndarray (ns,ns), optionnal
95+
C1 : ndarray (ns,ns), optional
9596
Structure matrix in the source domain. Only used when armijo=False
96-
C2 : ndarray (nt,nt), optionnal
97+
C2 : ndarray (nt,nt), optional
9798
Structure matrix in the target domain. Only used when armijo=False
98-
reg : float, optionnal
99+
reg : float, optional
99100
Regularization parameter. Only used when armijo=False
100101
Gc : ndarray (ns,nt)
101102
Optimal map found by linearization in the FW algorithm. Only used when armijo=False
102103
constC : ndarray (ns,nt)
103104
Constant for the gromov cost. See [24]. Only used when armijo=False
104-
M : ndarray (ns,nt), optionnal
105+
M : ndarray (ns,nt), optional
105106
Cost matrix between the features. Only used when armijo=False
106107
Returns
107108
-------
@@ -223,9 +224,9 @@ def cost(G):
223224
it = 0
224225

225226
if verbose:
226-
print('{:5s}|{:12s}|{:8s}'.format(
227-
'It.', 'Loss', 'Delta loss') + '\n' + '-' * 32)
228-
print('{:5d}|{:8e}|{:8e}'.format(it, f_val, 0))
227+
print('{:5s}|{:12s}|{:8s}|{:8s}'.format(
228+
'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48)
229+
print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, 0, 0))
229230

230231
while loop:
231232

@@ -261,8 +262,8 @@ def cost(G):
261262

262263
if verbose:
263264
if it % 20 == 0:
264-
print('{:5s}|{:12s}|{:8s}'.format(
265-
'It.', 'Loss', 'Relative variation loss', 'Absolute variation loss') + '\n' + '-' * 32)
265+
print('{:5s}|{:12s}|{:8s}|{:8s}'.format(
266+
'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48)
266267
print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, relative_delta_fval, abs_delta_fval))
267268

268269
if log:
@@ -363,9 +364,9 @@ def cost(G):
363364
it = 0
364365

365366
if verbose:
366-
print('{:5s}|{:12s}|{:8s}'.format(
367-
'It.', 'Loss', 'Delta loss') + '\n' + '-' * 32)
368-
print('{:5d}|{:8e}|{:8e}'.format(it, f_val, 0))
367+
print('{:5s}|{:12s}|{:8s}|{:8s}'.format(
368+
'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48)
369+
print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, 0, 0))
369370

370371
while loop:
371372

@@ -402,8 +403,8 @@ def cost(G):
402403

403404
if verbose:
404405
if it % 20 == 0:
405-
print('{:5s}|{:12s}|{:8s}'.format(
406-
'It.', 'Loss', 'Relative variation loss', 'Absolute variation loss') + '\n' + '-' * 32)
406+
print('{:5s}|{:12s}|{:8s}|{:8s}'.format(
407+
'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48)
407408
print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, relative_delta_fval, abs_delta_fval))
408409

409410
if log:

0 commit comments

Comments
 (0)