Skip to content

Commit a412c9b

Browse files
Relax test tolerance in float32
1 parent ea268f7 commit a412c9b

File tree

1 file changed

+13
-19
lines changed

1 file changed

+13
-19
lines changed

tests/tensor/signal/test_conv.py

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -183,29 +183,15 @@ def test_convolve2d(kernel_shape, data_shape, mode, boundary, boundary_kwargs):
183183
utt.verify_grad(lambda k: op(data_val, k).sum(), [kernel_val])
184184

185185

186-
def test_convolve2d_fft():
187-
data = matrix("data")
188-
kernel = matrix("kernel")
189-
out_fft = convolve2d(data, kernel, mode="same", method="fft")
190-
out_direct = convolve2d(data, kernel, mode="same", method="direct")
191-
192-
rng = np.random.default_rng()
193-
data_val = rng.normal(size=(7, 5)).astype(config.floatX)
194-
kernel_val = rng.normal(size=(3, 2)).astype(config.floatX)
195-
196-
fn = function([data, kernel], [out_fft, out_direct])
197-
fft_res, direct_res = fn(data_val, kernel_val)
198-
np.testing.assert_allclose(fft_res, direct_res)
199-
200-
201186
@pytest.mark.parametrize("mode", ["full", "valid", "same"])
202-
def test_batched_1d_agrees_with_2d_row_filter(mode):
187+
@pytest.mark.parametrize("method", ["direct", "fft"])
188+
def test_batched_1d_agrees_with_2d_row_filter(mode, method):
203189
data = matrix("data")
204190
kernel_1d = vector("kernel_1d")
205191
kernel_2d = expand_dims(kernel_1d, 0)
206192

207193
output_1d = convolve1d(data, kernel_1d, mode=mode)
208-
output_2d = convolve2d(data, kernel_2d, mode=mode)
194+
output_2d = convolve2d(data, kernel_2d, mode=mode, method=method)
209195

210196
grad_1d = grad(output_1d.sum(), kernel_1d).ravel()
211197
grad_2d = grad(output_1d.sum(), kernel_1d).ravel()
@@ -216,5 +202,13 @@ def test_batched_1d_agrees_with_2d_row_filter(mode):
216202
kernel_1d_val = np.random.normal(size=(3,)).astype(config.floatX)
217203

218204
forward_1d, forward_2d, backward_1d, backward_2d = fn(data_val, kernel_1d_val)
219-
np.testing.assert_allclose(forward_1d, forward_2d)
220-
np.testing.assert_allclose(backward_1d, backward_2d)
205+
np.testing.assert_allclose(
206+
forward_1d,
207+
forward_2d,
208+
rtol=1e-5 if config.floatX == "float32" else 1e-13,
209+
)
210+
np.testing.assert_allclose(
211+
backward_1d,
212+
backward_2d,
213+
rtol=1e-5 if config.floatX == "float32" else 1e-13,
214+
)

0 commit comments

Comments
 (0)