Skip to content

Commit

Permalink
Fix: warnings in softplus and inverse_softplus #270 (#275)
Browse files Browse the repository at this point in the history
  • Loading branch information
han-ol authored Dec 11, 2024
1 parent 6449a92 commit ab0be62
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions bayesflow/utils/numpy_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ def inverse_shifted_softplus(

def inverse_softplus(x: np.ndarray, beta: float = 1.0, threshold: float = 20.0) -> np.ndarray:
"""Numerically stabilized inverse softplus function."""
return np.where(beta * x > threshold, x, np.log(beta * np.expm1(x)) / beta)
with np.errstate(over="ignore"):
expm1_x = np.expm1(x)
return np.where(beta * x > threshold, x, np.log(beta * expm1_x) / beta)


def one_hot(indices: np.ndarray, num_classes: int, dtype: str = "float32") -> np.ndarray:
Expand All @@ -37,4 +39,6 @@ def shifted_softplus(

def softplus(x: np.ndarray, beta: float = 1.0, threshold: float = 20.0) -> np.ndarray:
"""Numerically stabilized softplus function."""
return np.where(beta * x > threshold, x, np.log1p(np.exp(beta * x)) / beta)
with np.errstate(over="ignore"):
exp_beta_x = np.exp(beta * x)
return np.where(beta * x > threshold, x, np.log1p(exp_beta_x) / beta)

0 comments on commit ab0be62

Please sign in to comment.