Skip to content

Implement labeled RVs #1446

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: labeled_tensors
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pytensor/tensor/random/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1625,8 +1625,7 @@ def rng_fn_scipy(cls, rng, n, p, size):
return stats.nbinom.rvs(n, p, size=size, random_state=rng)


nbinom = NegBinomialRV()
negative_binomial = NegBinomialRV()
nbinom = negative_binomial = NegBinomialRV()


class BetaBinomialRV(ScipyRandomVariable):
Expand Down Expand Up @@ -1808,6 +1807,7 @@ def rng_fn(cls, rng, n, p, size):

multinomial = MultinomialRV()


vsearchsorted = np.vectorize(np.searchsorted, otypes=[int], signature="(n),()->()")


Expand Down
4 changes: 1 addition & 3 deletions pytensor/tensor/rewriting/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,6 +759,7 @@ def local_remove_useless_assert(fgraph, node):
return [new_var]


@register_infer_shape
@node_rewriter([Assert])
def local_remove_all_assert(fgraph, node):
r"""A rewrite that removes all `Assert`\s from a graph.
Expand All @@ -768,9 +769,6 @@ def local_remove_all_assert(fgraph, node):
See the :ref:`unsafe` section.

"""
if not isinstance(node.op, Assert):
return

return [node.inputs[0]]


Expand Down
29 changes: 29 additions & 0 deletions pytensor/tensor/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import pytensor
from pytensor.graph import FunctionGraph, Variable
from pytensor.npy_2_compat import normalize_axis_tuple
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.utils import hash_from_code


Expand Down Expand Up @@ -256,3 +257,31 @@ def faster_ndindex(shape: Sequence[int]):
https://github.com/numpy/numpy/issues/28921
"""
return product(*(range(s) for s in shape))


def get_static_shape_from_size_variables(
size_vars: Sequence[Variable],
) -> tuple[int | None, ...]:
"""Get static shape from size variables.

Parameters
----------
size_vars : Sequence[Variable]
A sequence of variables representing the size of each dimension.
Returns
-------
tuple[int | None, ...]
A tuple containing the static lengths of each dimension, or None if
the length is not statically known.
"""
from pytensor.tensor.basic import get_scalar_constant_value

static_lengths = [None] * len(size_vars)
for i, length in enumerate(size_vars):
try:
static_length = get_scalar_constant_value(length)
except NotScalarConstantError:
pass
else:
static_lengths[i] = int(static_length)
return tuple(static_lengths)
3 changes: 3 additions & 0 deletions pytensor/tensor/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,9 @@ def dimshuffle(self, *pattern):
if (len(pattern) == 1) and (isinstance(pattern[0], list | tuple | np.ndarray)):
pattern = pattern[0]
ds_op = pt.elemwise.DimShuffle(input_ndim=self.type.ndim, new_order=pattern)
if ds_op.new_order == tuple(range(self.type.ndim)):
# No-op
return self
return ds_op(self)

def flatten(self, ndim=1):
Expand Down
4 changes: 2 additions & 2 deletions pytensor/xtensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ def make_node(self, x):
return Apply(self, [x], [output])


def xtensor_from_tensor(x, dims):
return XTensorFromTensor(dims=dims)(x)
def xtensor_from_tensor(x, dims, name=None):
return XTensorFromTensor(dims=dims)(x, name=name)


class Rename(XTypeCastOp):
Expand Down
127 changes: 127 additions & 0 deletions pytensor/xtensor/random.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
from collections.abc import Sequence
from functools import wraps

import pytensor.tensor.random.basic as ptr
from pytensor.graph.basic import Variable
from pytensor.tensor.random.op import RandomVariable
from pytensor.xtensor import as_xtensor
from pytensor.xtensor.vectorization import XRV


def _as_xrv(
core_op: RandomVariable,
core_inps_dims_map: Sequence[Sequence[int]] | None = None,
core_out_dims_map: Sequence[int] | None = None,
):
if core_inps_dims_map is None:
# Assume core_dims map positionally from left to right
core_inps_dims_map = [tuple(range(ndim)) for ndim in core_op.ndims_params]
if core_out_dims_map is None:
# Assume core_dims map positionally from left to right
core_out_dims_map = tuple(range(core_op.ndim_supp))

core_dims_needed = max(
(*(len(i) for i in core_inps_dims_map), len(core_out_dims_map)), default=0
)

@wraps(core_op)
def xrv_constructor(
*params,
core_dims: Sequence[str] | str | None = None,
extra_dims: dict[str, Variable] | None = None,
rng: Variable | None = None,
):
if core_dims is None:
core_dims = ()
if core_dims_needed:
raise ValueError(
f"{core_op.name} needs {core_dims_needed} core_dims to be specified"
)
elif isinstance(core_dims, str):
core_dims = (core_dims,)

if len(core_dims) != core_dims_needed:
raise ValueError(
f"{core_op.name} needs {core_dims_needed} core_dims, but got {len(core_dims)}"
)

full_input_core_dims = tuple(
tuple(core_dims[i] for i in inp_dims_map)
for inp_dims_map in core_inps_dims_map
)
full_output_core_dims = (tuple(core_dims[i] for i in core_out_dims_map),)
full_core_dims = (full_input_core_dims, full_output_core_dims)

if extra_dims is None:
extra_dims = {}

return XRV(core_op, core_dims=full_core_dims, extra_dims=extra_dims.keys())(
rng, *extra_dims.values(), *params
)

return xrv_constructor


bernoulli = _as_xrv(ptr.bernoulli)
beta = _as_xrv(ptr.beta)
betabinom = _as_xrv(ptr.betabinom)
binomial = _as_xrv(ptr.binomial)
categorical = _as_xrv(ptr.categorical)
cauchy = _as_xrv(ptr.cauchy)
dirichlet = _as_xrv(ptr.dirichlet)
exponential = _as_xrv(ptr.exponential)
gamma = _as_xrv(ptr._gamma)
gengamma = _as_xrv(ptr.gengamma)
geometric = _as_xrv(ptr.geometric)
gumbel = _as_xrv(ptr.gumbel)
halfcauchy = _as_xrv(ptr.halfcauchy)
halfnormal = _as_xrv(ptr.halfnormal)
hypergeometric = _as_xrv(ptr.hypergeometric)
integers = _as_xrv(ptr.integers)
invgamma = _as_xrv(ptr.invgamma)
laplace = _as_xrv(ptr.laplace)
logistic = _as_xrv(ptr.logistic)
lognormal = _as_xrv(ptr.lognormal)
multinomial = _as_xrv(ptr.multinomial)
nbinom = negative_binomial = _as_xrv(ptr.negative_binomial)
normal = _as_xrv(ptr.normal)
pareto = _as_xrv(ptr.pareto)
poisson = _as_xrv(ptr.poisson)
t = _as_xrv(ptr.t)
triangular = _as_xrv(ptr.triangular)
truncexpon = _as_xrv(ptr.truncexpon)
uniform = _as_xrv(ptr.uniform)
vonmises = _as_xrv(ptr.vonmises)
wald = _as_xrv(ptr.wald)
weibull = _as_xrv(ptr.weibull)


def multivariate_normal(
mean,
cov,
*,
core_dims: Sequence[str],
extra_dims=None,
rng=None,
method="cholesky",
):
mean = as_xtensor(mean)
if len(core_dims) != 2:
raise ValueError(
f"multivariate_normal requires 2 core_dims, got {len(core_dims)}"
)

# Align core_dims, so the one that exists in mean comes before
if core_dims[0] not in mean.type.dims:
core_dims = core_dims[::-1]

xop = _as_xrv(ptr.MvNormalRV(method=method))
return xop(mean, cov, core_dims=core_dims, extra_dims=extra_dims, rng=rng)


# Missing special cases
# standard_normal
# chisquare
# rayleigh
# choice
# permutation
46 changes: 45 additions & 1 deletion pytensor/xtensor/rewriting/vectorization.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from pytensor.graph import node_rewriter
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.random.utils import compute_batch_shape
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
from pytensor.xtensor.rewriting.utils import register_lower_xtensor
from pytensor.xtensor.vectorization import XBlockwise, XElemwise
from pytensor.xtensor.vectorization import XRV, XBlockwise, XElemwise


@register_lower_xtensor
Expand Down Expand Up @@ -62,3 +63,46 @@ def lower_blockwise(fgraph, node):
for (tensor_out, old_out) in zip(tensor_outs, node.outputs, strict=True)
]
return new_outs


@register_lower_xtensor
@node_rewriter(tracks=[XRV])
def lower_rv(fgraph, node):
op: XRV = node.op
core_op = op.core_op

_, old_out = node.outputs
rng, *extra_dim_lengths_and_params = node.inputs
extra_dim_lengths = extra_dim_lengths_and_params[: len(op.extra_dims)]
params = extra_dim_lengths_and_params[len(op.extra_dims) :]

batch_ndim = old_out.type.ndim - len(op.core_dims[1][0])
param_batch_dims = old_out.type.dims[len(op.extra_dims) : batch_ndim]

# Convert params Tensors to XTensors, align batch dimensions and place core dimension at the end
tensor_params = []
for inp, core_dims in zip(params, op.core_dims[0]):
inp_dims = inp.type.dims
# Align the batch dims of the input, and place the core dims on the right
batch_order = [
inp_dims.index(batch_dim) if batch_dim in inp_dims else "x"
for batch_dim in param_batch_dims
]
core_order = [inp_dims.index(core_dim) for core_dim in core_dims]
tensor_inp = tensor_from_xtensor(inp).dimshuffle(batch_order + core_order)
tensor_params.append(tensor_inp)

size = None
if op.extra_dims:
# RV size contains the lengths of all batch dimensions, including those coming from the parameters
param_batch_shape = compute_batch_shape(
tensor_params, ndims_params=core_op.ndims_params
)
size = [*extra_dim_lengths, *tuple(param_batch_shape)]

# RVs are their own core Op
new_next_rng, tensor_out = core_op(*tensor_params, rng=rng, size=size).owner.outputs

# Convert output Tensors to XTensors
new_out = xtensor_from_tensor(tensor_out, dims=old_out.type.dims)
return [new_next_rng, new_out]
15 changes: 5 additions & 10 deletions pytensor/xtensor/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

from pytensor.graph import Apply
from pytensor.scalar import discrete_dtypes, upcast
from pytensor.tensor import as_tensor, get_scalar_constant_value
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor import as_tensor
from pytensor.tensor.utils import get_static_shape_from_size_variables
from pytensor.xtensor.basic import XOp
from pytensor.xtensor.type import as_xtensor, xtensor

Expand Down Expand Up @@ -127,14 +127,9 @@ def make_node(self, x, *unstacked_length):
)
)

static_unstacked_lengths = [None] * len(unstacked_lengths)
for i, length in enumerate(unstacked_lengths):
try:
static_length = get_scalar_constant_value(length)
except NotScalarConstantError:
pass
else:
static_unstacked_lengths[i] = int(static_length)
static_unstacked_lengths = get_static_shape_from_size_variables(
unstacked_lengths
)

output = xtensor(
dtype=x.type.dtype,
Expand Down
24 changes: 15 additions & 9 deletions pytensor/xtensor/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ def __init__(
)
self.ndim = len(self.dims)
self.name = name
self.numpy_dtype = np.dtype(self.dtype)
self.filter_checks_isfinite = False

def clone(
self,
Expand All @@ -82,8 +84,9 @@ def clone(
return type(self)(dtype=dtype, shape=shape, dims=dims, **kwargs)

def filter(self, value, strict=False, allow_downcast=None):
# TODO implement this
return value
return TensorType.filter(
self, value, strict=strict, allow_downcast=allow_downcast
)

def convert_variable(self, var):
# TODO: Implement this
Expand Down Expand Up @@ -633,7 +636,7 @@ def signature(self):

def xtensor_constant(x, name=None, dims: None | Sequence[str] = None):
x_dims: tuple[str, ...]
if isinstance(x, xr.DataArray):
if XARRAY_AVAILABLE and isinstance(x, xr.DataArray):
xarray_dims = x.dims
if not all(isinstance(dim, str) for dim in xarray_dims):
raise NotImplementedError(
Expand Down Expand Up @@ -689,17 +692,20 @@ def as_xtensor(x, name=None, dims: Sequence[str] | None = None):
if isinstance(x.type, XTensorType):
return x
if isinstance(x.type, TensorType):
if x.type.ndim > 0 and dims is None:
raise TypeError(
"non-scalar TensorVariable cannot be converted to XTensorVariable without dims."
)
return px.basic.xtensor_from_tensor(x, dims)
if dims is None:
if x.type.ndim == 0:
dims = ()
else:
raise TypeError(
"non-scalar TensorVariable cannot be converted to XTensorVariable without dims."
)
return px.basic.xtensor_from_tensor(x, dims=dims, name=name)
else:
raise TypeError(
"Variable with type {x.type} cannot be converted to XTensorVariable."
)
try:
return xtensor_constant(x, name=name, dims=dims)
return xtensor_constant(x, dims=dims, name=name)
except TypeError as err:
raise TypeError(f"Cannot convert {x} to XTensorType {type(x)}") from err

Expand Down
Loading
Loading