Skip to content

Commit

Permalink
Support Luau 0.650 native vector library
Browse files Browse the repository at this point in the history
  • Loading branch information
khvzak committed Nov 9, 2024
1 parent a4bfeb7 commit a3cd25d
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 51 deletions.
18 changes: 1 addition & 17 deletions src/luau/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use std::ffi::CStr;
use std::os::raw::{c_float, c_int};
use std::os::raw::c_int;

use crate::error::Result;
use crate::state::Lua;
Expand All @@ -11,7 +11,6 @@ impl Lua {
let globals = self.globals();

globals.raw_set("collectgarbage", self.create_c_function(lua_collectgarbage)?)?;
globals.raw_set("vector", self.create_c_function(lua_vector)?)?;

// Set `_VERSION` global to include version number
// The environment variable `LUAU_VERSION` set by the build script
Expand Down Expand Up @@ -65,21 +64,6 @@ unsafe extern "C-unwind" fn lua_collectgarbage(state: *mut ffi::lua_State) -> c_
}
}

// Luau vector datatype constructor
unsafe extern "C-unwind" fn lua_vector(state: *mut ffi::lua_State) -> c_int {
let x = ffi::luaL_checknumber(state, 1) as c_float;
let y = ffi::luaL_checknumber(state, 2) as c_float;
let z = ffi::luaL_checknumber(state, 3) as c_float;
#[cfg(feature = "luau-vector4")]
let w = ffi::luaL_checknumber(state, 4) as c_float;

#[cfg(not(feature = "luau-vector4"))]
ffi::lua_pushvector(state, x, y, z);
#[cfg(feature = "luau-vector4")]
ffi::lua_pushvector(state, x, y, z, w);
1
}

pub(crate) use package::register_package_module;

mod package;
23 changes: 14 additions & 9 deletions src/state/raw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1347,6 +1347,12 @@ unsafe fn load_std_libs(state: *mut ffi::lua_State, libs: StdLib) -> Result<()>
ffi::lua_pop(state, 1);
}

#[cfg(feature = "luau")]
if libs.contains(StdLib::VECTOR) {
requiref(state, ffi::LUA_VECLIBNAME, ffi::luaopen_vector, 1)?;
ffi::lua_pop(state, 1);
}

if libs.contains(StdLib::MATH) {
requiref(state, ffi::LUA_MATHLIBNAME, ffi::luaopen_math, 1)?;
ffi::lua_pop(state, 1);
Expand All @@ -1369,16 +1375,15 @@ unsafe fn load_std_libs(state: *mut ffi::lua_State, libs: StdLib) -> Result<()>
}

#[cfg(feature = "luajit")]
{
if libs.contains(StdLib::JIT) {
requiref(state, ffi::LUA_JITLIBNAME, ffi::luaopen_jit, 1)?;
ffi::lua_pop(state, 1);
}
if libs.contains(StdLib::JIT) {
requiref(state, ffi::LUA_JITLIBNAME, ffi::luaopen_jit, 1)?;
ffi::lua_pop(state, 1);
}

if libs.contains(StdLib::FFI) {
requiref(state, ffi::LUA_FFILIBNAME, ffi::luaopen_ffi, 1)?;
ffi::lua_pop(state, 1);
}
#[cfg(feature = "luajit")]
if libs.contains(StdLib::FFI) {
requiref(state, ffi::LUA_FFILIBNAME, ffi::luaopen_ffi, 1)?;
ffi::lua_pop(state, 1);
}

Ok(())
Expand Down
7 changes: 6 additions & 1 deletion src/stdlib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,17 @@ impl StdLib {
#[cfg_attr(docsrs, doc(cfg(feature = "luau")))]
pub const BUFFER: StdLib = StdLib(1 << 9);

/// [`vector`](https://luau-lang.org/library#vector-library) library
#[cfg(any(feature = "luau", doc))]
#[cfg_attr(docsrs, doc(cfg(feature = "luau")))]
pub const VECTOR: StdLib = StdLib(1 << 10);

/// [`jit`](http://luajit.org/ext_jit.html) library
///
/// Requires `feature = "luajit"`
#[cfg(any(feature = "luajit", doc))]
#[cfg_attr(docsrs, doc(cfg(feature = "luajit")))]
pub const JIT: StdLib = StdLib(1 << 9);
pub const JIT: StdLib = StdLib(1 << 11);

/// (**unsafe**) [`ffi`](http://luajit.org/ext_ffi.html) library
///
Expand Down
24 changes: 14 additions & 10 deletions tests/luau.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,17 +97,19 @@ fn test_require() -> Result<()> {
fn test_vectors() -> Result<()> {
let lua = Lua::new();

let v: Vector = lua.load("vector(1, 2, 3) + vector(3, 2, 1)").eval()?;
let v: Vector = lua
.load("vector.create(1, 2, 3) + vector.create(3, 2, 1)")
.eval()?;
assert_eq!(v, [4.0, 4.0, 4.0]);

// Test conversion into Rust array
let v: [f64; 3] = lua.load("vector(1, 2, 3)").eval()?;
let v: [f64; 3] = lua.load("vector.create(1, 2, 3)").eval()?;
assert!(v == [1.0, 2.0, 3.0]);

// Test vector methods
lua.load(
r#"
local v = vector(1, 2, 3)
local v = vector.create(1, 2, 3)
assert(v.x == 1)
assert(v.y == 2)
assert(v.z == 3)
Expand All @@ -118,7 +120,7 @@ fn test_vectors() -> Result<()> {
// Test vector methods (fastcall)
lua.load(
r#"
local v = vector(1, 2, 3)
local v = vector.create(1, 2, 3)
assert(v.x == 1)
assert(v.y == 2)
assert(v.z == 3)
Expand All @@ -135,17 +137,19 @@ fn test_vectors() -> Result<()> {
fn test_vectors() -> Result<()> {
let lua = Lua::new();

let v: Vector = lua.load("vector(1, 2, 3, 4) + vector(4, 3, 2, 1)").eval()?;
let v: Vector = lua
.load("vector.create(1, 2, 3, 4) + vector.create(4, 3, 2, 1)")
.eval()?;
assert_eq!(v, [5.0, 5.0, 5.0, 5.0]);

// Test conversion into Rust array
let v: [f64; 4] = lua.load("vector(1, 2, 3, 4)").eval()?;
let v: [f64; 4] = lua.load("vector.create(1, 2, 3, 4)").eval()?;
assert!(v == [1.0, 2.0, 3.0, 4.0]);

// Test vector methods
lua.load(
r#"
local v = vector(1, 2, 3, 4)
local v = vector.create(1, 2, 3, 4)
assert(v.x == 1)
assert(v.y == 2)
assert(v.z == 3)
Expand All @@ -157,7 +161,7 @@ fn test_vectors() -> Result<()> {
// Test vector methods (fastcall)
lua.load(
r#"
local v = vector(1, 2, 3, 4)
local v = vector.create(1, 2, 3, 4)
assert(v.x == 1)
assert(v.y == 2)
assert(v.z == 3)
Expand All @@ -180,10 +184,10 @@ fn test_vector_metatable() -> Result<()> {
r#"
{
__index = {
new = vector,
new = vector.create,
product = function(a, b)
return vector(a.x * b.x, a.y * b.y, a.z * b.z)
return vector.create(a.x * b.x, a.y * b.y, a.z * b.z)
end
}
}
Expand Down
18 changes: 4 additions & 14 deletions tests/serde.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,13 +133,7 @@ fn test_serialize_failure() -> Result<(), Box<dyn StdError>> {
fn test_serialize_vector() -> Result<(), Box<dyn StdError>> {
let lua = Lua::new();

let globals = lua.globals();
globals.set(
"vector",
lua.create_function(|_, (x, y, z)| Ok(mlua::Vector::new(x, y, z)))?,
)?;

let val = lua.load("{_vector = vector(1, 2, 3)}").eval::<Value>()?;
let val = lua.load("{_vector = vector.create(1, 2, 3)}").eval::<Value>()?;
let json = serde_json::json!({
"_vector": [1.0, 2.0, 3.0],
});
Expand All @@ -156,13 +150,9 @@ fn test_serialize_vector() -> Result<(), Box<dyn StdError>> {
fn test_serialize_vector() -> Result<(), Box<dyn StdError>> {
let lua = Lua::new();

let globals = lua.globals();
globals.set(
"vector",
lua.create_function(|_, (x, y, z, w)| Ok(mlua::Vector::new(x, y, z, w)))?,
)?;

let val = lua.load("{_vector = vector(1, 2, 3, 4)}").eval::<Value>()?;
let val = lua
.load("{_vector = vector.create(1, 2, 3, 4)}")
.eval::<Value>()?;
let json = serde_json::json!({
"_vector": [1.0, 2.0, 3.0, 4.0],
});
Expand Down

0 comments on commit a3cd25d

Please sign in to comment.