diff --git a/src/ruleset.c b/src/ruleset.c index e83d307..b8b7081 100644 --- a/src/ruleset.c +++ b/src/ruleset.c @@ -411,28 +411,32 @@ static int marshal_(lua_State *restrict L) return 1; } -struct unmarshal_status { +struct reader_status { struct stream *s; const char *prefix; size_t prefixlen; }; -static const char *unmarshal_stream(lua_State *L, void *ud, size_t *restrict sz) +static const char *read_stream(lua_State *L, void *ud, size_t *restrict sz) { UNUSED(L); - struct unmarshal_status *restrict ctx = ud; - const void *buf = ctx->prefix; + struct reader_status *restrict rd = ud; + const void *buf = rd->prefix; if (buf != NULL) { - ctx->prefix = NULL; - *sz = ctx->prefixlen; + *sz = rd->prefixlen; + rd->prefix = NULL; + rd->prefixlen = 0; return buf; } *sz = SIZE_MAX; /* Lua allows arbitrary length */ - const int err = stream_direct_read(ctx->s, &buf, sz); + const int err = stream_direct_read(rd->s, &buf, sz); if (err != 0) { - LOGE_F("unmarshal_stream: error %d", err); + LOGE_F("read_stream: error %d", err); + } + if (*sz == 0) { + return NULL; } - return *sz > 0 ? buf : NULL; + return buf; } enum ruleset_functions { @@ -491,25 +495,10 @@ static int ruleset_loadfile_(lua_State *restrict L) return 0; } -static const char *read_stream(lua_State *L, void *ud, size_t *restrict sz) -{ - UNUSED(L); - const void *buf; - *sz = SIZE_MAX; /* Lua allows arbitrary length */ - const int err = stream_direct_read(ud, &buf, sz); - if (err != 0) { - LOGE_F("read_stream: error %d", err); - } - if (*sz == 0) { - return NULL; - } - return buf; -} - static int ruleset_invoke_(lua_State *restrict L) { - struct stream *s = (struct stream *)lua_topointer(L, 1); - if (lua_load(L, read_stream, s, "=invoke", NULL) != LUA_OK) { + struct reader_status rd = { .s = (struct stream *)lua_topointer(L, 1) }; + if (lua_load(L, read_stream, &rd, "=invoke", NULL) != LUA_OK) { return lua_error(L); } lua_call(L, 0, 0); @@ -518,12 +507,12 @@ static int ruleset_invoke_(lua_State *restrict L) static int ruleset_rpcall_(lua_State *restrict L) { - struct stream *s = (struct stream *)lua_topointer(L, 1); + struct reader_status rd = { .s = (struct stream *)lua_topointer(L, 1) }; const void **result = (const void **)lua_topointer(L, 2); size_t *resultlen = (size_t *)lua_topointer(L, 3); lua_settop(L, 0); lua_pushcfunction(L, marshal_); - if (lua_load(L, read_stream, s, "=rpc", NULL) != LUA_OK) { + if (lua_load(L, read_stream, &rd, "=rpc", NULL) != LUA_OK) { return lua_error(L); } /* stack: marshal f */ @@ -582,10 +571,10 @@ static int package_replace_(lua_State *restrict L) static int ruleset_update_(lua_State *restrict L) { const char *modname = lua_topointer(L, 1); - struct stream *s = (struct stream *)lua_topointer(L, 2); + struct reader_status rd = { .s = (struct stream *)lua_topointer(L, 2) }; lua_settop(L, 0); if (modname == NULL) { - if (lua_load(L, read_stream, s, "=ruleset", NULL) != LUA_OK) { + if (lua_load(L, read_stream, &rd, "=ruleset", NULL) != LUA_OK) { return lua_error(L); } lua_pushliteral(L, "ruleset"); @@ -600,7 +589,7 @@ static int ruleset_update_(lua_State *restrict L) name[0] = '='; memcpy(name + 1, modname, namelen); name[1 + namelen] = '\0'; - if (lua_load(L, read_stream, s, name, NULL) != LUA_OK) { + if (lua_load(L, read_stream, &rd, name, NULL) != LUA_OK) { return lua_error(L); } } @@ -1321,12 +1310,12 @@ await_invoke_k_(lua_State *restrict L, const int status, lua_KContext ctx) } /* unmarshal */ const int base = lua_gettop(L); - struct unmarshal_status u = { + struct reader_status rd = { + .s = (struct stream *)data, .prefix = "return ", .prefixlen = 7, - .s = (struct stream *)data, }; - if (lua_load(L, unmarshal_stream, &u, "=unmarshal", NULL) != LUA_OK) { + if (lua_load(L, read_stream, &rd, "=unmarshal", NULL) != LUA_OK) { return lua_error(L); } lua_call(L, 0, LUA_MULTRET);