Skip to content

Commit ad450b0

Browse files
committed
changes forgotten coments
1 parent 89a2e0a commit ad450b0

File tree

4 files changed

+30
-42
lines changed

4 files changed

+30
-42
lines changed

ot/gromov.py

Lines changed: 3 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818

1919
from .bregman import sinkhorn
20-
from .utils import dist
20+
from .utils import dist, UndefinedParameter
2121
from .optim import cg
2222

2323

@@ -1011,9 +1011,6 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_
10111011
International Conference on Machine Learning (ICML). 2019.
10121012
"""
10131013

1014-
class UndefinedParameter(Exception):
1015-
pass
1016-
10171014
S = len(Cs)
10181015
d = Ys[0].shape[1] # dimension on the node features
10191016
if p is None:
@@ -1049,10 +1046,7 @@ class UndefinedParameter(Exception):
10491046

10501047
T = [np.outer(p, q) for q in ps]
10511048

1052-
# X is N,d
1053-
# Ys is ns,d
1054-
Ms = [np.asarray(dist(X, Ys[s]), dtype=np.float64) for s in range(len(Ys))]
1055-
# Ms is N,ns
1049+
Ms = [np.asarray(dist(X, Ys[s]), dtype=np.float64) for s in range(len(Ys))] # Ms is N,ns
10561050

10571051
cpt = 0
10581052
err_feature = 1
@@ -1072,27 +1066,13 @@ class UndefinedParameter(Exception):
10721066
Ys_temp = [y.T for y in Ys]
10731067
X = update_feature_matrix(lambdas, Ys_temp, T, p).T
10741068

1075-
# X must be N,d
1076-
# Ys must be ns,d
10771069
Ms = [np.asarray(dist(X, Ys[s]), dtype=np.float64) for s in range(len(Ys))]
10781070

10791071
if not fixed_structure:
10801072
if loss_fun == 'square_loss':
1081-
# T must be ns,N
1082-
# Cs must be ns,ns
1083-
# p must be N,1
10841073
T_temp = [t.T for t in T]
10851074
C = update_sructure_matrix(p, lambdas, T_temp, Cs)
10861075

1087-
# Ys must be d,ns
1088-
# Ts must be N,ns
1089-
# p must be N,1
1090-
# Ms is N,ns
1091-
# C is N,N
1092-
# Cs is ns,ns
1093-
# p is N,1
1094-
# ps is ns,1
1095-
10961076
T = [fused_gromov_wasserstein((1 - alpha) * Ms[s], C, Cs[s], p, ps[s], loss_fun, alpha, numItermax=max_iter, stopThr=1e-5, verbose=verbose) for s in range(S)]
10971077

10981078
# T is N,ns
@@ -1115,7 +1095,7 @@ class UndefinedParameter(Exception):
11151095
if log:
11161096
log_['T'] = T # from target to Ys
11171097
log_['p'] = p
1118-
log_['Ms'] = Ms # Ms are N,ns
1098+
log_['Ms'] = Ms
11191099

11201100
if log:
11211101
return X, C, log_

ot/optim.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,8 @@ def phi(alpha1):
7373
return alpha, fc[0], phi1
7474

7575

76-
def do_linesearch(cost, G, deltaG, Mi, f_val,
77-
armijo=True, C1=None, C2=None, reg=None, Gc=None, constC=None, M=None):
76+
def solve_linesearch(cost, G, deltaG, Mi, f_val,
77+
armijo=True, C1=None, C2=None, reg=None, Gc=None, constC=None, M=None):
7878
"""
7979
Solve the linesearch in the FW iterations
8080
Parameters
@@ -93,17 +93,17 @@ def do_linesearch(cost, G, deltaG, Mi, f_val,
9393
If True the steps of the line-search is found via an armijo research. Else closed form is used.
9494
If there is convergence issues use False.
9595
C1 : ndarray (ns,ns), optional
96-
Structure matrix in the source domain. Only used when armijo=False
96+
Structure matrix in the source domain. Only used and necessary when armijo=False
9797
C2 : ndarray (nt,nt), optional
98-
Structure matrix in the target domain. Only used when armijo=False
98+
Structure matrix in the target domain. Only used and necessary when armijo=False
9999
reg : float, optional
100-
Regularization parameter. Only used when armijo=False
100+
Regularization parameter. Only used and necessary when armijo=False
101101
Gc : ndarray (ns,nt)
102-
Optimal map found by linearization in the FW algorithm. Only used when armijo=False
102+
Optimal map found by linearization in the FW algorithm. Only used and necessary when armijo=False
103103
constC : ndarray (ns,nt)
104-
Constant for the gromov cost. See [24]. Only used when armijo=False
104+
Constant for the gromov cost. See [24]. Only used and necessary when armijo=False
105105
M : ndarray (ns,nt), optional
106-
Cost matrix between the features. Only used when armijo=False
106+
Cost matrix between the features. Only used and necessary when armijo=False
107107
Returns
108108
-------
109109
alpha : float
@@ -128,7 +128,7 @@ def do_linesearch(cost, G, deltaG, Mi, f_val,
128128
b = np.sum((M + reg * constC) * deltaG) - 2 * reg * (np.sum(dot12 * G) + np.sum(np.dot(C1, G).dot(C2) * deltaG))
129129
c = cost(G)
130130

131-
alpha = solve_1d_linesearch_quad_funct(a, b, c)
131+
alpha = solve_1d_linesearch_quad(a, b, c)
132132
fc = None
133133
f_val = cost(G + alpha * deltaG)
134134

@@ -181,7 +181,7 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200,
181181
Print information along iterations
182182
log : bool, optional
183183
record log if True
184-
kwargs : dict
184+
**kwargs : dict
185185
Parameters for linesearch
186186
187187
Returns
@@ -244,7 +244,7 @@ def cost(G):
244244
deltaG = Gc - G
245245

246246
# line search
247-
alpha, fc, f_val = do_linesearch(cost, G, deltaG, Mi, f_val, reg=reg, M=M, Gc=Gc, **kwargs)
247+
alpha, fc, f_val = solve_linesearch(cost, G, deltaG, Mi, f_val, reg=reg, M=M, Gc=Gc, **kwargs)
248248

249249
G = G + alpha * deltaG
250250

@@ -254,7 +254,7 @@ def cost(G):
254254

255255
abs_delta_fval = abs(f_val - old_fval)
256256
relative_delta_fval = abs_delta_fval / abs(f_val)
257-
if relative_delta_fval < stopThr and abs_delta_fval < stopThr2:
257+
if relative_delta_fval < stopThr or abs_delta_fval < stopThr2:
258258
loop = 0
259259

260260
if log:
@@ -395,7 +395,7 @@ def cost(G):
395395
abs_delta_fval = abs(f_val - old_fval)
396396
relative_delta_fval = abs_delta_fval / abs(f_val)
397397

398-
if relative_delta_fval < stopThr and abs_delta_fval < stopThr2:
398+
if relative_delta_fval < stopThr or abs_delta_fval < stopThr2:
399399
loop = 0
400400

401401
if log:
@@ -413,11 +413,11 @@ def cost(G):
413413
return G
414414

415415

416-
def solve_1d_linesearch_quad_funct(a, b, c):
416+
def solve_1d_linesearch_quad(a, b, c):
417417
"""
418-
Solve on 0,1 the following problem:
418+
For any convex or non-convex 1d quadratic function f, solve on [0,1] the following problem:
419419
.. math::
420-
\min f(x)=a*x^{2}+b*x+c
420+
\argmin f(x)=a*x^{2}+b*x+c
421421
422422
Parameters
423423
----------

ot/utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -487,3 +487,11 @@ def set_params(self, **params):
487487
(key, self.__class__.__name__))
488488
setattr(self, key, value)
489489
return self
490+
491+
492+
class UndefinedParameter(Exception):
493+
"""
494+
Aim at raising an Exception when a undefined parameter is called
495+
496+
"""
497+
pass

test/test_optim.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,6 @@ def df(G):
6868

6969

7070
def test_solve_1d_linesearch_quad_funct():
71-
np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad_funct(1, -1, 0), 0.5)
72-
np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad_funct(-1, 5, 0), 0)
73-
np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad_funct(-1, 0.5, 0), 1)
71+
np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(1, -1, 0), 0.5)
72+
np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(-1, 5, 0), 0)
73+
np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(-1, 0.5, 0), 1)

0 commit comments

Comments
 (0)