Skip to content

Commit

Permalink
Optimize callback creation
Browse files Browse the repository at this point in the history
Attach only one upvalue to callbacks rather than two.
This leads to less lookup to Lua registry.
  • Loading branch information
khvzak committed Jun 30, 2021
1 parent fc84e86 commit 41aae83
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 83 deletions.
17 changes: 17 additions & 0 deletions benches/benchmark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,22 @@ fn create_string_table(c: &mut Criterion) {
});
}

fn create_function(c: &mut Criterion) {
let lua = Lua::new();

c.bench_function("create [function] 10", |b| {
b.iter_batched(
|| collect_gc_twice(&lua),
|_| {
for i in 0..10 {
lua.create_function(move |_, ()| Ok(i)).unwrap();
}
},
BatchSize::SmallInput,
);
});
}

fn call_lua_function(c: &mut Criterion) {
let lua = Lua::new();

Expand Down Expand Up @@ -258,6 +274,7 @@ criterion_group! {
create_table,
create_array,
create_string_table,
create_function,
call_lua_function,
call_sum_callback,
call_async_sum_callback,
Expand Down
119 changes: 62 additions & 57 deletions src/lua.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ use crate::string::String;
use crate::table::Table;
use crate::thread::Thread;
use crate::types::{
Callback, HookCallback, Integer, LightUserData, LuaRef, MaybeSend, Number, RegistryKey,
Callback, CallbackUpvalue, HookCallback, Integer, LightUserData, LuaRef, MaybeSend, Number,
RegistryKey,
};
use crate::userdata::{
AnyUserData, MetaMethod, UserData, UserDataCell, UserDataFields, UserDataMethods,
Expand All @@ -38,7 +39,7 @@ use std::rc::Rc;

#[cfg(feature = "async")]
use {
crate::types::AsyncCallback,
crate::types::{AsyncCallback, AsyncCallbackUpvalue, AsyncPollUpvalue},
futures_core::{
future::{Future, LocalBoxFuture},
task::{Context, Poll, Waker},
Expand Down Expand Up @@ -396,12 +397,13 @@ impl Lua {
// to prevent them from being garbage collected.

init_gc_metatable_for::<Callback>(state, None)?;
init_gc_metatable_for::<Lua>(state, None)?;
init_gc_metatable_for::<CallbackUpvalue>(state, None)?;
init_gc_metatable_for::<Weak<Mutex<ExtraData>>>(state, None)?;
#[cfg(feature = "async")]
{
init_gc_metatable_for::<AsyncCallback>(state, None)?;
init_gc_metatable_for::<LocalBoxFuture<Result<MultiValue>>>(state, None)?;
init_gc_metatable_for::<AsyncCallbackUpvalue>(state, None)?;
init_gc_metatable_for::<AsyncPollUpvalue>(state, None)?;
init_gc_metatable_for::<Option<Waker>>(state, None)?;

// Create empty Waker slot
Expand Down Expand Up @@ -1777,22 +1779,22 @@ impl Lua {
'lua: 'callback,
{
unsafe extern "C" fn call_callback(state: *mut ffi::lua_State) -> c_int {
callback_error2(state, |nargs| {
let upvalue_idx1 = ffi::lua_upvalueindex(1);
let upvalue_idx2 = ffi::lua_upvalueindex(2);
if ffi::lua_type(state, upvalue_idx1) == ffi::LUA_TNIL
|| ffi::lua_type(state, upvalue_idx2) == ffi::LUA_TNIL
{
let get_extra = |state| {
let upvalue = get_userdata::<CallbackUpvalue>(state, ffi::lua_upvalueindex(1));
(*upvalue).lua.extra.clone()
};
callback_error_ext(state, get_extra, |nargs| {
let upvalue_idx = ffi::lua_upvalueindex(1);
if ffi::lua_type(state, upvalue_idx) == ffi::LUA_TNIL {
return Err(Error::CallbackDestructed);
}
let func = get_userdata::<Callback>(state, upvalue_idx1);
let lua = get_userdata::<Lua>(state, upvalue_idx2);
let upvalue = get_userdata::<CallbackUpvalue>(state, upvalue_idx);

if nargs < ffi::LUA_MINSTACK {
check_stack(state, ffi::LUA_MINSTACK - nargs)?;
}

let lua = &mut *lua;
let lua = &mut (*upvalue).lua;
lua.state = state;

let mut args = MultiValue::new();
Expand All @@ -1801,7 +1803,7 @@ impl Lua {
args.push_front(lua.pop_value());
}

let results = (*func)(lua, args)?;
let results = ((*upvalue).func)(lua, args)?;
let nresults = results.len() as c_int;

check_stack(state, nresults)?;
Expand All @@ -1815,12 +1817,13 @@ impl Lua {

unsafe {
let _sg = StackGuard::new(self.state);
check_stack(self.state, 5)?;
check_stack(self.state, 4)?;

push_gc_userdata::<Callback>(self.state, mem::transmute(func))?;
push_gc_userdata(self.state, self.clone())?;
protect_lua(self.state, 2, 1, |state| {
ffi::lua_pushcclosure(state, call_callback, 2);
let lua = self.clone();
let func = mem::transmute(func);
push_gc_userdata(self.state, CallbackUpvalue { lua, func })?;
protect_lua(self.state, 1, 1, |state| {
ffi::lua_pushcclosure(state, call_callback, 1);
})?;

Ok(Function(self.pop_ref()))
Expand All @@ -1844,22 +1847,22 @@ impl Lua {
}

unsafe extern "C" fn call_callback(state: *mut ffi::lua_State) -> c_int {
callback_error2(state, |nargs| {
let upvalue_idx1 = ffi::lua_upvalueindex(1);
let upvalue_idx2 = ffi::lua_upvalueindex(2);
if ffi::lua_type(state, upvalue_idx1) == ffi::LUA_TNIL
|| ffi::lua_type(state, upvalue_idx2) == ffi::LUA_TNIL
{
let get_extra = |state| {
let upvalue = get_userdata::<AsyncCallbackUpvalue>(state, ffi::lua_upvalueindex(1));
(*upvalue).lua.extra.clone()
};
callback_error_ext(state, get_extra, |nargs| {
let upvalue_idx = ffi::lua_upvalueindex(1);
if ffi::lua_type(state, upvalue_idx) == ffi::LUA_TNIL {
return Err(Error::CallbackDestructed);
}
let func = get_userdata::<AsyncCallback>(state, upvalue_idx1);
let lua = get_userdata::<Lua>(state, upvalue_idx2);
let upvalue = get_userdata::<AsyncCallbackUpvalue>(state, upvalue_idx);

if nargs < ffi::LUA_MINSTACK {
check_stack(state, ffi::LUA_MINSTACK - nargs)?;
}

let lua = &mut *lua;
let lua = &mut (*upvalue).lua;
lua.state = state;

let mut args = MultiValue::new();
Expand All @@ -1868,35 +1871,34 @@ impl Lua {
args.push_front(lua.pop_value());
}

let fut = (*func)(lua, args);
push_gc_userdata(state, fut)?;
push_gc_userdata(state, lua.clone())?;

protect_lua(state, 2, 1, |state| {
ffi::lua_pushcclosure(state, poll_future, 2);
let fut = ((*upvalue).func)(lua, args);
let lua = lua.clone();
push_gc_userdata(state, AsyncPollUpvalue { lua, fut })?;
protect_lua(state, 1, 1, |state| {
ffi::lua_pushcclosure(state, poll_future, 1);
})?;

Ok(1)
})
}

unsafe extern "C" fn poll_future(state: *mut ffi::lua_State) -> c_int {
callback_error2(state, |nargs| {
let upvalue_idx1 = ffi::lua_upvalueindex(1);
let upvalue_idx2 = ffi::lua_upvalueindex(2);
if ffi::lua_type(state, upvalue_idx1) == ffi::LUA_TNIL
|| ffi::lua_type(state, upvalue_idx2) == ffi::LUA_TNIL
{
let get_extra = |state| {
let upvalue = get_userdata::<AsyncPollUpvalue>(state, ffi::lua_upvalueindex(1));
(*upvalue).lua.extra.clone()
};
callback_error_ext(state, get_extra, |nargs| {
let upvalue_idx = ffi::lua_upvalueindex(1);
if ffi::lua_type(state, upvalue_idx) == ffi::LUA_TNIL {
return Err(Error::CallbackDestructed);
}
let fut = get_userdata::<LocalBoxFuture<Result<MultiValue>>>(state, upvalue_idx1);
let lua = get_userdata::<Lua>(state, upvalue_idx2);
let upvalue = get_userdata::<AsyncPollUpvalue>(state, upvalue_idx);

if nargs < ffi::LUA_MINSTACK {
check_stack(state, ffi::LUA_MINSTACK - nargs)?;
}

let lua = &mut *lua;
let lua = &mut (*upvalue).lua;
lua.state = state;

// Try to get an outer poll waker
Expand All @@ -1910,7 +1912,8 @@ impl Lua {

let mut ctx = Context::from_waker(&waker);

match (*fut).as_mut().poll(&mut ctx) {
let fut = &mut (*upvalue).fut;
match fut.as_mut().poll(&mut ctx) {
Poll::Pending => {
check_stack(state, 1)?;
ffi::lua_pushboolean(state, 0);
Expand All @@ -1932,12 +1935,13 @@ impl Lua {

let get_poll = unsafe {
let _sg = StackGuard::new(self.state);
check_stack(self.state, 5)?;
check_stack(self.state, 4)?;

push_gc_userdata::<AsyncCallback>(self.state, mem::transmute(func))?;
push_gc_userdata(self.state, self.clone())?;
protect_lua(self.state, 2, 1, |state| {
ffi::lua_pushcclosure(state, call_callback, 2);
let lua = self.clone();
let func = mem::transmute(func);
push_gc_userdata(self.state, AsyncCallbackUpvalue { lua, func })?;
protect_lua(self.state, 1, 1, |state| {
ffi::lua_pushcclosure(state, call_callback, 1);
})?;

Function(self.pop_ref())
Expand Down Expand Up @@ -2287,13 +2291,14 @@ impl<'lua, T: AsRef<[u8]> + ?Sized> AsChunk<'lua> for T {

// An optimized version of `callback_error` that does not allocate `WrappedError+Panic` userdata
// and instead reuses unsed and cached values from previous calls (or allocates new).
// It assumes that ephemeral `Lua` struct is passed as a 2nd upvalue.
pub unsafe fn callback_error2<F, R>(state: *mut ffi::lua_State, f: F) -> R
// It requires `get_extra` function to return `ExtraData` value.
unsafe fn callback_error_ext<E, F, R>(state: *mut ffi::lua_State, get_extra: E, f: F) -> R
where
E: Fn(*mut ffi::lua_State) -> Arc<Mutex<ExtraData>>,
F: FnOnce(c_int) -> Result<R>,
{
let upvalue_idx2 = ffi::lua_upvalueindex(2);
if ffi::lua_type(state, upvalue_idx2) == ffi::LUA_TNIL {
let upvalue_idx = ffi::lua_upvalueindex(1);
if ffi::lua_type(state, upvalue_idx) == ffi::LUA_TNIL {
return callback_error(state, f);
}

Expand All @@ -2314,9 +2319,9 @@ where

// We cannot shadow Rust errors with Lua ones, so we need to obtain pre-allocated memory
// to store a wrapped error or panic *before* we proceed.
let lua = get_userdata::<Lua>(state, upvalue_idx2);
let extra = get_extra(state);
let prealloc_err = {
let mut extra = mlua_expect!((*lua).extra.lock(), "extra is poisoned");
let mut extra = mlua_expect!(extra.lock(), "extra is poisoned");
match extra.prealloc_wrapped_errors.pop() {
Some(index) => PreallocatedError::Cached(index),
None => {
Expand All @@ -2329,7 +2334,7 @@ where
};

let get_prealloc_err = || {
let mut extra = mlua_expect!((*lua).extra.lock(), "extra is poisoned");
let mut extra = mlua_expect!(extra.lock(), "extra is poisoned");
match prealloc_err {
PreallocatedError::New(ud) => {
ffi::lua_settop(state, 1);
Expand All @@ -2350,7 +2355,7 @@ where
match catch_unwind(AssertUnwindSafe(|| f(nargs))) {
Ok(Ok(r)) => {
// Return unused WrappedError+Panic to the cache
let mut extra = mlua_expect!((*lua).extra.lock(), "extra is poisoned");
let mut extra = mlua_expect!(extra.lock(), "extra is poisoned");
match prealloc_err {
PreallocatedError::New(_) if extra.prealloc_wrapped_errors.len() < 16 => {
ffi::lua_rotate(state, 1, -1);
Expand Down
36 changes: 10 additions & 26 deletions src/scope.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use crate::error::{Error, Result};
use crate::ffi;
use crate::function::Function;
use crate::lua::Lua;
use crate::types::{Callback, LuaRef, MaybeSend};
use crate::types::{Callback, CallbackUpvalue, LuaRef, MaybeSend};
use crate::userdata::{
AnyUserData, MetaMethod, UserData, UserDataCell, UserDataFields, UserDataMethods,
};
Expand All @@ -25,8 +25,8 @@ use crate::value::{FromLua, FromLuaMulti, MultiValue, ToLua, ToLuaMulti, Value};

#[cfg(feature = "async")]
use {
crate::types::AsyncCallback,
futures_core::future::{Future, LocalBoxFuture},
crate::types::{AsyncCallback, AsyncCallbackUpvalue, AsyncPollUpvalue},
futures_core::future::Future,
futures_util::future::{self, TryFutureExt},
};

Expand Down Expand Up @@ -224,7 +224,7 @@ impl<'lua, 'scope> Scope<'lua, 'scope> {
/// use [`Scope::create_userdata`] instead.
///
/// The main limitation that comes from using non-'static userdata is that the produced userdata
/// will no longer have a `TypeId` associated with it, becuase `TypeId` can only work for
/// will no longer have a `TypeId` associated with it, because `TypeId` can only work for
/// 'static types. This means that it is impossible, once the userdata is created, to get a
/// reference to it back *out* of an `AnyUserData` handle. This also implies that the
/// "function" type methods that can be added via [`UserDataMethods`] (the ones that accept
Expand Down Expand Up @@ -460,16 +460,11 @@ impl<'lua, 'scope> Scope<'lua, 'scope> {
// We know the destructor has not run yet because we hold a reference to the callback.

ffi::lua_getupvalue(state, -1, 1);
let ud1 = take_userdata::<Callback>(state);
let ud = take_userdata::<CallbackUpvalue>(state);
ffi::lua_pushnil(state);
ffi::lua_setupvalue(state, -2, 1);

ffi::lua_getupvalue(state, -1, 2);
let ud2 = take_userdata::<Lua>(state);
ffi::lua_pushnil(state);
ffi::lua_setupvalue(state, -2, 2);

vec![Box::new(ud1), Box::new(ud2)]
vec![Box::new(ud)]
});
self.destructors
.borrow_mut()
Expand Down Expand Up @@ -510,32 +505,21 @@ impl<'lua, 'scope> Scope<'lua, 'scope> {

// Destroy all upvalues
ffi::lua_getupvalue(state, -1, 1);
let ud1 = take_userdata::<AsyncCallback>(state);
let upvalue1 = take_userdata::<AsyncCallbackUpvalue>(state);
ffi::lua_pushnil(state);
ffi::lua_setupvalue(state, -2, 1);

ffi::lua_getupvalue(state, -1, 2);
let ud2 = take_userdata::<Lua>(state);
ffi::lua_pushnil(state);
ffi::lua_setupvalue(state, -2, 2);

ffi::lua_pop(state, 1);
let mut data: Vec<Box<dyn Any>> = vec![Box::new(ud1), Box::new(ud2)];
let mut data: Vec<Box<dyn Any>> = vec![Box::new(upvalue1)];

// Finally, get polled future and destroy it
f.lua.push_ref(&poll_str.0);
if ffi::lua_rawget(state, -2) == ffi::LUA_TFUNCTION {
ffi::lua_getupvalue(state, -1, 1);
let ud3 = take_userdata::<LocalBoxFuture<Result<MultiValue>>>(state);
let upvalue2 = take_userdata::<AsyncPollUpvalue>(state);
ffi::lua_pushnil(state);
ffi::lua_setupvalue(state, -2, 1);
data.push(Box::new(ud3));

ffi::lua_getupvalue(state, -1, 2);
let ud4 = take_userdata::<Lua>(state);
ffi::lua_pushnil(state);
ffi::lua_setupvalue(state, -2, 2);
data.push(Box::new(ud4));
data.push(Box::new(upvalue2));
}

data
Expand Down
17 changes: 17 additions & 0 deletions src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,27 @@ pub struct LightUserData(pub *mut c_void);
pub(crate) type Callback<'lua, 'a> =
Box<dyn Fn(&'lua Lua, MultiValue<'lua>) -> Result<MultiValue<'lua>> + 'a>;

pub(crate) struct CallbackUpvalue<'lua> {
pub(crate) lua: Lua,
pub(crate) func: Callback<'lua, 'static>,
}

#[cfg(feature = "async")]
pub(crate) type AsyncCallback<'lua, 'a> =
Box<dyn Fn(&'lua Lua, MultiValue<'lua>) -> LocalBoxFuture<'lua, Result<MultiValue<'lua>>> + 'a>;

#[cfg(feature = "async")]
pub(crate) struct AsyncCallbackUpvalue<'lua> {
pub(crate) lua: Lua,
pub(crate) func: AsyncCallback<'lua, 'static>,
}

#[cfg(feature = "async")]
pub(crate) struct AsyncPollUpvalue<'lua> {
pub(crate) lua: Lua,
pub(crate) fut: LocalBoxFuture<'lua, Result<MultiValue<'lua>>>,
}

pub(crate) type HookCallback = Arc<RefCell<dyn FnMut(&Lua, Debug) -> Result<()>>>;

#[cfg(feature = "send")]
Expand Down

0 comments on commit 41aae83

Please sign in to comment.