Skip to content

Commit 982f36c

Browse files
committed
Changes:
- Rename numItermax to max_iter - Default value to 100000 instead of 10000 - Add max_iter to class SinkhornTransport(BaseTransport) - Add norm to all BaseTransport
1 parent 308ce24 commit 982f36c

File tree

5 files changed

+123
-75
lines changed

5 files changed

+123
-75
lines changed

ot/da.py

Lines changed: 55 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -658,7 +658,7 @@ def __init__(self, metric='sqeuclidean'):
658658
self.metric = metric
659659
self.computed = False
660660

661-
def fit(self, xs, xt, ws=None, wt=None, norm=None, numItermax=10000):
661+
def fit(self, xs, xt, ws=None, wt=None, norm=None, max_iter=100000):
662662
"""Fit domain adaptation between samples is xs and xt
663663
(with optional weights)"""
664664
self.xs = xs
@@ -674,7 +674,7 @@ def fit(self, xs, xt, ws=None, wt=None, norm=None, numItermax=10000):
674674

675675
self.M = dist(xs, xt, metric=self.metric)
676676
self.normalizeM(norm)
677-
self.G = emd(ws, wt, self.M, numItermax)
677+
self.G = emd(ws, wt, self.M, max_iter)
678678
self.computed = True
679679

680680
def interp(self, direction=1):
@@ -1001,6 +1001,7 @@ def fit(self, Xs=None, ys=None, Xt=None, yt=None):
10011001

10021002
# pairwise distance
10031003
self.cost_ = dist(Xs, Xt, metric=self.metric)
1004+
self.normalizeCost_(self.norm)
10041005

10051006
if (ys is not None) and (yt is not None):
10061007

@@ -1182,6 +1183,26 @@ def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None,
11821183

11831184
return transp_Xt
11841185

1186+
def normalizeCost_(self, norm):
1187+
""" Apply normalization to the loss matrix
1188+
1189+
1190+
Parameters
1191+
----------
1192+
norm : str
1193+
type of normalization from 'median','max','log','loglog'
1194+
1195+
"""
1196+
1197+
if norm == "median":
1198+
self.cost_ /= float(np.median(self.cost_))
1199+
elif norm == "max":
1200+
self.cost_ /= float(np.max(self.cost_))
1201+
elif norm == "log":
1202+
self.cost_ = np.log(1 + self.cost_)
1203+
elif norm == "loglog":
1204+
self.cost_ = np.log(1 + np.log(1 + self.cost_))
1205+
11851206

11861207
class SinkhornTransport(BaseTransport):
11871208
"""Domain Adapatation OT method based on Sinkhorn Algorithm
@@ -1202,6 +1223,9 @@ class SinkhornTransport(BaseTransport):
12021223
be transported from a domain to another one.
12031224
metric : string, optional (default="sqeuclidean")
12041225
The ground metric for the Wasserstein problem
1226+
norm : string, optional (default=None)
1227+
If given, normalize the ground metric to avoid numerical errors that
1228+
can occur with large metric values.
12051229
distribution : string, optional (default="uniform")
12061230
The kind of distribution estimation to employ
12071231
verbose : int, optional (default=0)
@@ -1231,7 +1255,7 @@ class SinkhornTransport(BaseTransport):
12311255

12321256
def __init__(self, reg_e=1., max_iter=1000,
12331257
tol=10e-9, verbose=False, log=False,
1234-
metric="sqeuclidean",
1258+
metric="sqeuclidean", norm=None,
12351259
distribution_estimation=distribution_estimation_uniform,
12361260
out_of_sample_map='ferradans', limit_max=np.infty):
12371261

@@ -1241,6 +1265,7 @@ def __init__(self, reg_e=1., max_iter=1000,
12411265
self.verbose = verbose
12421266
self.log = log
12431267
self.metric = metric
1268+
self.norm = norm
12441269
self.limit_max = limit_max
12451270
self.distribution_estimation = distribution_estimation
12461271
self.out_of_sample_map = out_of_sample_map
@@ -1296,6 +1321,9 @@ class EMDTransport(BaseTransport):
12961321
be transported from a domain to another one.
12971322
metric : string, optional (default="sqeuclidean")
12981323
The ground metric for the Wasserstein problem
1324+
norm : string, optional (default=None)
1325+
If given, normalize the ground metric to avoid numerical errors that
1326+
can occur with large metric values.
12991327
distribution : string, optional (default="uniform")
13001328
The kind of distribution estimation to employ
13011329
verbose : int, optional (default=0)
@@ -1306,6 +1334,9 @@ class EMDTransport(BaseTransport):
13061334
Controls the semi supervised mode. Transport between labeled source
13071335
and target samples of different classes will exhibit an infinite cost
13081336
(10 times the maximum value of the cost matrix)
1337+
max_iter : int, optional (default=100000)
1338+
The maximum number of iterations before stopping the optimization
1339+
algorithm if it has not converged.
13091340
13101341
Attributes
13111342
----------
@@ -1319,14 +1350,17 @@ class EMDTransport(BaseTransport):
13191350
on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
13201351
"""
13211352

1322-
def __init__(self, metric="sqeuclidean",
1353+
def __init__(self, metric="sqeuclidean", norm=None,
13231354
distribution_estimation=distribution_estimation_uniform,
1324-
out_of_sample_map='ferradans', limit_max=10):
1355+
out_of_sample_map='ferradans', limit_max=10,
1356+
max_iter=100000):
13251357

13261358
self.metric = metric
1359+
self.norm = norm
13271360
self.limit_max = limit_max
13281361
self.distribution_estimation = distribution_estimation
13291362
self.out_of_sample_map = out_of_sample_map
1363+
self.max_iter = max_iter
13301364

13311365
def fit(self, Xs, ys=None, Xt=None, yt=None):
13321366
"""Build a coupling matrix from source and target sets of samples
@@ -1353,7 +1387,7 @@ def fit(self, Xs, ys=None, Xt=None, yt=None):
13531387

13541388
# coupling estimation
13551389
self.coupling_ = emd(
1356-
a=self.mu_s, b=self.mu_t, M=self.cost_,
1390+
a=self.mu_s, b=self.mu_t, M=self.cost_, max_iter=self.max_iter
13571391
)
13581392

13591393
return self
@@ -1376,6 +1410,9 @@ class SinkhornLpl1Transport(BaseTransport):
13761410
be transported from a domain to another one.
13771411
metric : string, optional (default="sqeuclidean")
13781412
The ground metric for the Wasserstein problem
1413+
norm : string, optional (default=None)
1414+
If given, normalize the ground metric to avoid numerical errors that
1415+
can occur with large metric values.
13791416
distribution : string, optional (default="uniform")
13801417
The kind of distribution estimation to employ
13811418
max_iter : int, float, optional (default=10)
@@ -1410,7 +1447,7 @@ class SinkhornLpl1Transport(BaseTransport):
14101447
def __init__(self, reg_e=1., reg_cl=0.1,
14111448
max_iter=10, max_inner_iter=200,
14121449
tol=10e-9, verbose=False,
1413-
metric="sqeuclidean",
1450+
metric="sqeuclidean", norm=None,
14141451
distribution_estimation=distribution_estimation_uniform,
14151452
out_of_sample_map='ferradans', limit_max=np.infty):
14161453

@@ -1421,6 +1458,7 @@ def __init__(self, reg_e=1., reg_cl=0.1,
14211458
self.tol = tol
14221459
self.verbose = verbose
14231460
self.metric = metric
1461+
self.norm = norm
14241462
self.distribution_estimation = distribution_estimation
14251463
self.out_of_sample_map = out_of_sample_map
14261464
self.limit_max = limit_max
@@ -1477,6 +1515,9 @@ class SinkhornL1l2Transport(BaseTransport):
14771515
be transported from a domain to another one.
14781516
metric : string, optional (default="sqeuclidean")
14791517
The ground metric for the Wasserstein problem
1518+
norm : string, optional (default=None)
1519+
If given, normalize the ground metric to avoid numerical errors that
1520+
can occur with large metric values.
14801521
distribution : string, optional (default="uniform")
14811522
The kind of distribution estimation to employ
14821523
max_iter : int, float, optional (default=10)
@@ -1516,7 +1557,7 @@ class SinkhornL1l2Transport(BaseTransport):
15161557
def __init__(self, reg_e=1., reg_cl=0.1,
15171558
max_iter=10, max_inner_iter=200,
15181559
tol=10e-9, verbose=False, log=False,
1519-
metric="sqeuclidean",
1560+
metric="sqeuclidean", norm=None,
15201561
distribution_estimation=distribution_estimation_uniform,
15211562
out_of_sample_map='ferradans', limit_max=10):
15221563

@@ -1528,6 +1569,7 @@ def __init__(self, reg_e=1., reg_cl=0.1,
15281569
self.verbose = verbose
15291570
self.log = log
15301571
self.metric = metric
1572+
self.norm = norm
15311573
self.distribution_estimation = distribution_estimation
15321574
self.out_of_sample_map = out_of_sample_map
15331575
self.limit_max = limit_max
@@ -1588,6 +1630,9 @@ class MappingTransport(BaseEstimator):
15881630
Estimate linear mapping with constant bias
15891631
metric : string, optional (default="sqeuclidean")
15901632
The ground metric for the Wasserstein problem
1633+
norm : string, optional (default=None)
1634+
If given, normalize the ground metric to avoid numerical errors that
1635+
can occur with large metric values.
15911636
kernel : string, optional (default="linear")
15921637
The kernel to use either linear or gaussian
15931638
sigma : float, optional (default=1)
@@ -1627,11 +1672,12 @@ class MappingTransport(BaseEstimator):
16271672
"""
16281673

16291674
def __init__(self, mu=1, eta=0.001, bias=False, metric="sqeuclidean",
1630-
kernel="linear", sigma=1, max_iter=100, tol=1e-5,
1675+
norm=None, kernel="linear", sigma=1, max_iter=100, tol=1e-5,
16311676
max_inner_iter=10, inner_tol=1e-6, log=False, verbose=False,
16321677
verbose2=False):
16331678

16341679
self.metric = metric
1680+
self.norm = norm
16351681
self.mu = mu
16361682
self.eta = eta
16371683
self.bias = bias

ot/lp/EMD.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,6 @@ enum ProblemType {
2929
UNBOUNDED
3030
};
3131

32-
int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double *cost, int numItermax);
32+
int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double *cost, int max_iter);
3333

3434
#endif

ot/lp/EMD_wrapper.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
#include "EMD.h"
1616

1717

18-
int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double *cost, int numItermax) {
18+
int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double *cost, int max_iter) {
1919
// beware M and C anre strored in row major C style!!!
2020
int n, m, i,cur;
2121
double max;
@@ -45,7 +45,7 @@ int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double *c
4545
std::vector<int> indI(n), indJ(m);
4646
std::vector<double> weights1(n), weights2(m);
4747
Digraph di(n, m);
48-
NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, n+m, n*m, numItermax);
48+
NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, n+m, n*m, max_iter);
4949

5050
// Set supply and demand, don't account for 0 values (faster)
5151

ot/lp/__init__.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@
1414
import multiprocessing
1515

1616

17-
18-
def emd(a, b, M, numItermax=10000):
17+
def emd(a, b, M, max_iter=100000):
1918
"""Solves the Earth Movers distance problem and returns the OT matrix
2019
2120
@@ -40,8 +39,9 @@ def emd(a, b, M, numItermax=10000):
4039
Target histogram (uniform weigth if empty list)
4140
M : (ns,nt) ndarray, float64
4241
loss matrix
43-
numItermax : int
44-
Maximum number of iterations made by the LP solver.
42+
max_iter : int, optional (default=100000)
43+
The maximum number of iterations before stopping the optimization
44+
algorithm if it has not converged.
4545
4646
Returns
4747
-------
@@ -54,7 +54,7 @@ def emd(a, b, M, numItermax=10000):
5454
5555
Simple example with obvious solution. The function emd accepts lists and
5656
perform automatic conversion to numpy arrays
57-
57+
5858
>>> import ot
5959
>>> a=[.5,.5]
6060
>>> b=[.5,.5]
@@ -86,10 +86,11 @@ def emd(a, b, M, numItermax=10000):
8686
if len(b) == 0:
8787
b = np.ones((M.shape[1], ), dtype=np.float64)/M.shape[1]
8888

89-
return emd_c(a, b, M, numItermax)
89+
return emd_c(a, b, M, max_iter)
90+
9091

91-
def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=10000):
92-
"""Solves the Earth Movers distance problem and returns the loss
92+
def emd2(a, b, M, processes=multiprocessing.cpu_count(), max_iter=100000):
93+
"""Solves the Earth Movers distance problem and returns the loss
9394
9495
.. math::
9596
\gamma = arg\min_\gamma <\gamma,M>_F
@@ -112,8 +113,9 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=10000):
112113
Target histogram (uniform weigth if empty list)
113114
M : (ns,nt) ndarray, float64
114115
loss matrix
115-
numItermax : int
116-
Maximum number of iterations made by the LP solver.
116+
max_iter : int, optional (default=100000)
117+
The maximum number of iterations before stopping the optimization
118+
algorithm if it has not converged.
117119
118120
Returns
119121
-------
@@ -126,15 +128,15 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=10000):
126128
127129
Simple example with obvious solution. The function emd accepts lists and
128130
perform automatic conversion to numpy arrays
129-
130-
131+
132+
131133
>>> import ot
132134
>>> a=[.5,.5]
133135
>>> b=[.5,.5]
134136
>>> M=[[0.,1.],[1.,0.]]
135137
>>> ot.emd2(a,b,M)
136138
0.0
137-
139+
138140
References
139141
----------
140142
@@ -157,16 +159,14 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=10000):
157159
a = np.ones((M.shape[0], ), dtype=np.float64)/M.shape[0]
158160
if len(b) == 0:
159161
b = np.ones((M.shape[1], ), dtype=np.float64)/M.shape[1]
160-
161-
if len(b.shape)==1:
162-
return emd2_c(a, b, M, numItermax)
162+
163+
if len(b.shape) == 1:
164+
return emd2_c(a, b, M, max_iter)
163165
else:
164-
nb=b.shape[1]
165-
#res=[emd2_c(a,b[:,i].copy(),M, numItermax) for i in range(nb)]
166+
nb = b.shape[1]
167+
# res = [emd2_c(a, b[:, i].copy(), M, max_iter) for i in range(nb)]
168+
166169
def f(b):
167-
return emd2_c(a,b,M, numItermax)
168-
res= parmap(f, [b[:,i] for i in range(nb)],processes)
170+
return emd2_c(a, b, M, max_iter)
171+
res = parmap(f, [b[:, i] for i in range(nb)], processes)
169172
return np.array(res)
170-
171-
172-

0 commit comments

Comments
 (0)