Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions pymc_extras/prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def custom_transform(x):

from pydantic import InstanceOf, validate_call
from pydantic.dataclasses import dataclass
from pymc.distributions.shape_utils import Dims
from pymc.distributions.shape_utils import Dims, StrongDims

from pymc_extras.deserialize import deserialize, register_deserialization

Expand Down Expand Up @@ -576,7 +576,7 @@ def __init__(
) -> None:
self.distribution = distribution
self.parameters = parameters
self.dims = dims
self.dims: StrongDims = dims
self.centered = centered
self.transform = transform

Expand Down Expand Up @@ -606,12 +606,16 @@ def transform(self, transform: str | None) -> None:
self.pytensor_transform = not transform or _get_transform(transform) # type: ignore

@property
def dims(self) -> Dims:
"""The dimensions of the variable."""
def dims(self) -> StrongDims:
"""The dimensions of the variable.

It will always be a tuple. Empty tuple for scalar variables.

"""
return self._dims

@dims.setter
def dims(self, dims) -> None:
def dims(self, dims: Dims | None) -> None:
if isinstance(dims, str):
dims = (dims,)

Expand Down
Loading