@@ -120,23 +120,23 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', numItermax=1000,
120
120
"""
121
121
122
122
if method .lower () == 'sinkhorn' :
123
- return _sinkhorn_knopp_unbalanced (a , b , M , reg , reg_m ,
124
- numItermax = numItermax ,
125
- stopThr = stopThr , verbose = verbose ,
126
- log = log , ** kwargs )
123
+ return sinkhorn_knopp_unbalanced (a , b , M , reg , reg_m ,
124
+ numItermax = numItermax ,
125
+ stopThr = stopThr , verbose = verbose ,
126
+ log = log , ** kwargs )
127
127
128
128
elif method .lower () == 'sinkhorn_stabilized' :
129
- return _sinkhorn_stabilized_unbalanced (a , b , M , reg , reg_m ,
130
- numItermax = numItermax ,
131
- stopThr = stopThr ,
132
- verbose = verbose ,
133
- log = log , ** kwargs )
129
+ return sinkhorn_stabilized_unbalanced (a , b , M , reg , reg_m ,
130
+ numItermax = numItermax ,
131
+ stopThr = stopThr ,
132
+ verbose = verbose ,
133
+ log = log , ** kwargs )
134
134
elif method .lower () in ['sinkhorn_reg_scaling' ]:
135
135
warnings .warn ('Method not implemented yet. Using classic Sinkhorn Knopp' )
136
- return _sinkhorn_knopp_unbalanced (a , b , M , reg , reg_m ,
137
- numItermax = numItermax ,
138
- stopThr = stopThr , verbose = verbose ,
139
- log = log , ** kwargs )
136
+ return sinkhorn_knopp_unbalanced (a , b , M , reg , reg_m ,
137
+ numItermax = numItermax ,
138
+ stopThr = stopThr , verbose = verbose ,
139
+ log = log , ** kwargs )
140
140
else :
141
141
raise ValueError ("Unknown method '%s'." % method )
142
142
@@ -241,29 +241,29 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn',
241
241
if len (b .shape ) < 2 :
242
242
b = b [:, None ]
243
243
if method .lower () == 'sinkhorn' :
244
- return _sinkhorn_knopp_unbalanced (a , b , M , reg , reg_m ,
245
- numItermax = numItermax ,
246
- stopThr = stopThr , verbose = verbose ,
247
- log = log , ** kwargs )
244
+ return sinkhorn_knopp_unbalanced (a , b , M , reg , reg_m ,
245
+ numItermax = numItermax ,
246
+ stopThr = stopThr , verbose = verbose ,
247
+ log = log , ** kwargs )
248
248
249
249
elif method .lower () == 'sinkhorn_stabilized' :
250
- return _sinkhorn_stabilized_unbalanced (a , b , M , reg , reg_m ,
251
- numItermax = numItermax ,
252
- stopThr = stopThr ,
253
- verbose = verbose ,
254
- log = log , ** kwargs )
250
+ return sinkhorn_stabilized_unbalanced (a , b , M , reg , reg_m ,
251
+ numItermax = numItermax ,
252
+ stopThr = stopThr ,
253
+ verbose = verbose ,
254
+ log = log , ** kwargs )
255
255
elif method .lower () in ['sinkhorn_reg_scaling' ]:
256
256
warnings .warn ('Method not implemented yet. Using classic Sinkhorn Knopp' )
257
- return _sinkhorn_knopp_unbalanced (a , b , M , reg , reg_m ,
258
- numItermax = numItermax ,
259
- stopThr = stopThr , verbose = verbose ,
260
- log = log , ** kwargs )
257
+ return sinkhorn_knopp_unbalanced (a , b , M , reg , reg_m ,
258
+ numItermax = numItermax ,
259
+ stopThr = stopThr , verbose = verbose ,
260
+ log = log , ** kwargs )
261
261
else :
262
262
raise ValueError ('Unknown method %s.' % method )
263
263
264
264
265
- def _sinkhorn_knopp_unbalanced (a , b , M , reg , reg_m , numItermax = 1000 ,
266
- stopThr = 1e-6 , verbose = False , log = False , ** kwargs ):
265
+ def sinkhorn_knopp_unbalanced (a , b , M , reg , reg_m , numItermax = 1000 ,
266
+ stopThr = 1e-6 , verbose = False , log = False , ** kwargs ):
267
267
r"""
268
268
Solve the entropic regularization unbalanced optimal transport problem and return the loss
269
269
@@ -300,7 +300,7 @@ def _sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000,
300
300
numItermax : int, optional
301
301
Max number of iterations
302
302
stopThr : float, optional
303
- Stop threshol on error (>0)
303
+ Stop threshol on error (> 0)
304
304
verbose : bool, optional
305
305
Print information along iterations
306
306
log : bool, optional
@@ -439,9 +439,9 @@ def _sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, numItermax=1000,
439
439
return u [:, None ] * K * v [None , :]
440
440
441
441
442
- def _sinkhorn_stabilized_unbalanced (a , b , M , reg , reg_m , tau = 1e5 , numItermax = 1000 ,
443
- stopThr = 1e-6 , verbose = False , log = False ,
444
- ** kwargs ):
442
+ def sinkhorn_stabilized_unbalanced (a , b , M , reg , reg_m , tau = 1e5 , numItermax = 1000 ,
443
+ stopThr = 1e-6 , verbose = False , log = False ,
444
+ ** kwargs ):
445
445
r"""
446
446
Solve the entropic regularization unbalanced optimal transport
447
447
problem and return the loss
@@ -653,9 +653,9 @@ def _sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, tau=1e5, numItermax=100
653
653
return ot_matrix
654
654
655
655
656
- def _barycenter_unbalanced_stabilized (A , M , reg , reg_m , weights = None , tau = 1e3 ,
657
- numItermax = 1000 , stopThr = 1e-6 ,
658
- verbose = False , log = False ):
656
+ def barycenter_unbalanced_stabilized (A , M , reg , reg_m , weights = None , tau = 1e3 ,
657
+ numItermax = 1000 , stopThr = 1e-6 ,
658
+ verbose = False , log = False ):
659
659
r"""Compute the entropic unbalanced wasserstein barycenter of A with stabilization.
660
660
661
661
The function solves the following optimization problem:
@@ -804,9 +804,9 @@ def _barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3,
804
804
return q
805
805
806
806
807
- def _barycenter_unbalanced (A , M , reg , reg_m , weights = None ,
808
- numItermax = 1000 , stopThr = 1e-6 ,
809
- verbose = False , log = False ):
807
+ def barycenter_unbalanced_sinkhorn (A , M , reg , reg_m , weights = None ,
808
+ numItermax = 1000 , stopThr = 1e-6 ,
809
+ verbose = False , log = False ):
810
810
r"""Compute the entropic unbalanced wasserstein barycenter of A.
811
811
812
812
The function solves the following optimization problem with a
@@ -1001,22 +1001,22 @@ def barycenter_unbalanced(A, M, reg, reg_m, method="sinkhorn", weights=None,
1001
1001
"""
1002
1002
1003
1003
if method .lower () == 'sinkhorn' :
1004
- return _barycenter_unbalanced (A , M , reg , reg_m ,
1005
- numItermax = numItermax ,
1006
- stopThr = stopThr , verbose = verbose ,
1007
- log = log , ** kwargs )
1004
+ return barycenter_unbalanced_sinkhorn (A , M , reg , reg_m ,
1005
+ numItermax = numItermax ,
1006
+ stopThr = stopThr , verbose = verbose ,
1007
+ log = log , ** kwargs )
1008
1008
1009
1009
elif method .lower () == 'sinkhorn_stabilized' :
1010
- return _barycenter_unbalanced_stabilized (A , M , reg , reg_m ,
1011
- numItermax = numItermax ,
1012
- stopThr = stopThr ,
1013
- verbose = verbose ,
1014
- log = log , ** kwargs )
1010
+ return barycenter_unbalanced_stabilized (A , M , reg , reg_m ,
1011
+ numItermax = numItermax ,
1012
+ stopThr = stopThr ,
1013
+ verbose = verbose ,
1014
+ log = log , ** kwargs )
1015
1015
elif method .lower () in ['sinkhorn_reg_scaling' ]:
1016
1016
warnings .warn ('Method not implemented yet. Using classic Sinkhorn Knopp' )
1017
- return _barycenter_unbalanced (A , M , reg , reg_m ,
1018
- numItermax = numItermax ,
1019
- stopThr = stopThr , verbose = verbose ,
1020
- log = log , ** kwargs )
1017
+ return barycenter_unbalanced (A , M , reg , reg_m ,
1018
+ numItermax = numItermax ,
1019
+ stopThr = stopThr , verbose = verbose ,
1020
+ log = log , ** kwargs )
1021
1021
else :
1022
1022
raise ValueError ("Unknown method '%s'." % method )
0 commit comments