1- from typing import TYPE_CHECKING , Literal , cast
1+ from typing import TYPE_CHECKING , Literal
2+ from typing import cast as type_cast
23
34import numpy as np
45from numpy import convolve as numpy_convolve
1314from pytensor .tensor .basic import as_tensor_variable , join , zeros
1415from pytensor .tensor .blockwise import Blockwise
1516from pytensor .tensor .math import maximum , minimum , switch
16- from pytensor .tensor .type import matrix , vector
17+ from pytensor .tensor .pad import pad
18+ from pytensor .tensor .subtensor import flip
19+ from pytensor .tensor .type import tensor
1720from pytensor .tensor .variable import TensorVariable
1821
1922
2023if TYPE_CHECKING :
2124 from pytensor .tensor import TensorLike
2225
2326
24- class Convolve1d ( COp ) :
27+ class AbstractConvolveNd :
2528 __props__ = ()
26- gufunc_signature = "(n),(k),()->(o)"
29+ ndim : int
30+
31+ @property
32+ def gufunc_signature (self ):
33+ data_signature = "," .join ([f"n{ i } " for i in range (self .ndim )])
34+ kernel_signature = "," .join ([f"k{ i } " for i in range (self .ndim )])
35+ output_signature = "," .join ([f"o{ i } " for i in range (self .ndim )])
36+
37+ return f"({ data_signature } ),({ kernel_signature } ),()->({ output_signature } )"
2738
2839 def make_node (self , in1 , in2 , full_mode ):
2940 in1 = as_tensor_variable (in1 )
3041 in2 = as_tensor_variable (in2 )
3142 full_mode = as_scalar (full_mode )
3243
33- if not (in1 .ndim == 1 and in2 .ndim == 1 ):
34- raise ValueError ("Convolution inputs must be vector (ndim=1)" )
44+ ndim = self .ndim
45+ if not (in1 .ndim == ndim and in2 .ndim == self .ndim ):
46+ raise ValueError (
47+ f"Convolution inputs must have ndim={ ndim } , got: in1={ in1 .ndim } , in2={ in2 .ndim } "
48+ )
3549 if not full_mode .dtype == "bool" :
36- raise ValueError ("Convolution mode must be a boolean type" )
50+ raise ValueError ("Convolution full_mode flag must be a boolean type" )
3751
38- dtype = upcast (in1 .dtype , in2 .dtype )
39- n = in1 .type .shape [0 ]
40- k = in2 .type .shape [0 ]
4152 match full_mode :
4253 case Constant ():
4354 static_mode = "full" if full_mode .data else "valid"
4455 case _:
4556 static_mode = None
4657
47- if n is None or k is None or static_mode is None :
48- out_shape = (None ,)
49- elif static_mode == "full" :
50- out_shape = (n + k - 1 ,)
51- else : # mode == "valid":
52- out_shape = (max (n , k ) - min (n , k ) + 1 ,)
58+ if static_mode is None :
59+ out_shape = (None ,) * ndim
60+ else :
61+ out_shape = []
62+ # TODO: Raise if static shapes are not valid (one input size doesn't dominate the other)
63+ for n , k in zip (in1 .type .shape , in2 .type .shape ):
64+ if n is None or k is None :
65+ out_shape .append (None )
66+ elif static_mode == "full" :
67+ out_shape .append (
68+ n + k - 1 ,
69+ )
70+ else : # mode == "valid":
71+ out_shape .append (
72+ max (n , k ) - min (n , k ) + 1 ,
73+ )
74+ out_shape = tuple (out_shape )
5375
54- out = vector (dtype = dtype , shape = out_shape )
55- return Apply (self , [in1 , in2 , full_mode ], [out ])
76+ dtype = upcast (in1 .dtype , in2 .dtype )
5677
57- def perform (self , node , inputs , outputs ):
58- # We use numpy_convolve as that's what scipy would use if method="direct" was passed.
59- # And mode != "same", which this Op doesn't cover anyway.
60- in1 , in2 , full_mode = inputs
61- outputs [0 ][0 ] = numpy_convolve (in1 , in2 , mode = "full" if full_mode else "valid" )
78+ out = tensor (dtype = dtype , shape = out_shape )
79+ return Apply (self , [in1 , in2 , full_mode ], [out ])
6280
6381 def infer_shape (self , fgraph , node , shapes ):
6482 _ , _ , full_mode = node .inputs
6583 in1_shape , in2_shape , _ = shapes
66- n = in1_shape [ 0 ]
67- k = in2_shape [ 0 ]
68- shape_valid = maximum ( n , k ) - minimum ( n , k ) + 1
69- shape_full = n + k - 1
70- shape = switch ( full_mode , shape_full , shape_valid )
71- return [[ shape ] ]
84+ out_shape = [
85+ switch ( full_mode , n + k - 1 , maximum ( n , k ) - minimum ( n , k ) + 1 )
86+ for n , k in zip ( in1_shape , in2_shape )
87+ ]
88+
89+ return [out_shape ]
7290
7391 def connection_pattern (self , node ):
7492 return [[True ], [True ], [False ]]
@@ -77,22 +95,34 @@ def L_op(self, inputs, outputs, output_grads):
7795 in1 , in2 , full_mode = inputs
7896 [grad ] = output_grads
7997
80- n = in1 .shape [0 ]
81- k = in2 .shape [0 ]
98+ n = in1 .shape
99+ k = in2 .shape
100+ # Note: this assumes the shape of one input dominates the other over all dimensions (which is required for a valid forward)
82101
83102 # If mode is "full", or mode is "valid" and k >= n, then in1_bar mode should use "valid" convolve
84103 # The expression below is equivalent to ~(full_mode | (k >= n))
85- full_mode_in1_bar = ~ full_mode & (k < n )
104+ full_mode_in1_bar = ~ full_mode & (k < n ). any ()
86105 # If mode is "full", or mode is "valid" and n >= k, then in2_bar mode should use "valid" convolve
87106 # The expression below is equivalent to ~(full_mode | (n >= k))
88- full_mode_in2_bar = ~ full_mode & (n < k )
107+ full_mode_in2_bar = ~ full_mode & (n < k ). any ()
89108
90109 return [
91- self (grad , in2 [:: - 1 ] , full_mode_in1_bar ),
92- self (grad , in1 [:: - 1 ] , full_mode_in2_bar ),
110+ self (grad , flip ( in2 ) , full_mode_in1_bar ),
111+ self (grad , flip ( in1 ) , full_mode_in2_bar ),
93112 DisconnectedType ()(),
94113 ]
95114
115+
116+ class Convolve1d (AbstractConvolveNd , COp ):
117+ __props__ = ()
118+ ndim = 1
119+
120+ def perform (self , node , inputs , outputs ):
121+ # We use numpy_convolve as that's what scipy would use if method="direct" was passed.
122+ # And mode != "same", which this Op doesn't cover anyway.
123+ in1 , in2 , full_mode = inputs
124+ outputs [0 ][0 ] = numpy_convolve (in1 , in2 , mode = "full" if full_mode else "valid" )
125+
96126 def c_code_cache_version (self ):
97127 return (2 ,)
98128
@@ -212,94 +242,29 @@ def convolve1d(
212242 mode = "valid"
213243
214244 full_mode = as_scalar (np .bool_ (mode == "full" ))
215- return cast (TensorVariable , blockwise_convolve_1d (in1 , in2 , full_mode ))
216-
217-
218- class Convolve2D (Op ):
219- __props__ = ("mode" , "boundary" , "fillvalue" )
220- gufunc_signature = "(n,m),(k,l)->(o,p)"
245+ return type_cast (TensorVariable , blockwise_convolve_1d (in1 , in2 , full_mode ))
221246
222- def __init__ (
223- self ,
224- mode : Literal ["full" , "valid" ] = "full" ,
225- boundary : Literal ["fill" , "wrap" , "symm" ] = "fill" ,
226- fillvalue : float | int = 0 ,
227- ):
228- if mode not in ("full" , "valid" ):
229- raise ValueError (f"Invalid mode: { mode } " )
230247
231- self .mode = mode
232- self .fillvalue = fillvalue
233- self .boundary = boundary
234-
235- def make_node (self , in1 , in2 ):
236- in1 , in2 = map (as_tensor_variable , (in1 , in2 ))
237-
238- assert in1 .ndim == 2
239- assert in2 .ndim == 2
240-
241- dtype = upcast (in1 .dtype , in2 .dtype )
242-
243- n , m = in1 .type .shape
244- k , l = in2 .type .shape
245-
246- if self .mode == "full" :
247- shape_1 = None if (n is None or k is None ) else n + k - 1
248- shape_2 = None if (m is None or l is None ) else m + l - 1
249-
250- elif self .mode == "valid" :
251- shape_1 = None if (n is None or k is None ) else max (n , k ) - max (n , k ) + 1
252- shape_2 = None if (m is None or l is None ) else max (m , l ) - min (m , l ) + 1
253-
254- else : # mode == "same"
255- shape_1 = n
256- shape_2 = m
257-
258- out_shape = (shape_1 , shape_2 )
259- out = matrix (dtype = dtype , shape = out_shape )
260- return Apply (self , [in1 , in2 ], [out ])
248+ class Convolve2d (AbstractConvolveNd , Op ):
249+ __props__ = ()
250+ ndim = 2
261251
262252 def perform (self , node , inputs , outputs ):
263- in1 , in2 = inputs
253+ in1 , in2 , full_mode = inputs
264254
265255 # if all(inpt.dtype.kind in ['f', 'c'] for inpt in inputs):
266256 # outputs[0][0] = scipy_convolve(in1, in2, mode=self.mode, method='fft')
267257 #
268258 # else:
259+ # TODO: Why is .item() needed???
269260 outputs [0 ][0 ] = scipy_convolve2d (
270- in1 , in2 , mode = self .mode , fillvalue = self .fillvalue , boundary = self .boundary
261+ in1 ,
262+ in2 ,
263+ mode = "full" if full_mode .item () else "valid" ,
271264 )
272265
273- def infer_shape (self , fgraph , node , shapes ):
274- in1_shape , in2_shape = shapes
275- n , m = in1_shape
276- k , l = in2_shape
277-
278- if self .mode == "full" :
279- shape = (n + k - 1 , m + l - 1 )
280- elif self .mode == "valid" :
281- shape = (
282- maximum (n , k ) - minimum (n , k ) + 1 ,
283- maximum (m , l ) - minimum (m , l ) + 1 ,
284- )
285- else : # self.mode == 'same':
286- shape = (n , m )
287-
288- return [shape ]
289-
290- def L_op (self , inputs , outputs , output_grads ):
291- in1 , in2 = inputs
292- incoming_grads = output_grads [0 ]
293-
294- if self .mode == "full" :
295- prop_dict = self ._props_dict ()
296- prop_dict ["mode" ] = "valid"
297- conv_valid = type (self )(** prop_dict )
298-
299- in1_grad = conv_valid (in2 , incoming_grads )
300- in2_grad = conv_valid (in1 , incoming_grads )
301266
302- return [ in1_grad , in2_grad ]
267+ blockwise_convolve_2d = Blockwise ( Convolve2d ())
303268
304269
305270def convolve2d (
@@ -340,10 +305,28 @@ def convolve2d(
340305 in1 = as_tensor_variable (in1 )
341306 in2 = as_tensor_variable (in2 )
342307
343- # TODO: Handle boundaries symbolically
344- # TODO: Handle 'same' symbolically
308+ if mode == "same" :
309+ raise NotImplementedError ("same mode not implemented for convolve2d" )
310+
311+ if mode != "valid" and (boundary != "fill" or fillvalue != 0 ):
312+ # We use a valid convolution on an appropriately padded kernel
313+ * _ , k , l = in2 .shape
314+ ndim = max (in1 .type .ndim , in2 .type .ndim )
315+
316+ pad_width = zeros ((ndim , 2 ), dtype = "int64" )
317+ pad_width = pad_width [- 2 , :].set (k - 1 )
318+ pad_width = pad_width [- 1 , :].set (l - 1 )
319+ if boundary == "fill" :
320+ in1 = pad (
321+ in1 , pad_width = pad_width , mode = "constant" , constant_values = fillvalue
322+ )
323+ elif boundary == "wrap" :
324+ in1 = pad (in1 , pad_width = pad_width , mode = "wrap" )
325+
326+ elif boundary == "symm" :
327+ in1 = pad (in1 , pad_width = pad_width , mode = "symmetric" )
345328
346- blockwise_convolve = Blockwise (
347- Convolve2D ( mode = mode , boundary = boundary , fillvalue = fillvalue )
348- )
349- return cast (TensorVariable , blockwise_convolve (in1 , in2 ))
329+ mode = "valid"
330+
331+ full_mode = as_scalar ( np . bool_ ( mode == "full" ) )
332+ return type_cast (TensorVariable , blockwise_convolve_2d (in1 , in2 , full_mode ))
0 commit comments