33
44import numpy as np
55from numpy import convolve as numpy_convolve
6- from scipy .signal import convolve2d as scipy_convolve2d
6+ from scipy .signal import convolve as scipy_convolve
77
88from pytensor .gradient import DisconnectedType
99from pytensor .graph import Apply , Constant
@@ -113,7 +113,7 @@ def L_op(self, inputs, outputs, output_grads):
113113 ]
114114
115115
116- class Convolve1d (AbstractConvolveNd , COp ):
116+ class Convolve1d (AbstractConvolveNd , COp ): # type: ignore[misc]
117117 __props__ = ()
118118 ndim = 1
119119
@@ -245,26 +245,19 @@ def convolve1d(
245245 return type_cast (TensorVariable , blockwise_convolve_1d (in1 , in2 , full_mode ))
246246
247247
248- class Convolve2d (AbstractConvolveNd , Op ):
249- __props__ = ()
248+ class Convolve2d (AbstractConvolveNd , Op ): # type: ignore[misc]
249+ __props__ = ("method" ,) # type: ignore[assignment]
250250 ndim = 2
251251
252+ def __init__ (self , method : Literal ["direct" , "fft" , "auto" ] = "auto" ):
253+ self .method = method
254+
252255 def perform (self , node , inputs , outputs ):
253256 in1 , in2 , full_mode = inputs
254257
255- # if all(inpt.dtype.kind in ['f', 'c'] for inpt in inputs):
256- # outputs[0][0] = scipy_convolve(in1, in2, mode=self.mode, method='fft')
257- #
258- # else:
259- # TODO: Why is .item() needed???
260- outputs [0 ][0 ] = scipy_convolve2d (
261- in1 ,
262- in2 ,
263- mode = "full" if full_mode .item () else "valid" ,
264- )
265-
266-
267- blockwise_convolve_2d = Blockwise (Convolve2d ())
258+ # TODO: Why is .item() needed?
259+ mode : Literal ["full" , "valid" , "same" ] = "full" if full_mode .item () else "valid"
260+ outputs [0 ][0 ] = scipy_convolve (in1 , in2 , mode = mode , method = self .method )
268261
269262
270263def convolve2d (
@@ -273,6 +266,7 @@ def convolve2d(
273266 mode : Literal ["full" , "valid" , "same" ] = "full" ,
274267 boundary : Literal ["fill" , "wrap" , "symm" ] = "fill" ,
275268 fillvalue : float | int = 0 ,
269+ method : Literal ["direct" , "fft" , "auto" ] = "auto" ,
276270) -> TensorVariable :
277271 """Convolve two two-dimensional arrays.
278272
@@ -296,6 +290,10 @@ def convolve2d(
296290 - 'symm': Symmetrically reflects the input arrays.
297291 fillvalue : float or int, optional
298292 The value to use for padding when boundary is 'fill'. Default is 0.
293+ method : str, one of 'direct', 'fft', or 'auto'
294+ Computation method to use. 'direct' uses direct convolution, 'fft' uses FFT-based convolution,
295+ and 'auto' lets the implementation choose the best method at runtime.
296+
299297 Returns
300298 -------
301299 out: tensor_variable
@@ -304,29 +302,44 @@ def convolve2d(
304302 """
305303 in1 = as_tensor_variable (in1 )
306304 in2 = as_tensor_variable (in2 )
305+ ndim = max (in1 .type .ndim , in2 .type .ndim )
306+
307+ def _pad_input (input_tensor , pad_width ):
308+ if boundary == "fill" :
309+ return pad (
310+ input_tensor ,
311+ pad_width = pad_width ,
312+ mode = "constant" ,
313+ constant_values = fillvalue ,
314+ )
315+ if boundary == "wrap" :
316+ return pad (input_tensor , pad_width = pad_width , mode = "wrap" )
317+ if boundary == "symm" :
318+ return pad (input_tensor , pad_width = pad_width , mode = "symmetric" )
319+ raise ValueError (f"Unsupported boundary mode: { boundary } " )
307320
308321 if mode == "same" :
309- raise NotImplementedError ("same mode not implemented for convolve2d" )
322+ # Same mode is implemented as "valid" with a padded input.
323+ pad_width = zeros ((ndim , 2 ), dtype = "int64" )
324+ pad_width = pad_width [- 2 , 0 ].set (in2 .shape [- 2 ] // 2 )
325+ pad_width = pad_width [- 2 , 1 ].set ((in2 .shape [- 2 ] - 1 ) // 2 )
326+ pad_width = pad_width [- 1 , 0 ].set (in2 .shape [- 1 ] // 2 )
327+ pad_width = pad_width [- 1 , 1 ].set ((in2 .shape [- 1 ] - 1 ) // 2 )
328+ in1 = _pad_input (in1 , pad_width )
329+ mode = "valid"
310330
311331 if mode != "valid" and (boundary != "fill" or fillvalue != 0 ):
312332 # We use a valid convolution on an appropriately padded kernel
313333 * _ , k , l = in2 .shape
314- ndim = max (in1 .type .ndim , in2 .type .ndim )
315334
316335 pad_width = zeros ((ndim , 2 ), dtype = "int64" )
317336 pad_width = pad_width [- 2 , :].set (k - 1 )
318337 pad_width = pad_width [- 1 , :].set (l - 1 )
319- if boundary == "fill" :
320- in1 = pad (
321- in1 , pad_width = pad_width , mode = "constant" , constant_values = fillvalue
322- )
323- elif boundary == "wrap" :
324- in1 = pad (in1 , pad_width = pad_width , mode = "wrap" )
325-
326- elif boundary == "symm" :
327- in1 = pad (in1 , pad_width = pad_width , mode = "symmetric" )
338+ in1 = _pad_input (in1 , pad_width )
328339
329340 mode = "valid"
330341
331342 full_mode = as_scalar (np .bool_ (mode == "full" ))
332- return type_cast (TensorVariable , blockwise_convolve_2d (in1 , in2 , full_mode ))
343+ return type_cast (
344+ TensorVariable , Blockwise (Convolve2d (method = method ))(in1 , in2 , full_mode )
345+ )
0 commit comments