@@ -748,7 +748,7 @@ def OT_mapping_linear(xs, xt, reg=1e-6, ws=None,
748
748
return A , b
749
749
750
750
751
- def emd_laplace (a , b , xs , xt , M , reg , eta , alpha ,
751
+ def emd_laplace (a , b , xs , xt , M , sim , sim_param , reg , eta , alpha ,
752
752
numItermax , stopThr , numInnerItermax ,
753
753
stopInnerThr , log = False , verbose = False , ** kwargs ):
754
754
r"""Solve the optimal transport problem (OT) with Laplacian regularization
@@ -785,6 +785,11 @@ def emd_laplace(a, b, xs, xt, M, reg, eta, alpha,
785
785
samples in the target domain
786
786
M : np.ndarray (ns,nt)
787
787
loss matrix
788
+ sim : string, optional
789
+ Type of similarity ('knn' or 'gauss') used to construct the Laplacian.
790
+ sim_param : int or float, optional
791
+ Parameter (number of the nearest neighbors for sim='knn'
792
+ or bandwidth for sim='gauss' used to compute the Laplacian.
788
793
reg : string
789
794
Type of Laplacian regularization
790
795
eta : float
@@ -803,11 +808,6 @@ def emd_laplace(a, b, xs, xt, M, reg, eta, alpha,
803
808
Print information along iterations
804
809
log : bool, optional
805
810
record log if True
806
- kwargs : dict
807
- Dictionary with attributes 'sim' ('knn' or 'gauss') and
808
- 'param' (int, float or None) for similarity type and its parameter to be used.
809
- If 'param' is None, it is computed as mean pairwise Euclidean distance over the data set
810
- or set to 3 when sim is 'gauss' or 'knn', respectively.
811
811
812
812
Returns
813
813
-------
@@ -824,7 +824,7 @@ def emd_laplace(a, b, xs, xt, M, reg, eta, alpha,
824
824
"Optimal Transport for Domain Adaptation," in IEEE
825
825
Transactions on Pattern Analysis and Machine Intelligence ,
826
826
vol.PP, no.99, pp.1-1
827
- .. [28 ] R. Flamary, N. Courty, D. Tuia, A. Rakotomamonjy,
827
+ .. [30 ] R. Flamary, N. Courty, D. Tuia, A. Rakotomamonjy,
828
828
"Optimal transport with Laplacian regularization: Applications to domain adaptation and shape matching,"
829
829
in NIPS Workshop on Optimal Transport and Machine Learning OTML, 2014.
830
830
@@ -834,28 +834,28 @@ def emd_laplace(a, b, xs, xt, M, reg, eta, alpha,
834
834
ot.optim.cg : General regularized OT
835
835
836
836
"""
837
- if not isinstance (kwargs [ 'param' ] , (int , float , type (None ))):
837
+ if not isinstance (sim_param , (int , float , type (None ))):
838
838
raise ValueError (
839
- 'Similarity parameter should be an int or a float. Got {type} instead.' .format (type = type (kwargs [ 'param' ]) ))
839
+ 'Similarity parameter should be an int or a float. Got {type} instead.' .format (type = type (sim_param ). __name__ ))
840
840
841
- if kwargs [ ' sim' ] == 'gauss' :
842
- if kwargs [ 'param' ] is None :
843
- kwargs [ 'param' ] = 1 / (2 * (np .mean (dist (xs , xs , 'sqeuclidean' )) ** 2 ))
844
- sS = kernel (xs , xs , method = kwargs [ ' sim' ] , sigma = kwargs [ 'param' ] )
845
- sT = kernel (xt , xt , method = kwargs [ ' sim' ] , sigma = kwargs [ 'param' ] )
841
+ if sim == 'gauss' :
842
+ if sim_param is None :
843
+ sim_param = 1 / (2 * (np .mean (dist (xs , xs , 'sqeuclidean' )) ** 2 ))
844
+ sS = kernel (xs , xs , method = sim , sigma = sim_param )
845
+ sT = kernel (xt , xt , method = sim , sigma = sim_param )
846
846
847
- elif kwargs [ ' sim' ] == 'knn' :
848
- if kwargs [ 'param' ] is None :
849
- kwargs [ 'param' ] = 3
847
+ elif sim == 'knn' :
848
+ if sim_param is None :
849
+ sim_param = 3
850
850
851
851
from sklearn .neighbors import kneighbors_graph
852
852
853
- sS = kneighbors_graph (X = xs , n_neighbors = int (kwargs [ 'param' ] )).toarray ()
853
+ sS = kneighbors_graph (X = xs , n_neighbors = int (sim_param )).toarray ()
854
854
sS = (sS + sS .T ) / 2
855
- sT = kneighbors_graph (xt , n_neighbors = int (kwargs [ 'param' ] )).toarray ()
855
+ sT = kneighbors_graph (xt , n_neighbors = int (sim_param )).toarray ()
856
856
sT = (sT + sT .T ) / 2
857
857
else :
858
- raise ValueError ('Unknown similarity type {sim}. Currently supported similarity types are "knn" and "gauss".' .format (sim = kwargs [ ' sim' ] ))
858
+ raise ValueError ('Unknown similarity type {sim}. Currently supported similarity types are "knn" and "gauss".' .format (sim = sim ))
859
859
860
860
lS = laplacian (sS )
861
861
lT = laplacian (sT )
@@ -1729,9 +1729,10 @@ class EMDLaplaceTransport(BaseTransport):
1729
1729
can occur with large metric values.
1730
1730
similarity : string, optional (default="knn")
1731
1731
The similarity to use either knn or gaussian
1732
- similarity_param : int or float, optional (default=3 )
1732
+ similarity_param : int or float, optional (default=None )
1733
1733
Parameter for the similarity: number of nearest neighbors or bandwidth
1734
- if similarity="knn" or "gaussian", respectively.
1734
+ if similarity="knn" or "gaussian", respectively. If None is provided,
1735
+ it is set to 3 or the average pairwise squared Euclidean distance, respectively.
1735
1736
max_iter : int, optional (default=100)
1736
1737
Max number of BCD iterations
1737
1738
tol : float, optional (default=1e-5)
@@ -1813,14 +1814,10 @@ class label
1813
1814
1814
1815
super (EMDLaplaceTransport , self ).fit (Xs , ys , Xt , yt )
1815
1816
1816
- kwargs = dict ()
1817
- kwargs ["sim" ] = self .similarity
1818
- kwargs ["param" ] = self .sim_param
1819
-
1820
1817
returned_ = emd_laplace (a = self .mu_s , b = self .mu_t , xs = self .xs_ ,
1821
- xt = self .xt_ , M = self .cost_ , reg = self .reg , eta = self .reg_lap , alpha = self .reg_src ,
1818
+ xt = self .xt_ , M = self .cost_ , sim = self . similarity , sim_param = self . sim_param , reg = self .reg , eta = self .reg_lap , alpha = self .reg_src ,
1822
1819
numItermax = self .max_iter , stopThr = self .tol , numInnerItermax = self .max_inner_iter ,
1823
- stopInnerThr = self .inner_tol , log = self .log , verbose = self .verbose , ** kwargs )
1820
+ stopInnerThr = self .inner_tol , log = self .log , verbose = self .verbose )
1824
1821
1825
1822
# coupling estimation
1826
1823
if self .log :
0 commit comments