Skip to content

Commit b853e6a

Browse files
authored
[MRG] Change the number of projection to match the predefined case (#419)
1 parent 0411ea2 commit b853e6a

File tree

2 files changed

+16
-0
lines changed

2 files changed

+16
-0
lines changed

ot/sliced.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,8 @@ def sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, p=2,
147147

148148
if projections is None:
149149
projections = get_random_projections(d, n_projections, seed, backend=nx, type_as=X_s)
150+
else:
151+
n_projections = projections.shape[1]
150152

151153
X_s_projections = nx.dot(X_s, projections)
152154
X_t_projections = nx.dot(X_t, projections)

test/test_sliced.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,20 @@ def test_max_sliced_different_dists():
110110
assert res > 0.
111111

112112

113+
def test_sliced_same_proj():
114+
n_projections = 10
115+
seed = 12
116+
rng = np.random.RandomState(0)
117+
X = rng.randn(8, 2)
118+
Y = rng.randn(8, 2)
119+
cost1, log1 = ot.sliced_wasserstein_distance(X, Y, seed=seed, n_projections=n_projections, log=True)
120+
P = get_random_projections(X.shape[1], n_projections=10, seed=seed)
121+
cost2, log2 = ot.sliced_wasserstein_distance(X, Y, projections=P, log=True)
122+
123+
assert np.allclose(log1['projections'], log2['projections'])
124+
assert np.isclose(cost1, cost2)
125+
126+
113127
def test_sliced_backend(nx):
114128

115129
n = 100

0 commit comments

Comments
 (0)