Skip to content

autoquant api sharp edges #657

Open
Open
@msaroufim

Description

@msaroufim

Context

I was trying to run the new Flux model but ran into some sharp bits with the autoquant API

import time
import torchao
from torchao.quantization.quant_api import quantize_, int8_weight_only
from torch import nn
from torch.utils.benchmark import Timer

import torch
from diffusers import FluxPipeline

pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload()

def inference():
    prompt = "A cat holding a sign that says hello world"
    out = pipe(
        prompt=prompt,
        guidance_scale=0.,
        height=768,
        width=1360,
        num_inference_steps=4,
        max_sequence_length=256,
    ).images[0]
    out.save("image.png")

tic = time.time()
inference()
toc = time.time()

print(f"Original Running time is {toc - tic}")

What I tried

The baseline was compiling the model which made it about 25% faster

pipe = torch.compile(pipe)

So I tried compiling autoquant at first like this

# Running over the pipe didn't work
#   File "/home/marksaroufim/.conda/envs/ao/lib/python3.10/site-packages/torchao/quantization/quant_api.py", line 175, in _replace_with_custom_fn_if_matches_filter
#     for name, child in model.named_children():
# AttributeError: 'function' object has no attribute 'named_children'
# pipe = torch.autoquant(torch.compile(pipe, mode='max-autotune'))

And the problem was while torch.compile works with both nn modules and functions. torchao.autoquant works only with nn modules

So instead I tried running over the transformer of the pipe only

# Running over transformer only
#   File "/home/marksaroufim/.conda/envs/ao/lib/python3.10/site-packages/torch/nn/modules/module.py", line 783, in compute_should_use_set_data
#     if torch._has_compatible_shallow_copy_type(tensor, tensor_applied):
# TypeError: _has_compatible_shallow_copy_type(): argument 'from' (position 2) must be Tensor, not NoneType
# pipe.transformer = torchao.autoquant(torch.compile(pipe.transformer, mode="max-autotune"))

Another idea was to compile the pipe and autoquant the transformer

#   return t.type.__tensor_unflatten__(
#   File "/home/marksaroufim/.conda/envs/ao/lib/python3.10/site-packages/torchao/quantization/autoquant.py", line 177, in __tensor_unflatten__
#     return cls(weight, qtensor_class_list, mode, shape=shape if outer_size is None else outer_size, dtype=dtype, strides=outer_stride)
#   File "/home/marksaroufim/.conda/envs/ao/lib/python3.10/site-packages/torchao/quantization/autoquant.py", line 65, in __new__
#     return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs)  # type: ignore[attr-defined]
# torch._dynamo.exc.InternalTorchDynamoError: _make_wrapper_subclass(): argument 'dtype' must be torch.dtype, not torch._C._TensorMeta

pipe.transformer = torchao.autoquant(pipe.transformer)

Dependencies

pip freeze

accelerate==0.33.0
aiofiles==23.2.1
altair==5.4.0
annotated-types==0.7.0
anyio==4.4.0
attrs==24.2.0
blinker==1.8.2
cachetools==5.4.0
certifi==2024.7.4
charset-normalizer==3.3.2
click==8.1.7
contourpy==1.2.1
cycler==0.12.1
diffusers==0.30.0
einops==0.8.0
exceptiongroup==1.2.2
fastapi==0.112.0
ffmpy==0.4.0
filelock==3.13.1
fire==0.6.0
-e git+https://github.com/black-forest-labs/flux@c23ae247225daba30fbd56058d247cc1b1fc20a3#egg=flux
fonttools==4.53.1
fsspec==2024.6.1
gitdb==4.0.11
GitPython==3.1.43
gradio==4.41.0
gradio_client==1.3.0
h11==0.14.0
httpcore==1.0.5
httpx==0.27.0
huggingface-hub==0.24.5
idna==3.7
importlib_metadata==8.2.0
importlib_resources==6.4.0
invisible-watermark==0.2.0
Jinja2==3.1.4
jsonschema==4.23.0
jsonschema-specifications==2023.12.1
kiwisolver==1.4.5
markdown-it-py==3.0.0
MarkupSafe==2.1.5
matplotlib==3.9.1.post1
mdurl==0.1.2
mpmath==1.3.0
narwhals==1.3.0
networkx==3.3
numpy==1.26.4
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==9.1.0.70
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.21.5
nvidia-nvjitlink-cu12==12.1.105
nvidia-nvtx-cu12==12.1.105
opencv-python==4.10.0.84
orjson==3.10.7
packaging==24.1
pandas==2.2.2
pillow==10.4.0
protobuf==5.27.3
psutil==6.0.0
pyarrow==17.0.0
pydantic==2.8.2
pydantic_core==2.20.1
pydeck==0.9.1
pydub==0.25.1
Pygments==2.18.0
pyparsing==3.1.2
python-dateutil==2.9.0.post0
python-multipart==0.0.9
pytorch-triton==3.0.0+dedb7bdf33
pytz==2024.1
PyWavelets==1.6.0
PyYAML==6.0.2
referencing==0.35.1
regex==2024.7.24
requests==2.32.3
rich==13.7.1
rpds-py==0.20.0
ruff==0.5.7
safetensors==0.4.4
semantic-version==2.10.0
sentencepiece==0.2.0
shellingham==1.5.4
six==1.16.0
smmap==5.0.1
sniffio==1.3.1
starlette==0.37.2
streamlit==1.37.1
streamlit-keyup==0.2.4
sympy==1.13.1
tenacity==8.5.0
termcolor==2.4.0
tokenizers==0.19.1
toml==0.10.2
tomlkit==0.12.0
torch==2.5.0.dev20240811+cu121
torchao @ file:///home/marksaroufim/ao
torchvision==0.19.0
tornado==6.4.1
tqdm==4.66.5
transformers==4.44.0
triton==3.0.0
typer==0.12.3
typing_extensions==4.12.2
tzdata==2024.1
urllib3==2.2.2
uvicorn==0.30.5
watchdog==4.0.2
websockets==12.0
zipp==3.20.0

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions