From faae8df956c1e68e5c6de05ec4cfe7154c55fbb3 Mon Sep 17 00:00:00 2001 From: Radev Date: Tue, 14 May 2024 18:45:11 -0400 Subject: [PATCH] Add soft clamping to affine transform, otherwise the loss explodes even for simple cases --- .../coupling_flow/transforms/affine_transform.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/bayesflow/experimental/networks/coupling_flow/transforms/affine_transform.py b/bayesflow/experimental/networks/coupling_flow/transforms/affine_transform.py index ca361f810..329c0820b 100644 --- a/bayesflow/experimental/networks/coupling_flow/transforms/affine_transform.py +++ b/bayesflow/experimental/networks/coupling_flow/transforms/affine_transform.py @@ -1,5 +1,6 @@ -import numpy as np +from math import pi as PI_CONST + from keras import ops from bayesflow.experimental.types import Tensor @@ -7,25 +8,29 @@ class AffineTransform(Transform): + + def __init__(self, clamp_factor=1.9): + self.clamp_factor = clamp_factor + def split_parameters(self, parameters: Tensor) -> dict[str, Tensor]: scale, shift = ops.split(parameters, 2, axis=-1) return {"scale": scale, "shift": shift} def constrain_parameters(self, parameters: dict[str, Tensor]) -> dict[str, Tensor]: - shift = np.log(np.e - 1) - parameters["scale"] = ops.softplus(parameters["scale"] + shift) + s = (2.0 * self.clamp_factor / PI_CONST) * ops.atan(parameters["scale"] / self.soft_clamp) + parameters["scale"] = ops.exp(s) return parameters def forward(self, x: Tensor, parameters: dict[str, Tensor]) -> (Tensor, Tensor): z = parameters["scale"] * x + parameters["shift"] - log_det = ops.mean(ops.log(parameters["scale"]), axis=-1) + log_det = ops.mean(parameters["scale"], axis=-1) return z, log_det def inverse(self, z: Tensor, parameters: dict[str, Tensor]) -> (Tensor, Tensor): x = (z - parameters["shift"]) / parameters["scale"] - log_det = -ops.mean(ops.log(parameters["scale"]), axis=-1) + log_det = -ops.mean(parameters["scale"], axis=-1) return x, log_det