Skip to content

Commit 220b98b

Browse files
authored
profiler recipe update (#3435)
1 parent a96b470 commit 220b98b

File tree

1 file changed

+36
-23
lines changed

1 file changed

+36
-23
lines changed

recipes_source/recipes/profiler_recipe.py

Lines changed: 36 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
"""
22
PyTorch Profiler
33
====================================
4+
**Author:** `Shivam Raikundalia <https://github.com/sraikund16>`_
5+
"""
6+
7+
######################################################################
8+
"""
49
This recipe explains how to use PyTorch profiler and measure the time and
510
memory consumption of the model's operators.
611
@@ -12,6 +17,10 @@
1217
In this recipe, we will use a simple Resnet model to demonstrate how to
1318
use profiler to analyze model performance.
1419
20+
Prerequisites
21+
---------------
22+
- ``torch >= 1.9``
23+
1524
Setup
1625
-----
1726
To install ``torch`` and ``torchvision`` use the following command:
@@ -20,10 +29,8 @@
2029
2130
pip install torch torchvision
2231
23-
2432
"""
2533

26-
2734
######################################################################
2835
# Steps
2936
# -----
@@ -45,7 +52,7 @@
4552

4653
import torch
4754
import torchvision.models as models
48-
from torch.profiler import profile, record_function, ProfilerActivity
55+
from torch.profiler import profile, ProfilerActivity, record_function
4956

5057

5158
######################################################################
@@ -135,7 +142,11 @@
135142
# To get a finer granularity of results and include operator input shapes, pass ``group_by_input_shape=True``
136143
# (note: this requires running the profiler with ``record_shapes=True``):
137144

138-
print(prof.key_averages(group_by_input_shape=True).table(sort_by="cpu_time_total", row_limit=10))
145+
print(
146+
prof.key_averages(group_by_input_shape=True).table(
147+
sort_by="cpu_time_total", row_limit=10
148+
)
149+
)
139150

140151
########################################################################################
141152
# The output might look like this (omitting some columns):
@@ -167,14 +178,17 @@
167178
# Users could switch between cpu, cuda and xpu
168179
activities = [ProfilerActivity.CPU]
169180
if torch.cuda.is_available():
170-
device = 'cuda'
181+
device = "cuda"
171182
activities += [ProfilerActivity.CUDA]
172183
elif torch.xpu.is_available():
173-
device = 'xpu'
184+
device = "xpu"
174185
activities += [ProfilerActivity.XPU]
175186
else:
176-
print('Neither CUDA nor XPU devices are available to demonstrate profiling on acceleration devices')
187+
print(
188+
"Neither CUDA nor XPU devices are available to demonstrate profiling on acceleration devices"
189+
)
177190
import sys
191+
178192
sys.exit(0)
179193

180194
sort_by_keyword = device + "_time_total"
@@ -256,8 +270,9 @@
256270
model = models.resnet18()
257271
inputs = torch.randn(5, 3, 224, 224)
258272

259-
with profile(activities=[ProfilerActivity.CPU],
260-
profile_memory=True, record_shapes=True) as prof:
273+
with profile(
274+
activities=[ProfilerActivity.CPU], profile_memory=True, record_shapes=True
275+
) as prof:
261276
model(inputs)
262277

263278
print(prof.key_averages().table(sort_by="self_cpu_memory_usage", row_limit=10))
@@ -312,14 +327,17 @@
312327
# Users could switch between cpu, cuda and xpu
313328
activities = [ProfilerActivity.CPU]
314329
if torch.cuda.is_available():
315-
device = 'cuda'
330+
device = "cuda"
316331
activities += [ProfilerActivity.CUDA]
317332
elif torch.xpu.is_available():
318-
device = 'xpu'
333+
device = "xpu"
319334
activities += [ProfilerActivity.XPU]
320335
else:
321-
print('Neither CUDA nor XPU devices are available to demonstrate profiling on acceleration devices')
336+
print(
337+
"Neither CUDA nor XPU devices are available to demonstrate profiling on acceleration devices"
338+
)
322339
import sys
340+
323341
sys.exit(0)
324342

325343
model = models.resnet18().to(device)
@@ -347,6 +365,7 @@
347365
with profile(
348366
activities=activities,
349367
with_stack=True,
368+
experimental_config=torch._C._profiler._ExperimentalConfig(verbose=True),
350369
) as prof:
351370
model(inputs)
352371

@@ -401,12 +420,7 @@
401420

402421
from torch.profiler import schedule
403422

404-
my_schedule = schedule(
405-
skip_first=10,
406-
wait=5,
407-
warmup=1,
408-
active=3,
409-
repeat=2)
423+
my_schedule = schedule(skip_first=10, wait=5, warmup=1, active=3, repeat=2)
410424

411425
######################################################################
412426
# Profiler assumes that the long-running job is composed of steps, numbered
@@ -444,18 +458,17 @@
444458

445459
sort_by_keyword = "self_" + device + "_time_total"
446460

461+
447462
def trace_handler(p):
448463
output = p.key_averages().table(sort_by=sort_by_keyword, row_limit=10)
449464
print(output)
450465
p.export_chrome_trace("/tmp/trace_" + str(p.step_num) + ".json")
451466

467+
452468
with profile(
453469
activities=activities,
454-
schedule=torch.profiler.schedule(
455-
wait=1,
456-
warmup=1,
457-
active=2),
458-
on_trace_ready=trace_handler
470+
schedule=torch.profiler.schedule(wait=1, warmup=1, active=2),
471+
on_trace_ready=trace_handler,
459472
) as p:
460473
for idx in range(8):
461474
model(inputs)

0 commit comments

Comments
 (0)