Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stateless scheduler #8

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open

Stateless scheduler #8

wants to merge 6 commits into from

Conversation

pcuenca
Copy link
Collaborator

@pcuenca pcuenca commented Aug 31, 2022

This is the biggest change derived from #5.

In that branch I was previously using a standard Python dataclass and a dict to manage the scheduler state, and I converted between them when needed. I have now adopted flax.struct.dataclass, which is essentially the same but:

  • It's frozen so it enforces functional updates.
  • Can be easily replicated across devices.

I'm not super happy with the result for these reasons:

  • Invoking scheduler.set_timesteps() requires more information to initialize the state, such as the shape of the expected inputs. This makes the code a bit uglier because you need to prepare that in advance.
  • The dataclass has many state vars. In PyTorch we initialized the instance with just the betas and the number of training steps, and then added a few items when setting the inference timesteps. If we try to do the same here by providing defaults to most attributes, then Python synthesizes an initializer with just the betas and the num_training_steps, and then the replace method doesn't work. I had to create a specific class method to simplify the initial creation of the instance.
  • The updates in step are harder to follow than the PyTorch logic.

Note also that the initial PRK steps have not been implemented yet (since we are skipping them by default). If we go with this approach they should be easy to add.

Also, make it compatible with any number of parallel devices instead of
assuming 8. This is useful to me while testing in my computer with two
different GPUs.

sample = state.sample
model_output = state.model_output
prev_sample = self._get_prev_sample(state, sample, timestep, prev_timestep, model_output)

prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output)
self.counter += 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we still use the self.counter variable?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also we should return the state somewhere no?

@patrickvonplaten
Copy link
Collaborator

That's a great PR @pcuenca! Think some clean-ups are left to make it functional, but overall the design seems like the correct design to me

@pcuenca
Copy link
Collaborator Author

pcuenca commented Sep 1, 2022

@patrickvonplaten You are right! I stashed my changes to apply them in a new branch and messed up somehow. I have just pushed the correct version, I hope. Sorry about that!

Copy link
Collaborator

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool - looks clean to me now!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants