forked from RWKV/rwkv.cpp
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrwkv.h
89 lines (71 loc) · 3.18 KB
/
rwkv.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
#ifndef RWKV_H
#define RWKV_H
#include <stddef.h>
#include <stdint.h>
#include <stdbool.h>
#ifdef RWKV_SHARED
# if defined(_WIN32) && !defined(__MINGW32__)
# ifdef RWKV_BUILD
# define RWKV_API __declspec(dllexport)
# else
# define RWKV_API __declspec(dllimport)
# endif
# else
# define RWKV_API __attribute__ ((visibility ("default")))
# endif
#else
# define RWKV_API
#endif
// 'ggmf' in hex.
#define RWKV_FILE_MAGIC 0x67676d66
#define RWKV_FILE_VERSION 100
#define RWKV_EVAL_FULL 0x0
#define RWKV_EVAL_PARTIAL 0x1
#define RWKV_EVAL_REST 0x2
#define RWKV_EVAL_MIN RWKV_EVAL_FULL
#define RWKV_EVAL_MAX RWKV_EVAL_REST
#ifdef __cplusplus
extern "C" {
#endif
struct rwkv_context;
// Loads the model from a file and prepares it for inference.
// Returns NULL on any error. Error messages would be printed to stderr.
// - model_file_path: path to model file in ggml format.
// - n_threads: count of threads to use, must be positive.
RWKV_API struct rwkv_context * rwkv_init_from_file(const char * model_file_path, const uint32_t n_threads);
// Evaluates the model for a single token.
// Returns false on any error. Error messages would be printed to stderr.
// - token: next token index, in range 0 <= token < n_vocab.
// - state_in: FP32 buffer of size rwkv_get_state_buffer_element_count; or NULL, if this is a first pass.
// - state_out: FP32 buffer of size rwkv_get_state_buffer_element_count. This buffer will be written to.
// - logits_out: FP32 buffer of size rwkv_get_logits_buffer_element_count. This buffer will be written to.
RWKV_API bool rwkv_eval(const struct rwkv_context * ctx, const uint32_t token, const float * state_in, float * state_out, float * logits_out, int32_t eval_mode);
// Returns count of FP32 elements in state buffer.
RWKV_API uint32_t rwkv_get_state_buffer_element_count(const struct rwkv_context * ctx);
// Returns count of FP32 elements in logits buffer.
RWKV_API uint32_t rwkv_get_logits_buffer_element_count(const struct rwkv_context * ctx);
// Returns the number of layers in the model.
RWKV_API uint32_t rwkv_get_layer_count(const struct rwkv_context * ctx);
// Returns the size of the embedding vector of the model.
RWKV_API uint32_t rwkv_get_embedding_size(const struct rwkv_context * ctx);
// Frees all allocated memory and the context.
RWKV_API void rwkv_free(struct rwkv_context * ctx);
// Quantizes FP32 or FP16 model to one of quantized formats.
// Returns false on any error. Error messages would be printed to stderr.
// - model_file_path_in: path to model file in ggml format, must be either FP32 or FP16.
// - model_file_path_out: quantized model will be written here.
// - format_name: must be one of available format names below.
// Available format names:
// - Q4_0
// - Q4_1
// - Q4_2
// - Q5_0
// - Q5_1
// - Q8_0
RWKV_API bool rwkv_quantize_model_file(const char * model_file_path_in, const char * model_file_path_out, const char * format_name);
// Returns system information string.
RWKV_API const char * rwkv_get_system_info_string(void);
#ifdef __cplusplus
}
#endif
#endif