Skip to content

Commit 1e61660

Browse files
Implement Convolve2d and gradients
1 parent c2819a8 commit 1e61660

File tree

4 files changed

+143
-156
lines changed

4 files changed

+143
-156
lines changed

pytensor/gradient.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1951,10 +1951,10 @@ def random_projection():
19511951
mode_for_cost = mode
19521952

19531953
cost_fn = fn_maker(tensor_pt, cost, name="gradient.py cost", mode=mode_for_cost)
1954-
19551954
symbolic_grad = grad(cost, tensor_pt, disconnected_inputs="ignore")
19561955

19571956
grad_fn = fn_maker(tensor_pt, symbolic_grad, name="gradient.py symbolic grad")
1957+
grad_fn.dprint(print_shape=True)
19581958

19591959
for test_num in range(n_tests):
19601960
num_grad = numeric_grad(cost_fn, [p.copy() for p in pt], eps, out_type)

pytensor/tensor/basic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -921,7 +921,7 @@ def zeros_like(model, dtype=None, opt=False):
921921
return fill(_model, ret)
922922

923923

924-
def zeros(shape, dtype=None):
924+
def zeros(shape, dtype=None) -> TensorVariable:
925925
"""Create a `TensorVariable` filled with zeros, closer to NumPy's syntax than ``alloc``."""
926926
if not (
927927
isinstance(shape, np.ndarray | Sequence)
@@ -933,7 +933,7 @@ def zeros(shape, dtype=None):
933933
return alloc(np.array(0, dtype=dtype), *shape)
934934

935935

936-
def ones(shape, dtype=None):
936+
def ones(shape, dtype=None) -> TensorVariable:
937937
"""Create a `TensorVariable` filled with ones, closer to NumPy's syntax than ``alloc``."""
938938
if not (
939939
isinstance(shape, np.ndarray | Sequence)

pytensor/tensor/signal/conv.py

Lines changed: 99 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from typing import TYPE_CHECKING, Literal, cast
1+
from typing import TYPE_CHECKING, Literal
2+
from typing import cast as type_cast
23

34
import numpy as np
45
from numpy import convolve as numpy_convolve
@@ -13,62 +14,79 @@
1314
from pytensor.tensor.basic import as_tensor_variable, join, zeros
1415
from pytensor.tensor.blockwise import Blockwise
1516
from pytensor.tensor.math import maximum, minimum, switch
16-
from pytensor.tensor.type import matrix, vector
17+
from pytensor.tensor.pad import pad
18+
from pytensor.tensor.subtensor import flip
19+
from pytensor.tensor.type import tensor
1720
from pytensor.tensor.variable import TensorVariable
1821

1922

2023
if TYPE_CHECKING:
2124
from pytensor.tensor import TensorLike
2225

2326

24-
class Convolve1d(COp):
27+
class AbstractConvolveNd:
2528
__props__ = ()
26-
gufunc_signature = "(n),(k),()->(o)"
29+
ndim: int
30+
31+
@property
32+
def gufunc_signature(self):
33+
data_signature = ",".join([f"n{i}" for i in range(self.ndim)])
34+
kernel_signature = ",".join([f"k{i}" for i in range(self.ndim)])
35+
output_signature = ",".join([f"o{i}" for i in range(self.ndim)])
36+
37+
return f"({data_signature}),({kernel_signature}),()->({output_signature})"
2738

2839
def make_node(self, in1, in2, full_mode):
2940
in1 = as_tensor_variable(in1)
3041
in2 = as_tensor_variable(in2)
3142
full_mode = as_scalar(full_mode)
3243

33-
if not (in1.ndim == 1 and in2.ndim == 1):
34-
raise ValueError("Convolution inputs must be vector (ndim=1)")
44+
ndim = self.ndim
45+
if not (in1.ndim == ndim and in2.ndim == self.ndim):
46+
raise ValueError(
47+
f"Convolution inputs must have ndim={ndim}, got: in1={in1.ndim}, in2={in2.ndim}"
48+
)
3549
if not full_mode.dtype == "bool":
36-
raise ValueError("Convolution mode must be a boolean type")
50+
raise ValueError("Convolution full_mode flag must be a boolean type")
3751

38-
dtype = upcast(in1.dtype, in2.dtype)
39-
n = in1.type.shape[0]
40-
k = in2.type.shape[0]
4152
match full_mode:
4253
case Constant():
4354
static_mode = "full" if full_mode.data else "valid"
4455
case _:
4556
static_mode = None
4657

47-
if n is None or k is None or static_mode is None:
48-
out_shape = (None,)
49-
elif static_mode == "full":
50-
out_shape = (n + k - 1,)
51-
else: # mode == "valid":
52-
out_shape = (max(n, k) - min(n, k) + 1,)
58+
if static_mode is None:
59+
out_shape = (None,) * ndim
60+
else:
61+
out_shape = []
62+
# TODO: Raise if static shapes are not valid (one input size doesn't dominate the other)
63+
for n, k in zip(in1.type.shape, in2.type.shape):
64+
if n is None or k is None:
65+
out_shape.append(None)
66+
elif static_mode == "full":
67+
out_shape.append(
68+
n + k - 1,
69+
)
70+
else: # mode == "valid":
71+
out_shape.append(
72+
max(n, k) - min(n, k) + 1,
73+
)
74+
out_shape = tuple(out_shape)
5375

54-
out = vector(dtype=dtype, shape=out_shape)
55-
return Apply(self, [in1, in2, full_mode], [out])
76+
dtype = upcast(in1.dtype, in2.dtype)
5677

57-
def perform(self, node, inputs, outputs):
58-
# We use numpy_convolve as that's what scipy would use if method="direct" was passed.
59-
# And mode != "same", which this Op doesn't cover anyway.
60-
in1, in2, full_mode = inputs
61-
outputs[0][0] = numpy_convolve(in1, in2, mode="full" if full_mode else "valid")
78+
out = tensor(dtype=dtype, shape=out_shape)
79+
return Apply(self, [in1, in2, full_mode], [out])
6280

6381
def infer_shape(self, fgraph, node, shapes):
6482
_, _, full_mode = node.inputs
6583
in1_shape, in2_shape, _ = shapes
66-
n = in1_shape[0]
67-
k = in2_shape[0]
68-
shape_valid = maximum(n, k) - minimum(n, k) + 1
69-
shape_full = n + k - 1
70-
shape = switch(full_mode, shape_full, shape_valid)
71-
return [[shape]]
84+
out_shape = [
85+
switch(full_mode, n + k - 1, maximum(n, k) - minimum(n, k) + 1)
86+
for n, k in zip(in1_shape, in2_shape)
87+
]
88+
89+
return [out_shape]
7290

7391
def connection_pattern(self, node):
7492
return [[True], [True], [False]]
@@ -77,22 +95,34 @@ def L_op(self, inputs, outputs, output_grads):
7795
in1, in2, full_mode = inputs
7896
[grad] = output_grads
7997

80-
n = in1.shape[0]
81-
k = in2.shape[0]
98+
n = in1.shape
99+
k = in2.shape
100+
# Note: this assumes the shape of one input dominates the other over all dimensions (which is required for a valid forward)
82101

83102
# If mode is "full", or mode is "valid" and k >= n, then in1_bar mode should use "valid" convolve
84103
# The expression below is equivalent to ~(full_mode | (k >= n))
85-
full_mode_in1_bar = ~full_mode & (k < n)
104+
full_mode_in1_bar = ~full_mode & (k < n).any()
86105
# If mode is "full", or mode is "valid" and n >= k, then in2_bar mode should use "valid" convolve
87106
# The expression below is equivalent to ~(full_mode | (n >= k))
88-
full_mode_in2_bar = ~full_mode & (n < k)
107+
full_mode_in2_bar = ~full_mode & (n < k).any()
89108

90109
return [
91-
self(grad, in2[::-1], full_mode_in1_bar),
92-
self(grad, in1[::-1], full_mode_in2_bar),
110+
self(grad, flip(in2), full_mode_in1_bar),
111+
self(grad, flip(in1), full_mode_in2_bar),
93112
DisconnectedType()(),
94113
]
95114

115+
116+
class Convolve1d(AbstractConvolveNd, COp):
117+
__props__ = ()
118+
ndim = 1
119+
120+
def perform(self, node, inputs, outputs):
121+
# We use numpy_convolve as that's what scipy would use if method="direct" was passed.
122+
# And mode != "same", which this Op doesn't cover anyway.
123+
in1, in2, full_mode = inputs
124+
outputs[0][0] = numpy_convolve(in1, in2, mode="full" if full_mode else "valid")
125+
96126
def c_code_cache_version(self):
97127
return (2,)
98128

@@ -212,94 +242,29 @@ def convolve1d(
212242
mode = "valid"
213243

214244
full_mode = as_scalar(np.bool_(mode == "full"))
215-
return cast(TensorVariable, blockwise_convolve_1d(in1, in2, full_mode))
216-
217-
218-
class Convolve2D(Op):
219-
__props__ = ("mode", "boundary", "fillvalue")
220-
gufunc_signature = "(n,m),(k,l)->(o,p)"
245+
return type_cast(TensorVariable, blockwise_convolve_1d(in1, in2, full_mode))
221246

222-
def __init__(
223-
self,
224-
mode: Literal["full", "valid"] = "full",
225-
boundary: Literal["fill", "wrap", "symm"] = "fill",
226-
fillvalue: float | int = 0,
227-
):
228-
if mode not in ("full", "valid"):
229-
raise ValueError(f"Invalid mode: {mode}")
230247

231-
self.mode = mode
232-
self.fillvalue = fillvalue
233-
self.boundary = boundary
234-
235-
def make_node(self, in1, in2):
236-
in1, in2 = map(as_tensor_variable, (in1, in2))
237-
238-
assert in1.ndim == 2
239-
assert in2.ndim == 2
240-
241-
dtype = upcast(in1.dtype, in2.dtype)
242-
243-
n, m = in1.type.shape
244-
k, l = in2.type.shape
245-
246-
if self.mode == "full":
247-
shape_1 = None if (n is None or k is None) else n + k - 1
248-
shape_2 = None if (m is None or l is None) else m + l - 1
249-
250-
elif self.mode == "valid":
251-
shape_1 = None if (n is None or k is None) else max(n, k) - max(n, k) + 1
252-
shape_2 = None if (m is None or l is None) else max(m, l) - min(m, l) + 1
253-
254-
else: # mode == "same"
255-
shape_1 = n
256-
shape_2 = m
257-
258-
out_shape = (shape_1, shape_2)
259-
out = matrix(dtype=dtype, shape=out_shape)
260-
return Apply(self, [in1, in2], [out])
248+
class Convolve2d(AbstractConvolveNd, Op):
249+
__props__ = ()
250+
ndim = 2
261251

262252
def perform(self, node, inputs, outputs):
263-
in1, in2 = inputs
253+
in1, in2, full_mode = inputs
264254

265255
# if all(inpt.dtype.kind in ['f', 'c'] for inpt in inputs):
266256
# outputs[0][0] = scipy_convolve(in1, in2, mode=self.mode, method='fft')
267257
#
268258
# else:
259+
# TODO: Why is .item() needed???
269260
outputs[0][0] = scipy_convolve2d(
270-
in1, in2, mode=self.mode, fillvalue=self.fillvalue, boundary=self.boundary
261+
in1,
262+
in2,
263+
mode="full" if full_mode.item() else "valid",
271264
)
272265

273-
def infer_shape(self, fgraph, node, shapes):
274-
in1_shape, in2_shape = shapes
275-
n, m = in1_shape
276-
k, l = in2_shape
277-
278-
if self.mode == "full":
279-
shape = (n + k - 1, m + l - 1)
280-
elif self.mode == "valid":
281-
shape = (
282-
maximum(n, k) - minimum(n, k) + 1,
283-
maximum(m, l) - minimum(m, l) + 1,
284-
)
285-
else: # self.mode == 'same':
286-
shape = (n, m)
287-
288-
return [shape]
289-
290-
def L_op(self, inputs, outputs, output_grads):
291-
in1, in2 = inputs
292-
incoming_grads = output_grads[0]
293-
294-
if self.mode == "full":
295-
prop_dict = self._props_dict()
296-
prop_dict["mode"] = "valid"
297-
conv_valid = type(self)(**prop_dict)
298-
299-
in1_grad = conv_valid(in2, incoming_grads)
300-
in2_grad = conv_valid(in1, incoming_grads)
301266

302-
return [in1_grad, in2_grad]
267+
blockwise_convolve_2d = Blockwise(Convolve2d())
303268

304269

305270
def convolve2d(
@@ -340,10 +305,28 @@ def convolve2d(
340305
in1 = as_tensor_variable(in1)
341306
in2 = as_tensor_variable(in2)
342307

343-
# TODO: Handle boundaries symbolically
344-
# TODO: Handle 'same' symbolically
308+
if mode == "same":
309+
raise NotImplementedError("same mode not implemented for convolve2d")
310+
311+
if mode != "valid" and (boundary != "fill" or fillvalue != 0):
312+
# We use a valid convolution on an appropriately padded kernel
313+
*_, k, l = in2.shape
314+
ndim = max(in1.type.ndim, in2.type.ndim)
315+
316+
pad_width = zeros((ndim, 2), dtype="int64")
317+
pad_width = pad_width[-2, :].set(k - 1)
318+
pad_width = pad_width[-1, :].set(l - 1)
319+
if boundary == "fill":
320+
in1 = pad(
321+
in1, pad_width=pad_width, mode="constant", constant_values=fillvalue
322+
)
323+
elif boundary == "wrap":
324+
in1 = pad(in1, pad_width=pad_width, mode="wrap")
325+
326+
elif boundary == "symm":
327+
in1 = pad(in1, pad_width=pad_width, mode="symmetric")
345328

346-
blockwise_convolve = Blockwise(
347-
Convolve2D(mode=mode, boundary=boundary, fillvalue=fillvalue)
348-
)
349-
return cast(TensorVariable, blockwise_convolve(in1, in2))
329+
mode = "valid"
330+
331+
full_mode = as_scalar(np.bool_(mode == "full"))
332+
return type_cast(TensorVariable, blockwise_convolve_2d(in1, in2, full_mode))

0 commit comments

Comments
 (0)