18
18
# The corresponding scipy function does not work for matrices
19
19
20
20
21
- def line_search_armijo (f , xk , pk , gfk , old_fval ,
22
- args = (), c1 = 1e-4 , alpha0 = 0.99 ):
21
+ def line_search_armijo (
22
+ f , xk , pk , gfk , old_fval , args = (), c1 = 1e-4 ,
23
+ alpha0 = 0.99 , alpha_min = None , alpha_max = None
24
+ ):
23
25
r"""
24
26
Armijo linesearch function that works with matrices
25
27
@@ -44,6 +46,10 @@ def line_search_armijo(f, xk, pk, gfk, old_fval,
44
46
:math:`c_1` const in armijo rule (>0)
45
47
alpha0 : float, optional
46
48
initial step (>0)
49
+ alpha_min : float, optional
50
+ minimum value for alpha
51
+ alpha_max : float, optional
52
+ maximum value for alpha
47
53
48
54
Returns
49
55
-------
@@ -80,13 +86,15 @@ def phi(alpha1):
80
86
if alpha is None :
81
87
return 0. , fc [0 ], phi0
82
88
else :
83
- # scalar_search_armijo can return alpha > 1
84
- alpha = min ( 1 , alpha )
89
+ if alpha_min is not None or alpha_max is not None :
90
+ alpha = np . clip ( alpha , alpha_min , alpha_max )
85
91
return alpha , fc [0 ], phi1
86
92
87
93
88
- def solve_linesearch (cost , G , deltaG , Mi , f_val ,
89
- armijo = True , C1 = None , C2 = None , reg = None , Gc = None , constC = None , M = None ):
94
+ def solve_linesearch (
95
+ cost , G , deltaG , Mi , f_val , armijo = True , C1 = None , C2 = None ,
96
+ reg = None , Gc = None , constC = None , M = None , alpha_min = None , alpha_max = None
97
+ ):
90
98
"""
91
99
Solve the linesearch in the FW iterations
92
100
@@ -117,6 +125,10 @@ def solve_linesearch(cost, G, deltaG, Mi, f_val,
117
125
Constant for the gromov cost. See :ref:`[24] <references-solve-linesearch>`. Only used and necessary when armijo=False
118
126
M : array-like (ns,nt), optional
119
127
Cost matrix between the features. Only used and necessary when armijo=False
128
+ alpha_min : float, optional
129
+ Minimum value for alpha
130
+ alpha_max : float, optional
131
+ Maximum value for alpha
120
132
121
133
Returns
122
134
-------
@@ -136,7 +148,9 @@ def solve_linesearch(cost, G, deltaG, Mi, f_val,
136
148
International Conference on Machine Learning (ICML). 2019.
137
149
"""
138
150
if armijo :
139
- alpha , fc , f_val = line_search_armijo (cost , G , deltaG , Mi , f_val )
151
+ alpha , fc , f_val = line_search_armijo (
152
+ cost , G , deltaG , Mi , f_val , alpha_min = alpha_min , alpha_max = alpha_max
153
+ )
140
154
else : # requires symetric matrices
141
155
G , deltaG , C1 , C2 , constC , M = list_to_array (G , deltaG , C1 , C2 , constC , M )
142
156
if isinstance (M , int ) or isinstance (M , float ):
@@ -150,6 +164,8 @@ def solve_linesearch(cost, G, deltaG, Mi, f_val,
150
164
c = cost (G )
151
165
152
166
alpha = solve_1d_linesearch_quad (a , b , c )
167
+ if alpha_min is not None or alpha_max is not None :
168
+ alpha = np .clip (alpha , alpha_min , alpha_max )
153
169
fc = None
154
170
f_val = cost (G + alpha * deltaG )
155
171
@@ -274,7 +290,10 @@ def cost(G):
274
290
deltaG = Gc - G
275
291
276
292
# line search
277
- alpha , fc , f_val = solve_linesearch (cost , G , deltaG , Mi , f_val , reg = reg , M = M , Gc = Gc , ** kwargs )
293
+ alpha , fc , f_val = solve_linesearch (
294
+ cost , G , deltaG , Mi , f_val , reg = reg , M = M , Gc = Gc ,
295
+ alpha_min = 0. , alpha_max = 1. , ** kwargs
296
+ )
278
297
279
298
G = G + alpha * deltaG
280
299
@@ -420,7 +439,9 @@ def cost(G):
420
439
421
440
# line search
422
441
dcost = Mi + reg1 * (1 + nx .log (G )) # ??
423
- alpha , fc , f_val = line_search_armijo (cost , G , deltaG , dcost , f_val )
442
+ alpha , fc , f_val = line_search_armijo (
443
+ cost , G , deltaG , dcost , f_val , alpha_min = 0. , alpha_max = 1.
444
+ )
424
445
425
446
G = G + alpha * deltaG
426
447
0 commit comments