Skip to content

Commit a53ede9

Browse files
authored
Merge pull request #29 from arolet/ot_dual_variables
Dual variables in EMD_wrapper
2 parents 62dcfbf + e52b6eb commit a53ede9

File tree

7 files changed

+281
-192
lines changed

7 files changed

+281
-192
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,12 +138,12 @@ The contributors to this library are:
138138
* [Léo Gautheron](https://github.com/aje) (GPU implementation)
139139
* [Nathalie Gayraud](https://www.linkedin.com/in/nathalie-t-h-gayraud/?ppe=1)
140140
* [Stanislas Chambon](https://slasnista.github.io/)
141+
* [Antoine Rolet](https://arolet.github.io/)
141142

142143
This toolbox benefit a lot from open source research and we would like to thank the following persons for providing some code (in various languages):
143144

144145
* [Gabriel Peyré](http://gpeyre.github.io/) (Wasserstein Barycenters in Matlab)
145146
* [Nicolas Bonneel](http://liris.cnrs.fr/~nbonneel/) ( C++ code for EMD)
146-
* [Antoine Rolet](https://arolet.github.io/) ( Mex file for EMD )
147147
* [Marco Cuturi](http://marcocuturi.net/) (Sinkhorn Knopp in Matlab/Cuda)
148148

149149

ot/lp/EMD.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,10 @@ typedef unsigned int node_id_type;
2626
enum ProblemType {
2727
INFEASIBLE,
2828
OPTIMAL,
29-
UNBOUNDED
29+
UNBOUNDED,
30+
MAX_ITER_REACHED
3031
};
3132

32-
int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double *cost, int max_iter);
33+
int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter);
3334

3435
#endif

ot/lp/EMD_wrapper.cpp

Lines changed: 31 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -15,104 +15,92 @@
1515
#include "EMD.h"
1616

1717

18-
int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double *cost, int max_iter) {
18+
int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
19+
double* alpha, double* beta, double *cost, int maxIter) {
1920
// beware M and C anre strored in row major C style!!!
20-
int n, m, i,cur;
21-
double max;
21+
int n, m, i, cur;
2222

2323
typedef FullBipartiteDigraph Digraph;
2424
DIGRAPH_TYPEDEFS(FullBipartiteDigraph);
2525

2626
// Get the number of non zero coordinates for r and c
2727
n=0;
28-
for (node_id_type i=0; i<n1; i++) {
28+
for (int i=0; i<n1; i++) {
2929
double val=*(X+i);
3030
if (val>0) {
3131
n++;
32-
}
32+
}else if(val<0){
33+
return INFEASIBLE;
34+
}
3335
}
3436
m=0;
35-
for (node_id_type i=0; i<n2; i++) {
37+
for (int i=0; i<n2; i++) {
3638
double val=*(Y+i);
3739
if (val>0) {
3840
m++;
39-
}
41+
}else if(val<0){
42+
return INFEASIBLE;
43+
}
4044
}
4145

42-
4346
// Define the graph
4447

4548
std::vector<int> indI(n), indJ(m);
4649
std::vector<double> weights1(n), weights2(m);
4750
Digraph di(n, m);
48-
NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, n+m, n*m, max_iter);
51+
NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, n+m, n*m, maxIter);
4952

5053
// Set supply and demand, don't account for 0 values (faster)
5154

52-
max=0;
5355
cur=0;
54-
for (node_id_type i=0; i<n1; i++) {
56+
for (int i=0; i<n1; i++) {
5557
double val=*(X+i);
5658
if (val>0) {
57-
weights1[ di.nodeFromId(cur) ] = val;
58-
max+=val;
59+
weights1[ cur ] = val;
5960
indI[cur++]=i;
6061
}
6162
}
6263

6364
// Demand is actually negative supply...
6465

65-
max=0;
6666
cur=0;
67-
for (node_id_type i=0; i<n2; i++) {
67+
for (int i=0; i<n2; i++) {
6868
double val=*(Y+i);
6969
if (val>0) {
70-
weights2[ di.nodeFromId(cur) ] = -val;
70+
weights2[ cur ] = -val;
7171
indJ[cur++]=i;
72-
73-
max-=val;
7472
}
7573
}
7674

7775

7876
net.supplyMap(&weights1[0], n, &weights2[0], m);
7977

8078
// Set the cost of each edge
81-
max=0;
82-
for (node_id_type i=0; i<n; i++) {
83-
for (node_id_type j=0; j<m; j++) {
79+
for (int i=0; i<n; i++) {
80+
for (int j=0; j<m; j++) {
8481
double val=*(D+indI[i]*n2+indJ[j]);
8582
net.setCost(di.arcFromId(i*m+j), val);
86-
if (val>max) {
87-
max=val;
88-
}
8983
}
9084
}
9185

9286

9387
// Solve the problem with the network simplex algorithm
9488

9589
int ret=net.run();
96-
if (ret!=(int)net.OPTIMAL) {
97-
if (ret==(int)net.INFEASIBLE) {
98-
std::cout << "Infeasible problem";
90+
if (ret==(int)net.OPTIMAL || ret==(int)net.MAX_ITER_REACHED) {
91+
*cost = 0;
92+
Arc a; di.first(a);
93+
for (; a != INVALID; di.next(a)) {
94+
int i = di.source(a);
95+
int j = di.target(a);
96+
double flow = net.flow(a);
97+
*cost += flow * (*(D+indI[i]*n2+indJ[j-n]));
98+
*(G+indI[i]*n2+indJ[j-n]) = flow;
99+
*(alpha + indI[i]) = -net.potential(i);
100+
*(beta + indJ[j-n]) = net.potential(j);
99101
}
100-
if (ret==(int)net.UNBOUNDED)
101-
{
102-
std::cout << "Unbounded problem";
103-
}
104-
} else
105-
{
106-
for (node_id_type i=0; i<n; i++)
107-
{
108-
for (node_id_type j=0; j<m; j++)
109-
{
110-
*(G+indI[i]*n2+indJ[j]) = net.flow(di.arcFromId(i*m+j));
111-
}
112-
};
113-
*cost = net.totalCost();
114-
115-
};
102+
103+
}
116104

117105

118106
return ret;

ot/lp/__init__.py

Lines changed: 60 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,16 @@
77
#
88
# License: MIT License
99

10+
import multiprocessing
11+
1012
import numpy as np
13+
1114
# import compiled emd
12-
from .emd_wrap import emd_c, emd2_c
15+
from .emd_wrap import emd_c, check_result
1316
from ..utils import parmap
14-
import multiprocessing
1517

1618

17-
def emd(a, b, M, numItermax=100000):
19+
def emd(a, b, M, numItermax=100000, log=False):
1820
"""Solves the Earth Movers distance problem and returns the OT matrix
1921
2022
@@ -42,11 +44,17 @@ def emd(a, b, M, numItermax=100000):
4244
numItermax : int, optional (default=100000)
4345
The maximum number of iterations before stopping the optimization
4446
algorithm if it has not converged.
47+
log: boolean, optional (default=False)
48+
If True, returns a dictionary containing the cost and dual
49+
variables. Otherwise returns only the optimal transportation matrix.
4550
4651
Returns
4752
-------
4853
gamma: (ns x nt) ndarray
4954
Optimal transportation matrix for the given parameters
55+
log: dict
56+
If input log is true, a dictionary containing the cost and dual
57+
variables and exit status
5058
5159
5260
Examples
@@ -82,14 +90,24 @@ def emd(a, b, M, numItermax=100000):
8290

8391
# if empty array given then use unifor distributions
8492
if len(a) == 0:
85-
a = np.ones((M.shape[0], ), dtype=np.float64)/M.shape[0]
93+
a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0]
8694
if len(b) == 0:
87-
b = np.ones((M.shape[1], ), dtype=np.float64)/M.shape[1]
88-
89-
return emd_c(a, b, M, numItermax)
90-
91-
92-
def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000):
95+
b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]
96+
97+
G, cost, u, v, result_code = emd_c(a, b, M, numItermax)
98+
result_code_string = check_result(result_code)
99+
if log:
100+
log = {}
101+
log['cost'] = cost
102+
log['u'] = u
103+
log['v'] = v
104+
log['warning'] = result_code_string
105+
log['result_code'] = result_code
106+
return G, log
107+
return G
108+
109+
110+
def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000, log=False, return_matrix=False):
93111
"""Solves the Earth Movers distance problem and returns the loss
94112
95113
.. math::
@@ -116,11 +134,19 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000):
116134
numItermax : int, optional (default=100000)
117135
The maximum number of iterations before stopping the optimization
118136
algorithm if it has not converged.
137+
log: boolean, optional (default=False)
138+
If True, returns a dictionary containing the cost and dual
139+
variables. Otherwise returns only the optimal transportation cost.
140+
return_matrix: boolean, optional (default=False)
141+
If True, returns the optimal transportation matrix in the log.
119142
120143
Returns
121144
-------
122145
gamma: (ns x nt) ndarray
123146
Optimal transportation matrix for the given parameters
147+
log: dict
148+
If input log is true, a dictionary containing the cost and dual
149+
variables and exit status
124150
125151
126152
Examples
@@ -156,17 +182,31 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000):
156182

157183
# if empty array given then use unifor distributions
158184
if len(a) == 0:
159-
a = np.ones((M.shape[0], ), dtype=np.float64)/M.shape[0]
185+
a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0]
160186
if len(b) == 0:
161-
b = np.ones((M.shape[1], ), dtype=np.float64)/M.shape[1]
187+
b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]
162188

163-
if len(b.shape) == 1:
164-
return emd2_c(a, b, M, numItermax)
189+
if log or return_matrix:
190+
def f(b):
191+
G, cost, u, v, resultCode = emd_c(a, b, M, numItermax)
192+
result_code_string = check_result(resultCode)
193+
log = {}
194+
if return_matrix:
195+
log['G'] = G
196+
log['u'] = u
197+
log['v'] = v
198+
log['warning'] = result_code_string
199+
log['result_code'] = resultCode
200+
return [cost, log]
165201
else:
166-
nb = b.shape[1]
167-
# res = [emd2_c(a, b[:, i].copy(), M, numItermax) for i in range(nb)]
168-
169202
def f(b):
170-
return emd2_c(a, b, M, numItermax)
171-
res = parmap(f, [b[:, i] for i in range(nb)], processes)
172-
return np.array(res)
203+
G, cost, u, v, result_code = emd_c(a, b, M, numItermax)
204+
check_result(result_code)
205+
return cost
206+
207+
if len(b.shape) == 1:
208+
return f(b)
209+
nb = b.shape[1]
210+
211+
res = parmap(f, [b[:, i] for i in range(nb)], processes)
212+
return res

0 commit comments

Comments
 (0)