Skip to content

Commit

Permalink
Create truncatedprior.py
Browse files Browse the repository at this point in the history
fixes #779 
This function leverages the truncnorm distribution from the SciPy library to generate samples from a truncated normal distribution. You can adapt the function to other distributions by changing the import statement and parameters.
  • Loading branch information
speco29 authored Jan 22, 2025
1 parent 6b66691 commit be7a335
Showing 1 changed file with 19 additions and 0 deletions.
19 changes: 19 additions & 0 deletions bambi/priors/truncatedprior.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import numpy as np
from scipy.stats import truncnorm

def truncatedprior(mean, std, lower, upper, size=1):
"""
Generate truncated normal distribution samples.
Parameters:
- mean (float): Mean of the normal distribution.
- std (float): Standard deviation of the normal distribution.
- lower (float): Lower bound of the truncated distribution.
- upper (float): Upper bound of the truncated distribution.
- size (int): Number of samples to generate.
Returns:
- ndarray: Samples from the truncated normal distribution.
"""
a, b = (lower - mean) / std, (upper - mean) / std
return truncnorm(a, b, loc=mean, scale=std).rvs(size)

0 comments on commit be7a335

Please sign in to comment.