Skip to content

Commit

Permalink
Fix UNET and VAE implementations for new diffusers version (#4663)
Browse files Browse the repository at this point in the history
This PR updates the DepSpeed UNET and VAE implementations to support
`diffusers>=0.23.0`.
  • Loading branch information
lekurile authored Nov 10, 2023
1 parent 4388a60 commit a361bac
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
8 changes: 7 additions & 1 deletion deepspeed/model_implementations/diffusers/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,13 @@ def _create_cuda_graph(self, *inputs, **kwargs):

self.cuda_graph_created = True

def _forward(self, sample, timestamp, encoder_hidden_states, return_dict=True, cross_attention_kwargs=None):
def _forward(self,
sample,
timestamp,
encoder_hidden_states,
return_dict=True,
cross_attention_kwargs=None,
timestep_cond=None):
if cross_attention_kwargs:
return self.unet(sample,
timestamp,
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/model_implementations/diffusers/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def _graph_replay_decoder(self, *inputs, **kwargs):
self._decoder_cuda_graph.replay()
return self.static_decoder_output

def _decode(self, x, return_dict=True):
def _decode(self, x, return_dict=True, generator=None):
return self.vae.decode(x, return_dict=return_dict)

def _create_cuda_graph_decoder(self, *inputs, **kwargs):
Expand Down

0 comments on commit a361bac

Please sign in to comment.