Skip to content

Commit 1581f0c

Browse files
committed
Use MutableTorchTensorRTModule to do quantization
1 parent 3dcf128 commit 1581f0c

File tree

2 files changed

+106
-26
lines changed

2 files changed

+106
-26
lines changed

examples/apps/flux-demo.py

+71-10
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,88 @@
1+
import argparse
2+
import re
13
import time
24

35
import gradio as gr
6+
import modelopt.torch.quantization as mtq
47
import torch
58
import torch_tensorrt
69
from diffusers import FluxPipeline
710

11+
parser = argparse.ArgumentParser(
12+
description="Run Flux quantization with different dtypes"
13+
)
14+
15+
parser.add_argument(
16+
"--dtype",
17+
choices=["fp8", "int8", "fp16"],
18+
default="int8",
19+
help="Select the data type to use (fp8 or int8 or fp16)",
20+
)
21+
args = parser.parse_args()
22+
# Update enabled precisions based on dtype argument
23+
24+
if args.dtype == "fp8":
25+
enabled_precisions = {torch.float8_e4m3fn, torch.float16}
26+
ptq_config = mtq.FP8_DEFAULT_CFG
27+
elif args.dtype == "int8":
28+
enabled_precisions = {torch.int8, torch.float16}
29+
ptq_config = mtq.INT8_DEFAULT_CFG
30+
ptq_config["quant_cfg"]["*weight_quantizer"]["axis"] = None
31+
elif args.dtype == "fp16":
32+
enabled_precisions = {torch.float16}
33+
print(f"\nUsing {args.dtype} quantization")
34+
35+
836
DEVICE = "cuda:0"
937
pipe = FluxPipeline.from_pretrained(
1038
"black-forest-labs/FLUX.1-dev",
1139
torch_dtype=torch.float16,
1240
)
13-
pipe.to(torch.float16)
41+
42+
43+
pipe.to(DEVICE).to(torch.float16)
1444
backbone = pipe.transformer
45+
backbone.eval()
46+
47+
48+
def filter_func(name):
49+
pattern = re.compile(
50+
r".*(time_emb_proj|time_embedding|conv_in|conv_out|conv_shortcut|add_embedding|pos_embed|time_text_embed|context_embedder|norm_out|x_embedder).*"
51+
)
52+
return pattern.match(name) is not None
1553

1654

55+
def do_calibrate(
56+
pipe,
57+
prompt: str,
58+
) -> None:
59+
"""
60+
Run calibration steps on the pipeline using the given prompts.
61+
"""
62+
image = pipe(
63+
prompt,
64+
output_type="pil",
65+
num_inference_steps=20,
66+
generator=torch.Generator("cuda").manual_seed(0),
67+
).images[0]
68+
69+
70+
def forward_loop(mod):
71+
# Switch the pipeline's backbone, run calibration
72+
pipe.transformer = mod
73+
do_calibrate(
74+
pipe=pipe,
75+
prompt="test",
76+
)
77+
78+
79+
if args.dtype != "fp16":
80+
backbone = mtq.quantize(backbone, ptq_config, forward_loop)
81+
mtq.disable_quantizer(backbone, filter_func)
82+
1783
batch_size = 2
18-
BATCH = torch.export.Dim("batch", min=1, max=8)
1984

20-
# This particular min, max values for img_id input are recommended by torch dynamo during the export of the model.
21-
# To see this recommendation, you can try exporting using min=1, max=4096
85+
BATCH = torch.export.Dim("batch", min=1, max=8)
2286
dynamic_shapes = {
2387
"hidden_states": {0: BATCH},
2488
"encoder_hidden_states": {0: BATCH},
@@ -34,21 +98,18 @@
3498
settings = {
3599
"strict": False,
36100
"allow_complex_guards_as_runtime_asserts": True,
37-
"enabled_precisions": {torch.float32},
101+
"enabled_precisions": enabled_precisions,
38102
"truncate_double": True,
39103
"min_block_size": 1,
40-
"use_fp32_acc": True,
41-
"use_explicit_typing": True,
42104
"debug": False,
43105
"use_python_runtime": True,
44106
"immutable_weights": False,
45-
"enable_cuda_graph": True,
107+
"offload_module_to_cpu": True,
46108
}
47-
backbone.to(DEVICE)
109+
48110
trt_gm = torch_tensorrt.MutableTorchTensorRTModule(backbone, **settings)
49111
trt_gm.set_expected_dynamic_shape_range((), dynamic_shapes)
50112
pipe.transformer = trt_gm
51-
pipe.to(DEVICE)
52113

53114

54115
def generate_image(prompt, inference_step, batch_size=2):

py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py

+35-16
Original file line numberDiff line numberDiff line change
@@ -306,23 +306,42 @@ def refit_gm(self) -> None:
306306
torch.cuda.empty_cache()
307307

308308
def get_exported_program(self) -> torch.export.ExportedProgram:
309-
if self.allow_complex_guards_as_runtime_asserts:
310-
return _export(
311-
self.original_model,
312-
self.arg_inputs,
313-
kwargs=self.kwarg_inputs,
314-
dynamic_shapes=self._get_total_dynamic_shapes(),
315-
strict=self.strict,
316-
allow_complex_guards_as_runtime_asserts=self.allow_complex_guards_as_runtime_asserts,
317-
)
309+
310+
def export_fn() -> torch.export.ExportedProgram:
311+
if self.allow_complex_guards_as_runtime_asserts:
312+
return _export(
313+
self.original_model,
314+
self.arg_inputs,
315+
kwargs=self.kwarg_inputs,
316+
dynamic_shapes=self._get_total_dynamic_shapes(),
317+
strict=self.strict,
318+
allow_complex_guards_as_runtime_asserts=self.allow_complex_guards_as_runtime_asserts,
319+
)
320+
else:
321+
return torch.export.export(
322+
self.original_model,
323+
self.arg_inputs,
324+
kwargs=self.kwarg_inputs,
325+
dynamic_shapes=self._get_total_dynamic_shapes(),
326+
strict=self.strict,
327+
)
328+
329+
if (
330+
torch.float8_e4m3fn in self.additional_settings["enabled_precisions"]
331+
or torch.int8 in self.additional_settings["enabled_precisions"]
332+
):
333+
try:
334+
from modelopt.torch.quantization.utils import export_torch_mode
335+
336+
assert torch.ops.tensorrt.quantize_op.default
337+
except Exception as e:
338+
logger.warning(
339+
"Unable to import quantization op. Please install modelopt library (https://github.com/NVIDIA/TensorRT-Model-Optimizer?tab=readme-ov-file#installation) to add support for compiling quantized models"
340+
)
341+
with export_torch_mode():
342+
return export_fn()
318343
else:
319-
return torch.export.export(
320-
self.original_model,
321-
self.arg_inputs,
322-
kwargs=self.kwarg_inputs,
323-
dynamic_shapes=self._get_total_dynamic_shapes(),
324-
strict=self.strict,
325-
)
344+
return export_fn()
326345

327346
def compile(self) -> None:
328347
"""

0 commit comments

Comments
 (0)