@@ -65,6 +65,8 @@ def __init__(
65
65
)
66
66
self .ndim = len (self .dims )
67
67
self .name = name
68
+ self .numpy_dtype = np .dtype (self .dtype )
69
+ self .filter_checks_isfinite = False
68
70
69
71
def clone (
70
72
self ,
@@ -82,8 +84,9 @@ def clone(
82
84
return type (self )(dtype = dtype , shape = shape , dims = dims , ** kwargs )
83
85
84
86
def filter (self , value , strict = False , allow_downcast = None ):
85
- # TODO implement this
86
- return value
87
+ return TensorType .filter (
88
+ self , value , strict = strict , allow_downcast = allow_downcast
89
+ )
87
90
88
91
def convert_variable (self , var ):
89
92
# TODO: Implement this
@@ -689,17 +692,20 @@ def as_xtensor(x, name=None, dims: Sequence[str] | None = None):
689
692
if isinstance (x .type , XTensorType ):
690
693
return x
691
694
if isinstance (x .type , TensorType ):
692
- if x .type .ndim > 0 and dims is None :
693
- raise TypeError (
694
- "non-scalar TensorVariable cannot be converted to XTensorVariable without dims."
695
- )
696
- return px .basic .xtensor_from_tensor (x , dims )
695
+ if dims is None :
696
+ if x .type .ndim == 0 :
697
+ dims = ()
698
+ else :
699
+ raise TypeError (
700
+ "non-scalar TensorVariable cannot be converted to XTensorVariable without dims."
701
+ )
702
+ return px .basic .xtensor_from_tensor (x , dims = dims , name = name )
697
703
else :
698
704
raise TypeError (
699
705
"Variable with type {x.type} cannot be converted to XTensorVariable."
700
706
)
701
707
try :
702
- return xtensor_constant (x , name = name , dims = dims )
708
+ return xtensor_constant (x , dims = dims , name = name )
703
709
except TypeError as err :
704
710
raise TypeError (f"Cannot convert { x } to XTensorType { type (x )} " ) from err
705
711
0 commit comments