Skip to content

Commit 2c5f872

Browse files
committed
improve time broadcasting
1 parent 56f88b3 commit 2c5f872

File tree

1 file changed

+3
-8
lines changed

1 file changed

+3
-8
lines changed

bayesflow/networks/flow_matching/flow_matching.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -107,14 +107,9 @@ def from_config(cls, config):
107107
return cls(**config)
108108

109109
def velocity(self, xz: Tensor, t: float | Tensor, conditions: Tensor = None, training: bool = False) -> Tensor:
110-
if not keras.ops.is_tensor(t):
111-
t = keras.ops.convert_to_tensor(t, dtype=keras.ops.dtype(xz))
112-
113-
if keras.ops.ndim(t) == 0:
114-
t = keras.ops.broadcast_to(t, keras.ops.shape(xz)[:-1])
115-
110+
t = keras.ops.convert_to_tensor(t)
116111
t = expand_right_as(t, xz)
117-
t = keras.ops.tile(t, [1] + list(keras.ops.shape(xz)[1:-1]) + [1])
112+
t = keras.ops.broadcast_to(t, keras.ops.shape(xz)[:-1] + (1,))
118113

119114
if conditions is None:
120115
xtc = keras.ops.concatenate([xz, t], axis=-1)
@@ -196,7 +191,7 @@ def compute_metrics(
196191
else:
197192
# not pre-configured, resample
198193
x1 = x
199-
x0 = keras.random.normal(keras.ops.shape(x1), dtype=keras.ops.dtype(x1), seed=self.seed_generator)
194+
x0 = self.base_distribution.sample(keras.ops.shape(x1), seed=self.seed_generator)
200195

201196
if self.use_optimal_transport:
202197
x1, x0, conditions = optimal_transport(

0 commit comments

Comments
 (0)