@@ -183,29 +183,15 @@ def test_convolve2d(kernel_shape, data_shape, mode, boundary, boundary_kwargs):
183183 utt .verify_grad (lambda k : op (data_val , k ).sum (), [kernel_val ])
184184
185185
186- def test_convolve2d_fft ():
187- data = matrix ("data" )
188- kernel = matrix ("kernel" )
189- out_fft = convolve2d (data , kernel , mode = "same" , method = "fft" )
190- out_direct = convolve2d (data , kernel , mode = "same" , method = "direct" )
191-
192- rng = np .random .default_rng ()
193- data_val = rng .normal (size = (7 , 5 )).astype (config .floatX )
194- kernel_val = rng .normal (size = (3 , 2 )).astype (config .floatX )
195-
196- fn = function ([data , kernel ], [out_fft , out_direct ])
197- fft_res , direct_res = fn (data_val , kernel_val )
198- np .testing .assert_allclose (fft_res , direct_res )
199-
200-
201186@pytest .mark .parametrize ("mode" , ["full" , "valid" , "same" ])
202- def test_batched_1d_agrees_with_2d_row_filter (mode ):
187+ @pytest .mark .parametrize ("method" , ["direct" , "fft" ])
188+ def test_batched_1d_agrees_with_2d_row_filter (mode , method ):
203189 data = matrix ("data" )
204190 kernel_1d = vector ("kernel_1d" )
205191 kernel_2d = expand_dims (kernel_1d , 0 )
206192
207193 output_1d = convolve1d (data , kernel_1d , mode = mode )
208- output_2d = convolve2d (data , kernel_2d , mode = mode )
194+ output_2d = convolve2d (data , kernel_2d , mode = mode , method = method )
209195
210196 grad_1d = grad (output_1d .sum (), kernel_1d ).ravel ()
211197 grad_2d = grad (output_1d .sum (), kernel_1d ).ravel ()
@@ -216,5 +202,13 @@ def test_batched_1d_agrees_with_2d_row_filter(mode):
216202 kernel_1d_val = np .random .normal (size = (3 ,)).astype (config .floatX )
217203
218204 forward_1d , forward_2d , backward_1d , backward_2d = fn (data_val , kernel_1d_val )
219- np .testing .assert_allclose (forward_1d , forward_2d )
220- np .testing .assert_allclose (backward_1d , backward_2d )
205+ np .testing .assert_allclose (
206+ forward_1d ,
207+ forward_2d ,
208+ rtol = 1e-5 if config .floatX == "float32" else 1e-13 ,
209+ )
210+ np .testing .assert_allclose (
211+ backward_1d ,
212+ backward_2d ,
213+ rtol = 1e-5 if config .floatX == "float32" else 1e-13 ,
214+ )
0 commit comments