|
10 | 10 | # License: MIT License
|
11 | 11 |
|
12 | 12 | import numpy as np
|
| 13 | +from .utils import unif, dist |
13 | 14 |
|
14 | 15 |
|
15 | 16 | def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
|
@@ -1375,11 +1376,11 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numI
|
1375 | 1376 | '''
|
1376 | 1377 |
|
1377 | 1378 | if a is None:
|
1378 |
| - a = ot.unif(np.shape(X_s)[0]) |
| 1379 | + a = unif(np.shape(X_s)[0]) |
1379 | 1380 | if b is None:
|
1380 |
| - b = ot.unif(np.shape(X_t)[0]) |
| 1381 | + b = unif(np.shape(X_t)[0]) |
1381 | 1382 |
|
1382 |
| - M = ot.dist(X_s, X_t, metric=metric) |
| 1383 | + M = dist(X_s, X_t, metric=metric) |
1383 | 1384 |
|
1384 | 1385 | if log:
|
1385 | 1386 | pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=True, **kwargs)
|
@@ -1465,11 +1466,11 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num
|
1465 | 1466 | '''
|
1466 | 1467 |
|
1467 | 1468 | if a is None:
|
1468 |
| - a = ot.unif(np.shape(X_s)[0]) |
| 1469 | + a = unif(np.shape(X_s)[0]) |
1469 | 1470 | if b is None:
|
1470 |
| - b = ot.unif(np.shape(X_t)[0]) |
| 1471 | + b = unif(np.shape(X_t)[0]) |
1471 | 1472 |
|
1472 |
| - M = ot.dist(X_s, X_t, metric=metric) |
| 1473 | + M = dist(X_s, X_t, metric=metric) |
1473 | 1474 |
|
1474 | 1475 | if log:
|
1475 | 1476 | sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, **kwargs)
|
|
0 commit comments