Skip to content

Add profiling to benchmarking #2032

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Apr 18, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 75 additions & 65 deletions benchmarks/microbenchmarks/benchmark_inference.py
Original file line number Diff line number Diff line change
@@ -15,6 +15,9 @@

import torch

from benchmarks.microbenchmarks.profiler import (
generate_model_profile,
)
from benchmarks.microbenchmarks.utils import (
BenchmarkConfig,
BenchmarkResult,
@@ -29,70 +32,77 @@

def run(config: BenchmarkConfig) -> BenchmarkResult:
"""Run inference benchmarks"""
clean_caches() # Clean caches

# Create output directory if it doesn't exist
Path(config.output_dir).mkdir(parents=True, exist_ok=True)

base_model, input_data = create_model_and_input(
config.model_type,
config.m,
config.k,
config.n,
high_precision_dtype=config.high_precision_dtype,
device=config.device,
)

# Use quantize_ to apply each quantization function to the model
m_copy = deepcopy(base_model).eval().to(config.device)
ao_base_config = string_to_config(
config.quantization,
config.sparsity,
high_precision_dtype=config.high_precision_dtype,
)

# Check if sparsity is requested and if the device is CUDA (sparsity operations require CUDA)
is_cuda = config.device == "cuda" and torch.cuda.is_available()

if config.sparsity is not None and (
config.quantization is None or "baseline" in config.quantization
):
if is_cuda:
print(f"Applying {config.sparsity} sparsity to model")
sparsify_(m_copy, ao_base_config)
try:
clean_caches() # Clean caches

# Create output directory if it doesn't exist
Path(config.output_dir).mkdir(parents=True, exist_ok=True)

base_model, input_data = create_model_and_input(
config.model_type,
config.m,
config.k,
config.n,
high_precision_dtype=config.high_precision_dtype,
device=config.device,
)

# Use quantize_ to apply each quantization function to the model
m_copy = deepcopy(base_model).eval().to(config.device)
ao_base_config = string_to_config(
config.quantization,
config.sparsity,
high_precision_dtype=config.high_precision_dtype,
)

# Check if sparsity is requested and if the device is CUDA (sparsity operations require CUDA)
is_cuda = config.device == "cuda" and torch.cuda.is_available()

if config.sparsity is not None and (
config.quantization is None or "baseline" in config.quantization
):
if is_cuda:
print(f"Applying {config.sparsity} sparsity to model")
sparsify_(m_copy, ao_base_config)
else:
print(
f"Warning: Skipping {config.sparsity} sparsity as it requires CUDA, but device is {config.device}"
)
elif config.sparsity is None and (
config.quantization is None or "baseline" in config.quantization
):
pass # No quantization or sparsity specified, do nothing
else:
print(
f"Warning: Skipping {config.sparsity} sparsity as it requires CUDA, but device is {config.device}"
print("Quantizing model....")
quantize_(m_copy, ao_base_config)

if config.use_torch_compile:
print("Compiling model....")
m_copy = torch.compile(
m_copy, mode=config.torch_compile_mode, fullgraph=True
)
elif config.sparsity is None and (
config.quantization is None or "baseline" in config.quantization
):
pass # No quantization or sparsity specified, do nothing
else:
print("Quantizing model....")
quantize_(m_copy, ao_base_config)

if config.use_torch_compile:
print("Compiling model....")
m_copy = torch.compile(m_copy, mode=config.torch_compile_mode, fullgraph=True)

# Run benchmarks
result = BenchmarkResult(config=config)

# Benchmark time to run an inference call for quantized model
result.model_inference_time_in_ms = model_inference_time_in_ms(
model=m_copy, input_data=input_data
)

# TODO: Benchmark time using profiler
# Profile dtype model evaluation
# prof_dtype = benchmark_model_op_with_profiler_in_microseconds(m_copy, input_data, quantized_dtype)
# prof_dtype.export_chrome_trace(f"{quantization}_model_{input_data[0].size()[0]}.json") # Save profiling details

# TODO: Benchmark gemm time using cuda graph
# gemm_time = benchmark_torch_function_in_microseconds(gemm_op, *args, **kwargs)

# TODO: Benchmark op with cuda graph
# time = benchmark_op_with_cuda_graph(op, args)

return result

# Run benchmarks
result = BenchmarkResult(config=config)
# Store result in model for memory profiling
m_copy._benchmark_result = result

# Benchmark time to run an inference call for quantized model
result.model_inference_time_in_ms = model_inference_time_in_ms(
model=m_copy, input_data=input_data
)

# Run profiler if enabled
if config.enable_profiler:
print("Running profiler...")
try:
result.profiler_json_path = generate_model_profile(
m_copy, input_data, config.profiler_file_name
)
except Exception as e:
print(f"Error running profiler for {config.name} with error: {e}")

return result
except Exception as e:
print(f"Error in benchmark run: {config.name} with error: {e}")
return None
19 changes: 11 additions & 8 deletions benchmarks/microbenchmarks/benchmark_runner.py
Original file line number Diff line number Diff line change
@@ -164,16 +164,19 @@ def run_inference_benchmarks_from_config(configs: List[BenchmarkConfig]) -> None
f"Running: {config.name} for Quantization: {config.quantization} and Sparsity: {config.sparsity}"
)
result = run_inference(config) # Pass the config object directly
results.append(result)
except Exception:
print(f"Error running benchmark {config.name}")
if result is not None: # Only add successful results
results.append(result)
except Exception as e:
print(f"Error running benchmark {config.name} with error: {e}")
continue

# Add results to csv
generate_results_csv(results, configs[0].output_dir)

# Print results
print_results(results)
# Add results to csv if there are any
if results:
generate_results_csv(results, configs[0].output_dir)
# Print results
print_results(results)
else:
print("No benchmark results were collected. All benchmarks failed.")

# TODO: Process results: Speedups:
# 1. For different shapes for same model and quantization
60 changes: 60 additions & 0 deletions benchmarks/microbenchmarks/profiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
import os

import torch
from torch.profiler import ProfilerActivity


def generate_model_profile(model, input_data, profile_file_path):
"""Function to benchmark model evaluation with profiling.
Args:
model: The model to profile
input_data: Input data for the model
profile_file_path: Path to save the profiler output
Returns:
profile_file_path
"""
# Create parent directory if it doesn't exist
os.makedirs(os.path.dirname(profile_file_path), exist_ok=True)

# Set up profiler activities based on device
activities = [ProfilerActivity.CPU]
device = next(model.parameters()).device
if device.type == "cuda" and torch.cuda.is_available():
activities.append(ProfilerActivity.CUDA)

# Warm up
with torch.no_grad():
for _ in range(3):
_ = model(input_data)
if device.type == "cuda":
torch.cuda.synchronize()

# Run profiler with minimal settings to ensure compatibility
with torch.profiler.profile(
activities=activities,
record_shapes=True,
with_stack=True,
profile_memory=True,
with_flops=True, # Experimental; might be unreliable for some layers
) as prof:
with torch.no_grad():
for _ in range(3):
_ = model(input_data)
if device.type == "cuda":
torch.cuda.synchronize()

# Save profiling details
prof.export_chrome_trace(profile_file_path)
print(f"Chrome trace saved at: {profile_file_path}")
print("You can now visualize it using:")
print("1. Chrome Trace Viewer: chrome://tracing")
print("2. Perfetto UI: https://ui.perfetto.dev")

return profile_file_path
35 changes: 8 additions & 27 deletions benchmarks/microbenchmarks/test/benchmark_config.yml
Original file line number Diff line number Diff line change
@@ -2,46 +2,27 @@
benchmark_mode: "inference"
quantization_config_recipe_names:
# Will run a baseline inference for model by default, without quantization for comparison
- "int4wo-32"
- "marlin"
sparsity_config_recipe_names:
- "int8wo"
- "int8dq"
- "float8dq"
- "float8wo"
# sparsity_config_recipe_names:
# Will run a baseline inference for model by default, without sparsity for comparison
- "semi-sparse"
- "block"
# - "semi-sparse"
# - "block"
output_dir: "benchmarks/microbenchmarks/results"
model_params:
- name: "small_bf16_linear"
matrix_shapes:
- name: "custom"
shapes: [
[1024, 1024, 1024], # [m, k, n]
]
high_precision_dtype: "torch.bfloat16"
use_torch_compile: true
torch_compile_mode: "max-autotune"
device: "cuda"
model_type: "linear"

- name: "large_bf16_ln_linear"
matrix_shapes:
- name: "custom"
shapes: [
[2048, 4096, 1024],
[4096, 4096, 1024]
]
high_precision_dtype: "torch.bfloat16"
use_torch_compile: true
torch_compile_mode: "max-autotune"
device: "cuda"
model_type: "ln_linear_sigmoid"

- name: "cpu_fp32_linear"
matrix_shapes:
- name: "custom"
shapes: [
[4096, 4096, 1024]
]
high_precision_dtype: "torch.float32"
use_torch_compile: false
device: "cpu"
model_type: "linear"
enable_profiler: true # Enable profiling for this model
156 changes: 156 additions & 0 deletions benchmarks/microbenchmarks/test/test_benchmark_profiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

import json
import os
import unittest

import torch

from benchmarks.microbenchmarks.profiler import (
generate_model_profile,
)
from benchmarks.microbenchmarks.utils import (
BenchmarkConfig,
ToyLinearModel,
)


class TestBenchmarkProfiler(unittest.TestCase):
def setUp(self):
self.test_dir = os.path.dirname(os.path.abspath(__file__))
self.results_dir = os.path.join(self.test_dir, "results")
os.makedirs(self.results_dir, exist_ok=True)

# Set up a simple model and input for testing
self.m, self.k, self.n = 1024, 1024, 1024
self.dtype = torch.bfloat16
self.model = ToyLinearModel(k=self.k, n=self.n, dtype=self.dtype)
self.input_data = torch.randn(1, self.k, dtype=self.dtype)

# Move to appropriate device
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = self.model.to(self.device)
self.input_data = self.input_data.to(self.device)

def tearDown(self):
# Clean up any generated files
import shutil

if os.path.exists(self.results_dir):
shutil.rmtree(self.results_dir)

def test_profiler_enabled(self):
"""Test that profiler works when enabled"""
config = BenchmarkConfig(
quantization=None,
sparsity=None,
params={
"enable_profiler": True,
"device": self.device,
},
shape_name="test",
shape=[self.m, self.k, self.n],
output_dir=self.results_dir,
benchmark_mode="inference",
)

profile_path = os.path.join(
self.results_dir,
"profiler",
f"{config.name}_{self.m}_{self.k}_{self.n}_profile.json",
)

# Generate profile
result_path = generate_model_profile(self.model, self.input_data, profile_path)

# Check that profile file exists and is not empty
self.assertTrue(os.path.exists(result_path))
self.assertGreater(os.path.getsize(result_path), 0)

# Verify it's valid JSON
with open(result_path) as f:
profile_data = json.load(f)
self.assertIsInstance(profile_data, dict)

def test_profiler_basic_output(self):
"""Test that profiler output contains expected basic fields"""
config = BenchmarkConfig(
quantization=None,
sparsity=None,
params={
"enable_profiler": True,
"device": self.device,
},
shape_name="test",
shape=[self.m, self.k, self.n],
output_dir=self.results_dir,
benchmark_mode="inference",
)

profile_path = os.path.join(
self.results_dir,
"profiler",
f"{config.name}_{self.m}_{self.k}_{self.n}_profile.json",
)

result_path = generate_model_profile(self.model, self.input_data, profile_path)

with open(result_path) as f:
data = json.load(f)

# Check for required Chrome Trace Event format fields
self.assertIn("traceEvents", data)
self.assertTrue(isinstance(data["traceEvents"], list))

# Check that we have some events
self.assertGreater(len(data["traceEvents"]), 0)

# Check event format
event = data["traceEvents"][0]
self.assertIn("name", event)
self.assertIn("ph", event) # Phase
self.assertIn("ts", event) # Timestamp
self.assertIn("pid", event) # Process ID

@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_cuda_profiling(self):
"""Test CUDA profiling when available"""
config = BenchmarkConfig(
quantization=None,
sparsity=None,
params={
"enable_profiler": True,
"device": "cuda",
},
shape_name="test",
shape=[self.m, self.k, self.n],
output_dir=self.results_dir,
benchmark_mode="inference",
)

profile_path = os.path.join(
self.results_dir,
"profiler",
f"{config.name}_{self.m}_{self.k}_{self.n}_profile.json",
)

result_path = generate_model_profile(
self.model.cuda(), self.input_data.cuda(), profile_path
)

with open(result_path) as f:
data = json.load(f)

# Check for CUDA events
cuda_events = [
event for event in data["traceEvents"] if "cuda" in event.get("name", "")
]
self.assertGreater(len(cuda_events), 0)


if __name__ == "__main__":
unittest.main()
104 changes: 46 additions & 58 deletions benchmarks/microbenchmarks/utils.py
Original file line number Diff line number Diff line change
@@ -84,6 +84,14 @@ def __init__(
"name",
f"benchmark_{self.quantization}_{self.model_type}_m{self.m}_k{self.k}_n{self.n}{'_compile' if self.use_torch_compile else ''}",
)
self.enable_profiler = bool(params.get("enable_profiler", False))
# Create profiler directory path without leading slash
profiler_dir = os.path.join(self.output_dir, "profiler")
os.makedirs(profiler_dir, exist_ok=True)
file_name = f"{self.name}_{self.m}_{self.k}_{self.n}_quant_{self.quantization}_sparsity_{self.sparsity}"
self.profiler_file_name = os.path.join(
profiler_dir, f"{file_name}_profile.json"
)

@staticmethod
def _parse_precision(precision_str: str) -> torch.dtype:
@@ -105,6 +113,7 @@ def to_dict(self) -> Dict[str, Any]:
"device": self.device,
"model_type": self.model_type,
"output_dir": self.output_dir,
"enable_profiler": self.enable_profiler,
}


@@ -116,13 +125,16 @@ def __init__(
self.config = config
self.output_dir = config.output_dir
self.model_inference_time_in_ms = 0.0
self.profiler_json_path: Optional[str] = None

def to_dict(self) -> Dict[str, Any]:
"""Convert result to dictionary for main function"""
return {
result_dict = {
**self.config.to_dict(),
"model_inference_time_in_ms": self.model_inference_time_in_ms,
"profiler_json_path": self.profiler_json_path,
}
return result_dict


class ToyLinearModel(torch.nn.Module):
@@ -379,6 +391,11 @@ def generate_results_csv(
output_dir (str): Directory to save the CSV file.
file_name (str, optional): Name of the CSV file. Defaults to "results.csv".
"""
# Check if results list is empty
if len(results) == 0:
print("No results to save to CSV.")
return

# Create the output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)
file_path = os.path.join(output_dir, file_name)
@@ -396,68 +413,39 @@ def generate_results_csv(


def print_results(results: List[BenchmarkResult]):
"""Print benchmark results in a formatted table.
Args:
results (List[BenchmarkResult]): List of benchmark results
"""
"""Print results in a table format"""
if not results:
print("No results to display")
return

# Extract relevant columns for display
display_columns = [
"quantization",
"sparsity",
"model_type",
"m",
"k",
"n",
"model_inference_time_in_ms",
"use_torch_compile",
]

# Format data for tabulate
headers = {
"quantization": "Quantization",
"sparsity": "Sparsity",
"model_type": "Model Type",
"m": "M",
"k": "K",
"n": "N",
"model_inference_time_in_ms": "Time (μs)",
"use_torch_compile": "Compile Mode",
}

# Extract and format data
table_data = []
for result in results:
result_dict = result.to_dict()
row = []
for col in display_columns:
value = result_dict.get(col, "N/A")
if value is None:
value = "N/A"
if col == "model_inference_time_in_ms":
value = f"{value:.2f}" if isinstance(value, (int, float)) else value
elif col == "use_torch_compile":
# Show compile mode if compile is True, otherwise show False
value = (
result_dict.get("torch_compile_mode", "default")
if result_dict.get("use_torch_compile")
else "False"
)
row.append(value)
if result is None:
continue

row = [
result.config.name,
result.config.quantization or "baseline",
result.config.sparsity or "none",
f"{result.config.shape_name} ({result.config.m}, {result.config.k}, {result.config.n})",
f"{result.model_inference_time_in_ms:.2f}",
str(result.config.enable_profiler),
]

table_data.append(row)

# Print formatted table
print("\nBenchmark Results:")
print(
tabulate(
table_data,
headers=[headers[col] for col in display_columns],
tablefmt="grid",
floatfmt=".2f",
)
)
print()
# Define headers
headers = [
"Name",
"Quantization",
"Sparsity",
"Shape",
"Inference Time (ms)",
"Profiler Enabled",
]

if table_data:
print("\nBenchmark Results:")
print(tabulate(table_data, headers=headers, tablefmt="grid"))
else:
print("\nNo valid results to display")