Skip to content

Commit 3dc3b1d

Browse files
committed
add hardcoded data values for iterator.
1 parent 7623498 commit 3dc3b1d

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

src/maxdiffusion/trainers/wan_trainer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,11 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data
356356
start_step = restore_args.get("step", 0)
357357
per_device_tflops, _, _ = WanTrainer.calculate_tflops(pipeline)
358358
scheduler_state = pipeline.scheduler_state
359-
example_batch = load_next_batch(train_data_iterator, None, self.config)
359+
example_batch = {
360+
"latents" : jax.random.normal(rng, (jax.device_count() * self.config.global_batch_size_to_train_on, 16, 21, 90, 160), dtype=jnp.bfloat16),
361+
"encoder_hidden_states" : jax.random.normal(rng, (jax.device_count() * self.config.global_batch_size_to_train_on, 512, 4096), dtype=jnp.bfloat16)
362+
}
363+
example_batch = load_next_batch(train_data_iterator, example_batch, self.config)
360364

361365
with ThreadPoolExecutor(max_workers=1) as executor:
362366
for step in np.arange(start_step, self.config.max_train_steps):

0 commit comments

Comments
 (0)