Skip to content

Commit 36b2e92

Browse files
author
ievred
committed
added defaults for emd_laplace
1 parent fd115a5 commit 36b2e92

File tree

1 file changed

+7
-8
lines changed

1 file changed

+7
-8
lines changed

ot/da.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -748,9 +748,9 @@ def OT_mapping_linear(xs, xt, reg=1e-6, ws=None,
748748
return A, b
749749

750750

751-
def emd_laplace(a, b, xs, xt, M, sim, sim_param, reg, eta, alpha,
752-
numItermax, stopThr, numInnerItermax,
753-
stopInnerThr, log=False, verbose=False, **kwargs):
751+
def emd_laplace(a, b, xs, xt, M, sim='knn', sim_param=None, reg='pos', eta=1, alpha=.5,
752+
numItermax=100, stopThr=1e-9, numInnerItermax=100000,
753+
stopInnerThr=1e-9, log=False, verbose=False):
754754
r"""Solve the optimal transport problem (OT) with Laplacian regularization
755755
756756
.. math::
@@ -1765,15 +1765,14 @@ class EMDLaplaceTransport(BaseTransport):
17651765
in NIPS Workshop on Optimal Transport and Machine Learning OTML, 2014.
17661766
"""
17671767

1768-
def __init__(self, reg_type='pos', reg_lap=1., reg_src=1., alpha=0.5,
1769-
metric="sqeuclidean", norm=None, similarity="knn", similarity_param=None, max_iter=100, tol=1e-9,
1768+
def __init__(self, reg_type='pos', reg_lap=1., reg_src=1., metric="sqeuclidean",
1769+
norm=None, similarity="knn", similarity_param=None, max_iter=100, tol=1e-9,
17701770
max_inner_iter=100000, inner_tol=1e-9, log=False, verbose=False,
17711771
distribution_estimation=distribution_estimation_uniform,
17721772
out_of_sample_map='ferradans'):
17731773
self.reg = reg_type
17741774
self.reg_lap = reg_lap
17751775
self.reg_src = reg_src
1776-
self.alpha = alpha
17771776
self.metric = metric
17781777
self.norm = norm
17791778
self.similarity = similarity
@@ -1815,8 +1814,8 @@ class label
18151814
super(EMDLaplaceTransport, self).fit(Xs, ys, Xt, yt)
18161815

18171816
returned_ = emd_laplace(a=self.mu_s, b=self.mu_t, xs=self.xs_,
1818-
xt=self.xt_, M=self.cost_, sim=self.similarity, sim_param=self.sim_param, reg=self.reg, eta=self.reg_lap, alpha=self.reg_src,
1819-
numItermax=self.max_iter, stopThr=self.tol, numInnerItermax=self.max_inner_iter,
1817+
xt=self.xt_, M=self.cost_, sim=self.similarity, sim_param=self.sim_param, reg=self.reg, eta=self.reg_lap,
1818+
alpha=self.reg_src, numItermax=self.max_iter, stopThr=self.tol, numInnerItermax=self.max_inner_iter,
18201819
stopInnerThr=self.inner_tol, log=self.log, verbose=self.verbose)
18211820

18221821
# coupling estimation

0 commit comments

Comments
 (0)