@@ -658,7 +658,7 @@ def __init__(self, metric='sqeuclidean'):
658
658
self .metric = metric
659
659
self .computed = False
660
660
661
- def fit (self , xs , xt , ws = None , wt = None , norm = None , numItermax = 10000 ):
661
+ def fit (self , xs , xt , ws = None , wt = None , norm = None , max_iter = 100000 ):
662
662
"""Fit domain adaptation between samples is xs and xt
663
663
(with optional weights)"""
664
664
self .xs = xs
@@ -674,7 +674,7 @@ def fit(self, xs, xt, ws=None, wt=None, norm=None, numItermax=10000):
674
674
675
675
self .M = dist (xs , xt , metric = self .metric )
676
676
self .normalizeM (norm )
677
- self .G = emd (ws , wt , self .M , numItermax )
677
+ self .G = emd (ws , wt , self .M , max_iter )
678
678
self .computed = True
679
679
680
680
def interp (self , direction = 1 ):
@@ -1001,6 +1001,7 @@ def fit(self, Xs=None, ys=None, Xt=None, yt=None):
1001
1001
1002
1002
# pairwise distance
1003
1003
self .cost_ = dist (Xs , Xt , metric = self .metric )
1004
+ self .normalizeCost_ (self .norm )
1004
1005
1005
1006
if (ys is not None ) and (yt is not None ):
1006
1007
@@ -1182,6 +1183,26 @@ def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None,
1182
1183
1183
1184
return transp_Xt
1184
1185
1186
+ def normalizeCost_ (self , norm ):
1187
+ """ Apply normalization to the loss matrix
1188
+
1189
+
1190
+ Parameters
1191
+ ----------
1192
+ norm : str
1193
+ type of normalization from 'median','max','log','loglog'
1194
+
1195
+ """
1196
+
1197
+ if norm == "median" :
1198
+ self .cost_ /= float (np .median (self .cost_ ))
1199
+ elif norm == "max" :
1200
+ self .cost_ /= float (np .max (self .cost_ ))
1201
+ elif norm == "log" :
1202
+ self .cost_ = np .log (1 + self .cost_ )
1203
+ elif norm == "loglog" :
1204
+ self .cost_ = np .log (1 + np .log (1 + self .cost_ ))
1205
+
1185
1206
1186
1207
class SinkhornTransport (BaseTransport ):
1187
1208
"""Domain Adapatation OT method based on Sinkhorn Algorithm
@@ -1202,6 +1223,9 @@ class SinkhornTransport(BaseTransport):
1202
1223
be transported from a domain to another one.
1203
1224
metric : string, optional (default="sqeuclidean")
1204
1225
The ground metric for the Wasserstein problem
1226
+ norm : string, optional (default=None)
1227
+ If given, normalize the ground metric to avoid numerical errors that
1228
+ can occur with large metric values.
1205
1229
distribution : string, optional (default="uniform")
1206
1230
The kind of distribution estimation to employ
1207
1231
verbose : int, optional (default=0)
@@ -1231,7 +1255,7 @@ class SinkhornTransport(BaseTransport):
1231
1255
1232
1256
def __init__ (self , reg_e = 1. , max_iter = 1000 ,
1233
1257
tol = 10e-9 , verbose = False , log = False ,
1234
- metric = "sqeuclidean" ,
1258
+ metric = "sqeuclidean" , norm = None ,
1235
1259
distribution_estimation = distribution_estimation_uniform ,
1236
1260
out_of_sample_map = 'ferradans' , limit_max = np .infty ):
1237
1261
@@ -1241,6 +1265,7 @@ def __init__(self, reg_e=1., max_iter=1000,
1241
1265
self .verbose = verbose
1242
1266
self .log = log
1243
1267
self .metric = metric
1268
+ self .norm = norm
1244
1269
self .limit_max = limit_max
1245
1270
self .distribution_estimation = distribution_estimation
1246
1271
self .out_of_sample_map = out_of_sample_map
@@ -1296,6 +1321,9 @@ class EMDTransport(BaseTransport):
1296
1321
be transported from a domain to another one.
1297
1322
metric : string, optional (default="sqeuclidean")
1298
1323
The ground metric for the Wasserstein problem
1324
+ norm : string, optional (default=None)
1325
+ If given, normalize the ground metric to avoid numerical errors that
1326
+ can occur with large metric values.
1299
1327
distribution : string, optional (default="uniform")
1300
1328
The kind of distribution estimation to employ
1301
1329
verbose : int, optional (default=0)
@@ -1306,6 +1334,9 @@ class EMDTransport(BaseTransport):
1306
1334
Controls the semi supervised mode. Transport between labeled source
1307
1335
and target samples of different classes will exhibit an infinite cost
1308
1336
(10 times the maximum value of the cost matrix)
1337
+ max_iter : int, optional (default=100000)
1338
+ The maximum number of iterations before stopping the optimization
1339
+ algorithm if it has not converged.
1309
1340
1310
1341
Attributes
1311
1342
----------
@@ -1319,14 +1350,17 @@ class EMDTransport(BaseTransport):
1319
1350
on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
1320
1351
"""
1321
1352
1322
- def __init__ (self , metric = "sqeuclidean" ,
1353
+ def __init__ (self , metric = "sqeuclidean" , norm = None ,
1323
1354
distribution_estimation = distribution_estimation_uniform ,
1324
- out_of_sample_map = 'ferradans' , limit_max = 10 ):
1355
+ out_of_sample_map = 'ferradans' , limit_max = 10 ,
1356
+ max_iter = 100000 ):
1325
1357
1326
1358
self .metric = metric
1359
+ self .norm = norm
1327
1360
self .limit_max = limit_max
1328
1361
self .distribution_estimation = distribution_estimation
1329
1362
self .out_of_sample_map = out_of_sample_map
1363
+ self .max_iter = max_iter
1330
1364
1331
1365
def fit (self , Xs , ys = None , Xt = None , yt = None ):
1332
1366
"""Build a coupling matrix from source and target sets of samples
@@ -1353,7 +1387,7 @@ def fit(self, Xs, ys=None, Xt=None, yt=None):
1353
1387
1354
1388
# coupling estimation
1355
1389
self .coupling_ = emd (
1356
- a = self .mu_s , b = self .mu_t , M = self .cost_ ,
1390
+ a = self .mu_s , b = self .mu_t , M = self .cost_ , max_iter = self . max_iter
1357
1391
)
1358
1392
1359
1393
return self
@@ -1376,6 +1410,9 @@ class SinkhornLpl1Transport(BaseTransport):
1376
1410
be transported from a domain to another one.
1377
1411
metric : string, optional (default="sqeuclidean")
1378
1412
The ground metric for the Wasserstein problem
1413
+ norm : string, optional (default=None)
1414
+ If given, normalize the ground metric to avoid numerical errors that
1415
+ can occur with large metric values.
1379
1416
distribution : string, optional (default="uniform")
1380
1417
The kind of distribution estimation to employ
1381
1418
max_iter : int, float, optional (default=10)
@@ -1410,7 +1447,7 @@ class SinkhornLpl1Transport(BaseTransport):
1410
1447
def __init__ (self , reg_e = 1. , reg_cl = 0.1 ,
1411
1448
max_iter = 10 , max_inner_iter = 200 ,
1412
1449
tol = 10e-9 , verbose = False ,
1413
- metric = "sqeuclidean" ,
1450
+ metric = "sqeuclidean" , norm = None ,
1414
1451
distribution_estimation = distribution_estimation_uniform ,
1415
1452
out_of_sample_map = 'ferradans' , limit_max = np .infty ):
1416
1453
@@ -1421,6 +1458,7 @@ def __init__(self, reg_e=1., reg_cl=0.1,
1421
1458
self .tol = tol
1422
1459
self .verbose = verbose
1423
1460
self .metric = metric
1461
+ self .norm = norm
1424
1462
self .distribution_estimation = distribution_estimation
1425
1463
self .out_of_sample_map = out_of_sample_map
1426
1464
self .limit_max = limit_max
@@ -1477,6 +1515,9 @@ class SinkhornL1l2Transport(BaseTransport):
1477
1515
be transported from a domain to another one.
1478
1516
metric : string, optional (default="sqeuclidean")
1479
1517
The ground metric for the Wasserstein problem
1518
+ norm : string, optional (default=None)
1519
+ If given, normalize the ground metric to avoid numerical errors that
1520
+ can occur with large metric values.
1480
1521
distribution : string, optional (default="uniform")
1481
1522
The kind of distribution estimation to employ
1482
1523
max_iter : int, float, optional (default=10)
@@ -1516,7 +1557,7 @@ class SinkhornL1l2Transport(BaseTransport):
1516
1557
def __init__ (self , reg_e = 1. , reg_cl = 0.1 ,
1517
1558
max_iter = 10 , max_inner_iter = 200 ,
1518
1559
tol = 10e-9 , verbose = False , log = False ,
1519
- metric = "sqeuclidean" ,
1560
+ metric = "sqeuclidean" , norm = None ,
1520
1561
distribution_estimation = distribution_estimation_uniform ,
1521
1562
out_of_sample_map = 'ferradans' , limit_max = 10 ):
1522
1563
@@ -1528,6 +1569,7 @@ def __init__(self, reg_e=1., reg_cl=0.1,
1528
1569
self .verbose = verbose
1529
1570
self .log = log
1530
1571
self .metric = metric
1572
+ self .norm = norm
1531
1573
self .distribution_estimation = distribution_estimation
1532
1574
self .out_of_sample_map = out_of_sample_map
1533
1575
self .limit_max = limit_max
@@ -1588,6 +1630,9 @@ class MappingTransport(BaseEstimator):
1588
1630
Estimate linear mapping with constant bias
1589
1631
metric : string, optional (default="sqeuclidean")
1590
1632
The ground metric for the Wasserstein problem
1633
+ norm : string, optional (default=None)
1634
+ If given, normalize the ground metric to avoid numerical errors that
1635
+ can occur with large metric values.
1591
1636
kernel : string, optional (default="linear")
1592
1637
The kernel to use either linear or gaussian
1593
1638
sigma : float, optional (default=1)
@@ -1627,11 +1672,12 @@ class MappingTransport(BaseEstimator):
1627
1672
"""
1628
1673
1629
1674
def __init__ (self , mu = 1 , eta = 0.001 , bias = False , metric = "sqeuclidean" ,
1630
- kernel = "linear" , sigma = 1 , max_iter = 100 , tol = 1e-5 ,
1675
+ norm = None , kernel = "linear" , sigma = 1 , max_iter = 100 , tol = 1e-5 ,
1631
1676
max_inner_iter = 10 , inner_tol = 1e-6 , log = False , verbose = False ,
1632
1677
verbose2 = False ):
1633
1678
1634
1679
self .metric = metric
1680
+ self .norm = norm
1635
1681
self .mu = mu
1636
1682
self .eta = eta
1637
1683
self .bias = bias
0 commit comments