Skip to content

Commit 7415c5c

Browse files
author
Nathan Lambert
authored
Fix scheduler inference steps error with power of 3 (huggingface#466)
* initial attempt at solving * fix pndm power of 3 inference_step * add power of 3 test * fix index in pndm test, remove ddim test * add comments, change to round()
1 parent c2df7ed commit 7415c5c

File tree

2 files changed

+8
-6
lines changed

2 files changed

+8
-6
lines changed

schedulers/scheduling_ddim.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -145,9 +145,10 @@ def set_timesteps(self, num_inference_steps: int, offset: int = 0):
145145
optional value to shift timestep values up by. A value of 1 is used in stable diffusion for inference.
146146
"""
147147
self.num_inference_steps = num_inference_steps
148-
self.timesteps = np.arange(
149-
0, self.config.num_train_timesteps, self.config.num_train_timesteps // self.num_inference_steps
150-
)[::-1].copy()
148+
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
149+
# creates integer timesteps by multiplying by ratio
150+
# casting to int to avoid issues when num_inference_step is power of 3
151+
self.timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy()
151152
self.timesteps += offset
152153
self.set_format(tensor_format=self.tensor_format)
153154

schedulers/scheduling_pndm.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -143,9 +143,10 @@ def set_timesteps(self, num_inference_steps: int, offset: int = 0) -> torch.Floa
143143
optional value to shift timestep values up by. A value of 1 is used in stable diffusion for inference.
144144
"""
145145
self.num_inference_steps = num_inference_steps
146-
self._timesteps = list(
147-
range(0, self.config.num_train_timesteps, self.config.num_train_timesteps // num_inference_steps)
148-
)
146+
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
147+
# creates integer timesteps by multiplying by ratio
148+
# casting to int to avoid issues when num_inference_step is power of 3
149+
self._timesteps = (np.arange(0, num_inference_steps) * step_ratio).round().tolist()
149150
self._offset = offset
150151
self._timesteps = np.array([t + self._offset for t in self._timesteps])
151152

0 commit comments

Comments
 (0)