Skip to content

Commit 24a7a04

Browse files
panispanirflamary
andauthored
Check if alpha is not None when restricting it to be at most 1 (#199)
* Check if alpha is not None when restricting it to be at most 1 * Write check more clearly * Add no regression test for line search armijo returning None for alpha Co-authored-by: Rémi Flamary <[email protected]>
1 parent 679ed31 commit 24a7a04

File tree

2 files changed

+14
-1
lines changed

2 files changed

+14
-1
lines changed

Diff for: ot/optim.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,10 @@ def phi(alpha1):
6969
alpha, phi1 = scalar_search_armijo(
7070
phi, phi0, derphi0, c1=c1, alpha0=alpha0)
7171

72-
return min(1, alpha), fc[0], phi1
72+
# scalar_search_armijo can return alpha > 1
73+
if alpha is not None:
74+
alpha = min(1, alpha)
75+
return alpha, fc[0], phi1
7376

7477

7578
def solve_linesearch(cost, G, deltaG, Mi, f_val,

Diff for: test/test_optim.py

+10
Original file line numberDiff line numberDiff line change
@@ -104,3 +104,13 @@ def test_solve_1d_linesearch_quad_funct():
104104
np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(1, -1, 0), 0.5)
105105
np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(-1, 5, 0), 0)
106106
np.testing.assert_allclose(ot.optim.solve_1d_linesearch_quad(-1, 0.5, 0), 1)
107+
108+
109+
def test_line_search_armijo():
110+
xk = np.array([[0.25, 0.25], [0.25, 0.25]])
111+
pk = np.array([[-0.25, 0.25], [0.25, -0.25]])
112+
gfk = np.array([[23.04273441, 23.0449082], [23.04273441, 23.0449082]])
113+
old_fval = -123
114+
# Should not throw an exception and return None for alpha
115+
alpha, _, _ = ot.optim.line_search_armijo(lambda x: 1, xk, pk, gfk, old_fval)
116+
assert alpha is None

0 commit comments

Comments
 (0)