Skip to content

Commit bb31134

Browse files
committed
Switch between shared and exclusive lock for UserDataRef depending if T: Sync or not.
1 parent 928e1d9 commit bb31134

File tree

5 files changed

+120
-20
lines changed

5 files changed

+120
-20
lines changed

src/userdata.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1082,6 +1082,7 @@ mod cell;
10821082
mod lock;
10831083
mod object;
10841084
mod registry;
1085+
mod util;
10851086

10861087
#[cfg(test)]
10871088
mod assertions {

src/userdata/cell.rs

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ use crate::util::get_userdata;
1616
use crate::value::Value;
1717

1818
use super::lock::{RawLock, UserDataLock};
19+
use super::util::is_sync;
1920

2021
#[cfg(all(feature = "serialize", not(feature = "send")))]
2122
type DynSerialize = dyn erased_serde::Serialize;
@@ -164,7 +165,11 @@ impl<T> Deref for UserDataRef<T> {
164165
impl<T> Drop for UserDataRef<T> {
165166
#[inline]
166167
fn drop(&mut self) {
167-
unsafe { self.0.raw_lock().unlock_shared() };
168+
if !cfg!(feature = "send") || is_sync::<T>() {
169+
unsafe { self.0.raw_lock().unlock_shared() };
170+
} else {
171+
unsafe { self.0.raw_lock().unlock_exclusive() };
172+
}
168173
}
169174
}
170175

@@ -185,7 +190,11 @@ impl<T> TryFrom<UserDataVariant<T>> for UserDataRef<T> {
185190

186191
#[inline]
187192
fn try_from(variant: UserDataVariant<T>) -> Result<Self> {
188-
if !variant.raw_lock().try_lock_shared() {
193+
if !cfg!(feature = "send") || is_sync::<T>() {
194+
if !variant.raw_lock().try_lock_shared() {
195+
return Err(Error::UserDataBorrowError);
196+
}
197+
} else if !variant.raw_lock().try_lock_exclusive() {
189198
return Err(Error::UserDataBorrowError);
190199
}
191200
Ok(UserDataRef(variant))
@@ -282,7 +291,11 @@ pub(crate) struct UserDataBorrowRef<'a, T>(&'a UserDataVariant<T>);
282291
impl<T> Drop for UserDataBorrowRef<'_, T> {
283292
#[inline]
284293
fn drop(&mut self) {
285-
unsafe { self.0.raw_lock().unlock_shared() };
294+
if !cfg!(feature = "send") || is_sync::<T>() {
295+
unsafe { self.0.raw_lock().unlock_shared() };
296+
} else {
297+
unsafe { self.0.raw_lock().unlock_exclusive() };
298+
}
286299
}
287300
}
288301

@@ -301,7 +314,11 @@ impl<'a, T> TryFrom<&'a UserDataVariant<T>> for UserDataBorrowRef<'a, T> {
301314

302315
#[inline(always)]
303316
fn try_from(variant: &'a UserDataVariant<T>) -> Result<Self> {
304-
if !variant.raw_lock().try_lock_shared() {
317+
if !cfg!(feature = "send") || is_sync::<T>() {
318+
if !variant.raw_lock().try_lock_shared() {
319+
return Err(Error::UserDataBorrowError);
320+
}
321+
} else if !variant.raw_lock().try_lock_exclusive() {
305322
return Err(Error::UserDataBorrowError);
306323
}
307324
Ok(UserDataBorrowRef(variant))

src/userdata/lock.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -63,32 +63,32 @@ mod lock_impl {
6363

6464
#[cfg(feature = "send")]
6565
mod lock_impl {
66-
use parking_lot::lock_api::RawMutex;
66+
use parking_lot::lock_api::RawRwLock;
6767

68-
pub(crate) type RawLock = parking_lot::RawMutex;
68+
pub(crate) type RawLock = parking_lot::RawRwLock;
6969

7070
impl super::UserDataLock for RawLock {
7171
#[allow(clippy::declare_interior_mutable_const)]
72-
const INIT: Self = <Self as parking_lot::lock_api::RawMutex>::INIT;
72+
const INIT: Self = <Self as parking_lot::lock_api::RawRwLock>::INIT;
7373

7474
#[inline(always)]
7575
fn try_lock_shared(&self) -> bool {
76-
RawLock::try_lock(self)
76+
RawRwLock::try_lock_shared(self)
7777
}
7878

7979
#[inline(always)]
8080
fn try_lock_exclusive(&self) -> bool {
81-
RawLock::try_lock(self)
81+
RawRwLock::try_lock_exclusive(self)
8282
}
8383

8484
#[inline(always)]
8585
unsafe fn unlock_shared(&self) {
86-
RawLock::unlock(self)
86+
RawRwLock::unlock_shared(self)
8787
}
8888

8989
#[inline(always)]
9090
unsafe fn unlock_exclusive(&self) {
91-
RawLock::unlock(self)
91+
RawRwLock::unlock_exclusive(self)
9292
}
9393
}
9494
}

src/userdata/util.rs

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
use std::cell::Cell;
2+
use std::marker::PhantomData;
3+
4+
// This is a trick to check if a type is `Sync` or not.
5+
// It uses leaked specialization feature from stdlib.
6+
struct IsSync<'a, T> {
7+
is_sync: &'a Cell<bool>,
8+
_marker: PhantomData<T>,
9+
}
10+
11+
impl<T> Clone for IsSync<'_, T> {
12+
fn clone(&self) -> Self {
13+
self.is_sync.set(false);
14+
IsSync {
15+
is_sync: self.is_sync,
16+
_marker: PhantomData,
17+
}
18+
}
19+
}
20+
21+
impl<T: Sync> Copy for IsSync<'_, T> {}
22+
23+
pub(crate) fn is_sync<T>() -> bool {
24+
let is_sync = Cell::new(true);
25+
let _ = [IsSync::<T> {
26+
is_sync: &is_sync,
27+
_marker: PhantomData,
28+
}]
29+
.clone();
30+
is_sync.get()
31+
}

tests/send.rs

Lines changed: 60 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,25 +4,36 @@ use std::cell::UnsafeCell;
44
use std::marker::PhantomData;
55
use std::string::String as StdString;
66

7-
use mlua::{AnyUserData, Error, Lua, Result, UserDataRef};
7+
use mlua::{AnyUserData, Error, Lua, ObjectLike, Result, UserData, UserDataMethods, UserDataRef};
88
use static_assertions::{assert_impl_all, assert_not_impl_all};
99

1010
#[test]
11-
fn test_userdata_multithread_access() -> Result<()> {
11+
fn test_userdata_multithread_access_send_only() -> Result<()> {
1212
let lua = Lua::new();
1313

1414
// This type is `Send` but not `Sync`.
15-
struct MyUserData(#[allow(unused)] StdString, PhantomData<UnsafeCell<()>>);
16-
15+
struct MyUserData(StdString, PhantomData<UnsafeCell<()>>);
1716
assert_impl_all!(MyUserData: Send);
1817
assert_not_impl_all!(MyUserData: Sync);
1918

20-
lua.globals().set(
21-
"ud",
22-
AnyUserData::wrap(MyUserData("hello".to_string(), PhantomData)),
23-
)?;
19+
impl UserData for MyUserData {
20+
fn add_methods<M: UserDataMethods<Self>>(methods: &mut M) {
21+
methods.add_method("method", |lua, this, ()| {
22+
let ud = lua.globals().get::<AnyUserData>("ud")?;
23+
assert!((ud.call_method::<()>("method2", ()).err().unwrap().to_string())
24+
.contains("error borrowing userdata"));
25+
Ok(this.0.clone())
26+
});
27+
28+
methods.add_method("method2", |_, _, ()| Ok(()));
29+
}
30+
}
31+
32+
lua.globals()
33+
.set("ud", MyUserData("hello".to_string(), PhantomData))?;
34+
2435
// We acquired the exclusive reference.
25-
let _ud1 = lua.globals().get::<UserDataRef<MyUserData>>("ud")?;
36+
let ud = lua.globals().get::<UserDataRef<MyUserData>>("ud")?;
2637

2738
std::thread::scope(|s| {
2839
s.spawn(|| {
@@ -31,5 +42,45 @@ fn test_userdata_multithread_access() -> Result<()> {
3142
});
3243
});
3344

45+
drop(ud);
46+
lua.load("ud:method()").exec().unwrap();
47+
48+
Ok(())
49+
}
50+
51+
#[test]
52+
fn test_userdata_multithread_access_sync() -> Result<()> {
53+
let lua = Lua::new();
54+
55+
// This type is `Send` and `Sync`.
56+
struct MyUserData(StdString);
57+
assert_impl_all!(MyUserData: Send, Sync);
58+
59+
impl UserData for MyUserData {
60+
fn add_methods<M: UserDataMethods<Self>>(methods: &mut M) {
61+
methods.add_method("method", |lua, this, ()| {
62+
let ud = lua.globals().get::<AnyUserData>("ud")?;
63+
assert!(ud.call_method::<()>("method2", ()).is_ok());
64+
Ok(this.0.clone())
65+
});
66+
67+
methods.add_method("method2", |_, _, ()| Ok(()));
68+
}
69+
}
70+
71+
lua.globals().set("ud", MyUserData("hello".to_string()))?;
72+
73+
// We acquired the shared reference.
74+
let _ud = lua.globals().get::<UserDataRef<MyUserData>>("ud")?;
75+
76+
std::thread::scope(|s| {
77+
s.spawn(|| {
78+
// Getting another shared reference for `Sync` type is allowed.
79+
let _ = lua.globals().get::<UserDataRef<MyUserData>>("ud").unwrap();
80+
});
81+
});
82+
83+
lua.load("ud:method()").exec().unwrap();
84+
3485
Ok(())
3586
}

0 commit comments

Comments
 (0)