Skip to content

Commit 65815f5

Browse files
committed
Implement proper type.filter
1 parent 7b8877b commit 65815f5

File tree

1 file changed

+14
-8
lines changed

1 file changed

+14
-8
lines changed

pytensor/xtensor/type.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ def __init__(
6565
)
6666
self.ndim = len(self.dims)
6767
self.name = name
68+
self.numpy_dtype = np.dtype(self.dtype)
69+
self.filter_checks_isfinite = False
6870

6971
def clone(
7072
self,
@@ -82,8 +84,9 @@ def clone(
8284
return type(self)(dtype=dtype, shape=shape, dims=dims, **kwargs)
8385

8486
def filter(self, value, strict=False, allow_downcast=None):
85-
# TODO implement this
86-
return value
87+
return TensorType.filter(
88+
self, value, strict=strict, allow_downcast=allow_downcast
89+
)
8790

8891
def convert_variable(self, var):
8992
# TODO: Implement this
@@ -689,17 +692,20 @@ def as_xtensor(x, name=None, dims: Sequence[str] | None = None):
689692
if isinstance(x.type, XTensorType):
690693
return x
691694
if isinstance(x.type, TensorType):
692-
if x.type.ndim > 0 and dims is None:
693-
raise TypeError(
694-
"non-scalar TensorVariable cannot be converted to XTensorVariable without dims."
695-
)
696-
return px.basic.xtensor_from_tensor(x, dims)
695+
if dims is None:
696+
if x.type.ndim == 0:
697+
dims = ()
698+
else:
699+
raise TypeError(
700+
"non-scalar TensorVariable cannot be converted to XTensorVariable without dims."
701+
)
702+
return px.basic.xtensor_from_tensor(x, dims=dims, name=name)
697703
else:
698704
raise TypeError(
699705
"Variable with type {x.type} cannot be converted to XTensorVariable."
700706
)
701707
try:
702-
return xtensor_constant(x, name=name, dims=dims)
708+
return xtensor_constant(x, dims=dims, name=name)
703709
except TypeError as err:
704710
raise TypeError(f"Cannot convert {x} to XTensorType {type(x)}") from err
705711

0 commit comments

Comments
 (0)