Skip to content

Commit ae0f260

Browse files
authored
Update for PyMC v5.11 (#320)
1 parent ee07caa commit ae0f260

11 files changed

+17
-17
lines changed

conda-envs/environment-test.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,6 @@ dependencies:
1111
- xhistogram
1212
- statsmodels
1313
- pip:
14-
- pymc>=5.10.0 # CI was failing to resolve
14+
- pymc>=5.11.0 # CI was failing to resolve
1515
- blackjax
1616
- scikit-learn

conda-envs/windows-environment-test.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,6 @@ dependencies:
1010
- xhistogram
1111
- statsmodels
1212
- pip:
13-
- pymc>=5.10.0 # CI was failing to resolve
13+
- pymc>=5.11.0 # CI was failing to resolve
1414
- blackjax
1515
- scikit-learn

pymc_experimental/distributions/continuous.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ def logcdf(value, mu, sigma, xi):
211211
logc, sigma > 0, pt.and_(xi > -1, xi < 1), msg="sigma > 0 or -1 < xi < 1"
212212
)
213213

214-
def moment(rv, size, mu, sigma, xi):
214+
def support_point(rv, size, mu, sigma, xi):
215215
r"""
216216
Using the mode, as the mean can be infinite when :math:`\xi > 1`
217217
"""

pymc_experimental/distributions/discrete.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def dist(cls, mu, lam, **kwargs):
143143
lam = pt.as_tensor_variable(lam)
144144
return super().dist([mu, lam], **kwargs)
145145

146-
def moment(rv, size, mu, lam):
146+
def support_point(rv, size, mu, lam):
147147
mean = pt.floor(mu / (1 - lam))
148148
if not rv_size_is_none(size):
149149
mean = pt.full(size, mean)

pymc_experimental/distributions/timeseries.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
from pymc.distributions.distribution import (
1010
Distribution,
1111
SymbolicRandomVariable,
12-
_moment,
13-
moment,
12+
_support_point,
13+
support_point,
1414
)
1515
from pymc.distributions.shape_utils import (
1616
_change_dist_size,
@@ -221,9 +221,9 @@ def change_mc_size(op, dist, new_size, expand=False):
221221
return DiscreteMarkovChain.rv_op(*dist.owner.inputs[:-1], size=new_size, n_lags=op.n_lags)
222222

223223

224-
@_moment.register(DiscreteMarkovChainRV)
224+
@_support_point.register(DiscreteMarkovChainRV)
225225
def discrete_mc_moment(op, rv, P, steps, init_dist, state_rng):
226-
init_dist_moment = moment(init_dist)
226+
init_dist_moment = support_point(init_dist)
227227
n_lags = op.n_lags
228228

229229
def greedy_transition(*args):

pymc_experimental/model/marginal_model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -655,7 +655,7 @@ def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> Tuple[int, ...]:
655655
return tuple(range(pt.get_vector_length(p_param)))
656656
elif isinstance(op, DiscreteUniform):
657657
lower, upper = constant_fold(rv.owner.inputs[3:])
658-
return tuple(range(lower, upper + 1))
658+
return tuple(np.arange(lower, upper + 1))
659659
elif isinstance(op, DiscreteMarkovChain):
660660
P = rv.owner.inputs[0]
661661
return tuple(range(pt.get_vector_length(P[-1])))

pymc_experimental/tests/distributions/test_continuous.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
R,
2626
Rplus,
2727
Rplusbig,
28-
assert_moment_is_expected,
28+
assert_support_point_is_expected,
2929
check_logcdf,
3030
check_logp,
3131
seeded_scipy_distribution_builder,
@@ -111,10 +111,10 @@ def ref_logcdf(value, mu, sigma, xi):
111111
),
112112
],
113113
)
114-
def test_genextreme_moment(self, mu, sigma, xi, size, expected):
114+
def test_genextreme_support_point(self, mu, sigma, xi, size, expected):
115115
with pm.Model() as model:
116116
GenExtreme("x", mu=mu, sigma=sigma, xi=xi, size=size)
117-
assert_moment_is_expected(model, expected)
117+
assert_support_point_is_expected(model, expected)
118118

119119
def test_gen_extreme_scipy_kwarg(self):
120120
dist = GenExtreme.dist(xi=1, scipy=False)

pymc_experimental/tests/distributions/test_discrete.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
Domain,
2424
I,
2525
Rplus,
26-
assert_moment_is_expected,
26+
assert_support_point_is_expected,
2727
check_logp,
2828
discrete_random_tester,
2929
)
@@ -123,7 +123,7 @@ def test_logp_lam_expected_moments(self):
123123
def test_moment(self, mu, lam, size, expected):
124124
with pm.Model() as model:
125125
GeneralizedPoisson("x", mu=mu, lam=lam, size=size)
126-
assert_moment_is_expected(model, expected)
126+
assert_support_point_is_expected(model, expected)
127127

128128

129129
class TestBetaNegativeBinomial:

pymc_experimental/tests/distributions/test_discrete_markov_chain.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def test_moment_function(self):
129129
state = chain_np[i]
130130
chain_np[i + 1] = np.argmax(P_np[state])
131131

132-
dmc_chain = pm.distributions.distribution.moment(chain).eval()
132+
dmc_chain = pm.distributions.distribution.support_point(chain).eval()
133133

134134
assert np.allclose(dmc_chain, chain_np)
135135

pymc_experimental/tests/model/test_marginal_model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def test_marginalized_bernoulli_logp():
5858
marginal_rv_node = FiniteDiscreteMarginalRV(
5959
[mu],
6060
[idx, y],
61-
ndim_supp=None,
61+
ndim_supp=0,
6262
n_updates=0,
6363
)(
6464
mu

requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
pymc>=5.10.0
1+
pymc>=5.11.0
22
scikit-learn

0 commit comments

Comments
 (0)