1- from typing import TYPE_CHECKING , Literal , cast
1+ from typing import TYPE_CHECKING , Literal
2+ from typing import cast as type_cast
23
34import numpy as np
45from numpy import convolve as numpy_convolve
6+ from scipy .signal import convolve as scipy_convolve
57
68from pytensor .gradient import DisconnectedType
79from pytensor .graph import Apply , Constant
10+ from pytensor .graph .op import Op
811from pytensor .link .c .op import COp
912from pytensor .scalar import as_scalar
1013from pytensor .scalar .basic import upcast
1114from pytensor .tensor .basic import as_tensor_variable , join , zeros
1215from pytensor .tensor .blockwise import Blockwise
1316from pytensor .tensor .math import maximum , minimum , switch
14- from pytensor .tensor .type import vector
17+ from pytensor .tensor .pad import pad
18+ from pytensor .tensor .subtensor import flip
19+ from pytensor .tensor .type import tensor
1520from pytensor .tensor .variable import TensorVariable
1621
1722
1823if TYPE_CHECKING :
1924 from pytensor .tensor import TensorLike
2025
2126
22- class Convolve1d ( COp ) :
27+ class AbstractConvolveNd :
2328 __props__ = ()
24- gufunc_signature = "(n),(k),()->(o)"
29+ ndim : int
30+
31+ @property
32+ def gufunc_signature (self ):
33+ data_signature = "," .join ([f"n{ i } " for i in range (self .ndim )])
34+ kernel_signature = "," .join ([f"k{ i } " for i in range (self .ndim )])
35+ output_signature = "," .join ([f"o{ i } " for i in range (self .ndim )])
36+
37+ return f"({ data_signature } ),({ kernel_signature } ),()->({ output_signature } )"
2538
2639 def make_node (self , in1 , in2 , full_mode ):
2740 in1 = as_tensor_variable (in1 )
2841 in2 = as_tensor_variable (in2 )
2942 full_mode = as_scalar (full_mode )
3043
31- if not (in1 .ndim == 1 and in2 .ndim == 1 ):
32- raise ValueError ("Convolution inputs must be vector (ndim=1)" )
44+ ndim = self .ndim
45+ if not (in1 .ndim == ndim and in2 .ndim == self .ndim ):
46+ raise ValueError (
47+ f"Convolution inputs must have ndim={ ndim } , got: in1={ in1 .ndim } , in2={ in2 .ndim } "
48+ )
3349 if not full_mode .dtype == "bool" :
34- raise ValueError ("Convolution mode must be a boolean type" )
50+ raise ValueError ("Convolution full_mode flag must be a boolean type" )
3551
36- dtype = upcast (in1 .dtype , in2 .dtype )
37- n = in1 .type .shape [0 ]
38- k = in2 .type .shape [0 ]
3952 match full_mode :
4053 case Constant ():
4154 static_mode = "full" if full_mode .data else "valid"
4255 case _:
4356 static_mode = None
4457
45- if n is None or k is None or static_mode is None :
46- out_shape = (None ,)
47- elif static_mode == "full" :
48- out_shape = (n + k - 1 ,)
49- else : # mode == "valid":
50- out_shape = (max (n , k ) - min (n , k ) + 1 ,)
58+ if static_mode is None :
59+ out_shape = (None ,) * ndim
60+ else :
61+ out_shape = []
62+ # TODO: Raise if static shapes are not valid (one input size doesn't dominate the other)
63+ for n , k in zip (in1 .type .shape , in2 .type .shape ):
64+ if n is None or k is None :
65+ out_shape .append (None )
66+ elif static_mode == "full" :
67+ out_shape .append (
68+ n + k - 1 ,
69+ )
70+ else : # mode == "valid":
71+ out_shape .append (
72+ max (n , k ) - min (n , k ) + 1 ,
73+ )
74+ out_shape = tuple (out_shape )
5175
52- out = vector (dtype = dtype , shape = out_shape )
53- return Apply (self , [in1 , in2 , full_mode ], [out ])
76+ dtype = upcast (in1 .dtype , in2 .dtype )
5477
55- def perform (self , node , inputs , outputs ):
56- # We use numpy_convolve as that's what scipy would use if method="direct" was passed.
57- # And mode != "same", which this Op doesn't cover anyway.
58- in1 , in2 , full_mode = inputs
59- outputs [0 ][0 ] = numpy_convolve (in1 , in2 , mode = "full" if full_mode else "valid" )
78+ out = tensor (dtype = dtype , shape = out_shape )
79+ return Apply (self , [in1 , in2 , full_mode ], [out ])
6080
6181 def infer_shape (self , fgraph , node , shapes ):
6282 _ , _ , full_mode = node .inputs
6383 in1_shape , in2_shape , _ = shapes
64- n = in1_shape [ 0 ]
65- k = in2_shape [ 0 ]
66- shape_valid = maximum ( n , k ) - minimum ( n , k ) + 1
67- shape_full = n + k - 1
68- shape = switch ( full_mode , shape_full , shape_valid )
69- return [[ shape ] ]
84+ out_shape = [
85+ switch ( full_mode , n + k - 1 , maximum ( n , k ) - minimum ( n , k ) + 1 )
86+ for n , k in zip ( in1_shape , in2_shape )
87+ ]
88+
89+ return [out_shape ]
7090
7191 def connection_pattern (self , node ):
7292 return [[True ], [True ], [False ]]
@@ -75,22 +95,34 @@ def L_op(self, inputs, outputs, output_grads):
7595 in1 , in2 , full_mode = inputs
7696 [grad ] = output_grads
7797
78- n = in1 .shape [0 ]
79- k = in2 .shape [0 ]
98+ n = in1 .shape
99+ k = in2 .shape
100+ # Note: this assumes the shape of one input dominates the other over all dimensions (which is required for a valid forward)
80101
81102 # If mode is "full", or mode is "valid" and k >= n, then in1_bar mode should use "valid" convolve
82103 # The expression below is equivalent to ~(full_mode | (k >= n))
83- full_mode_in1_bar = ~ full_mode & (k < n )
104+ full_mode_in1_bar = ~ full_mode & (k < n ). any ()
84105 # If mode is "full", or mode is "valid" and n >= k, then in2_bar mode should use "valid" convolve
85106 # The expression below is equivalent to ~(full_mode | (n >= k))
86- full_mode_in2_bar = ~ full_mode & (n < k )
107+ full_mode_in2_bar = ~ full_mode & (n < k ). any ()
87108
88109 return [
89- self (grad , in2 [:: - 1 ] , full_mode_in1_bar ),
90- self (grad , in1 [:: - 1 ] , full_mode_in2_bar ),
110+ self (grad , flip ( in2 ) , full_mode_in1_bar ),
111+ self (grad , flip ( in1 ) , full_mode_in2_bar ),
91112 DisconnectedType ()(),
92113 ]
93114
115+
116+ class Convolve1d (AbstractConvolveNd , COp ): # type: ignore[misc]
117+ __props__ = ()
118+ ndim = 1
119+
120+ def perform (self , node , inputs , outputs ):
121+ # We use numpy_convolve as that's what scipy would use if method="direct" was passed.
122+ # And mode != "same", which this Op doesn't cover anyway.
123+ in1 , in2 , full_mode = inputs
124+ outputs [0 ][0 ] = numpy_convolve (in1 , in2 , mode = "full" if full_mode else "valid" )
125+
94126 def c_code_cache_version (self ):
95127 return (2 ,)
96128
@@ -210,4 +242,104 @@ def convolve1d(
210242 mode = "valid"
211243
212244 full_mode = as_scalar (np .bool_ (mode == "full" ))
213- return cast (TensorVariable , blockwise_convolve_1d (in1 , in2 , full_mode ))
245+ return type_cast (TensorVariable , blockwise_convolve_1d (in1 , in2 , full_mode ))
246+
247+
248+ class Convolve2d (AbstractConvolveNd , Op ): # type: ignore[misc]
249+ __props__ = ("method" ,) # type: ignore[assignment]
250+ ndim = 2
251+
252+ def __init__ (self , method : Literal ["direct" , "fft" , "auto" ] = "auto" ):
253+ self .method = method
254+
255+ def perform (self , node , inputs , outputs ):
256+ in1 , in2 , full_mode = inputs
257+
258+ # TODO: Why is .item() needed?
259+ mode : Literal ["full" , "valid" , "same" ] = "full" if full_mode .item () else "valid"
260+ outputs [0 ][0 ] = scipy_convolve (in1 , in2 , mode = mode , method = self .method )
261+
262+
263+ def convolve2d (
264+ in1 : "TensorLike" ,
265+ in2 : "TensorLike" ,
266+ mode : Literal ["full" , "valid" , "same" ] = "full" ,
267+ boundary : Literal ["fill" , "wrap" , "symm" ] = "fill" ,
268+ fillvalue : float | int = 0 ,
269+ method : Literal ["direct" , "fft" , "auto" ] = "auto" ,
270+ ) -> TensorVariable :
271+ """Convolve two two-dimensional arrays.
272+
273+ Convolve in1 and in2, with the output size determined by the mode argument.
274+
275+ Parameters
276+ ----------
277+ in1 : (..., N, M) tensor_like
278+ First input.
279+ in2 : (..., K, L) tensor_like
280+ Second input.
281+ mode : {'full', 'valid', 'same'}, optional
282+ A string indicating the size of the output:
283+ - 'full': The output is the full discrete linear convolution of the inputs, with shape (..., N+K-1, M+L-1).
284+ - 'valid': The output consists only of elements that do not rely on zero-padding, with shape (..., max(N, K) - min(N, K) + 1, max(M, L) - min(M, L) + 1).
285+ - 'same': The output is the same size as in1, centered with respect to the 'full' output.
286+ boundary : {'fill', 'wrap', 'symm'}, optional
287+ A string indicating how to handle boundaries:
288+ - 'fill': Pads the input arrays with fillvalue.
289+ - 'wrap': Circularly wraps the input arrays.
290+ - 'symm': Symmetrically reflects the input arrays.
291+ fillvalue : float or int, optional
292+ The value to use for padding when boundary is 'fill'. Default is 0.
293+ method : str, one of 'direct', 'fft', or 'auto'
294+ Computation method to use. 'direct' uses direct convolution, 'fft' uses FFT-based convolution,
295+ and 'auto' lets the implementation choose the best method at runtime.
296+
297+ Returns
298+ -------
299+ out: tensor_variable
300+ The discrete linear convolution of in1 with in2.
301+
302+ """
303+ in1 = as_tensor_variable (in1 )
304+ in2 = as_tensor_variable (in2 )
305+ ndim = max (in1 .type .ndim , in2 .type .ndim )
306+
307+ def _pad_input (input_tensor , pad_width ):
308+ if boundary == "fill" :
309+ return pad (
310+ input_tensor ,
311+ pad_width = pad_width ,
312+ mode = "constant" ,
313+ constant_values = fillvalue ,
314+ )
315+ if boundary == "wrap" :
316+ return pad (input_tensor , pad_width = pad_width , mode = "wrap" )
317+ if boundary == "symm" :
318+ return pad (input_tensor , pad_width = pad_width , mode = "symmetric" )
319+ raise ValueError (f"Unsupported boundary mode: { boundary } " )
320+
321+ if mode == "same" :
322+ # Same mode is implemented as "valid" with a padded input.
323+ pad_width = zeros ((ndim , 2 ), dtype = "int64" )
324+ pad_width = pad_width [- 2 , 0 ].set (in2 .shape [- 2 ] // 2 )
325+ pad_width = pad_width [- 2 , 1 ].set ((in2 .shape [- 2 ] - 1 ) // 2 )
326+ pad_width = pad_width [- 1 , 0 ].set (in2 .shape [- 1 ] // 2 )
327+ pad_width = pad_width [- 1 , 1 ].set ((in2 .shape [- 1 ] - 1 ) // 2 )
328+ in1 = _pad_input (in1 , pad_width )
329+ mode = "valid"
330+
331+ if mode != "valid" and (boundary != "fill" or fillvalue != 0 ):
332+ # We use a valid convolution on an appropriately padded kernel
333+ * _ , k , l = in2 .shape
334+
335+ pad_width = zeros ((ndim , 2 ), dtype = "int64" )
336+ pad_width = pad_width [- 2 , :].set (k - 1 )
337+ pad_width = pad_width [- 1 , :].set (l - 1 )
338+ in1 = _pad_input (in1 , pad_width )
339+
340+ mode = "valid"
341+
342+ full_mode = as_scalar (np .bool_ (mode == "full" ))
343+ return type_cast (
344+ TensorVariable , Blockwise (Convolve2d (method = method ))(in1 , in2 , full_mode )
345+ )
0 commit comments