Skip to content

Commit 24b268c

Browse files
Kilian FatrasKilian Fatras
authored andcommitted
import unif and dist in bregman file
1 parent f63712f commit 24b268c

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

ot/bregman.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
# License: MIT License
1111

1212
import numpy as np
13+
from .utils import unif, dist
1314

1415

1516
def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
@@ -1375,11 +1376,11 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numI
13751376
'''
13761377

13771378
if a is None:
1378-
a = ot.unif(np.shape(X_s)[0])
1379+
a = unif(np.shape(X_s)[0])
13791380
if b is None:
1380-
b = ot.unif(np.shape(X_t)[0])
1381+
b = unif(np.shape(X_t)[0])
13811382

1382-
M = ot.dist(X_s, X_t, metric=metric)
1383+
M = dist(X_s, X_t, metric=metric)
13831384

13841385
if log:
13851386
pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=True, **kwargs)
@@ -1465,11 +1466,11 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num
14651466
'''
14661467

14671468
if a is None:
1468-
a = ot.unif(np.shape(X_s)[0])
1469+
a = unif(np.shape(X_s)[0])
14691470
if b is None:
1470-
b = ot.unif(np.shape(X_t)[0])
1471+
b = unif(np.shape(X_t)[0])
14711472

1472-
M = ot.dist(X_s, X_t, metric=metric)
1473+
M = dist(X_s, X_t, metric=metric)
14731474

14741475
if log:
14751476
sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, **kwargs)

0 commit comments

Comments
 (0)