Skip to content

Commit 6ec20ea

Browse files
authored
Rectified Flow Scheduler Test (#220)
* added rf scheduler test * rectified flow scheduler test added * removed safetensors downloading * replaced print statements with max_logging
1 parent de60c6c commit 6ec20ea

File tree

7 files changed

+84
-0
lines changed

7 files changed

+84
-0
lines changed
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

tests/schedulers/test_scheduler_rf.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
import jax.numpy as jnp
2+
from maxdiffusion.schedulers.scheduling_rectified_flow import FlaxRectifiedFlowMultistepScheduler
3+
import os
4+
from maxdiffusion import max_logging
5+
import torch
6+
import unittest
7+
from absl.testing import absltest
8+
import numpy as np
9+
10+
11+
12+
class rfTest(unittest.TestCase):
13+
14+
def test_rf_steps(self):
15+
# --- Simulation Parameters ---
16+
latent_tensor_shape = (1, 256, 128) # Example latent tensor shape (Batch, Channels, Height, Width)
17+
inference_steps_count = 5 # Number of steps for the denoising process
18+
19+
# --- Run the Simulation ---
20+
max_logging.log("\n--- Simulating RectifiedFlowMultistepScheduler ---")
21+
22+
seed = 42
23+
device = 'cpu'
24+
max_logging.log(f"Sample shape: {latent_tensor_shape}, Inference steps: {inference_steps_count}, Seed: {seed}")
25+
26+
generator = torch.Generator(device=device).manual_seed(seed)
27+
28+
# 1. Instantiate the scheduler
29+
config = {'_class_name': 'RectifiedFlowScheduler', '_diffusers_version': '0.25.1', 'num_train_timesteps': 1000, 'shifting': None, 'base_resolution': None, 'sampler': 'LinearQuadratic'}
30+
flax_scheduler = FlaxRectifiedFlowMultistepScheduler.from_config(config)
31+
32+
# 2. Create and set initial state for the scheduler
33+
flax_state = flax_scheduler.create_state()
34+
flax_state = flax_scheduler.set_timesteps(flax_state, inference_steps_count, latent_tensor_shape)
35+
max_logging.log("\nScheduler initialized.")
36+
max_logging.log(f" flax_state timesteps shape: {flax_state.timesteps.shape}")
37+
38+
# 3. Prepare the initial noisy latent sample
39+
# In a real scenario, this would typically be pure random noise (e.g., N(0,1))
40+
# For simulation, we'll generate it.
41+
42+
sample = jnp.array(torch.randn(latent_tensor_shape, generator=generator, dtype=torch.float32).to(device).numpy())
43+
max_logging.log(f"\nInitial sample shape: {sample.shape}, dtype: {sample.dtype}")
44+
45+
# 4. Simulate the denoising loop
46+
max_logging.log("\nStarting denoising loop:")
47+
for i, t in enumerate(flax_state.timesteps):
48+
max_logging.log(f" Step {i+1}/{inference_steps_count}, Timestep: {t.item()}")
49+
50+
# Simulate model_output (e.g., noise prediction from a UNet)
51+
model_output = jnp.array(torch.randn(latent_tensor_shape, generator=generator, dtype=torch.float32).to(device).numpy())
52+
53+
# Call the scheduler's step function
54+
scheduler_output = flax_scheduler.step(
55+
state=flax_state,
56+
model_output=model_output,
57+
timestep=t, # Pass the current timestep from the scheduler's sequence
58+
sample=sample,
59+
return_dict=True # Return a SchedulerOutput dataclass
60+
)
61+
62+
sample = scheduler_output.prev_sample # Update the sample for the next step
63+
flax_state = scheduler_output.state # Update the state for the next step
64+
65+
# Compare with pytorch implementation
66+
base_dir = os.path.dirname(__file__)
67+
ref_dir = os.path.join(base_dir, "rf_scheduler_test_ref")
68+
ref_filename = os.path.join(ref_dir, f"step_{i+1:02d}.npy")
69+
if os.path.exists(ref_filename):
70+
pt_sample = np.load(ref_filename)
71+
torch.testing.assert_close(np.array(sample), pt_sample)
72+
else:
73+
max_logging.log(f"Warning: Reference file not found: {ref_filename}")
74+
75+
76+
max_logging.log("\nDenoising loop completed.")
77+
max_logging.log(f"Final sample shape: {sample.shape}, dtype: {sample.dtype}")
78+
max_logging.log(f"Final sample min: {sample.min().item():.4f}, max: {sample.max().item():.4f}")
79+
80+
max_logging.log("\nSimulation of RectifiedMultistepScheduler usage complete.")
81+
82+
83+
if __name__ == "__main__":
84+
absltest.main()

0 commit comments

Comments
 (0)