2424import torchao
2525import torch .nn as nn
2626import numpy as np
27+ import pandas as pd
2728import matplotlib .pyplot as plt
2829from torch .profiler import profile , record_function , ProfilerActivity
2930from torchao .quantization .quant_api import quantize_ , float8_dynamic_activation_float8_weight , float8_weight_only
3031import copy
3132from utils import (
3233 get_name_to_shapes_iter ,
3334 get_llm_mm_shapes ,
35+ get_diffusion_mm_shapes
3436)
3537import tqdm
3638from tabulate import tabulate
@@ -79,24 +81,14 @@ def get_gpu_kernel_times(profiler_chrome_trace, gpu_op_name):
7981 gpu_overhead_time += event [1 ]
8082 return gpu_op_time , gpu_overhead_time
8183
82- def run_gemm_benchmarks (name_to_shapes , float8_dtype = torch .float8_e4m3fn , other_dtype = torch .bfloat16 , quantization_technique = float8_weight_only ):
83- # Dictionary to store performance data
84- performance_data = {
85- 'Input Size' : [],
86- 'float8 Op Kernel Times (ms)' : [],
87- 'bf16 Op Kernel Times (ms)' : [],
88- 'float8 Overhead Kernel Times (ms)' : [],
89- 'bf16 Overhead Kernel Times (ms)' : [],
90- 'float8 Total Kernel Times (ms)' : [],
91- 'bf16 Total Kernel Times (ms)' : [],
92- }
84+ def run_gemm_benchmarks (performance_data , name_to_shapes , float8_dtype = torch .float8_e4m3fn , other_dtype = torch .bfloat16 , quantization_technique = float8_weight_only , batch_size = 1 ):
9385 # Run benchmarks for each input size
9486 for idx , (name , (m , k , n )) in enumerate (tqdm .tqdm (name_to_shapes )):
95- print (f"Profiling model with input size: { m , k , n } for quantization technique: { quantization_technique } , dtype: { float8_dtype } vs { other_dtype } " )
87+ print (f"Profiling model with input size: { batch_size , m , k , n } for quantization technique: { quantization_technique } , dtype: { float8_dtype } vs { other_dtype } " )
9688
9789 # Initialize the model with the specified dimensions
98- model = ToyLinearModel (m , k , n ).eval ().to (device )
99- example_inputs = model .example_inputs (m )
90+ model = ToyLinearModel (batch_size * m , k , n ).eval ().to (device )
91+ example_inputs = model .example_inputs (batch_size * m )
10092 model_bf16 = copy .deepcopy (model ).to (device ) # Copy the model to bf
10193 model_bf16 = torch .compile (model_bf16 ) # Compile the model
10294
@@ -117,51 +109,96 @@ def run_gemm_benchmarks(name_to_shapes, float8_dtype=torch.float8_e4m3fn, other_
117109 float8_gpu_op_time , float8_gpu_overhead_time = get_gpu_kernel_times (prof_float8 , 'gemm' )
118110 bf16_gpu_op_time , bf16_gpu_overhead_time = get_gpu_kernel_times (prof_bf16 , 'gemm' )
119111
120- # # Print profiling details
121- # print(f"bfloat16_gpu_overhead_time: {bf16_gpu_overhead_time} gpu_op_time: {bf16_gpu_op_time}")
122- # print(f"float8_gpu_overhead_time: {float8_gpu_overhead_time} float8_gpu_op_time: {float8_gpu_op_time}")
123-
124112 # Add the performance data to the dictionary
125113 # time/1000 -> Convert from microseconds to milliseconds
126- performance_data ['Input Size' ].append (f"{ tuple ( example_inputs [ 0 ]. shape )} " )
114+ performance_data ['Input Size' ].append (f"{ ( m , k , n )} " )
127115 performance_data ['float8 Total Kernel Times (ms)' ].append ((float8_gpu_op_time + float8_gpu_overhead_time ) / 1000 )
128116 performance_data ['bf16 Total Kernel Times (ms)' ].append ((bf16_gpu_op_time + bf16_gpu_overhead_time ) / 1000 )
129117 performance_data ['float8 Op Kernel Times (ms)' ].append (float8_gpu_op_time / 1000 )
130118 performance_data ['bf16 Op Kernel Times (ms)' ].append (bf16_gpu_op_time / 1000 )
131119 performance_data ['float8 Overhead Kernel Times (ms)' ].append (float8_gpu_overhead_time / 1000 )
132120 performance_data ['bf16 Overhead Kernel Times (ms)' ].append (bf16_gpu_overhead_time / 1000 )
121+ performance_data ['Batch Size' ].append (batch_size )
133122
134123 return performance_data
135124
136125
137- def plot_performance_data (performance_data ):
126+ def plot_performance_data (performance_data , x_col , plot_name = 'model_evaluation_gpu_kernel_performance' ):
138127 # Plotting the results
139128 plt .figure (figsize = (10 , 6 ))
140- plt .plot (performance_data ['Input Size' ], performance_data ['float8 Total Kernel Times (ms)' ], marker = 'o' , label = 'float8' )
141- plt .plot (performance_data ['Input Size' ], performance_data ['bf16 Total Kernel Times (ms)' ], marker = 's' , label = 'bf16' )
142- plt .xlabel ('Batch Size' )
129+ plt .plot (performance_data [x_col ], performance_data ['float8 Total Kernel Times (ms)' ], marker = 'o' , label = 'float8' )
130+ plt .plot (performance_data [x_col ], performance_data ['bf16 Total Kernel Times (ms)' ], marker = 's' , label = 'bf16' )
131+ plt .xlabel (x_col )
143132 plt .ylabel ('Kernel Time (ms)' )
144- plt .title ('Model Evaluation GPU Kernel Performance : float8 vs bf16' )
133+ plt .title (plot_name + ' performance : float8 vs bf16' )
145134 plt .legend ()
146135 plt .grid (True )
147- plt .savefig ('model_evaluation_gpu_kernel_performance .png' )
136+ plt .savefig (plot_name + ' .png' )
148137
149138
150- if __name__ == '__main__' :
151-
152- # llm_model_names = ["bert-base-uncased", "gpt2", "t5-small", "meta-llama/Llama-3.2-3B", "nvidia/Llama-3.1-Nemotron-70B-Instruct-HF"]
153- # name_to_shapes = get_name_to_shapes_iter("llama", None, None, None)
154- name_to_shapes = get_llm_mm_shapes ("nvidia/Llama-3.1-Nemotron-70B-Instruct-HF" , None , None , None )
155-
156- print ('Shapes:' , name_to_shapes )
157- float8_dtype = torch .float8_e4m3fn # Change to the float8 dtype you want to use
158- bf16_dtype = torch .bfloat16 # Change to the comparing dtype you want to use
139+ def plot_llm_performance_data_hf_model (model_name , quantization_dtype = torch .float8_e4m3fn , quantization_technique = float8_weight_only , baseline_dtype = torch .bfloat16 , batch_sizes = [1 , 2 , 4 , 8 , 16 , 32 , 64 , 128 , 256 , 512 , 1024 , 2048 , 4096 ,]):
140+ # Dictionary to store performance data
141+ performance_data = {
142+ 'Input Size' : [],
143+ 'float8 Op Kernel Times (ms)' : [],
144+ 'bf16 Op Kernel Times (ms)' : [],
145+ 'float8 Overhead Kernel Times (ms)' : [],
146+ 'bf16 Overhead Kernel Times (ms)' : [],
147+ 'float8 Total Kernel Times (ms)' : [],
148+ 'bf16 Total Kernel Times (ms)' : [],
149+ 'Batch Size' : []
150+ }
151+ name_to_shapes = get_llm_mm_shapes (model_name , seq_len = 128 )
152+ print (f'For Model: { model_name } , and Shapes: { name_to_shapes } ' )
159153 quantization_technique = float8_weight_only # Change to the quantization technique you want to use
154+ for batch_size in batch_sizes :
155+ performance_data = run_gemm_benchmarks (
156+ performance_data = performance_data ,
157+ name_to_shapes = name_to_shapes ,
158+ float8_dtype = quantization_dtype ,
159+ other_dtype = baseline_dtype ,
160+ quantization_technique = quantization_technique ,
161+ batch_size = batch_size ,
162+ )
163+ df_performance_data = pd .DataFrame (performance_data )
164+ df_grouped = df_performance_data .groupby ('Input Size' )
165+ for name , group in df_grouped :
166+ print (f"Group: { name } " )
167+ # print(group)
168+ plot_performance_data (group , 'Batch Size' , plot_name = f'{ model_name .split ("/" )[- 1 ]} _input_size_{ name } _quant_{ quantization_technique } ' )
169+
160170
171+ if __name__ == '__main__' :
172+
173+ # Run benchmarks for LLMs
174+ llm_model_names = ["bert-base-uncased" , "gpt2" , "t5-small" , "meta-llama/Llama-3.2-3B" , "nvidia/Llama-3.1-Nemotron-70B-Instruct-HF" ]
175+ for model_name in llm_model_names :
176+ plot_llm_performance_data_hf_model (
177+ model_name ,
178+ quantization_dtype = torch .float8_e4m3fn ,
179+ quantization_technique = float8_weight_only ,
180+ baseline_dtype = torch .bfloat16 ,
181+ )
182+
183+ # Run benchmarks for different_matrix_shapes (m, k, n)
184+ name_to_shapes = get_name_to_shapes_iter ("square" , None , None , None )
185+ # Dictionary to store performance data
186+ performance_data = {
187+ 'Input Size' : [],
188+ 'float8 Op Kernel Times (ms)' : [],
189+ 'bf16 Op Kernel Times (ms)' : [],
190+ 'float8 Overhead Kernel Times (ms)' : [],
191+ 'bf16 Overhead Kernel Times (ms)' : [],
192+ 'float8 Total Kernel Times (ms)' : [],
193+ 'bf16 Total Kernel Times (ms)' : [],
194+ 'Batch Size' : []
195+ }
161196 performance_data = run_gemm_benchmarks (
197+ performance_data = performance_data ,
162198 name_to_shapes = name_to_shapes ,
163- float8_dtype = float8_dtype ,
164- other_dtype = bf16_dtype ,
165- quantization_technique = quantization_technique
166- )
167- print ('Performance data: \n ' , tabulate (performance_data , headers = performance_data .keys ()))
199+ float8_dtype = torch .float8_e4m3fn ,
200+ other_dtype = torch .bfloat16 ,
201+ quantization_technique = float8_weight_only ,
202+ )
203+ plot_performance_data (performance_data , 'Input Size' , plot_name = 'different_matrix_shapes' )
204+ # print('Performance data: \n', tabulate(performance_data, headers=performance_data.keys()))
0 commit comments