@@ -107,14 +107,9 @@ def from_config(cls, config):
107
107
return cls (** config )
108
108
109
109
def velocity (self , xz : Tensor , t : float | Tensor , conditions : Tensor = None , training : bool = False ) -> Tensor :
110
- if not keras .ops .is_tensor (t ):
111
- t = keras .ops .convert_to_tensor (t , dtype = keras .ops .dtype (xz ))
112
-
113
- if keras .ops .ndim (t ) == 0 :
114
- t = keras .ops .broadcast_to (t , keras .ops .shape (xz )[:- 1 ])
115
-
110
+ t = keras .ops .convert_to_tensor (t )
116
111
t = expand_right_as (t , xz )
117
- t = keras .ops .tile (t , [ 1 ] + list ( keras .ops .shape (xz )[1 :- 1 ]) + [ 1 ] )
112
+ t = keras .ops .broadcast_to (t , keras .ops .shape (xz )[:- 1 ] + ( 1 ,) )
118
113
119
114
if conditions is None :
120
115
xtc = keras .ops .concatenate ([xz , t ], axis = - 1 )
@@ -196,7 +191,7 @@ def compute_metrics(
196
191
else :
197
192
# not pre-configured, resample
198
193
x1 = x
199
- x0 = keras . random . normal (keras .ops .shape ( x1 ), dtype = keras . ops . dtype (x1 ), seed = self .seed_generator )
194
+ x0 = self . base_distribution . sample (keras .ops .shape (x1 ), seed = self .seed_generator )
200
195
201
196
if self .use_optimal_transport :
202
197
x1 , x0 , conditions = optimal_transport (
0 commit comments