Skip to content

Conversation

@pcuenca
Copy link
Collaborator

@pcuenca pcuenca commented Aug 31, 2022

This is the biggest change derived from #5.

In that branch I was previously using a standard Python dataclass and a dict to manage the scheduler state, and I converted between them when needed. I have now adopted flax.struct.dataclass, which is essentially the same but:

  • It's frozen so it enforces functional updates.
  • Can be easily replicated across devices.

I'm not super happy with the result for these reasons:

  • Invoking scheduler.set_timesteps() requires more information to initialize the state, such as the shape of the expected inputs. This makes the code a bit uglier because you need to prepare that in advance.
  • The dataclass has many state vars. In PyTorch we initialized the instance with just the betas and the number of training steps, and then added a few items when setting the inference timesteps. If we try to do the same here by providing defaults to most attributes, then Python synthesizes an initializer with just the betas and the num_training_steps, and then the replace method doesn't work. I had to create a specific class method to simplify the initial creation of the instance.
  • The updates in step are harder to follow than the PyTorch logic.

Note also that the initial PRK steps have not been implemented yet (since we are skipping them by default). If we go with this approach they should be easy to add.

Also, make it compatible with any number of parallel devices instead of
assuming 8. This is useful to me while testing in my computer with two
different GPUs.
prev_sample = self._get_prev_sample(state, sample, timestep, prev_timestep, model_output)

prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output)
self.counter += 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we still use the self.counter variable?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also we should return the state somewhere no?

@patrickvonplaten
Copy link
Collaborator

That's a great PR @pcuenca! Think some clean-ups are left to make it functional, but overall the design seems like the correct design to me

@pcuenca
Copy link
Collaborator Author

pcuenca commented Sep 1, 2022

@patrickvonplaten You are right! I stashed my changes to apply them in a new branch and messed up somehow. I have just pushed the correct version, I hope. Sorry about that!

Copy link
Collaborator

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool - looks clean to me now!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants