Skip to content

Commit 25034e5

Browse files
authored
Add profiling to benchmarking (#2032)
1 parent 9a56a1d commit 25034e5

File tree

6 files changed

+356
-158
lines changed

6 files changed

+356
-158
lines changed

benchmarks/microbenchmarks/benchmark_inference.py

Lines changed: 75 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515

1616
import torch
1717

18+
from benchmarks.microbenchmarks.profiler import (
19+
generate_model_profile,
20+
)
1821
from benchmarks.microbenchmarks.utils import (
1922
BenchmarkConfig,
2023
BenchmarkResult,
@@ -29,70 +32,77 @@
2932

3033
def run(config: BenchmarkConfig) -> BenchmarkResult:
3134
"""Run inference benchmarks"""
32-
clean_caches() # Clean caches
33-
34-
# Create output directory if it doesn't exist
35-
Path(config.output_dir).mkdir(parents=True, exist_ok=True)
36-
37-
base_model, input_data = create_model_and_input(
38-
config.model_type,
39-
config.m,
40-
config.k,
41-
config.n,
42-
high_precision_dtype=config.high_precision_dtype,
43-
device=config.device,
44-
)
45-
46-
# Use quantize_ to apply each quantization function to the model
47-
m_copy = deepcopy(base_model).eval().to(config.device)
48-
ao_base_config = string_to_config(
49-
config.quantization,
50-
config.sparsity,
51-
high_precision_dtype=config.high_precision_dtype,
52-
)
53-
54-
# Check if sparsity is requested and if the device is CUDA (sparsity operations require CUDA)
55-
is_cuda = config.device == "cuda" and torch.cuda.is_available()
56-
57-
if config.sparsity is not None and (
58-
config.quantization is None or "baseline" in config.quantization
59-
):
60-
if is_cuda:
61-
print(f"Applying {config.sparsity} sparsity to model")
62-
sparsify_(m_copy, ao_base_config)
35+
try:
36+
clean_caches() # Clean caches
37+
38+
# Create output directory if it doesn't exist
39+
Path(config.output_dir).mkdir(parents=True, exist_ok=True)
40+
41+
base_model, input_data = create_model_and_input(
42+
config.model_type,
43+
config.m,
44+
config.k,
45+
config.n,
46+
high_precision_dtype=config.high_precision_dtype,
47+
device=config.device,
48+
)
49+
50+
# Use quantize_ to apply each quantization function to the model
51+
m_copy = deepcopy(base_model).eval().to(config.device)
52+
ao_base_config = string_to_config(
53+
config.quantization,
54+
config.sparsity,
55+
high_precision_dtype=config.high_precision_dtype,
56+
)
57+
58+
# Check if sparsity is requested and if the device is CUDA (sparsity operations require CUDA)
59+
is_cuda = config.device == "cuda" and torch.cuda.is_available()
60+
61+
if config.sparsity is not None and (
62+
config.quantization is None or "baseline" in config.quantization
63+
):
64+
if is_cuda:
65+
print(f"Applying {config.sparsity} sparsity to model")
66+
sparsify_(m_copy, ao_base_config)
67+
else:
68+
print(
69+
f"Warning: Skipping {config.sparsity} sparsity as it requires CUDA, but device is {config.device}"
70+
)
71+
elif config.sparsity is None and (
72+
config.quantization is None or "baseline" in config.quantization
73+
):
74+
pass # No quantization or sparsity specified, do nothing
6375
else:
64-
print(
65-
f"Warning: Skipping {config.sparsity} sparsity as it requires CUDA, but device is {config.device}"
76+
print("Quantizing model....")
77+
quantize_(m_copy, ao_base_config)
78+
79+
if config.use_torch_compile:
80+
print("Compiling model....")
81+
m_copy = torch.compile(
82+
m_copy, mode=config.torch_compile_mode, fullgraph=True
6683
)
67-
elif config.sparsity is None and (
68-
config.quantization is None or "baseline" in config.quantization
69-
):
70-
pass # No quantization or sparsity specified, do nothing
71-
else:
72-
print("Quantizing model....")
73-
quantize_(m_copy, ao_base_config)
74-
75-
if config.use_torch_compile:
76-
print("Compiling model....")
77-
m_copy = torch.compile(m_copy, mode=config.torch_compile_mode, fullgraph=True)
78-
79-
# Run benchmarks
80-
result = BenchmarkResult(config=config)
81-
82-
# Benchmark time to run an inference call for quantized model
83-
result.model_inference_time_in_ms = model_inference_time_in_ms(
84-
model=m_copy, input_data=input_data
85-
)
86-
87-
# TODO: Benchmark time using profiler
88-
# Profile dtype model evaluation
89-
# prof_dtype = benchmark_model_op_with_profiler_in_microseconds(m_copy, input_data, quantized_dtype)
90-
# prof_dtype.export_chrome_trace(f"{quantization}_model_{input_data[0].size()[0]}.json") # Save profiling details
91-
92-
# TODO: Benchmark gemm time using cuda graph
93-
# gemm_time = benchmark_torch_function_in_microseconds(gemm_op, *args, **kwargs)
94-
95-
# TODO: Benchmark op with cuda graph
96-
# time = benchmark_op_with_cuda_graph(op, args)
97-
98-
return result
84+
85+
# Run benchmarks
86+
result = BenchmarkResult(config=config)
87+
# Store result in model for memory profiling
88+
m_copy._benchmark_result = result
89+
90+
# Benchmark time to run an inference call for quantized model
91+
result.model_inference_time_in_ms = model_inference_time_in_ms(
92+
model=m_copy, input_data=input_data
93+
)
94+
95+
# Run profiler if enabled
96+
if config.enable_profiler:
97+
print("Running profiler...")
98+
try:
99+
result.profiler_json_path = generate_model_profile(
100+
m_copy, input_data, config.profiler_file_name
101+
)
102+
except Exception as e:
103+
print(f"Error running profiler for {config.name} with error: {e}")
104+
105+
return result
106+
except Exception as e:
107+
print(f"Error in benchmark run: {config.name} with error: {e}")
108+
return None

benchmarks/microbenchmarks/benchmark_runner.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -164,16 +164,19 @@ def run_inference_benchmarks_from_config(configs: List[BenchmarkConfig]) -> None
164164
f"Running: {config.name} for Quantization: {config.quantization} and Sparsity: {config.sparsity}"
165165
)
166166
result = run_inference(config) # Pass the config object directly
167-
results.append(result)
168-
except Exception:
169-
print(f"Error running benchmark {config.name}")
167+
if result is not None: # Only add successful results
168+
results.append(result)
169+
except Exception as e:
170+
print(f"Error running benchmark {config.name} with error: {e}")
170171
continue
171172

172-
# Add results to csv
173-
generate_results_csv(results, configs[0].output_dir)
174-
175-
# Print results
176-
print_results(results)
173+
# Add results to csv if there are any
174+
if results:
175+
generate_results_csv(results, configs[0].output_dir)
176+
# Print results
177+
print_results(results)
178+
else:
179+
print("No benchmark results were collected. All benchmarks failed.")
177180

178181
# TODO: Process results: Speedups:
179182
# 1. For different shapes for same model and quantization
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
import os
7+
8+
import torch
9+
from torch.profiler import ProfilerActivity
10+
11+
12+
def generate_model_profile(model, input_data, profile_file_path):
13+
"""Function to benchmark model evaluation with profiling.
14+
15+
Args:
16+
model: The model to profile
17+
input_data: Input data for the model
18+
profile_file_path: Path to save the profiler output
19+
20+
Returns:
21+
profile_file_path
22+
"""
23+
# Create parent directory if it doesn't exist
24+
os.makedirs(os.path.dirname(profile_file_path), exist_ok=True)
25+
26+
# Set up profiler activities based on device
27+
activities = [ProfilerActivity.CPU]
28+
device = next(model.parameters()).device
29+
if device.type == "cuda" and torch.cuda.is_available():
30+
activities.append(ProfilerActivity.CUDA)
31+
32+
# Warm up
33+
with torch.no_grad():
34+
for _ in range(3):
35+
_ = model(input_data)
36+
if device.type == "cuda":
37+
torch.cuda.synchronize()
38+
39+
# Run profiler with minimal settings to ensure compatibility
40+
with torch.profiler.profile(
41+
activities=activities,
42+
record_shapes=True,
43+
with_stack=True,
44+
profile_memory=True,
45+
with_flops=True, # Experimental; might be unreliable for some layers
46+
) as prof:
47+
with torch.no_grad():
48+
for _ in range(3):
49+
_ = model(input_data)
50+
if device.type == "cuda":
51+
torch.cuda.synchronize()
52+
53+
# Save profiling details
54+
prof.export_chrome_trace(profile_file_path)
55+
print(f"Chrome trace saved at: {profile_file_path}")
56+
print("You can now visualize it using:")
57+
print("1. Chrome Trace Viewer: chrome://tracing")
58+
print("2. Perfetto UI: https://ui.perfetto.dev")
59+
60+
return profile_file_path

benchmarks/microbenchmarks/test/benchmark_config.yml

Lines changed: 8 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2,46 +2,27 @@
22
benchmark_mode: "inference"
33
quantization_config_recipe_names:
44
# Will run a baseline inference for model by default, without quantization for comparison
5-
- "int4wo-32"
6-
- "marlin"
7-
sparsity_config_recipe_names:
5+
- "int8wo"
6+
- "int8dq"
7+
- "float8dq"
8+
- "float8wo"
9+
# sparsity_config_recipe_names:
810
# Will run a baseline inference for model by default, without sparsity for comparison
9-
- "semi-sparse"
10-
- "block"
11+
# - "semi-sparse"
12+
# - "block"
1113
output_dir: "benchmarks/microbenchmarks/results"
1214
model_params:
1315
- name: "small_bf16_linear"
1416
matrix_shapes:
1517
- name: "custom"
1618
shapes: [
1719
[1024, 1024, 1024], # [m, k, n]
18-
]
19-
high_precision_dtype: "torch.bfloat16"
20-
use_torch_compile: true
21-
torch_compile_mode: "max-autotune"
22-
device: "cuda"
23-
model_type: "linear"
24-
25-
- name: "large_bf16_ln_linear"
26-
matrix_shapes:
27-
- name: "custom"
28-
shapes: [
2920
[2048, 4096, 1024],
3021
[4096, 4096, 1024]
3122
]
3223
high_precision_dtype: "torch.bfloat16"
3324
use_torch_compile: true
3425
torch_compile_mode: "max-autotune"
3526
device: "cuda"
36-
model_type: "ln_linear_sigmoid"
37-
38-
- name: "cpu_fp32_linear"
39-
matrix_shapes:
40-
- name: "custom"
41-
shapes: [
42-
[4096, 4096, 1024]
43-
]
44-
high_precision_dtype: "torch.float32"
45-
use_torch_compile: false
46-
device: "cpu"
4727
model_type: "linear"
28+
enable_profiler: true # Enable profiling for this model

0 commit comments

Comments
 (0)