Skip to content

Commit 0411ea2

Browse files
rflamaryagramfort
andauthored
[MRG] New API for OT solver (with pre-computed ground cost matrix) (#388)
* new API for OT solver * use itertools for product of parameters * add tests for result class * add tests for result class * add tests for result class last time? * add sinkhorn * make partial OT bckend compatible * add TV as unbalanced flavor * better tests * make smoth backend compatible and add l2 tregularizatio to solve * add reularizedd unbalanced * add test for more complex attibutes * add test for more complex attibutes * add generic unbalaned solver and implement it for ot.solve * add entropy to possible regularization * star of documentation for ot.solv * weird new pep8 * documenttaion for function ot.solve done * pep8 * Update ot/solvers.py Co-authored-by: Alexandre Gramfort <[email protected]> * update release file * Apply suggestions from code review Co-authored-by: Alexandre Gramfort <[email protected]> * add test NotImplemented * pep8 * pep8gcmp pep8! * compute kl in backend * debug tensorflow kl backend Co-authored-by: Alexandre Gramfort <[email protected]>
1 parent 8490196 commit 0411ea2

13 files changed

+1011
-18
lines changed

RELEASES.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66

77
- Added Generalized Wasserstein Barycenter solver + example (PR #372), fixed graphical details on the example (PR #376)
88
- Added Free Support Sinkhorn Barycenter + example (PR #387)
9+
- New API for OT solver using function `ot.solve` (PR #388)
10+
- Backend version of `ot.partial` and `ot.smooth` (PR #388)
11+
912

1013
#### Closed issues
1114

ot/__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from . import regpath
3535
from . import weak
3636
from . import factored
37+
from . import solvers
3738

3839
# OT functions
3940
from .lp import emd, emd2, emd_1d, emd2_1d, wasserstein_1d
@@ -46,7 +47,7 @@
4647
gromov_barycenters, fused_gromov_wasserstein, fused_gromov_wasserstein2)
4748
from .weak import weak_optimal_transport
4849
from .factored import factored_optimal_transport
49-
50+
from .solvers import solve
5051

5152
# utils functions
5253
from .utils import dist, unif, tic, toc, toq
@@ -61,5 +62,5 @@
6162
'sinkhorn_unbalanced2', 'sliced_wasserstein_distance',
6263
'gromov_wasserstein', 'gromov_wasserstein2', 'gromov_barycenters', 'fused_gromov_wasserstein', 'fused_gromov_wasserstein2',
6364
'max_sliced_wasserstein_distance', 'weak_optimal_transport',
64-
'factored_optimal_transport',
65-
'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath']
65+
'factored_optimal_transport', 'solve',
66+
'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath', 'solvers']

ot/backend.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -854,6 +854,21 @@ def sqrtm(self, a):
854854
"""
855855
raise NotImplementedError()
856856

857+
def kl_div(self, p, q, eps=1e-16):
858+
r"""
859+
Computes the Kullback-Leibler divergence.
860+
861+
This function follows the api from :any:`scipy.stats.entropy`.
862+
863+
Parameter eps is used to avoid numerical errors and is added in the log.
864+
865+
.. math::
866+
KL(p,q) = \sum_i p(i) \log (\frac{p(i)}{q(i)}+\epsilon)
867+
868+
See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.entropy.html
869+
"""
870+
raise NotImplementedError()
871+
857872
def isfinite(self, a):
858873
r"""
859874
Tests element-wise for finiteness (not infinity and not Not a Number).
@@ -1158,6 +1173,9 @@ def inv(self, a):
11581173
def sqrtm(self, a):
11591174
return scipy.linalg.sqrtm(a)
11601175

1176+
def kl_div(self, p, q, eps=1e-16):
1177+
return np.sum(p * np.log(p / q + eps))
1178+
11611179
def isfinite(self, a):
11621180
return np.isfinite(a)
11631181

@@ -1481,6 +1499,9 @@ def sqrtm(self, a):
14811499
L, V = jnp.linalg.eigh(a)
14821500
return (V * jnp.sqrt(L)[None, :]) @ V.T
14831501

1502+
def kl_div(self, p, q, eps=1e-16):
1503+
return jnp.sum(p * jnp.log(p / q + eps))
1504+
14841505
def isfinite(self, a):
14851506
return jnp.isfinite(a)
14861507

@@ -1901,6 +1922,9 @@ def sqrtm(self, a):
19011922
L, V = torch.linalg.eigh(a)
19021923
return (V * torch.sqrt(L)[None, :]) @ V.T
19031924

1925+
def kl_div(self, p, q, eps=1e-16):
1926+
return torch.sum(p * torch.log(p / q + eps))
1927+
19041928
def isfinite(self, a):
19051929
return torch.isfinite(a)
19061930

@@ -2248,6 +2272,9 @@ def sqrtm(self, a):
22482272
L, V = cp.linalg.eigh(a)
22492273
return (V * self.sqrt(L)[None, :]) @ V.T
22502274

2275+
def kl_div(self, p, q, eps=1e-16):
2276+
return cp.sum(p * cp.log(p / q + eps))
2277+
22512278
def isfinite(self, a):
22522279
return cp.isfinite(a)
22532280

@@ -2608,6 +2635,9 @@ def inv(self, a):
26082635
def sqrtm(self, a):
26092636
return tf.linalg.sqrtm(a)
26102637

2638+
def kl_div(self, p, q, eps=1e-16):
2639+
return tnp.sum(p * tnp.log(p / q + eps))
2640+
26112641
def isfinite(self, a):
26122642
return tnp.isfinite(a)
26132643

ot/partial.py

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
import numpy as np
1010
from .lp import emd
11+
from .backend import get_backend
12+
from .utils import list_to_array
1113

1214

1315
def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False,
@@ -114,14 +116,22 @@ def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False,
114116
ot.partial.partial_wasserstein : Partial Wasserstein with fixed mass
115117
"""
116118

117-
if np.sum(a) > 1 or np.sum(b) > 1:
119+
a, b, M = list_to_array(a, b, M)
120+
121+
nx = get_backend(a, b, M)
122+
123+
if nx.sum(a) > 1 or nx.sum(b) > 1:
118124
raise ValueError("Problem infeasible. Check that a and b are in the "
119125
"simplex")
120126

121127
if reg_m is None:
122-
reg_m = np.max(M) + 1
123-
if reg_m < -np.max(M):
124-
return np.zeros((len(a), len(b)))
128+
reg_m = float(nx.max(M)) + 1
129+
if reg_m < -nx.max(M):
130+
return nx.zeros((len(a), len(b)), type_as=M)
131+
132+
a0, b0, M0 = a, b, M
133+
# convert to humpy
134+
a, b, M = nx.to_numpy(a, b, M)
125135

126136
eps = 1e-20
127137
M = np.asarray(M, dtype=np.float64)
@@ -149,10 +159,16 @@ def partial_wasserstein_lagrange(a, b, M, reg_m=None, nb_dummies=1, log=False,
149159
gamma = np.zeros((len(a), len(b)))
150160
gamma[np.ix_(idx_x, idx_y)] = gamma_extended[:-nb_dummies, :-nb_dummies]
151161

162+
# convert back to backend
163+
gamma = nx.from_numpy(gamma, type_as=M0)
164+
152165
if log_emd['warning'] is not None:
153166
raise ValueError("Error in the EMD resolution: try to increase the"
154167
" number of dummy points")
155-
log_emd['cost'] = np.sum(gamma * M)
168+
log_emd['cost'] = nx.sum(gamma * M0)
169+
log_emd['u'] = nx.from_numpy(log_emd['u'], type_as=a0)
170+
log_emd['v'] = nx.from_numpy(log_emd['v'], type_as=b0)
171+
156172
if log:
157173
return gamma, log_emd
158174
else:
@@ -250,15 +266,23 @@ def partial_wasserstein(a, b, M, m=None, nb_dummies=1, log=False, **kwargs):
250266
entropic regularization parameter
251267
"""
252268

269+
a, b, M = list_to_array(a, b, M)
270+
271+
nx = get_backend(a, b, M)
272+
253273
if m is None:
254274
return partial_wasserstein_lagrange(a, b, M, log=log, **kwargs)
255275
elif m < 0:
256276
raise ValueError("Problem infeasible. Parameter m should be greater"
257277
" than 0.")
258-
elif m > np.min((np.sum(a), np.sum(b))):
278+
elif m > nx.min((nx.sum(a), nx.sum(b))):
259279
raise ValueError("Problem infeasible. Parameter m should lower or"
260280
" equal than min(|a|_1, |b|_1).")
261281

282+
a0, b0, M0 = a, b, M
283+
# convert to humpy
284+
a, b, M = nx.to_numpy(a, b, M)
285+
262286
b_extended = np.append(b, [(np.sum(a) - m) / nb_dummies] * nb_dummies)
263287
a_extended = np.append(a, [(np.sum(b) - m) / nb_dummies] * nb_dummies)
264288
M_extended = np.zeros((len(a_extended), len(b_extended)))
@@ -267,15 +291,20 @@ def partial_wasserstein(a, b, M, m=None, nb_dummies=1, log=False, **kwargs):
267291

268292
gamma, log_emd = emd(a_extended, b_extended, M_extended, log=True,
269293
**kwargs)
294+
295+
gamma = nx.from_numpy(gamma[:len(a), :len(b)], type_as=M)
296+
270297
if log_emd['warning'] is not None:
271298
raise ValueError("Error in the EMD resolution: try to increase the"
272299
" number of dummy points")
273-
log_emd['partial_w_dist'] = np.sum(M * gamma[:len(a), :len(b)])
300+
log_emd['partial_w_dist'] = nx.sum(M0 * gamma)
301+
log_emd['u'] = nx.from_numpy(log_emd['u'][:len(a)], type_as=a0)
302+
log_emd['v'] = nx.from_numpy(log_emd['v'][:len(b)], type_as=b0)
274303

275304
if log:
276-
return gamma[:len(a), :len(b)], log_emd
305+
return gamma, log_emd
277306
else:
278-
return gamma[:len(a), :len(b)]
307+
return gamma
279308

280309

281310
def partial_wasserstein2(a, b, M, m=None, nb_dummies=1, log=False, **kwargs):

ot/smooth.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444

4545
import numpy as np
4646
from scipy.optimize import minimize
47+
from .backend import get_backend
4748

4849

4950
def projection_simplex(V, z=1, axis=None):
@@ -511,22 +512,28 @@ def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9,
511512
512513
"""
513514

515+
nx = get_backend(a, b, M)
516+
514517
if reg_type.lower() in ['l2', 'squaredl2']:
515518
regul = SquaredL2(gamma=reg)
516519
elif reg_type.lower() in ['entropic', 'negentropy', 'kl']:
517520
regul = NegEntropy(gamma=reg)
518521
else:
519522
raise NotImplementedError('Unknown regularization')
520523

524+
a0, b0, M0 = a, b, M
525+
# convert to humpy
526+
a, b, M = nx.to_numpy(a, b, M)
527+
521528
# solve dual
522529
alpha, beta, res = solve_dual(a, b, M, regul, max_iter=numItermax,
523530
tol=stopThr, verbose=verbose)
524531

525532
# reconstruct transport matrix
526-
G = get_plan_from_dual(alpha, beta, M, regul)
533+
G = nx.from_numpy(get_plan_from_dual(alpha, beta, M, regul), type_as=M0)
527534

528535
if log:
529-
log = {'alpha': alpha, 'beta': beta, 'res': res}
536+
log = {'alpha': nx.from_numpy(alpha, type_as=a0), 'beta': nx.from_numpy(beta, type_as=b0), 'res': res}
530537
return G, log
531538
else:
532539
return G

0 commit comments

Comments
 (0)