Skip to content

Commit d9fe929

Browse files
committed
add old pytorch example
1 parent ed06e36 commit d9fe929

File tree

3 files changed

+48
-1
lines changed

3 files changed

+48
-1
lines changed
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import json
2+
import time
3+
import torch
4+
import urllib
5+
import sys
6+
7+
if __name__ == "__main__":
8+
start = time.time()
9+
model = torch.hub.load('pytorch/vision:v0.6.0', 'alexnet', pretrained=True)
10+
# assert time.time() - start < 3, "looks like we just did the first-time download, run this benchmark again to get a clean run"
11+
model.eval()
12+
13+
url, filename = ("https://github.com/pytorch/hub/raw/master/images/dog.jpg", "dog.jpg")
14+
urllib.request.urlretrieve(url, filename)
15+
16+
from PIL import Image
17+
from torchvision import transforms
18+
input_image = Image.open(filename)
19+
preprocess = transforms.Compose([
20+
transforms.Resize(256),
21+
transforms.CenterCrop(224),
22+
transforms.ToTensor(),
23+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
24+
])
25+
input_tensor = preprocess(input_image)
26+
input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model
27+
28+
n = 1000
29+
if len(sys.argv) > 1:
30+
n = int(sys.argv[1])
31+
32+
with torch.no_grad():
33+
times = []
34+
for i in range(n):
35+
times.append(time.time())
36+
if i % 10 == 0:
37+
print(i)
38+
output = model(input_batch)
39+
times.append(time.time())
40+
print((len(times) - 1) / (times[-1] - times[0]) , "/s")
41+
42+
if len(sys.argv) > 2:
43+
json.dump(times, open(sys.argv[2], 'w'))
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
future==0.18.2
2+
numpy==1.19.0
3+
Pillow==8.0.0
4+
torch==1.5.1

run_all.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ set -x
1313
mkdir -p results
1414

1515
ENV=/tmp/macrobenchmark_env
16-
for bench in flaskblogging djangocms mypy_bench pylint_bench pycparser_bench; do
16+
for bench in flaskblogging djangocms mypy_bench pylint_bench pycparser_bench pytorch_alexnet_inference; do
1717
rm -rf $ENV
1818
virtualenv -p $BINARY $ENV
1919
$ENV/bin/pip install -r benchmarks/${bench}_requirements.txt

0 commit comments

Comments
 (0)