|
| 1 | +import jax.numpy as jnp |
| 2 | +from maxdiffusion.schedulers.scheduling_rectified_flow import FlaxRectifiedFlowMultistepScheduler |
| 3 | +import os |
| 4 | +from maxdiffusion import max_logging |
| 5 | +import torch |
| 6 | +import unittest |
| 7 | +from absl.testing import absltest |
| 8 | +import numpy as np |
| 9 | + |
| 10 | + |
| 11 | + |
| 12 | +class rfTest(unittest.TestCase): |
| 13 | + |
| 14 | + def test_rf_steps(self): |
| 15 | + # --- Simulation Parameters --- |
| 16 | + latent_tensor_shape = (1, 256, 128) # Example latent tensor shape (Batch, Channels, Height, Width) |
| 17 | + inference_steps_count = 5 # Number of steps for the denoising process |
| 18 | + |
| 19 | + # --- Run the Simulation --- |
| 20 | + max_logging.log("\n--- Simulating RectifiedFlowMultistepScheduler ---") |
| 21 | + |
| 22 | + seed = 42 |
| 23 | + device = 'cpu' |
| 24 | + max_logging.log(f"Sample shape: {latent_tensor_shape}, Inference steps: {inference_steps_count}, Seed: {seed}") |
| 25 | + |
| 26 | + generator = torch.Generator(device=device).manual_seed(seed) |
| 27 | + |
| 28 | + # 1. Instantiate the scheduler |
| 29 | + config = {'_class_name': 'RectifiedFlowScheduler', '_diffusers_version': '0.25.1', 'num_train_timesteps': 1000, 'shifting': None, 'base_resolution': None, 'sampler': 'LinearQuadratic'} |
| 30 | + flax_scheduler = FlaxRectifiedFlowMultistepScheduler.from_config(config) |
| 31 | + |
| 32 | + # 2. Create and set initial state for the scheduler |
| 33 | + flax_state = flax_scheduler.create_state() |
| 34 | + flax_state = flax_scheduler.set_timesteps(flax_state, inference_steps_count, latent_tensor_shape) |
| 35 | + max_logging.log("\nScheduler initialized.") |
| 36 | + max_logging.log(f" flax_state timesteps shape: {flax_state.timesteps.shape}") |
| 37 | + |
| 38 | + # 3. Prepare the initial noisy latent sample |
| 39 | + # In a real scenario, this would typically be pure random noise (e.g., N(0,1)) |
| 40 | + # For simulation, we'll generate it. |
| 41 | + |
| 42 | + sample = jnp.array(torch.randn(latent_tensor_shape, generator=generator, dtype=torch.float32).to(device).numpy()) |
| 43 | + max_logging.log(f"\nInitial sample shape: {sample.shape}, dtype: {sample.dtype}") |
| 44 | + |
| 45 | + # 4. Simulate the denoising loop |
| 46 | + max_logging.log("\nStarting denoising loop:") |
| 47 | + for i, t in enumerate(flax_state.timesteps): |
| 48 | + max_logging.log(f" Step {i+1}/{inference_steps_count}, Timestep: {t.item()}") |
| 49 | + |
| 50 | + # Simulate model_output (e.g., noise prediction from a UNet) |
| 51 | + model_output = jnp.array(torch.randn(latent_tensor_shape, generator=generator, dtype=torch.float32).to(device).numpy()) |
| 52 | + |
| 53 | + # Call the scheduler's step function |
| 54 | + scheduler_output = flax_scheduler.step( |
| 55 | + state=flax_state, |
| 56 | + model_output=model_output, |
| 57 | + timestep=t, # Pass the current timestep from the scheduler's sequence |
| 58 | + sample=sample, |
| 59 | + return_dict=True # Return a SchedulerOutput dataclass |
| 60 | + ) |
| 61 | + |
| 62 | + sample = scheduler_output.prev_sample # Update the sample for the next step |
| 63 | + flax_state = scheduler_output.state # Update the state for the next step |
| 64 | + |
| 65 | + # Compare with pytorch implementation |
| 66 | + base_dir = os.path.dirname(__file__) |
| 67 | + ref_dir = os.path.join(base_dir, "rf_scheduler_test_ref") |
| 68 | + ref_filename = os.path.join(ref_dir, f"step_{i+1:02d}.npy") |
| 69 | + if os.path.exists(ref_filename): |
| 70 | + pt_sample = np.load(ref_filename) |
| 71 | + torch.testing.assert_close(np.array(sample), pt_sample) |
| 72 | + else: |
| 73 | + max_logging.log(f"Warning: Reference file not found: {ref_filename}") |
| 74 | + |
| 75 | + |
| 76 | + max_logging.log("\nDenoising loop completed.") |
| 77 | + max_logging.log(f"Final sample shape: {sample.shape}, dtype: {sample.dtype}") |
| 78 | + max_logging.log(f"Final sample min: {sample.min().item():.4f}, max: {sample.max().item():.4f}") |
| 79 | + |
| 80 | + max_logging.log("\nSimulation of RectifiedMultistepScheduler usage complete.") |
| 81 | + |
| 82 | + |
| 83 | +if __name__ == "__main__": |
| 84 | + absltest.main() |
0 commit comments