Skip to content

Implement several RandomVariables as SymbolicRandomVariables #7239

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 14, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 27 additions & 20 deletions pymc/distributions/censored.py
Original file line number Diff line number Diff line change
@@ -16,13 +16,19 @@

from pytensor.tensor import TensorVariable
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.random.utils import normalize_size_param

from pymc.distributions.distribution import (
Distribution,
SymbolicRandomVariable,
_support_point,
)
from pymc.distributions.shape_utils import _change_dist_size, change_dist_size
from pymc.distributions.shape_utils import (
_change_dist_size,
change_dist_size,
implicit_size_from_params,
rv_size_is_none,
)
from pymc.util import check_dist_not_registered


@@ -31,9 +37,27 @@ class CensoredRV(SymbolicRandomVariable):

inline_logprob = True
signature = "(),(),()->()"
ndim_supp = 0
_print_name = ("Censored", "\\operatorname{Censored}")

@classmethod
def rv_op(cls, dist, lower, upper, *, size=None):
# We don't allow passing `rng` because we don't fully control the rng of the components!
lower = pt.constant(-np.inf) if lower is None else pt.as_tensor(lower)
upper = pt.constant(np.inf) if upper is None else pt.as_tensor(upper)
size = normalize_size_param(size)

if rv_size_is_none(size):
size = implicit_size_from_params(dist, lower, upper, ndims_params=cls.ndims_params)

# Censoring is achieved by clipping the base distribution between lower and upper
dist = change_dist_size(dist, size)
censored_rv = pt.clip(dist, lower, upper)

return CensoredRV(
inputs=[dist, lower, upper],
outputs=[censored_rv],
)(dist, lower, upper)


class Censored(Distribution):
r"""
@@ -85,6 +109,7 @@ class Censored(Distribution):
"""

rv_type = CensoredRV
rv_op = CensoredRV.rv_op

@classmethod
def dist(cls, dist, lower, upper, **kwargs):
@@ -101,24 +126,6 @@ def dist(cls, dist, lower, upper, **kwargs):
check_dist_not_registered(dist)
return super().dist([dist, lower, upper], **kwargs)

@classmethod
def rv_op(cls, dist, lower=None, upper=None, size=None):
lower = pt.constant(-np.inf) if lower is None else pt.as_tensor_variable(lower)
upper = pt.constant(np.inf) if upper is None else pt.as_tensor_variable(upper)

# When size is not specified, dist may have to be broadcasted according to lower/upper
dist_shape = size if size is not None else pt.broadcast_shape(dist, lower, upper)
dist = change_dist_size(dist, dist_shape)

# Censoring is achieved by clipping the base distribution between lower and upper
dist_, lower_, upper_ = dist.type(), lower.type(), upper.type()
censored_rv_ = pt.clip(dist_, lower_, upper_)

return CensoredRV(
inputs=[dist_, lower_, upper_],
outputs=[censored_rv_],
)(dist, lower, upper)


@_change_dist_size.register(CensoredRV)
def change_censored_size(cls, dist, new_size, expand=False):
158 changes: 90 additions & 68 deletions pymc/distributions/continuous.py
Original file line number Diff line number Diff line change
@@ -52,10 +52,12 @@
vonmises,
)
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.random.utils import normalize_size_param
from pytensor.tensor.variable import TensorConstant

from pymc.logprob.abstract import _logprob_helper
from pymc.logprob.basic import icdf
from pymc.pytensorf import normalize_rng_param

try:
from polyagamma import polyagamma_cdf, polyagamma_pdf, random_polyagamma
@@ -73,7 +75,6 @@ def polyagamma_cdf(*args, **kwargs):

from scipy import stats
from scipy.interpolate import InterpolatedUnivariateSpline
from scipy.special import expit

from pymc.distributions import transforms
from pymc.distributions.dist_math import (
@@ -90,8 +91,8 @@ def polyagamma_cdf(*args, **kwargs):
normal_lcdf,
zvalue,
)
from pymc.distributions.distribution import DIST_PARAMETER_TYPES, Continuous
from pymc.distributions.shape_utils import rv_size_is_none
from pymc.distributions.distribution import DIST_PARAMETER_TYPES, Continuous, SymbolicRandomVariable
from pymc.distributions.shape_utils import implicit_size_from_params, rv_size_is_none
from pymc.distributions.transforms import _default_transform
from pymc.math import invlogit, logdiffexp, logit

@@ -1236,20 +1237,28 @@ def icdf(value, alpha, beta):
)


class KumaraswamyRV(RandomVariable):
class KumaraswamyRV(SymbolicRandomVariable):
name = "kumaraswamy"
ndim_supp = 0
ndims_params = [0, 0]
dtype = "floatX"
signature = "[rng],[size],(),()->[rng],()"
_print_name = ("Kumaraswamy", "\\operatorname{Kumaraswamy}")

@classmethod
def rng_fn(cls, rng, a, b, size) -> np.ndarray:
u = rng.uniform(size=size)
return np.asarray((1 - (1 - u) ** (1 / b)) ** (1 / a))
def rv_op(cls, a, b, *, size=None, rng=None):
a = pt.as_tensor(a)
b = pt.as_tensor(b)
rng = normalize_rng_param(rng)
size = normalize_size_param(size)

if rv_size_is_none(size):
size = implicit_size_from_params(a, b, ndims_params=cls.ndims_params)

kumaraswamy = KumaraswamyRV()
next_rng, u = uniform(size=size, rng=rng).owner.outputs
draws = (1 - (1 - u) ** (1 / b)) ** (1 / a)

return cls(
inputs=[rng, size, a, b],
outputs=[next_rng, draws],
)(rng, size, a, b)


class Kumaraswamy(UnitContinuous):
@@ -1296,13 +1305,11 @@ class Kumaraswamy(UnitContinuous):
b > 0.
"""

rv_op = kumaraswamy
rv_type = KumaraswamyRV
rv_op = KumaraswamyRV.rv_op

@classmethod
def dist(cls, a: DIST_PARAMETER_TYPES, b: DIST_PARAMETER_TYPES, *args, **kwargs):
a = pt.as_tensor_variable(a)
b = pt.as_tensor_variable(b)

return super().dist([a, b], *args, **kwargs)

def support_point(rv, size, a, b):
@@ -1533,24 +1540,32 @@ def icdf(value, mu, b):
return check_icdf_parameters(res, b > 0, msg="b > 0")


class AsymmetricLaplaceRV(RandomVariable):
class AsymmetricLaplaceRV(SymbolicRandomVariable):
name = "asymmetriclaplace"
ndim_supp = 0
ndims_params = [0, 0, 0]
dtype = "floatX"
signature = "[rng],[size],(),(),()->[rng],()"
_print_name = ("AsymmetricLaplace", "\\operatorname{AsymmetricLaplace}")

@classmethod
def rng_fn(cls, rng, b, kappa, mu, size=None) -> np.ndarray:
u = rng.uniform(size=size)
def rv_op(cls, b, kappa, mu, *, size=None, rng=None):
b = pt.as_tensor(b)
kappa = pt.as_tensor(kappa)
mu = pt.as_tensor(mu)
rng = normalize_rng_param(rng)
size = normalize_size_param(size)

if rv_size_is_none(size):
size = implicit_size_from_params(b, kappa, mu, ndims_params=cls.ndims_params)

next_rng, u = uniform(size=size, rng=rng).owner.outputs
switch = kappa**2 / (1 + kappa**2)
non_positive_x = mu + kappa * np.log(u * (1 / switch)) / b
positive_x = mu - np.log((1 - u) * (1 + kappa**2)) / (kappa * b)
non_positive_x = mu + kappa * pt.log(u * (1 / switch)) / b
positive_x = mu - pt.log((1 - u) * (1 + kappa**2)) / (kappa * b)
draws = non_positive_x * (u <= switch) + positive_x * (u > switch)
return np.asarray(draws)


asymmetriclaplace = AsymmetricLaplaceRV()
return cls(
inputs=[rng, size, b, kappa, mu],
outputs=[next_rng, draws],
)(rng, size, b, kappa, mu)


class AsymmetricLaplace(Continuous):
@@ -1599,15 +1614,12 @@ class AsymmetricLaplace(Continuous):
of interest.
"""

rv_op = asymmetriclaplace
rv_type = AsymmetricLaplaceRV
rv_op = AsymmetricLaplaceRV.rv_op

@classmethod
def dist(cls, kappa=None, mu=None, b=None, q=None, *args, **kwargs):
kappa = cls.get_kappa(kappa, q)
b = pt.as_tensor_variable(b)
kappa = pt.as_tensor_variable(kappa)
mu = pt.as_tensor_variable(mu)

return super().dist([b, kappa, mu], *args, **kwargs)

@classmethod
@@ -2475,7 +2487,6 @@ def dist(cls, nu, **kwargs):
return Gamma.dist(alpha=nu / 2, beta=1 / 2, **kwargs)


# TODO: Remove this once logp for multiplication is working!
class WeibullBetaRV(RandomVariable):
name = "weibull"
ndim_supp = 0
@@ -2597,19 +2608,22 @@ def icdf(value, alpha, beta):
)


class HalfStudentTRV(RandomVariable):
class HalfStudentTRV(SymbolicRandomVariable):
name = "halfstudentt"
ndim_supp = 0
ndims_params = [0, 0]
dtype = "floatX"
signature = "[rng],[size],(),()->[rng],()"
_print_name = ("HalfStudentT", "\\operatorname{HalfStudentT}")

@classmethod
def rng_fn(cls, rng, nu, sigma, size=None) -> np.ndarray:
return np.asarray(np.abs(stats.t.rvs(nu, scale=sigma, size=size, random_state=rng)))
def rv_op(cls, nu, sigma, *, size=None, rng=None) -> np.ndarray:
nu = pt.as_tensor(nu)
sigma = pt.as_tensor(sigma)
rng = normalize_rng_param(rng)
size = normalize_size_param(size)

next_rng, t_draws = t(df=nu, scale=sigma, size=size, rng=rng).owner.outputs
draws = pt.abs(t_draws)

halfstudentt = HalfStudentTRV()
return cls(inputs=[rng, size, nu, sigma], outputs=[next_rng, draws])(rng, size, nu, sigma)


class HalfStudentT(PositiveContinuous):
@@ -2671,14 +2685,12 @@ class HalfStudentT(PositiveContinuous):
x = pm.HalfStudentT('x', lam=4, nu=10)
"""

rv_op = halfstudentt
rv_type = HalfStudentTRV
rv_op = HalfStudentTRV.rv_op

@classmethod
def dist(cls, nu, sigma=None, lam=None, *args, **kwargs):
nu = pt.as_tensor_variable(nu)
lam, sigma = get_tau_sigma(lam, sigma)
sigma = pt.as_tensor_variable(sigma)

return super().dist([nu, sigma], *args, **kwargs)

def support_point(rv, size, nu, sigma):
@@ -2710,19 +2722,29 @@ def logp(value, nu, sigma):
)


class ExGaussianRV(RandomVariable):
class ExGaussianRV(SymbolicRandomVariable):
name = "exgaussian"
ndim_supp = 0
ndims_params = [0, 0, 0]
dtype = "floatX"
signature = "[rng],[size],(),(),()->[rng],()"
_print_name = ("ExGaussian", "\\operatorname{ExGaussian}")

@classmethod
def rng_fn(cls, rng, mu, sigma, nu, size=None) -> np.ndarray:
return np.asarray(rng.normal(mu, sigma, size=size) + rng.exponential(scale=nu, size=size))
def rv_op(cls, mu, sigma, nu, *, size=None, rng=None):
mu = pt.as_tensor(mu)
sigma = pt.as_tensor(sigma)
nu = pt.as_tensor(nu)
rng = normalize_rng_param(rng)
size = normalize_size_param(size)

if rv_size_is_none(size):
size = implicit_size_from_params(mu, sigma, nu, ndims_params=cls.ndims_params)

exgaussian = ExGaussianRV()
next_rng, normal_draws = normal(loc=mu, scale=sigma, size=size, rng=rng).owner.outputs
final_rng, exponential_draws = exponential(scale=nu, size=size, rng=next_rng).owner.outputs
draws = normal_draws + exponential_draws

return cls(inputs=[rng, size, mu, sigma, nu], outputs=[final_rng, draws])(
rng, size, mu, sigma, nu
)


class ExGaussian(Continuous):
@@ -2792,14 +2814,11 @@ class ExGaussian(Continuous):
Vol. 4, No. 1, pp 35-45.
"""

rv_op = exgaussian
rv_type = ExGaussianRV
rv_op = ExGaussianRV.rv_op

@classmethod
def dist(cls, mu=0.0, sigma=None, nu=None, *args, **kwargs):
mu = pt.as_tensor_variable(mu)
sigma = pt.as_tensor_variable(sigma)
nu = pt.as_tensor_variable(nu)

return super().dist([mu, sigma, nu], *args, **kwargs)

def support_point(rv, size, mu, sigma, nu):
@@ -3477,19 +3496,25 @@ def icdf(value, mu, s):
)


class LogitNormalRV(RandomVariable):
class LogitNormalRV(SymbolicRandomVariable):
name = "logit_normal"
ndim_supp = 0
ndims_params = [0, 0]
dtype = "floatX"
signature = "[rng],[size],(),()->[rng],()"
_print_name = ("logitNormal", "\\operatorname{logitNormal}")

@classmethod
def rng_fn(cls, rng, mu, sigma, size=None) -> np.ndarray:
return np.asarray(expit(stats.norm.rvs(loc=mu, scale=sigma, size=size, random_state=rng)))
def rv_op(cls, mu, sigma, *, size=None, rng=None):
mu = pt.as_tensor(mu)
sigma = pt.as_tensor(sigma)
rng = normalize_rng_param(rng)
size = normalize_size_param(size)

next_rng, normal_draws = normal(loc=mu, scale=sigma, size=size, rng=rng).owner.outputs
draws = pt.expit(normal_draws)

logit_normal = LogitNormalRV()
return cls(
inputs=[rng, size, mu, sigma],
outputs=[next_rng, draws],
)(rng, size, mu, sigma)


class LogitNormal(UnitContinuous):
@@ -3540,15 +3565,12 @@ class LogitNormal(UnitContinuous):
Defaults to 1.
"""

rv_op = logit_normal
rv_type = LogitNormalRV
rv_op = LogitNormalRV.rv_op

@classmethod
def dist(cls, mu=0, sigma=None, tau=None, **kwargs):
mu = pt.as_tensor_variable(mu)
tau, sigma = get_tau_sigma(tau=tau, sigma=sigma)
sigma = pt.as_tensor_variable(sigma)
tau = pt.as_tensor_variable(tau)

_, sigma = get_tau_sigma(tau=tau, sigma=sigma)
return super().dist([mu, sigma], **kwargs)

def support_point(rv, size, mu, sigma):
36 changes: 22 additions & 14 deletions pymc/distributions/discrete.py
Original file line number Diff line number Diff line change
@@ -18,7 +18,6 @@

from pytensor.tensor import TensorConstant
from pytensor.tensor.random.basic import (
RandomVariable,
ScipyRandomVariable,
bernoulli,
betabinom,
@@ -28,7 +27,9 @@
hypergeometric,
nbinom,
poisson,
uniform,
)
from pytensor.tensor.random.utils import normalize_size_param
from scipy import stats

import pymc as pm
@@ -45,8 +46,8 @@
normal_lccdf,
normal_lcdf,
)
from pymc.distributions.distribution import Discrete
from pymc.distributions.shape_utils import rv_size_is_none
from pymc.distributions.distribution import Discrete, SymbolicRandomVariable
from pymc.distributions.shape_utils import implicit_size_from_params, rv_size_is_none
from pymc.logprob.basic import logcdf, logp
from pymc.math import sigmoid

@@ -65,6 +66,8 @@
"OrderedProbit",
]

from pymc.pytensorf import normalize_rng_param


class Binomial(Discrete):
R"""
@@ -387,20 +390,26 @@ def logcdf(value, p):
)


class DiscreteWeibullRV(RandomVariable):
class DiscreteWeibullRV(SymbolicRandomVariable):
name = "discrete_weibull"
ndim_supp = 0
ndims_params = [0, 0]
dtype = "int64"
signature = "[rng],[size],(),()->[rng],()"
_print_name = ("dWeibull", "\\operatorname{dWeibull}")

@classmethod
def rng_fn(cls, rng, q, beta, size):
p = rng.uniform(size=size)
return np.ceil(np.power(np.log(1 - p) / np.log(q), 1.0 / beta)) - 1
def rv_op(cls, q, beta, *, size=None, rng=None):
q = pt.as_tensor(q)
beta = pt.as_tensor(beta)
rng = normalize_rng_param(rng)
size = normalize_size_param(size)

if rv_size_is_none(size):
size = implicit_size_from_params(q, beta, ndims_params=cls.ndims_params)

next_rng, p = uniform(size=size, rng=rng).owner.outputs
draws = pt.ceil(pt.power(pt.log(1 - p) / pt.log(q), 1.0 / beta)) - 1
draws = draws.astype("int64")

discrete_weibull = DiscreteWeibullRV()
return cls(inputs=[rng, size, q, beta], outputs=[next_rng, draws])(rng, size, q, beta)


class DiscreteWeibull(Discrete):
@@ -452,12 +461,11 @@ def DiscreteWeibull(q, b, x):
"""

rv_op = discrete_weibull
rv_type = DiscreteWeibullRV
rv_op = DiscreteWeibullRV.rv_op

@classmethod
def dist(cls, q, beta, *args, **kwargs):
q = pt.as_tensor_variable(q)
beta = pt.as_tensor_variable(beta)
return super().dist([q, beta], **kwargs)

def support_point(rv, size, q, beta):
265 changes: 210 additions & 55 deletions pymc/distributions/distribution.py

Large diffs are not rendered by default.

207 changes: 99 additions & 108 deletions pymc/distributions/mixture.py
Original file line number Diff line number Diff line change
@@ -21,6 +21,7 @@
from pytensor.graph.basic import Node, equal_computations
from pytensor.tensor import TensorVariable
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.random.utils import normalize_size_param

from pymc.distributions import transforms
from pymc.distributions.continuous import Gamma, LogNormal, Normal, get_tau_sigma
@@ -33,7 +34,7 @@
_support_point,
support_point,
)
from pymc.distributions.shape_utils import _change_dist_size, change_dist_size
from pymc.distributions.shape_utils import _change_dist_size, change_dist_size, rv_size_is_none
from pymc.distributions.transforms import _default_transform
from pymc.distributions.truncated import Truncated
from pymc.logprob.abstract import _logcdf, _logcdf_helper, _logprob
@@ -58,9 +59,103 @@
class MarginalMixtureRV(SymbolicRandomVariable):
"""A placeholder used to specify a log-likelihood for a mixture sub-graph."""

default_output = 1
_print_name = ("MarginalMixture", "\\operatorname{MarginalMixture}")

@classmethod
def rv_op(cls, weights, *components, size=None):
# We don't allow passing `rng` because we don't fully control the rng of the components!
mix_indexes_rng = pytensor.shared(np.random.default_rng())

single_component = len(components) == 1
ndim_supp = components[0].owner.op.ndim_supp

size = normalize_size_param(size)
if not rv_size_is_none(size):
components = cls._resize_components(size, *components)
elif not single_component:
# We might need to broadcast components when size is not specified
shape = tuple(pt.broadcast_shape(*components))
size = shape[: len(shape) - ndim_supp]
components = cls._resize_components(size, *components)

# Extract replication ndims from components and weights
ndim_batch = components[0].ndim - ndim_supp
if single_component:
# One dimension is taken by the mixture axis in the single component case
ndim_batch -= 1

# The weights may imply extra batch dimensions that go beyond what is already
# implied by the component dimensions (ndim_batch)
weights_ndim_batch = max(0, weights.ndim - ndim_batch - 1)

# If weights are large enough that they would broadcast the component distributions
# we try to resize them. This in necessary to avoid duplicated values in the
# random method and for equivalency with the logp method
if weights_ndim_batch:
new_size = pt.concatenate(
[
weights.shape[:weights_ndim_batch],
components[0].shape[:ndim_batch],
]
)
components = cls._resize_components(new_size, *components)

# Extract support and batch ndims from components and weights
ndim_batch = components[0].ndim - ndim_supp
if single_component:
ndim_batch -= 1
weights_ndim_batch = max(0, weights.ndim - ndim_batch - 1)

assert weights_ndim_batch == 0

mix_axis = -ndim_supp - 1

# Stack components across mixture axis
if single_component:
# If single component, we consider it as being already "stacked"
stacked_components = components[0]
else:
stacked_components = pt.stack(components, axis=mix_axis)

# Broadcast weights to (*batched dimensions, stack dimension), ignoring support dimensions
weights_broadcast_shape = stacked_components.shape[: ndim_batch + 1]
weights_broadcasted = pt.broadcast_to(weights, weights_broadcast_shape)

# Draw mixture indexes and append (stack + ndim_supp) broadcastable dimensions to the right
mix_indexes_rng_next, mix_indexes = pt.random.categorical(
weights_broadcasted, rng=mix_indexes_rng
).owner.outputs
mix_indexes_padded = pt.shape_padright(mix_indexes, ndim_supp + 1)

# Index components and squeeze mixture dimension
mix_out = pt.take_along_axis(stacked_components, mix_indexes_padded, axis=mix_axis)
mix_out = pt.squeeze(mix_out, axis=mix_axis)

s = ",".join(f"s{i}" for i in range(components[0].owner.op.ndim_supp))
if len(components) == 1:
comp_s = ",".join((*s, "w"))
signature = f"[rng],(w),({comp_s})->[rng],({s})"
else:
comps_s = ",".join(f"({s})" for _ in components)
signature = f"[rng],(w),{comps_s}->[rng],({s})"

return MarginalMixtureRV(
inputs=[mix_indexes_rng, weights, *components],
outputs=[mix_indexes_rng_next, mix_out],
signature=signature,
)(mix_indexes_rng, weights, *components)

@classmethod
def _resize_components(cls, size, *components):
if len(components) == 1:
# If we have a single component, we need to keep the length of the mixture
# axis intact, because that's what determines the number of mixture components
mix_axis = -components[0].owner.op.ndim_supp - 1
mix_size = components[0].shape[mix_axis]
size = (*size, mix_size)

return [change_dist_size(component, size) for component in components]

def update(self, node: Node):
# Update for the internal mix_indexes RV
return {node.inputs[0]: node.outputs[0]}
@@ -176,6 +271,7 @@ class Mixture(Distribution):
"""

rv_type = MarginalMixtureRV
rv_op = MarginalMixtureRV.rv_op

@classmethod
def dist(cls, w, comp_dists, **kwargs):
@@ -221,115 +317,10 @@ def dist(cls, w, comp_dists, **kwargs):
w = pt.as_tensor_variable(w)
return super().dist([w, *comp_dists], **kwargs)

@classmethod
def rv_op(cls, weights, *components, size=None):
# Create new rng for the mix_indexes internal RV
mix_indexes_rng = pytensor.shared(np.random.default_rng())

single_component = len(components) == 1
ndim_supp = components[0].owner.op.ndim_supp

if size is not None:
components = cls._resize_components(size, *components)
elif not single_component:
# We might need to broadcast components when size is not specified
shape = tuple(pt.broadcast_shape(*components))
size = shape[: len(shape) - ndim_supp]
components = cls._resize_components(size, *components)

# Extract replication ndims from components and weights
ndim_batch = components[0].ndim - ndim_supp
if single_component:
# One dimension is taken by the mixture axis in the single component case
ndim_batch -= 1

# The weights may imply extra batch dimensions that go beyond what is already
# implied by the component dimensions (ndim_batch)
weights_ndim_batch = max(0, weights.ndim - ndim_batch - 1)

# If weights are large enough that they would broadcast the component distributions
# we try to resize them. This in necessary to avoid duplicated values in the
# random method and for equivalency with the logp method
if weights_ndim_batch:
new_size = pt.concatenate(
[
weights.shape[:weights_ndim_batch],
components[0].shape[:ndim_batch],
]
)
components = cls._resize_components(new_size, *components)

# Extract support and batch ndims from components and weights
ndim_batch = components[0].ndim - ndim_supp
if single_component:
ndim_batch -= 1
weights_ndim_batch = max(0, weights.ndim - ndim_batch - 1)

assert weights_ndim_batch == 0

# Create a OpFromGraph that encapsulates the random generating process
# Create dummy input variables with the same type as the ones provided
weights_ = weights.type()
components_ = [component.type() for component in components]
mix_indexes_rng_ = mix_indexes_rng.type()

mix_axis = -ndim_supp - 1

# Stack components across mixture axis
if single_component:
# If single component, we consider it as being already "stacked"
stacked_components_ = components_[0]
else:
stacked_components_ = pt.stack(components_, axis=mix_axis)

# Broadcast weights to (*batched dimensions, stack dimension), ignoring support dimensions
weights_broadcast_shape_ = stacked_components_.shape[: ndim_batch + 1]
weights_broadcasted_ = pt.broadcast_to(weights_, weights_broadcast_shape_)

# Draw mixture indexes and append (stack + ndim_supp) broadcastable dimensions to the right
mix_indexes_ = pt.random.categorical(weights_broadcasted_, rng=mix_indexes_rng_)
mix_indexes_padded_ = pt.shape_padright(mix_indexes_, ndim_supp + 1)

# Index components and squeeze mixture dimension
mix_out_ = pt.take_along_axis(stacked_components_, mix_indexes_padded_, axis=mix_axis)
mix_out_ = pt.squeeze(mix_out_, axis=mix_axis)

# Output mix_indexes rng update so that it can be updated in place
mix_indexes_rng_next_ = mix_indexes_.owner.outputs[0]

s = ",".join(f"s{i}" for i in range(components[0].owner.op.ndim_supp))
if len(components) == 1:
comp_s = ",".join((*s, "w"))
signature = f"(),(w),({comp_s})->({s})"
else:
comps_s = ",".join(f"({s})" for _ in components)
signature = f"(),(w),{comps_s}->({s})"
mix_op = MarginalMixtureRV(
inputs=[mix_indexes_rng_, weights_, *components_],
outputs=[mix_indexes_rng_next_, mix_out_],
signature=signature,
)

# Create the actual MarginalMixture variable
mix_out = mix_op(mix_indexes_rng, weights, *components)

return mix_out

@classmethod
def _resize_components(cls, size, *components):
if len(components) == 1:
# If we have a single component, we need to keep the length of the mixture
# axis intact, because that's what determines the number of mixture components
mix_axis = -components[0].owner.op.ndim_supp - 1
mix_size = components[0].shape[mix_axis]
size = (*size, mix_size)

return [change_dist_size(component, size) for component in components]


@_change_dist_size.register(MarginalMixtureRV)
def change_marginal_mixture_size(op, dist, new_size, expand=False):
weights, *components = dist.owner.inputs[1:]
rng, weights, *components = dist.owner.inputs

if expand:
component = components[0]
215 changes: 88 additions & 127 deletions pymc/distributions/multivariate.py
Original file line number Diff line number Diff line change
@@ -35,6 +35,7 @@
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.random.utils import (
broadcast_params,
normalize_size_param,
supp_shape_from_ref_param_shape,
)
from pytensor.tensor.type import TensorType
@@ -64,13 +65,14 @@
broadcast_dist_samples_shape,
change_dist_size,
get_support_shape,
implicit_size_from_params,
rv_size_is_none,
to_tuple,
)
from pymc.distributions.transforms import Interval, ZeroSumTransform, _default_transform
from pymc.logprob.abstract import _logprob
from pymc.math import kron_diag, kron_dot
from pymc.pytensorf import intX
from pymc.pytensorf import intX, normalize_rng_param
from pymc.util import check_dist_not_registered

__all__ = [
@@ -592,48 +594,28 @@ def logp(value, n, p):
)


class DirichletMultinomialRV(RandomVariable):
class DirichletMultinomialRV(SymbolicRandomVariable):
name = "dirichlet_multinomial"
ndim_supp = 1
ndims_params = [0, 1]
dtype = "int64"
_print_name = ("DirichletMN", "\\operatorname{DirichletMN}")

def _supp_shape_from_params(self, dist_params, param_shapes=None):
return supp_shape_from_ref_param_shape(
ndim_supp=self.ndim_supp,
dist_params=dist_params,
param_shapes=param_shapes,
ref_param_idx=1,
)
signature = "[rng],[size],(),(p)->[rng],(p)"
_print_name = ("DirichletMultinomial", "\\operatorname{DirichletMultinomial}")

@classmethod
def rng_fn(cls, rng, n, a, size):
if n.ndim > 0 or a.ndim > 1:
n, a = broadcast_params([n, a], cls.ndims_params)
size = tuple(size or ())

if size:
n = np.broadcast_to(n, size)
a = np.broadcast_to(a, (*size, a.shape[-1]))

res = np.empty(a.shape)
for idx in np.ndindex(a.shape[:-1]):
p = rng.dirichlet(a[idx])
res[idx] = rng.multinomial(n[idx], p)
return res
else:
# n is a scalar, a is a 1d array
p = rng.dirichlet(a, size=size) # (size, a.shape)

res = np.empty(p.shape)
for idx in np.ndindex(p.shape[:-1]):
res[idx] = rng.multinomial(n, p[idx])
def rv_op(cls, n, a, *, size=None, rng=None):
n = pt.as_tensor(n, dtype=int)
a = pt.as_tensor(a)
rng = normalize_rng_param(rng)
size = normalize_size_param(size)

return res
if rv_size_is_none(size):
size = implicit_size_from_params(n, a, ndims_params=cls.ndims_params)

next_rng, p = dirichlet(a, size=size, rng=rng).owner.outputs
final_rng, rv = multinomial(n, p, size=size, rng=next_rng).owner.outputs

dirichlet_multinomial = DirichletMultinomialRV()
return cls(
inputs=[rng, size, n, a],
outputs=[final_rng, rv],
)(rng, size, n, a)


class DirichletMultinomial(Discrete):
@@ -665,7 +647,8 @@ class DirichletMultinomial(Discrete):
the length of the last axis.
"""

rv_op = dirichlet_multinomial
rv_type = DirichletMultinomialRV
rv_op = DirichletMultinomialRV.rv_op

@classmethod
def dist(cls, n, a, *args, **kwargs):
@@ -1161,11 +1144,40 @@ def rng_fn(self, rng, n, eta, D, size):
# _LKJCholeskyCovBaseRV requires a properly shaped `D`, which means the variable can't
# be safely resized. Because of this, we add the thin SymbolicRandomVariable wrapper
class _LKJCholeskyCovRV(SymbolicRandomVariable):
default_output = 1
signature = "(),(),(),(n)->(),(n)"
ndim_supp = 1
signature = "[rng],(),(),(n)->[rng],(n)"
_print_name = ("_lkjcholeskycov", "\\operatorname{_lkjcholeskycov}")

@classmethod
def rv_op(cls, n, eta, sd_dist, *, size=None):
# We don't allow passing `rng` because we don't fully control the rng of the components!
n = pt.as_tensor(n, dtype="int64", ndim=0)
eta = pt.as_tensor_variable(eta, ndim=0)
rng = pytensor.shared(np.random.default_rng())
size = normalize_size_param(size)

# We resize the sd_dist automatically so that it has (size x n) independent
# draws which is what the `_LKJCholeskyCovBaseRV.rng_fn` expects. This makes the
# random and logp methods equivalent, as the latter also assumes a unique value
# for each diagonal element.
# Since `eta` and `n` are forced to be scalars we don't need to worry about
# implied batched dimensions from those for the time being.
if rv_size_is_none(size):
size = sd_dist.shape[:-1]

shape = (*size, n)
if sd_dist.owner.op.ndim_supp == 0:
sd_dist = change_dist_size(sd_dist, shape)
else:
# The support shape must be `n` but we have no way of controlling it
sd_dist = change_dist_size(sd_dist, shape[:-1])

next_rng, lkjcov = _ljk_cholesky_cov_base(n, eta, sd_dist, rng=rng).owner.outputs

return _LKJCholeskyCovRV(
inputs=[rng, n, eta, sd_dist],
outputs=[next_rng, lkjcov],
)(rng, n, eta, sd_dist)

def update(self, node):
return {node.inputs[0]: node.outputs[0]}

@@ -1176,12 +1188,10 @@ class _LKJCholeskyCov(Distribution):
"""

rv_type = _LKJCholeskyCovRV
rv_op = _LKJCholeskyCovRV.rv_op

@classmethod
def dist(cls, n, eta, sd_dist, **kwargs):
n = pt.as_tensor_variable(n, dtype=int)
eta = pt.as_tensor_variable(eta)

if not (
isinstance(sd_dist, Variable)
and sd_dist.owner is not None
@@ -1193,34 +1203,6 @@ def dist(cls, n, eta, sd_dist, **kwargs):
check_dist_not_registered(sd_dist)
return super().dist([n, eta, sd_dist], **kwargs)

@classmethod
def rv_op(cls, n, eta, sd_dist, size=None):
# We resize the sd_dist automatically so that it has (size x n) independent
# draws which is what the `_LKJCholeskyCovBaseRV.rng_fn` expects. This makes the
# random and logp methods equivalent, as the latter also assumes a unique value
# for each diagonal element.
# Since `eta` and `n` are forced to be scalars we don't need to worry about
# implied batched dimensions from those for the time being.
if size is None:
size = sd_dist.shape[:-1]
shape = (*size, n)
if sd_dist.owner.op.ndim_supp == 0:
sd_dist = change_dist_size(sd_dist, shape)
else:
# The support shape must be `n` but we have no way of controlling it
sd_dist = change_dist_size(sd_dist, shape[:-1])

# Create new rng for the _lkjcholeskycov internal RV
rng = pytensor.shared(np.random.default_rng())

rng_, n_, eta_, sd_dist_ = rng.type(), n.type(), eta.type(), sd_dist.type()
next_rng_, lkjcov_ = _ljk_cholesky_cov_base(n_, eta_, sd_dist_, rng=rng_).owner.outputs

return _LKJCholeskyCovRV(
inputs=[rng_, n_, eta_, sd_dist_],
outputs=[next_rng_, lkjcov_],
)(rng, n, eta, sd_dist)


@_change_dist_size.register(_LKJCholeskyCovRV)
def change_LKJCholeksyCovRV_size(op, dist, new_size, expand=False):
@@ -2630,7 +2612,34 @@ class ZeroSumNormalRV(SymbolicRandomVariable):
"""ZeroSumNormal random variable"""

_print_name = ("ZeroSumNormal", "\\operatorname{ZeroSumNormal}")
default_output = 0

@classmethod
def rv_op(cls, sigma, support_shape, *, size=None, rng=None):
n_zerosum_axes = pt.get_vector_length(support_shape)
sigma = pt.as_tensor(sigma)
support_shape = pt.as_tensor(support_shape, ndim=1)
rng = normalize_rng_param(rng)
size = normalize_size_param(size)

if rv_size_is_none(size):
# Size is implied by shape of sigma
size = sigma.shape[:-n_zerosum_axes]

shape = tuple(size) + tuple(support_shape)
next_rng, normal_dist = pm.Normal.dist(sigma=sigma, shape=shape, rng=rng).owner.outputs

# Zerosum-normaling is achieved by subtracting the mean along the given n_zerosum_axes
zerosum_rv = normal_dist
for axis in range(n_zerosum_axes):
zerosum_rv -= zerosum_rv.mean(axis=-axis - 1, keepdims=True)

support_str = ",".join([f"d{i}" for i in range(n_zerosum_axes)])
signature = f"[rng],(),(s),[size]->[rng],({support_str})"
return ZeroSumNormalRV(
inputs=[rng, sigma, support_shape, size],
outputs=[next_rng, zerosum_rv],
signature=signature,
)(rng, sigma, support_shape, size)


class ZeroSumNormal(Distribution):
@@ -2695,6 +2704,7 @@ class ZeroSumNormal(Distribution):
"""

rv_type = ZeroSumNormalRV
rv_op = ZeroSumNormalRV.rv_op

def __new__(
cls, *args, zerosum_axes=None, n_zerosum_axes=None, support_shape=None, dims=None, **kwargs
@@ -2726,10 +2736,10 @@ def __new__(
)

@classmethod
def dist(cls, sigma=1, n_zerosum_axes=None, support_shape=None, **kwargs):
def dist(cls, sigma=1.0, n_zerosum_axes=None, support_shape=None, **kwargs):
n_zerosum_axes = cls.check_zerosum_axes(n_zerosum_axes)

sigma = pt.as_tensor_variable(sigma)
sigma = pt.as_tensor(sigma)
if not all(sigma.type.broadcastable[-n_zerosum_axes:]):
raise ValueError("sigma must have length one across the zero-sum axes")

@@ -2743,15 +2753,13 @@ def dist(cls, sigma=1, n_zerosum_axes=None, support_shape=None, **kwargs):
if n_zerosum_axes > 0:
raise ValueError("You must specify dims, shape or support_shape parameter")

support_shape = pt.as_tensor_variable(intX(support_shape))
support_shape = pt.as_tensor(support_shape, dtype="int64", ndim=1)

assert n_zerosum_axes == pt.get_vector_length(
support_shape
), "support_shape has to be as long as n_zerosum_axes"

return super().dist(
[sigma], n_zerosum_axes=n_zerosum_axes, support_shape=support_shape, **kwargs
)
return super().dist([sigma, support_shape], **kwargs)

@classmethod
def check_zerosum_axes(cls, n_zerosum_axes: int | None) -> int:
@@ -2763,52 +2771,6 @@ def check_zerosum_axes(cls, n_zerosum_axes: int | None) -> int:
raise ValueError("n_zerosum_axes has to be > 0")
return n_zerosum_axes

@classmethod
def rv_op(cls, sigma, n_zerosum_axes, support_shape, size=None):
if size is not None:
shape = tuple(size) + tuple(support_shape)
else:
# Size is implied by shape of sigma
shape = tuple(sigma.shape[:-n_zerosum_axes]) + tuple(support_shape)

normal_dist = pm.Normal.dist(sigma=sigma, shape=shape)

if n_zerosum_axes > normal_dist.ndim:
raise ValueError("Shape of distribution is too small for the number of zerosum axes")

normal_dist_, sigma_, support_shape_ = (
normal_dist.type(),
sigma.type(),
support_shape.type(),
)

# Zerosum-normaling is achieved by subtracting the mean along the given n_zerosum_axes
zerosum_rv_ = normal_dist_
for axis in range(n_zerosum_axes):
zerosum_rv_ -= zerosum_rv_.mean(axis=-axis - 1, keepdims=True)

support_str = ",".join([f"d{i}" for i in range(n_zerosum_axes)])
signature = f"({support_str}),(),(s)->({support_str})"
return ZeroSumNormalRV(
inputs=[normal_dist_, sigma_, support_shape_],
outputs=[zerosum_rv_],
signature=signature,
)(normal_dist, sigma, support_shape)


@_change_dist_size.register(ZeroSumNormalRV)
def change_zerosum_size(op, normal_dist, new_size, expand=False):
normal_dist, sigma, support_shape = normal_dist.owner.inputs

if expand:
original_shape = tuple(normal_dist.shape)
old_size = original_shape[: len(original_shape) - op.ndim_supp]
new_size = tuple(new_size) + old_size

return ZeroSumNormal.rv_op(
sigma=sigma, n_zerosum_axes=op.ndim_supp, support_shape=support_shape, size=new_size
)


@_support_point.register(ZeroSumNormalRV)
def zerosumnormal_support_point(op, rv, *rv_inputs):
@@ -2822,11 +2784,10 @@ def zerosum_default_transform(op, rv):


@_logprob.register(ZeroSumNormalRV)
def zerosumnormal_logp(op, values, normal_dist, sigma, support_shape, **kwargs):
def zerosumnormal_logp(op, values, rng, sigma, support_shape, size, **kwargs):
(value,) = values
shape = value.shape
n_zerosum_axes = op.ndim_supp
*_, sigma = normal_dist.owner.inputs

_deg_free_support_shape = pt.inc_subtensor(shape[-n_zerosum_axes:], -1)
_full_size = pt.prod(shape).astype("floatX")
27 changes: 26 additions & 1 deletion pymc/distributions/shape_utils.py
Original file line number Diff line number Diff line change
@@ -297,8 +297,10 @@ def find_size(
return None


def rv_size_is_none(size: Variable) -> bool:
def rv_size_is_none(size: Variable | None) -> bool:
"""Check whether an rv size is None (ie., pt.Constant([]))"""
if size is None:
return True
return size.type.shape == (0,) # type: ignore [attr-defined]


@@ -354,6 +356,7 @@ def change_dist_size(
else:
new_size = tuple(new_size) # type: ignore

# TODO: Get rid of unused expand argument
new_dist = _change_dist_size(dist.owner.op, dist, new_size=new_size, expand=expand)
_add_future_warning_tag(new_dist)

@@ -538,3 +541,25 @@ def get_support_shape_1d(
return support_shape_
else:
return None


def implicit_size_from_params(
*params: TensorVariable,
ndims_params: Sequence[int],
) -> TensorVariable:
"""Infer the size of a distribution from the batch dimenesions of its parameters."""
batch_shapes = []
for param, ndim in zip(params, ndims_params):
batch_shape = list(param.shape[:-ndim] if ndim > 0 else param.shape)
# Overwrite broadcastable dims
for i, broadcastable in enumerate(param.type.broadcastable):
if broadcastable:
batch_shape[i] = 1
batch_shapes.append(batch_shape)

return pt.as_tensor(
pt.broadcast_shape(
*batch_shapes,
arrays_are_shapes=True,
)
)
447 changes: 211 additions & 236 deletions pymc/distributions/timeseries.py

Large diffs are not rendered by default.

272 changes: 140 additions & 132 deletions pymc/distributions/truncated.py
Original file line number Diff line number Diff line change
@@ -36,7 +36,12 @@
_support_point,
support_point,
)
from pymc.distributions.shape_utils import _change_dist_size, change_dist_size, to_tuple
from pymc.distributions.shape_utils import (
_change_dist_size,
change_dist_size,
rv_size_is_none,
to_tuple,
)
from pymc.distributions.transforms import _default_transform
from pymc.exceptions import TruncationError
from pymc.logprob.abstract import _logcdf, _logprob
@@ -71,6 +76,137 @@ def __init__(
)
super().__init__(*args, **kwargs)

@classmethod
def rv_op(cls, dist, lower, upper, max_n_steps, *, size=None):
# We don't accept rng because we don't have control over it when using a specialized Op
# and there may be a need for multiple RNGs in dist.

# Try to use specialized Op
try:
return _truncated(dist.owner.op, lower, upper, size, *dist.owner.inputs)
except NotImplementedError:
pass

lower = pt.as_tensor_variable(lower) if lower is not None else pt.constant(-np.inf)
upper = pt.as_tensor_variable(upper) if upper is not None else pt.constant(np.inf)

if size is not None:
size = pt.as_tensor(size, dtype="int64", ndim=1)

if rv_size_is_none(size):
size = pt.broadcast_shape(dist, lower, upper)

dist = change_dist_size(dist, new_size=size)

rv_inputs = [
inp
if not isinstance(inp.type, RandomType)
else pytensor.shared(np.random.default_rng())
for inp in dist.owner.inputs
]
graph_inputs = [*rv_inputs, lower, upper]

rv = dist.owner.op.make_node(*rv_inputs).default_output()

# Try to use inverted cdf sampling
# truncated_rv = icdf(rv, draw(uniform(cdf(lower), cdf(upper))))
try:
logcdf_lower, logcdf_upper = cls._create_logcdf_exprs(rv, rv, lower, upper)
# We use the first RNG from the base RV, so we don't have to introduce a new one
# This is not problematic because the RNG won't be used in the RV logcdf graph
uniform_rng = next(inp for inp in rv_inputs if isinstance(inp.type, RandomType))
uniform_next_rng, uniform = pt.random.uniform(
pt.exp(logcdf_lower),
pt.exp(logcdf_upper),
rng=uniform_rng,
size=rv.shape,
).owner.outputs
truncated_rv = icdf(rv, uniform, warn_rvs=False)
return TruncatedRV(
base_rv_op=dist.owner.op,
inputs=graph_inputs,
outputs=[truncated_rv, uniform_next_rng],
ndim_supp=0,
max_n_steps=max_n_steps,
)(*graph_inputs)
except NotImplementedError:
pass

# Fallback to rejection sampling
# truncated_rv = zeros(rv.shape)
# reject_draws = ones(rv.shape, dtype=bool)
# while any(reject_draws):
# truncated_rv[reject_draws] = draw(rv)[reject_draws]
# reject_draws = (truncated_rv < lower) | (truncated_rv > upper)
def loop_fn(truncated_rv, reject_draws, lower, upper, *rv_inputs):
new_truncated_rv = dist.owner.op.make_node(*rv_inputs).default_output()
# Avoid scalar boolean indexing
if truncated_rv.type.ndim == 0:
truncated_rv = new_truncated_rv
else:
truncated_rv = pt.set_subtensor(
truncated_rv[reject_draws],
new_truncated_rv[reject_draws],
)
reject_draws = pt.or_((truncated_rv < lower), (truncated_rv > upper))

return (
(truncated_rv, reject_draws),
collect_default_updates(new_truncated_rv, inputs=rv_inputs),
until(~pt.any(reject_draws)),
)

(truncated_rv, reject_draws_), updates = scan(
loop_fn,
outputs_info=[
pt.zeros_like(rv),
pt.ones_like(rv, dtype=bool),
],
non_sequences=[lower, upper, *rv_inputs],
n_steps=max_n_steps,
strict=True,
)

truncated_rv = truncated_rv[-1]
convergence = ~pt.any(reject_draws_[-1])
truncated_rv = TruncationCheck(f"Truncation did not converge in {max_n_steps} steps")(
truncated_rv, convergence
)

# Sort updates of each RNG so that they show in the same order as the input RNGs
def sort_updates(update):
rng, next_rng = update
return graph_inputs.index(rng)

next_rngs = [next_rng for rng, next_rng in sorted(updates.items(), key=sort_updates)]

return TruncatedRV(
base_rv_op=dist.owner.op,
inputs=graph_inputs,
outputs=[truncated_rv, *next_rngs],
ndim_supp=0,
max_n_steps=max_n_steps,
)(*graph_inputs)

@staticmethod
def _create_logcdf_exprs(
base_rv: TensorVariable,
value: TensorVariable,
lower: TensorVariable,
upper: TensorVariable,
) -> tuple[TensorVariable, TensorVariable]:
"""Create lower and upper logcdf expressions for base_rv.
Uses `value` as a template for broadcasting.
"""
# For left truncated discrete RVs, we need to include the whole lower bound.
lower_value = lower - 1 if base_rv.type.dtype.startswith("int") else lower
lower_value = pt.full_like(value, lower_value, dtype=config.floatX)
upper_value = pt.full_like(value, upper, dtype=config.floatX)
lower_logcdf = logcdf(base_rv, lower_value, warn_rvs=False)
upper_logcdf = graph_replace(lower_logcdf, {lower_value: upper_value})
return lower_logcdf, upper_logcdf

def update(self, node: Node):
"""Return the update mapping for the internal RNGs.
@@ -152,6 +288,7 @@ class Truncated(Distribution):
"""

rv_type = TruncatedRV
rv_op = rv_type.rv_op

@classmethod
def dist(cls, dist, lower=None, upper=None, max_n_steps: int = 10_000, **kwargs):
@@ -178,135 +315,6 @@ def dist(cls, dist, lower=None, upper=None, max_n_steps: int = 10_000, **kwargs)

return super().dist([dist, lower, upper, max_n_steps], **kwargs)

@classmethod
def rv_op(cls, dist, lower, upper, max_n_steps, size=None):
# Try to use specialized Op
try:
return _truncated(dist.owner.op, lower, upper, size, *dist.owner.inputs)
except NotImplementedError:
pass

lower = pt.as_tensor_variable(lower) if lower is not None else pt.constant(-np.inf)
upper = pt.as_tensor_variable(upper) if upper is not None else pt.constant(np.inf)

if size is None:
size = pt.broadcast_shape(dist, lower, upper)
dist = change_dist_size(dist, new_size=size)
rv_inputs = [
inp
if not isinstance(inp.type, RandomType)
else pytensor.shared(np.random.default_rng())
for inp in dist.owner.inputs
]
graph_inputs = [*rv_inputs, lower, upper]

# Variables with `_` suffix identify dummy inputs for the OpFromGraph
graph_inputs_ = [
inp.type() if not isinstance(inp.type, RandomType) else inp for inp in graph_inputs
]
*rv_inputs_, lower_, upper_ = graph_inputs_

rv_ = dist.owner.op.make_node(*rv_inputs_).default_output()

# Try to use inverted cdf sampling
# truncated_rv = icdf(rv, draw(uniform(cdf(lower), cdf(upper))))
try:
logcdf_lower_, logcdf_upper_ = Truncated._create_logcdf_exprs(rv_, rv_, lower_, upper_)
# We use the first RNG from the base RV, so we don't have to introduce a new one
# This is not problematic because the RNG won't be used in the RV logcdf graph
uniform_rng_ = next(inp_ for inp_ in rv_inputs_ if isinstance(inp_.type, RandomType))
uniform_next_rng_, uniform_ = pt.random.uniform(
pt.exp(logcdf_lower_),
pt.exp(logcdf_upper_),
rng=uniform_rng_,
size=rv_.shape,
).owner.outputs
truncated_rv_ = icdf(rv_, uniform_, warn_rvs=False)
return TruncatedRV(
base_rv_op=dist.owner.op,
inputs=graph_inputs_,
outputs=[truncated_rv_, uniform_next_rng_],
ndim_supp=0,
max_n_steps=max_n_steps,
)(*graph_inputs)
except NotImplementedError:
pass

# Fallback to rejection sampling
# truncated_rv = zeros(rv.shape)
# reject_draws = ones(rv.shape, dtype=bool)
# while any(reject_draws):
# truncated_rv[reject_draws] = draw(rv)[reject_draws]
# reject_draws = (truncated_rv < lower) | (truncated_rv > upper)
def loop_fn(truncated_rv, reject_draws, lower, upper, *rv_inputs):
new_truncated_rv = dist.owner.op.make_node(*rv_inputs_).default_output()
# Avoid scalar boolean indexing
if truncated_rv.type.ndim == 0:
truncated_rv = new_truncated_rv
else:
truncated_rv = pt.set_subtensor(
truncated_rv[reject_draws],
new_truncated_rv[reject_draws],
)
reject_draws = pt.or_((truncated_rv < lower), (truncated_rv > upper))

return (
(truncated_rv, reject_draws),
collect_default_updates(new_truncated_rv),
until(~pt.any(reject_draws)),
)

(truncated_rv_, reject_draws_), updates = scan(
loop_fn,
outputs_info=[
pt.zeros_like(rv_),
pt.ones_like(rv_, dtype=bool),
],
non_sequences=[lower_, upper_, *rv_inputs_],
n_steps=max_n_steps,
strict=True,
)

truncated_rv_ = truncated_rv_[-1]
convergence_ = ~pt.any(reject_draws_[-1])
truncated_rv_ = TruncationCheck(f"Truncation did not converge in {max_n_steps} steps")(
truncated_rv_, convergence_
)
# Sort updates of each RNG so that they show in the same order as the input RNGs

def sort_updates(update):
rng, next_rng = update
return graph_inputs.index(rng)

next_rngs = [next_rng for rng, next_rng in sorted(updates.items(), key=sort_updates)]

return TruncatedRV(
base_rv_op=dist.owner.op,
inputs=graph_inputs_,
outputs=[truncated_rv_, *next_rngs],
ndim_supp=0,
max_n_steps=max_n_steps,
)(*graph_inputs)

@staticmethod
def _create_logcdf_exprs(
base_rv: TensorVariable,
value: TensorVariable,
lower: TensorVariable,
upper: TensorVariable,
) -> tuple[TensorVariable, TensorVariable]:
"""Create lower and upper logcdf expressions for base_rv.
Uses `value` as a template for broadcasting.
"""
# For left truncated discrete RVs, we need to include the whole lower bound.
lower_value = lower - 1 if base_rv.type.dtype.startswith("int") else lower
lower_value = pt.full_like(value, lower_value, dtype=config.floatX)
upper_value = pt.full_like(value, upper, dtype=config.floatX)
lower_logcdf = logcdf(base_rv, lower_value, warn_rvs=False)
upper_logcdf = graph_replace(lower_logcdf, {lower_value: upper_value})
return lower_logcdf, upper_logcdf


@_change_dist_size.register(TruncatedRV)
def change_truncated_size(op: TruncatedRV, truncated_rv, new_size, expand):
@@ -367,7 +375,7 @@ def truncated_logprob(op, values, *inputs, **kwargs):
base_rv_op = op.base_rv_op
base_rv = base_rv_op.make_node(*rv_inputs).default_output()
base_logp = logp(base_rv, value)
lower_logcdf, upper_logcdf = Truncated._create_logcdf_exprs(base_rv, value, lower, upper)
lower_logcdf, upper_logcdf = TruncatedRV._create_logcdf_exprs(base_rv, value, lower, upper)
if base_rv_op.name:
base_logp.name = f"{base_rv_op}_logprob"
lower_logcdf.name = f"{base_rv_op}_lower_logcdf"
@@ -408,7 +416,7 @@ def truncated_logcdf(op: TruncatedRV, value, *inputs, **kwargs):

base_rv = op.base_rv_op.make_node(*rv_inputs).default_output()
base_logcdf = logcdf(base_rv, value)
lower_logcdf, upper_logcdf = Truncated._create_logcdf_exprs(base_rv, value, lower, upper)
lower_logcdf, upper_logcdf = TruncatedRV._create_logcdf_exprs(base_rv, value, lower, upper)

is_lower_bounded = not (isinstance(lower, TensorConstant) and np.all(np.isneginf(lower.value)))
is_upper_bounded = not (isinstance(upper, TensorConstant) and np.all(np.isinf(upper.value)))
11 changes: 11 additions & 0 deletions pymc/pytensorf.py
Original file line number Diff line number Diff line change
@@ -1089,3 +1089,14 @@ def toposort_replace(
reverse=reverse,
)
fgraph.replace_all(sorted_replacements, import_missing=True)


def normalize_rng_param(rng: None | Variable) -> Variable:
"""Validate rng is a valid type or create a new one if None"""
if rng is None:
rng = pytensor.shared(np.random.default_rng())
elif not isinstance(rng.type, RandomType):
raise TypeError(
"The type of rng should be an instance of either RandomGeneratorType or RandomStateType"
)
return rng
23 changes: 17 additions & 6 deletions pymc/testing.py
Original file line number Diff line number Diff line change
@@ -29,6 +29,7 @@
from pytensor.graph.basic import Variable
from pytensor.graph.rewriting.basic import in2out
from pytensor.tensor import TensorVariable
from pytensor.tensor.random.op import RandomVariable
from scipy import special as sp
from scipy import stats as st

@@ -897,7 +898,18 @@ def check_pymc_draws_match_reference(self):
)

def check_pymc_params_match_rv_op(self):
pytensor_dist_inputs = self.pymc_rv.get_parents()[0].inputs[3:]
op = self.pymc_rv.owner.op
if isinstance(op, RandomVariable):
_, _, _, *pytensor_dist_inputs = self.pymc_rv.owner.inputs
else:
inputs_signature, _ = op.signature.split("->")
pytensor_dist_inputs = [
inp
for inp, inp_signature in zip(
self.pymc_rv.owner.inputs, inputs_signature.split(",")
)
if inp_signature not in ("[rng]", "[size]")
]
assert len(self.expected_rv_op_params) == len(pytensor_dist_inputs)
for (expected_name, expected_value), actual_variable in zip(
self.expected_rv_op_params.items(), pytensor_dist_inputs
@@ -917,18 +929,17 @@ def check_rv_size(self):
expected_symbolic = tuple(pymc_rv.shape.eval())
actual = pymc_rv.eval().shape
assert actual == expected_symbolic
assert expected_symbolic == expected
assert expected_symbolic == expected, (size, expected_symbolic, expected)

# test multi-parameters sampling for univariate distributions (with univariate inputs)
if (
self.pymc_dist.rv_op.ndim_supp == 0
and self.pymc_dist.rv_op.ndims_params
and sum(self.pymc_dist.rv_op.ndims_params) == 0
self.pymc_dist.rv_type.ndim_supp == 0
and self.pymc_dist.rv_type.ndims_params
and sum(self.pymc_dist.rv_type.ndims_params) == 0
):
params = {
k: p * np.ones(self.repeated_params_shape) for k, p in self.pymc_dist_params.items()
}
self._instantiate_pymc_rv(params)
sizes_to_check = [None, self.repeated_params_shape, (5, self.repeated_params_shape)]
sizes_expected = [
(self.repeated_params_shape,),
20 changes: 10 additions & 10 deletions tests/distributions/test_distribution.py
Original file line number Diff line number Diff line change
@@ -749,27 +749,27 @@ def dist(p, size):

out = CustomDist.dist([0.25, 0.75], dist=dist, signature="(p)->()")
# Size and updates are added automatically to the signature
assert out.owner.op.signature == "(),(p)->(),()"
assert out.owner.op.signature == "[size],(p),[rng]->(),[rng]"
assert out.owner.op.ndim_supp == 0
assert out.owner.op.ndims_params == [0, 1]
assert out.owner.op.ndims_params == [1]

# When recreated internally, the whole signature may already be known
out = CustomDist.dist([0.25, 0.75], dist=dist, signature="(),(p)->(),()")
assert out.owner.op.signature == "(),(p)->(),()"
out = CustomDist.dist([0.25, 0.75], dist=dist, signature="[size],(p),[rng]->(),[rng]")
assert out.owner.op.signature == "[size],(p),[rng]->(),[rng]"
assert out.owner.op.ndim_supp == 0
assert out.owner.op.ndims_params == [0, 1]
assert out.owner.op.ndims_params == [1]

# A safe signature can be inferred from ndim_supp and ndims_params
out = CustomDist.dist([0.25, 0.75], dist=dist, ndim_supp=0, ndims_params=[0, 1])
assert out.owner.op.signature == "(),(i10)->(),()"
out = CustomDist.dist([0.25, 0.75], dist=dist, ndim_supp=0, ndims_params=[1])
assert out.owner.op.signature == "[size],(i00),[rng]->(),[rng]"
assert out.owner.op.ndim_supp == 0
assert out.owner.op.ndims_params == [0, 1]
assert out.owner.op.ndims_params == [1]

# Otherwise be default we assume everything is scalar, even though it's wrong in this case
out = CustomDist.dist([0.25, 0.75], dist=dist)
assert out.owner.op.signature == "(),()->(),()"
assert out.owner.op.signature == "[size],(),[rng]->(),[rng]"
assert out.owner.op.ndim_supp == 0
assert out.owner.op.ndims_params == [0, 0]
assert out.owner.op.ndims_params == [0]


class TestSymbolicRandomVariable:
8 changes: 7 additions & 1 deletion tests/sampling/test_jax.py
Original file line number Diff line number Diff line change
@@ -34,7 +34,7 @@
import pymc as pm

from pymc import ImputationWarning
from pymc.distributions.multivariate import PosDefMatrix
from pymc.distributions.multivariate import DirichletMultinomial, PosDefMatrix
from pymc.sampling.jax import (
_get_batched_jittered_initial_points,
_get_log_likelihood,
@@ -511,3 +511,9 @@ def test_convergence_warnings(caplog, nuts_sampler):

[record] = caplog.records
assert re.match(r"There were \d+ divergences after tuning", record.message)


def test_dirichlet_multinomial():
dm = DirichletMultinomial.dist(n=5, a=np.eye(3) * 1e6 + 0.01)
dm_draws = pm.draw(dm, mode="JAX")
np.testing.assert_equal(dm_draws, np.eye(3) * 5)