88 print (torch .cuda .get_device_properties (i ).name )
99
1010DEVICE = "cuda:0"
11- # pipe = FluxPipeline.from_pretrained(
12- # "black-forest-labs/FLUX.1-dev",
13- # torch_dtype=torch.float32 ,
14- # )
15- pipe .to (DEVICE ).to (torch .float32 )
11+ pipe = FluxPipeline .from_pretrained (
12+ "black-forest-labs/FLUX.1-dev" ,
13+ torch_dtype = torch .bfloat16 ,
14+ )
15+ pipe .to (DEVICE ).to (torch .bfloat16 )
1616backbone = pipe .transformer
1717
1818
4444 "debug" : False ,
4545 "use_python_runtime" : True ,
4646 "immutable_weights" : False ,
47+ "offload_module_to_cpu" : True ,
4748}
4849
4950
50- def generate_image (prompt , inference_step , batch_size = 2 , benchmark = False , iterations = 1 ):
51+ def generate_image (prompt , inference_step , batch_size = 1 , benchmark = False , iterations = 1 ):
5152
5253 start = time ()
5354 for i in range (iterations ):
@@ -62,35 +63,37 @@ def generate_image(prompt, inference_step, batch_size=2, benchmark=False, iterat
6263 print ("Time Elapse for" , iterations , "iterations:" , end - start )
6364 print (
6465 "Average Latency Per Step:" ,
65- (end - start ) / inference_step / iterations / batchsize ,
66+ (end - start ) / inference_step / iterations / batch_size ,
6667 )
6768 return image
6869
6970
70- generate_image (["Test" ], 2 )
71- print ("Benchmark Original PyTorch Module Latency (float32)" )
72- generate_image (["Test" ], 50 , benchmark = True , iterations = 3 )
71+ pipe .to (torch .bfloat16 )
72+ torch .cuda .empty_cache ()
73+ # Warmup
74+ generate_image (["Test" ], 20 )
75+ print ("Benchmark Original PyTorch Module Latency (bfloat16)" )
76+ generate_image (["Test" ], 20 , benchmark = True , iterations = 3 )
7377
7478pipe .to (torch .float16 )
7579print ("Benchmark Original PyTorch Module Latency (float16)" )
76- generate_image (["Test" ], 50 , benchmark = True , iterations = 3 )
77-
80+ generate_image (["Test" ], 20 , benchmark = True , iterations = 3 )
7881
7982trt_gm = torch_tensorrt .MutableTorchTensorRTModule (backbone , ** settings )
8083trt_gm .set_expected_dynamic_shape_range ((), dynamic_shapes )
8184pipe .transformer = trt_gm
8285
8386start = time ()
84- generate_image (["Test" ], 2 )
87+ generate_image (["Test" ], 2 , batch_size = 2 )
8588end = time ()
8689print ("Time Elapse compilation:" , end - start )
8790print ()
8891print ("Benchmark TRT Accelerated Latency" )
89- generate_image (["Test" ], 50 , benchmark = True , iterations = 3 )
92+ generate_image (["Test" ], 20 , benchmark = True , iterations = 3 )
9093torch .cuda .empty_cache ()
9194
9295
9396with torch_tensorrt .runtime .enable_cudagraphs (trt_gm ):
9497 generate_image (["Test" ], 2 )
9598 print ("Benchmark TRT Accelerated Latency with Cuda Graph" )
96- generate_image (["Test" ], 50 , benchmark = True , iterations = 3 )
99+ generate_image (["Test" ], 20 , benchmark = True , iterations = 3 )
0 commit comments