Skip to content

Commit e547fe3

Browse files
authored
[MRG] Correct pointer overflow in EMD (#381)
* avoid overflow on openmp version of emd solver * monothread version updated * Fixed typo in readme * added PR in releases * typo in releases.md * added a precision to releases.md * added a precision to releases.md * correct readme * forgot to cast * lower error
1 parent 1f30759 commit e547fe3

File tree

4 files changed

+26
-23
lines changed

4 files changed

+26
-23
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ POT provides the following generic OT solvers (links to examples):
2626
* Debiased Sinkhorn barycenters [Sinkhorn divergence barycenter](https://pythonot.github.io/auto_examples/barycenters/plot_debiased_barycenter.html) [37]
2727
* [Smooth optimal transport solvers](https://pythonot.github.io/auto_examples/plot_OT_1D_smooth.html) (dual and semi-dual) for KL and squared L2 regularizations [17].
2828
* Weak OT solver between empirical distributions [39]
29-
* Non regularized [Wasserstein barycenters [16] ](https://pythonot.github.io/auto_examples/barycenters/plot_barycenter_lp_vs_entropic.html)) with LP solver (only small scale).
30-
* [Gromov-Wasserstein distances](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) and [GW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_gromov_barycenter.html) (exact [13] and regularized [12]), differentiable using gradients from
29+
* Non regularized [Wasserstein barycenters [16] ](https://pythonot.github.io/auto_examples/barycenters/plot_barycenter_lp_vs_entropic.html) with LP solver (only small scale).
30+
* [Gromov-Wasserstein distances](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) and [GW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_gromov_barycenter.html) (exact [13] and regularized [12]), differentiable using gradients from Graph Dictionary Learning [38]
3131
* [Fused-Gromov-Wasserstein distances solver](https://pythonot.github.io/auto_examples/gromov/plot_fgw.html#sphx-glr-auto-examples-plot-fgw-py) and [FGW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_barycenter_fgw.html) [24]
3232
* [Stochastic
3333
solver](https://pythonot.github.io/auto_examples/others/plot_stochastic.html) and

RELEASES.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@
1313
- Fixed an issue where Sinkhorn solver assumed a symmetric cost matrix (Issue #374, PR #375)
1414
- Fixed an issue where hitting iteration limits would be reported to stderr by std::cerr regardless of Python's stderr stream status (PR #377)
1515
- Fixed an issue where the metric argument in ot.dist did not allow a callable parameter (Issue #378, PR #379)
16-
- Fixed an issue where the max number of iterations in ot.emd was not allow to go beyond 2^31 (PR #380)
16+
- Fixed an issue where the max number of iterations in ot.emd was not allowed to go beyond 2^31 (PR #380)
17+
- Fixed an issue where pointers would overflow in the EMD solver, returning an
18+
incomplete transport plan above a certain size (slightly above 46k, its square being
19+
roughly 2^31) (PR #381)
1720

1821

1922
## 0.8.2

ot/lp/EMD_wrapper.cpp

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
2424
// beware M and C are stored in row major C style!!!
2525

2626
using namespace lemon;
27-
int n, m, cur;
27+
uint64_t n, m, cur;
2828

2929
typedef FullBipartiteDigraph Digraph;
3030
DIGRAPH_TYPEDEFS(Digraph);
@@ -51,15 +51,15 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
5151

5252
// Define the graph
5353

54-
std::vector<int> indI(n), indJ(m);
54+
std::vector<uint64_t> indI(n), indJ(m);
5555
std::vector<double> weights1(n), weights2(m);
5656
Digraph di(n, m);
57-
NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, n+m, ((int64_t)n)*((int64_t)m), maxIter);
57+
NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, (int) (n + m), n * m, maxIter);
5858

5959
// Set supply and demand, don't account for 0 values (faster)
6060

6161
cur=0;
62-
for (int i=0; i<n1; i++) {
62+
for (uint64_t i=0; i<n1; i++) {
6363
double val=*(X+i);
6464
if (val>0) {
6565
weights1[ cur ] = val;
@@ -70,7 +70,7 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
7070
// Demand is actually negative supply...
7171

7272
cur=0;
73-
for (int i=0; i<n2; i++) {
73+
for (uint64_t i=0; i<n2; i++) {
7474
double val=*(Y+i);
7575
if (val>0) {
7676
weights2[ cur ] = -val;
@@ -79,12 +79,12 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
7979
}
8080

8181

82-
net.supplyMap(&weights1[0], n, &weights2[0], m);
82+
net.supplyMap(&weights1[0], (int) n, &weights2[0], (int) m);
8383

8484
// Set the cost of each edge
8585
int64_t idarc = 0;
86-
for (int i=0; i<n; i++) {
87-
for (int j=0; j<m; j++) {
86+
for (uint64_t i=0; i<n; i++) {
87+
for (uint64_t j=0; j<m; j++) {
8888
double val=*(D+indI[i]*n2+indJ[j]);
8989
net.setCost(di.arcFromId(idarc), val);
9090
++idarc;
@@ -95,7 +95,7 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
9595
// Solve the problem with the network simplex algorithm
9696

9797
int ret=net.run();
98-
int i, j;
98+
uint64_t i, j;
9999
if (ret==(int)net.OPTIMAL || ret==(int)net.MAX_ITER_REACHED) {
100100
*cost = 0;
101101
Arc a; di.first(a);
@@ -126,7 +126,7 @@ int EMD_wrap_omp(int n1, int n2, double *X, double *Y, double *D, double *G,
126126
// beware M and C are stored in row major C style!!!
127127

128128
using namespace lemon_omp;
129-
int n, m, cur;
129+
uint64_t n, m, cur;
130130

131131
typedef FullBipartiteDigraph Digraph;
132132
DIGRAPH_TYPEDEFS(Digraph);
@@ -153,15 +153,15 @@ int EMD_wrap_omp(int n1, int n2, double *X, double *Y, double *D, double *G,
153153

154154
// Define the graph
155155

156-
std::vector<int> indI(n), indJ(m);
156+
std::vector<uint64_t> indI(n), indJ(m);
157157
std::vector<double> weights1(n), weights2(m);
158158
Digraph di(n, m);
159-
NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, n+m, ((int64_t)n)*((int64_t)m), maxIter, numThreads);
159+
NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, (int) (n + m), n * m, maxIter, numThreads);
160160

161161
// Set supply and demand, don't account for 0 values (faster)
162162

163163
cur=0;
164-
for (int i=0; i<n1; i++) {
164+
for (uint64_t i=0; i<n1; i++) {
165165
double val=*(X+i);
166166
if (val>0) {
167167
weights1[ cur ] = val;
@@ -172,7 +172,7 @@ int EMD_wrap_omp(int n1, int n2, double *X, double *Y, double *D, double *G,
172172
// Demand is actually negative supply...
173173

174174
cur=0;
175-
for (int i=0; i<n2; i++) {
175+
for (uint64_t i=0; i<n2; i++) {
176176
double val=*(Y+i);
177177
if (val>0) {
178178
weights2[ cur ] = -val;
@@ -181,12 +181,12 @@ int EMD_wrap_omp(int n1, int n2, double *X, double *Y, double *D, double *G,
181181
}
182182

183183

184-
net.supplyMap(&weights1[0], n, &weights2[0], m);
184+
net.supplyMap(&weights1[0], (int) n, &weights2[0], (int) m);
185185

186186
// Set the cost of each edge
187187
int64_t idarc = 0;
188-
for (int i=0; i<n; i++) {
189-
for (int j=0; j<m; j++) {
188+
for (uint64_t i=0; i<n; i++) {
189+
for (uint64_t j=0; j<m; j++) {
190190
double val=*(D+indI[i]*n2+indJ[j]);
191191
net.setCost(di.arcFromId(idarc), val);
192192
++idarc;
@@ -197,7 +197,7 @@ int EMD_wrap_omp(int n1, int n2, double *X, double *Y, double *D, double *G,
197197
// Solve the problem with the network simplex algorithm
198198

199199
int ret=net.run();
200-
int i, j;
200+
uint64_t i, j;
201201
if (ret==(int)net.OPTIMAL || ret==(int)net.MAX_ITER_REACHED) {
202202
*cost = 0;
203203
Arc a; di.first(a);

ot/lp/network_simplex_simple_omp.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@
4141
#undef EPSILON
4242
#undef _EPSILON
4343
#undef MAX_DEBUG_ITER
44-
#define EPSILON std::numeric_limits<Cost>::epsilon()*10
45-
#define _EPSILON 1e-8
44+
#define EPSILON std::numeric_limits<Cost>::epsilon()
45+
#define _EPSILON 1e-14
4646
#define MAX_DEBUG_ITER 100000
4747

4848
/// \ingroup min_cost_flow_algs

0 commit comments

Comments
 (0)