Skip to content

Commit f41b093

Browse files
committed
up test bures_wasserstein_distance
1 parent 3a7effc commit f41b093

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

test/test_gaussian.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,8 @@ def test_bures_wasserstein_distance_batch(nx):
148148
np.testing.assert_allclose(0, nx.to_numpy(Wb3[0]), atol=1e-5)
149149
np.testing.assert_allclose(0, nx.to_numpy(Wb3[1]), atol=1e-5)
150150

151-
m_rev = np.zeros((k, 2))
152-
C_rev = np.zeros((k, 2, 2))
151+
m_rev = nx.zeros((k, 2))
152+
C_rev = nx.zeros((k, 2, 2))
153153
m_rev[0] = m[1, 0]
154154
m_rev[1] = m[0, 0]
155155
C_rev[0] = C[1]

0 commit comments

Comments
 (0)