Skip to content

Commit

Permalink
improve time broadcasting
Browse files Browse the repository at this point in the history
  • Loading branch information
LarsKue committed Feb 7, 2025
1 parent 56f88b3 commit 2c5f872
Showing 1 changed file with 3 additions and 8 deletions.
11 changes: 3 additions & 8 deletions bayesflow/networks/flow_matching/flow_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,14 +107,9 @@ def from_config(cls, config):
return cls(**config)

def velocity(self, xz: Tensor, t: float | Tensor, conditions: Tensor = None, training: bool = False) -> Tensor:
if not keras.ops.is_tensor(t):
t = keras.ops.convert_to_tensor(t, dtype=keras.ops.dtype(xz))

if keras.ops.ndim(t) == 0:
t = keras.ops.broadcast_to(t, keras.ops.shape(xz)[:-1])

t = keras.ops.convert_to_tensor(t)
t = expand_right_as(t, xz)
t = keras.ops.tile(t, [1] + list(keras.ops.shape(xz)[1:-1]) + [1])
t = keras.ops.broadcast_to(t, keras.ops.shape(xz)[:-1] + (1,))

if conditions is None:
xtc = keras.ops.concatenate([xz, t], axis=-1)
Expand Down Expand Up @@ -196,7 +191,7 @@ def compute_metrics(
else:
# not pre-configured, resample
x1 = x
x0 = keras.random.normal(keras.ops.shape(x1), dtype=keras.ops.dtype(x1), seed=self.seed_generator)
x0 = self.base_distribution.sample(keras.ops.shape(x1), seed=self.seed_generator)

if self.use_optimal_transport:
x1, x0, conditions = optimal_transport(
Expand Down

0 comments on commit 2c5f872

Please sign in to comment.