Skip to content

Commit 18b6455

Browse files
committed
Added quantization debug script
1 parent 1581f0c commit 18b6455

File tree

1 file changed

+202
-0
lines changed

1 file changed

+202
-0
lines changed
+202
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
# %%
2+
# Import the following libraries
3+
# -----------------------------
4+
# Load the ModelOpt-modified model architecture and weights using Huggingface APIs
5+
# Add argument parsing for dtype selection
6+
import argparse
7+
import re
8+
9+
import modelopt.torch.opt as mto
10+
import modelopt.torch.quantization as mtq
11+
import torch
12+
import torch_tensorrt
13+
from diffusers import FluxPipeline
14+
from diffusers.models.attention_processor import Attention
15+
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
16+
from modelopt.torch.quantization.utils import export_torch_mode
17+
from torch.export._trace import _export
18+
from transformers import AutoModelForCausalLM
19+
20+
parser = argparse.ArgumentParser(
21+
description="Run Flux quantization with different dtypes"
22+
)
23+
parser.add_argument(
24+
"--dtype",
25+
choices=["fp8", "int8"],
26+
default="fp8",
27+
help="Quantization data type to use (fp8 or int8)",
28+
)
29+
30+
args = parser.parse_args()
31+
32+
# Update enabled precisions based on dtype argument
33+
if args.dtype == "fp8":
34+
enabled_precisions = {torch.float8_e4m3fn, torch.float16}
35+
ptq_config = mtq.FP8_DEFAULT_CFG
36+
else: # int8
37+
enabled_precisions = {torch.int8, torch.float16}
38+
ptq_config = mtq.INT8_DEFAULT_CFG
39+
ptq_config["quant_cfg"]["*weight_quantizer"]["axis"] = None
40+
print(f"\nUsing {args.dtype} quantization")
41+
# %%
42+
DEVICE = "cuda:0"
43+
pipe = FluxPipeline.from_pretrained(
44+
"black-forest-labs/FLUX.1-dev",
45+
torch_dtype=torch.float16,
46+
)
47+
pipe.transformer = FluxTransformer2DModel(
48+
num_layers=1, num_single_layers=1, guidance_embeds=True
49+
)
50+
51+
pipe.to(DEVICE).to(torch.float16)
52+
# Store the config and transformer backbone
53+
config = pipe.transformer.config
54+
# global backbone
55+
backbone = pipe.transformer
56+
backbone.eval()
57+
58+
59+
def filter_func(name):
60+
pattern = re.compile(
61+
r".*(time_emb_proj|time_embedding|conv_in|conv_out|conv_shortcut|add_embedding|pos_embed|time_text_embed|context_embedder|norm_out|x_embedder).*"
62+
)
63+
return pattern.match(name) is not None
64+
65+
66+
def generate_image(pipe, prompt, image_name):
67+
seed = 42
68+
image = pipe(
69+
prompt,
70+
output_type="pil",
71+
num_inference_steps=20,
72+
generator=torch.Generator("cuda").manual_seed(seed),
73+
).images[0]
74+
image.save(f"{image_name}.png")
75+
print(f"Image generated using {image_name} model saved as {image_name}.png")
76+
77+
78+
def benchmark(prompt, inference_step, batch_size=1, iterations=1):
79+
from time import time
80+
81+
start = time()
82+
for i in range(iterations):
83+
image = pipe(
84+
prompt,
85+
output_type="pil",
86+
num_inference_steps=inference_step,
87+
num_images_per_prompt=batch_size,
88+
).images
89+
end = time()
90+
print(f"Batch Size: {batch_size}")
91+
print("Time Elapse for", iterations, "iterations:", end - start)
92+
print(
93+
"Average Latency Per Step:",
94+
(end - start) / inference_step / iterations / batch_size,
95+
)
96+
return image
97+
98+
99+
# %%
100+
# Quantization
101+
102+
103+
def do_calibrate(
104+
pipe,
105+
prompt: str,
106+
) -> None:
107+
"""
108+
Run calibration steps on the pipeline using the given prompts.
109+
"""
110+
image = pipe(
111+
prompt,
112+
output_type="pil",
113+
num_inference_steps=20,
114+
generator=torch.Generator("cuda").manual_seed(0),
115+
).images[0]
116+
117+
118+
def forward_loop(mod):
119+
# Switch the pipeline's backbone, run calibration
120+
pipe.transformer = mod
121+
do_calibrate(
122+
pipe=pipe,
123+
prompt="test",
124+
)
125+
126+
127+
backbone = mtq.quantize(backbone, ptq_config, forward_loop)
128+
mtq.disable_quantizer(backbone, filter_func)
129+
130+
batch_size = 2
131+
BATCH = torch.export.Dim("batch", min=1, max=8)
132+
SEQ_LEN = torch.export.Dim("seq_len", min=1, max=512)
133+
# This particular min, max values for img_id input are recommended by torch dynamo during the export of the model.
134+
# To see this recommendation, you can try exporting using min=1, max=4096
135+
IMG_ID = torch.export.Dim("img_id", min=3586, max=4096)
136+
dynamic_shapes = {
137+
"hidden_states": {0: BATCH},
138+
"encoder_hidden_states": {0: BATCH, 1: SEQ_LEN},
139+
"pooled_projections": {0: BATCH},
140+
"timestep": {0: BATCH},
141+
"txt_ids": {0: SEQ_LEN},
142+
"img_ids": {0: IMG_ID},
143+
"guidance": {0: BATCH},
144+
"joint_attention_kwargs": {},
145+
"return_dict": None,
146+
}
147+
# The guidance factor is of type torch.float32
148+
dummy_inputs = {
149+
"hidden_states": torch.randn((batch_size, 4096, 64), dtype=torch.float16).to(
150+
DEVICE
151+
),
152+
"encoder_hidden_states": torch.randn(
153+
(batch_size, 512, 4096), dtype=torch.float16
154+
).to(DEVICE),
155+
"pooled_projections": torch.randn((batch_size, 768), dtype=torch.float16).to(
156+
DEVICE
157+
),
158+
"timestep": torch.tensor([1.0] * batch_size, dtype=torch.float16).to(DEVICE),
159+
"txt_ids": torch.randn((512, 3), dtype=torch.float16).to(DEVICE),
160+
"img_ids": torch.randn((4096, 3), dtype=torch.float16).to(DEVICE),
161+
"guidance": torch.tensor([1.0] * batch_size, dtype=torch.float32).to(DEVICE),
162+
"joint_attention_kwargs": {},
163+
"return_dict": False,
164+
}
165+
166+
# This will create an exported program which is going to be compiled with Torch-TensorRT
167+
with export_torch_mode():
168+
ep = _export(
169+
backbone,
170+
args=(),
171+
kwargs=dummy_inputs,
172+
dynamic_shapes=dynamic_shapes,
173+
strict=False,
174+
allow_complex_guards_as_runtime_asserts=True,
175+
)
176+
177+
178+
trt_gm = torch_tensorrt.dynamo.compile(
179+
ep,
180+
inputs=dummy_inputs,
181+
enabled_precisions=enabled_precisions,
182+
truncate_double=True,
183+
min_block_size=1,
184+
debug=False,
185+
use_python_runtime=True,
186+
immutable_weights=True,
187+
offload_module_to_cpu=True,
188+
)
189+
190+
191+
del ep
192+
pipe.transformer = trt_gm
193+
pipe.transformer.config = config
194+
195+
196+
# %%
197+
trt_gm.device = torch.device(DEVICE)
198+
# Function which generates images from the flux pipeline
199+
generate_image(pipe, ["A golden retriever"], "dog_code2")
200+
201+
202+
# For this dummy model, the fp16 engine size is around 1GB, fp32 engine size is around 2GB

0 commit comments

Comments
 (0)