Skip to content

Commit

Permalink
Rewrite Luau require function to support module loaders.
Browse files Browse the repository at this point in the history
Also add `package` library with `path`/`loaded`/`loaders`.
  • Loading branch information
khvzak committed Nov 16, 2023
1 parent 34476eb commit 2d77569
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 55 deletions.
2 changes: 1 addition & 1 deletion src/lua.rs
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,7 @@ impl Lua {
}

#[cfg(feature = "luau")]
mlua_expect!(lua.prepare_luau_state(), "Error preparing Luau state");
mlua_expect!(lua.prepare_luau_state(), "Error configuring Luau");

lua
}
Expand Down
192 changes: 141 additions & 51 deletions src/luau.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
use std::ffi::CStr;
use std::fmt::Write;
use std::os::raw::{c_float, c_int};
use std::path::{PathBuf, MAIN_SEPARATOR_STR};
use std::string::String as StdString;
use std::{env, fs};

use crate::chunk::ChunkMode;
use crate::error::{Error, Result};
use crate::error::Result;
use crate::lua::Lua;
use crate::table::Table;
use crate::util::{check_stack, StackGuard};
use crate::value::Value;
use crate::types::RegistryKey;
use crate::value::{IntoLua, Value};

// Since Luau has some missing standard function, we re-implement them here

// We keep reference to the `package` table in registry under this key
struct PackageKey(RegistryKey);

impl Lua {
pub(crate) unsafe fn prepare_luau_state(&self) -> Result<()> {
let globals = self.globals();
Expand All @@ -19,7 +25,8 @@ impl Lua {
"collectgarbage",
self.create_c_function(lua_collectgarbage)?,
)?;
globals.raw_set("require", self.create_function(lua_require)?)?;
globals.raw_set("require", self.create_c_function(lua_require)?)?;
globals.raw_set("package", create_package_table(self)?)?;
globals.raw_set("vector", self.create_c_function(lua_vector)?)?;

// Set `_VERSION` global to include version number
Expand Down Expand Up @@ -69,56 +76,57 @@ unsafe extern "C-unwind" fn lua_collectgarbage(state: *mut ffi::lua_State) -> c_
}
}

fn lua_require(lua: &Lua, name: Option<StdString>) -> Result<Value> {
let name = name.ok_or_else(|| Error::runtime("invalid module name"))?;

// Find module in the cache
let state = lua.state();
let loaded = unsafe {
let _sg = StackGuard::new(state);
check_stack(state, 2)?;
protect_lua!(state, 0, 1, fn(state) {
ffi::luaL_getsubtable(state, ffi::LUA_REGISTRYINDEX, cstr!("_LOADED"));
})?;
Table(lua.pop_ref())
};
if let Some(v) = loaded.raw_get(name.clone())? {
return Ok(v);
}

// Load file from filesystem
let mut search_path = std::env::var("LUAU_PATH").unwrap_or_default();
if search_path.is_empty() {
search_path = "?.luau;?.lua".into();
unsafe extern "C-unwind" fn lua_require(state: *mut ffi::lua_State) -> c_int {
ffi::lua_settop(state, 1);
let name = ffi::luaL_checkstring(state, 1);
ffi::luaL_getsubtable(state, ffi::LUA_REGISTRYINDEX, cstr!("_LOADED")); // _LOADED is at index 2
if ffi::lua_rawgetfield(state, 2, name) != ffi::LUA_TNIL {
return 1; // module is already loaded
}

let (mut source, mut source_name) = (None, String::new());
for path in search_path.split(';') {
let file_path = path.replacen('?', &name, 1);
if let Ok(buf) = std::fs::read(&file_path) {
source = Some(buf);
source_name = file_path;
break;
ffi::lua_pop(state, 1); // remove nil

// load the module
let err_buf = ffi::lua_newuserdata_t::<StdString>(state);
err_buf.write(StdString::new());
ffi::luaL_getsubtable(state, ffi::LUA_REGISTRYINDEX, cstr!("_LOADERS")); // _LOADERS is at index 3
for i in 1.. {
if ffi::lua_rawgeti(state, -1, i) == ffi::LUA_TNIL {
// no more loaders?
if (*err_buf).is_empty() {
ffi::luaL_error(state, cstr!("module '%s' not found"), name);
} else {
let bytes = (*err_buf).as_bytes();
let extra = ffi::lua_pushlstring(state, bytes.as_ptr() as *const _, bytes.len());
ffi::luaL_error(state, cstr!("module '%s' not found:%s"), name, extra);
}
}
ffi::lua_pushvalue(state, 1); // name arg
ffi::lua_call(state, 1, 2); // call loader
match ffi::lua_type(state, -2) {
ffi::LUA_TFUNCTION => break, // loader found
ffi::LUA_TSTRING => {
// error message
let msg = ffi::lua_tostring(state, -2);
let msg = CStr::from_ptr(msg).to_string_lossy();
_ = write!(&mut *err_buf, "\n\t{msg}");
}
_ => {}
}
ffi::lua_pop(state, 2); // remove both results
}
ffi::lua_pushvalue(state, 1); // name is 1st argument to module loader
ffi::lua_rotate(state, -2, 1); // loader data <-> name

// stack: ...; loader function; module name; loader data
ffi::lua_call(state, 2, 1);
// stack: ...; result from loader function
if ffi::lua_isnil(state, -1) != 0 {
ffi::lua_pop(state, 1);
ffi::lua_pushboolean(state, 1); // use true as result
}
let source = source.ok_or_else(|| Error::runtime(format!("cannot find '{name}'")))?;

let value = lua
.load(&source)
.set_name(&format!("={source_name}"))
.set_mode(ChunkMode::Text)
.call::<_, Value>(())?;

// Save in the cache
loaded.raw_set(
name,
match value.clone() {
Value::Nil => Value::Boolean(true),
v => v,
},
)?;

Ok(value)
ffi::lua_pushvalue(state, -1); // make copy of entrypoint result
ffi::lua_setfield(state, 2, name); /* _LOADED[name] = returned value */
1
}

// Luau vector datatype constructor
Expand All @@ -135,3 +143,85 @@ unsafe extern "C-unwind" fn lua_vector(state: *mut ffi::lua_State) -> c_int {
ffi::lua_pushvector(state, x, y, z, w);
1
}

//
// package module
//

fn create_package_table(lua: &Lua) -> Result<Table> {
// Create the package table and store it in app_data for later use (bypassing globals lookup)
let package = lua.create_table()?;
lua.set_app_data(PackageKey(lua.create_registry_value(package.clone())?));

// Set `package.path`
let mut search_path = env::var("LUAU_PATH")
.or_else(|_| env::var("LUA_PATH"))
.unwrap_or_default();
if search_path.is_empty() {
search_path = "?.luau;?.lua".to_string();
}
package.raw_set("path", search_path)?;

// Set `package.loaded` (table with a list of loaded modules)
let loaded = lua.create_table()?;
package.raw_set("loaded", loaded.clone())?;
lua.set_named_registry_value("_LOADED", loaded)?;

// Set `package.loaders`
let loaders = lua.create_sequence_from([lua.create_function(lua_loader)?])?;
package.raw_set("loaders", loaders.clone())?;
lua.set_named_registry_value("_LOADERS", loaders)?;

Ok(package)
}

/// Searches for the given `name` in the given `path`.
///
/// `path` is a string containing a sequence of templates separated by semicolons.
fn package_searchpath(name: &str, search_path: &str, try_prefix: bool) -> Option<PathBuf> {
let mut names = vec![name.replace('.', MAIN_SEPARATOR_STR)];
if try_prefix && name.contains('.') {
let prefix = name.split_once('.').map(|(prefix, _)| prefix).unwrap();
names.push(prefix.to_string());
}
for path in search_path.split(';') {
for name in &names {
let file_path = PathBuf::from(path.replace('?', name));
if let Ok(true) = fs::metadata(&file_path).map(|m| m.is_file()) {
return Some(file_path);
}
}
}
None
}

//
// Module loaders
//

/// Tries to load a lua (text) file
fn lua_loader(lua: &Lua, modname: StdString) -> Result<Value> {
let package = {
let key = lua.app_data_ref::<PackageKey>().unwrap();
lua.registry_value::<Table>(&key.0)
}?;
let search_path = package.get::<_, StdString>("path").unwrap_or_default();

if let Some(file_path) = package_searchpath(&modname, &search_path, false) {
match fs::read(&file_path) {
Ok(buf) => {
return lua
.load(&buf)
.set_name(&format!("={}", file_path.display()))
.set_mode(ChunkMode::Text)
.into_function()
.map(Value::Function);
}
Err(err) => {
return format!("cannot open '{}': {err}", file_path.display()).into_lua(lua);
}
}
}

Ok(Value::Nil)
}
2 changes: 1 addition & 1 deletion src/util/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -997,7 +997,7 @@ impl WrappedFailure {
#[cfg(feature = "luau")]
let ud = ffi::lua_newuserdata_t::<Self>(state);
#[cfg(not(feature = "luau"))]
let ud = ffi::lua_newuserdata(state, std::mem::size_of::<WrappedFailure>()) as *mut Self;
let ud = ffi::lua_newuserdata(state, std::mem::size_of::<Self>()) as *mut Self;
ptr::write(ud, WrappedFailure::None);
ud
}
Expand Down
6 changes: 4 additions & 2 deletions tests/luau.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#![cfg(feature = "luau")]

use std::env;
use std::fmt::Debug;
use std::fs;
use std::panic::{catch_unwind, AssertUnwindSafe};
Expand Down Expand Up @@ -37,7 +36,10 @@ fn test_require() -> Result<()> {
"#,
)?;

env::set_var("LUAU_PATH", temp_dir.path().join("?.luau"));
lua.globals()
.get::<_, Table>("package")?
.set("path", temp_dir.path().join("?.luau").to_string_lossy())?;

lua.load(
r#"
local module = require("module")
Expand Down

0 comments on commit 2d77569

Please sign in to comment.