33
44import numpy as np
55from numpy import convolve as numpy_convolve
6- from scipy .signal import convolve2d as scipy_convolve2d
6+ from scipy .signal import convolve as scipy_convolve
77
88from pytensor .gradient import DisconnectedType
99from pytensor .graph import Apply , Constant
@@ -246,25 +246,92 @@ def convolve1d(
246246
247247
248248class Convolve2d (AbstractConvolveNd , Op ):
249- __props__ = ()
249+ __props__ = ("method" , )
250250 ndim = 2
251251
252+ def __init__ (self , method : Literal ["direct" , "fft" , "auto" ] = "auto" ):
253+ self .method = method
254+
252255 def perform (self , node , inputs , outputs ):
253256 in1 , in2 , full_mode = inputs
254257
255- # if all(inpt.dtype.kind in ['f', 'c'] for inpt in inputs):
256- # outputs[0][0] = scipy_convolve(in1, in2, mode=self.mode, method='fft')
257- #
258- # else:
259- # TODO: Why is .item() needed???
260- outputs [0 ][0 ] = scipy_convolve2d (
261- in1 ,
262- in2 ,
263- mode = "full" if full_mode .item () else "valid" ,
264- )
258+ # TODO: Why is .item() needed?
259+ mode : Literal ["full" , "valid" , "same" ] = "full" if full_mode .item () else "valid"
260+ outputs [0 ][0 ] = scipy_convolve (in1 , in2 , mode = mode , method = self .method )
261+
262+ def c_code_cache_version (self ):
263+ return (1 ,)
264+
265+ def c_code (self , node , name , inputs , outputs , sub ):
266+ in1 , in2 , full_mode = inputs
267+ [out ] = outputs
268+
269+ # For now, only the direct/correlation-based implementation is provided in C.
270+ # FFT-based convolution would require us to link to a vendored FFT library. Scipy uses
271+ # pypocketfft for that, but I'm not sure if we can easily call into that from here.
272+ code = f"""
273+ {{
274+ if (PyArray_NDIM({ in1 } ) != 2 || PyArray_NDIM({ in2 } ) != 2) {{
275+ PyErr_SetString(PyExc_ValueError, "Convolve2d C code expects 2D arrays.");
276+ { sub ["fail" ]} ;
277+ }}
265278
279+ npy_intp k0 = PyArray_DIM({ in2 } , 0);
280+ npy_intp k1 = PyArray_DIM({ in2 } , 1);
266281
267- blockwise_convolve_2d = Blockwise (Convolve2d ())
282+ if (k0 == 0 || k1 == 0) {{
283+ PyErr_SetString(PyExc_ValueError, "Convolve2d: second input (kernel) cannot be empty.");
284+ { sub ["fail" ]} ;
285+ }}
286+
287+ npy_intp dims[2] = {{k0, k1}};
288+ npy_intp strides[2];
289+ strides[0] = -PyArray_STRIDES({ in2 } )[0];
290+ strides[1] = -PyArray_STRIDES({ in2 } )[1];
291+
292+ char* data = (char*)PyArray_DATA({ in2 } )
293+ + (k0 - 1) * PyArray_STRIDES({ in2 } )[0]
294+ + (k1 - 1) * PyArray_STRIDES({ in2 } )[1];
295+
296+ PyArrayObject* in2_flipped_view = (PyArrayObject*)PyArray_NewFromDescr(
297+ Py_TYPE({ in2 } ),
298+ PyArray_DESCR({ in2 } ),
299+ 2,
300+ dims,
301+ strides,
302+ data,
303+ (PyArray_FLAGS({ in2 } ) & ~NPY_ARRAY_WRITEABLE),
304+ NULL
305+ );
306+
307+ if (!in2_flipped_view) {{
308+ PyErr_SetString(PyExc_RuntimeError, "Failed to create flipped kernel view for Convolve2d.");
309+ { sub ["fail" ]} ;
310+ }}
311+
312+ Py_INCREF({ in2 } );
313+ if (PyArray_SetBaseObject(in2_flipped_view, (PyObject*){ in2 } ) < 0) {{
314+ Py_DECREF({ in2 } );
315+ Py_DECREF(in2_flipped_view);
316+ in2_flipped_view = NULL;
317+ PyErr_SetString(PyExc_RuntimeError, "Failed to set base object for flipped kernel view in Convolve2d.");
318+ { sub ["fail" ]} ;
319+ }}
320+
321+ PyArray_UpdateFlags(in2_flipped_view, (NPY_ARRAY_C_CONTIGUOUS | NPY_ARRAY_F_CONTIGUOUS));
322+
323+ int mode_int = { full_mode } ? 2 : 0;
324+
325+ Py_XDECREF({ out } );
326+ { out } = (PyArrayObject*)PyArray_Correlate2((PyObject*){ in1 } , (PyObject*)in2_flipped_view, mode_int);
327+ Py_XDECREF(in2_flipped_view);
328+
329+ if (!{ out } ) {{
330+ { sub ["fail" ]} ;
331+ }}
332+ }}
333+ """
334+ return code
268335
269336
270337def convolve2d (
@@ -273,6 +340,7 @@ def convolve2d(
273340 mode : Literal ["full" , "valid" , "same" ] = "full" ,
274341 boundary : Literal ["fill" , "wrap" , "symm" ] = "fill" ,
275342 fillvalue : float | int = 0 ,
343+ method : Literal ["direct" , "fft" , "auto" ] = "auto" ,
276344) -> TensorVariable :
277345 """Convolve two two-dimensional arrays.
278346
@@ -296,6 +364,10 @@ def convolve2d(
296364 - 'symm': Symmetrically reflects the input arrays.
297365 fillvalue : float or int, optional
298366 The value to use for padding when boundary is 'fill'. Default is 0.
367+ method : str, one of 'direct', 'fft', or 'auto'
368+ Computation method to use. 'direct' uses direct convolution, 'fft' uses FFT-based convolution,
369+ and 'auto' lets the implementation choose the best method at runtime.
370+
299371 Returns
300372 -------
301373 out: tensor_variable
@@ -304,29 +376,44 @@ def convolve2d(
304376 """
305377 in1 = as_tensor_variable (in1 )
306378 in2 = as_tensor_variable (in2 )
379+ ndim = max (in1 .type .ndim , in2 .type .ndim )
380+
381+ def _pad_input (input_tensor , pad_width ):
382+ if boundary == "fill" :
383+ return pad (
384+ input_tensor ,
385+ pad_width = pad_width ,
386+ mode = "constant" ,
387+ constant_values = fillvalue ,
388+ )
389+ if boundary == "wrap" :
390+ return pad (input_tensor , pad_width = pad_width , mode = "wrap" )
391+ if boundary == "symm" :
392+ return pad (input_tensor , pad_width = pad_width , mode = "symmetric" )
393+ raise ValueError (f"Unsupported boundary mode: { boundary } " )
307394
308395 if mode == "same" :
309- raise NotImplementedError ("same mode not implemented for convolve2d" )
396+ # Same mode is implemented as "valid" with a padded input.
397+ pad_width = zeros ((ndim , 2 ), dtype = "int64" )
398+ pad_width = pad_width [- 2 , 0 ].set (in2 .shape [- 2 ] // 2 )
399+ pad_width = pad_width [- 2 , 1 ].set ((in2 .shape [- 2 ] - 1 ) // 2 )
400+ pad_width = pad_width [- 1 , 0 ].set (in2 .shape [- 1 ] // 2 )
401+ pad_width = pad_width [- 1 , 1 ].set ((in2 .shape [- 1 ] - 1 ) // 2 )
402+ in1 = _pad_input (in1 , pad_width )
403+ mode = "valid"
310404
311405 if mode != "valid" and (boundary != "fill" or fillvalue != 0 ):
312406 # We use a valid convolution on an appropriately padded kernel
313407 * _ , k , l = in2 .shape
314- ndim = max (in1 .type .ndim , in2 .type .ndim )
315408
316409 pad_width = zeros ((ndim , 2 ), dtype = "int64" )
317410 pad_width = pad_width [- 2 , :].set (k - 1 )
318411 pad_width = pad_width [- 1 , :].set (l - 1 )
319- if boundary == "fill" :
320- in1 = pad (
321- in1 , pad_width = pad_width , mode = "constant" , constant_values = fillvalue
322- )
323- elif boundary == "wrap" :
324- in1 = pad (in1 , pad_width = pad_width , mode = "wrap" )
325-
326- elif boundary == "symm" :
327- in1 = pad (in1 , pad_width = pad_width , mode = "symmetric" )
412+ in1 = _pad_input (in1 , pad_width )
328413
329414 mode = "valid"
330415
331416 full_mode = as_scalar (np .bool_ (mode == "full" ))
332- return type_cast (TensorVariable , blockwise_convolve_2d (in1 , in2 , full_mode ))
417+ return type_cast (
418+ TensorVariable , Blockwise (Convolve2d (method = method ))(in1 , in2 , full_mode )
419+ )
0 commit comments