跳到主要内容

使用Rust编写Lua扩展库

参考代码: https://github.com/sniper00/moon-extensions/tree/main/rust

Rust有完善的包管理机制,借助Rust可以极大丰富Lua扩展库,如依赖tokiohttps client,sqlx 等网络相关的库。使用Rust编写Lua扩展库也是比较容易的, 因为Rust本身提供编写动态库给C/Cpp调用。

编写Lua扩展库需要依赖lib-lua-sys, 这里参考了mlua-sys(Low level (FFI) bindings to Lua 5.4/5.3/5.2/5.1 (including LuaJIT) and Roblox Luau)。 Rust mlua库本身是支持编写lua扩展库的, 但它比较复杂,并且使用起来没有 lua api灵活,所以这里只使用它的 ffi bingding部分。同时lib-lua-sys也做了一些改动, mlua默认是静态link的lua库, 由于要给moon编写扩展库lib-lua-sys是动态link lua库的。

编写基础库

新建Rust项目, 手动添加lib-lua-sys。具体Rust包管理机制请参考Rust相关文档,这里不再详细描述。

[dependencies]
lib-core = { path = "../../libs/lib-core"}
lib-lua = {package = "lib-lua-sys", path = "../../libs/lib-lua-sys",features = ["lua54"]}

然后就可以使用类似 lua c api 的方式编写lua扩展库了, 如lua_excel举例

use calamine::{open_workbook, Data, Reader, Xlsx};
use csv::ReaderBuilder;
use lib_lua::{self, cstr, ffi, ffi::luaL_Reg, laux, lreg, lreg_null};
use std::{os::raw::c_int, path::Path};

fn read_csv(state: *mut ffi::lua_State, path: &Path, max_row: usize) -> c_int {
let res = ReaderBuilder::new().has_headers(false).from_path(path);
unsafe {
ffi::lua_createtable(state, 0, 0);
}

match res {
Ok(mut reader) => {
unsafe {
ffi::lua_createtable(state, 0, 2);
laux::lua_push(
state,
path.file_stem()
.unwrap_or_default()
.to_str()
.unwrap_or_default(),
);
ffi::lua_setfield(state, -2, cstr!("sheet_name"));
ffi::lua_createtable(state, 1024, 0);
}

let mut idx: usize = 0;

for result in reader.records() {
if idx >= max_row {
break;
}
match result {
Ok(record) => unsafe {
ffi::lua_createtable(state, 0, record.len() as i32);
for (i, field) in record.iter().enumerate() {
laux::lua_push(state, field);
ffi::lua_rawseti(state, -2, (i + 1) as i64);
}
idx += 1;
ffi::lua_rawseti(state, -2, idx as i64);
},
Err(err) => unsafe {
ffi::lua_pushboolean(state, 0);
laux::lua_push(
state,
format!("read csv '{}' error: {}", path.to_string_lossy(), err)
.as_str(),
);
return 2;
},
}
}

unsafe {
ffi::lua_setfield(state, -2, cstr!("data"));
ffi::lua_rawseti(state, -2, 1);
}
1
}
Err(err) => {
unsafe {
ffi::lua_pushboolean(state, 0);
}

laux::lua_push(
state,
format!("open file '{}' error: {}", path.to_string_lossy(), err).as_str(),
);
2
}
}
}

fn read_xlxs(state: *mut ffi::lua_State, path: &Path, max_row: usize) -> c_int {
let res: Result<Xlsx<_>, _> = open_workbook(path);
match res {
Ok(mut workbook) => {
unsafe {
ffi::lua_createtable(state, 0, 0);
}
let mut sheet_counter = 0;
workbook.sheet_names().iter().for_each(|sheet| {
if let Ok(range) = workbook.worksheet_range(sheet) {
unsafe {
ffi::lua_createtable(state, 0, 2);
laux::lua_push(state, sheet.as_str());

ffi::lua_setfield(state, -2, cstr!("sheet_name"));

ffi::lua_createtable(state, range.rows().len() as i32, 0);
for (i, row) in range.rows().enumerate() {
if i >= max_row {
break;
}
//rows
ffi::lua_createtable(state, row.len() as i32, 0);

for (j, cell) in row.iter().enumerate() {
//columns

match cell {
Data::Int(v) => {
ffi::lua_pushinteger(state, *v as ffi::lua_Integer)
}
Data::Float(v) => ffi::lua_pushnumber(state, *v),
Data::String(v) => laux::lua_push(state, v.as_str()),
Data::Bool(v) => ffi::lua_pushboolean(state, *v as i32),
Data::Error(v) => laux::lua_push(state, v.to_string()),
Data::Empty => ffi::lua_pushnil(state),
Data::DateTime(v) => laux::lua_push(state, v.to_string()),
_ => ffi::lua_pushnil(state),
}
ffi::lua_rawseti(state, -2, (j + 1) as i64);
}
ffi::lua_rawseti(state, -2, (i + 1) as i64);
}
ffi::lua_setfield(state, -2, cstr!("data"));
}
sheet_counter += 1;
unsafe {
ffi::lua_rawseti(state, -2, sheet_counter as i64);
}
}
});
1
}
Err(err) => unsafe {
ffi::lua_pushboolean(state, 0);
laux::lua_push(state, format!("{}", err).as_str());
2
},
}
}

extern "C-unwind" fn lua_excel_read(state: *mut ffi::lua_State) -> c_int {
let filename: &str = laux::lua_get(state, 1);
let max_row: usize = laux::lua_opt(state, 2).unwrap_or(usize::MAX);
let path = Path::new(filename);

match path.extension() {
Some(ext) => {
let ext = ext.to_string_lossy().to_string();
match ext.as_str() {
"csv" => read_csv(state, path, max_row),
"xlsx" => read_xlxs(state, path, max_row),
_ => unsafe {
ffi::lua_pushboolean(state, 0);
laux::lua_push(state, format!("unsupport file type: {}", ext));
2
},
}
}
None => unsafe {
ffi::lua_pushboolean(state, 0);
laux::lua_push(
state,
format!("unsupport file type: {}", path.to_string_lossy()),
);
2
},
}
}

/// # Safety
///
/// This function is unsafe because it dereferences a raw pointer `state`.
/// The caller must ensure that `state` is a valid pointer to a `lua_State`
/// and that it remains valid for the duration of the function call.
#[no_mangle]
#[allow(clippy::not_unsafe_ptr_arg_deref)]
pub unsafe extern "C-unwind" fn luaopen_rust_excel(state: *mut ffi::lua_State) -> c_int {
let l = [lreg!("read", lua_excel_read), lreg_null!()];

ffi::lua_createtable(state, 0, l.len() as c_int);
ffi::luaL_setfuncs(state, l.as_ptr(), 0);

1
}

编写异步库

对于带异步调用的库,一般是和框架的事件循环相关,由于moon是基于Actor模型的,一切皆消息, 只需要把发送消息的函数, 导出给Rust就可以接入到moon的事件循环系统中。

//导出函数
extern "C" {
void MOON_EXPORT
send_message(uint8_t type, uint32_t receiver, int64_t session, const char* data, size_t len) {
auto svr = wk_server.lock();
if (nullptr == svr)
return;
moon::message msg(len);
msg.set_type(type);
msg.set_receiver(receiver);
msg.set_sessionid(session);
msg.write_data(std::string_view(data, len));
svr->send_message(std::move(msg));
}
}

rust中使用导出的函数

//Send rust object pointer
pub fn moon_send<T>(protocol_type: u8, owner: u32, session: i64, res: T) {
unsafe extern "C-unwind" {
unsafe fn send_integer_message(type_: u8, receiver: u32, session: i64, val: isize);
}

if session == 0 {
return;
}
let ptr = Box::into_raw(Box::new(res));

unsafe {
send_integer_message(protocol_type, owner, session, ptr as isize);
}
}

pub fn moon_send_bytes(protocol_type: u8, owner: u32, session: i64, data: &[u8]) {
unsafe extern "C-unwind" {
unsafe fn send_message(type_: u8, receiver: u32, session: i64, data: *const i8, len: usize);
}

unsafe {
send_message(
protocol_type,
owner,
session,
data.as_ptr() as *const i8,
data.len(),
);
}
}

注意: 带异步运行时的Rust扩展库不能随Lua虚拟机关闭而卸载,这里需要修改Lua源码, 取消dlclose(lib)/FreeLibrary((HMODULE)lib).

这里拿lua_http库举例

use lib_core::context::CONTEXT;
use lib_lua::{
self, cstr,
ffi::{self, luaL_Reg},
laux::{self},
lreg, lreg_null, luaL_newlib, lua_rawsetfield,
};
use reqwest::{header::HeaderMap, Method, Response};
use std::{error::Error, ffi::c_int, str::FromStr};
use url::form_urlencoded::{self};

use crate::{moon_send, moon_send_bytes, PTYPE_ERROR};

struct HttpRequest {
owner: u32,
session: i64,
method: String,
url: String,
body: String,
headers: HeaderMap,
timeout: u64,
proxy: String,
}

fn version_to_string(version: &reqwest::Version) -> &str {
match *version {
reqwest::Version::HTTP_09 => "HTTP/0.9",
reqwest::Version::HTTP_10 => "HTTP/1.0",
reqwest::Version::HTTP_11 => "HTTP/1.1",
reqwest::Version::HTTP_2 => "HTTP/2.0",
reqwest::Version::HTTP_3 => "HTTP/3.0",
_ => "Unknown",
}
}

async fn http_request(req: HttpRequest, protocol_type: u8) -> Result<(), Box<dyn Error>> {
let http_client = &CONTEXT.get_http_client(req.timeout, &req.proxy);

let response = http_client
.request(Method::from_str(req.method.as_str())?, req.url)
.headers(req.headers)
.body(req.body)
.send()
.await?;

moon_send(protocol_type, req.owner, req.session, response);

Ok(())
}

fn extract_headers(state: *mut ffi::lua_State, index: i32) -> Result<HeaderMap, String> {
let mut headers = HeaderMap::new();

laux::push_c_string(state, cstr!("headers"));
if laux::lua_rawget(state, index) == ffi::LUA_TTABLE {
// [+1]
laux::lua_pushnil(state);
while laux::lua_next(state, -2) {
let key: &str = laux::lua_opt(state, -2).unwrap_or_default();
let value: &str = laux::lua_opt(state, -1).unwrap_or_default();
match key.parse::<reqwest::header::HeaderName>() {
Ok(name) => match value.parse::<reqwest::header::HeaderValue>() {
Ok(value) => {
headers.insert(name, value);
}
Err(err) => return Err(err.to_string()),
},
Err(err) => return Err(err.to_string()),
}
laux::lua_pop(state, 1);
}
laux::lua_pop(state, 1); //pop headers table
}

Ok(headers)
}

extern "C-unwind" fn lua_http_request(state: *mut ffi::lua_State) -> c_int {
laux::lua_checktype(state, 1, ffi::LUA_TTABLE);

let protocol_type = laux::lua_get::<u8>(state, 2);

let headers = match extract_headers(state, 1) {
Ok(headers) => headers,
Err(err) => {
laux::lua_push(state, false);
laux::lua_push(state, err);
return 2;
}
};

let session = laux::opt_field(state, 1, "session").unwrap_or(0);

let req = HttpRequest {
owner: laux::opt_field(state, 1, "owner").unwrap_or_default(),
session,
method: laux::opt_field(state, 1, "method").unwrap_or("GET".to_string()),
url: laux::opt_field(state, 1, "url").unwrap_or_default(),
body: laux::opt_field(state, 1, "body").unwrap_or_default(),
headers,
timeout: laux::opt_field(state, 1, "timeout").unwrap_or(5),
proxy: laux::opt_field(state, 1, "proxy").unwrap_or_default(),
};

CONTEXT.tokio_runtime.spawn(async move {
let session = req.session;
let owner = req.owner;
if let Err(err) = http_request(req, protocol_type).await {
let err_string = err.to_string();
moon_send_bytes(PTYPE_ERROR, owner, session, err_string.as_bytes());
}
});

laux::lua_push(state, session);
1
}

extern "C-unwind" fn lua_http_form_urlencode(state: *mut ffi::lua_State) -> c_int {
laux::lua_checktype(state, 1, ffi::LUA_TTABLE);
laux::lua_pushnil(state);
let mut result = String::new();
while laux::lua_next(state, 1) {
if !result.is_empty() {
result.push('&');
}
let key = laux::to_string_unchecked(state, -2);
let value = laux::to_string_unchecked(state, -1);
result.push_str(
form_urlencoded::byte_serialize(key.as_bytes())
.collect::<String>()
.as_str(),
);
result.push('=');
result.push_str(
form_urlencoded::byte_serialize(value.as_bytes())
.collect::<String>()
.as_str(),
);
laux::lua_pop(state, 1);
}
laux::lua_push(state, result);
1
}

extern "C-unwind" fn lua_http_form_urldecode(state: *mut ffi::lua_State) -> c_int {
let query_string = laux::lua_get::<&str>(state, 1);

unsafe { ffi::lua_createtable(state, 0, 8) };

let decoded: Vec<(String, String)> = form_urlencoded::parse(query_string.as_bytes())
.into_owned()
.collect();

for pair in decoded {
laux::lua_push(state, pair.0);
laux::lua_push(state, pair.1);
unsafe {
ffi::lua_rawset(state, -3);
}
}
1
}

extern "C-unwind" fn decode(state: *mut ffi::lua_State) -> c_int {
laux::luaL_checkstack(state, 4, std::ptr::null());
let p_as_isize: isize = laux::lua_get(state, 1);
let response = unsafe { Box::from_raw(p_as_isize as *mut Response) };

unsafe {
ffi::lua_createtable(state, 0, 6);
lua_rawsetfield!(
state,
-1,
"version",
laux::lua_push(state, version_to_string(&response.version()))
);
lua_rawsetfield!(
state,
-1,
"status_code",
laux::lua_push(state, response.status().as_u16() as u32)
);

ffi::lua_pushstring(state, cstr!("headers"));
ffi::lua_createtable(state, 0, 16);

for (key, value) in response.headers().iter() {
laux::lua_push(state, key.to_string().to_lowercase());
laux::lua_push(state, value.to_str().unwrap_or("").trim());
ffi::lua_rawset(state, -3);
}
ffi::lua_rawset(state, -3);
}
1
}

#[no_mangle]
#[allow(clippy::not_unsafe_ptr_arg_deref)]
pub extern "C-unwind" fn luaopen_rust_httpc(state: *mut ffi::lua_State) -> c_int {
let l = [
lreg!("request", lua_http_request),
lreg!("form_urlencode", lua_http_form_urlencode),
lreg!("form_urldecode", lua_http_form_urldecode),
lreg!("decode", decode),
lreg_null!(),
];

luaL_newlib!(state, l);

1
}

这样就完成了Rust异步库和moon的集成, Lua层包装代码

---@diagnostic disable: inject-field
local moon = require "moon"
local json = require "json"
local c = require "rust.httpc"

local protocol_type = 21

moon.register_protocol {
name = "http",
PTYPE = protocol_type,
pack = function(...) return ... end,
unpack = function (val)
return c.decode(val) -- 'val' is rust object pointer
end
}

---@return table
local function tojson(response)
if response.status_code ~= 200 then return {} end
return json.decode(response.body)
end

---@class HttpRequestOptions
---@field headers? table<string,string>
---@field timeout? integer Request timeout in seconds. default 5s
---@field proxy? string

local client = {}

---@param url string
---@param opts? HttpRequestOptions
---@return HttpResponse
function client.get(url, opts)
opts = opts or {}
opts.owner = moon.id
opts.session = moon.next_sequence()
opts.url = url
opts.method = "GET"
return moon.wait(c.request(opts, protocol_type))
end

local json_content_type = { ["Content-Type"] = "application/json" }

---@param url string
---@param data table
---@param opts? HttpRequestOptions
---@return HttpResponse
function client.post_json(url, data, opts)
opts = opts or {}
opts.owner = moon.id
opts.session = moon.next_sequence()
if not opts.headers then
opts.headers = json_content_type
else
if not opts.headers['Content-Type'] then
opts.headers['Content-Type'] = "application/json"
end
end

opts.url = url
opts.method = "POST"
opts.body = json.encode(data)

local res = moon.wait(c.request(opts, protocol_type))

if res.status_code == 200 then
res.body = tojson(res)
end
return res
end

---@param url string
---@param data string
---@param opts? HttpRequestOptions
---@return HttpResponse
function client.post(url, data, opts)
opts = opts or {}
opts.owner = moon.id
opts.session = moon.next_sequence()
opts.url = url
opts.body = data
opts.method = "POST"
return moon.wait(c.request(opts, protocol_type))
end

local form_headers = { ["Content-Type"] = "application/x-www-form-urlencoded" }

---@param url string
---@param data table<string,string>
---@param opts? HttpRequestOptions
---@return HttpResponse
function client.post_form(url, data, opts)
opts = opts or {}
opts.owner = moon.id
opts.session = moon.next_sequence()
if not opts.headers then
opts.headers = form_headers
else
if not opts.headers['Content-Type'] then
opts.headers['Content-Type'] = "application/x-www-form-urlencoded"
end
end

opts.body = {}
for k, v in pairs(data) do
opts.body[k] = tostring(v)
end

opts.url = url
opts.method = "POST"
opts.body = c.form_urlencode(opts.body)

return moon.wait(c.request(opts, protocol_type))
end

return client