Skip to content

Commit

Permalink
Support multi threads under send feature flag
Browse files Browse the repository at this point in the history
  • Loading branch information
khvzak committed Jul 31, 2024
1 parent c1395ab commit 658f2a1
Show file tree
Hide file tree
Showing 13 changed files with 233 additions and 113 deletions.
15 changes: 9 additions & 6 deletions src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -594,9 +594,12 @@ impl IntoLua for WrappedAsyncFunction {
}
}

// #[cfg(test)]
// mod assertions {
// use super::*;

// static_assertions::assert_not_impl_any!(Function: Send);
// }
#[cfg(test)]
mod assertions {
use super::*;

#[cfg(not(feature = "send"))]
static_assertions::assert_not_impl_any!(Function: Send);
#[cfg(feature = "send")]
static_assertions::assert_impl_all!(Function: Send, Sync);
}
2 changes: 1 addition & 1 deletion src/hook.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ use std::ops::{BitOr, BitOrAssign};
use std::os::raw::c_int;

use ffi::lua_Debug;
use parking_lot::ReentrantMutexGuard;

use crate::state::RawLua;
use crate::types::ReentrantMutexGuard;
use crate::util::{linenumber_to_usize, ptr_to_lossy_str, ptr_to_str};

/// Contains information about currently executing Lua code.
Expand Down
61 changes: 31 additions & 30 deletions src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,10 @@ use std::marker::PhantomData;
use std::ops::Deref;
use std::os::raw::{c_int, c_void};
use std::panic::Location;
use std::rc::Rc;
use std::result::Result as StdResult;
use std::sync::{Arc, Weak};
use std::{mem, ptr};

use parking_lot::{ReentrantMutex, ReentrantMutexGuard};

use crate::chunk::{AsChunk, Chunk};
use crate::error::{Error, Result};
use crate::function::Function;
Expand All @@ -24,7 +22,7 @@ use crate::table::Table;
use crate::thread::Thread;
use crate::types::{
AppDataRef, AppDataRefMut, ArcReentrantMutexGuard, Integer, LightUserData, MaybeSend, Number,
RegistryKey,
ReentrantMutex, ReentrantMutexGuard, RegistryKey, XRc, XWeak,
};
use crate::userdata::{AnyUserData, UserData, UserDataProxy, UserDataRegistry, UserDataVariant};
use crate::util::{assert_stack, check_stack, push_string, push_table, rawset_field, StackGuard};
Expand All @@ -49,11 +47,11 @@ use util::{callback_error_ext, StateGuard};
/// Top level Lua struct which represents an instance of Lua VM.
#[derive(Clone)]
#[repr(transparent)]
pub struct Lua(Arc<ReentrantMutex<RawLua>>);
pub struct Lua(XRc<ReentrantMutex<RawLua>>);

#[derive(Clone)]
#[repr(transparent)]
pub(crate) struct WeakLua(Weak<ReentrantMutex<RawLua>>);
pub(crate) struct WeakLua(XWeak<ReentrantMutex<RawLua>>);

pub(crate) struct LuaGuard(ArcReentrantMutexGuard<RawLua>);

Expand Down Expand Up @@ -142,11 +140,6 @@ impl LuaOptions {
}
}

/// Requires `feature = "send"`
#[cfg(feature = "send")]
#[cfg_attr(docsrs, doc(cfg(feature = "send")))]
unsafe impl Send for Lua {}

#[cfg(not(feature = "module"))]
impl Drop for Lua {
fn drop(&mut self) {
Expand Down Expand Up @@ -421,7 +414,8 @@ impl Lua {
#[doc(hidden)]
#[cfg(feature = "module")]
pub fn skip_memory_check(&self, skip: bool) {
unsafe { (*self.extra.get()).skip_memory_check = skip };
let lua = self.lock();
unsafe { (*lua.extra.get()).skip_memory_check = skip };
}

/// Enables (or disables) sandbox mode on this Lua instance.
Expand Down Expand Up @@ -605,7 +599,7 @@ impl Lua {
let interrupt_cb = (*extra).interrupt_callback.clone();
let interrupt_cb =
mlua_expect!(interrupt_cb, "no interrupt callback set in interrupt_proc");
if Arc::strong_count(&interrupt_cb) > 2 {
if Rc::strong_count(&interrupt_cb) > 2 {
return Ok(VmState::Continue); // Don't allow recursion
}
let _guard = StateGuard::new((*extra).raw_lua(), state);
Expand All @@ -622,7 +616,7 @@ impl Lua {
// Set interrupt callback
let lua = self.lock();
unsafe {
(*lua.extra.get()).interrupt_callback = Some(Arc::new(callback));
(*lua.extra.get()).interrupt_callback = Some(Rc::new(callback));
(*ffi::lua_callbacks(lua.main_state)).interrupt = Some(interrupt_proc);
}
}
Expand Down Expand Up @@ -947,7 +941,8 @@ impl Lua {
#[cfg(any(feature = "luau-jit", doc))]
#[cfg_attr(docsrs, doc(cfg(feature = "luau-jit")))]
pub fn enable_jit(&self, enable: bool) {
unsafe { (*self.extra.get()).enable_jit = enable };
let lua = self.lock();
unsafe { (*lua.extra.get()).enable_jit = enable };
}

/// Sets Luau feature flag (global setting).
Expand Down Expand Up @@ -1879,15 +1874,15 @@ impl Lua {

#[inline(always)]
pub(crate) fn weak(&self) -> WeakLua {
WeakLua(Arc::downgrade(&self.0))
WeakLua(XRc::downgrade(&self.0))
}
}

impl WeakLua {
#[track_caller]
#[inline(always)]
pub(crate) fn lock(&self) -> LuaGuard {
LuaGuard::new(self.0.upgrade().unwrap())
LuaGuard::new(self.0.upgrade().expect("Lua instance is destroyed"))
}

#[inline(always)]
Expand All @@ -1898,15 +1893,21 @@ impl WeakLua {

impl PartialEq for WeakLua {
fn eq(&self, other: &Self) -> bool {
Weak::ptr_eq(&self.0, &other.0)
XWeak::ptr_eq(&self.0, &other.0)
}
}

impl Eq for WeakLua {}

impl LuaGuard {
pub(crate) fn new(handle: Arc<ReentrantMutex<RawLua>>) -> Self {
Self(handle.lock_arc())
#[cfg(feature = "send")]
pub(crate) fn new(handle: XRc<ReentrantMutex<RawLua>>) -> Self {
LuaGuard(handle.lock_arc())
}

#[cfg(not(feature = "send"))]
pub(crate) fn new(handle: XRc<ReentrantMutex<RawLua>>) -> Self {
LuaGuard(handle.into_lock_arc())
}
}

Expand All @@ -1922,15 +1923,15 @@ pub(crate) mod extra;
mod raw;
pub(crate) mod util;

// #[cfg(test)]
// mod assertions {
// use super::*;
#[cfg(test)]
mod assertions {
use super::*;

// // Lua has lots of interior mutability, should not be RefUnwindSafe
// static_assertions::assert_not_impl_any!(Lua: std::panic::RefUnwindSafe);
// Lua has lots of interior mutability, should not be RefUnwindSafe
static_assertions::assert_not_impl_any!(Lua: std::panic::RefUnwindSafe);

// #[cfg(not(feature = "send"))]
// static_assertions::assert_not_impl_any!(Lua: Send);
// #[cfg(feature = "send")]
// static_assertions::assert_impl_all!(Lua: Send);
// }
#[cfg(not(feature = "send"))]
static_assertions::assert_not_impl_any!(Lua: Send);
#[cfg(feature = "send")]
static_assertions::assert_impl_all!(Lua: Send, Sync);
}
26 changes: 13 additions & 13 deletions src/state/extra.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@ use std::rc::Rc;
use std::mem::{self, MaybeUninit};
use std::os::raw::{c_int, c_void};
use std::ptr;
use std::sync::{Arc, Weak};
use std::sync::Arc;

use parking_lot::{Mutex, ReentrantMutex};
use parking_lot::Mutex;
use rustc_hash::FxHashMap;

use crate::error::Result;
use crate::state::RawLua;
use crate::stdlib::StdLib;
use crate::types::AppData;
use crate::types::{AppData, ReentrantMutex, XRc, XWeak};
use crate::util::{get_gc_metatable, push_gc_userdata, WrappedFailure};

#[cfg(any(feature = "luau", doc))]
Expand All @@ -34,9 +34,9 @@ const REF_STACK_RESERVE: c_int = 1;
/// Data associated with the Lua state.
pub(crate) struct ExtraData {
// Same layout as `Lua`
pub(super) lua: MaybeUninit<Arc<ReentrantMutex<RawLua>>>,
pub(super) lua: MaybeUninit<XRc<ReentrantMutex<RawLua>>>,
// Same layout as `WeakLua`
pub(super) weak: MaybeUninit<Weak<ReentrantMutex<RawLua>>>,
pub(super) weak: MaybeUninit<XWeak<ReentrantMutex<RawLua>>>,

pub(super) registered_userdata: FxHashMap<TypeId, c_int>,
pub(super) registered_userdata_mt: FxHashMap<*const c_void, Option<TypeId>>,
Expand Down Expand Up @@ -107,7 +107,7 @@ impl ExtraData {
#[cfg(any(feature = "lua51", feature = "luajit", feature = "luau"))]
pub(super) const ERROR_TRACEBACK_IDX: c_int = 1;

pub(super) unsafe fn init(state: *mut ffi::lua_State) -> Rc<UnsafeCell<Self>> {
pub(super) unsafe fn init(state: *mut ffi::lua_State) -> XRc<UnsafeCell<Self>> {
// Create ref stack thread and place it in the registry to prevent it
// from being garbage collected.
let ref_thread = mlua_expect!(
Expand All @@ -133,7 +133,7 @@ impl ExtraData {
assert_eq!(ffi::lua_gettop(ref_thread), Self::ERROR_TRACEBACK_IDX);
}

let extra = Rc::new(UnsafeCell::new(ExtraData {
let extra = XRc::new(UnsafeCell::new(ExtraData {
lua: MaybeUninit::uninit(),
weak: MaybeUninit::uninit(),
registered_userdata: FxHashMap::default(),
Expand Down Expand Up @@ -179,12 +179,12 @@ impl ExtraData {
extra
}

pub(super) unsafe fn set_lua(&mut self, lua: &Arc<ReentrantMutex<RawLua>>) {
self.lua.write(Arc::clone(lua));
pub(super) unsafe fn set_lua(&mut self, lua: &XRc<ReentrantMutex<RawLua>>) {
self.lua.write(XRc::clone(lua));
if cfg!(not(feature = "module")) {
Arc::decrement_strong_count(Arc::as_ptr(lua));
XRc::decrement_strong_count(XRc::as_ptr(lua));
}
self.weak.write(Arc::downgrade(lua));
self.weak.write(XRc::downgrade(lua));
}

pub(super) unsafe fn get(state: *mut ffi::lua_State) -> *mut Self {
Expand All @@ -206,14 +206,14 @@ impl ExtraData {
(*extra_ptr).get()
}

unsafe fn store(extra: &Rc<UnsafeCell<Self>>, state: *mut ffi::lua_State) -> Result<()> {
unsafe fn store(extra: &XRc<UnsafeCell<Self>>, state: *mut ffi::lua_State) -> Result<()> {
#[cfg(feature = "luau")]
if cfg!(not(feature = "module")) {
(*ffi::lua_callbacks(state)).userdata = extra.get() as *mut _;
return Ok(());
}

push_gc_userdata(state, Rc::clone(extra), true)?;
push_gc_userdata(state, XRc::clone(extra), true)?;
protect_lua!(state, 1, 0, fn(state) {
let extra_key = &EXTRA_REGISTRY_KEY as *const u8 as *const c_void;
ffi::lua_rawsetp(state, ffi::LUA_REGISTRYINDEX, extra_key);
Expand Down
31 changes: 16 additions & 15 deletions src/state/raw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ use std::result::Result as StdResult;
use std::sync::Arc;
use std::{mem, ptr};

use parking_lot::ReentrantMutex;

use crate::chunk::ChunkMode;
use crate::error::{Error, Result};
use crate::function::Function;
Expand All @@ -21,7 +19,7 @@ use crate::table::Table;
use crate::thread::Thread;
use crate::types::{
AppDataRef, AppDataRefMut, Callback, CallbackUpvalue, DestructedUserdata, Integer,
LightUserData, MaybeSend, RegistryKey, SubtypeId, ValueRef,
LightUserData, MaybeSend, ReentrantMutex, RegistryKey, SubtypeId, ValueRef, XRc,
};
use crate::userdata::{AnyUserData, MetaMethod, UserData, UserDataRegistry, UserDataVariant};
use crate::util::{
Expand Down Expand Up @@ -50,7 +48,7 @@ pub struct RawLua {
// The state is dynamic and depends on context
pub(super) state: Cell<*mut ffi::lua_State>,
pub(super) main_state: *mut ffi::lua_State,
pub(super) extra: Rc<UnsafeCell<ExtraData>>,
pub(super) extra: XRc<UnsafeCell<ExtraData>>,
}

#[cfg(not(feature = "module"))]
Expand All @@ -69,6 +67,9 @@ impl Drop for RawLua {
}
}

#[cfg(feature = "send")]
unsafe impl Send for RawLua {}

impl RawLua {
#[inline(always)]
pub(crate) fn lua(&self) -> &Lua {
Expand Down Expand Up @@ -96,7 +97,7 @@ impl RawLua {
unsafe { (*self.extra.get()).ref_thread }
}

pub(super) unsafe fn new(libs: StdLib, options: LuaOptions) -> Arc<ReentrantMutex<Self>> {
pub(super) unsafe fn new(libs: StdLib, options: LuaOptions) -> XRc<ReentrantMutex<Self>> {
let mem_state: *mut MemoryState = Box::into_raw(Box::default());
let mut state = ffi::lua_newstate(ALLOCATOR, mem_state as *mut c_void);
// If state is null then switch to Lua internal allocator
Expand Down Expand Up @@ -154,7 +155,7 @@ impl RawLua {
rawlua
}

pub(super) unsafe fn init_from_ptr(state: *mut ffi::lua_State) -> Arc<ReentrantMutex<Self>> {
pub(super) unsafe fn init_from_ptr(state: *mut ffi::lua_State) -> XRc<ReentrantMutex<Self>> {
assert!(!state.is_null(), "Lua state is NULL");
if let Some(lua) = Self::try_from_ptr(state) {
return lua;
Expand Down Expand Up @@ -209,10 +210,10 @@ impl RawLua {
assert_stack(main_state, ffi::LUA_MINSTACK);

#[allow(clippy::arc_with_non_send_sync)]
let rawlua = Arc::new(ReentrantMutex::new(RawLua {
let rawlua = XRc::new(ReentrantMutex::new(RawLua {
state: Cell::new(state),
main_state,
extra: Rc::clone(&extra),
extra: XRc::clone(&extra),
}));
(*extra.get()).set_lua(&rawlua);

Expand All @@ -221,10 +222,10 @@ impl RawLua {

pub(super) unsafe fn try_from_ptr(
state: *mut ffi::lua_State,
) -> Option<Arc<ReentrantMutex<Self>>> {
) -> Option<XRc<ReentrantMutex<Self>>> {
match ExtraData::get(state) {
extra if extra.is_null() => None,
extra => Some(Arc::clone(&(*extra).lua().0)),
extra => Some(XRc::clone(&(*extra).lua().0)),
}
}

Expand Down Expand Up @@ -369,7 +370,7 @@ impl RawLua {
callback_error_ext(state, extra, move |_| {
let hook_cb = (*extra).hook_callback.clone();
let hook_cb = mlua_expect!(hook_cb, "no hook callback set in hook_proc");
if Arc::strong_count(&hook_cb) > 2 {
if Rc::strong_count(&hook_cb) > 2 {
return Ok(()); // Don't allow recursion
}
let rawlua = (*extra).raw_lua();
Expand All @@ -379,7 +380,7 @@ impl RawLua {
})
}

(*self.extra.get()).hook_callback = Some(Arc::new(callback));
(*self.extra.get()).hook_callback = Some(Rc::new(callback));
(*self.extra.get()).hook_thread = state; // Mark for what thread the hook is set
ffi::lua_sethook(state, Some(hook_proc), triggers.mask(), triggers.count());
}
Expand Down Expand Up @@ -1105,7 +1106,7 @@ impl RawLua {
check_stack(state, 4)?;

let func = mem::transmute::<Callback, Callback<'static>>(func);
let extra = Rc::clone(&self.extra);
let extra = XRc::clone(&self.extra);
let protect = !self.unlikely_memory_error();
push_gc_userdata(state, CallbackUpvalue { data: func, extra }, protect)?;
if protect {
Expand Down Expand Up @@ -1149,7 +1150,7 @@ impl RawLua {
let args = MultiValue::from_stack_multi(nargs, rawlua)?;
let func = &*(*upvalue).data;
let fut = func(rawlua, args);
let extra = Rc::clone(&(*upvalue).extra);
let extra = XRc::clone(&(*upvalue).extra);
let protect = !rawlua.unlikely_memory_error();
push_gc_userdata(state, AsyncPollUpvalue { data: fut, extra }, protect)?;
if protect {
Expand Down Expand Up @@ -1209,7 +1210,7 @@ impl RawLua {
check_stack(state, 4)?;

let func = mem::transmute::<AsyncCallback, AsyncCallback<'static>>(func);
let extra = Rc::clone(&self.extra);
let extra = XRc::clone(&self.extra);
let protect = !self.unlikely_memory_error();
let upvalue = AsyncCallbackUpvalue { data: func, extra };
push_gc_userdata(state, upvalue, protect)?;
Expand Down
Loading

0 comments on commit 658f2a1

Please sign in to comment.