Skip to content

Commit 10d225d

Browse files
Implement Convolve2D Op (#1397)
* Implement simple wrapper Op for `scipy.signal.convolve2d` * Better shape inference * conv2d gradient v0 * Implement Convolve2d and gradients * Add same mode, and FFT support * Relax test tolerance in float32 * Change seed * Relax test tolerance in float32 * Add `convolve2d` to `signal.__init__` --------- Co-authored-by: ricardoV94 <[email protected]>
1 parent 1d13f8c commit 10d225d

File tree

4 files changed

+249
-42
lines changed

4 files changed

+249
-42
lines changed

pytensor/tensor/basic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -921,7 +921,7 @@ def zeros_like(model, dtype=None, opt=False):
921921
return fill(_model, ret)
922922

923923

924-
def zeros(shape, dtype=None):
924+
def zeros(shape, dtype=None) -> TensorVariable:
925925
"""Create a `TensorVariable` filled with zeros, closer to NumPy's syntax than ``alloc``."""
926926
if not (
927927
isinstance(shape, np.ndarray | Sequence)
@@ -933,7 +933,7 @@ def zeros(shape, dtype=None):
933933
return alloc(np.array(0, dtype=dtype), *shape)
934934

935935

936-
def ones(shape, dtype=None):
936+
def ones(shape, dtype=None) -> TensorVariable:
937937
"""Create a `TensorVariable` filled with ones, closer to NumPy's syntax than ``alloc``."""
938938
if not (
939939
isinstance(shape, np.ndarray | Sequence)

pytensor/tensor/signal/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from pytensor.tensor.signal.conv import convolve1d
1+
from pytensor.tensor.signal.conv import convolve1d, convolve2d
22

33

4-
__all__ = ("convolve1d",)
4+
__all__ = ("convolve1d", "convolve2d")

pytensor/tensor/signal/conv.py

Lines changed: 168 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,72 +1,92 @@
1-
from typing import TYPE_CHECKING, Literal, cast
1+
from typing import TYPE_CHECKING, Literal
2+
from typing import cast as type_cast
23

34
import numpy as np
45
from numpy import convolve as numpy_convolve
6+
from scipy.signal import convolve as scipy_convolve
57

68
from pytensor.gradient import DisconnectedType
79
from pytensor.graph import Apply, Constant
10+
from pytensor.graph.op import Op
811
from pytensor.link.c.op import COp
912
from pytensor.scalar import as_scalar
1013
from pytensor.scalar.basic import upcast
1114
from pytensor.tensor.basic import as_tensor_variable, join, zeros
1215
from pytensor.tensor.blockwise import Blockwise
1316
from pytensor.tensor.math import maximum, minimum, switch
14-
from pytensor.tensor.type import vector
17+
from pytensor.tensor.pad import pad
18+
from pytensor.tensor.subtensor import flip
19+
from pytensor.tensor.type import tensor
1520
from pytensor.tensor.variable import TensorVariable
1621

1722

1823
if TYPE_CHECKING:
1924
from pytensor.tensor import TensorLike
2025

2126

22-
class Convolve1d(COp):
27+
class AbstractConvolveNd:
2328
__props__ = ()
24-
gufunc_signature = "(n),(k),()->(o)"
29+
ndim: int
30+
31+
@property
32+
def gufunc_signature(self):
33+
data_signature = ",".join([f"n{i}" for i in range(self.ndim)])
34+
kernel_signature = ",".join([f"k{i}" for i in range(self.ndim)])
35+
output_signature = ",".join([f"o{i}" for i in range(self.ndim)])
36+
37+
return f"({data_signature}),({kernel_signature}),()->({output_signature})"
2538

2639
def make_node(self, in1, in2, full_mode):
2740
in1 = as_tensor_variable(in1)
2841
in2 = as_tensor_variable(in2)
2942
full_mode = as_scalar(full_mode)
3043

31-
if not (in1.ndim == 1 and in2.ndim == 1):
32-
raise ValueError("Convolution inputs must be vector (ndim=1)")
44+
ndim = self.ndim
45+
if not (in1.ndim == ndim and in2.ndim == self.ndim):
46+
raise ValueError(
47+
f"Convolution inputs must have ndim={ndim}, got: in1={in1.ndim}, in2={in2.ndim}"
48+
)
3349
if not full_mode.dtype == "bool":
34-
raise ValueError("Convolution mode must be a boolean type")
50+
raise ValueError("Convolution full_mode flag must be a boolean type")
3551

36-
dtype = upcast(in1.dtype, in2.dtype)
37-
n = in1.type.shape[0]
38-
k = in2.type.shape[0]
3952
match full_mode:
4053
case Constant():
4154
static_mode = "full" if full_mode.data else "valid"
4255
case _:
4356
static_mode = None
4457

45-
if n is None or k is None or static_mode is None:
46-
out_shape = (None,)
47-
elif static_mode == "full":
48-
out_shape = (n + k - 1,)
49-
else: # mode == "valid":
50-
out_shape = (max(n, k) - min(n, k) + 1,)
58+
if static_mode is None:
59+
out_shape = (None,) * ndim
60+
else:
61+
out_shape = []
62+
# TODO: Raise if static shapes are not valid (one input size doesn't dominate the other)
63+
for n, k in zip(in1.type.shape, in2.type.shape):
64+
if n is None or k is None:
65+
out_shape.append(None)
66+
elif static_mode == "full":
67+
out_shape.append(
68+
n + k - 1,
69+
)
70+
else: # mode == "valid":
71+
out_shape.append(
72+
max(n, k) - min(n, k) + 1,
73+
)
74+
out_shape = tuple(out_shape)
5175

52-
out = vector(dtype=dtype, shape=out_shape)
53-
return Apply(self, [in1, in2, full_mode], [out])
76+
dtype = upcast(in1.dtype, in2.dtype)
5477

55-
def perform(self, node, inputs, outputs):
56-
# We use numpy_convolve as that's what scipy would use if method="direct" was passed.
57-
# And mode != "same", which this Op doesn't cover anyway.
58-
in1, in2, full_mode = inputs
59-
outputs[0][0] = numpy_convolve(in1, in2, mode="full" if full_mode else "valid")
78+
out = tensor(dtype=dtype, shape=out_shape)
79+
return Apply(self, [in1, in2, full_mode], [out])
6080

6181
def infer_shape(self, fgraph, node, shapes):
6282
_, _, full_mode = node.inputs
6383
in1_shape, in2_shape, _ = shapes
64-
n = in1_shape[0]
65-
k = in2_shape[0]
66-
shape_valid = maximum(n, k) - minimum(n, k) + 1
67-
shape_full = n + k - 1
68-
shape = switch(full_mode, shape_full, shape_valid)
69-
return [[shape]]
84+
out_shape = [
85+
switch(full_mode, n + k - 1, maximum(n, k) - minimum(n, k) + 1)
86+
for n, k in zip(in1_shape, in2_shape)
87+
]
88+
89+
return [out_shape]
7090

7191
def connection_pattern(self, node):
7292
return [[True], [True], [False]]
@@ -75,22 +95,34 @@ def L_op(self, inputs, outputs, output_grads):
7595
in1, in2, full_mode = inputs
7696
[grad] = output_grads
7797

78-
n = in1.shape[0]
79-
k = in2.shape[0]
98+
n = in1.shape
99+
k = in2.shape
100+
# Note: this assumes the shape of one input dominates the other over all dimensions (which is required for a valid forward)
80101

81102
# If mode is "full", or mode is "valid" and k >= n, then in1_bar mode should use "valid" convolve
82103
# The expression below is equivalent to ~(full_mode | (k >= n))
83-
full_mode_in1_bar = ~full_mode & (k < n)
104+
full_mode_in1_bar = ~full_mode & (k < n).any()
84105
# If mode is "full", or mode is "valid" and n >= k, then in2_bar mode should use "valid" convolve
85106
# The expression below is equivalent to ~(full_mode | (n >= k))
86-
full_mode_in2_bar = ~full_mode & (n < k)
107+
full_mode_in2_bar = ~full_mode & (n < k).any()
87108

88109
return [
89-
self(grad, in2[::-1], full_mode_in1_bar),
90-
self(grad, in1[::-1], full_mode_in2_bar),
110+
self(grad, flip(in2), full_mode_in1_bar),
111+
self(grad, flip(in1), full_mode_in2_bar),
91112
DisconnectedType()(),
92113
]
93114

115+
116+
class Convolve1d(AbstractConvolveNd, COp): # type: ignore[misc]
117+
__props__ = ()
118+
ndim = 1
119+
120+
def perform(self, node, inputs, outputs):
121+
# We use numpy_convolve as that's what scipy would use if method="direct" was passed.
122+
# And mode != "same", which this Op doesn't cover anyway.
123+
in1, in2, full_mode = inputs
124+
outputs[0][0] = numpy_convolve(in1, in2, mode="full" if full_mode else "valid")
125+
94126
def c_code_cache_version(self):
95127
return (2,)
96128

@@ -210,4 +242,104 @@ def convolve1d(
210242
mode = "valid"
211243

212244
full_mode = as_scalar(np.bool_(mode == "full"))
213-
return cast(TensorVariable, blockwise_convolve_1d(in1, in2, full_mode))
245+
return type_cast(TensorVariable, blockwise_convolve_1d(in1, in2, full_mode))
246+
247+
248+
class Convolve2d(AbstractConvolveNd, Op): # type: ignore[misc]
249+
__props__ = ("method",) # type: ignore[assignment]
250+
ndim = 2
251+
252+
def __init__(self, method: Literal["direct", "fft", "auto"] = "auto"):
253+
self.method = method
254+
255+
def perform(self, node, inputs, outputs):
256+
in1, in2, full_mode = inputs
257+
258+
# TODO: Why is .item() needed?
259+
mode: Literal["full", "valid", "same"] = "full" if full_mode.item() else "valid"
260+
outputs[0][0] = scipy_convolve(in1, in2, mode=mode, method=self.method)
261+
262+
263+
def convolve2d(
264+
in1: "TensorLike",
265+
in2: "TensorLike",
266+
mode: Literal["full", "valid", "same"] = "full",
267+
boundary: Literal["fill", "wrap", "symm"] = "fill",
268+
fillvalue: float | int = 0,
269+
method: Literal["direct", "fft", "auto"] = "auto",
270+
) -> TensorVariable:
271+
"""Convolve two two-dimensional arrays.
272+
273+
Convolve in1 and in2, with the output size determined by the mode argument.
274+
275+
Parameters
276+
----------
277+
in1 : (..., N, M) tensor_like
278+
First input.
279+
in2 : (..., K, L) tensor_like
280+
Second input.
281+
mode : {'full', 'valid', 'same'}, optional
282+
A string indicating the size of the output:
283+
- 'full': The output is the full discrete linear convolution of the inputs, with shape (..., N+K-1, M+L-1).
284+
- 'valid': The output consists only of elements that do not rely on zero-padding, with shape (..., max(N, K) - min(N, K) + 1, max(M, L) - min(M, L) + 1).
285+
- 'same': The output is the same size as in1, centered with respect to the 'full' output.
286+
boundary : {'fill', 'wrap', 'symm'}, optional
287+
A string indicating how to handle boundaries:
288+
- 'fill': Pads the input arrays with fillvalue.
289+
- 'wrap': Circularly wraps the input arrays.
290+
- 'symm': Symmetrically reflects the input arrays.
291+
fillvalue : float or int, optional
292+
The value to use for padding when boundary is 'fill'. Default is 0.
293+
method : str, one of 'direct', 'fft', or 'auto'
294+
Computation method to use. 'direct' uses direct convolution, 'fft' uses FFT-based convolution,
295+
and 'auto' lets the implementation choose the best method at runtime.
296+
297+
Returns
298+
-------
299+
out: tensor_variable
300+
The discrete linear convolution of in1 with in2.
301+
302+
"""
303+
in1 = as_tensor_variable(in1)
304+
in2 = as_tensor_variable(in2)
305+
ndim = max(in1.type.ndim, in2.type.ndim)
306+
307+
def _pad_input(input_tensor, pad_width):
308+
if boundary == "fill":
309+
return pad(
310+
input_tensor,
311+
pad_width=pad_width,
312+
mode="constant",
313+
constant_values=fillvalue,
314+
)
315+
if boundary == "wrap":
316+
return pad(input_tensor, pad_width=pad_width, mode="wrap")
317+
if boundary == "symm":
318+
return pad(input_tensor, pad_width=pad_width, mode="symmetric")
319+
raise ValueError(f"Unsupported boundary mode: {boundary}")
320+
321+
if mode == "same":
322+
# Same mode is implemented as "valid" with a padded input.
323+
pad_width = zeros((ndim, 2), dtype="int64")
324+
pad_width = pad_width[-2, 0].set(in2.shape[-2] // 2)
325+
pad_width = pad_width[-2, 1].set((in2.shape[-2] - 1) // 2)
326+
pad_width = pad_width[-1, 0].set(in2.shape[-1] // 2)
327+
pad_width = pad_width[-1, 1].set((in2.shape[-1] - 1) // 2)
328+
in1 = _pad_input(in1, pad_width)
329+
mode = "valid"
330+
331+
if mode != "valid" and (boundary != "fill" or fillvalue != 0):
332+
# We use a valid convolution on an appropriately padded kernel
333+
*_, k, l = in2.shape
334+
335+
pad_width = zeros((ndim, 2), dtype="int64")
336+
pad_width = pad_width[-2, :].set(k - 1)
337+
pad_width = pad_width[-1, :].set(l - 1)
338+
in1 = _pad_input(in1, pad_width)
339+
340+
mode = "valid"
341+
342+
full_mode = as_scalar(np.bool_(mode == "full"))
343+
return type_cast(
344+
TensorVariable, Blockwise(Convolve2d(method=method))(in1, in2, full_mode)
345+
)

0 commit comments

Comments
 (0)