From 15b4a8ad72cd46a4d2170ec88c908590f1e206a5 Mon Sep 17 00:00:00 2001 From: larskue Date: Thu, 6 Feb 2025 15:11:22 +0100 Subject: [PATCH] fix #288 --- .../networks/flow_matching/flow_matching.py | 20 ++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/bayesflow/networks/flow_matching/flow_matching.py b/bayesflow/networks/flow_matching/flow_matching.py index 2ff5c1bda..64ef76b4f 100644 --- a/bayesflow/networks/flow_matching/flow_matching.py +++ b/bayesflow/networks/flow_matching/flow_matching.py @@ -90,7 +90,7 @@ def from_config(cls, config): config = deserialize_value_or_type(config, "subnet") return cls(**config) - def velocity(self, xz: Tensor, t: float | Tensor, conditions: Tensor = None) -> Tensor: + def velocity(self, xz: Tensor, t: float | Tensor, conditions: Tensor = None, training: bool = False) -> Tensor: if not keras.ops.is_tensor(t): t = keras.ops.convert_to_tensor(t, dtype=keras.ops.dtype(xz)) @@ -105,11 +105,13 @@ def velocity(self, xz: Tensor, t: float | Tensor, conditions: Tensor = None) -> else: xtc = keras.ops.concatenate([xz, t, conditions], axis=-1) - return self.output_projector(self.subnet(xtc)) + return self.output_projector(self.subnet(xtc, training=training), training=training) - def _velocity_trace(self, xz: Tensor, t: Tensor, conditions: Tensor = None, max_steps: int = 1) -> (Tensor, Tensor): + def _velocity_trace( + self, xz: Tensor, t: Tensor, conditions: Tensor = None, max_steps: int = 1, training: bool = False + ) -> (Tensor, Tensor): def f(x): - return self.velocity(x, t, conditions=conditions) + return self.velocity(x, t, conditions=conditions, training=training) v, trace = jacobian_trace(f, xz, max_steps=max_steps) @@ -121,7 +123,7 @@ def _forward( if density: def deltas(t, xz): - v, trace = self._velocity_trace(xz, t, conditions=conditions) + v, trace = self._velocity_trace(xz, t, conditions=conditions, training=training) return {"xz": v, "trace": trace} state = {"xz": x, "trace": 0.0} @@ -130,7 +132,7 @@ def deltas(t, xz): return state["xz"], state["trace"] def deltas(t, xz): - return {"xz": self.velocity(xz, t, conditions=conditions)} + return {"xz": self.velocity(xz, t, conditions=conditions, training=training)} state = {"xz": x} state = integrate(deltas, state, start_time=1.0, stop_time=0.0, **kwargs) @@ -143,7 +145,7 @@ def _inverse( if density: def deltas(t, xz): - v, trace = self._velocity_trace(xz, t, conditions=conditions) + v, trace = self._velocity_trace(xz, t, conditions=conditions, training=training) return {"xz": v, "trace": trace} state = {"xz": z, "trace": 0.0} @@ -152,7 +154,7 @@ def deltas(t, xz): return state["xz"], state["trace"] def deltas(t, xz): - return {"xz": self.velocity(xz, t, conditions=conditions)} + return {"xz": self.velocity(xz, t, conditions=conditions, training=training)} state = {"xz": z} state = integrate(deltas, state, start_time=0.0, stop_time=1.0, **kwargs) @@ -183,7 +185,7 @@ def compute_metrics( base_metrics = super().compute_metrics(x1, conditions, stage) - predicted_velocity = self.velocity(x, t, conditions) + predicted_velocity = self.velocity(x, t, conditions, training=stage == "training") loss = keras.losses.mean_squared_error(target_velocity, predicted_velocity) loss = keras.ops.mean(loss)