13
13
14
14
from .bregman import sinkhorn
15
15
from .lp import emd
16
- from .utils import unif , dist , kernel
16
+ from .utils import unif , dist , kernel , cost_normalization
17
17
from .utils import check_params , deprecated , BaseEstimator
18
18
from .optim import cg
19
19
from .optim import gcg
@@ -650,15 +650,16 @@ class OTDA(object):
650
650
651
651
"""
652
652
653
- def __init__ (self , metric = 'sqeuclidean' ):
653
+ def __init__ (self , metric = 'sqeuclidean' , norm = None ):
654
654
""" Class initialization"""
655
655
self .xs = 0
656
656
self .xt = 0
657
657
self .G = 0
658
658
self .metric = metric
659
+ self .norm = norm
659
660
self .computed = False
660
661
661
- def fit (self , xs , xt , ws = None , wt = None , norm = None ):
662
+ def fit (self , xs , xt , ws = None , wt = None , max_iter = 100000 ):
662
663
"""Fit domain adaptation between samples is xs and xt
663
664
(with optional weights)"""
664
665
self .xs = xs
@@ -673,8 +674,8 @@ def fit(self, xs, xt, ws=None, wt=None, norm=None):
673
674
self .wt = wt
674
675
675
676
self .M = dist (xs , xt , metric = self .metric )
676
- self .normalizeM ( norm )
677
- self .G = emd (ws , wt , self .M )
677
+ self .M = cost_normalization ( self . M , self . norm )
678
+ self .G = emd (ws , wt , self .M , max_iter )
678
679
self .computed = True
679
680
680
681
def interp (self , direction = 1 ):
@@ -741,26 +742,6 @@ def predict(self, x, direction=1):
741
742
# aply the delta to the interpolation
742
743
return xf [idx , :] + x - x0 [idx , :]
743
744
744
- def normalizeM (self , norm ):
745
- """ Apply normalization to the loss matrix
746
-
747
-
748
- Parameters
749
- ----------
750
- norm : str
751
- type of normalization from 'median','max','log','loglog'
752
-
753
- """
754
-
755
- if norm == "median" :
756
- self .M /= float (np .median (self .M ))
757
- elif norm == "max" :
758
- self .M /= float (np .max (self .M ))
759
- elif norm == "log" :
760
- self .M = np .log (1 + self .M )
761
- elif norm == "loglog" :
762
- self .M = np .log (1 + np .log (1 + self .M ))
763
-
764
745
765
746
@deprecated ("The class OTDA_sinkhorn is deprecated in 0.3.1 and will be"
766
747
" removed in 0.5 \n Use class SinkhornTransport instead." )
@@ -772,7 +753,7 @@ class OTDA_sinkhorn(OTDA):
772
753
773
754
"""
774
755
775
- def fit (self , xs , xt , reg = 1 , ws = None , wt = None , norm = None , ** kwargs ):
756
+ def fit (self , xs , xt , reg = 1 , ws = None , wt = None , ** kwargs ):
776
757
"""Fit regularized domain adaptation between samples is xs and xt
777
758
(with optional weights)"""
778
759
self .xs = xs
@@ -787,7 +768,7 @@ def fit(self, xs, xt, reg=1, ws=None, wt=None, norm=None, **kwargs):
787
768
self .wt = wt
788
769
789
770
self .M = dist (xs , xt , metric = self .metric )
790
- self .normalizeM ( norm )
771
+ self .M = cost_normalization ( self . M , self . norm )
791
772
self .G = sinkhorn (ws , wt , self .M , reg , ** kwargs )
792
773
self .computed = True
793
774
@@ -799,8 +780,7 @@ class OTDA_lpl1(OTDA):
799
780
"""Class for domain adaptation with optimal transport with entropic and
800
781
group regularization"""
801
782
802
- def fit (self , xs , ys , xt , reg = 1 , eta = 1 , ws = None , wt = None , norm = None ,
803
- ** kwargs ):
783
+ def fit (self , xs , ys , xt , reg = 1 , eta = 1 , ws = None , wt = None , ** kwargs ):
804
784
"""Fit regularized domain adaptation between samples is xs and xt
805
785
(with optional weights), See ot.da.sinkhorn_lpl1_mm for fit
806
786
parameters"""
@@ -816,7 +796,7 @@ def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, norm=None,
816
796
self .wt = wt
817
797
818
798
self .M = dist (xs , xt , metric = self .metric )
819
- self .normalizeM ( norm )
799
+ self .M = cost_normalization ( self . M , self . norm )
820
800
self .G = sinkhorn_lpl1_mm (ws , ys , wt , self .M , reg , eta , ** kwargs )
821
801
self .computed = True
822
802
@@ -828,8 +808,7 @@ class OTDA_l1l2(OTDA):
828
808
"""Class for domain adaptation with optimal transport with entropic
829
809
and group lasso regularization"""
830
810
831
- def fit (self , xs , ys , xt , reg = 1 , eta = 1 , ws = None , wt = None , norm = None ,
832
- ** kwargs ):
811
+ def fit (self , xs , ys , xt , reg = 1 , eta = 1 , ws = None , wt = None , ** kwargs ):
833
812
"""Fit regularized domain adaptation between samples is xs and xt
834
813
(with optional weights), See ot.da.sinkhorn_lpl1_gl for fit
835
814
parameters"""
@@ -845,7 +824,7 @@ def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, norm=None,
845
824
self .wt = wt
846
825
847
826
self .M = dist (xs , xt , metric = self .metric )
848
- self .normalizeM ( norm )
827
+ self .M = cost_normalization ( self . M , self . norm )
849
828
self .G = sinkhorn_l1l2_gl (ws , ys , wt , self .M , reg , eta , ** kwargs )
850
829
self .computed = True
851
830
@@ -1001,6 +980,7 @@ def fit(self, Xs=None, ys=None, Xt=None, yt=None):
1001
980
1002
981
# pairwise distance
1003
982
self .cost_ = dist (Xs , Xt , metric = self .metric )
983
+ self .cost_ = cost_normalization (self .cost_ , self .norm )
1004
984
1005
985
if (ys is not None ) and (yt is not None ):
1006
986
@@ -1202,6 +1182,9 @@ class SinkhornTransport(BaseTransport):
1202
1182
be transported from a domain to another one.
1203
1183
metric : string, optional (default="sqeuclidean")
1204
1184
The ground metric for the Wasserstein problem
1185
+ norm : string, optional (default=None)
1186
+ If given, normalize the ground metric to avoid numerical errors that
1187
+ can occur with large metric values.
1205
1188
distribution : string, optional (default="uniform")
1206
1189
The kind of distribution estimation to employ
1207
1190
verbose : int, optional (default=0)
@@ -1231,7 +1214,7 @@ class SinkhornTransport(BaseTransport):
1231
1214
1232
1215
def __init__ (self , reg_e = 1. , max_iter = 1000 ,
1233
1216
tol = 10e-9 , verbose = False , log = False ,
1234
- metric = "sqeuclidean" ,
1217
+ metric = "sqeuclidean" , norm = None ,
1235
1218
distribution_estimation = distribution_estimation_uniform ,
1236
1219
out_of_sample_map = 'ferradans' , limit_max = np .infty ):
1237
1220
@@ -1241,6 +1224,7 @@ def __init__(self, reg_e=1., max_iter=1000,
1241
1224
self .verbose = verbose
1242
1225
self .log = log
1243
1226
self .metric = metric
1227
+ self .norm = norm
1244
1228
self .limit_max = limit_max
1245
1229
self .distribution_estimation = distribution_estimation
1246
1230
self .out_of_sample_map = out_of_sample_map
@@ -1296,6 +1280,9 @@ class EMDTransport(BaseTransport):
1296
1280
be transported from a domain to another one.
1297
1281
metric : string, optional (default="sqeuclidean")
1298
1282
The ground metric for the Wasserstein problem
1283
+ norm : string, optional (default=None)
1284
+ If given, normalize the ground metric to avoid numerical errors that
1285
+ can occur with large metric values.
1299
1286
distribution : string, optional (default="uniform")
1300
1287
The kind of distribution estimation to employ
1301
1288
verbose : int, optional (default=0)
@@ -1306,6 +1293,9 @@ class EMDTransport(BaseTransport):
1306
1293
Controls the semi supervised mode. Transport between labeled source
1307
1294
and target samples of different classes will exhibit an infinite cost
1308
1295
(10 times the maximum value of the cost matrix)
1296
+ max_iter : int, optional (default=100000)
1297
+ The maximum number of iterations before stopping the optimization
1298
+ algorithm if it has not converged.
1309
1299
1310
1300
Attributes
1311
1301
----------
@@ -1319,14 +1309,17 @@ class EMDTransport(BaseTransport):
1319
1309
on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
1320
1310
"""
1321
1311
1322
- def __init__ (self , metric = "sqeuclidean" ,
1312
+ def __init__ (self , metric = "sqeuclidean" , norm = None ,
1323
1313
distribution_estimation = distribution_estimation_uniform ,
1324
- out_of_sample_map = 'ferradans' , limit_max = 10 ):
1314
+ out_of_sample_map = 'ferradans' , limit_max = 10 ,
1315
+ max_iter = 100000 ):
1325
1316
1326
1317
self .metric = metric
1318
+ self .norm = norm
1327
1319
self .limit_max = limit_max
1328
1320
self .distribution_estimation = distribution_estimation
1329
1321
self .out_of_sample_map = out_of_sample_map
1322
+ self .max_iter = max_iter
1330
1323
1331
1324
def fit (self , Xs , ys = None , Xt = None , yt = None ):
1332
1325
"""Build a coupling matrix from source and target sets of samples
@@ -1353,7 +1346,7 @@ def fit(self, Xs, ys=None, Xt=None, yt=None):
1353
1346
1354
1347
# coupling estimation
1355
1348
self .coupling_ = emd (
1356
- a = self .mu_s , b = self .mu_t , M = self .cost_ ,
1349
+ a = self .mu_s , b = self .mu_t , M = self .cost_ , numItermax = self . max_iter
1357
1350
)
1358
1351
1359
1352
return self
@@ -1376,6 +1369,9 @@ class SinkhornLpl1Transport(BaseTransport):
1376
1369
be transported from a domain to another one.
1377
1370
metric : string, optional (default="sqeuclidean")
1378
1371
The ground metric for the Wasserstein problem
1372
+ norm : string, optional (default=None)
1373
+ If given, normalize the ground metric to avoid numerical errors that
1374
+ can occur with large metric values.
1379
1375
distribution : string, optional (default="uniform")
1380
1376
The kind of distribution estimation to employ
1381
1377
max_iter : int, float, optional (default=10)
@@ -1410,7 +1406,7 @@ class SinkhornLpl1Transport(BaseTransport):
1410
1406
def __init__ (self , reg_e = 1. , reg_cl = 0.1 ,
1411
1407
max_iter = 10 , max_inner_iter = 200 ,
1412
1408
tol = 10e-9 , verbose = False ,
1413
- metric = "sqeuclidean" ,
1409
+ metric = "sqeuclidean" , norm = None ,
1414
1410
distribution_estimation = distribution_estimation_uniform ,
1415
1411
out_of_sample_map = 'ferradans' , limit_max = np .infty ):
1416
1412
@@ -1421,6 +1417,7 @@ def __init__(self, reg_e=1., reg_cl=0.1,
1421
1417
self .tol = tol
1422
1418
self .verbose = verbose
1423
1419
self .metric = metric
1420
+ self .norm = norm
1424
1421
self .distribution_estimation = distribution_estimation
1425
1422
self .out_of_sample_map = out_of_sample_map
1426
1423
self .limit_max = limit_max
@@ -1477,6 +1474,9 @@ class SinkhornL1l2Transport(BaseTransport):
1477
1474
be transported from a domain to another one.
1478
1475
metric : string, optional (default="sqeuclidean")
1479
1476
The ground metric for the Wasserstein problem
1477
+ norm : string, optional (default=None)
1478
+ If given, normalize the ground metric to avoid numerical errors that
1479
+ can occur with large metric values.
1480
1480
distribution : string, optional (default="uniform")
1481
1481
The kind of distribution estimation to employ
1482
1482
max_iter : int, float, optional (default=10)
@@ -1516,7 +1516,7 @@ class SinkhornL1l2Transport(BaseTransport):
1516
1516
def __init__ (self , reg_e = 1. , reg_cl = 0.1 ,
1517
1517
max_iter = 10 , max_inner_iter = 200 ,
1518
1518
tol = 10e-9 , verbose = False , log = False ,
1519
- metric = "sqeuclidean" ,
1519
+ metric = "sqeuclidean" , norm = None ,
1520
1520
distribution_estimation = distribution_estimation_uniform ,
1521
1521
out_of_sample_map = 'ferradans' , limit_max = 10 ):
1522
1522
@@ -1528,6 +1528,7 @@ def __init__(self, reg_e=1., reg_cl=0.1,
1528
1528
self .verbose = verbose
1529
1529
self .log = log
1530
1530
self .metric = metric
1531
+ self .norm = norm
1531
1532
self .distribution_estimation = distribution_estimation
1532
1533
self .out_of_sample_map = out_of_sample_map
1533
1534
self .limit_max = limit_max
@@ -1588,6 +1589,9 @@ class MappingTransport(BaseEstimator):
1588
1589
Estimate linear mapping with constant bias
1589
1590
metric : string, optional (default="sqeuclidean")
1590
1591
The ground metric for the Wasserstein problem
1592
+ norm : string, optional (default=None)
1593
+ If given, normalize the ground metric to avoid numerical errors that
1594
+ can occur with large metric values.
1591
1595
kernel : string, optional (default="linear")
1592
1596
The kernel to use either linear or gaussian
1593
1597
sigma : float, optional (default=1)
@@ -1627,11 +1631,12 @@ class MappingTransport(BaseEstimator):
1627
1631
"""
1628
1632
1629
1633
def __init__ (self , mu = 1 , eta = 0.001 , bias = False , metric = "sqeuclidean" ,
1630
- kernel = "linear" , sigma = 1 , max_iter = 100 , tol = 1e-5 ,
1634
+ norm = None , kernel = "linear" , sigma = 1 , max_iter = 100 , tol = 1e-5 ,
1631
1635
max_inner_iter = 10 , inner_tol = 1e-6 , log = False , verbose = False ,
1632
1636
verbose2 = False ):
1633
1637
1634
1638
self .metric = metric
1639
+ self .norm = norm
1635
1640
self .mu = mu
1636
1641
self .eta = eta
1637
1642
self .bias = bias
0 commit comments