- Finish implementing the
UNet2Dmodel inmodeling_unte2d.py. Port weights of any existing LDM unet from diffusers and verify equivalence. I've added the skleton of modules that we need to implement in the file. - Adapt the
PNDMSchedulerfromdiffusersfor JAX: Usejnparrays and make it stateless. - Add the KL module from (here)[https://github.dev/CompVis/stable-diffusion] in
modeling_vae.pyfile. For inference we don't really need it, but would be nice to have for completeness. Port the weights of any existing KL VAE and verify equivalence. - Add an inference loop in
pipeline_stabel_diffusion. We should able tojit/pmapthe loop to deploy on TPUs.
-
Notifications
You must be signed in to change notification settings - Fork 8
License
patil-suraj/stable-diffusion-jax
Folders and files
| Name | Name | Last commit message | Last commit date | |
|---|---|---|---|---|
Repository files navigation
About
No description, website, or topics provided.
Resources
License
Stars
Watchers
Forks
Releases
No releases published
Packages 0
No packages published