Skip to content

Commit 5835016

Browse files
[MRG] fix bugs of gw_entropic and armijo to run on gpu (#446)
* maj gw/ srgw/ generic cg solver * correct pep8 on current state * fix bug previous tests * fix pep8 * fix bug srGW constC in loss and gradient * fix doc html * fix doc html * start updating test_optim.py * update tests gromov and optim - plus fix gromov dependencies * add symmetry feature to entropic gw * add symmetry feature to entropic gw * add exemple for sr(F)GW matchings * small stuff * remove (reg,M) from line-search/ complete srgw tests with backend * remove backend repetitions / rename fG to costG/ fix innerlog to True * fix pep8 * take comments into account / new nx parameters still to test * factor (f)gw2 + test new backend parameters in ot.gromov + harmonize stopping criterions * split gromov.py in ot/gromov/ + update test_gromov with helper_backend functions * manual documentaion gromov * remove circular autosummary * trying stuff * debug documentation * alphabetic ordering of module * merge into branch * add note in entropic gw solvers * fix exemples/gromov doc * add fixed issue to releases.md * fix bugs of gw_entropic and armijo to run on gpu * add pr to releases.md * fix pep8 * fix call to backend in line_search_armijo * correct docstring generic_conditional_gradient --------- Co-authored-by: Rémi Flamary <[email protected]>
1 parent 8f56eff commit 5835016

File tree

5 files changed

+100
-17
lines changed

5 files changed

+100
-17
lines changed

RELEASES.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ PR #413)
4747
that explicitly specified `stopThr=1e-9` (Issue #421, PR #422).
4848
- Fixed a bug breaking an example where we would try to make an array of arrays of different shapes (Issue #424, PR #425)
4949
- Fixed an issue with the documentation gallery section (PR #444)
50+
- Fixed issues with cuda variables for `line_search_armijo` and `entropic_gromov_wasserstein` (Issue #445, #PR 446)
5051

5152
## 0.8.2
5253

@@ -571,4 +572,4 @@ It provides the following solvers:
571572
* Optimal transport for domain adaptation with group lasso regularization
572573
* Conditional gradient and Generalized conditional gradient for regularized OT.
573574

574-
Some demonstrations (both in Python and Jupyter Notebook format) are available in the examples folder.
575+
Some demonstrations (both in Python and Jupyter Notebook format) are available in the examples folder.

ot/gromov/_bregman.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,6 @@
1111
#
1212
# License: MIT License
1313

14-
import numpy as np
15-
16-
1714
from ..bregman import sinkhorn
1815
from ..utils import dist, list_to_array, check_random_state
1916
from ..backend import get_backend
@@ -109,7 +106,7 @@ def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, symmetric=None,
109106
T = G0
110107
constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun, nx)
111108
if symmetric is None:
112-
symmetric = np.allclose(C1, C1.T, atol=1e-10) and np.allclose(C2, C2.T, atol=1e-10)
109+
symmetric = nx.allclose(C1, C1.T, atol=1e-10) and nx.allclose(C2, C2.T, atol=1e-10)
113110
if not symmetric:
114111
constCt, hC1t, hC2t = init_matrix(C1.T, C2.T, p, q, loss_fun, nx)
115112
cpt = 0

ot/optim.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ def line_search_armijo(
3535
Find an approximate minimum of :math:`f(x_k + \alpha \cdot p_k)` that satisfies the
3636
armijo conditions.
3737
38+
.. note:: If the loss function f returns a float (resp. a 1d array) then
39+
the returned alpha and fa are float (resp. 1d arrays).
40+
3841
Parameters
3942
----------
4043
f : callable
@@ -45,7 +48,7 @@ def line_search_armijo(
4548
descent direction
4649
gfk : array-like
4750
gradient of `f` at :math:`x_k`
48-
old_fval : float
51+
old_fval : float or 1d array
4952
loss value at :math:`x_k`
5053
args : tuple, optional
5154
arguments given to `f`
@@ -61,42 +64,59 @@ def line_search_armijo(
6164
If let to its default value None, a backend test will be conducted.
6265
Returns
6366
-------
64-
alpha : float
67+
alpha : float or 1d array
6568
step that satisfy armijo conditions
6669
fc : int
6770
nb of function call
68-
fa : float
71+
fa : float or 1d array
6972
loss value at step alpha
7073
7174
"""
7275
if nx is None:
7376
xk, pk, gfk = list_to_array(xk, pk, gfk)
74-
nx = get_backend(xk, pk)
77+
xk0, pk0 = xk, pk
78+
nx = get_backend(xk0, pk0)
79+
else:
80+
xk0, pk0 = xk, pk
7581

7682
if len(xk.shape) == 0:
7783
xk = nx.reshape(xk, (-1,))
7884

85+
xk = nx.to_numpy(xk)
86+
pk = nx.to_numpy(pk)
87+
gfk = nx.to_numpy(gfk)
88+
7989
fc = [0]
8090

8191
def phi(alpha1):
92+
# The callable function operates on nx backend
8293
fc[0] += 1
83-
return f(xk + alpha1 * pk, *args)
94+
alpha10 = nx.from_numpy(alpha1)
95+
fval = f(xk0 + alpha10 * pk0, *args)
96+
if type(fval) is float:
97+
# prevent bug from nx.to_numpy that can look for .cpu or .gpu
98+
return fval
99+
else:
100+
return nx.to_numpy(fval)
84101

85102
if old_fval is None:
86103
phi0 = phi(0.)
87-
else:
104+
elif type(old_fval) is float:
105+
# prevent bug from nx.to_numpy that can look for .cpu or .gpu
88106
phi0 = old_fval
107+
else:
108+
phi0 = nx.to_numpy(old_fval)
89109

90-
derphi0 = nx.sum(pk * gfk) # Quickfix for matrices
110+
derphi0 = np.sum(pk * gfk) # Quickfix for matrices
91111
alpha, phi1 = scalar_search_armijo(
92112
phi, phi0, derphi0, c1=c1, alpha0=alpha0)
93113

94114
if alpha is None:
95-
return 0., fc[0], phi0
115+
return 0., fc[0], nx.from_numpy(phi0, type_as=xk0)
96116
else:
97117
if alpha_min is not None or alpha_max is not None:
98118
alpha = np.clip(alpha, alpha_min, alpha_max)
99-
return float(alpha), fc[0], phi1
119+
return nx.from_numpy(alpha, type_as=xk0), fc[0], nx.from_numpy(phi1, type_as=xk0)
100120

101121

102122
def generic_conditional_gradient(a, b, M, f, df, reg1, reg2, lp_solver, line_search, G0=None,

test/test_gromov.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,7 @@ def test_gromov2_gradients():
214214
C11 = torch.tensor(C1, requires_grad=True, device=device)
215215
C12 = torch.tensor(C2, requires_grad=True, device=device)
216216

217+
# Test with exact line-search
217218
val = ot.gromov_wasserstein2(C11, C12, p1, q1)
218219

219220
val.backward()
@@ -224,6 +225,21 @@ def test_gromov2_gradients():
224225
assert C11.shape == C11.grad.shape
225226
assert C12.shape == C12.grad.shape
226227

228+
# Test with armijo line-search
229+
q1.grad = None
230+
p1.grad = None
231+
C11.grad = None
232+
C12.grad = None
233+
val = ot.gromov_wasserstein2(C11, C12, p1, q1, armijo=True)
234+
235+
val.backward()
236+
237+
assert val.device == p1.device
238+
assert q1.shape == q1.grad.shape
239+
assert p1.shape == p1.grad.shape
240+
assert C11.shape == C11.grad.shape
241+
assert C12.shape == C12.grad.shape
242+
227243

228244
def test_gw_helper_backend(nx):
229245
n_samples = 20 # nb samples

test/test_optim.py

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -135,16 +135,18 @@ def test_line_search_armijo(nx):
135135
xk = np.array([[0.25, 0.25], [0.25, 0.25]])
136136
pk = np.array([[-0.25, 0.25], [0.25, -0.25]])
137137
gfk = np.array([[23.04273441, 23.0449082], [23.04273441, 23.0449082]])
138-
old_fval = -123
138+
old_fval = -123.
139139

140140
xkb, pkb, gfkb = nx.from_numpy(xk, pk, gfk)
141141

142+
def f(x):
143+
return 1.
142144
# Should not throw an exception and return 0. for alpha
143145
alpha, a, b = ot.optim.line_search_armijo(
144-
lambda x: 1, xkb, pkb, gfkb, old_fval
146+
f, xkb, pkb, gfkb, old_fval
145147
)
146148
alpha_np, anp, bnp = ot.optim.line_search_armijo(
147-
lambda x: 1, xk, pk, gfk, old_fval
149+
f, xk, pk, gfk, old_fval
148150
)
149151
assert a == anp
150152
assert b == bnp
@@ -182,3 +184,50 @@ def grad(x):
182184
old_fval = f(xk)
183185
alpha, _, _ = ot.optim.line_search_armijo(f, xk, pk, gfk, old_fval)
184186
np.testing.assert_allclose(alpha, 0.1)
187+
188+
189+
def test_line_search_armijo_dtype_device(nx):
190+
for tp in nx.__type_list__:
191+
def f(x):
192+
return nx.sum((x - 5.0) ** 2)
193+
194+
def grad(x):
195+
return 2 * (x - 5.0)
196+
197+
xk = np.array([[[-5.0, -5.0]]])
198+
pk = np.array([[[100.0, 100.0]]])
199+
xkb, pkb = nx.from_numpy(xk, pk, type_as=tp)
200+
gfkb = grad(xkb)
201+
old_fval = f(xkb)
202+
203+
# chech the case where the optimum is on the direction
204+
alpha, _, fval = ot.optim.line_search_armijo(f, xkb, pkb, gfkb, old_fval)
205+
alpha = nx.to_numpy(alpha)
206+
np.testing.assert_allclose(alpha, 0.1)
207+
nx.assert_same_dtype_device(old_fval, fval)
208+
209+
# check the case where the direction is not far enough
210+
pk = np.array([[[3.0, 3.0]]])
211+
pkb = nx.from_numpy(pk, type_as=tp)
212+
alpha, _, fval = ot.optim.line_search_armijo(f, xkb, pkb, gfkb, old_fval, alpha0=1.0)
213+
alpha = nx.to_numpy(alpha)
214+
np.testing.assert_allclose(alpha, 1.0)
215+
nx.assert_same_dtype_device(old_fval, fval)
216+
217+
# check the case where checking the wrong direction
218+
alpha, _, fval = ot.optim.line_search_armijo(f, xkb, -pkb, gfkb, old_fval)
219+
alpha = nx.to_numpy(alpha)
220+
221+
assert alpha <= 0
222+
nx.assert_same_dtype_device(old_fval, fval)
223+
224+
# check the case where the point is not a vector
225+
xkb = nx.from_numpy(np.array(-5.0), type_as=tp)
226+
pkb = nx.from_numpy(np.array(100), type_as=tp)
227+
gfkb = grad(xkb)
228+
old_fval = f(xkb)
229+
alpha, _, fval = ot.optim.line_search_armijo(f, xkb, pkb, gfkb, old_fval)
230+
alpha = nx.to_numpy(alpha)
231+
232+
np.testing.assert_allclose(alpha, 0.1)
233+
nx.assert_same_dtype_device(old_fval, fval)

0 commit comments

Comments
 (0)