Skip to content

Commit 0a9d499

Browse files
committed
up test bures_wasserstein_distance
1 parent f41b093 commit 0a9d499

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

test/test_gaussian.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,12 +148,14 @@ def test_bures_wasserstein_distance_batch(nx):
148148
np.testing.assert_allclose(0, nx.to_numpy(Wb3[0]), atol=1e-5)
149149
np.testing.assert_allclose(0, nx.to_numpy(Wb3[1]), atol=1e-5)
150150

151-
m_rev = nx.zeros((k, 2))
152-
C_rev = nx.zeros((k, 2, 2))
151+
m_rev = np.zeros((k, 2))
152+
C_rev = np.zeros((k, 2, 2))
153153
m_rev[0] = m[1, 0]
154154
m_rev[1] = m[0, 0]
155155
C_rev[0] = C[1]
156156
C_rev[1] = C[0]
157+
m_rev = nx.from_numpy(m_rev)
158+
C_rev = nx.from_numpy(C_rev)
157159

158160
Wb3 = ot.gaussian.bures_wasserstein_distance(m_rev, m[:, 0], C_rev, C, paired=True)
159161
np.testing.assert_allclose(nx.to_numpy(Wb2)[0, 1], nx.to_numpy(Wb3)[0], atol=1e-5)

0 commit comments

Comments
 (0)