@@ -748,9 +748,9 @@ def OT_mapping_linear(xs, xt, reg=1e-6, ws=None,
748
748
return A , b
749
749
750
750
751
- def emd_laplace (a , b , xs , xt , M , sim , sim_param , reg , eta , alpha ,
752
- numItermax , stopThr , numInnerItermax ,
753
- stopInnerThr , log = False , verbose = False , ** kwargs ):
751
+ def emd_laplace (a , b , xs , xt , M , sim = 'knn' , sim_param = None , reg = 'pos' , eta = 1 , alpha = .5 ,
752
+ numItermax = 100 , stopThr = 1e-9 , numInnerItermax = 100000 ,
753
+ stopInnerThr = 1e-9 , log = False , verbose = False ):
754
754
r"""Solve the optimal transport problem (OT) with Laplacian regularization
755
755
756
756
.. math::
@@ -1765,15 +1765,14 @@ class EMDLaplaceTransport(BaseTransport):
1765
1765
in NIPS Workshop on Optimal Transport and Machine Learning OTML, 2014.
1766
1766
"""
1767
1767
1768
- def __init__ (self , reg_type = 'pos' , reg_lap = 1. , reg_src = 1. , alpha = 0.5 ,
1769
- metric = "sqeuclidean" , norm = None , similarity = "knn" , similarity_param = None , max_iter = 100 , tol = 1e-9 ,
1768
+ def __init__ (self , reg_type = 'pos' , reg_lap = 1. , reg_src = 1. , metric = "sqeuclidean" ,
1769
+ norm = None , similarity = "knn" , similarity_param = None , max_iter = 100 , tol = 1e-9 ,
1770
1770
max_inner_iter = 100000 , inner_tol = 1e-9 , log = False , verbose = False ,
1771
1771
distribution_estimation = distribution_estimation_uniform ,
1772
1772
out_of_sample_map = 'ferradans' ):
1773
1773
self .reg = reg_type
1774
1774
self .reg_lap = reg_lap
1775
1775
self .reg_src = reg_src
1776
- self .alpha = alpha
1777
1776
self .metric = metric
1778
1777
self .norm = norm
1779
1778
self .similarity = similarity
@@ -1815,8 +1814,8 @@ class label
1815
1814
super (EMDLaplaceTransport , self ).fit (Xs , ys , Xt , yt )
1816
1815
1817
1816
returned_ = emd_laplace (a = self .mu_s , b = self .mu_t , xs = self .xs_ ,
1818
- xt = self .xt_ , M = self .cost_ , sim = self .similarity , sim_param = self .sim_param , reg = self .reg , eta = self .reg_lap , alpha = self . reg_src ,
1819
- numItermax = self .max_iter , stopThr = self .tol , numInnerItermax = self .max_inner_iter ,
1817
+ xt = self .xt_ , M = self .cost_ , sim = self .similarity , sim_param = self .sim_param , reg = self .reg , eta = self .reg_lap ,
1818
+ alpha = self . reg_src , numItermax = self .max_iter , stopThr = self .tol , numInnerItermax = self .max_inner_iter ,
1820
1819
stopInnerThr = self .inner_tol , log = self .log , verbose = self .verbose )
1821
1820
1822
1821
# coupling estimation
0 commit comments