Skip to content

Commit

Permalink
Add constrained transform (#764)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomicapretto authored Dec 2, 2023
1 parent 2d4b260 commit 312afa2
Show file tree
Hide file tree
Showing 8 changed files with 136 additions and 52 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
* Add configuration facilities to Bambi (#745)
* Interpet submodule now outputs informative messages when computing default values (#745)
* Bambi supports weighted responses (#761)
* Bambi supports constrained responses (#764)

### Maintenance and fixes

Expand Down
36 changes: 35 additions & 1 deletion bambi/backend/terms.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,34 @@ def build_response_distribution(self, kwargs, pymc_backend):
self.name, stateless_dist, lower=lower, upper=upper, observed=observed, dims=dims
)

# Handle constrained responses (through truncated distributions)
elif self.term.is_constrained:
dims = kwargs.pop("dims", None)
data_matrix = kwargs.pop("observed")

# Get values of the response variable
observed = np.squeeze(data_matrix[:, 0])

# Get truncation values
lower = np.squeeze(data_matrix[:, 1])
upper = np.squeeze(data_matrix[:, 2])

# Handle 'None' and scalars appropriately
if np.all(lower == -np.inf):
lower = None
elif np.all(lower == lower[0]):
lower = lower[0]

if np.all(upper == np.inf):
upper = None
elif np.all(upper == upper[0]):
upper = upper[0]

stateless_dist = distribution.dist(**kwargs)
dist_rv = pm.Truncated(
self.name, stateless_dist, lower=lower, upper=upper, observed=observed, dims=dims
)

# Handle weighted responses
elif self.term.is_weighted:
dims = kwargs.pop("dims", None)
Expand All @@ -340,6 +368,7 @@ def build_response_distribution(self, kwargs, pymc_backend):
# Get a weighted version of the response distribution
weighted_dist = make_weighted_distribution(distribution)
dist_rv = weighted_dist(self.name, weights, **kwargs, observed=observed, dims=dims)
# All of the other response kinds are "not special" and thus are handled the same way
else:
dist_rv = distribution(self.name, **kwargs)

Expand All @@ -361,7 +390,12 @@ def robustify_dims(self, pymc_backend, kwargs):
if isinstance(self.family, (Multinomial, DirichletMultinomial)):
return kwargs

if self.term.is_censored or self.term.is_truncated or self.term.is_weighted:
if (
self.term.is_censored
or self.term.is_truncated
or self.term.is_weighted
or self.term.is_constrained
):
return kwargs

dims, data = kwargs["dims"], kwargs["observed"]
Expand Down
16 changes: 14 additions & 2 deletions bambi/families/family.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,9 @@ def posterior_predictive(self, model, posterior, **kwargs):
A data array with the draws from the posterior predictive distribution
"""
response_dist = get_response_dist(model.family)
response_term = model.response_component.response_term
params = model.family.likelihood.params
response_aliased_name = get_aliased_name(model.response_component.response_term)
response_aliased_name = get_aliased_name(response_term)

kwargs.pop("data", None) # Remove the 'data' kwarg
dont_reshape = kwargs.pop("dont_reshape", [])
Expand Down Expand Up @@ -181,7 +182,18 @@ def posterior_predictive(self, model, posterior, **kwargs):
if hasattr(model.family, "transform_kwargs"):
kwargs = model.family.transform_kwargs(kwargs)

output_array = pm.draw(response_dist.dist(**kwargs))
# Handle constrained responses
if response_term.is_constrained:
# Bounds are scalars, we can safely pick them from the first row
lower, upper = response_term.data[0, 1:]
lower = lower if lower != -np.inf else None
upper = upper if upper != np.inf else None
output_array = pm.draw(
pm.Truncated.dist(response_dist.dist(**kwargs), lower=lower, upper=upper)
)
else:
output_array = pm.draw(response_dist.dist(**kwargs))

output_coords_all = xr.merge(output_dataset_list).coords

coord_names = ["chain", "draw", response_aliased_name + "_obs"]
Expand Down
29 changes: 5 additions & 24 deletions bambi/terms/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,17 @@

from bambi.terms.base import BaseTerm

from bambi.terms.utils import is_censored_response, is_truncated_response, is_weighted_response
from bambi.terms.utils import is_response_of_kind


class ResponseTerm(BaseTerm):
def __init__(self, response, family):
self.term = response.term.term
self.family = family
self.is_censored = is_censored_response(self.term)
self.is_truncated = is_truncated_response(self.term)
self.is_weighted = is_weighted_response(self.term)
self.is_censored = is_response_of_kind(self.term, "censored")
self.is_constrained = is_response_of_kind(self.term, "constrained")
self.is_truncated = is_response_of_kind(self.term, "truncated")
self.is_weighted = is_response_of_kind(self.term, "weighted")

@property
def term(self):
Expand Down Expand Up @@ -81,23 +82,3 @@ def __str__(self):
else:
extras += [f"reference: {self.reference}"]
return self.make_str(extras)


# Categorical
# -> Nominal
# -> Binary
# -> Ordinal

# These aren't actually used to do something with data, but mostly to give information to the user
# Well, the ordinal kind can be useful as well.
# class Categorical:
# pass

# class Nominal(Categorical):
# pass

# class Ordinal(Categorical):
# pass

# class Binary(Nominal):
# pass
28 changes: 4 additions & 24 deletions bambi/terms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def is_call_component(component) -> bool:
return isinstance(component, fm.terms.call.Call)


def is_call_of_kind(call, kind):
def is_call_of_kind(call, kind: str) -> bool:
"""Determines if formulae call component is of certain kind
To do so, it checks whether the callee has metadata and whether the 'kind' slot matches the
Expand All @@ -22,31 +22,11 @@ def is_call_of_kind(call, kind):
return hasattr(function, "__metadata__") and function.__metadata__["kind"] == kind


def is_censored_response(term):
"""Determines if a formulae term represents a censored response"""
def is_response_of_kind(term, kind: str) -> bool:
"""Determines if a formulae term represents a response of a certain kind"""
if not is_single_component(term):
return False
component = term.components[0] # get the first (and single) component
if not is_call_component(component):
return False
return is_call_of_kind(component, "censored")


def is_truncated_response(term):
"""Determines if a formulae term represents a truncated response"""
if not is_single_component(term):
return False
component = term.components[0] # get the first (and single) component
if not is_call_component(component):
return False
return is_call_of_kind(component, "truncated")


def is_weighted_response(term):
"""Determines if a formulae term represents a weighted response"""
if not is_single_component(term):
return False
component = term.components[0] # get the first (and single) component
if not is_call_component(component):
return False
return is_call_of_kind(component, "weighted")
return is_call_of_kind(component, kind)
21 changes: 21 additions & 0 deletions bambi/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,26 @@ def truncated(x, lb=None, ub=None):
truncated.__metadata__ = {"kind": "truncated"}


def constrained(x, lb=None, ub=None):
"""Construct an array for a constrained response
It's exactly like truncated, but it's interpreted by Bambi in a different way as this
one truncates/constrains the bounds of a probability distribution, while `truncated()` is
interpreted as the missing data mechanism.
`lb` and `ub` can only be scalar values.
"""
if not (lb is None or isinstance(lb, (int, float))):
raise ValueError("'lb' must be None or scalar.")

if not (ub is None or isinstance(ub, (int, float))):
raise ValueError("'ub' must be None or scalar.")
return truncated(x, lb, ub)


constrained.__metadata__ = {"kind": "constrained"}


def weighted(x, weights):
"""Construct array for a weighted response
Expand Down Expand Up @@ -403,6 +423,7 @@ def get_distance(x):
transformations_namespace = {
"c": c,
"censored": censored,
"constrained": constrained,
"truncated": truncated,
"weighted": weighted,
"log": np.log,
Expand Down
27 changes: 27 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1132,6 +1132,33 @@ def test_truncated_response(self, truncated_data):
assert isinstance(model.backend.model.observed_RVs[0]._owner.op, pm.TruncatedNormal.rv_type)


class TestConstrainedResponse(FitPredictParent):
def test_constrained_response(self, truncated_data):
priors = {
"Intercept": bmb.Prior("Normal", mu=0, sigma=1),
"x": bmb.Prior("Normal", mu=0, sigma=1),
"sigma": bmb.Prior("HalfNormal", sigma=1),
}
model = bmb.Model("constrained(y, -5) ~ x", truncated_data, priors=priors)
idata = self.fit(model, random_seed=121195)
idata = self.predict_oos(model, idata)
assert idata.posterior_predictive["constrained(y, -5)"].to_numpy().min() > -5


model = bmb.Model("constrained(y, ub=5) ~ x", truncated_data, priors=priors)
idata = self.fit(model, random_seed=121195)
idata = self.predict_oos(model, idata)
assert idata.posterior_predictive["constrained(y, ub=5)"].to_numpy().max() < 5

model = bmb.Model("constrained(y, -5, 5) ~ x", truncated_data, priors=priors)
idata = self.fit(model, random_seed=121195)
idata = self.predict_oos(model, idata)
assert idata.posterior_predictive["constrained(y, -5, 5)"].to_numpy().min() > -5
assert idata.posterior_predictive["constrained(y, -5, 5)"].to_numpy().max() < 5




class TestMultinomial(FitPredictParent):
def assert_posterior_predictive(self, model, idata):
y_name = model.response_component.response_term.name
Expand Down
30 changes: 29 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from bambi.utils import listify
from bambi.backend.pymc import probit, cloglog
from bambi.backend.utils import make_weighted_distribution
from bambi.transformations import censored, truncated, weighted
from bambi.transformations import censored, constrained, truncated, weighted


def test_listify():
Expand Down Expand Up @@ -104,6 +104,34 @@ def test_truncated():
truncated(x, ub=np.column_stack([upper_arr, upper_arr]))


def test_constrained():
x = np.array([-3, -2, -1, 0, 0, 0, 1, 1, 2, 3])
lower = -5
upper = 4.5

# Arguments and expected outcomes
iterable = {
"lower": (lower, None, lower),
"upper": (None, upper, upper),
"elower": (lower, -np.inf, lower),
"eupper": (np.inf, upper, upper),
}

for l, u, el, eu in zip(*iterable.values()):
result = constrained(x, lb=l, ub=u)
assert result.shape == (10, 3)
assert (result[:, 0] == x).all()
assert (result[:, 1] == el).all()
assert (result[:, 2] == eu).all()

with pytest.raises(ValueError, match="'lb' must be None or scalar."):
constrained(x, np.array([lower, lower]))


with pytest.raises(ValueError, match="'ub' must be None or scalar."):
constrained(x, ub=np.array([upper, upper]))


def test_weighted():
rng = np.random.default_rng(1234)
weights = 1 + rng.poisson(lam=3, size=100)
Expand Down

0 comments on commit 312afa2

Please sign in to comment.