Skip to content

Commit fd115a5

Browse files
author
ievred
committed
sim+sim param fixed
1 parent 1a36193 commit fd115a5

File tree

1 file changed

+25
-28
lines changed

1 file changed

+25
-28
lines changed

ot/da.py

Lines changed: 25 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -748,7 +748,7 @@ def OT_mapping_linear(xs, xt, reg=1e-6, ws=None,
748748
return A, b
749749

750750

751-
def emd_laplace(a, b, xs, xt, M, reg, eta, alpha,
751+
def emd_laplace(a, b, xs, xt, M, sim, sim_param, reg, eta, alpha,
752752
numItermax, stopThr, numInnerItermax,
753753
stopInnerThr, log=False, verbose=False, **kwargs):
754754
r"""Solve the optimal transport problem (OT) with Laplacian regularization
@@ -785,6 +785,11 @@ def emd_laplace(a, b, xs, xt, M, reg, eta, alpha,
785785
samples in the target domain
786786
M : np.ndarray (ns,nt)
787787
loss matrix
788+
sim : string, optional
789+
Type of similarity ('knn' or 'gauss') used to construct the Laplacian.
790+
sim_param : int or float, optional
791+
Parameter (number of the nearest neighbors for sim='knn'
792+
or bandwidth for sim='gauss' used to compute the Laplacian.
788793
reg : string
789794
Type of Laplacian regularization
790795
eta : float
@@ -803,11 +808,6 @@ def emd_laplace(a, b, xs, xt, M, reg, eta, alpha,
803808
Print information along iterations
804809
log : bool, optional
805810
record log if True
806-
kwargs : dict
807-
Dictionary with attributes 'sim' ('knn' or 'gauss') and
808-
'param' (int, float or None) for similarity type and its parameter to be used.
809-
If 'param' is None, it is computed as mean pairwise Euclidean distance over the data set
810-
or set to 3 when sim is 'gauss' or 'knn', respectively.
811811
812812
Returns
813813
-------
@@ -824,7 +824,7 @@ def emd_laplace(a, b, xs, xt, M, reg, eta, alpha,
824824
"Optimal Transport for Domain Adaptation," in IEEE
825825
Transactions on Pattern Analysis and Machine Intelligence ,
826826
vol.PP, no.99, pp.1-1
827-
.. [28] R. Flamary, N. Courty, D. Tuia, A. Rakotomamonjy,
827+
.. [30] R. Flamary, N. Courty, D. Tuia, A. Rakotomamonjy,
828828
"Optimal transport with Laplacian regularization: Applications to domain adaptation and shape matching,"
829829
in NIPS Workshop on Optimal Transport and Machine Learning OTML, 2014.
830830
@@ -834,28 +834,28 @@ def emd_laplace(a, b, xs, xt, M, reg, eta, alpha,
834834
ot.optim.cg : General regularized OT
835835
836836
"""
837-
if not isinstance(kwargs['param'], (int, float, type(None))):
837+
if not isinstance(sim_param, (int, float, type(None))):
838838
raise ValueError(
839-
'Similarity parameter should be an int or a float. Got {type} instead.'.format(type=type(kwargs['param'])))
839+
'Similarity parameter should be an int or a float. Got {type} instead.'.format(type=type(sim_param).__name__))
840840

841-
if kwargs['sim'] == 'gauss':
842-
if kwargs['param'] is None:
843-
kwargs['param'] = 1 / (2 * (np.mean(dist(xs, xs, 'sqeuclidean')) ** 2))
844-
sS = kernel(xs, xs, method=kwargs['sim'], sigma=kwargs['param'])
845-
sT = kernel(xt, xt, method=kwargs['sim'], sigma=kwargs['param'])
841+
if sim == 'gauss':
842+
if sim_param is None:
843+
sim_param = 1 / (2 * (np.mean(dist(xs, xs, 'sqeuclidean')) ** 2))
844+
sS = kernel(xs, xs, method=sim, sigma=sim_param)
845+
sT = kernel(xt, xt, method=sim, sigma=sim_param)
846846

847-
elif kwargs['sim'] == 'knn':
848-
if kwargs['param'] is None:
849-
kwargs['param'] = 3
847+
elif sim == 'knn':
848+
if sim_param is None:
849+
sim_param = 3
850850

851851
from sklearn.neighbors import kneighbors_graph
852852

853-
sS = kneighbors_graph(X=xs, n_neighbors=int(kwargs['param'])).toarray()
853+
sS = kneighbors_graph(X=xs, n_neighbors=int(sim_param)).toarray()
854854
sS = (sS + sS.T) / 2
855-
sT = kneighbors_graph(xt, n_neighbors=int(kwargs['param'])).toarray()
855+
sT = kneighbors_graph(xt, n_neighbors=int(sim_param)).toarray()
856856
sT = (sT + sT.T) / 2
857857
else:
858-
raise ValueError('Unknown similarity type {sim}. Currently supported similarity types are "knn" and "gauss".'.format(sim=kwargs['sim']))
858+
raise ValueError('Unknown similarity type {sim}. Currently supported similarity types are "knn" and "gauss".'.format(sim=sim))
859859

860860
lS = laplacian(sS)
861861
lT = laplacian(sT)
@@ -1729,9 +1729,10 @@ class EMDLaplaceTransport(BaseTransport):
17291729
can occur with large metric values.
17301730
similarity : string, optional (default="knn")
17311731
The similarity to use either knn or gaussian
1732-
similarity_param : int or float, optional (default=3)
1732+
similarity_param : int or float, optional (default=None)
17331733
Parameter for the similarity: number of nearest neighbors or bandwidth
1734-
if similarity="knn" or "gaussian", respectively.
1734+
if similarity="knn" or "gaussian", respectively. If None is provided,
1735+
it is set to 3 or the average pairwise squared Euclidean distance, respectively.
17351736
max_iter : int, optional (default=100)
17361737
Max number of BCD iterations
17371738
tol : float, optional (default=1e-5)
@@ -1813,14 +1814,10 @@ class label
18131814

18141815
super(EMDLaplaceTransport, self).fit(Xs, ys, Xt, yt)
18151816

1816-
kwargs = dict()
1817-
kwargs["sim"] = self.similarity
1818-
kwargs["param"] = self.sim_param
1819-
18201817
returned_ = emd_laplace(a=self.mu_s, b=self.mu_t, xs=self.xs_,
1821-
xt=self.xt_, M=self.cost_, reg=self.reg, eta=self.reg_lap, alpha=self.reg_src,
1818+
xt=self.xt_, M=self.cost_, sim=self.similarity, sim_param=self.sim_param, reg=self.reg, eta=self.reg_lap, alpha=self.reg_src,
18221819
numItermax=self.max_iter, stopThr=self.tol, numInnerItermax=self.max_inner_iter,
1823-
stopInnerThr=self.inner_tol, log=self.log, verbose=self.verbose, **kwargs)
1820+
stopInnerThr=self.inner_tol, log=self.log, verbose=self.verbose)
18241821

18251822
# coupling estimation
18261823
if self.log:

0 commit comments

Comments
 (0)