From ab0be625efbed21f83fa24facc4c018325bab96b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hans=20Olischl=C3=A4ger?= <106988117+han-ol@users.noreply.github.com> Date: Wed, 11 Dec 2024 10:32:16 +0100 Subject: [PATCH] Fix: warnings in softplus and inverse_softplus #270 (#275) --- bayesflow/utils/numpy_utils.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/bayesflow/utils/numpy_utils.py b/bayesflow/utils/numpy_utils.py index 7338817d3..0e0d8c252 100644 --- a/bayesflow/utils/numpy_utils.py +++ b/bayesflow/utils/numpy_utils.py @@ -16,7 +16,9 @@ def inverse_shifted_softplus( def inverse_softplus(x: np.ndarray, beta: float = 1.0, threshold: float = 20.0) -> np.ndarray: """Numerically stabilized inverse softplus function.""" - return np.where(beta * x > threshold, x, np.log(beta * np.expm1(x)) / beta) + with np.errstate(over="ignore"): + expm1_x = np.expm1(x) + return np.where(beta * x > threshold, x, np.log(beta * expm1_x) / beta) def one_hot(indices: np.ndarray, num_classes: int, dtype: str = "float32") -> np.ndarray: @@ -37,4 +39,6 @@ def shifted_softplus( def softplus(x: np.ndarray, beta: float = 1.0, threshold: float = 20.0) -> np.ndarray: """Numerically stabilized softplus function.""" - return np.where(beta * x > threshold, x, np.log1p(np.exp(beta * x)) / beta) + with np.errstate(over="ignore"): + exp_beta_x = np.exp(beta * x) + return np.where(beta * x > threshold, x, np.log1p(exp_beta_x) / beta)