|
3 | 3 |
|
4 | 4 | #include <cmath>
|
5 | 5 | #include <cstdio>
|
| 6 | +#include <fstream> |
6 | 7 | #include <string>
|
7 | 8 | #include <vector>
|
8 | 9 |
|
9 |
| -int main(int argc, char ** argv) { |
10 |
| - if (argc < 3 || argv[1][0] == '-') { |
11 |
| - printf("usage: %s MODEL_PATH PROMPT [--ids]\n" , argv[0]); |
| 10 | +#if defined(_WIN32) |
| 11 | +#define WIN32_LEAN_AND_MEAN |
| 12 | +#include <windows.h> |
| 13 | +#include <shellapi.h> // For CommandLineToArgvW |
| 14 | +#endif |
| 15 | + |
| 16 | +static void print_usage_information(const char * argv0, FILE * stream) { |
| 17 | + fprintf(stream, "usage: %s [options]\n\n", argv0); |
| 18 | + fprintf(stream, "The tokenize program tokenizes a prompt using a given model,\n"); |
| 19 | + fprintf(stream, "and prints the resulting tokens to standard output.\n\n"); |
| 20 | + fprintf(stream, "It needs a model file, a prompt, and optionally other flags\n"); |
| 21 | + fprintf(stream, "to control the behavior of the tokenizer.\n\n"); |
| 22 | + fprintf(stream, " The possible options are:\n"); |
| 23 | + fprintf(stream, "\n"); |
| 24 | + fprintf(stream, " -h, --help print this help and exit\n"); |
| 25 | + fprintf(stream, " -m MODEL_PATH, --model MODEL_PATH path to model.\n"); |
| 26 | + fprintf(stream, " --ids if given, only print numerical token IDs, and not token strings.\n"); |
| 27 | + fprintf(stream, " The output format looks like [1, 2, 3], i.e. parseable by Python.\n"); |
| 28 | + fprintf(stream, " -f PROMPT_FNAME, --file PROMPT_FNAME read prompt from a file.\n"); |
| 29 | + fprintf(stream, " -p PROMPT, --prompt PROMPT read prompt from the argument.\n"); |
| 30 | + fprintf(stream, " --stdin read prompt from standard input.\n"); |
| 31 | + fprintf(stream, " --no-bos do not ever add a BOS token to the prompt, even if normally the model uses a BOS token.\n"); |
| 32 | + fprintf(stream, " --log-disable disable logs. Makes stderr quiet when loading the model.\n"); |
| 33 | +} |
| 34 | + |
| 35 | +static void llama_log_callback_null(ggml_log_level level, const char * text, void * user_data) { |
| 36 | + (void) level; |
| 37 | + (void) text; |
| 38 | + (void) user_data; |
| 39 | +} |
| 40 | + |
| 41 | +static std::string read_prompt_from_file(const char * filepath, bool & success) { |
| 42 | + success = false; |
| 43 | + |
| 44 | + std::ifstream in(filepath, std::ios::binary); |
| 45 | + if (!in) { |
| 46 | + fprintf(stderr, "%s: could not open file '%s' for reading: %s\n", __func__, filepath, strerror(errno)); |
| 47 | + return std::string(); |
| 48 | + } |
| 49 | + // do not assume the file is seekable (e.g. /dev/stdin) |
| 50 | + std::stringstream buffer; |
| 51 | + buffer << in.rdbuf(); |
| 52 | + if (in.fail()) { |
| 53 | + fprintf(stderr, "%s: could not read the entire file '%s': %s\n", __func__, filepath, strerror(errno)); |
| 54 | + return std::string(); |
| 55 | + } |
| 56 | + |
| 57 | + success = true; |
| 58 | + return buffer.str(); |
| 59 | +} |
| 60 | + |
| 61 | +// |
| 62 | +// Function: ingest_args(...) -> vector<string> |
| 63 | +// |
| 64 | +// Takes argc and argv arguments, and converts them to a vector of UTF-8 encoded |
| 65 | +// strings, as an STL vector<string>. |
| 66 | +// |
| 67 | +// In particular, it handles character encoding shenanigans on Windows. |
| 68 | +// |
| 69 | +// Note: raw_argc and raw_argv are not actually read at all on Windows. |
| 70 | +// On Windows we call GetCommandLineW to get the arguments in wchar_t |
| 71 | +// format, ignoring the regular argc/argv arguments to main(). |
| 72 | +// |
| 73 | +// TODO: potential opportunity to roll common stuff into common/console.cpp |
| 74 | +// in relation to Windows wchar_t shenanigans. |
| 75 | +static std::vector<std::string> ingest_args(int raw_argc, char ** raw_argv) { |
| 76 | + std::vector<std::string> argv; |
| 77 | + |
| 78 | + // Handle Windows, if given non-ASCII arguments. |
| 79 | + // We convert wchar_t arguments into UTF-8 char* on this platform. |
| 80 | + // Lets you invoke 'tokenize' on Windows cmd.exe with non-ASCII characters |
| 81 | + // without throwing tantrums. |
| 82 | +#if defined(_WIN32) |
| 83 | + int argc; |
| 84 | + const LPWSTR cmdline_wargv = GetCommandLineW(); |
| 85 | + LPWSTR * wargv = CommandLineToArgvW(cmdline_wargv, &argc); |
| 86 | + |
| 87 | + // silence unused arg warnings |
| 88 | + (void) raw_argc; |
| 89 | + (void) raw_argv; |
| 90 | + |
| 91 | + for (int i = 0; i < argc; ++i) { |
| 92 | + int length_needed = WideCharToMultiByte(CP_UTF8, 0, wargv[i], wcslen(wargv[i]), 0, 0, NULL, NULL); |
| 93 | + char * output_buf = (char *) calloc(length_needed+1, sizeof(char)); |
| 94 | + GGML_ASSERT(output_buf); |
| 95 | + |
| 96 | + WideCharToMultiByte(CP_UTF8, 0, wargv[i], wcslen(wargv[i]), output_buf, length_needed, NULL, NULL); |
| 97 | + output_buf[length_needed] = '\0'; |
| 98 | + |
| 99 | + argv.push_back(output_buf); |
| 100 | + free(output_buf); |
| 101 | + } |
| 102 | + |
| 103 | + LocalFree((HLOCAL) wargv); |
| 104 | +#else |
| 105 | + int argc = raw_argc; |
| 106 | + for (int i = 0; i < argc; ++i) { |
| 107 | + argv.push_back(raw_argv[i]); |
| 108 | + } |
| 109 | +#endif |
| 110 | + |
| 111 | + GGML_ASSERT((unsigned int) argc == argv.size()); |
| 112 | + |
| 113 | + return argv; |
| 114 | +} |
| 115 | + |
| 116 | +// |
| 117 | +// Function: write_utf8_cstr_to_stdout(const char *) -> <writes to stdout> |
| 118 | +// |
| 119 | +// writes a string to standard output; taking into account that on Windows |
| 120 | +// to display correctly you have to use special handling. Works even if the |
| 121 | +// user has not set a unicode code page on a Windows cmd.exe. |
| 122 | +// |
| 123 | +// In case of invalid UTF-8, invalid_utf8 is set to true on Windows, and something |
| 124 | +// a human-readable is written instead. |
| 125 | +// |
| 126 | +// On non-Windows systems, simply printfs() the string. |
| 127 | +static void write_utf8_cstr_to_stdout(const char * str, bool & invalid_utf8) { |
| 128 | + invalid_utf8 = false; |
| 129 | + |
| 130 | +#if defined(_WIN32) |
| 131 | + // Are we in a console? |
| 132 | + HANDLE hConsole = GetStdHandle(STD_OUTPUT_HANDLE); |
| 133 | + DWORD dwMode = 0; |
| 134 | + |
| 135 | + // According to Microsoft docs: |
| 136 | + // "WriteConsole fails if it is used with a standard handle that is redirected to a file." |
| 137 | + // Also according to the docs, you can use GetConsoleMode to check for that. |
| 138 | + if (hConsole == INVALID_HANDLE_VALUE || !GetConsoleMode(hConsole, &dwMode)) { |
| 139 | + printf("%s", str); |
| 140 | + return; |
| 141 | + } |
| 142 | + |
| 143 | + // MultiByteToWideChar reports an error if str is empty, don't report |
| 144 | + // them as invalid_utf8. |
| 145 | + if (*str == 0) { |
| 146 | + return; |
| 147 | + } |
| 148 | + int length_needed = MultiByteToWideChar(CP_UTF8, MB_ERR_INVALID_CHARS, str, strlen(str), NULL, 0); |
| 149 | + if (length_needed == 0) { |
| 150 | + DWORD err = GetLastError(); |
| 151 | + if (err == ERROR_NO_UNICODE_TRANSLATION) { |
| 152 | + invalid_utf8 = true; |
| 153 | + int len = strlen(str); |
| 154 | + printf("<"); |
| 155 | + for (int i = 0; i < len; ++i) { |
| 156 | + if (i > 0) { |
| 157 | + printf(" "); |
| 158 | + } |
| 159 | + printf("%02x", (uint8_t) str[i]); |
| 160 | + } |
| 161 | + printf(">"); |
| 162 | + return; |
| 163 | + } |
| 164 | + GGML_ASSERT(false && "MultiByteToWideChar() failed in an unexpected way."); |
| 165 | + } |
| 166 | + |
| 167 | + LPWSTR wstr = (LPWSTR) calloc(length_needed+1, sizeof(*wstr)); |
| 168 | + GGML_ASSERT(wstr); |
| 169 | + |
| 170 | + MultiByteToWideChar(CP_UTF8, 0, str, strlen(str), wstr, length_needed); |
| 171 | + WriteConsoleW(hConsole, wstr, length_needed, NULL, NULL); |
| 172 | + |
| 173 | + free(wstr); |
| 174 | +#else |
| 175 | + // TODO: reporting invalid_utf8 would be useful on non-Windows too. |
| 176 | + // printf will silently just write bad unicode. |
| 177 | + printf("%s", str); |
| 178 | +#endif |
| 179 | +} |
| 180 | + |
| 181 | +int main(int raw_argc, char ** raw_argv) { |
| 182 | + const std::vector<std::string> argv = ingest_args(raw_argc, raw_argv); |
| 183 | + const int argc = argv.size(); |
| 184 | + |
| 185 | + if (argc <= 1) { |
| 186 | + print_usage_information(argv[0].c_str(), stderr); |
| 187 | + return 1; |
| 188 | + } |
| 189 | + |
| 190 | + ////// |
| 191 | + // Read out all the command line arguments. |
| 192 | + ////// |
| 193 | + |
| 194 | + // variables where to put any arguments we see. |
| 195 | + bool printing_ids = false; |
| 196 | + bool no_bos = false; |
| 197 | + bool disable_logging = false; |
| 198 | + const char * model_path = NULL; |
| 199 | + const char * prompt_path = NULL; |
| 200 | + const char * prompt_arg = NULL; |
| 201 | + |
| 202 | + // track which arguments were explicitly given |
| 203 | + // used for sanity checking down the line |
| 204 | + bool model_path_set = false; |
| 205 | + bool prompt_path_set = false; |
| 206 | + bool prompt_set = false; |
| 207 | + bool stdin_set = false; |
| 208 | + |
| 209 | + int iarg = 1; |
| 210 | + for (; iarg < argc; ++iarg) { |
| 211 | + std::string arg{argv[iarg]}; |
| 212 | + if (arg == "-h" || arg == "--help") { |
| 213 | + print_usage_information(argv[0].c_str(), stdout); |
| 214 | + return 0; |
| 215 | + } |
| 216 | + else if (arg == "--ids") { |
| 217 | + printing_ids = true; |
| 218 | + } |
| 219 | + else if (arg == "-m" || arg == "--model") { |
| 220 | + if (model_path_set) { |
| 221 | + fprintf(stderr, "Error: -m or --model specified multiple times.\n"); |
| 222 | + return 1; |
| 223 | + } |
| 224 | + model_path = argv[++iarg].c_str(); |
| 225 | + model_path_set = true; |
| 226 | + } |
| 227 | + else if (arg == "--no-bos") { |
| 228 | + no_bos = true; |
| 229 | + } |
| 230 | + else if (arg == "-p" || arg == "--prompt") { |
| 231 | + if (prompt_set) { |
| 232 | + fprintf(stderr, "Error: -p or --prompt specified multiple times.\n"); |
| 233 | + return 1; |
| 234 | + } |
| 235 | + prompt_arg = argv[++iarg].c_str(); |
| 236 | + prompt_set = true; |
| 237 | + } |
| 238 | + else if (arg == "-f" || arg == "--file") { |
| 239 | + if (prompt_path_set) { |
| 240 | + fprintf(stderr, "Error: -f or --file specified multiple times.\n"); |
| 241 | + return 1; |
| 242 | + } |
| 243 | + prompt_path = argv[++iarg].c_str(); |
| 244 | + prompt_path_set = true; |
| 245 | + } |
| 246 | + else if (arg == "--stdin") { |
| 247 | + stdin_set = true; |
| 248 | + } |
| 249 | + else if (arg == "--log-disable") { |
| 250 | + disable_logging = true; |
| 251 | + } |
| 252 | + else { |
| 253 | + fprintf(stderr, "Error: unknown option '%s'\n", argv[iarg].c_str()); |
| 254 | + return 1; |
| 255 | + } |
| 256 | + } |
| 257 | + |
| 258 | + ////// |
| 259 | + // Sanity check the command line arguments. |
| 260 | + ////// |
| 261 | + |
| 262 | + // Check that we have the required stuff set. |
| 263 | + if (model_path_set && model_path == NULL) { |
| 264 | + fprintf(stderr, "Error: --model requires an argument.\n"); |
| 265 | + return 1; |
| 266 | + } |
| 267 | + if (!model_path_set) { |
| 268 | + fprintf(stderr, "Error: must specify --model.\n"); |
| 269 | + return 1; |
| 270 | + } |
| 271 | + if (prompt_path_set && prompt_path == NULL) { |
| 272 | + fprintf(stderr, "Error: --file requires an argument.\n"); |
| 273 | + return 1; |
| 274 | + } |
| 275 | + if (prompt_set && prompt_arg == NULL) { |
| 276 | + fprintf(stderr, "Error: --prompt requires an argument.\n"); |
| 277 | + return 1; |
| 278 | + } |
| 279 | + const int prompts_set = !!(prompt_path_set) + !!(prompt_set) + !!(stdin_set); |
| 280 | + if (prompts_set > 1) { |
| 281 | + fprintf(stderr, "Error: --stdin, --file and --prompt are mutually exclusive.\n"); |
| 282 | + return 1; |
| 283 | + } |
| 284 | + // Must have some prompt. |
| 285 | + if (prompts_set == 0) { |
| 286 | + fprintf(stderr, "Error: must specify one of: --stdin, --file or --prompt.\n"); |
12 | 287 | return 1;
|
13 | 288 | }
|
14 | 289 |
|
15 |
| - const char * model_path = argv[1]; |
16 |
| - const char * prompt = argv[2]; |
| 290 | + GGML_ASSERT(model_path); |
| 291 | + GGML_ASSERT(prompt_path || prompt_arg || stdin_set); |
17 | 292 |
|
18 |
| - const bool printing_ids = argc > 3 && std::string(argv[3]) == "--ids"; |
| 293 | + ////// |
| 294 | + // Figure out where will the prompt come from. |
| 295 | + ////// |
| 296 | + |
| 297 | + std::string prompt; |
| 298 | + if (prompt_path_set) { |
| 299 | + bool success = false; |
| 300 | + prompt = read_prompt_from_file(prompt_path, success); |
| 301 | + if (!success) { |
| 302 | + return 1; |
| 303 | + } |
| 304 | + } else if (prompt_set) { |
| 305 | + prompt = prompt_arg; |
| 306 | + } else { |
| 307 | + GGML_ASSERT(stdin_set); |
| 308 | + // we read stdin *after* loading model (early exit if model cannot |
| 309 | + // be loaded, which can be a nicer user experience) |
| 310 | + } |
| 311 | + |
| 312 | + ////// |
| 313 | + // Start actually doing the tokenizing stuff. |
| 314 | + ////// |
| 315 | + |
| 316 | +#ifdef LOG_DISABLE_LOGS |
| 317 | + disable_logging = true; |
| 318 | +#endif |
| 319 | + |
| 320 | + if (disable_logging) { |
| 321 | + llama_log_set(llama_log_callback_null, NULL); |
| 322 | + } |
19 | 323 |
|
20 | 324 | llama_backend_init();
|
21 | 325 |
|
22 | 326 | llama_model_params model_params = llama_model_default_params();
|
23 | 327 | model_params.vocab_only = true;
|
24 | 328 | llama_model * model = llama_load_model_from_file(model_path, model_params);
|
| 329 | + if (!model) { |
| 330 | + fprintf(stderr, "Error: could not load model from file '%s'.\n", model_path); |
| 331 | + return 1; |
| 332 | + } |
25 | 333 |
|
26 | 334 | llama_context_params ctx_params = llama_context_default_params();
|
27 | 335 | llama_context * ctx = llama_new_context_with_model(model, ctx_params);
|
| 336 | + if (!ctx) { |
| 337 | + fprintf(stderr, "Error: could not create context.\n"); |
| 338 | + return 1; |
| 339 | + } |
| 340 | + |
| 341 | + // read entire prompt from stdin? |
| 342 | + if (stdin_set) { |
| 343 | + GGML_ASSERT(!prompt_path_set && !prompt_set); |
| 344 | + |
| 345 | + std::stringstream stdin_buffer; |
| 346 | + stdin_buffer << std::cin.rdbuf(); |
| 347 | + if (std::cin.fail()) { |
| 348 | + fprintf(stderr, "Error: could not read the entire standard input.\n"); |
| 349 | + return 1; |
| 350 | + } |
| 351 | + |
| 352 | + prompt = stdin_buffer.str(); |
| 353 | + } |
| 354 | + |
| 355 | + const bool model_wants_add_bos = llama_should_add_bos_token(model); |
| 356 | + const bool add_bos = model_wants_add_bos && !no_bos; |
28 | 357 |
|
29 | 358 | std::vector<llama_token> tokens;
|
| 359 | + tokens = ::llama_tokenize(model, prompt, add_bos, true); |
30 | 360 |
|
31 |
| - tokens = ::llama_tokenize(model, prompt, true, true); |
| 361 | + if (printing_ids) { |
| 362 | + printf("["); |
| 363 | + } |
32 | 364 |
|
33 | 365 | for (int i = 0; i < (int) tokens.size(); i++) {
|
34 | 366 | if (printing_ids) {
|
35 |
| - printf("%d\n", tokens[i]); |
| 367 | + if (i > 0) { |
| 368 | + printf(", "); |
| 369 | + } |
| 370 | + printf("%d", tokens[i]); |
36 | 371 | } else {
|
37 |
| - printf("%6d -> '%s'\n", tokens[i], llama_token_to_piece(ctx, tokens[i]).c_str()); |
| 372 | + bool invalid_utf8 = false; |
| 373 | + printf("%6d -> '", tokens[i]); |
| 374 | + write_utf8_cstr_to_stdout(llama_token_to_piece(ctx, tokens[i]).c_str(), invalid_utf8); |
| 375 | + if (invalid_utf8) { |
| 376 | + printf("' (utf-8 decode failure)\n"); |
| 377 | + } else { |
| 378 | + printf("'\n"); |
| 379 | + } |
38 | 380 | }
|
39 | 381 | }
|
40 | 382 |
|
| 383 | + if (printing_ids) { |
| 384 | + printf("]\n"); |
| 385 | + } |
| 386 | + |
| 387 | + // silence valgrind |
| 388 | + llama_free(ctx); |
| 389 | + llama_free_model(model); |
| 390 | + |
41 | 391 | return 0;
|
42 | 392 | }
|
0 commit comments