This repository has been archived by the owner on Dec 9, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 635
/
Copy pathcnn_util_test.py
125 lines (108 loc) · 4.33 KB
/
cnn_util_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tf_cnn_benchmarks.cnn_util."""
import threading
import time
import tensorflow.compat.v1 as tf
import cnn_util
class CnnUtilBarrierTest(tf.test.TestCase):
def testBarrier(self):
num_tasks = 20
num_waits = 4
barrier = cnn_util.Barrier(num_tasks)
threads = []
sync_matrix = []
for i in range(num_tasks):
sync_times = [0] * num_waits
thread = threading.Thread(
target=self._run_task, args=(barrier, sync_times))
thread.start()
threads.append(thread)
sync_matrix.append(sync_times)
for thread in threads:
thread.join()
for wait_index in range(num_waits - 1):
# Max of times at iteration i < min of times at iteration i + 1
self.assertLessEqual(
max([sync_matrix[i][wait_index] for i in range(num_tasks)]),
min([sync_matrix[i][wait_index + 1] for i in range(num_tasks)]))
def _run_task(self, barrier, sync_times):
for wait_index in range(len(sync_times)):
sync_times[wait_index] = time.time()
barrier.wait()
def testBarrierAbort(self):
num_tasks = 2
num_waits = 1
sync_times = [0] * num_waits
barrier = cnn_util.Barrier(num_tasks)
thread = threading.Thread(
target=self._run_task, args=(barrier, sync_times))
thread.start()
barrier.abort()
# thread won't be blocked by done barrier.
thread.join()
class ImageProducerTest(tf.test.TestCase):
def _slow_tensorflow_op(self):
"""Returns a TensorFlow op that takes approximately 0.1s to complete."""
def slow_func(v):
time.sleep(0.1)
return v
return tf.py_func(slow_func, [tf.constant(0.)], tf.float32).op
def _test_image_producer(self, batch_group_size, put_slower_than_get):
# We use the variable x to simulate a staging area of images. x represents
# the number of batches in the staging area.
x = tf.Variable(0, dtype=tf.int32)
if put_slower_than_get:
put_dep = self._slow_tensorflow_op()
get_dep = tf.no_op()
else:
put_dep = tf.no_op()
get_dep = self._slow_tensorflow_op()
with tf.control_dependencies([put_dep]):
put_op = x.assign_add(batch_group_size, use_locking=True)
with tf.control_dependencies([get_dep]):
get_op = x.assign_sub(1, use_locking=True)
with self.test_session() as sess:
sess.run(tf.variables_initializer([x]))
image_producer = cnn_util.ImageProducer(sess, put_op, batch_group_size,
use_python32_barrier=False)
image_producer.start()
for _ in range(5 * batch_group_size):
sess.run(get_op)
# We assert x is nonnegative, to ensure image_producer never causes
# an unstage op to block. We assert x is at most 2 * batch_group_size,
# to ensure it doesn't use too much memory by storing too many batches
# in the staging area.
self.assertGreaterEqual(sess.run(x), 0)
self.assertLessEqual(sess.run(x), 2 * batch_group_size)
image_producer.notify_image_consumption()
self.assertGreaterEqual(sess.run(x), 0)
self.assertLessEqual(sess.run(x), 2 * batch_group_size)
image_producer.done()
time.sleep(0.1)
self.assertGreaterEqual(sess.run(x), 0)
self.assertLessEqual(sess.run(x), 2 * batch_group_size)
def test_image_producer(self):
self._test_image_producer(1, False)
self._test_image_producer(1, True)
self._test_image_producer(2, False)
self._test_image_producer(2, True)
self._test_image_producer(3, False)
self._test_image_producer(3, True)
self._test_image_producer(8, False)
self._test_image_producer(8, True)
if __name__ == '__main__':
tf.disable_v2_behavior()
tf.test.main()