Skip to content

Fix group selection in sample_posterior_predictive when predictions=True is passed in kwargs #426

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
27 changes: 20 additions & 7 deletions pymc_extras/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,7 @@ def predict(
self,
X_pred: np.ndarray | pd.DataFrame | pd.Series,
extend_idata: bool = True,
predictions: bool = True,
**kwargs,
) -> np.ndarray:
"""
Expand All @@ -542,6 +543,9 @@ def predict(
The input data used for prediction.
extend_idata : Boolean determining whether the predictions should be added to inference data object.
Defaults to True.
predictions : bool
Whether to use the predictions group for posterior predictive sampling.
Defaults to True.
**kwargs: Additional arguments to pass to pymc.sample_posterior_predictive

Returns
Expand All @@ -559,7 +563,7 @@ def predict(
"""

posterior_predictive_samples = self.sample_posterior_predictive(
X_pred, extend_idata, combined=False, **kwargs
X_pred, extend_idata, combined=False, predictions=predictions, **kwargs
)

if self.output_var not in posterior_predictive_samples:
Expand Down Expand Up @@ -624,7 +628,9 @@ def sample_prior_predictive(

return prior_predictive_samples

def sample_posterior_predictive(self, X_pred, extend_idata, combined, **kwargs):
def sample_posterior_predictive(
self, X_pred, extend_idata, combined, predictions=True, **kwargs
):
"""
Sample from the model's posterior predictive distribution.

Expand All @@ -634,6 +640,8 @@ def sample_posterior_predictive(self, X_pred, extend_idata, combined, **kwargs):
The input data used for prediction using prior distribution..
extend_idata : Boolean determining whether the predictions should be added to inference data object.
Defaults to False.
predictions : Boolean determing whether to use the predictions group for posterior predictive sampling.
Defaults to True.
combined: Combine chain and draw dims into sample. Won't work if a dim named sample already exists.
Defaults to True.
**kwargs: Additional arguments to pass to pymc.sample_posterior_predictive
Expand All @@ -646,13 +654,15 @@ def sample_posterior_predictive(self, X_pred, extend_idata, combined, **kwargs):
self._data_setter(X_pred)

with self.model: # sample with new input data
post_pred = pm.sample_posterior_predictive(self.idata, **kwargs)
post_pred = pm.sample_posterior_predictive(
self.idata, predictions=predictions, **kwargs
)
if extend_idata:
self.idata.extend(post_pred, join="right")

posterior_predictive_samples = az.extract(
post_pred, "posterior_predictive", combined=combined
)
group_name = "predictions" if predictions else "posterior_predictive"

posterior_predictive_samples = az.extract(post_pred, group_name, combined=combined)

return posterior_predictive_samples

Expand Down Expand Up @@ -700,6 +710,7 @@ def predict_posterior(
X_pred: np.ndarray | pd.DataFrame | pd.Series,
extend_idata: bool = True,
combined: bool = True,
predictions: bool = True,
**kwargs,
) -> xr.DataArray:
"""
Expand All @@ -713,6 +724,8 @@ def predict_posterior(
Defaults to True.
combined: Combine chain and draw dims into sample. Won't work if a dim named sample already exists.
Defaults to True.
predictions : Boolean determing whether to use the predictions group for posterior predictive sampling.
Defaults to True.
**kwargs: Additional arguments to pass to pymc.sample_posterior_predictive

Returns
Expand All @@ -723,7 +736,7 @@ def predict_posterior(

X_pred = self._validate_data(X_pred)
posterior_predictive_samples = self.sample_posterior_predictive(
X_pred, extend_idata, combined, **kwargs
X_pred, extend_idata, combined, predictions=predictions, **kwargs
)

if self.output_var not in posterior_predictive_samples:
Expand Down
33 changes: 33 additions & 0 deletions tests/test_model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,3 +304,36 @@ def test_id():
).hexdigest()[:16]

assert model_builder.id == expected_id


@pytest.mark.parametrize("predictions", [True, False])
def test_predict_respects_predictions_flag(fitted_model_instance, predictions):
x_pred = np.random.uniform(0, 1, 100)
prediction_data = pd.DataFrame({"input": x_pred})
output_var = fitted_model_instance.output_var

# Snapshot the original posterior_predictive values
pp_before = fitted_model_instance.idata.posterior_predictive[output_var].values.copy()

# Ensure 'predictions' group is not present initially
assert "predictions" not in fitted_model_instance.idata.groups()

# Run prediction with predictions=True or False
fitted_model_instance.predict(
prediction_data["input"],
extend_idata=True,
combined=False,
predictions=predictions,
)

pp_after = fitted_model_instance.idata.posterior_predictive[output_var].values

# Check predictions group presence
if predictions:
assert "predictions" in fitted_model_instance.idata.groups()
# Posterior predictive should remain unchanged
np.testing.assert_array_equal(pp_before, pp_after)
else:
assert "predictions" not in fitted_model_instance.idata.groups()
# Posterior predictive should be updated
np.testing.assert_array_not_equal(pp_before, pp_after)