diff --git a/bayesflow/utils/numpy_utils.py b/bayesflow/utils/numpy_utils.py index 7338817d3..0e0d8c252 100644 --- a/bayesflow/utils/numpy_utils.py +++ b/bayesflow/utils/numpy_utils.py @@ -16,7 +16,9 @@ def inverse_shifted_softplus( def inverse_softplus(x: np.ndarray, beta: float = 1.0, threshold: float = 20.0) -> np.ndarray: """Numerically stabilized inverse softplus function.""" - return np.where(beta * x > threshold, x, np.log(beta * np.expm1(x)) / beta) + with np.errstate(over="ignore"): + expm1_x = np.expm1(x) + return np.where(beta * x > threshold, x, np.log(beta * expm1_x) / beta) def one_hot(indices: np.ndarray, num_classes: int, dtype: str = "float32") -> np.ndarray: @@ -37,4 +39,6 @@ def shifted_softplus( def softplus(x: np.ndarray, beta: float = 1.0, threshold: float = 20.0) -> np.ndarray: """Numerically stabilized softplus function.""" - return np.where(beta * x > threshold, x, np.log1p(np.exp(beta * x)) / beta) + with np.errstate(over="ignore"): + exp_beta_x = np.exp(beta * x) + return np.where(beta * x > threshold, x, np.log1p(exp_beta_x) / beta)