Skip to content

Commit 24e7c66

Browse files
committed
Implement XTensorVariable version of RandomVariables
1 parent 496b7b2 commit 24e7c66

File tree

10 files changed

+737
-54
lines changed

10 files changed

+737
-54
lines changed

pytensor/tensor/random/basic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1625,8 +1625,7 @@ def rng_fn_scipy(cls, rng, n, p, size):
16251625
return stats.nbinom.rvs(n, p, size=size, random_state=rng)
16261626

16271627

1628-
nbinom = NegBinomialRV()
1629-
negative_binomial = NegBinomialRV()
1628+
nbinom = negative_binomial = NegBinomialRV()
16301629

16311630

16321631
class BetaBinomialRV(ScipyRandomVariable):
@@ -1808,6 +1807,7 @@ def rng_fn(cls, rng, n, p, size):
18081807

18091808
multinomial = MultinomialRV()
18101809

1810+
18111811
vsearchsorted = np.vectorize(np.searchsorted, otypes=[int], signature="(n),()->()")
18121812

18131813

pytensor/tensor/rewriting/basic.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -759,6 +759,7 @@ def local_remove_useless_assert(fgraph, node):
759759
return [new_var]
760760

761761

762+
@register_infer_shape
762763
@node_rewriter([Assert])
763764
def local_remove_all_assert(fgraph, node):
764765
r"""A rewrite that removes all `Assert`\s from a graph.
@@ -768,9 +769,6 @@ def local_remove_all_assert(fgraph, node):
768769
See the :ref:`unsafe` section.
769770
770771
"""
771-
if not isinstance(node.op, Assert):
772-
return
773-
774772
return [node.inputs[0]]
775773

776774

pytensor/tensor/utils.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import pytensor
1010
from pytensor.graph import FunctionGraph, Variable
1111
from pytensor.npy_2_compat import normalize_axis_tuple
12+
from pytensor.tensor.exceptions import NotScalarConstantError
1213
from pytensor.utils import hash_from_code
1314

1415

@@ -256,3 +257,31 @@ def faster_ndindex(shape: Sequence[int]):
256257
https://github.com/numpy/numpy/issues/28921
257258
"""
258259
return product(*(range(s) for s in shape))
260+
261+
262+
def get_static_shape_from_size_variables(
263+
size_vars: Sequence[Variable],
264+
) -> tuple[int | None, ...]:
265+
"""Get static shape from size variables.
266+
267+
Parameters
268+
----------
269+
size_vars : Sequence[Variable]
270+
A sequence of variables representing the size of each dimension.
271+
Returns
272+
-------
273+
tuple[int | None, ...]
274+
A tuple containing the static lengths of each dimension, or None if
275+
the length is not statically known.
276+
"""
277+
from pytensor.tensor.basic import get_scalar_constant_value
278+
279+
static_lengths = [None] * len(size_vars)
280+
for i, length in enumerate(size_vars):
281+
try:
282+
static_length = get_scalar_constant_value(length)
283+
except NotScalarConstantError:
284+
pass
285+
else:
286+
static_lengths[i] = int(static_length)
287+
return tuple(static_lengths)

pytensor/tensor/variable.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,9 @@ def dimshuffle(self, *pattern):
349349
if (len(pattern) == 1) and (isinstance(pattern[0], list | tuple | np.ndarray)):
350350
pattern = pattern[0]
351351
ds_op = pt.elemwise.DimShuffle(input_ndim=self.type.ndim, new_order=pattern)
352+
if ds_op.new_order == tuple(range(self.type.ndim)):
353+
# No-op
354+
return self
352355
return ds_op(self)
353356

354357
def flatten(self, ndim=1):

pytensor/xtensor/basic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,8 @@ def make_node(self, x):
6767
return Apply(self, [x], [output])
6868

6969

70-
def xtensor_from_tensor(x, dims):
71-
return XTensorFromTensor(dims=dims)(x)
70+
def xtensor_from_tensor(x, dims, name=None):
71+
return XTensorFromTensor(dims=dims)(x, name=name)
7272

7373

7474
class Rename(XTypeCastOp):

pytensor/xtensor/random.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
from collections.abc import Sequence
2+
from functools import wraps
3+
4+
import pytensor.tensor.random.basic as ptr
5+
from pytensor.graph.basic import Variable
6+
from pytensor.tensor.random.op import RandomVariable
7+
from pytensor.xtensor import as_xtensor
8+
from pytensor.xtensor.vectorization import XRV
9+
10+
11+
def _as_xrv(
12+
core_op: RandomVariable,
13+
core_inps_dims_map: Sequence[Sequence[int]] | None = None,
14+
core_out_dims_map: Sequence[int] | None = None,
15+
):
16+
if core_inps_dims_map is None:
17+
# Assume core_dims map positionally from left to right
18+
core_inps_dims_map = [tuple(range(ndim)) for ndim in core_op.ndims_params]
19+
if core_out_dims_map is None:
20+
# Assume core_dims map positionally from left to right
21+
core_out_dims_map = tuple(range(core_op.ndim_supp))
22+
23+
core_dims_needed = max(
24+
(*(len(i) for i in core_inps_dims_map), len(core_out_dims_map)), default=0
25+
)
26+
27+
@wraps(core_op)
28+
def xrv_constructor(
29+
*params,
30+
core_dims: Sequence[str] | str | None = None,
31+
extra_dims: dict[str, Variable] | None = None,
32+
rng: Variable | None = None,
33+
):
34+
if core_dims is None:
35+
core_dims = ()
36+
if core_dims_needed:
37+
raise ValueError(
38+
f"{core_op.name} needs {core_dims_needed} core_dims to be specified"
39+
)
40+
elif isinstance(core_dims, str):
41+
core_dims = (core_dims,)
42+
43+
if len(core_dims) != core_dims_needed:
44+
raise ValueError(
45+
f"{core_op.name} needs {core_dims_needed} core_dims, but got {len(core_dims)}"
46+
)
47+
48+
full_input_core_dims = tuple(
49+
tuple(core_dims[i] for i in inp_dims_map)
50+
for inp_dims_map in core_inps_dims_map
51+
)
52+
full_output_core_dims = (tuple(core_dims[i] for i in core_out_dims_map),)
53+
full_core_dims = (full_input_core_dims, full_output_core_dims)
54+
55+
if extra_dims is None:
56+
extra_dims = {}
57+
58+
return XRV(core_op, core_dims=full_core_dims, extra_dims=extra_dims.keys())(
59+
rng, *extra_dims.values(), *params
60+
)
61+
62+
return xrv_constructor
63+
64+
65+
bernoulli = _as_xrv(ptr.bernoulli)
66+
beta = _as_xrv(ptr.beta)
67+
betabinom = _as_xrv(ptr.betabinom)
68+
binomial = _as_xrv(ptr.binomial)
69+
categorical = _as_xrv(ptr.categorical)
70+
cauchy = _as_xrv(ptr.cauchy)
71+
dirichlet = _as_xrv(ptr.dirichlet)
72+
exponential = _as_xrv(ptr.exponential)
73+
gamma = _as_xrv(ptr._gamma)
74+
gengamma = _as_xrv(ptr.gengamma)
75+
geometric = _as_xrv(ptr.geometric)
76+
gumbel = _as_xrv(ptr.gumbel)
77+
halfcauchy = _as_xrv(ptr.halfcauchy)
78+
halfnormal = _as_xrv(ptr.halfnormal)
79+
hypergeometric = _as_xrv(ptr.hypergeometric)
80+
integers = _as_xrv(ptr.integers)
81+
invgamma = _as_xrv(ptr.invgamma)
82+
laplace = _as_xrv(ptr.laplace)
83+
logistic = _as_xrv(ptr.logistic)
84+
lognormal = _as_xrv(ptr.lognormal)
85+
multinomial = _as_xrv(ptr.multinomial)
86+
nbinom = negative_binomial = _as_xrv(ptr.negative_binomial)
87+
normal = _as_xrv(ptr.normal)
88+
pareto = _as_xrv(ptr.pareto)
89+
poisson = _as_xrv(ptr.poisson)
90+
t = _as_xrv(ptr.t)
91+
triangular = _as_xrv(ptr.triangular)
92+
truncexpon = _as_xrv(ptr.truncexpon)
93+
uniform = _as_xrv(ptr.uniform)
94+
vonmises = _as_xrv(ptr.vonmises)
95+
wald = _as_xrv(ptr.wald)
96+
weibull = _as_xrv(ptr.weibull)
97+
98+
99+
def multivariate_normal(
100+
mean,
101+
cov,
102+
*,
103+
core_dims: Sequence[str],
104+
extra_dims=None,
105+
rng=None,
106+
method="cholesky",
107+
):
108+
mean = as_xtensor(mean)
109+
if len(core_dims) != 2:
110+
raise ValueError(
111+
f"multivariate_normal requires 2 core_dims, got {len(core_dims)}"
112+
)
113+
114+
# Align core_dims, so the one that exists in mean comes before
115+
if core_dims[0] not in mean.type.dims:
116+
core_dims = core_dims[::-1]
117+
118+
xop = _as_xrv(ptr.MvNormalRV(method=method))
119+
return xop(mean, cov, core_dims=core_dims, extra_dims=extra_dims, rng=rng)
120+
121+
122+
# Missing special cases
123+
# standard_normal
124+
# chisquare
125+
# rayleigh
126+
# choice
127+
# permutation

pytensor/xtensor/rewriting/vectorization.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from pytensor.graph import node_rewriter
22
from pytensor.tensor.blockwise import Blockwise
33
from pytensor.tensor.elemwise import Elemwise
4+
from pytensor.tensor.random.utils import compute_batch_shape
45
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
56
from pytensor.xtensor.rewriting.utils import register_lower_xtensor
6-
from pytensor.xtensor.vectorization import XBlockwise, XElemwise
7+
from pytensor.xtensor.vectorization import XRV, XBlockwise, XElemwise
78

89

910
@register_lower_xtensor
@@ -62,3 +63,46 @@ def lower_blockwise(fgraph, node):
6263
for (tensor_out, old_out) in zip(tensor_outs, node.outputs, strict=True)
6364
]
6465
return new_outs
66+
67+
68+
@register_lower_xtensor
69+
@node_rewriter(tracks=[XRV])
70+
def lower_rv(fgraph, node):
71+
op: XRV = node.op
72+
core_op = op.core_op
73+
74+
_, old_out = node.outputs
75+
rng, *extra_dim_lengths_and_params = node.inputs
76+
extra_dim_lengths = extra_dim_lengths_and_params[: len(op.extra_dims)]
77+
params = extra_dim_lengths_and_params[len(op.extra_dims) :]
78+
79+
batch_ndim = old_out.type.ndim - len(op.core_dims[1][0])
80+
param_batch_dims = old_out.type.dims[len(op.extra_dims) : batch_ndim]
81+
82+
# Convert params Tensors to XTensors, align batch dimensions and place core dimension at the end
83+
tensor_params = []
84+
for inp, core_dims in zip(params, op.core_dims[0]):
85+
inp_dims = inp.type.dims
86+
# Align the batch dims of the input, and place the core dims on the right
87+
batch_order = [
88+
inp_dims.index(batch_dim) if batch_dim in inp_dims else "x"
89+
for batch_dim in param_batch_dims
90+
]
91+
core_order = [inp_dims.index(core_dim) for core_dim in core_dims]
92+
tensor_inp = tensor_from_xtensor(inp).dimshuffle(batch_order + core_order)
93+
tensor_params.append(tensor_inp)
94+
95+
size = None
96+
if op.extra_dims:
97+
# RV size contains the lengths of all batch dimensions, including those coming from the parameters
98+
param_batch_shape = compute_batch_shape(
99+
tensor_params, ndims_params=core_op.ndims_params
100+
)
101+
size = [*extra_dim_lengths, *tuple(param_batch_shape)]
102+
103+
# RVs are their own core Op
104+
new_next_rng, tensor_out = core_op(*tensor_params, rng=rng, size=size).owner.outputs
105+
106+
# Convert output Tensors to XTensors
107+
new_out = xtensor_from_tensor(tensor_out, dims=old_out.type.dims)
108+
return [new_next_rng, new_out]

pytensor/xtensor/shape.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55

66
from pytensor.graph import Apply
77
from pytensor.scalar import discrete_dtypes, upcast
8-
from pytensor.tensor import as_tensor, get_scalar_constant_value
9-
from pytensor.tensor.exceptions import NotScalarConstantError
8+
from pytensor.tensor import as_tensor
9+
from pytensor.tensor.utils import get_static_shape_from_size_variables
1010
from pytensor.xtensor.basic import XOp
1111
from pytensor.xtensor.type import as_xtensor, xtensor
1212

@@ -127,14 +127,9 @@ def make_node(self, x, *unstacked_length):
127127
)
128128
)
129129

130-
static_unstacked_lengths = [None] * len(unstacked_lengths)
131-
for i, length in enumerate(unstacked_lengths):
132-
try:
133-
static_length = get_scalar_constant_value(length)
134-
except NotScalarConstantError:
135-
pass
136-
else:
137-
static_unstacked_lengths[i] = int(static_length)
130+
static_unstacked_lengths = get_static_shape_from_size_variables(
131+
unstacked_lengths
132+
)
138133

139134
output = xtensor(
140135
dtype=x.type.dtype,

0 commit comments

Comments
 (0)