Skip to content

Commit 7efea81

Browse files
author
Hicham Janati
committed
same for unbalanced
1 parent c7269d3 commit 7efea81

File tree

1 file changed

+51
-51
lines changed

1 file changed

+51
-51
lines changed

ot/unbalanced.py

Lines changed: 51 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -120,23 +120,23 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000,
120120
"""
121121

122122
if method.lower() == 'sinkhorn':
123-
return _sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m,
124-
numItermax=numItermax,
125-
stopThr=stopThr, verbose=verbose,
126-
log=log, **kwargs)
123+
return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m,
124+
numItermax=numItermax,
125+
stopThr=stopThr, verbose=verbose,
126+
log=log, **kwargs)
127127

128128
elif method.lower() == 'sinkhorn_stabilized':
129-
return _sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m,
130-
numItermax=numItermax,
131-
stopThr=stopThr,
132-
verbose=verbose,
133-
log=log, **kwargs)
129+
return sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m,
130+
numItermax=numItermax,
131+
stopThr=stopThr,
132+
verbose=verbose,
133+
log=log, **kwargs)
134134
elif method.lower() in ['sinkhorn_reg_scaling']:
135135
warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp')
136-
return _sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m,
137-
numItermax=numItermax,
138-
stopThr=stopThr, verbose=verbose,
139-
log=log, **kwargs)
136+
return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m,
137+
numItermax=numItermax,
138+
stopThr=stopThr, verbose=verbose,
139+
log=log, **kwargs)
140140
else:
141141
raise ValueError("Unknown method '%s'." % method)
142142

@@ -241,29 +241,29 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn',
241241
if len(b.shape) < 2:
242242
b = b[:, None]
243243
if method.lower() == 'sinkhorn':
244-
return _sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m,
245-
numItermax=numItermax,
246-
stopThr=stopThr, verbose=verbose,
247-
log=log, **kwargs)
244+
return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m,
245+
numItermax=numItermax,
246+
stopThr=stopThr, verbose=verbose,
247+
log=log, **kwargs)
248248

249249
elif method.lower() == 'sinkhorn_stabilized':
250-
return _sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m,
251-
numItermax=numItermax,
252-
stopThr=stopThr,
253-
verbose=verbose,
254-
log=log, **kwargs)
250+
return sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m,
251+
numItermax=numItermax,
252+
stopThr=stopThr,
253+
verbose=verbose,
254+
log=log, **kwargs)
255255
elif method.lower() in ['sinkhorn_reg_scaling']:
256256
warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp')
257-
return _sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m,
258-
numItermax=numItermax,
259-
stopThr=stopThr, verbose=verbose,
260-
log=log, **kwargs)
257+
return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m,
258+
numItermax=numItermax,
259+
stopThr=stopThr, verbose=verbose,
260+
log=log, **kwargs)
261261
else:
262262
raise ValueError('Unknown method %s.' % method)
263263

264264

265-
def _sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000,
266-
stopThr=1e-6, verbose=False, log=False, **kwargs):
265+
def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000,
266+
stopThr=1e-6, verbose=False, log=False, **kwargs):
267267
r"""
268268
Solve the entropic regularization unbalanced optimal transport problem and return the loss
269269
@@ -300,7 +300,7 @@ def _sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000,
300300
numItermax : int, optional
301301
Max number of iterations
302302
stopThr : float, optional
303-
Stop threshol on error (>0)
303+
Stop threshol on error (> 0)
304304
verbose : bool, optional
305305
Print information along iterations
306306
log : bool, optional
@@ -439,9 +439,9 @@ def _sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000,
439439
return u[:, None] * K * v[None, :]
440440

441441

442-
def _sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000,
443-
stopThr=1e-6, verbose=False, log=False,
444-
**kwargs):
442+
def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=1000,
443+
stopThr=1e-6, verbose=False, log=False,
444+
**kwargs):
445445
r"""
446446
Solve the entropic regularization unbalanced optimal transport
447447
problem and return the loss
@@ -653,9 +653,9 @@ def _sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=100
653653
return ot_matrix
654654

655655

656-
def _barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3,
657-
numItermax=1000, stopThr=1e-6,
658-
verbose=False, log=False):
656+
def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3,
657+
numItermax=1000, stopThr=1e-6,
658+
verbose=False, log=False):
659659
r"""Compute the entropic unbalanced wasserstein barycenter of A with stabilization.
660660
661661
The function solves the following optimization problem:
@@ -804,9 +804,9 @@ def _barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3,
804804
return q
805805

806806

807-
def _barycenter_unbalanced(A, M, reg, reg_m, weights=None,
808-
numItermax=1000, stopThr=1e-6,
809-
verbose=False, log=False):
807+
def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None,
808+
numItermax=1000, stopThr=1e-6,
809+
verbose=False, log=False):
810810
r"""Compute the entropic unbalanced wasserstein barycenter of A.
811811
812812
The function solves the following optimization problem with a
@@ -1001,22 +1001,22 @@ def barycenter_unbalanced(A, M, reg, reg_m, method="sinkhorn", weights=None,
10011001
"""
10021002

10031003
if method.lower() == 'sinkhorn':
1004-
return _barycenter_unbalanced(A, M, reg, reg_m,
1005-
numItermax=numItermax,
1006-
stopThr=stopThr, verbose=verbose,
1007-
log=log, **kwargs)
1004+
return barycenter_unbalanced_sinkhorn(A, M, reg, reg_m,
1005+
numItermax=numItermax,
1006+
stopThr=stopThr, verbose=verbose,
1007+
log=log, **kwargs)
10081008

10091009
elif method.lower() == 'sinkhorn_stabilized':
1010-
return _barycenter_unbalanced_stabilized(A, M, reg, reg_m,
1011-
numItermax=numItermax,
1012-
stopThr=stopThr,
1013-
verbose=verbose,
1014-
log=log, **kwargs)
1010+
return barycenter_unbalanced_stabilized(A, M, reg, reg_m,
1011+
numItermax=numItermax,
1012+
stopThr=stopThr,
1013+
verbose=verbose,
1014+
log=log, **kwargs)
10151015
elif method.lower() in ['sinkhorn_reg_scaling']:
10161016
warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp')
1017-
return _barycenter_unbalanced(A, M, reg, reg_m,
1018-
numItermax=numItermax,
1019-
stopThr=stopThr, verbose=verbose,
1020-
log=log, **kwargs)
1017+
return barycenter_unbalanced(A, M, reg, reg_m,
1018+
numItermax=numItermax,
1019+
stopThr=stopThr, verbose=verbose,
1020+
log=log, **kwargs)
10211021
else:
10221022
raise ValueError("Unknown method '%s'." % method)

0 commit comments

Comments
 (0)