Skip to content

Commit b56f102

Browse files
author
Nathan Lambert
authored
Fix scheduler inference steps error with power of 3 (#466)
* initial attempt at solving * fix pndm power of 3 inference_step * add power of 3 test * fix index in pndm test, remove ddim test * add comments, change to round()
1 parent da99063 commit b56f102

File tree

3 files changed

+26
-7
lines changed

3 files changed

+26
-7
lines changed

src/diffusers/schedulers/scheduling_ddim.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -145,9 +145,10 @@ def set_timesteps(self, num_inference_steps: int, offset: int = 0):
145145
optional value to shift timestep values up by. A value of 1 is used in stable diffusion for inference.
146146
"""
147147
self.num_inference_steps = num_inference_steps
148-
self.timesteps = np.arange(
149-
0, self.config.num_train_timesteps, self.config.num_train_timesteps // self.num_inference_steps
150-
)[::-1].copy()
148+
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
149+
# creates integer timesteps by multiplying by ratio
150+
# casting to int to avoid issues when num_inference_step is power of 3
151+
self.timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy()
151152
self.timesteps += offset
152153
self.set_format(tensor_format=self.tensor_format)
153154

src/diffusers/schedulers/scheduling_pndm.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -143,9 +143,10 @@ def set_timesteps(self, num_inference_steps: int, offset: int = 0) -> torch.Floa
143143
optional value to shift timestep values up by. A value of 1 is used in stable diffusion for inference.
144144
"""
145145
self.num_inference_steps = num_inference_steps
146-
self._timesteps = list(
147-
range(0, self.config.num_train_timesteps, self.config.num_train_timesteps // num_inference_steps)
148-
)
146+
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
147+
# creates integer timesteps by multiplying by ratio
148+
# casting to int to avoid issues when num_inference_step is power of 3
149+
self._timesteps = (np.arange(0, num_inference_steps) * step_ratio).round().tolist()
149150
self._offset = offset
150151
self._timesteps = np.array([t + self._offset for t in self._timesteps])
151152

tests/test_scheduler.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,7 @@ def test_time_indices(self):
379379

380380
def test_inference_steps(self):
381381
for t, num_inference_steps in zip([1, 10, 50], [10, 50, 500]):
382-
self.check_over_forward(num_inference_steps=num_inference_steps)
382+
self.check_over_forward(time_step=t, num_inference_steps=num_inference_steps)
383383

384384
def test_eta(self):
385385
for t, eta in zip([1, 10, 49], [0.0, 0.5, 1.0]):
@@ -622,6 +622,23 @@ def test_inference_steps(self):
622622
for t, num_inference_steps in zip([1, 5, 10], [10, 50, 100]):
623623
self.check_over_forward(time_step=t, num_inference_steps=num_inference_steps)
624624

625+
def test_pow_of_3_inference_steps(self):
626+
# earlier version of set_timesteps() caused an error indexing alpha's with inference steps as power of 3
627+
num_inference_steps = 27
628+
629+
for scheduler_class in self.scheduler_classes:
630+
sample = self.dummy_sample
631+
residual = 0.1 * sample
632+
633+
scheduler_config = self.get_scheduler_config()
634+
scheduler = scheduler_class(**scheduler_config)
635+
636+
scheduler.set_timesteps(num_inference_steps)
637+
638+
# before power of 3 fix, would error on first step, so we only need to do two
639+
for i, t in enumerate(scheduler.prk_timesteps[:2]):
640+
sample = scheduler.step_prk(residual, t, sample).prev_sample
641+
625642
def test_inference_plms_no_past_residuals(self):
626643
with self.assertRaises(ValueError):
627644
scheduler_class = self.scheduler_classes[0]

0 commit comments

Comments
 (0)