Skip to content

Adding three distributions: Spherical, OrthogonalMatrices and NormalSingularValues (working names) #465

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 130 additions & 0 deletions pymc_extras/distributions/multivariate/normal_singular_values.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# Copyright 2025 The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytensor.tensor as pt

from pymc.distributions.continuous import Continuous
from pymc.distributions.distribution import SymbolicRandomVariable
from pymc.distributions.shape_utils import (
rv_size_is_none,
)
from pymc.distributions.transforms import _default_transform
from pymc.pytensorf import normalize_rng_param
from pytensor.tensor import get_underlying_scalar_constant_value
from pytensor.tensor.random.utils import (
normalize_size_param,
)

__all__ = ["NormalSingularValues"]

from pymc.logprob.transforms import Transform


# TODO: this is a lot of work to just get a list normally distributed variables
class NormalSingularValuesRV(SymbolicRandomVariable):
name = "normalsingularvalues"
extended_signature = "[rng],[size],(),(m)->[rng],(m)" # TODO: check if this is correct
_print_name = ("NormalSingularValuesRV", "\\operatorname{NormalSingularValuesRV}")

def make_node(self, rng, size, n, m):
n = pt.as_tensor_variable(n)
m = pt.as_tensor_variable(m)
if not all(n.type.broadcastable) or not all(m.type.broadcastable):
raise ValueError("n and m must be scalars.")

return super().make_node(rng, size, n, m)

@classmethod
def rv_op(cls, n: int, m: int, *, rng=None, size=None):
# We flatten the size to make operations easier, and then rebuild it
n = pt.as_tensor(n, ndim=0, dtype=int)
m = pt.as_tensor(m, ndim=0, dtype=int)

rng = normalize_rng_param(rng)
size = normalize_size_param(size)

# TODO: currently assume size = 1. Fix this once everything is working
D = get_underlying_scalar_constant_value(n)
Q = get_underlying_scalar_constant_value(m)

# Perform a direct computation via SVD of a normal matrix
sz = [] if rv_size_is_none(size) else size
next_rng, z = pt.random.normal(0, 1, size=(*sz, D, Q), rng=rng).owner.outputs
_, samples, _ = pt.linalg.svd(z)

return cls(
inputs=[rng, size, n, m],
outputs=[next_rng, samples],
)(rng, size, n, m)

return samples


# This is adapted from ordered transform.
# Might make sense to just make that transform more generic by
# allowing it to take parameters "positive" and "ascending"
# and then just use that here.
class PosRevOrdered(Transform):
name = "posrevordered"

def __init__(self, ndim_supp=None):
pass

def backward(self, value, *inputs):
return pt.cumsum(pt.exp(value[..., ::-1]), axis=-1)[..., ::-1]

def forward(self, value, *inputs):
y = pt.zeros(value.shape)
y = pt.set_subtensor(y[..., -1], pt.log(value[..., -1]))
y = pt.set_subtensor(y[..., :-1], pt.log(value[..., :-1] - value[..., 1:]))
return y

def log_jac_det(self, value, *inputs):
return pt.sum(value, axis=-1)


class NormalSingularValues(Continuous):
rv_type = NormalSingularValuesRV
rv_op = NormalSingularValuesRV.rv_op

@classmethod
def dist(cls, n, m, **kwargs):
n = pt.as_tensor_variable(n).astype(int)
m = pt.as_tensor_variable(m).astype(int)
return super().dist([n, m], **kwargs)

def support_point(rv, *args):
return pt.linspace(1, 0.5, rv.shape[-1])

def logp(sigma, n, m):
# First term: prod[exp(-0.5*sigma**2)]
log_p = -0.5 * pt.sum(sigma**2)

# Second + Fourth term (ignoring constant factor)
# prod(sigma**(D-Q-1)) + prod(2*sigma)) = prod(2*sigma**(D-Q))
log_p += (n - m) * pt.sum(pt.log(sigma))

# Third term: prod[prod[ |s1**2-s2**2| ]]
# li = pt.triu_indices(m,k=1)
# log_p += pt.log((sigma[:,None]**2 - sigma[None,:]**2)[li]).sum()
log_p += (
pt.log(pt.eye(m) + pt.abs(sigma[:, None] ** 2 - sigma[None, :] ** 2) + 1e-6).sum() / 2.0
)

return log_p


@_default_transform.register(NormalSingularValues)
def lkjcorr_default_transform(op, rv):
return PosRevOrdered()
77 changes: 77 additions & 0 deletions pymc_extras/distributions/multivariate/orthogonal_matrix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright 2025 The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytensor.tensor as pt

from pytensor.tensor import TensorVariable

from pymc_extras.distributions.multivariate.spherical import Spherical

__all__ = ["SemiOrthogonalMatrix"]


class SemiOrthogonalMatrix:
def __new__(cls, name, D, Q, **kwargs):
dof = D * Q - Q * (Q - 1) // 2 # Total degrees of freedom

vs, pos = pt.zeros(dof), 0
for q in range(Q):
vq = Spherical(f"{name}_v{q}", D - q)
vs = pt.set_subtensor(vs[pos : pos + D - q], vq)
pos += D - q

return cls.orth_from_vs(vs, D, Q)

# Create a householder matrix from a vector
@classmethod
def _householder_matrix(cls, v: TensorVariable, D: int) -> TensorVariable:
Q = v.shape[0]
H = pt.eye(D)
sgn = 1.0 # Original paper recommends sign(v[0]) but that causes divergences
u = pt.inc_subtensor(v[0], sgn * pt.linalg.norm(v))
H = pt.set_subtensor(
H[-Q:, -Q:], -sgn * (pt.eye(Q, Q) - 2 * u[:, None] * u[None, :] / (pt.dot(u, u) + 1e-6))
)
return H

# Construct an orthogonal matrix from a vector of normally distributed values
# as a cumulative product of householder matrices
@classmethod
def orth_from_vs(cls, vs: TensorVariable, D: int, Q: int) -> TensorVariable:
"""Construct an orthogonal matrix from a set of direction vectors v"""
H_p = pt.eye(D)
pos, q = 0, 0
dof = D * Q - Q * (Q - 1) // 2
while pos < dof:
v = vs[pos : pos + D - q]
H = cls._householder_matrix(v, D)
H_p = H @ H_p
pos += D - q
q += 1
return H_p[:q, :]

@classmethod
def vs_from_orth(cls, U: TensorVariable, D: int, Q: int) -> TensorVariable:
"""Get the vs values that would lead to orthogonal matrix U. Inverse of orth_from_vs"""
vs = []
vl = D * Q - Q * (Q - 1) // 2
vs, pos = pt.zeros(vl), 0
for q in range(Q):
v = U[q:, q] # Top row of the remaining submatrix

vs = pt.set_subtensor(vs[pos : pos + D - q], v)
H = cls._householder_matrix(v, D)
U = H.dot(U)
pos += D - q
return vs
89 changes: 89 additions & 0 deletions pymc_extras/distributions/multivariate/spherical.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright 2025 The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pymc as pm
import pytensor.tensor as pt

from pymc.distributions.continuous import Continuous
from pymc.distributions.distribution import SymbolicRandomVariable
from pymc.distributions.shape_utils import (
rv_size_is_none,
)
from pymc.pytensorf import normalize_rng_param
from pytensor.tensor import get_underlying_scalar_constant_value
from pytensor.tensor.random.utils import (
normalize_size_param,
)

__all__ = ["Spherical"]


class SphericalRV(SymbolicRandomVariable):
name = "spherical"
extended_signature = "[rng],[size],(n)->[rng],(n)" # TODO: check if this is correct
_print_name = ("SphericalRV", "\\operatorname{SphericalRV}")

def make_node(self, rng, size, n):
n = pt.as_tensor_variable(n)
return super().make_node(rng, size, n)

@classmethod
def rv_op(cls, n, *, rng=None, size=None):
rng = normalize_rng_param(rng)
size = normalize_size_param(size)
n = pt.as_tensor(n, ndim=0, dtype=int)
nv = get_underlying_scalar_constant_value(n)

# Perform a direct computation via SVD of a normal matrix
sz = [] if rv_size_is_none(size) else size

next_rng, z = pt.random.normal(0, 1, size=(*sz, nv), rng=rng).owner.outputs
samples = z / pt.sqrt(z * z.sum(axis=-1, keepdims=True) + 1e-6)
# TODO: scale by the .dist given

return cls(
inputs=[rng, size, n],
outputs=[next_rng, samples],
)(rng, size, n)

return samples


class Spherical(Continuous):
rv_type = SphericalRV
rv_op = SphericalRV.rv_op

@classmethod
def dist(cls, n, **kwargs):
n = pt.as_tensor_variable(n).astype(int)
return super().dist([n], **kwargs)

def support_point(rv, size, n, *args):
return pt.ones(rv.shape) / pt.sqrt(n)

def logp(value, n):
# TODO: take dist as a parameter instead of hardcoding
dist = pm.Gamma.dist(50, 50)

# Get the radius
r = pt.sqrt(pt.sum(value**2))

# Get the log prior of the radius
log_p = pm.logp(dist, r)
# log_p = pm.logp(pm.TruncatedNormal.dist(1,lower=0),r)

# Add the log det jacobian for radius
log_p += (value.shape[-1] - 1) * pt.log(r)

return log_p
Loading