Skip to content

self.attn ValueError: not enough values to unpack (expected 2, got 1) #10066

@tracyyy-zhu

Description

@tracyyy-zhu

Describe the bug

I have encountered an error when trying to sample from SD3 transformer component, where line 208 in attention.py is asking for two outputs but only one is given:

  File "/home/p2p/pytorch/mc3/envs/marigold/lib/python3.10/site-packages/diffusers/src/diffusers/models/attention.py", line 208, in forward
    attn_output, context_attn_output = self.attn(
ValueError: not enough values to unpack (expected 2, got 1)

Reproduction

https://colab.research.google.com/drive/1CkgjIaaClKUk4ZC-g_RR578BfFE24gps#scrollTo=1xcDHPHd56WH

import logging
from diffusers.models.transformers import SD3Transformer2DModel
import torch
from typing import Optional
from torch.nn import Conv2d
from torch.nn.parameter import Parameter

device = torch.device("cuda")
cat_latents = torch.randn(1, 16, 128, 128).to(device)
timesteps = torch.tensor([453.9749]).to(device)
prompt_embeds = torch.randn(1, 154, 4096).to(device)
pooled_prompt_embeds = torch.randn(1, 2048).to(device)

model = SD3Transformer2DModel().to(device)
model.enable_xformers_memory_efficient_attention()

model_pred = model(
    hidden_states=cat_latents,
    timestep=timesteps,
    encoder_hidden_states=prompt_embeds,
    pooled_projections=pooled_prompt_embeds,
    return_dict=False,
)[0]

Logs

Traceback (most recent call last):
  File "/home/p2p/src/trainer/trainer.py", line 347, in train
    model_pred = self.model.transformer(
  File "/home/p2p/pytorch/mc3/envs/marigold/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/p2p/pytorch/mc3/envs/marigold/lib/python3.10/site-packages/diffusers/src/diffusers/models/transformers/transformer_sd3.py", line 347, in forward
    encoder_hidden_states, hidden_states = block(
  File "/home/p2p/pytorch/mc3/envs/marigold/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/p2p/pytorch/mc3/envs/marigold/lib/python3.10/site-packages/diffusers/src/diffusers/models/attention.py", line 208, in forward
    attn_output, context_attn_output = self.attn(
ValueError: not enough values to unpack (expected 2, got 1)

System Info

  • 🤗 Diffusers version: 0.32.0.dev0
  • Platform: Linux-6.1.0-26-amd64-x86_64-with-glibc2.36
  • Running on Google Colab?: No
  • Python version: 3.10.12
  • PyTorch version (GPU?): 2.0.1+cu117 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 0.26.2
  • Transformers version: 4.46.1
  • Accelerate version: 0.34.2
  • PEFT version: 0.13.2
  • Bitsandbytes version: not installed
  • Safetensors version: 0.4.3
  • xFormers version: 0.0.21
  • Accelerator: NVIDIA RTX A6000, 49140 MiB
  • Using GPU in script?:
  • Using distributed or parallel set-up in script?:

Who can help?

@yiyixuxu @sayakpaul @DN6 @asomoza @yiyixuxu @sayakpaul

Activity

sayakpaul

sayakpaul commented on Dec 1, 2024

@sayakpaul
Member

This is not a reproducible code snippet. Could you please adjust it accordingly? We don't know any values here:

model_pred = self.model.transformer(
    hidden_states=cat_latents,
    timestep=timesteps,
    encoder_hidden_states=prompt_embeds,
    pooled_projections=pooled_prompt_embeds,
    return_dict=False,
)[0]
tracyyy-zhu

tracyyy-zhu commented on Dec 2, 2024

@tracyyy-zhu
Author

The inputs are tensors of shape as following:

`cat_latents` torch.Size([1, 32, 128, 128])
timesteps torch.Size([1])
prompt_embeds torch.Size([1, 154, 4096])
pooled_prompt_embeds torch.Size([1, 2048])
sayakpaul

sayakpaul commented on Dec 3, 2024

@sayakpaul
Member

I kindly ask you to modify your reproducer to a fully reproducible and minimal one.

tracyyy-zhu

tracyyy-zhu commented on Dec 8, 2024

@tracyyy-zhu
Author

Hi @sayakpaul I have modified to include the tensor values.

sayakpaul

sayakpaul commented on Dec 9, 2024

@sayakpaul
Member
Jamie-Cheung

Jamie-Cheung commented on Dec 23, 2024

@Jamie-Cheung

Hi @sayakpaul I have modified to include the tensor values.

I encountered the same mistake as you, can you share how you solved it

tracyyy-zhu

tracyyy-zhu commented on Dec 24, 2024

@tracyyy-zhu
Author

@Jamie-Cheung from my experience, commenting out this line in code solves the error:
# model.transformer.enable_xformers_memory_efficient_attention()
Fyi I was using diffusers version 0.31.0 in my code base for some compatibility reason. In the following notebook, changing to diffusers==0.32.0 also solves the error.

@sayakpaul This notebook reproduces the error: https://colab.research.google.com/drive/1CkgjIaaClKUk4ZC-g_RR578BfFE24gps#scrollTo=HU4WaNDM9vJk
From my debug, enable_xformers_memory_efficient_attention() changes some attention professors to diffusers.models.attention_processor.XFormersAttnProcessor object at 0x7f8e6581fd60, which returns only one value.

I added a print statement print("self processor", self.processor) before this line: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py#L588
This is what I get when enable_xformers_memory_efficient_attention() is not commented:

self processor <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x7f8e659a23e0>
self processor <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x7f8e659a23e0>
self processor <diffusers.models.attention_processor.XFormersAttnProcessor object at 0x7f8e6581fd60>

When I comment it I get (for one forward pass):

self processor <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x7f45f3646c50>
self processor <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x7f45f3646c50>
self processor <diffusers.models.attention_processor.JointAttnProcessor2_0 object at 0x7f45f3623070>
self processor <diffusers.models.attention_processor.JointAttnProcessor2_0 object at 0x7f45f3623c70>
self processor <diffusers.models.attention_processor.JointAttnProcessor2_0 object at 0x7f45f3644670>
self processor <diffusers.models.attention_processor.JointAttnProcessor2_0 object at 0x7f45f36450c0>
self processor <diffusers.models.attention_processor.JointAttnProcessor2_0 object at 0x7f45f3645990>
self processor <diffusers.models.attention_processor.JointAttnProcessor2_0 object at 0x7f45f3646290>
self processor <diffusers.models.attention_processor.JointAttnProcessor2_0 object at 0x7f45f3646dd0>
self processor <diffusers.models.attention_processor.JointAttnProcessor2_0 object at 0x7f45f3647670>
self processor <diffusers.models.attention_processor.JointAttnProcessor2_0 object at 0x7f45f36e80a0>
self processor <diffusers.models.attention_processor.JointAttnProcessor2_0 object at 0x7f45f36e89a0>
self processor <diffusers.models.attention_processor.JointAttnProcessor2_0 object at 0x7f45f36e9240>
self processor <diffusers.models.attention_processor.JointAttnProcessor2_0 object at 0x7f45f36e9e40>
self processor <diffusers.models.attention_processor.JointAttnProcessor2_0 object at 0x7f45f36ea710>
self processor <diffusers.models.attention_processor.JointAttnProcessor2_0 object at 0x7f45f36eb100>
self processor <diffusers.models.attention_processor.JointAttnProcessor2_0 object at 0x7f45f36eba90>
self processor <diffusers.models.attention_processor.JointAttnProcessor2_0 object at 0x7f45f35fc310>
self processor <diffusers.models.attention_processor.JointAttnProcessor2_0 object at 0x7f45f35fce20>
self processor <diffusers.models.attention_processor.JointAttnProcessor2_0 object at 0x7f45f35fd750>

Please feel free to leave any advice or thoughts

Jamie-Cheung

Jamie-Cheung commented on Dec 26, 2024

@Jamie-Cheung

@Jamie-Cheung from my experience, commenting out this line in code solves the error: # model.transformer.enable_xformers_memory_efficient_attention() Fyi I was using diffusers version 0.31.0 in my code base for some compatibility reason. In the following notebook, changing to diffusers==0.32.0 also solves the error.

@sayakpaul This notebook reproduces the error: https://colab.research.google.com/drive/1CkgjIaaClKUk4ZC-g_RR578BfFE24gps#scrollTo=HU4WaNDM9vJk From my debug, enable_xformers_memory_efficient_attention() changes some attention professors to diffusers.models.attention_processor.XFormersAttnProcessor object at 0x7f8e6581fd60, which returns only one value.

I added a print statement print("self processor", self.processor) before this line: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py#L588 This is what I get when enable_xformers_memory_efficient_attention() is not commented:

self processor <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x7f8e659a23e0>
self processor <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x7f8e659a23e0>
self processor <diffusers.models.attention_processor.XFormersAttnProcessor object at 0x7f8e6581fd60>

When I comment it I get (for one forward pass):

self processor <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x7f45f3646c50>
self processor <diffusers.models.attention_processor.AttnProcessor2_0 object at 0x7f45f3646c50>
self processor <diffusers.models.attention_processor.JointAttnProcessor2_0 object at 0x7f45f3623070>
self processor <diffusers.models.attention_processor.JointAttnProcessor2_0 object at 0x7f45f3623c70>
self processor <diffusers.models.attention_processor.JointAttnProcessor2_0 object at 0x7f45f3644670>
self processor <diffusers.models.attention_processor.JointAttnProcessor2_0 object at 0x7f45f36450c0>
self processor <diffusers.models.attention_processor.JointAttnProcessor2_0 object at 0x7f45f3645990>
self processor <diffusers.models.attention_processor.JointAttnProcessor2_0 object at 0x7f45f3646290>
self processor <diffusers.models.attention_processor.JointAttnProcessor2_0 object at 0x7f45f3646dd0>
self processor <diffusers.models.attention_processor.JointAttnProcessor2_0 object at 0x7f45f3647670>
self processor <diffusers.models.attention_processor.JointAttnProcessor2_0 object at 0x7f45f36e80a0>
self processor <diffusers.models.attention_processor.JointAttnProcessor2_0 object at 0x7f45f36e89a0>
self processor <diffusers.models.attention_processor.JointAttnProcessor2_0 object at 0x7f45f36e9240>
self processor <diffusers.models.attention_processor.JointAttnProcessor2_0 object at 0x7f45f36e9e40>
self processor <diffusers.models.attention_processor.JointAttnProcessor2_0 object at 0x7f45f36ea710>
self processor <diffusers.models.attention_processor.JointAttnProcessor2_0 object at 0x7f45f36eb100>
self processor <diffusers.models.attention_processor.JointAttnProcessor2_0 object at 0x7f45f36eba90>
self processor <diffusers.models.attention_processor.JointAttnProcessor2_0 object at 0x7f45f35fc310>
self processor <diffusers.models.attention_processor.JointAttnProcessor2_0 object at 0x7f45f35fce20>
self processor <diffusers.models.attention_processor.JointAttnProcessor2_0 object at 0x7f45f35fd750>

Please feel free to leave any advice or thoughts

If I want to use xformers to further reduce GPU memory use. How to solve "diffusers.models.attention_processor.XFormersAttnProcessor object at 0x7f8e6581fd60 only returns only one value"?

github-actions

github-actions commented on Jan 19, 2025

@github-actions
Contributor

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

tsvs

tsvs commented on Jun 16, 2025

@tsvs

getting the same problem, is there a way to run flux with xformers in diffusers?

yiyixuxu

yiyixuxu commented on Jun 17, 2025

@yiyixuxu
Collaborator

hey! I cannot reproduce the error

are you on the latest diffusers?

can you print out model.attn_processors?

tsvs

tsvs commented on Jun 17, 2025

@tsvs

model.attn_processors: self.attn.processor=<diffusers.models.attention_processor.XFormersAttnProcessor object at 0x722a2a1b3a10

I use: diffusers 0.33.1

tsvs

tsvs commented on Jun 17, 2025

@tsvs

the problem is that in this line https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_flux.py#L151

with xformers enabled attention_outputs is tensor but in lines bellow output expected to be a tuple with len 2 or 3 - https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_flux.py#L158

yiyixuxu

yiyixuxu commented on Jun 17, 2025

@yiyixuxu
Collaborator

you should see XFormersJointAttnProcessor instead of XFormersAttnProcessor

could you install from the source? just want to rule out that it's a bug we've already fixed

tsvs

tsvs commented on Jun 19, 2025

@tsvs

with installation from source pip install git+https://github.com/huggingface/diffusers

and latest version of xformers, I'm getting the same problem - XFormerAttnProcessor and wrong output for attention_ouput

attentions - transformer.attn_processors={'transformer_blocks.0.attn.processor': <diffusers.models.attention_processor.XFormersAttnProcessor object at 0x7d297134af50>, 'transformer_blocks.1.attn.processor': <diffusers.models.attention_processor.XFormersAttnProcessor object at 0x7d297fd23b90>, 'transformer_blocks.2.attn.processor': <diffusers.models.attention_processor.XFormersAttnProcessor object at 0x7d297fd8f950>, 'transformer_blocks.3.attn.processor': <diffusers.models.attention_processor.XFormersAttnProcessor object at 0x7d298f221e90>, 'transformer_blocks.4.attn.processor': <diffusers.models.attention_processor.XFormersAttnProcessor object at 0x7d297cd17190>, 'transformer_blocks.5.attn.processor': <diffusers.models.attention_processor.XFormersAttnProcessor object at 0x7d2970bc8f90>, 'transformer_blocks.6.attn.processor': <diffusers.models.attention_processor.XFormersAttnProcessor object at 0x7d2971434610>, 'transformer_blocks.7.attn.processor': <diffusers.models.attention_processor.XFormersAttnProcessor object at 0x7d2971437850>, 'transformer_blocks.8.attn.processor': <diffusers.models.attention_processor.XFormersAttnProcessor object at 0x7d297144ea90>, 'transformer_blocks.9.attn.processor': <diffusers.models.attention_processor.XFormersAttnProcessor object at 0x7d2971469c10>, 'transformer_blocks.10.attn.processor': <diffusers.models.attention_processor.XFormersAttnProcessor object at 0x7d2971488dd0>, 'transformer_blocks.11.attn.processor': <diffusers.models.attention_processor.XFormersAttnProcessor object at 0x7d297148be90>, 'transformer_blocks.12.attn.processor': <diffusers.models.attention_processor.XFormersAttnProcessor object at 0x7d29714a3090>, 'transformer_blocks.13.attn.processor': <diffusers.models.attention_processor.XFormersAttnProcessor object at 0x7d29714c2310>, 'transformer_blocks.14.attn.processor': <diffusers.models.attention_processor.XFormersAttnProcessor object at 0x7d29714dd510>, 'transformer_blocks.15.attn.processor': <diffusers.models.attention_processor.XFormersAttnProcessor object at 0x7d29714fc810>, 'transformer_blocks.16.attn.processor': <diffusers.models.attention_processor.XFormersAttnProcessor object at 0x7d29714ffad0>, 'transformer_blocks.17.attn.processor': <diffusers.models.attention_processor.XFormersAttnProcessor object at 0x7d297131ac50>, 'transformer_blocks.18.attn.processor': <diffusers.models.attention_processor.XFormersAttnProcessor object at 0x7d2971331ed0>, 'single_transformer_blocks.0.attn.processor': <diffusers.models.attention_processor.XFormersAttnProcessor object at 0x7d297fd395d0>, 'single_transformer_blocks.1.attn.processor': <diffusers.models.attention_processor.XFormersAttnProcessor object at 0x7d297fd01c90>, 'single_transformer_blocks.2.attn.processor': <diffusers.models.attention_processor.XFormersAttnProcessor object at 0x7d297fd05890>, 'single_transformer_blocks.3.attn.processor': <diffusers.models.attention_processor.XFormersAttnProcessor object at 0x7d297fd075d0>, 'single_transformer_blocks.4.attn.processor': <diffusers.models.attention_processor.XFormersAttnProcessor object at 0x7d2971348b50>, 'single_transformer_blocks.5.attn.processor': <diffusers.models.attention_processor.XFormersAttnProcessor object at 0x7d2971349d90>, 'single_transformer_blocks.6.attn.processor': <diffusers.models.attention_processor.XFormersAttnProcessor object at 0x7d297134afd0>, 'single_transformer_blocks.7.attn.processor': <diffusers.models.attention_processor.XFormersAttnProcessor object at 0x7d297139c2d0>, 'single_transformer_blocks.8.attn.processor': <diffusers.models.attention_processor.XFormersAttnProcessor object at 0x7d297139d510>, 'single_transformer_blocks.9.attn.processor': <diffusers.models.attention_processor.XFormersAttnProcessor object at 0x7d297139e750>, 'single_transformer_blocks.10.attn.processor': <diffusers.models.attention_processor.XFormersAttnProcessor object at 0x7d297139f990>, 'single_transformer_blocks.11.attn.processor': <diffusers.models.attention_processor.XFormersAttnProcessor object at 0x7d29713b4bd0>, 'single_transformer_blocks.12.attn.processor': <diffusers.models.attention_processor.XFormersAttnProcessor object at 0x7d29713b5e10>, 'single_transformer_blocks.13.attn.processor': <diffusers.models.attention_processor.XFormersAttnProcessor object at 0x7d29713b7050>, 'single_transformer_blocks.14.attn.processor': <diffusers.models.attention_processor.XFormersAttnProcessor object at 0x7d29713d02d0>, 'single_transformer_blocks.15.attn.processor': <diffusers.models.attention_processor.XFormersAttnProcessor object at 0x7d29713d1510>, 'single_transformer_blocks.16.attn.processor': <diffusers.models.attention_processor.XFormersAttnProcessor object at 0x7d29713d2750>, 'single_transformer_blocks.17.attn.processor': <diffusers.models.attention_processor.XFormersAttnProcessor object at 0x7d29713d3990>, 'single_transformer_blocks.18.attn.processor': <diffusers.models.attention_processor.XFormersAttnProcessor object at 0x7d29713ecc10>, 'single_transformer_blocks.19.attn.processor': <diffusers.models.attention_processor.XFormersAttnProcessor object at 0x7d29713ede50>, 'single_transformer_blocks.20.attn.processor': <diffusers.models.attention_processor.XFormersAttnProcessor object at 0x7d29713ef090>, 'single_transformer_blocks.21.attn.processor': <diffusers.models.attention_processor.XFormersAttnProcessor object at 0x7d2971204310>, 'single_transformer_blocks.22.attn.processor': <diffusers.models.attention_processor.XFormersAttnProcessor object at 0x7d2971205550>, 'single_transformer_blocks.23.attn.processor': <diffusers.models.attention_processor.XFormersAttnProcessor object at 0x7d2971206790>, 'single_transformer_blocks.24.attn.processor': <diffusers.models.attention_processor.XFormersAttnProcessor object at 0x7d29712079d0>, 'single_transformer_blocks.25.attn.processor': <diffusers.models.attention_processor.XFormersAttnProcessor object at 0x7d2971224c50>, 'single_transformer_blocks.26.attn.processor': <diffusers.models.attention_processor.XFormersAttnProcessor object at 0x7d2971225e90>, 'single_transformer_blocks.27.attn.processor': <diffusers.models.attention_processor.XFormersAttnProcessor object at 0x7d29712270d0>, 'single_transformer_blocks.28.attn.processor': <diffusers.models.attention_processor.XFormersAttnProcessor object at 0x7d297123c350>, 'single_transformer_blocks.29.attn.processor': <diffusers.models.attention_processor.XFormersAttnProcessor object at 0x7d297123d590>, 'single_transformer_blocks.30.attn.processor': <diffusers.models.attention_processor.XFormersAttnProcessor object at 0x7d297123e7d0>, 'single_transformer_blocks.31.attn.processor': <diffusers.models.attention_processor.XFormersAttnProcessor object at 0x7d297123fa10>, 'single_transformer_blocks.32.attn.processor': <diffusers.models.attention_processor.XFormersAttnProcessor object at 0x7d297125cc90>, 'single_transformer_blocks.33.attn.processor': <diffusers.models.attention_processor.XFormersAttnProcessor object at 0x7d297125ded0>, 'single_transformer_blocks.34.attn.processor': <diffusers.models.attention_processor.XFormersAttnProcessor object at 0x7d297125f110>, 'single_transformer_blocks.35.attn.processor': <diffusers.models.attention_processor.XFormersAttnProcessor object at 0x7d2971278390>, 'single_transformer_blocks.36.attn.processor': <diffusers.models.attention_processor.XFormersAttnProcessor object at 0x7d29712795d0>, 'single_transformer_blocks.37.attn.processor': <diffusers.models.attention_processor.XFormersAttnProcessor object at 0x7d297127a810>}

env for reference:

Using Python 3.11.11 environment at: at

Package Version
accelerate 1.7.0
annotated-types 0.7.0
anyio 4.9.0
attrs 23.2.0
certifi 2025.6.15
charset-normalizer 3.4.2
click 8.2.1
cog 0.15.8
diffusers 0.34.0.dev0
fastapi 0.115.13
filelock 3.18.0
fsspec 2025.5.1
h11 0.16.0
hf-xet 1.1.4
httptools 0.6.4
huggingface-hub 0.33.0
idna 3.10
importlib-metadata 8.7.0
jinja2 3.1.6
markupsafe 3.0.2
mpmath 1.3.0
networkx 3.5
numpy 2.3.0
nvidia-cublas-cu12 12.6.4.1
nvidia-cuda-cupti-cu12 12.6.80
nvidia-cuda-nvrtc-cu12 12.6.77
nvidia-cuda-runtime-cu12 12.6.77
nvidia-cudnn-cu12 9.5.1.17
nvidia-cufft-cu12 11.3.0.4
nvidia-cufile-cu12 1.11.1.6
nvidia-curand-cu12 10.3.7.77
nvidia-cusolver-cu12 11.7.1.2
nvidia-cusparse-cu12 12.5.4.2
nvidia-cusparselt-cu12 0.6.3
nvidia-nccl-cu12 2.26.2
nvidia-nvjitlink-cu12 12.6.85
nvidia-nvtx-cu12 12.6.77
packaging 25.0
peft 0.15.2
pillow 11.2.1
protobuf 6.31.1
psutil 7.0.0
pydantic 2.11.7
pydantic-core 2.33.2
python-dotenv 1.1.0
pyyaml 6.0.2
regex 2024.11.6
requests 2.32.4
safetensors 0.5.3
sentencepiece 0.2.0
setuptools 80.9.0
sniffio 1.3.1
starlette 0.46.2
structlog 24.4.0
sympy 1.14.0
tokenizers 0.21.1
torch 2.7.0
torchvision 0.22.1
tqdm 4.67.1
transformers 4.52.4
triton 3.3.0
typing-extensions 4.14.0
typing-inspection 0.4.1
urllib3 2.5.0
uvicorn 0.34.3
uvloop 0.21.0
watchfiles 1.1.0
websockets 15.0.1
xformers 0.0.30
zipp 3.23.0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingneeds-code-exampleWaiting for relevant code example to be providedstaleIssues that haven't received updates

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

      Development

      No branches or pull requests

        Participants

        @yiyixuxu@sayakpaul@tsvs@tracyyy-zhu@Jamie-Cheung

        Issue actions

          self.attn ValueError: not enough values to unpack (expected 2, got 1) · Issue #10066 · huggingface/diffusers