Skip to content

Commit d04a1d0

Browse files
relax tests again
1 parent 1c0a2f4 commit d04a1d0

File tree

1 file changed

+19
-1
lines changed
  • pytensor/link/jax/dispatch/signal

1 file changed

+19
-1
lines changed

pytensor/link/jax/dispatch/signal/conv.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from pytensor.link.jax.dispatch import jax_funcify
44
from pytensor.tensor.basic import get_underlying_scalar_constant_value
55
from pytensor.tensor.exceptions import NotScalarConstantError
6-
from pytensor.tensor.signal.conv import Convolve1d
6+
from pytensor.tensor.signal.conv import Convolve1d, Convolve2d
77

88

99
@jax_funcify.register(Convolve1d)
@@ -22,3 +22,21 @@ def conv1d(data, kernel, _runtime_full_mode):
2222
return jax.numpy.convolve(data, kernel, mode=static_mode)
2323

2424
return conv1d
25+
26+
27+
@jax_funcify.register(Convolve2d)
28+
def jax_funcify_Convolve2d(op, node, **kwargs):
29+
_, _, full_mode = node.inputs
30+
31+
try:
32+
full_mode = get_underlying_scalar_constant_value(full_mode)
33+
except NotScalarConstantError:
34+
raise NotImplementedError(
35+
"Cannot compile Convolve1D to jax without static mode"
36+
)
37+
static_mode = "full" if full_mode else "valid"
38+
39+
def conv2d(data, kernel, _runtime_full_mode):
40+
return jax.scipy.signal.convolve2d(data, kernel, mode=static_mode)
41+
42+
return conv2d

0 commit comments

Comments
 (0)