Skip to content

Commit c28f389

Browse files
committed
Fix TypeError: Value passed to parameter 'x' has DataType int32 not in list of allowed values: float16, bfloat16, float32, float64
1 parent 6487d54 commit c28f389

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

train.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def sample_k_fids_for_pid(pid, all_fids, all_pids, batch_k):
151151
# create a padded list of indices which contain a multiple of the
152152
# original FID count such that all of them will be sampled equally likely.
153153
count = tf.shape(possible_fids)[0]
154-
padded_count = tf.cast(tf.ceil(batch_k / count), tf.int32) * count
154+
padded_count = tf.cast(tf.ceil(batch_k / tf.cast(count, tf.float64)), tf.int32) * count
155155
full_range = tf.mod(tf.range(padded_count), count)
156156

157157
# Sampling is always performed by shuffling and taking the first k.

0 commit comments

Comments
 (0)