Skip to content

Commit 1669704

Browse files
authored
Merge pull request #25 from aje/master
Add iter_max to lp solver and fixes #24
2 parents a2ec6e5 + fadaf2a commit 1669704

File tree

6 files changed

+163
-100
lines changed

6 files changed

+163
-100
lines changed

ot/da.py

Lines changed: 45 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
from .bregman import sinkhorn
1515
from .lp import emd
16-
from .utils import unif, dist, kernel
16+
from .utils import unif, dist, kernel, cost_normalization
1717
from .utils import check_params, deprecated, BaseEstimator
1818
from .optim import cg
1919
from .optim import gcg
@@ -650,15 +650,16 @@ class OTDA(object):
650650
651651
"""
652652

653-
def __init__(self, metric='sqeuclidean'):
653+
def __init__(self, metric='sqeuclidean', norm=None):
654654
""" Class initialization"""
655655
self.xs = 0
656656
self.xt = 0
657657
self.G = 0
658658
self.metric = metric
659+
self.norm = norm
659660
self.computed = False
660661

661-
def fit(self, xs, xt, ws=None, wt=None, norm=None):
662+
def fit(self, xs, xt, ws=None, wt=None, max_iter=100000):
662663
"""Fit domain adaptation between samples is xs and xt
663664
(with optional weights)"""
664665
self.xs = xs
@@ -673,8 +674,8 @@ def fit(self, xs, xt, ws=None, wt=None, norm=None):
673674
self.wt = wt
674675

675676
self.M = dist(xs, xt, metric=self.metric)
676-
self.normalizeM(norm)
677-
self.G = emd(ws, wt, self.M)
677+
self.M = cost_normalization(self.M, self.norm)
678+
self.G = emd(ws, wt, self.M, max_iter)
678679
self.computed = True
679680

680681
def interp(self, direction=1):
@@ -741,26 +742,6 @@ def predict(self, x, direction=1):
741742
# aply the delta to the interpolation
742743
return xf[idx, :] + x - x0[idx, :]
743744

744-
def normalizeM(self, norm):
745-
""" Apply normalization to the loss matrix
746-
747-
748-
Parameters
749-
----------
750-
norm : str
751-
type of normalization from 'median','max','log','loglog'
752-
753-
"""
754-
755-
if norm == "median":
756-
self.M /= float(np.median(self.M))
757-
elif norm == "max":
758-
self.M /= float(np.max(self.M))
759-
elif norm == "log":
760-
self.M = np.log(1 + self.M)
761-
elif norm == "loglog":
762-
self.M = np.log(1 + np.log(1 + self.M))
763-
764745

765746
@deprecated("The class OTDA_sinkhorn is deprecated in 0.3.1 and will be"
766747
" removed in 0.5 \nUse class SinkhornTransport instead.")
@@ -772,7 +753,7 @@ class OTDA_sinkhorn(OTDA):
772753
773754
"""
774755

775-
def fit(self, xs, xt, reg=1, ws=None, wt=None, norm=None, **kwargs):
756+
def fit(self, xs, xt, reg=1, ws=None, wt=None, **kwargs):
776757
"""Fit regularized domain adaptation between samples is xs and xt
777758
(with optional weights)"""
778759
self.xs = xs
@@ -787,7 +768,7 @@ def fit(self, xs, xt, reg=1, ws=None, wt=None, norm=None, **kwargs):
787768
self.wt = wt
788769

789770
self.M = dist(xs, xt, metric=self.metric)
790-
self.normalizeM(norm)
771+
self.M = cost_normalization(self.M, self.norm)
791772
self.G = sinkhorn(ws, wt, self.M, reg, **kwargs)
792773
self.computed = True
793774

@@ -799,8 +780,7 @@ class OTDA_lpl1(OTDA):
799780
"""Class for domain adaptation with optimal transport with entropic and
800781
group regularization"""
801782

802-
def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, norm=None,
803-
**kwargs):
783+
def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, **kwargs):
804784
"""Fit regularized domain adaptation between samples is xs and xt
805785
(with optional weights), See ot.da.sinkhorn_lpl1_mm for fit
806786
parameters"""
@@ -816,7 +796,7 @@ def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, norm=None,
816796
self.wt = wt
817797

818798
self.M = dist(xs, xt, metric=self.metric)
819-
self.normalizeM(norm)
799+
self.M = cost_normalization(self.M, self.norm)
820800
self.G = sinkhorn_lpl1_mm(ws, ys, wt, self.M, reg, eta, **kwargs)
821801
self.computed = True
822802

@@ -828,8 +808,7 @@ class OTDA_l1l2(OTDA):
828808
"""Class for domain adaptation with optimal transport with entropic
829809
and group lasso regularization"""
830810

831-
def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, norm=None,
832-
**kwargs):
811+
def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, **kwargs):
833812
"""Fit regularized domain adaptation between samples is xs and xt
834813
(with optional weights), See ot.da.sinkhorn_lpl1_gl for fit
835814
parameters"""
@@ -845,7 +824,7 @@ def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, norm=None,
845824
self.wt = wt
846825

847826
self.M = dist(xs, xt, metric=self.metric)
848-
self.normalizeM(norm)
827+
self.M = cost_normalization(self.M, self.norm)
849828
self.G = sinkhorn_l1l2_gl(ws, ys, wt, self.M, reg, eta, **kwargs)
850829
self.computed = True
851830

@@ -1001,6 +980,7 @@ def fit(self, Xs=None, ys=None, Xt=None, yt=None):
1001980

1002981
# pairwise distance
1003982
self.cost_ = dist(Xs, Xt, metric=self.metric)
983+
self.cost_ = cost_normalization(self.cost_, self.norm)
1004984

1005985
if (ys is not None) and (yt is not None):
1006986

@@ -1202,6 +1182,9 @@ class SinkhornTransport(BaseTransport):
12021182
be transported from a domain to another one.
12031183
metric : string, optional (default="sqeuclidean")
12041184
The ground metric for the Wasserstein problem
1185+
norm : string, optional (default=None)
1186+
If given, normalize the ground metric to avoid numerical errors that
1187+
can occur with large metric values.
12051188
distribution : string, optional (default="uniform")
12061189
The kind of distribution estimation to employ
12071190
verbose : int, optional (default=0)
@@ -1231,7 +1214,7 @@ class SinkhornTransport(BaseTransport):
12311214

12321215
def __init__(self, reg_e=1., max_iter=1000,
12331216
tol=10e-9, verbose=False, log=False,
1234-
metric="sqeuclidean",
1217+
metric="sqeuclidean", norm=None,
12351218
distribution_estimation=distribution_estimation_uniform,
12361219
out_of_sample_map='ferradans', limit_max=np.infty):
12371220

@@ -1241,6 +1224,7 @@ def __init__(self, reg_e=1., max_iter=1000,
12411224
self.verbose = verbose
12421225
self.log = log
12431226
self.metric = metric
1227+
self.norm = norm
12441228
self.limit_max = limit_max
12451229
self.distribution_estimation = distribution_estimation
12461230
self.out_of_sample_map = out_of_sample_map
@@ -1296,6 +1280,9 @@ class EMDTransport(BaseTransport):
12961280
be transported from a domain to another one.
12971281
metric : string, optional (default="sqeuclidean")
12981282
The ground metric for the Wasserstein problem
1283+
norm : string, optional (default=None)
1284+
If given, normalize the ground metric to avoid numerical errors that
1285+
can occur with large metric values.
12991286
distribution : string, optional (default="uniform")
13001287
The kind of distribution estimation to employ
13011288
verbose : int, optional (default=0)
@@ -1306,6 +1293,9 @@ class EMDTransport(BaseTransport):
13061293
Controls the semi supervised mode. Transport between labeled source
13071294
and target samples of different classes will exhibit an infinite cost
13081295
(10 times the maximum value of the cost matrix)
1296+
max_iter : int, optional (default=100000)
1297+
The maximum number of iterations before stopping the optimization
1298+
algorithm if it has not converged.
13091299
13101300
Attributes
13111301
----------
@@ -1319,14 +1309,17 @@ class EMDTransport(BaseTransport):
13191309
on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
13201310
"""
13211311

1322-
def __init__(self, metric="sqeuclidean",
1312+
def __init__(self, metric="sqeuclidean", norm=None,
13231313
distribution_estimation=distribution_estimation_uniform,
1324-
out_of_sample_map='ferradans', limit_max=10):
1314+
out_of_sample_map='ferradans', limit_max=10,
1315+
max_iter=100000):
13251316

13261317
self.metric = metric
1318+
self.norm = norm
13271319
self.limit_max = limit_max
13281320
self.distribution_estimation = distribution_estimation
13291321
self.out_of_sample_map = out_of_sample_map
1322+
self.max_iter = max_iter
13301323

13311324
def fit(self, Xs, ys=None, Xt=None, yt=None):
13321325
"""Build a coupling matrix from source and target sets of samples
@@ -1353,7 +1346,7 @@ def fit(self, Xs, ys=None, Xt=None, yt=None):
13531346

13541347
# coupling estimation
13551348
self.coupling_ = emd(
1356-
a=self.mu_s, b=self.mu_t, M=self.cost_,
1349+
a=self.mu_s, b=self.mu_t, M=self.cost_, numItermax=self.max_iter
13571350
)
13581351

13591352
return self
@@ -1376,6 +1369,9 @@ class SinkhornLpl1Transport(BaseTransport):
13761369
be transported from a domain to another one.
13771370
metric : string, optional (default="sqeuclidean")
13781371
The ground metric for the Wasserstein problem
1372+
norm : string, optional (default=None)
1373+
If given, normalize the ground metric to avoid numerical errors that
1374+
can occur with large metric values.
13791375
distribution : string, optional (default="uniform")
13801376
The kind of distribution estimation to employ
13811377
max_iter : int, float, optional (default=10)
@@ -1410,7 +1406,7 @@ class SinkhornLpl1Transport(BaseTransport):
14101406
def __init__(self, reg_e=1., reg_cl=0.1,
14111407
max_iter=10, max_inner_iter=200,
14121408
tol=10e-9, verbose=False,
1413-
metric="sqeuclidean",
1409+
metric="sqeuclidean", norm=None,
14141410
distribution_estimation=distribution_estimation_uniform,
14151411
out_of_sample_map='ferradans', limit_max=np.infty):
14161412

@@ -1421,6 +1417,7 @@ def __init__(self, reg_e=1., reg_cl=0.1,
14211417
self.tol = tol
14221418
self.verbose = verbose
14231419
self.metric = metric
1420+
self.norm = norm
14241421
self.distribution_estimation = distribution_estimation
14251422
self.out_of_sample_map = out_of_sample_map
14261423
self.limit_max = limit_max
@@ -1477,6 +1474,9 @@ class SinkhornL1l2Transport(BaseTransport):
14771474
be transported from a domain to another one.
14781475
metric : string, optional (default="sqeuclidean")
14791476
The ground metric for the Wasserstein problem
1477+
norm : string, optional (default=None)
1478+
If given, normalize the ground metric to avoid numerical errors that
1479+
can occur with large metric values.
14801480
distribution : string, optional (default="uniform")
14811481
The kind of distribution estimation to employ
14821482
max_iter : int, float, optional (default=10)
@@ -1516,7 +1516,7 @@ class SinkhornL1l2Transport(BaseTransport):
15161516
def __init__(self, reg_e=1., reg_cl=0.1,
15171517
max_iter=10, max_inner_iter=200,
15181518
tol=10e-9, verbose=False, log=False,
1519-
metric="sqeuclidean",
1519+
metric="sqeuclidean", norm=None,
15201520
distribution_estimation=distribution_estimation_uniform,
15211521
out_of_sample_map='ferradans', limit_max=10):
15221522

@@ -1528,6 +1528,7 @@ def __init__(self, reg_e=1., reg_cl=0.1,
15281528
self.verbose = verbose
15291529
self.log = log
15301530
self.metric = metric
1531+
self.norm = norm
15311532
self.distribution_estimation = distribution_estimation
15321533
self.out_of_sample_map = out_of_sample_map
15331534
self.limit_max = limit_max
@@ -1588,6 +1589,9 @@ class MappingTransport(BaseEstimator):
15881589
Estimate linear mapping with constant bias
15891590
metric : string, optional (default="sqeuclidean")
15901591
The ground metric for the Wasserstein problem
1592+
norm : string, optional (default=None)
1593+
If given, normalize the ground metric to avoid numerical errors that
1594+
can occur with large metric values.
15911595
kernel : string, optional (default="linear")
15921596
The kernel to use either linear or gaussian
15931597
sigma : float, optional (default=1)
@@ -1627,11 +1631,12 @@ class MappingTransport(BaseEstimator):
16271631
"""
16281632

16291633
def __init__(self, mu=1, eta=0.001, bias=False, metric="sqeuclidean",
1630-
kernel="linear", sigma=1, max_iter=100, tol=1e-5,
1634+
norm=None, kernel="linear", sigma=1, max_iter=100, tol=1e-5,
16311635
max_inner_iter=10, inner_tol=1e-6, log=False, verbose=False,
16321636
verbose2=False):
16331637

16341638
self.metric = metric
1639+
self.norm = norm
16351640
self.mu = mu
16361641
self.eta = eta
16371642
self.bias = bias

ot/lp/EMD.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,12 @@
2323
using namespace lemon;
2424
typedef unsigned int node_id_type;
2525

26+
enum ProblemType {
27+
INFEASIBLE,
28+
OPTIMAL,
29+
UNBOUNDED
30+
};
2631

27-
void EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double *cost);
32+
int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double *cost, int max_iter);
2833

2934
#endif

ot/lp/EMD_wrapper.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,10 @@
1515
#include "EMD.h"
1616

1717

18-
void EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double *cost) {
18+
int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double *cost, int max_iter) {
1919
// beware M and C anre strored in row major C style!!!
2020
int n, m, i,cur;
2121
double max;
22-
int max_iter=10000;
2322

2423
typedef FullBipartiteDigraph Digraph;
2524
DIGRAPH_TYPEDEFS(FullBipartiteDigraph);
@@ -46,7 +45,7 @@ void EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double *
4645
std::vector<int> indI(n), indJ(m);
4746
std::vector<double> weights1(n), weights2(m);
4847
Digraph di(n, m);
49-
NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, n+m, n*m,max_iter);
48+
NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, n+m, n*m, max_iter);
5049

5150
// Set supply and demand, don't account for 0 values (faster)
5251

@@ -116,5 +115,5 @@ void EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double *
116115
};
117116

118117

119-
118+
return ret;
120119
}

0 commit comments

Comments
 (0)