Skip to content

Commit 8b9e641

Browse files
authored
[MRG] SinkhornL1L2 bug solve (#313)
* Now limiting alpha to a minimum value as well as a max value * Docs * typo
1 parent e235b08 commit 8b9e641

File tree

1 file changed

+30
-9
lines changed

1 file changed

+30
-9
lines changed

ot/optim.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,10 @@
1818
# The corresponding scipy function does not work for matrices
1919

2020

21-
def line_search_armijo(f, xk, pk, gfk, old_fval,
22-
args=(), c1=1e-4, alpha0=0.99):
21+
def line_search_armijo(
22+
f, xk, pk, gfk, old_fval, args=(), c1=1e-4,
23+
alpha0=0.99, alpha_min=None, alpha_max=None
24+
):
2325
r"""
2426
Armijo linesearch function that works with matrices
2527
@@ -44,6 +46,10 @@ def line_search_armijo(f, xk, pk, gfk, old_fval,
4446
:math:`c_1` const in armijo rule (>0)
4547
alpha0 : float, optional
4648
initial step (>0)
49+
alpha_min : float, optional
50+
minimum value for alpha
51+
alpha_max : float, optional
52+
maximum value for alpha
4753
4854
Returns
4955
-------
@@ -80,13 +86,15 @@ def phi(alpha1):
8086
if alpha is None:
8187
return 0., fc[0], phi0
8288
else:
83-
# scalar_search_armijo can return alpha > 1
84-
alpha = min(1, alpha)
89+
if alpha_min is not None or alpha_max is not None:
90+
alpha = np.clip(alpha, alpha_min, alpha_max)
8591
return alpha, fc[0], phi1
8692

8793

88-
def solve_linesearch(cost, G, deltaG, Mi, f_val,
89-
armijo=True, C1=None, C2=None, reg=None, Gc=None, constC=None, M=None):
94+
def solve_linesearch(
95+
cost, G, deltaG, Mi, f_val, armijo=True, C1=None, C2=None,
96+
reg=None, Gc=None, constC=None, M=None, alpha_min=None, alpha_max=None
97+
):
9098
"""
9199
Solve the linesearch in the FW iterations
92100
@@ -117,6 +125,10 @@ def solve_linesearch(cost, G, deltaG, Mi, f_val,
117125
Constant for the gromov cost. See :ref:`[24] <references-solve-linesearch>`. Only used and necessary when armijo=False
118126
M : array-like (ns,nt), optional
119127
Cost matrix between the features. Only used and necessary when armijo=False
128+
alpha_min : float, optional
129+
Minimum value for alpha
130+
alpha_max : float, optional
131+
Maximum value for alpha
120132
121133
Returns
122134
-------
@@ -136,7 +148,9 @@ def solve_linesearch(cost, G, deltaG, Mi, f_val,
136148
International Conference on Machine Learning (ICML). 2019.
137149
"""
138150
if armijo:
139-
alpha, fc, f_val = line_search_armijo(cost, G, deltaG, Mi, f_val)
151+
alpha, fc, f_val = line_search_armijo(
152+
cost, G, deltaG, Mi, f_val, alpha_min=alpha_min, alpha_max=alpha_max
153+
)
140154
else: # requires symetric matrices
141155
G, deltaG, C1, C2, constC, M = list_to_array(G, deltaG, C1, C2, constC, M)
142156
if isinstance(M, int) or isinstance(M, float):
@@ -150,6 +164,8 @@ def solve_linesearch(cost, G, deltaG, Mi, f_val,
150164
c = cost(G)
151165

152166
alpha = solve_1d_linesearch_quad(a, b, c)
167+
if alpha_min is not None or alpha_max is not None:
168+
alpha = np.clip(alpha, alpha_min, alpha_max)
153169
fc = None
154170
f_val = cost(G + alpha * deltaG)
155171

@@ -274,7 +290,10 @@ def cost(G):
274290
deltaG = Gc - G
275291

276292
# line search
277-
alpha, fc, f_val = solve_linesearch(cost, G, deltaG, Mi, f_val, reg=reg, M=M, Gc=Gc, **kwargs)
293+
alpha, fc, f_val = solve_linesearch(
294+
cost, G, deltaG, Mi, f_val, reg=reg, M=M, Gc=Gc,
295+
alpha_min=0., alpha_max=1., **kwargs
296+
)
278297

279298
G = G + alpha * deltaG
280299

@@ -420,7 +439,9 @@ def cost(G):
420439

421440
# line search
422441
dcost = Mi + reg1 * (1 + nx.log(G)) # ??
423-
alpha, fc, f_val = line_search_armijo(cost, G, deltaG, dcost, f_val)
442+
alpha, fc, f_val = line_search_armijo(
443+
cost, G, deltaG, dcost, f_val, alpha_min=0., alpha_max=1.
444+
)
424445

425446
G = G + alpha * deltaG
426447

0 commit comments

Comments
 (0)