Skip to content

Commit ea87e03

Browse files
Add same mode, C-code, and FFT support
1 parent 1e61660 commit ea87e03

File tree

2 files changed

+134
-30
lines changed

2 files changed

+134
-30
lines changed

pytensor/tensor/signal/conv.py

Lines changed: 112 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import numpy as np
55
from numpy import convolve as numpy_convolve
6-
from scipy.signal import convolve2d as scipy_convolve2d
6+
from scipy.signal import convolve as scipy_convolve
77

88
from pytensor.gradient import DisconnectedType
99
from pytensor.graph import Apply, Constant
@@ -246,25 +246,92 @@ def convolve1d(
246246

247247

248248
class Convolve2d(AbstractConvolveNd, Op):
249-
__props__ = ()
249+
__props__ = ("method",)
250250
ndim = 2
251251

252+
def __init__(self, method: Literal["direct", "fft", "auto"] = "auto"):
253+
self.method = method
254+
252255
def perform(self, node, inputs, outputs):
253256
in1, in2, full_mode = inputs
254257

255-
# if all(inpt.dtype.kind in ['f', 'c'] for inpt in inputs):
256-
# outputs[0][0] = scipy_convolve(in1, in2, mode=self.mode, method='fft')
257-
#
258-
# else:
259-
# TODO: Why is .item() needed???
260-
outputs[0][0] = scipy_convolve2d(
261-
in1,
262-
in2,
263-
mode="full" if full_mode.item() else "valid",
264-
)
258+
# TODO: Why is .item() needed?
259+
mode: Literal["full", "valid", "same"] = "full" if full_mode.item() else "valid"
260+
outputs[0][0] = scipy_convolve(in1, in2, mode=mode, method=self.method)
261+
262+
def c_code_cache_version(self):
263+
return (1,)
264+
265+
def c_code(self, node, name, inputs, outputs, sub):
266+
in1, in2, full_mode = inputs
267+
[out] = outputs
268+
269+
# For now, only the direct/correlation-based implementation is provided in C.
270+
# FFT-based convolution would require us to link to a vendored FFT library. Scipy uses
271+
# pypocketfft for that, but I'm not sure if we can easily call into that from here.
272+
code = f"""
273+
{{
274+
if (PyArray_NDIM({in1}) != 2 || PyArray_NDIM({in2}) != 2) {{
275+
PyErr_SetString(PyExc_ValueError, "Convolve2d C code expects 2D arrays.");
276+
{sub["fail"]};
277+
}}
265278
279+
npy_intp k0 = PyArray_DIM({in2}, 0);
280+
npy_intp k1 = PyArray_DIM({in2}, 1);
266281
267-
blockwise_convolve_2d = Blockwise(Convolve2d())
282+
if (k0 == 0 || k1 == 0) {{
283+
PyErr_SetString(PyExc_ValueError, "Convolve2d: second input (kernel) cannot be empty.");
284+
{sub["fail"]};
285+
}}
286+
287+
npy_intp dims[2] = {{k0, k1}};
288+
npy_intp strides[2];
289+
strides[0] = -PyArray_STRIDES({in2})[0];
290+
strides[1] = -PyArray_STRIDES({in2})[1];
291+
292+
char* data = (char*)PyArray_DATA({in2})
293+
+ (k0 - 1) * PyArray_STRIDES({in2})[0]
294+
+ (k1 - 1) * PyArray_STRIDES({in2})[1];
295+
296+
PyArrayObject* in2_flipped_view = (PyArrayObject*)PyArray_NewFromDescr(
297+
Py_TYPE({in2}),
298+
PyArray_DESCR({in2}),
299+
2,
300+
dims,
301+
strides,
302+
data,
303+
(PyArray_FLAGS({in2}) & ~NPY_ARRAY_WRITEABLE),
304+
NULL
305+
);
306+
307+
if (!in2_flipped_view) {{
308+
PyErr_SetString(PyExc_RuntimeError, "Failed to create flipped kernel view for Convolve2d.");
309+
{sub["fail"]};
310+
}}
311+
312+
Py_INCREF({in2});
313+
if (PyArray_SetBaseObject(in2_flipped_view, (PyObject*){in2}) < 0) {{
314+
Py_DECREF({in2});
315+
Py_DECREF(in2_flipped_view);
316+
in2_flipped_view = NULL;
317+
PyErr_SetString(PyExc_RuntimeError, "Failed to set base object for flipped kernel view in Convolve2d.");
318+
{sub["fail"]};
319+
}}
320+
321+
PyArray_UpdateFlags(in2_flipped_view, (NPY_ARRAY_C_CONTIGUOUS | NPY_ARRAY_F_CONTIGUOUS));
322+
323+
int mode_int = {full_mode} ? 2 : 0;
324+
325+
Py_XDECREF({out});
326+
{out} = (PyArrayObject*)PyArray_Correlate2((PyObject*){in1}, (PyObject*)in2_flipped_view, mode_int);
327+
Py_XDECREF(in2_flipped_view);
328+
329+
if (!{out}) {{
330+
{sub["fail"]};
331+
}}
332+
}}
333+
"""
334+
return code
268335

269336

270337
def convolve2d(
@@ -273,6 +340,7 @@ def convolve2d(
273340
mode: Literal["full", "valid", "same"] = "full",
274341
boundary: Literal["fill", "wrap", "symm"] = "fill",
275342
fillvalue: float | int = 0,
343+
method: Literal["direct", "fft", "auto"] = "auto",
276344
) -> TensorVariable:
277345
"""Convolve two two-dimensional arrays.
278346
@@ -296,6 +364,10 @@ def convolve2d(
296364
- 'symm': Symmetrically reflects the input arrays.
297365
fillvalue : float or int, optional
298366
The value to use for padding when boundary is 'fill'. Default is 0.
367+
method : str, one of 'direct', 'fft', or 'auto'
368+
Computation method to use. 'direct' uses direct convolution, 'fft' uses FFT-based convolution,
369+
and 'auto' lets the implementation choose the best method at runtime.
370+
299371
Returns
300372
-------
301373
out: tensor_variable
@@ -304,29 +376,44 @@ def convolve2d(
304376
"""
305377
in1 = as_tensor_variable(in1)
306378
in2 = as_tensor_variable(in2)
379+
ndim = max(in1.type.ndim, in2.type.ndim)
380+
381+
def _pad_input(input_tensor, pad_width):
382+
if boundary == "fill":
383+
return pad(
384+
input_tensor,
385+
pad_width=pad_width,
386+
mode="constant",
387+
constant_values=fillvalue,
388+
)
389+
if boundary == "wrap":
390+
return pad(input_tensor, pad_width=pad_width, mode="wrap")
391+
if boundary == "symm":
392+
return pad(input_tensor, pad_width=pad_width, mode="symmetric")
393+
raise ValueError(f"Unsupported boundary mode: {boundary}")
307394

308395
if mode == "same":
309-
raise NotImplementedError("same mode not implemented for convolve2d")
396+
# Same mode is implemented as "valid" with a padded input.
397+
pad_width = zeros((ndim, 2), dtype="int64")
398+
pad_width = pad_width[-2, 0].set(in2.shape[-2] // 2)
399+
pad_width = pad_width[-2, 1].set((in2.shape[-2] - 1) // 2)
400+
pad_width = pad_width[-1, 0].set(in2.shape[-1] // 2)
401+
pad_width = pad_width[-1, 1].set((in2.shape[-1] - 1) // 2)
402+
in1 = _pad_input(in1, pad_width)
403+
mode = "valid"
310404

311405
if mode != "valid" and (boundary != "fill" or fillvalue != 0):
312406
# We use a valid convolution on an appropriately padded kernel
313407
*_, k, l = in2.shape
314-
ndim = max(in1.type.ndim, in2.type.ndim)
315408

316409
pad_width = zeros((ndim, 2), dtype="int64")
317410
pad_width = pad_width[-2, :].set(k - 1)
318411
pad_width = pad_width[-1, :].set(l - 1)
319-
if boundary == "fill":
320-
in1 = pad(
321-
in1, pad_width=pad_width, mode="constant", constant_values=fillvalue
322-
)
323-
elif boundary == "wrap":
324-
in1 = pad(in1, pad_width=pad_width, mode="wrap")
325-
326-
elif boundary == "symm":
327-
in1 = pad(in1, pad_width=pad_width, mode="symmetric")
412+
in1 = _pad_input(in1, pad_width)
328413

329414
mode = "valid"
330415

331416
full_mode = as_scalar(np.bool_(mode == "full"))
332-
return type_cast(TensorVariable, blockwise_convolve_2d(in1, in2, full_mode))
417+
return type_cast(
418+
TensorVariable, Blockwise(Convolve2d(method=method))(in1, in2, full_mode)
419+
)

tests/tensor/signal/test_conv.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -102,12 +102,13 @@ def test_convolve1d_valid_grad(static_shape):
102102
"local_useless_unbatched_blockwise",
103103
),
104104
)
105-
grad_out.dprint()
105+
106106
[conv_node] = [
107107
node
108108
for node in io_toposort([larger, smaller], [grad_out])
109109
if isinstance(node.op, Convolve1d)
110110
]
111+
111112
full_mode = conv_node.inputs[-1]
112113
# If shape is static we get constant mode == "valid", otherwise it depends on the input shapes
113114
# ignoring E712 because np.True_ and np.False_ need to be compared with `==` to produce a valid boolean
@@ -148,7 +149,7 @@ def test_convolve1d_grad_benchmark_c(convolve_mode, benchmark):
148149
@pytest.mark.parametrize(
149150
"data_shape", [(3, 3), (5, 5), (8, 8)], ids=lambda x: f"data_shape={x}"
150151
)
151-
@pytest.mark.parametrize("mode", ["full", "valid", "same"][:-1])
152+
@pytest.mark.parametrize("mode", ["full", "valid", "same"])
152153
@pytest.mark.parametrize(
153154
"boundary, boundary_kwargs",
154155
[
@@ -181,13 +182,29 @@ def test_convolve2d(kernel_shape, data_shape, mode, boundary, boundary_kwargs):
181182
utt.verify_grad(lambda k: op(data_val, k).sum(), [kernel_val])
182183

183184

184-
def test_batched_1d_agrees_with_diagonal_2d():
185+
def test_convolve2d_fft():
186+
data = matrix("data")
187+
kernel = matrix("kernel")
188+
out_fft = convolve2d(data, kernel, mode="same", method="fft")
189+
out_direct = convolve2d(data, kernel, mode="same", method="direct")
190+
191+
rng = np.random.default_rng()
192+
data_val = rng.normal(size=(7, 5)).astype(config.floatX)
193+
kernel_val = rng.normal(size=(3, 2)).astype(config.floatX)
194+
195+
fn = function([data, kernel], [out_fft, out_direct])
196+
fft_res, direct_res = fn(data_val, kernel_val)
197+
np.testing.assert_allclose(fft_res, direct_res)
198+
199+
200+
@pytest.mark.parametrize("mode", ["full", "valid", "same"])
201+
def test_batched_1d_agrees_with_2d_row_filter(mode):
185202
data = matrix("data")
186203
kernel_1d = vector("kernel_1d")
187204
kernel_2d = expand_dims(kernel_1d, 0)
188205

189-
output_1d = convolve1d(data, kernel_1d, mode="valid")
190-
output_2d = convolve2d(data, kernel_2d, mode="valid")
206+
output_1d = convolve1d(data, kernel_1d, mode=mode)
207+
output_2d = convolve2d(data, kernel_2d, mode=mode)
191208

192209
grad_1d = grad(output_1d.sum(), kernel_1d).ravel()
193210
grad_2d = grad(output_1d.sum(), kernel_1d).ravel()

0 commit comments

Comments
 (0)