@@ -695,7 +695,7 @@ class RunQueue {
695
695
696
696
static std::atomic<uint32_t > next_tag{1 };
697
697
698
- template <typename Environment, bool kIsHybrid >
698
+ template <typename Environment>
699
699
class ThreadPoolTempl : public onnxruntime ::concurrency::ExtendedThreadPoolInterface {
700
700
private:
701
701
struct PerThread ;
@@ -767,29 +767,6 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter
767
767
typedef std::function<void ()> Task;
768
768
typedef RunQueue<Task, Tag, 1024 > Queue;
769
769
770
- // Class for waiting w/ exponential backoff.
771
- // Template argument is maximum number of spins in backoff loop.
772
- template <unsigned kMaxBackoff >
773
- class ThreadPoolWaiter {
774
- // Current number if spins in backoff loop
775
- unsigned pause_time_;
776
-
777
- public:
778
- void wait () {
779
- // If kMaxBackoff is zero don't do any pausing.
780
- if constexpr (kMaxBackoff == 1 ) {
781
- onnxruntime::concurrency::SpinPause ();
782
- } else if constexpr (kMaxBackoff > 1 ) {
783
- // Exponential backoff
784
- unsigned pause_time = pause_time_ + 1U ;
785
- for (unsigned i = 0 ; i < pause_time; ++i) {
786
- onnxruntime::concurrency::SpinPause ();
787
- }
788
- pause_time_ = (pause_time * 2U ) % kMaxBackoff ;
789
- }
790
- }
791
- };
792
-
793
770
ThreadPoolTempl (const CHAR_TYPE* name, int num_threads, bool allow_spinning, Environment& env,
794
771
const ThreadOptions& thread_options)
795
772
: profiler_(num_threads, name),
@@ -931,9 +908,8 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter
931
908
// finish dispatch work. This avoids new tasks being started
932
909
// concurrently with us attempting to end the parallel section.
933
910
if (ps.dispatch_q_idx != -1 ) {
934
- ThreadPoolWaiter<4 > waiter{};
935
911
while (!ps.dispatch_done .load (std::memory_order_acquire)) {
936
- waiter. wait ();
912
+ onnxruntime::concurrency::SpinPause ();
937
913
}
938
914
}
939
915
@@ -955,17 +931,15 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter
955
931
956
932
// Wait for the dispatch task's own work...
957
933
if (ps.dispatch_q_idx > -1 ) {
958
- ThreadPoolWaiter<kIsHybrid ? 0 : 1 > waiter{};
959
934
while (!ps.work_done .load (std::memory_order_acquire)) {
960
- waiter. wait ();
935
+ onnxruntime::concurrency::SpinPause ();
961
936
}
962
937
}
963
938
964
939
// ...and wait for any other tasks not revoked to finish their work
965
940
auto tasks_to_wait_for = tasks_started - ps.tasks_revoked ;
966
- ThreadPoolWaiter<kIsHybrid ? 0 : 1 > waiter{};
967
941
while (ps.tasks_finished < tasks_to_wait_for) {
968
- waiter. wait ();
942
+ onnxruntime::concurrency::SpinPause ();
969
943
}
970
944
971
945
// Clear status to allow the ThreadPoolParallelSection to be
@@ -1283,10 +1257,9 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter
1283
1257
// Increase the worker count if needed. Each worker will pick up
1284
1258
// loops to execute from the current parallel section.
1285
1259
std::function<void (unsigned )> worker_fn = [&ps](unsigned par_idx) {
1286
- ThreadPoolWaiter<kIsHybrid ? 4 : 0 > waiter{};
1287
1260
while (ps.active ) {
1288
1261
if (ps.current_loop .load () == nullptr ) {
1289
- waiter. wait ();
1262
+ onnxruntime::concurrency::SpinPause ();
1290
1263
} else {
1291
1264
ps.workers_in_loop ++;
1292
1265
ThreadPoolLoop* work_item = ps.current_loop ;
@@ -1307,9 +1280,8 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter
1307
1280
1308
1281
// Wait for workers to exit the loop
1309
1282
ps.current_loop = 0 ;
1310
- ThreadPoolWaiter<kIsHybrid ? 1 : 4 > waiter{};
1311
1283
while (ps.workers_in_loop ) {
1312
- waiter. wait ();
1284
+ onnxruntime::concurrency::SpinPause ();
1313
1285
}
1314
1286
profiler_.LogEnd (ThreadPoolProfiler::WAIT);
1315
1287
}
@@ -1560,30 +1532,13 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter
1560
1532
1561
1533
assert (td.GetStatus () == WorkerData::ThreadStatus::Spinning);
1562
1534
1563
- // The exact value of spin_count and steal_count are arbitrary and
1564
- // were experimentally determined. These numbers yielded the best
1565
- // performance across a range of workloads and
1566
- // machines. Generally, the goal of tuning spin_count is to make
1567
- // the number as small as possible while ensuring there is enough
1568
- // slack so that if each core is doing the same amount of work it
1569
- // won't sleep before they have all finished. The idea here is
1570
- // that in pipelined workloads, it won't sleep during each stage
1571
- // if it's done a bit faster than its neighbors, but that if there
1572
- // are non-equal sizes of work distributed, it won't take too long
1573
- // to reach sleep giving power (and thus frequency/performance) to
1574
- // its neighbors. Since hybrid has P/E cores, a lower value is
1575
- // chosen. On hybrid systems, even with equal sized workloads
1576
- // distributed the compute time won't stay synced. Typically in
1577
- // the hybrid case the P cores finish first (and are thus waiting)
1578
- // which is essentially a priority inversion.
1579
- constexpr int pref_spin_count = kIsHybrid ? 5000 : 10000 ;
1580
- const int spin_count = allow_spinning_ ? pref_spin_count : 0 ;
1581
- constexpr int steal_count = pref_spin_count / (kIsHybrid ? 25 : 100 );
1535
+ constexpr int log2_spin = 20 ;
1536
+ const int spin_count = allow_spinning_ ? (1ull << log2_spin) : 0 ;
1537
+ const int steal_count = spin_count / 100 ;
1582
1538
1583
1539
SetDenormalAsZero (set_denormal_as_zero_);
1584
1540
profiler_.LogThreadId (thread_id);
1585
1541
1586
- ThreadPoolWaiter<kIsHybrid ? 1 : 8 > waiter{};
1587
1542
while (!should_exit) {
1588
1543
Task t = q.PopFront ();
1589
1544
if (!t) {
@@ -1599,7 +1554,7 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter
1599
1554
if (spin_loop_status_.load (std::memory_order_relaxed) == SpinLoopStatus::kIdle ) {
1600
1555
break ;
1601
1556
}
1602
- waiter. wait ();
1557
+ onnxruntime::concurrency::SpinPause ();
1603
1558
}
1604
1559
1605
1560
// Attempt to block
0 commit comments