-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Stateless scheduler #8
base: main
Are you sure you want to change the base?
Conversation
Also, make it compatible with any number of parallel devices instead of assuming 8. This is useful to me while testing in my computer with two different GPUs.
|
||
sample = state.sample | ||
model_output = state.model_output | ||
prev_sample = self._get_prev_sample(state, sample, timestep, prev_timestep, model_output) | ||
|
||
prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output) | ||
self.counter += 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we still use the self.counter
variable?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also we should return the state somewhere no?
That's a great PR @pcuenca! Think some clean-ups are left to make it functional, but overall the design seems like the correct design to me |
@patrickvonplaten You are right! I stashed my changes to apply them in a new branch and messed up somehow. I have just pushed the correct version, I hope. Sorry about that! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool - looks clean to me now!
This is the biggest change derived from #5.
In that branch I was previously using a standard Python dataclass and a
dict
to manage the scheduler state, and I converted between them when needed. I have now adoptedflax.struct.dataclass
, which is essentially the same but:I'm not super happy with the result for these reasons:
scheduler.set_timesteps()
requires more information to initialize the state, such as the shape of the expected inputs. This makes the code a bit uglier because you need to prepare that in advance.betas
and thenum_training_steps
, and then thereplace
method doesn't work. I had to create a specific class method to simplify the initial creation of the instance.step
are harder to follow than the PyTorch logic.Note also that the initial PRK steps have not been implemented yet (since we are skipping them by default). If we go with this approach they should be easy to add.