Skip to content

Commit 566a0fc

Browse files
committed
made barycenter_solvers and network_simplex hidden + deprecated ot.lp.cvx
1 parent 3e3b444 commit 566a0fc

File tree

5 files changed

+163
-149
lines changed

5 files changed

+163
-149
lines changed

Diff for: RELEASES.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
- Implement CG solvers for partial FGW (PR #687)
77
- Added feature `grad=last_step` for `ot.solvers.solve` (PR #693)
88
- Automatic PR labeling and release file update check (PR #704)
9-
- Reorganize sub-module `ot/lp/__init__.py` into separate files. (PR #714) (PR #714)
9+
- Reorganize sub-module `ot/lp/__init__.py` into separate files (PR #714)
1010

1111
#### Closed issues
1212
- Fixed `ot.mapping` solvers which depended on deprecated `cvxpy` `ECOS` solver (PR #692, Issue #668)

Diff for: ot/lp/__init__.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@
99
# License: MIT License
1010

1111
from . import cvx
12-
from .cvx import barycenter
1312
from .dmmot import dmmot_monge_1dgrid_loss, dmmot_monge_1dgrid_optimize
14-
from .network_simplex import emd, emd2
15-
from .barycenter_solvers import (
13+
from ._network_simplex import emd, emd2
14+
from ._barycenter_solvers import (
15+
barycenter,
1616
free_support_barycenter,
1717
generalized_free_support_barycenter,
1818
)

Diff for: ot/lp/barycenter_solvers.py renamed to ot/lp/_barycenter_solvers.py

+155-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,160 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
OT Barycenter Solvers
4+
"""
5+
6+
# Author: Remi Flamary <[email protected]>
7+
# Eloi Tanguy <[email protected]>
8+
#
9+
# License: MIT License
10+
111
from ..backend import get_backend
212
from ..utils import dist
3-
from .network_simplex import emd
13+
from ._network_simplex import emd
14+
15+
import numpy as np
16+
import scipy as sp
17+
import scipy.sparse as sps
18+
19+
try:
20+
import cvxopt # for cvxopt barycenter solver
21+
from cvxopt import solvers, matrix, spmatrix
22+
except ImportError:
23+
cvxopt = False
24+
25+
26+
def scipy_sparse_to_spmatrix(A):
27+
"""Efficient conversion from scipy sparse matrix to cvxopt sparse matrix"""
28+
coo = A.tocoo()
29+
SP = spmatrix(coo.data.tolist(), coo.row.tolist(), coo.col.tolist(), size=A.shape)
30+
return SP
31+
32+
33+
def barycenter(A, M, weights=None, verbose=False, log=False, solver="highs-ipm"):
34+
r"""Compute the Wasserstein barycenter of distributions A
35+
36+
The function solves the following optimization problem [16]:
37+
38+
.. math::
39+
\mathbf{a} = arg\min_\mathbf{a} \sum_i W_{1}(\mathbf{a},\mathbf{a}_i)
40+
41+
where :
42+
43+
- :math:`W_1(\cdot,\cdot)` is the Wasserstein distance (see ot.emd.sinkhorn)
44+
- :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}`
45+
46+
The linear program is solved using the interior point solver from scipy.optimize.
47+
If cvxopt solver if installed it can use cvxopt
48+
49+
Note that this problem do not scale well (both in memory and computational time).
50+
51+
Parameters
52+
----------
53+
A : np.ndarray (d,n)
54+
n training distributions a_i of size d
55+
M : np.ndarray (d,d)
56+
loss matrix for OT
57+
reg : float
58+
Regularization term >0
59+
weights : np.ndarray (n,)
60+
Weights of each histogram a_i on the simplex (barycentric coordinates)
61+
verbose : bool, optional
62+
Print information along iterations
63+
log : bool, optional
64+
record log if True
65+
solver : string, optional
66+
the solver used, default 'interior-point' use the lp solver from
67+
scipy.optimize. None, or 'glpk' or 'mosek' use the solver from cvxopt.
68+
69+
Returns
70+
-------
71+
a : (d,) ndarray
72+
Wasserstein barycenter
73+
log : dict
74+
log dictionary return only if log==True in parameters
75+
76+
77+
References
78+
----------
79+
80+
.. [16] Agueh, M., & Carlier, G. (2011). Barycenters in the Wasserstein space. SIAM Journal on Mathematical Analysis, 43(2), 904-924.
81+
82+
83+
"""
84+
85+
if weights is None:
86+
weights = np.ones(A.shape[1]) / A.shape[1]
87+
else:
88+
assert len(weights) == A.shape[1]
89+
90+
n_distributions = A.shape[1]
91+
n = A.shape[0]
92+
93+
n2 = n * n
94+
c = np.zeros((0))
95+
b_eq1 = np.zeros((0))
96+
for i in range(n_distributions):
97+
c = np.concatenate((c, M.ravel() * weights[i]))
98+
b_eq1 = np.concatenate((b_eq1, A[:, i]))
99+
c = np.concatenate((c, np.zeros(n)))
100+
101+
lst_idiag1 = [sps.kron(sps.eye(n), np.ones((1, n))) for i in range(n_distributions)]
102+
# row constraints
103+
A_eq1 = sps.hstack(
104+
(sps.block_diag(lst_idiag1), sps.coo_matrix((n_distributions * n, n)))
105+
)
106+
107+
# columns constraints
108+
lst_idiag2 = []
109+
lst_eye = []
110+
for i in range(n_distributions):
111+
if i == 0:
112+
lst_idiag2.append(sps.kron(np.ones((1, n)), sps.eye(n)))
113+
lst_eye.append(-sps.eye(n))
114+
else:
115+
lst_idiag2.append(sps.kron(np.ones((1, n)), sps.eye(n - 1, n)))
116+
lst_eye.append(-sps.eye(n - 1, n))
117+
118+
A_eq2 = sps.hstack((sps.block_diag(lst_idiag2), sps.vstack(lst_eye)))
119+
b_eq2 = np.zeros((A_eq2.shape[0]))
120+
121+
# full problem
122+
A_eq = sps.vstack((A_eq1, A_eq2))
123+
b_eq = np.concatenate((b_eq1, b_eq2))
124+
125+
if not cvxopt or solver in ["interior-point", "highs", "highs-ipm", "highs-ds"]:
126+
# cvxopt not installed or interior point
127+
128+
if solver is None:
129+
solver = "interior-point"
130+
131+
options = {"disp": verbose}
132+
sol = sp.optimize.linprog(
133+
c, A_eq=A_eq, b_eq=b_eq, method=solver, options=options
134+
)
135+
x = sol.x
136+
b = x[-n:]
137+
138+
else:
139+
h = np.zeros((n_distributions * n2 + n))
140+
G = -sps.eye(n_distributions * n2 + n)
141+
142+
sol = solvers.lp(
143+
matrix(c),
144+
scipy_sparse_to_spmatrix(G),
145+
matrix(h),
146+
A=scipy_sparse_to_spmatrix(A_eq),
147+
b=matrix(b_eq),
148+
solver=solver,
149+
)
150+
151+
x = np.array(sol["x"])
152+
b = x[-n:].ravel()
153+
154+
if log:
155+
return b, sol
156+
else:
157+
return b
4158

5159

6160
def free_support_barycenter(
File renamed without changes.

Diff for: ot/lp/cvx.py

+4-144
Original file line numberDiff line numberDiff line change
@@ -1,152 +1,12 @@
11
# -*- coding: utf-8 -*-
22
"""
3-
LP solvers for optimal transport using cvxopt
3+
(DEPRECATED) LP solvers for optimal transport using cvxopt
44
"""
55

66
# Author: Remi Flamary <[email protected]>
77
#
88
# License: MIT License
99

10-
import numpy as np
11-
import scipy as sp
12-
import scipy.sparse as sps
13-
14-
try:
15-
import cvxopt
16-
from cvxopt import solvers, matrix, spmatrix
17-
except ImportError:
18-
cvxopt = False
19-
20-
21-
def scipy_sparse_to_spmatrix(A):
22-
"""Efficient conversion from scipy sparse matrix to cvxopt sparse matrix"""
23-
coo = A.tocoo()
24-
SP = spmatrix(coo.data.tolist(), coo.row.tolist(), coo.col.tolist(), size=A.shape)
25-
return SP
26-
27-
28-
def barycenter(A, M, weights=None, verbose=False, log=False, solver="highs-ipm"):
29-
r"""Compute the Wasserstein barycenter of distributions A
30-
31-
The function solves the following optimization problem [16]:
32-
33-
.. math::
34-
\mathbf{a} = arg\min_\mathbf{a} \sum_i W_{1}(\mathbf{a},\mathbf{a}_i)
35-
36-
where :
37-
38-
- :math:`W_1(\cdot,\cdot)` is the Wasserstein distance (see ot.emd.sinkhorn)
39-
- :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}`
40-
41-
The linear program is solved using the interior point solver from scipy.optimize.
42-
If cvxopt solver if installed it can use cvxopt
43-
44-
Note that this problem do not scale well (both in memory and computational time).
45-
46-
Parameters
47-
----------
48-
A : np.ndarray (d,n)
49-
n training distributions a_i of size d
50-
M : np.ndarray (d,d)
51-
loss matrix for OT
52-
reg : float
53-
Regularization term >0
54-
weights : np.ndarray (n,)
55-
Weights of each histogram a_i on the simplex (barycentric coordinates)
56-
verbose : bool, optional
57-
Print information along iterations
58-
log : bool, optional
59-
record log if True
60-
solver : string, optional
61-
the solver used, default 'interior-point' use the lp solver from
62-
scipy.optimize. None, or 'glpk' or 'mosek' use the solver from cvxopt.
63-
64-
Returns
65-
-------
66-
a : (d,) ndarray
67-
Wasserstein barycenter
68-
log : dict
69-
log dictionary return only if log==True in parameters
70-
71-
72-
References
73-
----------
74-
75-
.. [16] Agueh, M., & Carlier, G. (2011). Barycenters in the Wasserstein space. SIAM Journal on Mathematical Analysis, 43(2), 904-924.
76-
77-
78-
"""
79-
80-
if weights is None:
81-
weights = np.ones(A.shape[1]) / A.shape[1]
82-
else:
83-
assert len(weights) == A.shape[1]
84-
85-
n_distributions = A.shape[1]
86-
n = A.shape[0]
87-
88-
n2 = n * n
89-
c = np.zeros((0))
90-
b_eq1 = np.zeros((0))
91-
for i in range(n_distributions):
92-
c = np.concatenate((c, M.ravel() * weights[i]))
93-
b_eq1 = np.concatenate((b_eq1, A[:, i]))
94-
c = np.concatenate((c, np.zeros(n)))
95-
96-
lst_idiag1 = [sps.kron(sps.eye(n), np.ones((1, n))) for i in range(n_distributions)]
97-
# row constraints
98-
A_eq1 = sps.hstack(
99-
(sps.block_diag(lst_idiag1), sps.coo_matrix((n_distributions * n, n)))
100-
)
101-
102-
# columns constraints
103-
lst_idiag2 = []
104-
lst_eye = []
105-
for i in range(n_distributions):
106-
if i == 0:
107-
lst_idiag2.append(sps.kron(np.ones((1, n)), sps.eye(n)))
108-
lst_eye.append(-sps.eye(n))
109-
else:
110-
lst_idiag2.append(sps.kron(np.ones((1, n)), sps.eye(n - 1, n)))
111-
lst_eye.append(-sps.eye(n - 1, n))
112-
113-
A_eq2 = sps.hstack((sps.block_diag(lst_idiag2), sps.vstack(lst_eye)))
114-
b_eq2 = np.zeros((A_eq2.shape[0]))
115-
116-
# full problem
117-
A_eq = sps.vstack((A_eq1, A_eq2))
118-
b_eq = np.concatenate((b_eq1, b_eq2))
119-
120-
if not cvxopt or solver in ["interior-point", "highs", "highs-ipm", "highs-ds"]:
121-
# cvxopt not installed or interior point
122-
123-
if solver is None:
124-
solver = "interior-point"
125-
126-
options = {"disp": verbose}
127-
sol = sp.optimize.linprog(
128-
c, A_eq=A_eq, b_eq=b_eq, method=solver, options=options
129-
)
130-
x = sol.x
131-
b = x[-n:]
132-
133-
else:
134-
h = np.zeros((n_distributions * n2 + n))
135-
G = -sps.eye(n_distributions * n2 + n)
136-
137-
sol = solvers.lp(
138-
matrix(c),
139-
scipy_sparse_to_spmatrix(G),
140-
matrix(h),
141-
A=scipy_sparse_to_spmatrix(A_eq),
142-
b=matrix(b_eq),
143-
solver=solver,
144-
)
145-
146-
x = np.array(sol["x"])
147-
b = x[-n:].ravel()
148-
149-
if log:
150-
return b, sol
151-
else:
152-
return b
10+
print(
11+
"The module ot.lp.cvx is deprecated and will be removed in future versions. The function `barycenter` was moved to ot.lp._barycenter_solvers and can be importer via ot.lp."
12+
)

0 commit comments

Comments
 (0)