Skip to content

Commit 5768433

Browse files
Noedaggerganovmofosyne
authored
Make tokenize CLI tool have nicer command line arguments. (ggml-org#6188)
* Make tokenizer.cpp CLI tool nicer. Before this commit, tokenize was a simple CLI tool like this: tokenize MODEL_FILENAME PROMPT [--ids] This simple tool loads the model, takes the prompt, and shows the tokens llama.cpp is interpreting. This changeset makes the tokenize more sophisticated, and more useful for debugging and troubleshooting: tokenize [-m, --model MODEL_FILENAME] [--ids] [--stdin] [--prompt] [-f, --file] [--no-bos] [--log-disable] It also behaves nicer on Windows now, interpreting and rendering Unicode from command line arguments and pipes no matter what code page the user has set on their terminal. * style fix: strlen(str) == 0 --> *str == 0 * Simplify tokenize.cpp; by getting rid of handling positional style arguments. It must now be invoked with long --model, --prompt etc. arguments only. Shortens the code. * tokenize.cpp: iostream header no longer required --------- Co-authored-by: Georgi Gerganov <[email protected]> Co-authored-by: brian khuu <[email protected]>
1 parent b83bab1 commit 5768433

File tree

1 file changed

+359
-9
lines changed

1 file changed

+359
-9
lines changed

examples/tokenize/tokenize.cpp

Lines changed: 359 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,40 +3,390 @@
33

44
#include <cmath>
55
#include <cstdio>
6+
#include <fstream>
67
#include <string>
78
#include <vector>
89

9-
int main(int argc, char ** argv) {
10-
if (argc < 3 || argv[1][0] == '-') {
11-
printf("usage: %s MODEL_PATH PROMPT [--ids]\n" , argv[0]);
10+
#if defined(_WIN32)
11+
#define WIN32_LEAN_AND_MEAN
12+
#include <windows.h>
13+
#include <shellapi.h> // For CommandLineToArgvW
14+
#endif
15+
16+
static void print_usage_information(const char * argv0, FILE * stream) {
17+
fprintf(stream, "usage: %s [options]\n\n", argv0);
18+
fprintf(stream, "The tokenize program tokenizes a prompt using a given model,\n");
19+
fprintf(stream, "and prints the resulting tokens to standard output.\n\n");
20+
fprintf(stream, "It needs a model file, a prompt, and optionally other flags\n");
21+
fprintf(stream, "to control the behavior of the tokenizer.\n\n");
22+
fprintf(stream, " The possible options are:\n");
23+
fprintf(stream, "\n");
24+
fprintf(stream, " -h, --help print this help and exit\n");
25+
fprintf(stream, " -m MODEL_PATH, --model MODEL_PATH path to model.\n");
26+
fprintf(stream, " --ids if given, only print numerical token IDs, and not token strings.\n");
27+
fprintf(stream, " The output format looks like [1, 2, 3], i.e. parseable by Python.\n");
28+
fprintf(stream, " -f PROMPT_FNAME, --file PROMPT_FNAME read prompt from a file.\n");
29+
fprintf(stream, " -p PROMPT, --prompt PROMPT read prompt from the argument.\n");
30+
fprintf(stream, " --stdin read prompt from standard input.\n");
31+
fprintf(stream, " --no-bos do not ever add a BOS token to the prompt, even if normally the model uses a BOS token.\n");
32+
fprintf(stream, " --log-disable disable logs. Makes stderr quiet when loading the model.\n");
33+
}
34+
35+
static void llama_log_callback_null(ggml_log_level level, const char * text, void * user_data) {
36+
(void) level;
37+
(void) text;
38+
(void) user_data;
39+
}
40+
41+
static std::string read_prompt_from_file(const char * filepath, bool & success) {
42+
success = false;
43+
44+
std::ifstream in(filepath, std::ios::binary);
45+
if (!in) {
46+
fprintf(stderr, "%s: could not open file '%s' for reading: %s\n", __func__, filepath, strerror(errno));
47+
return std::string();
48+
}
49+
// do not assume the file is seekable (e.g. /dev/stdin)
50+
std::stringstream buffer;
51+
buffer << in.rdbuf();
52+
if (in.fail()) {
53+
fprintf(stderr, "%s: could not read the entire file '%s': %s\n", __func__, filepath, strerror(errno));
54+
return std::string();
55+
}
56+
57+
success = true;
58+
return buffer.str();
59+
}
60+
61+
//
62+
// Function: ingest_args(...) -> vector<string>
63+
//
64+
// Takes argc and argv arguments, and converts them to a vector of UTF-8 encoded
65+
// strings, as an STL vector<string>.
66+
//
67+
// In particular, it handles character encoding shenanigans on Windows.
68+
//
69+
// Note: raw_argc and raw_argv are not actually read at all on Windows.
70+
// On Windows we call GetCommandLineW to get the arguments in wchar_t
71+
// format, ignoring the regular argc/argv arguments to main().
72+
//
73+
// TODO: potential opportunity to roll common stuff into common/console.cpp
74+
// in relation to Windows wchar_t shenanigans.
75+
static std::vector<std::string> ingest_args(int raw_argc, char ** raw_argv) {
76+
std::vector<std::string> argv;
77+
78+
// Handle Windows, if given non-ASCII arguments.
79+
// We convert wchar_t arguments into UTF-8 char* on this platform.
80+
// Lets you invoke 'tokenize' on Windows cmd.exe with non-ASCII characters
81+
// without throwing tantrums.
82+
#if defined(_WIN32)
83+
int argc;
84+
const LPWSTR cmdline_wargv = GetCommandLineW();
85+
LPWSTR * wargv = CommandLineToArgvW(cmdline_wargv, &argc);
86+
87+
// silence unused arg warnings
88+
(void) raw_argc;
89+
(void) raw_argv;
90+
91+
for (int i = 0; i < argc; ++i) {
92+
int length_needed = WideCharToMultiByte(CP_UTF8, 0, wargv[i], wcslen(wargv[i]), 0, 0, NULL, NULL);
93+
char * output_buf = (char *) calloc(length_needed+1, sizeof(char));
94+
GGML_ASSERT(output_buf);
95+
96+
WideCharToMultiByte(CP_UTF8, 0, wargv[i], wcslen(wargv[i]), output_buf, length_needed, NULL, NULL);
97+
output_buf[length_needed] = '\0';
98+
99+
argv.push_back(output_buf);
100+
free(output_buf);
101+
}
102+
103+
LocalFree((HLOCAL) wargv);
104+
#else
105+
int argc = raw_argc;
106+
for (int i = 0; i < argc; ++i) {
107+
argv.push_back(raw_argv[i]);
108+
}
109+
#endif
110+
111+
GGML_ASSERT((unsigned int) argc == argv.size());
112+
113+
return argv;
114+
}
115+
116+
//
117+
// Function: write_utf8_cstr_to_stdout(const char *) -> <writes to stdout>
118+
//
119+
// writes a string to standard output; taking into account that on Windows
120+
// to display correctly you have to use special handling. Works even if the
121+
// user has not set a unicode code page on a Windows cmd.exe.
122+
//
123+
// In case of invalid UTF-8, invalid_utf8 is set to true on Windows, and something
124+
// a human-readable is written instead.
125+
//
126+
// On non-Windows systems, simply printfs() the string.
127+
static void write_utf8_cstr_to_stdout(const char * str, bool & invalid_utf8) {
128+
invalid_utf8 = false;
129+
130+
#if defined(_WIN32)
131+
// Are we in a console?
132+
HANDLE hConsole = GetStdHandle(STD_OUTPUT_HANDLE);
133+
DWORD dwMode = 0;
134+
135+
// According to Microsoft docs:
136+
// "WriteConsole fails if it is used with a standard handle that is redirected to a file."
137+
// Also according to the docs, you can use GetConsoleMode to check for that.
138+
if (hConsole == INVALID_HANDLE_VALUE || !GetConsoleMode(hConsole, &dwMode)) {
139+
printf("%s", str);
140+
return;
141+
}
142+
143+
// MultiByteToWideChar reports an error if str is empty, don't report
144+
// them as invalid_utf8.
145+
if (*str == 0) {
146+
return;
147+
}
148+
int length_needed = MultiByteToWideChar(CP_UTF8, MB_ERR_INVALID_CHARS, str, strlen(str), NULL, 0);
149+
if (length_needed == 0) {
150+
DWORD err = GetLastError();
151+
if (err == ERROR_NO_UNICODE_TRANSLATION) {
152+
invalid_utf8 = true;
153+
int len = strlen(str);
154+
printf("<");
155+
for (int i = 0; i < len; ++i) {
156+
if (i > 0) {
157+
printf(" ");
158+
}
159+
printf("%02x", (uint8_t) str[i]);
160+
}
161+
printf(">");
162+
return;
163+
}
164+
GGML_ASSERT(false && "MultiByteToWideChar() failed in an unexpected way.");
165+
}
166+
167+
LPWSTR wstr = (LPWSTR) calloc(length_needed+1, sizeof(*wstr));
168+
GGML_ASSERT(wstr);
169+
170+
MultiByteToWideChar(CP_UTF8, 0, str, strlen(str), wstr, length_needed);
171+
WriteConsoleW(hConsole, wstr, length_needed, NULL, NULL);
172+
173+
free(wstr);
174+
#else
175+
// TODO: reporting invalid_utf8 would be useful on non-Windows too.
176+
// printf will silently just write bad unicode.
177+
printf("%s", str);
178+
#endif
179+
}
180+
181+
int main(int raw_argc, char ** raw_argv) {
182+
const std::vector<std::string> argv = ingest_args(raw_argc, raw_argv);
183+
const int argc = argv.size();
184+
185+
if (argc <= 1) {
186+
print_usage_information(argv[0].c_str(), stderr);
187+
return 1;
188+
}
189+
190+
//////
191+
// Read out all the command line arguments.
192+
//////
193+
194+
// variables where to put any arguments we see.
195+
bool printing_ids = false;
196+
bool no_bos = false;
197+
bool disable_logging = false;
198+
const char * model_path = NULL;
199+
const char * prompt_path = NULL;
200+
const char * prompt_arg = NULL;
201+
202+
// track which arguments were explicitly given
203+
// used for sanity checking down the line
204+
bool model_path_set = false;
205+
bool prompt_path_set = false;
206+
bool prompt_set = false;
207+
bool stdin_set = false;
208+
209+
int iarg = 1;
210+
for (; iarg < argc; ++iarg) {
211+
std::string arg{argv[iarg]};
212+
if (arg == "-h" || arg == "--help") {
213+
print_usage_information(argv[0].c_str(), stdout);
214+
return 0;
215+
}
216+
else if (arg == "--ids") {
217+
printing_ids = true;
218+
}
219+
else if (arg == "-m" || arg == "--model") {
220+
if (model_path_set) {
221+
fprintf(stderr, "Error: -m or --model specified multiple times.\n");
222+
return 1;
223+
}
224+
model_path = argv[++iarg].c_str();
225+
model_path_set = true;
226+
}
227+
else if (arg == "--no-bos") {
228+
no_bos = true;
229+
}
230+
else if (arg == "-p" || arg == "--prompt") {
231+
if (prompt_set) {
232+
fprintf(stderr, "Error: -p or --prompt specified multiple times.\n");
233+
return 1;
234+
}
235+
prompt_arg = argv[++iarg].c_str();
236+
prompt_set = true;
237+
}
238+
else if (arg == "-f" || arg == "--file") {
239+
if (prompt_path_set) {
240+
fprintf(stderr, "Error: -f or --file specified multiple times.\n");
241+
return 1;
242+
}
243+
prompt_path = argv[++iarg].c_str();
244+
prompt_path_set = true;
245+
}
246+
else if (arg == "--stdin") {
247+
stdin_set = true;
248+
}
249+
else if (arg == "--log-disable") {
250+
disable_logging = true;
251+
}
252+
else {
253+
fprintf(stderr, "Error: unknown option '%s'\n", argv[iarg].c_str());
254+
return 1;
255+
}
256+
}
257+
258+
//////
259+
// Sanity check the command line arguments.
260+
//////
261+
262+
// Check that we have the required stuff set.
263+
if (model_path_set && model_path == NULL) {
264+
fprintf(stderr, "Error: --model requires an argument.\n");
265+
return 1;
266+
}
267+
if (!model_path_set) {
268+
fprintf(stderr, "Error: must specify --model.\n");
269+
return 1;
270+
}
271+
if (prompt_path_set && prompt_path == NULL) {
272+
fprintf(stderr, "Error: --file requires an argument.\n");
273+
return 1;
274+
}
275+
if (prompt_set && prompt_arg == NULL) {
276+
fprintf(stderr, "Error: --prompt requires an argument.\n");
277+
return 1;
278+
}
279+
const int prompts_set = !!(prompt_path_set) + !!(prompt_set) + !!(stdin_set);
280+
if (prompts_set > 1) {
281+
fprintf(stderr, "Error: --stdin, --file and --prompt are mutually exclusive.\n");
282+
return 1;
283+
}
284+
// Must have some prompt.
285+
if (prompts_set == 0) {
286+
fprintf(stderr, "Error: must specify one of: --stdin, --file or --prompt.\n");
12287
return 1;
13288
}
14289

15-
const char * model_path = argv[1];
16-
const char * prompt = argv[2];
290+
GGML_ASSERT(model_path);
291+
GGML_ASSERT(prompt_path || prompt_arg || stdin_set);
17292

18-
const bool printing_ids = argc > 3 && std::string(argv[3]) == "--ids";
293+
//////
294+
// Figure out where will the prompt come from.
295+
//////
296+
297+
std::string prompt;
298+
if (prompt_path_set) {
299+
bool success = false;
300+
prompt = read_prompt_from_file(prompt_path, success);
301+
if (!success) {
302+
return 1;
303+
}
304+
} else if (prompt_set) {
305+
prompt = prompt_arg;
306+
} else {
307+
GGML_ASSERT(stdin_set);
308+
// we read stdin *after* loading model (early exit if model cannot
309+
// be loaded, which can be a nicer user experience)
310+
}
311+
312+
//////
313+
// Start actually doing the tokenizing stuff.
314+
//////
315+
316+
#ifdef LOG_DISABLE_LOGS
317+
disable_logging = true;
318+
#endif
319+
320+
if (disable_logging) {
321+
llama_log_set(llama_log_callback_null, NULL);
322+
}
19323

20324
llama_backend_init();
21325

22326
llama_model_params model_params = llama_model_default_params();
23327
model_params.vocab_only = true;
24328
llama_model * model = llama_load_model_from_file(model_path, model_params);
329+
if (!model) {
330+
fprintf(stderr, "Error: could not load model from file '%s'.\n", model_path);
331+
return 1;
332+
}
25333

26334
llama_context_params ctx_params = llama_context_default_params();
27335
llama_context * ctx = llama_new_context_with_model(model, ctx_params);
336+
if (!ctx) {
337+
fprintf(stderr, "Error: could not create context.\n");
338+
return 1;
339+
}
340+
341+
// read entire prompt from stdin?
342+
if (stdin_set) {
343+
GGML_ASSERT(!prompt_path_set && !prompt_set);
344+
345+
std::stringstream stdin_buffer;
346+
stdin_buffer << std::cin.rdbuf();
347+
if (std::cin.fail()) {
348+
fprintf(stderr, "Error: could not read the entire standard input.\n");
349+
return 1;
350+
}
351+
352+
prompt = stdin_buffer.str();
353+
}
354+
355+
const bool model_wants_add_bos = llama_should_add_bos_token(model);
356+
const bool add_bos = model_wants_add_bos && !no_bos;
28357

29358
std::vector<llama_token> tokens;
359+
tokens = ::llama_tokenize(model, prompt, add_bos, true);
30360

31-
tokens = ::llama_tokenize(model, prompt, true, true);
361+
if (printing_ids) {
362+
printf("[");
363+
}
32364

33365
for (int i = 0; i < (int) tokens.size(); i++) {
34366
if (printing_ids) {
35-
printf("%d\n", tokens[i]);
367+
if (i > 0) {
368+
printf(", ");
369+
}
370+
printf("%d", tokens[i]);
36371
} else {
37-
printf("%6d -> '%s'\n", tokens[i], llama_token_to_piece(ctx, tokens[i]).c_str());
372+
bool invalid_utf8 = false;
373+
printf("%6d -> '", tokens[i]);
374+
write_utf8_cstr_to_stdout(llama_token_to_piece(ctx, tokens[i]).c_str(), invalid_utf8);
375+
if (invalid_utf8) {
376+
printf("' (utf-8 decode failure)\n");
377+
} else {
378+
printf("'\n");
379+
}
38380
}
39381
}
40382

383+
if (printing_ids) {
384+
printf("]\n");
385+
}
386+
387+
// silence valgrind
388+
llama_free(ctx);
389+
llama_free_model(model);
390+
41391
return 0;
42392
}

0 commit comments

Comments
 (0)