@@ -1315,7 +1315,10 @@ class SinkhornTransport(BaseTransport):
1315
1315
1316
1316
Attributes
1317
1317
----------
1318
- coupling_ : the optimal coupling
1318
+ coupling_ : array-like, shape (n_source_samples, n_target_samples)
1319
+ The optimal coupling
1320
+ log_ : dictionary
1321
+ The dictionary of log, empty dic if parameter log is not True
1319
1322
1320
1323
References
1321
1324
----------
@@ -1367,11 +1370,18 @@ def fit(self, Xs=None, ys=None, Xt=None, yt=None):
1367
1370
super (SinkhornTransport , self ).fit (Xs , ys , Xt , yt )
1368
1371
1369
1372
# coupling estimation
1370
- self . coupling_ = sinkhorn (
1373
+ returned_ = sinkhorn (
1371
1374
a = self .mu_s , b = self .mu_t , M = self .cost_ , reg = self .reg_e ,
1372
1375
numItermax = self .max_iter , stopThr = self .tol ,
1373
1376
verbose = self .verbose , log = self .log )
1374
1377
1378
+ # deal with the value of log
1379
+ if self .log :
1380
+ self .coupling_ , self .log_ = returned_
1381
+ else :
1382
+ self .coupling_ = returned_
1383
+ self .log_ = dict ()
1384
+
1375
1385
return self
1376
1386
1377
1387
@@ -1400,7 +1410,8 @@ class EMDTransport(BaseTransport):
1400
1410
1401
1411
Attributes
1402
1412
----------
1403
- coupling_ : the optimal coupling
1413
+ coupling_ : array-like, shape (n_source_samples, n_target_samples)
1414
+ The optimal coupling
1404
1415
1405
1416
References
1406
1417
----------
@@ -1475,15 +1486,14 @@ class SinkhornLpl1Transport(BaseTransport):
1475
1486
The number of iteration in the inner loop
1476
1487
verbose : int, optional (default=0)
1477
1488
Controls the verbosity of the optimization algorithm
1478
- log : int, optional (default=0)
1479
- Controls the logs of the optimization algorithm
1480
1489
limit_max: float, optional (defaul=np.infty)
1481
1490
Controls the semi supervised mode. Transport between labeled source
1482
1491
and target samples of different classes will exhibit an infinite cost
1483
1492
1484
1493
Attributes
1485
1494
----------
1486
- coupling_ : the optimal coupling
1495
+ coupling_ : array-like, shape (n_source_samples, n_target_samples)
1496
+ The optimal coupling
1487
1497
1488
1498
References
1489
1499
----------
@@ -1500,7 +1510,7 @@ class SinkhornLpl1Transport(BaseTransport):
1500
1510
1501
1511
def __init__ (self , reg_e = 1. , reg_cl = 0.1 ,
1502
1512
max_iter = 10 , max_inner_iter = 200 ,
1503
- tol = 10e-9 , verbose = False , log = False ,
1513
+ tol = 10e-9 , verbose = False ,
1504
1514
metric = "sqeuclidean" ,
1505
1515
distribution_estimation = distribution_estimation_uniform ,
1506
1516
out_of_sample_map = 'ferradans' , limit_max = np .infty ):
@@ -1511,7 +1521,6 @@ def __init__(self, reg_e=1., reg_cl=0.1,
1511
1521
self .max_inner_iter = max_inner_iter
1512
1522
self .tol = tol
1513
1523
self .verbose = verbose
1514
- self .log = log
1515
1524
self .metric = metric
1516
1525
self .distribution_estimation = distribution_estimation
1517
1526
self .out_of_sample_map = out_of_sample_map
@@ -1544,7 +1553,7 @@ def fit(self, Xs, ys=None, Xt=None, yt=None):
1544
1553
a = self .mu_s , labels_a = ys , b = self .mu_t , M = self .cost_ ,
1545
1554
reg = self .reg_e , eta = self .reg_cl , numItermax = self .max_iter ,
1546
1555
numInnerItermax = self .max_inner_iter , stopInnerThr = self .tol ,
1547
- verbose = self .verbose , log = self . log )
1556
+ verbose = self .verbose )
1548
1557
1549
1558
return self
1550
1559
@@ -1584,7 +1593,10 @@ class SinkhornL1l2Transport(BaseTransport):
1584
1593
1585
1594
Attributes
1586
1595
----------
1587
- coupling_ : the optimal coupling
1596
+ coupling_ : array-like, shape (n_source_samples, n_target_samples)
1597
+ The optimal coupling
1598
+ log_ : dictionary
1599
+ The dictionary of log, empty dic if parameter log is not True
1588
1600
1589
1601
References
1590
1602
----------
@@ -1641,12 +1653,19 @@ def fit(self, Xs, ys=None, Xt=None, yt=None):
1641
1653
1642
1654
super (SinkhornL1l2Transport , self ).fit (Xs , ys , Xt , yt )
1643
1655
1644
- self . coupling_ = sinkhorn_l1l2_gl (
1656
+ returned_ = sinkhorn_l1l2_gl (
1645
1657
a = self .mu_s , labels_a = ys , b = self .mu_t , M = self .cost_ ,
1646
1658
reg = self .reg_e , eta = self .reg_cl , numItermax = self .max_iter ,
1647
1659
numInnerItermax = self .max_inner_iter , stopInnerThr = self .tol ,
1648
1660
verbose = self .verbose , log = self .log )
1649
1661
1662
+ # deal with the value of log
1663
+ if self .log :
1664
+ self .coupling_ , self .log_ = returned_
1665
+ else :
1666
+ self .coupling_ = returned_
1667
+ self .log_ = dict ()
1668
+
1650
1669
return self
1651
1670
1652
1671
@@ -1683,14 +1702,15 @@ class MappingTransport(BaseEstimator):
1683
1702
1684
1703
Attributes
1685
1704
----------
1686
- coupling_ : array-like, shape (n_source_samples, n_features )
1705
+ coupling_ : array-like, shape (n_source_samples, n_target_samples )
1687
1706
The optimal coupling
1688
1707
mapping_ : array-like, shape (n_features (+ 1), n_features)
1689
1708
(if bias) for kernel == linear
1690
1709
The associated mapping
1691
-
1692
1710
array-like, shape (n_source_samples (+ 1), n_features)
1693
1711
(if bias) for kernel == gaussian
1712
+ log_ : dictionary
1713
+ The dictionary of log, empty dic if parameter log is not True
1694
1714
1695
1715
References
1696
1716
----------
@@ -1745,19 +1765,26 @@ def fit(self, Xs=None, ys=None, Xt=None, yt=None):
1745
1765
self .Xt = Xt
1746
1766
1747
1767
if self .kernel == "linear" :
1748
- self . coupling_ , self . mapping_ = joint_OT_mapping_linear (
1768
+ returned_ = joint_OT_mapping_linear (
1749
1769
Xs , Xt , mu = self .mu , eta = self .eta , bias = self .bias ,
1750
1770
verbose = self .verbose , verbose2 = self .verbose2 ,
1751
1771
numItermax = self .max_iter , numInnerItermax = self .max_inner_iter ,
1752
1772
stopThr = self .tol , stopInnerThr = self .inner_tol , log = self .log )
1753
1773
1754
1774
elif self .kernel == "gaussian" :
1755
- self . coupling_ , self . mapping_ = joint_OT_mapping_kernel (
1775
+ returned_ = joint_OT_mapping_kernel (
1756
1776
Xs , Xt , mu = self .mu , eta = self .eta , bias = self .bias ,
1757
1777
sigma = self .sigma , verbose = self .verbose , verbose2 = self .verbose ,
1758
1778
numItermax = self .max_iter , numInnerItermax = self .max_inner_iter ,
1759
1779
stopInnerThr = self .inner_tol , stopThr = self .tol , log = self .log )
1760
1780
1781
+ # deal with the value of log
1782
+ if self .log :
1783
+ self .coupling_ , self .mapping_ , self .log_ = returned_
1784
+ else :
1785
+ self .coupling_ , self .mapping_ = returned_
1786
+ self .log_ = dict ()
1787
+
1761
1788
return self
1762
1789
1763
1790
def transform (self , Xs ):
0 commit comments