|
| 1 | +import argparse |
| 2 | +import re |
1 | 3 | import time
|
2 | 4 |
|
3 | 5 | import gradio as gr
|
| 6 | +import modelopt.torch.quantization as mtq |
4 | 7 | import torch
|
5 | 8 | import torch_tensorrt
|
6 | 9 | from diffusers import FluxPipeline
|
7 | 10 |
|
| 11 | +parser = argparse.ArgumentParser( |
| 12 | + description="Run Flux quantization with different dtypes" |
| 13 | +) |
| 14 | + |
| 15 | +parser.add_argument( |
| 16 | + "--dtype", |
| 17 | + choices=["fp8", "int8", "fp16"], |
| 18 | + default="int8", |
| 19 | + help="Select the data type to use (fp8 or int8 or fp16)", |
| 20 | +) |
| 21 | +args = parser.parse_args() |
| 22 | +# Update enabled precisions based on dtype argument |
| 23 | + |
| 24 | +if args.dtype == "fp8": |
| 25 | + enabled_precisions = {torch.float8_e4m3fn, torch.float16} |
| 26 | + ptq_config = mtq.FP8_DEFAULT_CFG |
| 27 | +elif args.dtype == "int8": |
| 28 | + enabled_precisions = {torch.int8, torch.float16} |
| 29 | + ptq_config = mtq.INT8_DEFAULT_CFG |
| 30 | + ptq_config["quant_cfg"]["*weight_quantizer"]["axis"] = None |
| 31 | +elif args.dtype == "fp16": |
| 32 | + enabled_precisions = {torch.float16} |
| 33 | +print(f"\nUsing {args.dtype} quantization") |
| 34 | + |
| 35 | + |
8 | 36 | DEVICE = "cuda:0"
|
9 | 37 | pipe = FluxPipeline.from_pretrained(
|
10 | 38 | "black-forest-labs/FLUX.1-dev",
|
11 | 39 | torch_dtype=torch.float16,
|
12 | 40 | )
|
13 |
| -pipe.to(torch.float16) |
| 41 | + |
| 42 | + |
| 43 | +pipe.to(DEVICE).to(torch.float16) |
14 | 44 | backbone = pipe.transformer
|
| 45 | +backbone.eval() |
| 46 | + |
| 47 | + |
| 48 | +def filter_func(name): |
| 49 | + pattern = re.compile( |
| 50 | + r".*(time_emb_proj|time_embedding|conv_in|conv_out|conv_shortcut|add_embedding|pos_embed|time_text_embed|context_embedder|norm_out|x_embedder).*" |
| 51 | + ) |
| 52 | + return pattern.match(name) is not None |
15 | 53 |
|
16 | 54 |
|
| 55 | +def do_calibrate( |
| 56 | + pipe, |
| 57 | + prompt: str, |
| 58 | +) -> None: |
| 59 | + """ |
| 60 | + Run calibration steps on the pipeline using the given prompts. |
| 61 | + """ |
| 62 | + image = pipe( |
| 63 | + prompt, |
| 64 | + output_type="pil", |
| 65 | + num_inference_steps=20, |
| 66 | + generator=torch.Generator("cuda").manual_seed(0), |
| 67 | + ).images[0] |
| 68 | + |
| 69 | + |
| 70 | +def forward_loop(mod): |
| 71 | + # Switch the pipeline's backbone, run calibration |
| 72 | + pipe.transformer = mod |
| 73 | + do_calibrate( |
| 74 | + pipe=pipe, |
| 75 | + prompt="test", |
| 76 | + ) |
| 77 | + |
| 78 | + |
| 79 | +if args.dtype != "fp16": |
| 80 | + backbone = mtq.quantize(backbone, ptq_config, forward_loop) |
| 81 | + mtq.disable_quantizer(backbone, filter_func) |
| 82 | + |
17 | 83 | batch_size = 2
|
18 |
| -BATCH = torch.export.Dim("batch", min=1, max=8) |
19 | 84 |
|
20 |
| -# This particular min, max values for img_id input are recommended by torch dynamo during the export of the model. |
21 |
| -# To see this recommendation, you can try exporting using min=1, max=4096 |
| 85 | +BATCH = torch.export.Dim("batch", min=1, max=8) |
22 | 86 | dynamic_shapes = {
|
23 | 87 | "hidden_states": {0: BATCH},
|
24 | 88 | "encoder_hidden_states": {0: BATCH},
|
|
34 | 98 | settings = {
|
35 | 99 | "strict": False,
|
36 | 100 | "allow_complex_guards_as_runtime_asserts": True,
|
37 |
| - "enabled_precisions": {torch.float32}, |
| 101 | + "enabled_precisions": enabled_precisions, |
38 | 102 | "truncate_double": True,
|
39 | 103 | "min_block_size": 1,
|
40 |
| - "use_fp32_acc": True, |
41 |
| - "use_explicit_typing": True, |
42 | 104 | "debug": False,
|
43 | 105 | "use_python_runtime": True,
|
44 | 106 | "immutable_weights": False,
|
45 |
| - "enable_cuda_graph": True, |
| 107 | + "offload_module_to_cpu": True, |
46 | 108 | }
|
47 |
| -backbone.to(DEVICE) |
| 109 | + |
48 | 110 | trt_gm = torch_tensorrt.MutableTorchTensorRTModule(backbone, **settings)
|
49 | 111 | trt_gm.set_expected_dynamic_shape_range((), dynamic_shapes)
|
50 | 112 | pipe.transformer = trt_gm
|
51 |
| -pipe.to(DEVICE) |
52 | 113 |
|
53 | 114 |
|
54 | 115 | def generate_image(prompt, inference_step, batch_size=2):
|
|
0 commit comments