Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions core/iwasm/common/wasm_native.c
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ static NativeSymbolsList g_native_symbols_list = NULL;
static void *g_wasi_context_key;
#endif /* WASM_ENABLE_LIBC_WASI */

#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0
static void *g_wasi_nn_context_key;
#endif

uint32
get_libc_builtin_export_apis(NativeSymbol **p_libc_builtin_apis);

Expand Down Expand Up @@ -473,6 +477,32 @@ wasi_context_dtor(WASMModuleInstanceCommon *inst, void *ctx)
}
#endif /* end of WASM_ENABLE_LIBC_WASI */

#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0
WASINNGlobalContext *
wasm_runtime_get_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst_comm)
{
return wasm_native_get_context(module_inst_comm, g_wasi_nn_context_key);
}

void
wasm_runtime_set_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst_comm,
WASINNGlobalContext *wasi_nn_ctx)
{
wasm_native_set_context(module_inst_comm, g_wasi_nn_context_key,
wasi_nn_ctx);
}

static void
wasi_nn_context_dtor(WASMModuleInstanceCommon *inst, void *ctx)
{
if (ctx == NULL) {
return;
}

wasm_runtime_destroy_wasi_nn_global_ctx(inst);
}
#endif

#if WASM_ENABLE_QUICK_AOT_ENTRY != 0
static bool
quick_aot_entry_init(void);
Expand Down Expand Up @@ -582,6 +612,12 @@ wasm_native_init()
#endif /* WASM_ENABLE_LIB_RATS */

#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0
g_wasi_nn_context_key =
wasm_native_create_context_key(wasi_nn_context_dtor);
if (g_wasi_nn_context_key == NULL) {
goto fail;
}

if (!wasi_nn_initialize())
goto fail;

Expand Down Expand Up @@ -648,6 +684,10 @@ wasm_native_destroy()
#endif

#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0
if (g_wasi_nn_context_key != NULL) {
wasm_native_destroy_context_key(g_wasi_nn_context_key);
g_wasi_nn_context_key = NULL;
}
wasi_nn_destroy();
#endif

Expand Down
220 changes: 220 additions & 0 deletions core/iwasm/common/wasm_runtime_common.c
Original file line number Diff line number Diff line change
Expand Up @@ -1795,6 +1795,83 @@ wasm_runtime_instantiation_args_set_wasi_ns_lookup_pool(
}
#endif /* WASM_ENABLE_LIBC_WASI != 0 */

#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0
typedef struct WASINNArguments WASINNArguments;

void
wasm_runtime_wasi_nn_graph_registry_args_set_defaults(WASINNArguments *args)
{
memset(args, 0, sizeof(*args));
}

bool
wasi_nn_graph_registry_set_args(WASINNArguments *registry,
const char **model_names, const char **encoding,
const char **target, uint32_t n_graphs,
const char **graph_paths)
{
if (!registry || !model_names || !encoding || !target || !graph_paths) {
return false;
}

registry->n_graphs = n_graphs;
registry->target = (uint32_t **)malloc(sizeof(uint32_t *) * n_graphs);
registry->encoding = (uint32_t **)malloc(sizeof(uint32_t *) * n_graphs);
registry->model_names = (uint32_t **)malloc(sizeof(uint32_t *) * n_graphs);
registry->graph_paths = (uint32_t **)malloc(sizeof(uint32_t *) * n_graphs);
memset(registry->target, 0, sizeof(uint32_t *) * n_graphs);
memset(registry->encoding, 0, sizeof(uint32_t *) * n_graphs);
memset(registry->model_names, 0, sizeof(uint32_t *) * n_graphs);
memset(registry->graph_paths, 0, sizeof(uint32_t *) * n_graphs);

for (uint32_t i = 0; i < registry->n_graphs; i++) {
registry->graph_paths[i] = strdup(graph_paths[i]);
registry->model_names[i] = strdup(model_names[i]);
registry->encoding[i] = strdup(encoding[i]);
registry->target[i] = strdup(target[i]);
}

return true;
}

int
wasi_nn_graph_registry_create(WASINNArguments **registryp)
{
WASINNArguments *args = wasm_runtime_malloc(sizeof(*args));
if (args == NULL) {
return -1;
}
wasm_runtime_wasi_nn_graph_registry_args_set_defaults(args);
*registryp = args;
return 0;
}

void
wasi_nn_graph_registry_destroy(WASINNArguments *registry)
{
if (registry) {
for (uint32_t i = 0; i < registry->n_graphs; i++)
if (registry->graph_paths[i]) {
free(registry->graph_paths[i]);
if (registry->model_names[i])
free(registry->model_names[i]);
if (registry->encoding[i])
free(registry->encoding[i]);
if (registry->target[i])
free(registry->target[i]);
}
free(registry);
}
}

void
wasm_runtime_instantiation_args_set_wasi_nn_graph_registry(
struct InstantiationArgs2 *p, WASINNArguments *registry)
{
p->nn_registry = *registry;
}
#endif

WASMModuleInstanceCommon *
wasm_runtime_instantiate_ex2(WASMModuleCommon *module,
const struct InstantiationArgs2 *args,
Expand Down Expand Up @@ -8080,3 +8157,146 @@ wasm_runtime_check_and_update_last_used_shared_heap(
return false;
}
#endif

#if WASM_ENABLE_WASI_NN != 0 || WASM_ENABLE_WASI_EPHEMERAL_NN != 0
bool
wasm_runtime_init_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst,
const char **model_names,
const char **encoding, const char **target,
const uint32_t n_graphs,
char *graph_paths[], char *error_buf,
uint32_t error_buf_size)
{
WASINNGlobalContext *ctx;
bool ret = false;

ctx = runtime_malloc(sizeof(*ctx), module_inst, error_buf, error_buf_size);
if (!ctx)
return false;

ctx->n_graphs = n_graphs;

ctx->encoding = (uint32_t *)malloc(sizeof(uint32_t) * n_graphs);
memset(ctx->encoding, 0, sizeof(uint32_t) * n_graphs);
ctx->target = (uint32_t *)malloc(sizeof(uint32_t) * n_graphs);
memset(ctx->target, 0, sizeof(uint32_t) * n_graphs);
ctx->loaded = (uint32_t *)malloc(sizeof(uint32_t) * n_graphs);
memset(ctx->loaded, 0, sizeof(uint32_t) * n_graphs);
ctx->model_names = (uint32_t **)malloc(sizeof(uint32_t *) * n_graphs);
memset(ctx->model_names, 0, sizeof(uint32_t *) * n_graphs);
ctx->graph_paths = (uint32_t **)malloc(sizeof(uint32_t *) * n_graphs);
memset(ctx->graph_paths, 0, sizeof(uint32_t *) * n_graphs);

for (uint32_t i = 0; i < n_graphs; i++) {
ctx->graph_paths[i] = strdup(graph_paths[i]);
ctx->model_names[i] = strdup(model_names[i]);
ctx->target[i] = strdup(target[i]);
ctx->encoding[i] = strdup(encoding[i]);
}

wasm_runtime_set_wasi_nn_global_ctx(module_inst, ctx);

ret = true;

return ret;
}

void
wasm_runtime_destroy_wasi_nn_global_ctx(WASMModuleInstanceCommon *module_inst)
{
WASINNGlobalContext *wasi_nn_global_ctx =
wasm_runtime_get_wasi_nn_global_ctx(module_inst);

for (uint32 i = 0; i < wasi_nn_global_ctx->n_graphs; i++) {
// All graphs will be unregistered in deinit()
if (wasi_nn_global_ctx->graph_paths[i])
free(wasi_nn_global_ctx->graph_paths[i]);
if (wasi_nn_global_ctx->model_names[i])
free(wasi_nn_global_ctx->model_names[i]);
if (wasi_nn_global_ctx->encoding[i])
free(wasi_nn_global_ctx->encoding[i]);
if (wasi_nn_global_ctx->target[i])
free(wasi_nn_global_ctx->target[i]);
}
free(wasi_nn_global_ctx->encoding);
free(wasi_nn_global_ctx->target);
free(wasi_nn_global_ctx->loaded);
free(wasi_nn_global_ctx->model_names);
free(wasi_nn_global_ctx->graph_paths);

if (wasi_nn_global_ctx) {
wasm_runtime_free(wasi_nn_global_ctx);
}
}

uint32_t
wasm_runtime_get_wasi_nn_global_ctx_ngraphs(
WASINNGlobalContext *wasi_nn_global_ctx)
{
if (wasi_nn_global_ctx)
return wasi_nn_global_ctx->n_graphs;

return -1;
}

char *
wasm_runtime_get_wasi_nn_global_ctx_model_names_i(
WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx)
{
if (wasi_nn_global_ctx && (idx < wasi_nn_global_ctx->n_graphs))
return wasi_nn_global_ctx->model_names[idx];

return NULL;
}

char *
wasm_runtime_get_wasi_nn_global_ctx_graph_paths_i(
WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx)
{
if (wasi_nn_global_ctx && (idx < wasi_nn_global_ctx->n_graphs))
return wasi_nn_global_ctx->graph_paths[idx];

return NULL;
}

uint32_t
wasm_runtime_get_wasi_nn_global_ctx_loaded_i(
WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx)
{
if (wasi_nn_global_ctx && (idx < wasi_nn_global_ctx->n_graphs))
return wasi_nn_global_ctx->loaded[idx];

return -1;
}

uint32_t
wasm_runtime_set_wasi_nn_global_ctx_loaded_i(
WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx, uint32_t value)
{
if (wasi_nn_global_ctx && (idx < wasi_nn_global_ctx->n_graphs))
wasi_nn_global_ctx->loaded[idx] = value;

return 0;
}

char *
wasm_runtime_get_wasi_nn_global_ctx_encoding_i(
WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx)
{
if (wasi_nn_global_ctx && (idx < wasi_nn_global_ctx->n_graphs))
return wasi_nn_global_ctx->encoding[idx];

return NULL;
}

char *
wasm_runtime_get_wasi_nn_global_ctx_target_i(
WASINNGlobalContext *wasi_nn_global_ctx, uint32_t idx)
{
if (wasi_nn_global_ctx && (idx < wasi_nn_global_ctx->n_graphs))
return wasi_nn_global_ctx->target[idx];

return NULL;
}

#endif
Loading
Loading