Skip to content

Commit 50994ed

Browse files
committed
RELEASES.md
1 parent 2b317e2 commit 50994ed

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

ot/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1373,7 +1373,7 @@ def exp_bures(Sigma, S, nx=None):
13731373

13741374
return nx.einsum("ij,jk,kl -> il", C, Sigma, C)
13751375

1376-
1376+
13771377
def check_number_threads(numThreads):
13781378
"""Checks whether or not the requested number of threads has a valid value.
13791379

test/test_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -674,10 +674,10 @@ def test_exp_bures(nx):
674674
np.testing.assert_array_less(np.zeros(1), nx.to_numpy(z.T @ Lambda @ z))
675675

676676
# OT map from Lambda to Sigma
677-
Lambda_12 = nx.sqrtm(Lambda)
678-
Lambda_12_ = nx.inv(Lambda_12)
679-
M = nx.sqrtm(nx.einsum("ij, jk, kl -> il", Lambda_12, Sigma, Lambda_12))
680-
T = nx.einsum("ij, jk, kl -> il", Lambda_12_, M, Lambda_12_)
677+
Lambda12 = nx.sqrtm(Lambda)
678+
Lambda12inv = nx.inv(Lambda12)
679+
M = nx.sqrtm(nx.einsum("ij, jk, kl -> il", Lambda12, Sigma, Lambda12))
680+
T = nx.einsum("ij, jk, kl -> il", Lambda12inv, M, Lambda12inv)
681681

682682
# exp_\Lambda(log_\Lambda(Sigma)) = Sigma
683683
Sigma_exp = ot.utils.exp_bures(Lambda, T - nx.eye(d, type_as=T))

0 commit comments

Comments
 (0)