Skip to content

Commit

Permalink
Switch between shared and exclusive lock for UserDataRef depending …
Browse files Browse the repository at this point in the history
…if `T: Sync` or not.
  • Loading branch information
khvzak committed Nov 2, 2024
1 parent 928e1d9 commit bb31134
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 20 deletions.
1 change: 1 addition & 0 deletions src/userdata.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1082,6 +1082,7 @@ mod cell;
mod lock;
mod object;
mod registry;
mod util;

#[cfg(test)]
mod assertions {
Expand Down
25 changes: 21 additions & 4 deletions src/userdata/cell.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use crate::util::get_userdata;
use crate::value::Value;

use super::lock::{RawLock, UserDataLock};
use super::util::is_sync;

#[cfg(all(feature = "serialize", not(feature = "send")))]
type DynSerialize = dyn erased_serde::Serialize;
Expand Down Expand Up @@ -164,7 +165,11 @@ impl<T> Deref for UserDataRef<T> {
impl<T> Drop for UserDataRef<T> {
#[inline]
fn drop(&mut self) {
unsafe { self.0.raw_lock().unlock_shared() };
if !cfg!(feature = "send") || is_sync::<T>() {
unsafe { self.0.raw_lock().unlock_shared() };
} else {
unsafe { self.0.raw_lock().unlock_exclusive() };
}
}
}

Expand All @@ -185,7 +190,11 @@ impl<T> TryFrom<UserDataVariant<T>> for UserDataRef<T> {

#[inline]
fn try_from(variant: UserDataVariant<T>) -> Result<Self> {
if !variant.raw_lock().try_lock_shared() {
if !cfg!(feature = "send") || is_sync::<T>() {
if !variant.raw_lock().try_lock_shared() {
return Err(Error::UserDataBorrowError);
}
} else if !variant.raw_lock().try_lock_exclusive() {
return Err(Error::UserDataBorrowError);
}
Ok(UserDataRef(variant))
Expand Down Expand Up @@ -282,7 +291,11 @@ pub(crate) struct UserDataBorrowRef<'a, T>(&'a UserDataVariant<T>);
impl<T> Drop for UserDataBorrowRef<'_, T> {
#[inline]
fn drop(&mut self) {
unsafe { self.0.raw_lock().unlock_shared() };
if !cfg!(feature = "send") || is_sync::<T>() {
unsafe { self.0.raw_lock().unlock_shared() };
} else {
unsafe { self.0.raw_lock().unlock_exclusive() };
}
}
}

Expand All @@ -301,7 +314,11 @@ impl<'a, T> TryFrom<&'a UserDataVariant<T>> for UserDataBorrowRef<'a, T> {

#[inline(always)]
fn try_from(variant: &'a UserDataVariant<T>) -> Result<Self> {
if !variant.raw_lock().try_lock_shared() {
if !cfg!(feature = "send") || is_sync::<T>() {
if !variant.raw_lock().try_lock_shared() {
return Err(Error::UserDataBorrowError);
}
} else if !variant.raw_lock().try_lock_exclusive() {
return Err(Error::UserDataBorrowError);
}
Ok(UserDataBorrowRef(variant))
Expand Down
14 changes: 7 additions & 7 deletions src/userdata/lock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,32 +63,32 @@ mod lock_impl {

#[cfg(feature = "send")]
mod lock_impl {
use parking_lot::lock_api::RawMutex;
use parking_lot::lock_api::RawRwLock;

pub(crate) type RawLock = parking_lot::RawMutex;
pub(crate) type RawLock = parking_lot::RawRwLock;

impl super::UserDataLock for RawLock {
#[allow(clippy::declare_interior_mutable_const)]
const INIT: Self = <Self as parking_lot::lock_api::RawMutex>::INIT;
const INIT: Self = <Self as parking_lot::lock_api::RawRwLock>::INIT;

#[inline(always)]
fn try_lock_shared(&self) -> bool {
RawLock::try_lock(self)
RawRwLock::try_lock_shared(self)
}

#[inline(always)]
fn try_lock_exclusive(&self) -> bool {
RawLock::try_lock(self)
RawRwLock::try_lock_exclusive(self)
}

#[inline(always)]
unsafe fn unlock_shared(&self) {
RawLock::unlock(self)
RawRwLock::unlock_shared(self)
}

#[inline(always)]
unsafe fn unlock_exclusive(&self) {
RawLock::unlock(self)
RawRwLock::unlock_exclusive(self)
}
}
}
31 changes: 31 additions & 0 deletions src/userdata/util.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
use std::cell::Cell;
use std::marker::PhantomData;

// This is a trick to check if a type is `Sync` or not.
// It uses leaked specialization feature from stdlib.
struct IsSync<'a, T> {
is_sync: &'a Cell<bool>,
_marker: PhantomData<T>,
}

impl<T> Clone for IsSync<'_, T> {
fn clone(&self) -> Self {
self.is_sync.set(false);
IsSync {
is_sync: self.is_sync,
_marker: PhantomData,
}
}
}

impl<T: Sync> Copy for IsSync<'_, T> {}

pub(crate) fn is_sync<T>() -> bool {
let is_sync = Cell::new(true);
let _ = [IsSync::<T> {
is_sync: &is_sync,
_marker: PhantomData,
}]
.clone();
is_sync.get()
}
69 changes: 60 additions & 9 deletions tests/send.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,36 @@ use std::cell::UnsafeCell;
use std::marker::PhantomData;
use std::string::String as StdString;

use mlua::{AnyUserData, Error, Lua, Result, UserDataRef};
use mlua::{AnyUserData, Error, Lua, ObjectLike, Result, UserData, UserDataMethods, UserDataRef};
use static_assertions::{assert_impl_all, assert_not_impl_all};

#[test]
fn test_userdata_multithread_access() -> Result<()> {
fn test_userdata_multithread_access_send_only() -> Result<()> {
let lua = Lua::new();

// This type is `Send` but not `Sync`.
struct MyUserData(#[allow(unused)] StdString, PhantomData<UnsafeCell<()>>);

struct MyUserData(StdString, PhantomData<UnsafeCell<()>>);
assert_impl_all!(MyUserData: Send);
assert_not_impl_all!(MyUserData: Sync);

lua.globals().set(
"ud",
AnyUserData::wrap(MyUserData("hello".to_string(), PhantomData)),
)?;
impl UserData for MyUserData {
fn add_methods<M: UserDataMethods<Self>>(methods: &mut M) {
methods.add_method("method", |lua, this, ()| {
let ud = lua.globals().get::<AnyUserData>("ud")?;
assert!((ud.call_method::<()>("method2", ()).err().unwrap().to_string())
.contains("error borrowing userdata"));
Ok(this.0.clone())
});

methods.add_method("method2", |_, _, ()| Ok(()));
}
}

lua.globals()
.set("ud", MyUserData("hello".to_string(), PhantomData))?;

// We acquired the exclusive reference.
let _ud1 = lua.globals().get::<UserDataRef<MyUserData>>("ud")?;
let ud = lua.globals().get::<UserDataRef<MyUserData>>("ud")?;

std::thread::scope(|s| {
s.spawn(|| {
Expand All @@ -31,5 +42,45 @@ fn test_userdata_multithread_access() -> Result<()> {
});
});

drop(ud);
lua.load("ud:method()").exec().unwrap();

Ok(())
}

#[test]
fn test_userdata_multithread_access_sync() -> Result<()> {
let lua = Lua::new();

// This type is `Send` and `Sync`.
struct MyUserData(StdString);
assert_impl_all!(MyUserData: Send, Sync);

impl UserData for MyUserData {
fn add_methods<M: UserDataMethods<Self>>(methods: &mut M) {
methods.add_method("method", |lua, this, ()| {
let ud = lua.globals().get::<AnyUserData>("ud")?;
assert!(ud.call_method::<()>("method2", ()).is_ok());
Ok(this.0.clone())
});

methods.add_method("method2", |_, _, ()| Ok(()));
}
}

lua.globals().set("ud", MyUserData("hello".to_string()))?;

// We acquired the shared reference.
let _ud = lua.globals().get::<UserDataRef<MyUserData>>("ud")?;

std::thread::scope(|s| {
s.spawn(|| {
// Getting another shared reference for `Sync` type is allowed.
let _ = lua.globals().get::<UserDataRef<MyUserData>>("ud").unwrap();
});
});

lua.load("ud:method()").exec().unwrap();

Ok(())
}

0 comments on commit bb31134

Please sign in to comment.