Skip to content

Commit 51cfd3e

Browse files
committed
refactor Lock
1 parent d610b0c commit 51cfd3e

File tree

2 files changed

+133
-34
lines changed

2 files changed

+133
-34
lines changed

compiler/rustc_data_structures/src/lib.rs

+3
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@
2727
#![feature(thread_id_value)]
2828
#![feature(vec_into_raw_parts)]
2929
#![feature(get_mut_unchecked)]
30+
#![feature(const_trait_impl)]
31+
#![feature(const_ptr_as_ref)]
32+
#![feature(const_mut_refs)]
3033
#![allow(rustc::default_hash_types)]
3134
#![allow(rustc::potential_query_instability)]
3235
#![deny(rustc::untranslatable_diagnostic)]

compiler/rustc_data_structures/src/sync.rs

+130-34
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,28 @@
1818
//! depending on the value of cfg!(parallel_compiler).
1919
2020
use crate::owning_ref::{Erased, OwningRef};
21+
use std::cell::Cell;
22+
use std::cell::UnsafeCell;
2123
use std::collections::HashMap;
24+
use std::fmt::{Debug, Formatter};
2225
use std::hash::{BuildHasher, Hash};
26+
use std::intrinsics::likely;
27+
use std::marker::PhantomData;
2328
use std::ops::{Deref, DerefMut};
2429
use std::panic::{catch_unwind, resume_unwind, AssertUnwindSafe};
30+
use std::ptr::NonNull;
2531

32+
use parking_lot::lock_api::RawMutex as _;
33+
use parking_lot::RawMutex;
2634
pub use std::sync::atomic::Ordering;
2735
pub use std::sync::atomic::Ordering::SeqCst;
2836

2937
pub use vec::AppendOnlyVec;
3038

3139
mod vec;
3240

41+
static PARALLEL: std::sync::atomic::AtomicBool = std::sync::atomic::AtomicBool::new(false);
42+
3343
cfg_if! {
3444
if #[cfg(not(parallel_compiler))] {
3545
pub auto trait Send {}
@@ -172,15 +182,11 @@ cfg_if! {
172182
pub use std::cell::Ref as MappedReadGuard;
173183
pub use std::cell::RefMut as WriteGuard;
174184
pub use std::cell::RefMut as MappedWriteGuard;
175-
pub use std::cell::RefMut as LockGuard;
176185
pub use std::cell::RefMut as MappedLockGuard;
177186

178187
pub use std::cell::OnceCell;
179188

180189
use std::cell::RefCell as InnerRwLock;
181-
use std::cell::RefCell as InnerLock;
182-
183-
use std::cell::Cell;
184190

185191
#[derive(Debug)]
186192
pub struct WorkerLocal<T>(OneThread<T>);
@@ -257,7 +263,6 @@ cfg_if! {
257263
pub use parking_lot::RwLockWriteGuard as WriteGuard;
258264
pub use parking_lot::MappedRwLockWriteGuard as MappedWriteGuard;
259265

260-
pub use parking_lot::MutexGuard as LockGuard;
261266
pub use parking_lot::MappedMutexGuard as MappedLockGuard;
262267

263268
pub use std::sync::OnceLock as OnceCell;
@@ -299,7 +304,6 @@ cfg_if! {
299304
}
300305
}
301306

302-
use parking_lot::Mutex as InnerLock;
303307
use parking_lot::RwLock as InnerRwLock;
304308

305309
use std::thread;
@@ -381,55 +385,106 @@ impl<K: Eq + Hash, V: Eq, S: BuildHasher> HashMapExt<K, V> for HashMap<K, V, S>
381385
}
382386
}
383387

384-
#[derive(Debug)]
385-
pub struct Lock<T>(InnerLock<T>);
388+
pub struct Lock<T> {
389+
single_thread: bool,
390+
data: UnsafeCell<T>,
391+
borrow: Cell<bool>,
392+
mutex: RawMutex,
393+
}
394+
395+
impl<T: Debug> Debug for Lock<T> {
396+
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
397+
match self.try_lock() {
398+
Some(guard) => f.debug_struct("Lock").field("data", &&*guard).finish(),
399+
None => {
400+
struct LockedPlaceholder;
401+
impl Debug for LockedPlaceholder {
402+
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
403+
f.write_str("<locked>")
404+
}
405+
}
406+
407+
f.debug_struct("Lock").field("data", &LockedPlaceholder).finish()
408+
}
409+
}
410+
}
411+
}
386412

387413
impl<T> Lock<T> {
388-
#[inline(always)]
389-
pub fn new(inner: T) -> Self {
390-
Lock(InnerLock::new(inner))
414+
#[inline]
415+
pub fn new(val: T) -> Self {
416+
Lock {
417+
single_thread: !PARALLEL.load(Ordering::Relaxed),
418+
data: UnsafeCell::new(val),
419+
borrow: Cell::new(false),
420+
mutex: RawMutex::INIT,
421+
}
391422
}
392423

393-
#[inline(always)]
424+
#[inline]
394425
pub fn into_inner(self) -> T {
395-
self.0.into_inner()
426+
self.data.into_inner()
396427
}
397428

398-
#[inline(always)]
429+
#[inline]
399430
pub fn get_mut(&mut self) -> &mut T {
400-
self.0.get_mut()
431+
self.data.get_mut()
401432
}
402433

403-
#[cfg(parallel_compiler)]
404-
#[inline(always)]
434+
#[inline]
405435
pub fn try_lock(&self) -> Option<LockGuard<'_, T>> {
406-
self.0.try_lock()
436+
// SAFETY: the `&mut T` is accessible as long as self exists.
437+
if likely(self.single_thread) {
438+
if self.borrow.get() {
439+
None
440+
} else {
441+
self.borrow.set(true);
442+
Some(LockGuard {
443+
value: unsafe { NonNull::new_unchecked(self.data.get()) },
444+
lock: &self,
445+
marker: PhantomData,
446+
})
447+
}
448+
} else {
449+
if !self.mutex.try_lock() {
450+
None
451+
} else {
452+
Some(LockGuard {
453+
value: unsafe { NonNull::new_unchecked(self.data.get()) },
454+
lock: &self,
455+
marker: PhantomData,
456+
})
457+
}
458+
}
407459
}
408460

409-
#[cfg(not(parallel_compiler))]
410-
#[inline(always)]
411-
pub fn try_lock(&self) -> Option<LockGuard<'_, T>> {
412-
self.0.try_borrow_mut().ok()
461+
#[inline(never)]
462+
fn lock_mt(&self) -> LockGuard<'_, T> {
463+
self.mutex.lock();
464+
LockGuard {
465+
value: unsafe { NonNull::new_unchecked(self.data.get()) },
466+
lock: &self,
467+
marker: PhantomData,
468+
}
413469
}
414470

415-
#[cfg(parallel_compiler)]
416-
#[inline(always)]
471+
#[inline]
417472
#[track_caller]
418473
pub fn lock(&self) -> LockGuard<'_, T> {
419-
if ERROR_CHECKING {
420-
self.0.try_lock().expect("lock was already held")
474+
// SAFETY: the `&mut T` is accessible as long as self exists.
475+
if likely(self.single_thread) {
476+
assert!(!self.borrow.get());
477+
self.borrow.set(true);
478+
LockGuard {
479+
value: unsafe { NonNull::new_unchecked(self.data.get()) },
480+
lock: &self,
481+
marker: PhantomData,
482+
}
421483
} else {
422-
self.0.lock()
484+
self.lock_mt()
423485
}
424486
}
425487

426-
#[cfg(not(parallel_compiler))]
427-
#[inline(always)]
428-
#[track_caller]
429-
pub fn lock(&self) -> LockGuard<'_, T> {
430-
self.0.borrow_mut()
431-
}
432-
433488
#[inline(always)]
434489
#[track_caller]
435490
pub fn with_lock<F: FnOnce(&mut T) -> R, R>(&self, f: F) -> R {
@@ -464,6 +519,47 @@ impl<T: Clone> Clone for Lock<T> {
464519
}
465520
}
466521

522+
// Just for speed test
523+
unsafe impl<T: Send> std::marker::Send for Lock<T> {}
524+
unsafe impl<T: Send> std::marker::Sync for Lock<T> {}
525+
526+
pub struct LockGuard<'a, T> {
527+
value: NonNull<T>,
528+
lock: &'a Lock<T>,
529+
marker: PhantomData<&'a mut T>,
530+
}
531+
532+
impl<T> const Deref for LockGuard<'_, T> {
533+
type Target = T;
534+
535+
fn deref(&self) -> &T {
536+
unsafe { self.value.as_ref() }
537+
}
538+
}
539+
540+
impl<T> const DerefMut for LockGuard<'_, T> {
541+
fn deref_mut(&mut self) -> &mut T {
542+
unsafe { self.value.as_mut() }
543+
}
544+
}
545+
546+
#[inline(never)]
547+
unsafe fn unlock_mt<T>(guard: &mut LockGuard<'_, T>) {
548+
guard.lock.mutex.unlock()
549+
}
550+
551+
impl<'a, T> Drop for LockGuard<'a, T> {
552+
#[inline]
553+
fn drop(&mut self) {
554+
if likely(self.lock.single_thread) {
555+
debug_assert!(self.lock.borrow.get());
556+
self.lock.borrow.set(false);
557+
} else {
558+
unsafe { unlock_mt(self) }
559+
}
560+
}
561+
}
562+
467563
#[derive(Debug, Default)]
468564
pub struct RwLock<T>(InnerRwLock<T>);
469565

0 commit comments

Comments
 (0)