Skip to content

Commit

Permalink
fix #288
Browse files Browse the repository at this point in the history
  • Loading branch information
LarsKue committed Feb 6, 2025
1 parent 8466ff6 commit 15b4a8a
Showing 1 changed file with 11 additions and 9 deletions.
20 changes: 11 additions & 9 deletions bayesflow/networks/flow_matching/flow_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def from_config(cls, config):
config = deserialize_value_or_type(config, "subnet")
return cls(**config)

def velocity(self, xz: Tensor, t: float | Tensor, conditions: Tensor = None) -> Tensor:
def velocity(self, xz: Tensor, t: float | Tensor, conditions: Tensor = None, training: bool = False) -> Tensor:
if not keras.ops.is_tensor(t):
t = keras.ops.convert_to_tensor(t, dtype=keras.ops.dtype(xz))

Expand All @@ -105,11 +105,13 @@ def velocity(self, xz: Tensor, t: float | Tensor, conditions: Tensor = None) ->
else:
xtc = keras.ops.concatenate([xz, t, conditions], axis=-1)

return self.output_projector(self.subnet(xtc))
return self.output_projector(self.subnet(xtc, training=training), training=training)

def _velocity_trace(self, xz: Tensor, t: Tensor, conditions: Tensor = None, max_steps: int = 1) -> (Tensor, Tensor):
def _velocity_trace(
self, xz: Tensor, t: Tensor, conditions: Tensor = None, max_steps: int = 1, training: bool = False
) -> (Tensor, Tensor):
def f(x):
return self.velocity(x, t, conditions=conditions)
return self.velocity(x, t, conditions=conditions, training=training)

v, trace = jacobian_trace(f, xz, max_steps=max_steps)

Expand All @@ -121,7 +123,7 @@ def _forward(
if density:

def deltas(t, xz):
v, trace = self._velocity_trace(xz, t, conditions=conditions)
v, trace = self._velocity_trace(xz, t, conditions=conditions, training=training)
return {"xz": v, "trace": trace}

state = {"xz": x, "trace": 0.0}
Expand All @@ -130,7 +132,7 @@ def deltas(t, xz):
return state["xz"], state["trace"]

def deltas(t, xz):
return {"xz": self.velocity(xz, t, conditions=conditions)}
return {"xz": self.velocity(xz, t, conditions=conditions, training=training)}

state = {"xz": x}
state = integrate(deltas, state, start_time=1.0, stop_time=0.0, **kwargs)
Expand All @@ -143,7 +145,7 @@ def _inverse(
if density:

def deltas(t, xz):
v, trace = self._velocity_trace(xz, t, conditions=conditions)
v, trace = self._velocity_trace(xz, t, conditions=conditions, training=training)
return {"xz": v, "trace": trace}

state = {"xz": z, "trace": 0.0}
Expand All @@ -152,7 +154,7 @@ def deltas(t, xz):
return state["xz"], state["trace"]

def deltas(t, xz):
return {"xz": self.velocity(xz, t, conditions=conditions)}
return {"xz": self.velocity(xz, t, conditions=conditions, training=training)}

state = {"xz": z}
state = integrate(deltas, state, start_time=0.0, stop_time=1.0, **kwargs)
Expand Down Expand Up @@ -183,7 +185,7 @@ def compute_metrics(

base_metrics = super().compute_metrics(x1, conditions, stage)

predicted_velocity = self.velocity(x, t, conditions)
predicted_velocity = self.velocity(x, t, conditions, training=stage == "training")

loss = keras.losses.mean_squared_error(target_velocity, predicted_velocity)
loss = keras.ops.mean(loss)
Expand Down

0 comments on commit 15b4a8a

Please sign in to comment.