diff --git a/bambi/__init__.py b/bambi/__init__.py index 660d6724c..fdec8ec16 100644 --- a/bambi/__init__.py +++ b/bambi/__init__.py @@ -4,7 +4,7 @@ from pymc import math -from .backend import PyMCModel +from .backend import inference_methods, PyMCModel from .config import config from .data import clear_data_home, load_data from .families import Family, Likelihood, Link @@ -25,6 +25,7 @@ "Formula", "clear_data_home", "config", + "inference_methods", "load_data", "math", ] diff --git a/bambi/backend/__init__.py b/bambi/backend/__init__.py index 6ee2a4aa3..daef1924c 100644 --- a/bambi/backend/__init__.py +++ b/bambi/backend/__init__.py @@ -1,3 +1,4 @@ from .pymc import PyMCModel +from .inference_methods import inference_methods -__all__ = ["PyMCModel"] +__all__ = ["inference_methods", "PyMCModel"] diff --git a/bambi/backend/inference_methods.py b/bambi/backend/inference_methods.py new file mode 100644 index 000000000..900d9c262 --- /dev/null +++ b/bambi/backend/inference_methods.py @@ -0,0 +1,119 @@ +import importlib +import inspect +import operator + +import pymc as pm + + +class InferenceMethods: + """Obtain a dictionary of available inference methods for Bambi + models and or the default kwargs of each inference method. + """ + + def __init__(self): + # In order to access inference methods, a bayeux model must be initialized + self.bayeux_model = bayeux_model() + self.bayeux_methods = self._get_bayeux_methods(bayeux_model()) + self.pymc_methods = self._pymc_methods() + + def _get_bayeux_methods(self, model): + # Bambi only supports bayeux MCMC methods + mcmc_methods = model.methods.get("mcmc") + return {"mcmc": mcmc_methods} + + def _pymc_methods(self): + return {"mcmc": ["mcmc"], "vi": ["vi"]} + + def _remove_parameters(self, fn_signature_dict): + # Remove 'pm.sample' parameters that are irrelevant for Bambi users + params_to_remove = [ + "progressbar", + "progressbar_theme", + "var_names", + "nuts_sampler", + "return_inferencedata", + "idata_kwargs", + "callback", + "mp_ctx", + "model", + ] + return {k: v for k, v in fn_signature_dict.items() if k not in params_to_remove} + + def get_kwargs(self, method): + """Get the default kwargs for a given inference method. + + Parameters + ---------- + method : str + The name of the inference method. + + Returns + ------- + dict + The default kwargs for the inference method. + """ + if method in self.bayeux_methods.get("mcmc"): + bx_method = operator.attrgetter(method)( + self.bayeux_model.mcmc # pylint: disable=no-member + ) + return bx_method.get_kwargs() + elif method in self.pymc_methods.get("mcmc"): + return self._remove_parameters(get_default_signature(pm.sample)) + elif method in self.pymc_methods.get("vi"): + return get_default_signature(pm.ADVI.fit) + else: + raise ValueError( + f"Inference method '{method}' not found in the list of available" + " methods. Use `bmb.inference_methods.names` to list the available methods." + ) + + @property + def names(self): + return {"pymc": self.pymc_methods, "bayeux": self.bayeux_methods} + + +def bayeux_model(): + """Dummy bayeux model for obtaining inference methods. + + A dummy model is needed because algorithms are dynamically determined at + runtime, based on the libraries that are installed. A model can give + programmatic access to the available algorithms via the `methods` attribute. + + Returns + ------- + bayeux.Model + A dummy model with a simple quadratic likelihood function. + """ + if importlib.util.find_spec("bayeux") is None: + return {"mcmc": []} + + import bayeux as bx # pylint: disable=import-outside-toplevel + + return bx.Model(lambda x: -(x**2), 0.0) + + +def get_default_signature(fn): + """Get the default parameter values of a function. + + This function inspects the signature of the provided function and returns + a dictionary containing the default values of its parameters. + + Parameters + ---------- + fn : callable + The function for which default argument values are to be retrieved. + + Returns + ------- + dict + A dictionary mapping argument names to their default values. + + """ + defaults = {} + for key, val in inspect.signature(fn).parameters.items(): + if val.default is not inspect.Signature.empty: + defaults[key] = val.default + return defaults + + +inference_methods = InferenceMethods() diff --git a/bambi/backend/pymc.py b/bambi/backend/pymc.py index a84d3c571..75a1fe318 100644 --- a/bambi/backend/pymc.py +++ b/bambi/backend/pymc.py @@ -1,5 +1,4 @@ import functools -import importlib import logging import operator import traceback @@ -14,6 +13,7 @@ import pytensor.tensor as pt from pytensor.tensor.special import softmax +from bambi.backend.inference_methods import inference_methods from bambi.backend.links import cloglog, identity, inverse_squared, logit, probit, arctan_2 from bambi.backend.model_components import ConstantComponent, DistributionalComponent from bambi.utils import get_aliased_name @@ -47,8 +47,8 @@ def __init__(self): self.model = None self.spec = None self.components = {} - self.bayeux_methods = _get_bayeux_methods() - self.pymc_methods = {"mcmc": ["mcmc"], "vi": ["vi"]} + self.bayeux_methods = inference_methods.names["bayeux"] + self.pymc_methods = inference_methods.names["pymc"] def build(self, spec): """Compile the PyMC model from an abstract model specification. @@ -348,8 +348,7 @@ def _run_laplace(self, draws, omit_offsets, include_mean): Mainly for pedagogical use, provides reasonable results for approximately Gaussian posteriors. The approximation can be very poor for some models - like hierarchical ones. Use ``mcmc``, ``vi``, or JAX based MCMC methods - for better approximations. + like hierarchical ones. Use MCMC or VI methods for better approximations. Parameters ---------- @@ -398,10 +397,6 @@ def constant_components(self): def distributional_components(self): return {k: v for k, v in self.components.items() if isinstance(v, DistributionalComponent)} - @property - def inference_methods(self): - return {"pymc": self.pymc_methods, "bayeux": self.bayeux_methods} - def _posterior_samples_to_idata(samples, model): """Create InferenceData from samples. @@ -441,22 +436,3 @@ def _posterior_samples_to_idata(samples, model): idata = pm.to_inference_data(pm.backends.base.MultiTrace([strace]), model=model) return idata - - -def _get_bayeux_methods(): - """Gets a dictionary of usable bayeux methods if the bayeux package is installed - within the user's environment. - - Returns - ------- - dict - A dict where the keys are the module names and the values are the methods - available in that module. - """ - if importlib.util.find_spec("bayeux") is None: - return {"mcmc": []} - - import bayeux as bx # pylint: disable=import-outside-toplevel - - # Dummy log density to get access to all methods - return bx.Model(lambda x: -(x**2), 0.0).methods diff --git a/bambi/models.py b/bambi/models.py index ecb57700f..74286dbed 100644 --- a/bambi/models.py +++ b/bambi/models.py @@ -267,7 +267,7 @@ def fit( Finally, ``"laplace"``, in which case a Laplace approximation is used and is not recommended other than for pedagogical use. To get a list of JAX based inference methods, call - ``model.backend.inference_methods['bayeux']``. This will return a dictionary of the + ``bmb.inference_methods.names['bayeux']``. This will return a dictionary of the available methods such as ``blackjax_nuts``, ``numpyro_nuts``, among others. init : str Initialization method. Defaults to ``"auto"``. The available methods are: @@ -307,7 +307,7 @@ def fit( ------- An ArviZ ``InferenceData`` instance if inference_method is ``"mcmc"`` (default), "laplace", or one of the MCMC methods in - ``model.backend.inference_methods['bayeux']['mcmc]``. + ``bmb.inference_methods.names['bayeux']['mcmc]``. An ``Approximation`` object if ``"vi"``. """ method = kwargs.pop("method", None) diff --git a/docs/notebooks/alternative_samplers.ipynb b/docs/notebooks/alternative_samplers.ipynb index 24d610d96..7610df6d6 100644 --- a/docs/notebooks/alternative_samplers.ipynb +++ b/docs/notebooks/alternative_samplers.ipynb @@ -15,9 +15,17 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 1, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING (pytensor.tensor.blas): Using NumPy C-API based implementation for BLAS functions.\n" + ] + } + ], "source": [ "import arviz as az\n", "import bambi as bmb\n", @@ -62,7 +70,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 20, "metadata": {}, "outputs": [], "source": [ @@ -74,12 +82,12 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "We can call `model.backend.inference_methods` that returns a nested dictionary of the backends and list of inference methods." + "We can call `bmb.inference_methods.names` that returns a nested dictionary of the backends and list of inference methods." ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 27, "metadata": {}, "outputs": [ { @@ -100,47 +108,16 @@ " 'flowmc_realnvp_hmc',\n", " 'flowmc_realnvp_mala',\n", " 'numpyro_hmc',\n", - " 'numpyro_nuts'],\n", - " 'optimize': ['jaxopt_bfgs',\n", - " 'jaxopt_gradient_descent',\n", - " 'jaxopt_lbfgs',\n", - " 'jaxopt_nonlinear_cg',\n", - " 'optimistix_bfgs',\n", - " 'optimistix_chord',\n", - " 'optimistix_dogleg',\n", - " 'optimistix_gauss_newton',\n", - " 'optimistix_indirect_levenberg_marquardt',\n", - " 'optimistix_levenberg_marquardt',\n", - " 'optimistix_nelder_mead',\n", - " 'optimistix_newton',\n", - " 'optimistix_nonlinear_cg',\n", - " 'optax_adabelief',\n", - " 'optax_adafactor',\n", - " 'optax_adagrad',\n", - " 'optax_adam',\n", - " 'optax_adamw',\n", - " 'optax_adamax',\n", - " 'optax_amsgrad',\n", - " 'optax_fromage',\n", - " 'optax_lamb',\n", - " 'optax_lion',\n", - " 'optax_noisy_sgd',\n", - " 'optax_novograd',\n", - " 'optax_radam',\n", - " 'optax_rmsprop',\n", - " 'optax_sgd',\n", - " 'optax_sm3',\n", - " 'optax_yogi'],\n", - " 'vi': ['tfp_factored_surrogate_posterior']}}" + " 'numpyro_nuts']}}" ] }, - "execution_count": 4, + "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "methods = model.backend.inference_methods\n", + "methods = bmb.inference_methods.names\n", "methods" ] }, @@ -153,7 +130,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 28, "metadata": {}, "outputs": [ { @@ -162,7 +139,7 @@ "{'mcmc': ['mcmc'], 'vi': ['vi']}" ] }, - "execution_count": 5, + "execution_count": 28, "metadata": {}, "output_type": "execute_result" } @@ -180,36 +157,36 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "['tfp_hmc',\n", - " 'tfp_nuts',\n", - " 'tfp_snaper_hmc',\n", - " 'blackjax_hmc',\n", - " 'blackjax_chees_hmc',\n", - " 'blackjax_meads_hmc',\n", - " 'blackjax_nuts',\n", - " 'blackjax_hmc_pathfinder',\n", - " 'blackjax_nuts_pathfinder',\n", - " 'flowmc_rqspline_hmc',\n", - " 'flowmc_rqspline_mala',\n", - " 'flowmc_realnvp_hmc',\n", - " 'flowmc_realnvp_mala',\n", - " 'numpyro_hmc',\n", - " 'numpyro_nuts']" + "{'mcmc': ['tfp_hmc',\n", + " 'tfp_nuts',\n", + " 'tfp_snaper_hmc',\n", + " 'blackjax_hmc',\n", + " 'blackjax_chees_hmc',\n", + " 'blackjax_meads_hmc',\n", + " 'blackjax_nuts',\n", + " 'blackjax_hmc_pathfinder',\n", + " 'blackjax_nuts_pathfinder',\n", + " 'flowmc_rqspline_hmc',\n", + " 'flowmc_rqspline_mala',\n", + " 'flowmc_realnvp_hmc',\n", + " 'flowmc_realnvp_mala',\n", + " 'numpyro_hmc',\n", + " 'numpyro_nuts']}" ] }, - "execution_count": 7, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "methods[\"bayeux\"][\"mcmc\"]" + "methods[\"bayeux\"]" ] }, { @@ -242,9 +219,46 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 7, "metadata": {}, "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "909c6a6f539145ab8348ebdeb1d42a3b", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, { "data": { "text/html": [ @@ -256,8 +270,8 @@ "
  • created_at :
    2024-04-13T05:34:49.761913+00:00
    arviz_version :
    0.18.0
    modeling_interface :
    bambi
    modeling_interface_version :
    0.13.1.dev25+g1e7f677e.d20240413

  • \n", " \n", " \n", " \n", " \n", "
  • \n", - " \n", - " \n", + " \n", + " \n", "
    \n", "
    \n", "
  • created_at :
    2024-04-13T05:34:49.763427+00:00
    arviz_version :
    0.18.0
    modeling_interface :
    bambi
    modeling_interface_version :
    0.13.1.dev25+g1e7f677e.d20240413

  • \n", " \n", " \n", " \n", @@ -1447,7 +1461,8 @@ " grid-template-columns: 125px auto;\n", "}\n", "\n", - ".xr-attrs dt, dd {\n", + ".xr-attrs dt,\n", + ".xr-attrs dd {\n", " padding: 0;\n", " margin: 0;\n", " float: left;\n", @@ -1490,7 +1505,7 @@ "\t> sample_stats" ] }, - "execution_count": 8, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } @@ -1506,7 +1521,7 @@ "source": [ "Different backends have different naming conventions for the parameters specific to that MCMC method. Thus, to specify backend-specific parameters, pass your own `kwargs` to the `fit` method.\n", "\n", - "Each algorithm has a `.get_kwargs()` method that tells you how it will be called, and what functions are being called." + "The following can be performend to identify the kwargs specific to each method." ] }, { @@ -1542,7 +1557,7 @@ } ], "source": [ - "bx.Model.from_pymc(model.backend.model).mcmc.blackjax_nuts.get_kwargs()" + "bmb.inference_methods.get_kwargs(\"blackjax_nuts\")" ] }, { @@ -1554,9 +1569,46 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 9, "metadata": {}, "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "f12c14ad9394476085d96b2ebbaa837d", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
    \n"
    +      ],
    +      "text/plain": []
    +     },
    +     "metadata": {},
    +     "output_type": "display_data"
    +    },
    +    {
    +     "data": {
    +      "text/html": [
    +       "
    \n",
    +       "
    \n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, { "data": { "text/html": [ @@ -1568,8 +1620,8 @@ "
  • created_at :
    2024-04-13T05:36:20.439151+00:00
    arviz_version :
    0.18.0
    modeling_interface :
    bambi
    modeling_interface_version :
    0.13.1.dev25+g1e7f677e.d20240413

  • \n", " \n", " \n", " \n", " \n", "
  • \n", - " \n", - " \n", + " \n", + " \n", "
    \n", "
    \n", "
  • created_at :
    2024-04-13T05:36:20.441267+00:00
    arviz_version :
    0.18.0
    modeling_interface :
    bambi
    modeling_interface_version :
    0.13.1.dev25+g1e7f677e.d20240413

  • \n", " \n", " \n", " \n", @@ -3057,7 +3109,8 @@ " grid-template-columns: 125px auto;\n", "}\n", "\n", - ".xr-attrs dt, dd {\n", + ".xr-attrs dt,\n", + ".xr-attrs dd {\n", " padding: 0;\n", " margin: 0;\n", " float: left;\n", @@ -3100,7 +3153,7 @@ "\t> sample_stats" ] }, - "execution_count": 10, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -3126,9 +3179,46 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 10, "metadata": {}, "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "a25653608a3f4f26a20237fb94775629", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
    \n"
    +      ],
    +      "text/plain": []
    +     },
    +     "metadata": {},
    +     "output_type": "display_data"
    +    },
    +    {
    +     "data": {
    +      "text/html": [
    +       "
    \n",
    +       "
    \n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, { "data": { "text/html": [ @@ -3140,8 +3230,8 @@ "
  • created_at :
    2024-04-13T05:36:30.303342+00:00
    arviz_version :
    0.18.0
    modeling_interface :
    bambi
    modeling_interface_version :
    0.13.1.dev25+g1e7f677e.d20240413

  • \n", " \n", " \n", " \n", " \n", "
  • \n", - " \n", - " \n", + " \n", + " \n", "
    \n", "
    \n", "
  • created_at :
    2024-04-13T05:36:30.304788+00:00
    arviz_version :
    0.18.0
    modeling_interface :
    bambi
    modeling_interface_version :
    0.13.1.dev25+g1e7f677e.d20240413

  • \n", " \n", " \n", " \n", @@ -4325,7 +4415,8 @@ " grid-template-columns: 125px auto;\n", "}\n", "\n", - ".xr-attrs dt, dd {\n", + ".xr-attrs dt,\n", + ".xr-attrs dd {\n", " padding: 0;\n", " margin: 0;\n", " float: left;\n", @@ -4368,7 +4459,7 @@ "\t> sample_stats" ] }, - "execution_count": 5, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -4394,9 +4485,46 @@ "name": "stderr", "output_type": "stream", "text": [ - "sample: 100%|██████████| 1500/1500 [00:02<00:00, 551.76it/s]\n" + "sample: 100%|██████████| 1500/1500 [00:02<00:00, 599.25it/s]\n" ] }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "851ea515d7c54968926f9eb0dc8b30c1", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
    \n"
    +      ],
    +      "text/plain": []
    +     },
    +     "metadata": {},
    +     "output_type": "display_data"
    +    },
    +    {
    +     "data": {
    +      "text/html": [
    +       "
    \n",
    +       "
    \n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, { "data": { "text/html": [ @@ -4408,8 +4536,8 @@ "
  • created_at :
    2024-04-13T05:36:33.599519+00:00
    arviz_version :
    0.18.0
    inference_library :
    numpyro
    inference_library_version :
    0.14.0
    modeling_interface :
    bambi
    modeling_interface_version :
    0.13.1.dev25+g1e7f677e.d20240413

  • \n", " \n", " \n", " \n", " \n", "
  • \n", - " \n", - " \n", + " \n", + " \n", "
    \n", "
    \n", "
  • created_at :
    2024-04-13T05:36:33.623197+00:00
    arviz_version :
    0.18.0
    inference_library :
    numpyro
    inference_library_version :
    0.14.0
    modeling_interface :
    bambi
    modeling_interface_version :
    0.13.1.dev25+g1e7f677e.d20240413

  • \n", " \n", " \n", " \n", @@ -5603,7 +5731,8 @@ " grid-template-columns: 125px auto;\n", "}\n", "\n", - ".xr-attrs dt, dd {\n", + ".xr-attrs dt,\n", + ".xr-attrs dd {\n", " padding: 0;\n", " margin: 0;\n", " float: left;\n", @@ -5680,7 +5809,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "Tuning global sampler: 100%|██████████| 5/5 [00:51<00:00, 10.23s/it]\n" + "Tuning global sampler: 100%|██████████| 5/5 [00:51<00:00, 10.37s/it]\n" ] }, { @@ -5694,9 +5823,53 @@ "name": "stderr", "output_type": "stream", "text": [ - "Production run: 100%|██████████| 5/5 [00:00<00:00, 9.38it/s]\n" + "Production run: 100%|██████████| 5/5 [00:00<00:00, 14.38it/s]" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "1865c421a05b46109fcf06c8b7da2cf4", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" ] }, + { + "data": { + "text/html": [ + "
    \n"
    +      ],
    +      "text/plain": []
    +     },
    +     "metadata": {},
    +     "output_type": "display_data"
    +    },
    +    {
    +     "data": {
    +      "text/html": [
    +       "
    \n",
    +       "
    \n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, { "data": { "text/html": [ @@ -5708,8 +5881,8 @@ "
  • created_at :
    2024-04-13T05:37:29.798250+00:00
    arviz_version :
    0.18.0
    modeling_interface :
    bambi
    modeling_interface_version :
    0.13.1.dev25+g1e7f677e.d20240413

  • \n", " \n", " \n", " \n", @@ -6440,7 +6613,8 @@ " grid-template-columns: 125px auto;\n", "}\n", "\n", - ".xr-attrs dt, dd {\n", + ".xr-attrs dt,\n", + ".xr-attrs dd {\n", " padding: 0;\n", " margin: 0;\n", " float: left;\n", @@ -6503,7 +6677,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 13, "metadata": {}, "outputs": [ { @@ -6540,40 +6714,40 @@ " \n", " \n", " \n", - " y_sigma\n", - " 0.945\n", - " 0.070\n", - " 0.819\n", - " 1.080\n", - " 0.002\n", - " 0.002\n", - " 1044.0\n", - " 667.0\n", - " 1.0\n", - " \n", - " \n", " Intercept\n", - " 0.018\n", - " 0.089\n", - " -0.156\n", - " 0.185\n", + " 0.023\n", + " 0.097\n", + " -0.141\n", + " 0.209\n", + " 0.004\n", " 0.003\n", - " 0.002\n", - " 844.0\n", - " 733.0\n", - " 1.0\n", + " 694.0\n", + " 508.0\n", + " 1.00\n", " \n", " \n", " x\n", - " 0.358\n", - " 0.105\n", - " 0.163\n", - " 0.554\n", + " 0.356\n", + " 0.111\n", + " 0.162\n", + " 0.571\n", " 0.004\n", " 0.003\n", - " 829.0\n", - " 767.0\n", - " 1.0\n", + " 970.0\n", + " 675.0\n", + " 1.00\n", + " \n", + " \n", + " y_sigma\n", + " 0.950\n", + " 0.069\n", + " 0.827\n", + " 1.072\n", + " 0.002\n", + " 0.001\n", + " 1418.0\n", + " 842.0\n", + " 1.01\n", " \n", " \n", "\n", @@ -6581,17 +6755,17 @@ ], "text/plain": [ " mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk \\\n", - "y_sigma 0.945 0.070 0.819 1.080 0.002 0.002 1044.0 \n", - "Intercept 0.018 0.089 -0.156 0.185 0.003 0.002 844.0 \n", - "x 0.358 0.105 0.163 0.554 0.004 0.003 829.0 \n", + "Intercept 0.023 0.097 -0.141 0.209 0.004 0.003 694.0 \n", + "x 0.356 0.111 0.162 0.571 0.004 0.003 970.0 \n", + "y_sigma 0.950 0.069 0.827 1.072 0.002 0.001 1418.0 \n", "\n", " ess_tail r_hat \n", - "y_sigma 667.0 1.0 \n", - "Intercept 733.0 1.0 \n", - "x 767.0 1.0 " + "Intercept 508.0 1.00 \n", + "x 675.0 1.00 \n", + "y_sigma 842.0 1.01 " ] }, - "execution_count": 16, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" } @@ -6602,7 +6776,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 14, "metadata": {}, "outputs": [ { @@ -6639,39 +6813,39 @@ " \n", " \n", " \n", - " y_sigma\n", - " 0.948\n", - " 0.067\n", - " 0.824\n", - " 1.073\n", + " Intercept\n", + " 0.023\n", + " 0.097\n", + " -0.157\n", + " 0.205\n", " 0.001\n", " 0.001\n", - " 8107.0\n", - " 5585.0\n", + " 6785.0\n", + " 5740.0\n", " 1.0\n", " \n", " \n", - " Intercept\n", - " 0.025\n", - " 0.095\n", - " -0.152\n", - " 0.200\n", + " x\n", + " 0.360\n", + " 0.105\n", + " 0.169\n", + " 0.563\n", " 0.001\n", " 0.001\n", - " 6772.0\n", - " 5624.0\n", + " 6988.0\n", + " 5116.0\n", " 1.0\n", " \n", " \n", - " x\n", - " 0.361\n", - " 0.104\n", - " 0.157\n", - " 0.551\n", + " y_sigma\n", + " 0.946\n", + " 0.067\n", + " 0.831\n", + " 1.081\n", " 0.001\n", " 0.001\n", - " 6682.0\n", - " 5414.0\n", + " 7476.0\n", + " 5971.0\n", " 1.0\n", " \n", " \n", @@ -6680,17 +6854,17 @@ ], "text/plain": [ " mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk \\\n", - "y_sigma 0.948 0.067 0.824 1.073 0.001 0.001 8107.0 \n", - "Intercept 0.025 0.095 -0.152 0.200 0.001 0.001 6772.0 \n", - "x 0.361 0.104 0.157 0.551 0.001 0.001 6682.0 \n", + "Intercept 0.023 0.097 -0.157 0.205 0.001 0.001 6785.0 \n", + "x 0.360 0.105 0.169 0.563 0.001 0.001 6988.0 \n", + "y_sigma 0.946 0.067 0.831 1.081 0.001 0.001 7476.0 \n", "\n", " ess_tail r_hat \n", - "y_sigma 5585.0 1.0 \n", - "Intercept 5624.0 1.0 \n", - "x 5414.0 1.0 " + "Intercept 5740.0 1.0 \n", + "x 5116.0 1.0 \n", + "y_sigma 5971.0 1.0 " ] }, - "execution_count": 6, + "execution_count": 14, "metadata": {}, "output_type": "execute_result" } @@ -6701,7 +6875,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 15, "metadata": {}, "outputs": [ { @@ -6739,38 +6913,38 @@ " \n", " \n", " Intercept\n", - " 0.022\n", - " 0.097\n", - " -0.149\n", - " 0.217\n", + " 0.024\n", + " 0.095\n", + " -0.162\n", + " 0.195\n", " 0.001\n", " 0.001\n", - " 7412.0\n", - " 5758.0\n", + " 6851.0\n", + " 5614.0\n", " 1.0\n", " \n", " \n", " x\n", - " 0.359\n", - " 0.105\n", - " 0.159\n", - " 0.555\n", + " 0.362\n", + " 0.104\n", + " 0.176\n", + " 0.557\n", " 0.001\n", " 0.001\n", - " 7406.0\n", - " 5967.0\n", + " 9241.0\n", + " 6340.0\n", " 1.0\n", " \n", " \n", " y_sigma\n", - " 0.947\n", - " 0.069\n", - " 0.822\n", + " 0.946\n", + " 0.068\n", + " 0.826\n", " 1.079\n", " 0.001\n", " 0.001\n", - " 7371.0\n", - " 5405.0\n", + " 7247.0\n", + " 5711.0\n", " 1.0\n", " \n", " \n", @@ -6779,17 +6953,17 @@ ], "text/plain": [ " mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk \\\n", - "Intercept 0.022 0.097 -0.149 0.217 0.001 0.001 7412.0 \n", - "x 0.359 0.105 0.159 0.555 0.001 0.001 7406.0 \n", - "y_sigma 0.947 0.069 0.822 1.079 0.001 0.001 7371.0 \n", + "Intercept 0.024 0.095 -0.162 0.195 0.001 0.001 6851.0 \n", + "x 0.362 0.104 0.176 0.557 0.001 0.001 9241.0 \n", + "y_sigma 0.946 0.068 0.826 1.079 0.001 0.001 7247.0 \n", "\n", " ess_tail r_hat \n", - "Intercept 5758.0 1.0 \n", - "x 5967.0 1.0 \n", - "y_sigma 5405.0 1.0 " + "Intercept 5614.0 1.0 \n", + "x 6340.0 1.0 \n", + "y_sigma 5711.0 1.0 " ] }, - "execution_count": 17, + "execution_count": 15, "metadata": {}, "output_type": "execute_result" } @@ -6800,7 +6974,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 16, "metadata": {}, "outputs": [ { @@ -6837,39 +7011,39 @@ " \n", " \n", " \n", - " y_sigma\n", - " 0.946\n", - " 0.067\n", - " 0.825\n", - " 1.076\n", - " 0.001\n", - " 0.001\n", - " 6260.0\n", - " 5213.0\n", - " 1.00\n", - " \n", - " \n", " Intercept\n", - " 0.013\n", - " 0.093\n", - " -0.165\n", + " 0.015\n", + " 0.100\n", + " -0.186\n", " 0.190\n", + " 0.004\n", " 0.003\n", - " 0.002\n", - " 924.0\n", - " 1302.0\n", + " 758.0\n", + " 1233.0\n", " 1.02\n", " \n", " \n", " x\n", - " 0.359\n", - " 0.103\n", - " 0.166\n", - " 0.556\n", + " 0.361\n", + " 0.105\n", + " 0.174\n", + " 0.565\n", " 0.001\n", " 0.001\n", - " 5132.0\n", - " 5790.0\n", + " 5084.0\n", + " 4525.0\n", + " 1.00\n", + " \n", + " \n", + " y_sigma\n", + " 0.951\n", + " 0.070\n", + " 0.823\n", + " 1.079\n", + " 0.001\n", + " 0.001\n", + " 5536.0\n", + " 5080.0\n", " 1.00\n", " \n", " \n", @@ -6878,17 +7052,17 @@ ], "text/plain": [ " mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk \\\n", - "y_sigma 0.946 0.067 0.825 1.076 0.001 0.001 6260.0 \n", - "Intercept 0.013 0.093 -0.165 0.190 0.003 0.002 924.0 \n", - "x 0.359 0.103 0.166 0.556 0.001 0.001 5132.0 \n", + "Intercept 0.015 0.100 -0.186 0.190 0.004 0.003 758.0 \n", + "x 0.361 0.105 0.174 0.565 0.001 0.001 5084.0 \n", + "y_sigma 0.951 0.070 0.823 1.079 0.001 0.001 5536.0 \n", "\n", " ess_tail r_hat \n", - "y_sigma 5213.0 1.00 \n", - "Intercept 1302.0 1.02 \n", - "x 5790.0 1.00 " + "Intercept 1233.0 1.02 \n", + "x 4525.0 1.00 \n", + "y_sigma 5080.0 1.00 " ] }, - "execution_count": 18, + "execution_count": 16, "metadata": {}, "output_type": "execute_result" } @@ -6908,25 +7082,24 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Last updated: Fri Mar 01 2024\n", + "Last updated: Sat Apr 13 2024\n", "\n", "Python implementation: CPython\n", - "Python version : 3.11.7\n", - "IPython version : 8.21.0\n", + "Python version : 3.12.2\n", + "IPython version : 8.20.0\n", "\n", - "arviz : 0.17.0\n", - "bambi : 0.13.1.dev16+g9a1387a7.d20240204\n", - "numpy : 1.26.3\n", - "pandas : 2.2.0\n", - "bayeux : 0.1.9\n", - "matplotlib: 3.8.2\n", + "bambi : 0.13.1.dev25+g1e7f677e.d20240413\n", + "pandas: 2.2.1\n", + "numpy : 1.26.4\n", + "bayeux: 0.1.10\n", + "arviz : 0.18.0\n", "\n", "Watermark: 2.4.3\n", "\n" @@ -6955,7 +7128,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.7" + "version": "3.12.2" } }, "nbformat": 4, diff --git a/tests/test_alternative_samplers.py b/tests/test_alternative_samplers.py index 6222f3df3..be260b6bd 100644 --- a/tests/test_alternative_samplers.py +++ b/tests/test_alternative_samplers.py @@ -7,7 +7,9 @@ MCMC_METHODS = [getattr(bx.mcmc, k).name for k in bx.mcmc.__all__] -MCMC_METHODS_FILTERED = [i for i in MCMC_METHODS if not any(x in i for x in ("flowmc", "chees", "meads"))] +MCMC_METHODS_FILTERED = [ + i for i in MCMC_METHODS if not any(x in i for x in ("flowmc", "chees", "meads")) +] @pytest.fixture(scope="module") @@ -30,6 +32,24 @@ def data_n100(): return data +def test_inference_method_names_and_kwargs(): + names = bmb.inference_methods.names + + # Check PyMC inference method family + assert "mcmc" in names["pymc"].keys() + assert "vi" in names["pymc"].keys() + + # Check bayeu inference method family. Currently, only MCMC methods are supported + assert "mcmc" in names["bayeux"].keys() + + # Ensure get_kwargs method raises an error if a non-supported method name is passed + with pytest.raises( + ValueError, + match="Inference method 'not_a_method' not found in the list of available methods. Use `bmb.inference_methods.names` to list the available methods.", + ): + bmb.inference_methods.get_kwargs("not_a_method") + + def test_laplace(): data = pd.DataFrame(np.repeat((0, 1), (30, 60)), columns=["w"]) priors = {"Intercept": bmb.Prior("Uniform", lower=0, upper=1)} @@ -56,7 +76,7 @@ def test_vi(): (mode_n.item(), std_n.item()), (mode_a.item(), std_a.item()), decimal=2 ) -# + @pytest.mark.parametrize("sampler", MCMC_METHODS_FILTERED) def test_logistic_regression_categoric_alternative_samplers(data_n100, sampler): model = bmb.Model("b1 ~ n1", data_n100, family="bernoulli")