Skip to content

Commit 881056f

Browse files
committed
m: New impl
Signed-off-by: John Nunley <[email protected]>
1 parent a80088e commit 881056f

File tree

2 files changed

+167
-48
lines changed

2 files changed

+167
-48
lines changed

Cargo.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ exclude = ["/.*"]
1919
static = []
2020

2121
[dependencies]
22-
ahash = "0.8.11"
2322
async-task = "4.4.0"
2423
concurrent-queue = "2.5.0"
2524
fastrand = "2.0.0"

src/lib.rs

Lines changed: 167 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -39,16 +39,17 @@
3939
)]
4040
#![cfg_attr(docsrs, feature(doc_auto_cfg))]
4141

42+
use std::cell::{Cell, RefCell};
43+
use std::cmp::Reverse;
44+
use std::collections::VecDeque;
4245
use std::fmt;
4346
use std::marker::PhantomData;
4447
use std::panic::{RefUnwindSafe, UnwindSafe};
4548
use std::rc::Rc;
4649
use std::sync::atomic::{AtomicBool, AtomicPtr, AtomicUsize, Ordering};
4750
use std::sync::{Arc, Mutex, RwLock, TryLockError};
4851
use std::task::{Poll, Waker};
49-
use std::thread::{self, ThreadId};
5052

51-
use ahash::AHashMap;
5253
use async_task::{Builder, Runnable};
5354
use concurrent_queue::ConcurrentQueue;
5455
use futures_lite::{future, prelude::*};
@@ -355,8 +356,8 @@ impl<'a> Executor<'a> {
355356
.local_queues
356357
.read()
357358
.unwrap()
358-
.get(&thread_id())
359-
.and_then(|list| list.first())
359+
.get(thread_id())
360+
.and_then(|list| list.as_ref())
360361
{
361362
match local_queue.queue.push(runnable) {
362363
Ok(()) => {
@@ -692,8 +693,9 @@ struct State {
692693

693694
/// Local queues created by runners.
694695
///
695-
/// These are keyed by the thread that the runner originated in.
696-
local_queues: RwLock<AHashMap<ThreadId, Vec<Arc<LocalQueue>>>>,
696+
/// These are keyed by the thread that the runner originated in. See the `thread_id` function
697+
/// for more information.
698+
local_queues: RwLock<Vec<Option<Arc<LocalQueue>>>>,
697699

698700
/// Set to `true` when a sleeping ticker is notified or no tickers are sleeping.
699701
notified: AtomicBool,
@@ -710,7 +712,7 @@ impl State {
710712
const fn new() -> State {
711713
State {
712714
queue: ConcurrentQueue::unbounded(),
713-
local_queues: RwLock::new(AHashMap::new()),
715+
local_queues: RwLock::new(Vec::new()),
714716
notified: AtomicBool::new(true),
715717
sleepers: Mutex::new(Sleepers {
716718
count: 0,
@@ -1025,7 +1027,9 @@ struct Runner<'a> {
10251027
ticker: Ticker<'a>,
10261028

10271029
/// The ID of the thread we originated from.
1028-
origin_id: ThreadId,
1030+
///
1031+
/// This is `None` if we don't own the local runner for this thread.
1032+
origin_id: Option<usize>,
10291033

10301034
/// The local queue.
10311035
local: Arc<LocalQueue>,
@@ -1041,23 +1045,42 @@ impl Runner<'_> {
10411045
let runner_id = ID_GENERATOR.fetch_add(1, Ordering::SeqCst);
10421046

10431047
let origin_id = thread_id();
1044-
let runner = Runner {
1048+
let mut runner = Runner {
10451049
state,
10461050
ticker: Ticker::for_runner(state, runner_id),
10471051
local: Arc::new(LocalQueue {
10481052
queue: ConcurrentQueue::bounded(512),
10491053
runner_id,
10501054
}),
10511055
ticks: 0,
1052-
origin_id,
1056+
origin_id: Some(origin_id),
10531057
};
1054-
state
1058+
1059+
// If this is the highest thread ID this executor has seen, make more slots.
1060+
let mut local_queues = state.local_queues.write().unwrap();
1061+
if local_queues.len() <= origin_id {
1062+
local_queues.resize_with(origin_id + 1, || None);
1063+
}
1064+
1065+
// Try to reserve the thread-local slot.
1066+
match state
10551067
.local_queues
10561068
.write()
10571069
.unwrap()
1058-
.entry(origin_id)
1059-
.or_default()
1060-
.push(runner.local.clone());
1070+
.get_mut(origin_id)
1071+
.unwrap()
1072+
{
1073+
slot @ None => {
1074+
// We won the race, insert our queue.
1075+
*slot = Some(runner.local.clone());
1076+
}
1077+
1078+
Some(_) => {
1079+
// We lost the race, indicate we don't own this ID.
1080+
runner.origin_id = None;
1081+
}
1082+
}
1083+
10611084
runner
10621085
}
10631086

@@ -1085,8 +1108,8 @@ impl Runner<'_> {
10851108
let start = rng.usize(..n);
10861109
let iter = local_queues
10871110
.iter()
1088-
.flat_map(|(_, list)| list)
1089-
.chain(local_queues.iter().flat_map(|(_, list)| list))
1111+
.filter_map(|list| list.as_ref())
1112+
.chain(local_queues.iter().filter_map(|list| list.as_ref()))
10901113
.skip(start)
10911114
.take(n);
10921115

@@ -1120,13 +1143,15 @@ impl Runner<'_> {
11201143
impl Drop for Runner<'_> {
11211144
fn drop(&mut self) {
11221145
// Remove the local queue.
1123-
self.state
1124-
.local_queues
1125-
.write()
1126-
.unwrap()
1127-
.get_mut(&self.origin_id)
1128-
.unwrap()
1129-
.retain(|local| !Arc::ptr_eq(local, &self.local));
1146+
if let Some(origin_id) = self.origin_id {
1147+
*self
1148+
.state
1149+
.local_queues
1150+
.write()
1151+
.unwrap()
1152+
.get_mut(origin_id)
1153+
.unwrap() = None;
1154+
}
11301155

11311156
// Re-schedule remaining tasks in the local queue.
11321157
while let Ok(r) = self.local.queue.pop() {
@@ -1206,25 +1231,7 @@ fn debug_state(state: &State, name: &str, f: &mut fmt::Formatter<'_>) -> fmt::Re
12061231
}
12071232
}
12081233

1209-
/// Debug wrapper for the local runners.
1210-
struct LocalRunners<'a>(&'a RwLock<AHashMap<ThreadId, Vec<Arc<LocalQueue>>>>);
1211-
1212-
impl fmt::Debug for LocalRunners<'_> {
1213-
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1214-
match self.0.try_read() {
1215-
Ok(lock) => f
1216-
.debug_list()
1217-
.entries(
1218-
lock.iter()
1219-
.flat_map(|(_, list)| list)
1220-
.map(|queue| queue.queue.len()),
1221-
)
1222-
.finish(),
1223-
Err(TryLockError::WouldBlock) => f.write_str("<locked>"),
1224-
Err(TryLockError::Poisoned(_)) => f.write_str("<poisoned>"),
1225-
}
1226-
}
1227-
}
1234+
// TODO: Add wrapper for thread-local queues.
12281235

12291236
/// Debug wrapper for the sleepers.
12301237
struct SleepCount<'a>(&'a Mutex<Sleepers>);
@@ -1242,18 +1249,131 @@ fn debug_state(state: &State, name: &str, f: &mut fmt::Formatter<'_>) -> fmt::Re
12421249
f.debug_struct(name)
12431250
.field("active", &ActiveTasks(&state.active))
12441251
.field("global_tasks", &state.queue.len())
1245-
.field("local_runners", &LocalRunners(&state.local_queues))
12461252
.field("sleepers", &SleepCount(&state.sleepers))
12471253
.finish()
12481254
}
12491255

1250-
fn thread_id() -> ThreadId {
1256+
fn thread_id() -> usize {
1257+
// TODO: This strategy does not work for WASM, figure out a better way!
1258+
1259+
/// An allocator for thread IDs.
1260+
struct Allocator {
1261+
/// The next thread ID to yield.
1262+
free_id: usize,
1263+
1264+
/// The list of thread ID's that have been released.
1265+
///
1266+
/// This exists to defend against the case where a user spawns a million threads, then calls
1267+
/// this function, then drops all of those threads. In a few moments this strategy could take up
1268+
/// all of the available thread ID space. Therefore we try to reuse thread IDs after they've been
1269+
/// dropped.
1270+
///
1271+
/// We prefer lower thread IDs, as larger thread IDs require more memory in the const-time addressing
1272+
/// strategy we use for thread-specific storage.
1273+
///
1274+
/// This is only `None` at program startup, it's only `Option` for const initialization.
1275+
///
1276+
/// TODO(notgull): make an entry in the "useful features" table for this
1277+
released_ids: Option<VecDeque<Reverse<usize>>>,
1278+
}
1279+
1280+
impl Allocator {
1281+
/// Run a closure with the address allocator.
1282+
fn with<R>(f: impl FnOnce(&mut Allocator) -> R) -> R {
1283+
static ALLOCATOR: Mutex<Allocator> = Mutex::new(Allocator {
1284+
free_id: 0,
1285+
released_ids: None,
1286+
});
1287+
1288+
f(&mut ALLOCATOR.lock().unwrap_or_else(|x| x.into_inner()))
1289+
}
1290+
1291+
/// Get the queue for released IDs.
1292+
fn released_ids(&mut self) -> &mut VecDeque<Reverse<usize>> {
1293+
self.released_ids.get_or_insert_with(VecDeque::default)
1294+
}
1295+
1296+
/// Get the newest ID.
1297+
fn alloc(&mut self) -> usize {
1298+
// If we can, reuse an existing ID.
1299+
if let Some(Reverse(id)) = self.released_ids().pop_front() {
1300+
return id;
1301+
}
1302+
1303+
// Increment our ID counter.
1304+
let id = self.free_id;
1305+
self.free_id = self
1306+
.free_id
1307+
.checked_add(1)
1308+
.expect("took up all available thread-ID space");
1309+
id
1310+
}
1311+
1312+
/// Free an ID that was previously allocated.
1313+
fn free(&mut self, id: usize) {
1314+
self.released_ids().push_front(Reverse(id));
1315+
}
1316+
}
1317+
12511318
thread_local! {
1252-
static ID: ThreadId = thread::current().id();
1319+
/// The unique ID for this thread.
1320+
static THREAD_ID: Cell<Option<usize>> = const { Cell::new(None) };
1321+
}
1322+
1323+
thread_local! {
1324+
/// A destructor that frees this ID before the thread exits.
1325+
///
1326+
/// This is separate from `THREAD_ID` so that accessing it does not involve a thread-local
1327+
/// destructor.
1328+
static THREAD_GUARD: RefCell<Option<ThreadGuard>> = const { RefCell::new(None) };
1329+
}
1330+
1331+
struct ThreadGuard(usize);
1332+
1333+
impl Drop for ThreadGuard {
1334+
fn drop(&mut self) {
1335+
// DEADLOCK: Allocator is only ever held by this and the first call to "thread_id".
1336+
Allocator::with(|alloc| {
1337+
// De-allocate the ID.
1338+
alloc.free(self.0);
1339+
});
1340+
}
1341+
}
1342+
1343+
/// Fast path for getting the thread ID.
1344+
#[inline]
1345+
fn get() -> usize {
1346+
// Try to use the cached thread ID.
1347+
THREAD_ID.with(|thread_id| {
1348+
if let Some(thread_id) = thread_id.get() {
1349+
return thread_id;
1350+
}
1351+
1352+
// Use the slow path.
1353+
get_slow(thread_id)
1354+
})
1355+
}
1356+
1357+
/// Slow path for getting the thread ID.
1358+
#[cold]
1359+
fn get_slow(slot: &Cell<Option<usize>>) -> usize {
1360+
// Allocate a new thread ID.
1361+
let id = Allocator::with(|alloc| alloc.alloc());
1362+
1363+
// Store the thread ID.
1364+
let old = slot.replace(Some(id));
1365+
debug_assert!(old.is_none());
1366+
1367+
// Store the destructor,
1368+
THREAD_GUARD.with(|guard| {
1369+
*guard.borrow_mut() = Some(ThreadGuard(id));
1370+
});
1371+
1372+
// Return the ID.
1373+
id
12531374
}
12541375

1255-
ID.try_with(|id| *id)
1256-
.unwrap_or_else(|_| thread::current().id())
1376+
get()
12571377
}
12581378

12591379
/// Runs a closure when dropped.

0 commit comments

Comments
 (0)