Description
Context
I was trying to run the new Flux model but ran into some sharp bits with the autoquant API
import time
import torchao
from torchao.quantization.quant_api import quantize_, int8_weight_only
from torch import nn
from torch.utils.benchmark import Timer
import torch
from diffusers import FluxPipeline
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload()
def inference():
prompt = "A cat holding a sign that says hello world"
out = pipe(
prompt=prompt,
guidance_scale=0.,
height=768,
width=1360,
num_inference_steps=4,
max_sequence_length=256,
).images[0]
out.save("image.png")
tic = time.time()
inference()
toc = time.time()
print(f"Original Running time is {toc - tic}")
What I tried
The baseline was compiling the model which made it about 25% faster
pipe = torch.compile(pipe)
So I tried compiling autoquant at first like this
# Running over the pipe didn't work
# File "/home/marksaroufim/.conda/envs/ao/lib/python3.10/site-packages/torchao/quantization/quant_api.py", line 175, in _replace_with_custom_fn_if_matches_filter
# for name, child in model.named_children():
# AttributeError: 'function' object has no attribute 'named_children'
# pipe = torch.autoquant(torch.compile(pipe, mode='max-autotune'))
And the problem was while torch.compile
works with both nn modules and functions. torchao.autoquant
works only with nn modules
So instead I tried running over the transformer of the pipe only
# Running over transformer only
# File "/home/marksaroufim/.conda/envs/ao/lib/python3.10/site-packages/torch/nn/modules/module.py", line 783, in compute_should_use_set_data
# if torch._has_compatible_shallow_copy_type(tensor, tensor_applied):
# TypeError: _has_compatible_shallow_copy_type(): argument 'from' (position 2) must be Tensor, not NoneType
# pipe.transformer = torchao.autoquant(torch.compile(pipe.transformer, mode="max-autotune"))
Another idea was to compile the pipe and autoquant the transformer
# return t.type.__tensor_unflatten__(
# File "/home/marksaroufim/.conda/envs/ao/lib/python3.10/site-packages/torchao/quantization/autoquant.py", line 177, in __tensor_unflatten__
# return cls(weight, qtensor_class_list, mode, shape=shape if outer_size is None else outer_size, dtype=dtype, strides=outer_stride)
# File "/home/marksaroufim/.conda/envs/ao/lib/python3.10/site-packages/torchao/quantization/autoquant.py", line 65, in __new__
# return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
# torch._dynamo.exc.InternalTorchDynamoError: _make_wrapper_subclass(): argument 'dtype' must be torch.dtype, not torch._C._TensorMeta
pipe.transformer = torchao.autoquant(pipe.transformer)
Dependencies
pip freeze
accelerate==0.33.0
aiofiles==23.2.1
altair==5.4.0
annotated-types==0.7.0
anyio==4.4.0
attrs==24.2.0
blinker==1.8.2
cachetools==5.4.0
certifi==2024.7.4
charset-normalizer==3.3.2
click==8.1.7
contourpy==1.2.1
cycler==0.12.1
diffusers==0.30.0
einops==0.8.0
exceptiongroup==1.2.2
fastapi==0.112.0
ffmpy==0.4.0
filelock==3.13.1
fire==0.6.0
-e git+https://github.com/black-forest-labs/flux@c23ae247225daba30fbd56058d247cc1b1fc20a3#egg=flux
fonttools==4.53.1
fsspec==2024.6.1
gitdb==4.0.11
GitPython==3.1.43
gradio==4.41.0
gradio_client==1.3.0
h11==0.14.0
httpcore==1.0.5
httpx==0.27.0
huggingface-hub==0.24.5
idna==3.7
importlib_metadata==8.2.0
importlib_resources==6.4.0
invisible-watermark==0.2.0
Jinja2==3.1.4
jsonschema==4.23.0
jsonschema-specifications==2023.12.1
kiwisolver==1.4.5
markdown-it-py==3.0.0
MarkupSafe==2.1.5
matplotlib==3.9.1.post1
mdurl==0.1.2
mpmath==1.3.0
narwhals==1.3.0
networkx==3.3
numpy==1.26.4
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==9.1.0.70
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.21.5
nvidia-nvjitlink-cu12==12.1.105
nvidia-nvtx-cu12==12.1.105
opencv-python==4.10.0.84
orjson==3.10.7
packaging==24.1
pandas==2.2.2
pillow==10.4.0
protobuf==5.27.3
psutil==6.0.0
pyarrow==17.0.0
pydantic==2.8.2
pydantic_core==2.20.1
pydeck==0.9.1
pydub==0.25.1
Pygments==2.18.0
pyparsing==3.1.2
python-dateutil==2.9.0.post0
python-multipart==0.0.9
pytorch-triton==3.0.0+dedb7bdf33
pytz==2024.1
PyWavelets==1.6.0
PyYAML==6.0.2
referencing==0.35.1
regex==2024.7.24
requests==2.32.3
rich==13.7.1
rpds-py==0.20.0
ruff==0.5.7
safetensors==0.4.4
semantic-version==2.10.0
sentencepiece==0.2.0
shellingham==1.5.4
six==1.16.0
smmap==5.0.1
sniffio==1.3.1
starlette==0.37.2
streamlit==1.37.1
streamlit-keyup==0.2.4
sympy==1.13.1
tenacity==8.5.0
termcolor==2.4.0
tokenizers==0.19.1
toml==0.10.2
tomlkit==0.12.0
torch==2.5.0.dev20240811+cu121
torchao @ file:///home/marksaroufim/ao
torchvision==0.19.0
tornado==6.4.1
tqdm==4.66.5
transformers==4.44.0
triton==3.0.0
typer==0.12.3
typing_extensions==4.12.2
tzdata==2024.1
urllib3==2.2.2
uvicorn==0.30.5
watchdog==4.0.2
websockets==12.0
zipp==3.20.0