From be7a3351cfd59003a5891b7dfb9dc0164ff1282a Mon Sep 17 00:00:00 2001 From: specsy Date: Thu, 23 Jan 2025 01:14:12 +0530 Subject: [PATCH] Create truncatedprior.py fixes #779 This function leverages the truncnorm distribution from the SciPy library to generate samples from a truncated normal distribution. You can adapt the function to other distributions by changing the import statement and parameters. --- bambi/priors/truncatedprior.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) create mode 100644 bambi/priors/truncatedprior.py diff --git a/bambi/priors/truncatedprior.py b/bambi/priors/truncatedprior.py new file mode 100644 index 00000000..96c583d1 --- /dev/null +++ b/bambi/priors/truncatedprior.py @@ -0,0 +1,19 @@ +import numpy as np +from scipy.stats import truncnorm + +def truncatedprior(mean, std, lower, upper, size=1): + """ + Generate truncated normal distribution samples. + + Parameters: + - mean (float): Mean of the normal distribution. + - std (float): Standard deviation of the normal distribution. + - lower (float): Lower bound of the truncated distribution. + - upper (float): Upper bound of the truncated distribution. + - size (int): Number of samples to generate. + + Returns: + - ndarray: Samples from the truncated normal distribution. + """ + a, b = (lower - mean) / std, (upper - mean) / std + return truncnorm(a, b, loc=mean, scale=std).rvs(size)