Skip to content

Commit a6d5d75

Browse files
authored
[MRG] Add method argument to sinkhorn Transport (#440)
* add method argument to sinkhron transport' * update release
1 parent a313e21 commit a6d5d75

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

RELEASES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
- New API for OT solver using function `ot.solve` (PR #388)
1515
- Backend version of `ot.partial` and `ot.smooth` (PR #388)
1616
- Added argument for warmstart of dual potentials in Sinkhorn-based methods in `ot.bregman` (PR #437)
17+
- Add parameters method in `ot.da.SinkhornTransport` (PR #440)
1718

1819
#### Closed issues
1920

ot/da.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1417,12 +1417,13 @@ class SinkhornTransport(BaseTransport):
14171417
Sciences, 7(3), 1853-1882.
14181418
"""
14191419

1420-
def __init__(self, reg_e=1., max_iter=1000,
1420+
def __init__(self, reg_e=1., method="sinkhorn", max_iter=1000,
14211421
tol=10e-9, verbose=False, log=False,
14221422
metric="sqeuclidean", norm=None,
14231423
distribution_estimation=distribution_estimation_uniform,
14241424
out_of_sample_map='ferradans', limit_max=np.infty):
14251425
self.reg_e = reg_e
1426+
self.method = method
14261427
self.max_iter = max_iter
14271428
self.tol = tol
14281429
self.verbose = verbose
@@ -1463,7 +1464,7 @@ class label
14631464
# coupling estimation
14641465
returned_ = sinkhorn(
14651466
a=self.mu_s, b=self.mu_t, M=self.cost_, reg=self.reg_e,
1466-
numItermax=self.max_iter, stopThr=self.tol,
1467+
method=self.method, numItermax=self.max_iter, stopThr=self.tol,
14671468
verbose=self.verbose, log=self.log)
14681469

14691470
# deal with the value of log

0 commit comments

Comments
 (0)