-
Notifications
You must be signed in to change notification settings - Fork 5.9k
initial flax pndm schedular #492
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The documentation is not available anymore as the PR was closed or merged. |
from ..utils import DummyObject, requires_backends | ||
|
||
|
||
class FlaxPNDMScheduler(metaclass=DummyObject): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very cool! Think it's just the set_format
logic that we can now remove here
* initial flax pndm * fix typo * use state * return state * add FlaxSchedulerOutput * fix style * add flax imports * make style * fix typos * return created state * make style * add torch/flax imports * docs * fixed typo * remove tensor_format * round instead of cast * ets is jnp array * remove copy
for issue #475 from PR patil-suraj/stable-diffusion-jax#8 and integrates fix from #466