diff --git a/bayesflow/experimental/networks/coupling_flow/transforms/affine_transform.py b/bayesflow/experimental/networks/coupling_flow/transforms/affine_transform.py index ca361f810..329c0820b 100644 --- a/bayesflow/experimental/networks/coupling_flow/transforms/affine_transform.py +++ b/bayesflow/experimental/networks/coupling_flow/transforms/affine_transform.py @@ -1,5 +1,6 @@ -import numpy as np +from math import pi as PI_CONST + from keras import ops from bayesflow.experimental.types import Tensor @@ -7,25 +8,29 @@ class AffineTransform(Transform): + + def __init__(self, clamp_factor=1.9): + self.clamp_factor = clamp_factor + def split_parameters(self, parameters: Tensor) -> dict[str, Tensor]: scale, shift = ops.split(parameters, 2, axis=-1) return {"scale": scale, "shift": shift} def constrain_parameters(self, parameters: dict[str, Tensor]) -> dict[str, Tensor]: - shift = np.log(np.e - 1) - parameters["scale"] = ops.softplus(parameters["scale"] + shift) + s = (2.0 * self.clamp_factor / PI_CONST) * ops.atan(parameters["scale"] / self.soft_clamp) + parameters["scale"] = ops.exp(s) return parameters def forward(self, x: Tensor, parameters: dict[str, Tensor]) -> (Tensor, Tensor): z = parameters["scale"] * x + parameters["shift"] - log_det = ops.mean(ops.log(parameters["scale"]), axis=-1) + log_det = ops.mean(parameters["scale"], axis=-1) return z, log_det def inverse(self, z: Tensor, parameters: dict[str, Tensor]) -> (Tensor, Tensor): x = (z - parameters["shift"]) / parameters["scale"] - log_det = -ops.mean(ops.log(parameters["scale"]), axis=-1) + log_det = -ops.mean(parameters["scale"], axis=-1) return x, log_det