Skip to content

Commit 6393dc5

Browse files
authored
Use correct TF device depending on configuration (huggingface#492)
1 parent 8c158f2 commit 6393dc5

File tree

3 files changed

+39
-35
lines changed

3 files changed

+39
-35
lines changed

shark/shark_benchmark_runner.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -125,27 +125,29 @@ def benchmark_tf(self, modelname):
125125
import tensorflow as tf
126126
from tank.model_utils_tf import get_tf_model
127127

128-
model, input, = get_tf_model(
129-
modelname
130-
)[:2]
131-
frontend_model = model
128+
tf_device = "/GPU:0" if self.device == "cuda" else "/CPU:0"
129+
with tf.device(tf_device):
130+
model, input, = get_tf_model(
131+
modelname
132+
)[:2]
133+
frontend_model = model
132134

133-
for i in range(shark_args.num_warmup_iterations):
134-
frontend_model.forward(*input)
135+
for i in range(shark_args.num_warmup_iterations):
136+
frontend_model.forward(*input)
135137

136-
begin = time.time()
137-
for i in range(shark_args.num_iterations):
138-
out = frontend_model.forward(*input)
139-
if i == shark_args.num_iterations - 1:
140-
end = time.time()
141-
break
142-
print(
143-
f"TF benchmark:{shark_args.num_iterations/(end-begin)} iter/second, Total Iterations:{shark_args.num_iterations}"
144-
)
145-
return [
146-
f"{shark_args.num_iterations/(end-begin)}",
147-
f"{((end-begin)/shark_args.num_iterations)*1000}",
148-
]
138+
begin = time.time()
139+
for i in range(shark_args.num_iterations):
140+
out = frontend_model.forward(*input)
141+
if i == shark_args.num_iterations - 1:
142+
end = time.time()
143+
break
144+
print(
145+
f"TF benchmark:{shark_args.num_iterations/(end-begin)} iter/second, Total Iterations:{shark_args.num_iterations}"
146+
)
147+
return [
148+
f"{shark_args.num_iterations/(end-begin)}",
149+
f"{((end-begin)/shark_args.num_iterations)*1000}",
150+
]
149151

150152
def benchmark_c(self):
151153
print(self.benchmark_cl)

tank/model_utils_tf.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,6 @@
66
TFBertModel,
77
)
88

9-
visible_default = tf.config.list_physical_devices("GPU")
10-
try:
11-
tf.config.set_visible_devices([], "GPU")
12-
visible_devices = tf.config.get_visible_devices()
13-
for device in visible_devices:
14-
assert device.device_type != "GPU"
15-
except:
16-
# Invalid device or cannot modify virtual devices once initialized.
17-
pass
18-
199
BATCH_SIZE = 1
2010
MAX_SEQUENCE_LENGTH = 128
2111

tank/test_models.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import tempfile
2121
import os
2222
import shutil
23+
import multiprocessing
2324

2425

2526
def load_csv_and_convert(filename, gen=False):
@@ -241,6 +242,16 @@ def postprocess_outputs(self, golden_out, result):
241242
return expected, logits
242243

243244

245+
def run_test(module_tester, dynamic, device):
246+
tempdir = tempfile.TemporaryDirectory(
247+
prefix=module_tester.tmp_prefix, dir="./shark_tmp/"
248+
)
249+
module_tester.temp_dir = tempdir.name
250+
251+
with ireec.tools.TempFileSaver(tempdir.name):
252+
module_tester.create_and_check_module(dynamic, device)
253+
254+
244255
class SharkModuleTest(unittest.TestCase):
245256
@pytest.fixture(autouse=True)
246257
def configure(self, pytestconfig):
@@ -485,10 +496,11 @@ def test_module(self, dynamic, device, config):
485496
if not os.path.isdir("./shark_tmp/"):
486497
os.mkdir("./shark_tmp/")
487498

488-
tempdir = tempfile.TemporaryDirectory(
489-
prefix=self.module_tester.tmp_prefix, dir="./shark_tmp/"
499+
# We must create a new process each time we benchmark a model to allow
500+
# for Tensorflow to release GPU resources. Using the same process to
501+
# benchmark multiple models leads to OOM.
502+
p = multiprocessing.Process(
503+
target=run_test, args=(self.module_tester, dynamic, device)
490504
)
491-
self.module_tester.temp_dir = tempdir.name
492-
493-
with ireec.tools.TempFileSaver(tempdir.name):
494-
self.module_tester.create_and_check_module(dynamic, device)
505+
p.start()
506+
p.join()

0 commit comments

Comments
 (0)