Skip to content

Commit 6335114

Browse files
authored
quantize : improve type name parsing (ggml-org#9570)
quantize : do not ignore invalid types in arg parsing quantize : ignore case of type and ftype arguments
1 parent d13edb1 commit 6335114

File tree

1 file changed

+23
-7
lines changed

1 file changed

+23
-7
lines changed

examples/quantize/quantize.cpp

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -63,14 +63,24 @@ static const char * const LLM_KV_QUANTIZE_IMATRIX_DATASET = "quantize.imatrix
6363
static const char * const LLM_KV_QUANTIZE_IMATRIX_N_ENTRIES = "quantize.imatrix.entries_count";
6464
static const char * const LLM_KV_QUANTIZE_IMATRIX_N_CHUNKS = "quantize.imatrix.chunks_count";
6565

66+
static bool striequals(const char * a, const char * b) {
67+
while (*a && *b) {
68+
if (std::tolower(*a) != std::tolower(*b)) {
69+
return false;
70+
}
71+
a++; b++;
72+
}
73+
return *a == *b;
74+
}
75+
6676
static bool try_parse_ftype(const std::string & ftype_str_in, llama_ftype & ftype, std::string & ftype_str_out) {
6777
std::string ftype_str;
6878

6979
for (auto ch : ftype_str_in) {
7080
ftype_str.push_back(std::toupper(ch));
7181
}
7282
for (auto & it : QUANT_OPTIONS) {
73-
if (it.name == ftype_str) {
83+
if (striequals(it.name.c_str(), ftype_str.c_str())) {
7484
ftype = it.ftype;
7585
ftype_str_out = it.name;
7686
return true;
@@ -225,15 +235,15 @@ static int prepare_imatrix(const std::string & imatrix_file,
225235
}
226236

227237
static ggml_type parse_ggml_type(const char * arg) {
228-
ggml_type result = GGML_TYPE_COUNT;
229-
for (int j = 0; j < GGML_TYPE_COUNT; ++j) {
230-
auto type = ggml_type(j);
238+
for (int i = 0; i < GGML_TYPE_COUNT; ++i) {
239+
auto type = (ggml_type)i;
231240
const auto * name = ggml_type_name(type);
232-
if (name && strcmp(arg, name) == 0) {
233-
result = type; break;
241+
if (name && striequals(name, arg)) {
242+
return type;
234243
}
235244
}
236-
return result;
245+
fprintf(stderr, "%s: invalid ggml_type '%s'\n", __func__, arg);
246+
return GGML_TYPE_COUNT;
237247
}
238248

239249
int main(int argc, char ** argv) {
@@ -254,12 +264,18 @@ int main(int argc, char ** argv) {
254264
} else if (strcmp(argv[arg_idx], "--output-tensor-type") == 0) {
255265
if (arg_idx < argc-1) {
256266
params.output_tensor_type = parse_ggml_type(argv[++arg_idx]);
267+
if (params.output_tensor_type == GGML_TYPE_COUNT) {
268+
usage(argv[0]);
269+
}
257270
} else {
258271
usage(argv[0]);
259272
}
260273
} else if (strcmp(argv[arg_idx], "--token-embedding-type") == 0) {
261274
if (arg_idx < argc-1) {
262275
params.token_embedding_type = parse_ggml_type(argv[++arg_idx]);
276+
if (params.token_embedding_type == GGML_TYPE_COUNT) {
277+
usage(argv[0]);
278+
}
263279
} else {
264280
usage(argv[0]);
265281
}

0 commit comments

Comments
 (0)