From e3a4393a72a77d624b1bc504584d42666a704873 Mon Sep 17 00:00:00 2001 From: GStechschulte Date: Sat, 13 Apr 2024 07:52:06 +0200 Subject: [PATCH] convienent methods for getting inference names and kwargs --- bambi/backend/inference_methods.py | 88 ++++++++++++++++++++++++------ bambi/backend/pymc.py | 32 ++--------- bambi/models.py | 4 +- 3 files changed, 78 insertions(+), 46 deletions(-) diff --git a/bambi/backend/inference_methods.py b/bambi/backend/inference_methods.py index b66c97eef..1904b9779 100644 --- a/bambi/backend/inference_methods.py +++ b/bambi/backend/inference_methods.py @@ -1,15 +1,44 @@ import importlib +import inspect import operator +import pymc as pm + class InferenceMethods: """Obtain a dictionary of available inference methods for Bambi - models, and or the kwargs that each inference method accepts. + models and or the default kwargs of each inference method. """ def __init__(self): + # In order to access inference methods, a bayeux model must be initialized self.bayeux_model = bayeux_model() - + self.bayeux_methods = self._get_bayeux_methods(bayeux_model()) + self.pymc_methods = self._pymc_methods() + + def _get_bayeux_methods(self, model): + # Bambi only supports bayeux MCMC methods + mcmc_methods = model.methods.get("mcmc") + return {"mcmc": mcmc_methods} + + def _pymc_methods(self): + return {"mcmc": ["mcmc"], "vi": ["vi"]} + + def _remove_parameters(self, fn_signature_dict): + # Remove 'pm.sample' parameters that are irrelevant for Bambi users + params_to_remove = [ + "progressbar", + "progressbar_theme", + "var_names", + "nuts_sampler", + "return_inferencedata", + "idata_kwargs", + "callback", + "mp_ctx", + "model", + ] + return {k: v for k, v in fn_signature_dict.items() if k not in params_to_remove} + def get_kwargs(self, method): """Get the default kwargs for a given inference method. @@ -23,29 +52,31 @@ def get_kwargs(self, method): dict The default kwargs for the inference method. """ - # TODO: Somehow add the ability to retrieve PyMC kwargs of - # TODO: `pymc.sampling.mcmc.sample` - # Bambi only supports bayeux MCMC methods - if method not in self.bayeux_model.methods["mcmc"]: + if method in self.bayeux_methods.get("mcmc"): + bx_method = operator.attrgetter(method)( + self.bayeux_model.mcmc # pylint: disable=no-member + ) + return bx_method.get_kwargs() + elif method in self.pymc_methods.get("mcmc"): + return self._remove_parameters(get_default_signature(pm.sample)) + elif method in self.pymc_methods.get("vi"): + return get_default_signature(pm.ADVI.fit) + else: raise ValueError( f"Inference method '{method}' not found in the list of available" - " methods" - ) + " methods. Use `bmb.inference_methods.names` to list the available methods." + ) - bx_method = operator.attrgetter(method)(self.bayeux_model.mcmc) - return bx_method.get_kwargs() - @property def names(self): - # TODO: Add PyMC MCMC methods - return self.bayeux_model.methods.get("mcmc") + return {"pymc": self.pymc_methods, "bayeux": self.bayeux_methods} def bayeux_model(): """Dummy bayeux model for obtaining inference methods. - A dummy model is needed because algorithms are dynamically determined at - runtime, based on the libraries that are installed. A model can give + A dummy model is needed because algorithms are dynamically determined at + runtime, based on the libraries that are installed. A model can give programmatic access to the available algorithms via the `methods` attribute. Returns @@ -57,7 +88,32 @@ def bayeux_model(): return {"mcmc": []} import bayeux as bx # pylint: disable=import-outside-toplevel + return bx.Model(lambda x: -(x**2), 0.0) -inference_methods = InferenceMethods() \ No newline at end of file +def get_default_signature(fn): + """Get the default parameter values of a function. + + This function inspects the signature of the provided function and returns + a dictionary containing the default values of its parameters. + + Parameters + ---------- + fn : callable + The function for which default argument values are to be retrieved. + + Returns + ------- + dict + A dictionary mapping argument names to their default values. + + """ + defaults = {} + for key, val in inspect.signature(fn).parameters.items(): + if val.default is not inspect.Signature.empty: + defaults[key] = val.default + return defaults + + +inference_methods = InferenceMethods() diff --git a/bambi/backend/pymc.py b/bambi/backend/pymc.py index 82b646ebe..b5d7865eb 100644 --- a/bambi/backend/pymc.py +++ b/bambi/backend/pymc.py @@ -1,5 +1,4 @@ import functools -import importlib import logging import operator import traceback @@ -14,6 +13,7 @@ import pytensor.tensor as pt from pytensor.tensor.special import softmax +from bambi.backend.inference_methods import inference_methods from bambi.backend.links import cloglog, identity, inverse_squared, logit, probit, arctan_2 from bambi.backend.model_components import ConstantComponent, DistributionalComponent from bambi.utils import get_aliased_name @@ -47,8 +47,8 @@ def __init__(self): self.model = None self.spec = None self.components = {} - self.bayeux_methods = _get_bayeux_methods() - self.pymc_methods = {"mcmc": ["mcmc"], "vi": ["vi"]} + self.bayeux_methods = inference_methods.names["bayeux"] + self.pymc_methods = inference_methods.names["pymc"] def build(self, spec): """Compile the PyMC model from an abstract model specification. @@ -338,8 +338,7 @@ def _run_laplace(self, draws, omit_offsets, include_mean): Mainly for pedagogical use, provides reasonable results for approximately Gaussian posteriors. The approximation can be very poor for some models - like hierarchical ones. Use ``mcmc``, ``vi``, or JAX based MCMC methods - for better approximations. + like hierarchical ones. Use MCMC or VI methods for better approximations. Parameters ---------- @@ -388,10 +387,6 @@ def constant_components(self): def distributional_components(self): return {k: v for k, v in self.components.items() if isinstance(v, DistributionalComponent)} - @property - def inference_methods(self): - return {"pymc": self.pymc_methods, "bayeux": self.bayeux_methods} - def _posterior_samples_to_idata(samples, model): """Create InferenceData from samples. @@ -431,22 +426,3 @@ def _posterior_samples_to_idata(samples, model): idata = pm.to_inference_data(pm.backends.base.MultiTrace([strace]), model=model) return idata - - -def _get_bayeux_methods(): - """Gets a dictionary of usable bayeux methods if the bayeux package is installed - within the user's environment. - - Returns - ------- - dict - A dict where the keys are the module names and the values are the methods - available in that module. - """ - if importlib.util.find_spec("bayeux") is None: - return {"mcmc": []} - - import bayeux as bx # pylint: disable=import-outside-toplevel - - # Dummy log density to get access to all methods - return bx.Model(lambda x: -(x**2), 0.0).methods diff --git a/bambi/models.py b/bambi/models.py index ecb57700f..74286dbed 100644 --- a/bambi/models.py +++ b/bambi/models.py @@ -267,7 +267,7 @@ def fit( Finally, ``"laplace"``, in which case a Laplace approximation is used and is not recommended other than for pedagogical use. To get a list of JAX based inference methods, call - ``model.backend.inference_methods['bayeux']``. This will return a dictionary of the + ``bmb.inference_methods.names['bayeux']``. This will return a dictionary of the available methods such as ``blackjax_nuts``, ``numpyro_nuts``, among others. init : str Initialization method. Defaults to ``"auto"``. The available methods are: @@ -307,7 +307,7 @@ def fit( ------- An ArviZ ``InferenceData`` instance if inference_method is ``"mcmc"`` (default), "laplace", or one of the MCMC methods in - ``model.backend.inference_methods['bayeux']['mcmc]``. + ``bmb.inference_methods.names['bayeux']['mcmc]``. An ``Approximation`` object if ``"vi"``. """ method = kwargs.pop("method", None)