Skip to content

Commit 6c580f3

Browse files
Add same mode, and FFT support
1 parent 4f52e32 commit 6c580f3

File tree

2 files changed

+64
-34
lines changed

2 files changed

+64
-34
lines changed

pytensor/tensor/signal/conv.py

Lines changed: 42 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import numpy as np
55
from numpy import convolve as numpy_convolve
6-
from scipy.signal import convolve2d as scipy_convolve2d
6+
from scipy.signal import convolve as scipy_convolve
77

88
from pytensor.gradient import DisconnectedType
99
from pytensor.graph import Apply, Constant
@@ -113,7 +113,7 @@ def L_op(self, inputs, outputs, output_grads):
113113
]
114114

115115

116-
class Convolve1d(AbstractConvolveNd, COp):
116+
class Convolve1d(AbstractConvolveNd, COp): # type: ignore[misc]
117117
__props__ = ()
118118
ndim = 1
119119

@@ -245,26 +245,19 @@ def convolve1d(
245245
return type_cast(TensorVariable, blockwise_convolve_1d(in1, in2, full_mode))
246246

247247

248-
class Convolve2d(AbstractConvolveNd, Op):
249-
__props__ = ()
248+
class Convolve2d(AbstractConvolveNd, Op): # type: ignore[misc]
249+
__props__ = ("method",) # type: ignore[assignment]
250250
ndim = 2
251251

252+
def __init__(self, method: Literal["direct", "fft", "auto"] = "auto"):
253+
self.method = method
254+
252255
def perform(self, node, inputs, outputs):
253256
in1, in2, full_mode = inputs
254257

255-
# if all(inpt.dtype.kind in ['f', 'c'] for inpt in inputs):
256-
# outputs[0][0] = scipy_convolve(in1, in2, mode=self.mode, method='fft')
257-
#
258-
# else:
259-
# TODO: Why is .item() needed???
260-
outputs[0][0] = scipy_convolve2d(
261-
in1,
262-
in2,
263-
mode="full" if full_mode.item() else "valid",
264-
)
265-
266-
267-
blockwise_convolve_2d = Blockwise(Convolve2d())
258+
# TODO: Why is .item() needed?
259+
mode: Literal["full", "valid", "same"] = "full" if full_mode.item() else "valid"
260+
outputs[0][0] = scipy_convolve(in1, in2, mode=mode, method=self.method)
268261

269262

270263
def convolve2d(
@@ -273,6 +266,7 @@ def convolve2d(
273266
mode: Literal["full", "valid", "same"] = "full",
274267
boundary: Literal["fill", "wrap", "symm"] = "fill",
275268
fillvalue: float | int = 0,
269+
method: Literal["direct", "fft", "auto"] = "auto",
276270
) -> TensorVariable:
277271
"""Convolve two two-dimensional arrays.
278272
@@ -296,6 +290,10 @@ def convolve2d(
296290
- 'symm': Symmetrically reflects the input arrays.
297291
fillvalue : float or int, optional
298292
The value to use for padding when boundary is 'fill'. Default is 0.
293+
method : str, one of 'direct', 'fft', or 'auto'
294+
Computation method to use. 'direct' uses direct convolution, 'fft' uses FFT-based convolution,
295+
and 'auto' lets the implementation choose the best method at runtime.
296+
299297
Returns
300298
-------
301299
out: tensor_variable
@@ -304,29 +302,44 @@ def convolve2d(
304302
"""
305303
in1 = as_tensor_variable(in1)
306304
in2 = as_tensor_variable(in2)
305+
ndim = max(in1.type.ndim, in2.type.ndim)
306+
307+
def _pad_input(input_tensor, pad_width):
308+
if boundary == "fill":
309+
return pad(
310+
input_tensor,
311+
pad_width=pad_width,
312+
mode="constant",
313+
constant_values=fillvalue,
314+
)
315+
if boundary == "wrap":
316+
return pad(input_tensor, pad_width=pad_width, mode="wrap")
317+
if boundary == "symm":
318+
return pad(input_tensor, pad_width=pad_width, mode="symmetric")
319+
raise ValueError(f"Unsupported boundary mode: {boundary}")
307320

308321
if mode == "same":
309-
raise NotImplementedError("same mode not implemented for convolve2d")
322+
# Same mode is implemented as "valid" with a padded input.
323+
pad_width = zeros((ndim, 2), dtype="int64")
324+
pad_width = pad_width[-2, 0].set(in2.shape[-2] // 2)
325+
pad_width = pad_width[-2, 1].set((in2.shape[-2] - 1) // 2)
326+
pad_width = pad_width[-1, 0].set(in2.shape[-1] // 2)
327+
pad_width = pad_width[-1, 1].set((in2.shape[-1] - 1) // 2)
328+
in1 = _pad_input(in1, pad_width)
329+
mode = "valid"
310330

311331
if mode != "valid" and (boundary != "fill" or fillvalue != 0):
312332
# We use a valid convolution on an appropriately padded kernel
313333
*_, k, l = in2.shape
314-
ndim = max(in1.type.ndim, in2.type.ndim)
315334

316335
pad_width = zeros((ndim, 2), dtype="int64")
317336
pad_width = pad_width[-2, :].set(k - 1)
318337
pad_width = pad_width[-1, :].set(l - 1)
319-
if boundary == "fill":
320-
in1 = pad(
321-
in1, pad_width=pad_width, mode="constant", constant_values=fillvalue
322-
)
323-
elif boundary == "wrap":
324-
in1 = pad(in1, pad_width=pad_width, mode="wrap")
325-
326-
elif boundary == "symm":
327-
in1 = pad(in1, pad_width=pad_width, mode="symmetric")
338+
in1 = _pad_input(in1, pad_width)
328339

329340
mode = "valid"
330341

331342
full_mode = as_scalar(np.bool_(mode == "full"))
332-
return type_cast(TensorVariable, blockwise_convolve_2d(in1, in2, full_mode))
343+
return type_cast(
344+
TensorVariable, Blockwise(Convolve2d(method=method))(in1, in2, full_mode)
345+
)

tests/tensor/signal/test_conv.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -102,12 +102,13 @@ def test_convolve1d_valid_grad(static_shape):
102102
"local_useless_unbatched_blockwise",
103103
),
104104
)
105-
grad_out.dprint()
105+
106106
[conv_node] = [
107107
node
108108
for node in io_toposort([larger, smaller], [grad_out])
109109
if isinstance(node.op, Convolve1d)
110110
]
111+
111112
full_mode = conv_node.inputs[-1]
112113
# If shape is static we get constant mode == "valid", otherwise it depends on the input shapes
113114
# ignoring E712 because np.True_ and np.False_ need to be compared with `==` to produce a valid boolean
@@ -148,7 +149,7 @@ def test_convolve1d_grad_benchmark_c(convolve_mode, benchmark):
148149
@pytest.mark.parametrize(
149150
"data_shape", [(3, 3), (5, 5), (8, 8)], ids=lambda x: f"data_shape={x}"
150151
)
151-
@pytest.mark.parametrize("mode", ["full", "valid", "same"][:-1])
152+
@pytest.mark.parametrize("mode", ["full", "valid", "same"])
152153
@pytest.mark.parametrize(
153154
"boundary, boundary_kwargs",
154155
[
@@ -181,13 +182,29 @@ def test_convolve2d(kernel_shape, data_shape, mode, boundary, boundary_kwargs):
181182
utt.verify_grad(lambda k: op(data_val, k).sum(), [kernel_val])
182183

183184

184-
def test_batched_1d_agrees_with_diagonal_2d():
185+
def test_convolve2d_fft():
186+
data = matrix("data")
187+
kernel = matrix("kernel")
188+
out_fft = convolve2d(data, kernel, mode="same", method="fft")
189+
out_direct = convolve2d(data, kernel, mode="same", method="direct")
190+
191+
rng = np.random.default_rng()
192+
data_val = rng.normal(size=(7, 5)).astype(config.floatX)
193+
kernel_val = rng.normal(size=(3, 2)).astype(config.floatX)
194+
195+
fn = function([data, kernel], [out_fft, out_direct])
196+
fft_res, direct_res = fn(data_val, kernel_val)
197+
np.testing.assert_allclose(fft_res, direct_res)
198+
199+
200+
@pytest.mark.parametrize("mode", ["full", "valid", "same"])
201+
def test_batched_1d_agrees_with_2d_row_filter(mode):
185202
data = matrix("data")
186203
kernel_1d = vector("kernel_1d")
187204
kernel_2d = expand_dims(kernel_1d, 0)
188205

189-
output_1d = convolve1d(data, kernel_1d, mode="valid")
190-
output_2d = convolve2d(data, kernel_2d, mode="valid")
206+
output_1d = convolve1d(data, kernel_1d, mode=mode)
207+
output_2d = convolve2d(data, kernel_2d, mode=mode)
191208

192209
grad_1d = grad(output_1d.sum(), kernel_1d).ravel()
193210
grad_2d = grad(output_1d.sum(), kernel_1d).ravel()

0 commit comments

Comments
 (0)