Skip to content

Commit 56c2251

Browse files
committed
Implement C code for ExtractDiagonal and ARange
Set view flag of ExtractDiagonal to True and respect by default
1 parent af2bfec commit 56c2251

File tree

1 file changed

+89
-34
lines changed

1 file changed

+89
-34
lines changed

pytensor/tensor/basic.py

Lines changed: 89 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -3207,13 +3207,14 @@ def tile(
32073207
return A_replicated.reshape(tiled_shape)
32083208

32093209

3210-
class ARange(Op):
3210+
class ARange(COp):
32113211
"""Create an array containing evenly spaced values within a given interval.
32123212
32133213
Parameters and behaviour are the same as numpy.arange().
32143214
32153215
"""
32163216

3217+
# TODO: Arange should work with scalars as inputs, not arrays
32173218
__props__ = ("dtype",)
32183219

32193220
def __init__(self, dtype):
@@ -3293,13 +3294,30 @@ def upcast(var):
32933294
)
32943295
]
32953296

3296-
def perform(self, node, inp, out_):
3297-
start, stop, step = inp
3298-
(out,) = out_
3299-
start = start.item()
3300-
stop = stop.item()
3301-
step = step.item()
3302-
out[0] = np.arange(start, stop, step, dtype=self.dtype)
3297+
def perform(self, node, inputs, output_storage):
3298+
start, stop, step = inputs
3299+
output_storage[0][0] = np.arange(
3300+
start.item(), stop.item(), step.item(), dtype=self.dtype
3301+
)
3302+
3303+
def c_code(self, node, nodename, input_names, output_names, sub):
3304+
[start_name, stop_name, step_name] = input_names
3305+
[out_name] = output_names
3306+
typenum = np.dtype(self.dtype).num
3307+
return f"""
3308+
double start = ((dtype_{start_name}*)PyArray_DATA({start_name}))[0];
3309+
double stop = ((dtype_{stop_name}*)PyArray_DATA({stop_name}))[0];
3310+
double step = ((dtype_{step_name}*)PyArray_DATA({step_name}))[0];
3311+
//printf("start: %f, stop: %f, step: %f\\n", start, stop, step);
3312+
Py_XDECREF({out_name});
3313+
{out_name} = (PyArrayObject*) PyArray_Arange(start, stop, step, {typenum});
3314+
if (!{out_name}) {{
3315+
{sub["fail"]}
3316+
}}
3317+
"""
3318+
3319+
def c_code_cache_version(self):
3320+
return (0,)
33033321

33043322
def connection_pattern(self, node):
33053323
return [[True], [False], [True]]
@@ -3686,7 +3704,7 @@ def inverse_permutation(perm):
36863704

36873705

36883706
# TODO: optimization to insert ExtractDiag with view=True
3689-
class ExtractDiag(Op):
3707+
class ExtractDiag(COp):
36903708
"""
36913709
Return specified diagonals.
36923710
@@ -3742,7 +3760,7 @@ class ExtractDiag(Op):
37423760

37433761
__props__ = ("offset", "axis1", "axis2", "view")
37443762

3745-
def __init__(self, offset=0, axis1=0, axis2=1, view=False):
3763+
def __init__(self, offset=0, axis1=0, axis2=1, view=True):
37463764
self.view = view
37473765
if self.view:
37483766
self.view_map = {0: [0]}
@@ -3765,24 +3783,74 @@ def make_node(self, x):
37653783
if x.ndim < 2:
37663784
raise ValueError("ExtractDiag needs an input with 2 or more dimensions", x)
37673785

3768-
out_shape = [
3769-
st_dim
3770-
for i, st_dim in enumerate(x.type.shape)
3771-
if i not in (self.axis1, self.axis2)
3772-
] + [None]
3786+
if (dim1 := x.type.shape[self.axis1]) is not None and (
3787+
dim2 := x.type.shape[self.axis2]
3788+
) is not None:
3789+
offset = self.offset
3790+
if offset > 0:
3791+
diag_size = int(np.clip(dim2 - offset, 0, dim1))
3792+
elif offset < 0:
3793+
diag_size = int(np.clip(dim1 + offset, 0, dim2))
3794+
else:
3795+
diag_size = int(np.minimum(dim1, dim2))
3796+
else:
3797+
diag_size = None
3798+
3799+
out_shape = (
3800+
*(
3801+
dim
3802+
for i, dim in enumerate(x.type.shape)
3803+
if i not in (self.axis1, self.axis2)
3804+
),
3805+
diag_size,
3806+
)
37733807

37743808
return Apply(
37753809
self,
37763810
[x],
3777-
[x.type.clone(dtype=x.dtype, shape=tuple(out_shape))()],
3811+
[x.type.clone(dtype=x.dtype, shape=out_shape)()],
37783812
)
37793813

3780-
def perform(self, node, inputs, outputs):
3814+
def perform(self, node, inputs, output_storage):
37813815
(x,) = inputs
3782-
(z,) = outputs
3783-
z[0] = x.diagonal(self.offset, self.axis1, self.axis2)
3784-
if not self.view:
3785-
z[0] = z[0].copy()
3816+
out = x.diagonal(self.offset, self.axis1, self.axis2)
3817+
if self.view:
3818+
try:
3819+
out.flags.writeable = True
3820+
except ValueError:
3821+
# We can't make this array writable
3822+
out = out.copy()
3823+
else:
3824+
out = out.copy()
3825+
output_storage[0][0] = out
3826+
3827+
def c_code(self, node, nodename, input_names, output_names, sub):
3828+
[x_name] = input_names
3829+
[out_name] = output_names
3830+
return f"""
3831+
Py_XDECREF({out_name});
3832+
3833+
{out_name} = (PyArrayObject*) PyArray_Diagonal({x_name}, {self.offset}, {self.axis1}, {self.axis2});
3834+
if (!{out_name}) {{
3835+
{sub["fail"]} // Error already set by Numpy
3836+
}}
3837+
3838+
if ({int(self.view)} && PyArray_ISWRITEABLE({x_name})) {{
3839+
// Make output writeable if input was writeable
3840+
PyArray_ENABLEFLAGS({out_name}, NPY_ARRAY_WRITEABLE);
3841+
}} else {{
3842+
// Make a copy
3843+
PyArrayObject *{out_name}_copy = (PyArrayObject*) PyArray_Copy({out_name});
3844+
Py_DECREF({out_name});
3845+
if (!{out_name}_copy) {{
3846+
{sub['fail']}; // Error already set by Numpy
3847+
}}
3848+
{out_name} = {out_name}_copy;
3849+
}}
3850+
"""
3851+
3852+
def c_code_cache_version(self):
3853+
return (0,)
37863854

37873855
def grad(self, inputs, gout):
37883856
# Avoid circular import
@@ -3829,19 +3897,6 @@ def infer_shape(self, fgraph, node, shapes):
38293897
out_shape.append(diag_size)
38303898
return [tuple(out_shape)]
38313899

3832-
def __setstate__(self, state):
3833-
self.__dict__.update(state)
3834-
3835-
if self.view:
3836-
self.view_map = {0: [0]}
3837-
3838-
if "offset" not in state:
3839-
self.offset = 0
3840-
if "axis1" not in state:
3841-
self.axis1 = 0
3842-
if "axis2" not in state:
3843-
self.axis2 = 1
3844-
38453900

38463901
def extract_diag(x):
38473902
warnings.warn(

0 commit comments

Comments
 (0)