From bba6759bc0d8efdf5d41969cd42b10389460ebc2 Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Fri, 14 Feb 2025 00:57:56 -0800 Subject: [PATCH] Add llama2.c tokenizers Differential Revision: D69579081 Pull Request resolved: https://github.com/pytorch-labs/tokenizers/pull/19 --- README.md | 8 +- include/llama2c_tokenizer.h | 43 +++ src/llama2c_tokenizer.cpp | 315 ++++++++++++++++++++++ targets.bzl | 13 + test/resources/test_llama2c_tokenizer.bin | Bin 0 -> 16 bytes test/test_llama2c_tokenizer.cpp | 80 ++++++ tools/llama2c/TARGETS | 8 + tools/llama2c/__init__.py | 0 tools/llama2c/convert.py | 150 +++++++++++ tools/llama2c/targets.bzl | 40 +++ 10 files changed, 656 insertions(+), 1 deletion(-) create mode 100644 include/llama2c_tokenizer.h create mode 100644 src/llama2c_tokenizer.cpp create mode 100644 test/resources/test_llama2c_tokenizer.bin create mode 100644 test/test_llama2c_tokenizer.cpp create mode 100644 tools/llama2c/TARGETS create mode 100644 tools/llama2c/__init__.py create mode 100644 tools/llama2c/convert.py create mode 100644 tools/llama2c/targets.bzl diff --git a/README.md b/README.md index 1520df2..28528fe 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,13 @@ C++ implementations for various tokenizers (sentencepiece, tiktoken etc). Useful Depend on https://github.com/google/sentencepiece from Google. ## Tiktoken tokenizer -Adopted from https://github.com/sewenew/tokenizer. +Adapted from https://github.com/sewenew/tokenizer. + +## Huggingface tokenizer +Compatible with https://github.com/huggingface/tokenizers/. + +## Llama2.c tokenizer +Adapted from https://github.com/karpathy/llama2.c. ## License diff --git a/include/llama2c_tokenizer.h b/include/llama2c_tokenizer.h new file mode 100644 index 0000000..fc8418d --- /dev/null +++ b/include/llama2c_tokenizer.h @@ -0,0 +1,43 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +// @lint-ignore-every CLANGTIDY facebook-hte-RelativeInclude +#pragma once +#include +#include "tokenizer.h" + +namespace tokenizers { + +struct TokenIndex { + const char* str; + int32_t id; +}; + +// A simple Byte Pair Encoding (BPE) Tokenizer. Note that the current C++ code +// won't work with this class, it needs to go through tokenizer.py first. +class Llama2cTokenizer : public Tokenizer { + public: + explicit Llama2cTokenizer(); + ~Llama2cTokenizer() override; + + Error load(const std::string& tokenizer_path) override; + + Result> + encode(const std::string& input, int8_t bos, int8_t eos) const override; + + Result decode(uint64_t prev_token, uint64_t token) + const override; + + private: + std::unique_ptr vocab_ = nullptr; + std::unique_ptr vocab_scores_ = nullptr; + std::unique_ptr sorted_vocab_ = nullptr; + unsigned int max_token_length_ = 0; + unsigned char byte_pieces_[512]; // stores all single-byte strings +}; + +} // namespace tokenizers diff --git a/src/llama2c_tokenizer.cpp b/src/llama2c_tokenizer.cpp new file mode 100644 index 0000000..e73089d --- /dev/null +++ b/src/llama2c_tokenizer.cpp @@ -0,0 +1,315 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +// @lint-ignore-every CLANGTIDY facebook-hte-RelativeInclude +#include "llama2c_tokenizer.h" +#include + +namespace tokenizers { + +static int compare_tokens(const void* a, const void* b) { + if (((TokenIndex*)a)->str == nullptr) { + return -1; + } + if (((TokenIndex*)b)->str == nullptr) { + return 1; + } + return strcmp(((TokenIndex*)a)->str, ((TokenIndex*)b)->str); +} + +Llama2cTokenizer::Llama2cTokenizer() : Tokenizer() { + for (int i = 0; i < 256; i++) { + byte_pieces_[i * 2] = (unsigned char)i; + byte_pieces_[i * 2 + 1] = '\0'; + } +} + +/** + * @brief Load the tokenizer from a file. The tokenizer file contains the + * vocabulary and scores. The format is: the first integer is the maximum + * token length, followed by a list of (word_len, word) pairs. Here we + * are reading all the vocabulary into memory and keep it sorted for fast + * lookup. + * + * @param tokenizer_path The path to the tokenizer file. + * @return Error + */ +Error Llama2cTokenizer::load(const std::string& tokenizer_path) { + if (initialized_) { + TK_LOG(Info, "Tokenizer already initialized"); + return Error::Ok; + } + // read in the file + FILE* file = fopen(tokenizer_path.c_str(), "rb"); + if (!file) { + TK_LOG(Error, "couldn't load %s", tokenizer_path.c_str()); + return Error::LoadFailure; + } + int32_t metadata[4]; + for (int i = 0; i < 4; i++) { + if (fread(metadata + i, sizeof(int32_t), 1, file) != 1) { + TK_LOG( + Error, + "Failed to read the metadata at position %d, the tokenizer file is not valid!", + i); + return Error::ParseFailure; + } + } + + // now we have two vocab_sizes one from the model and another from the + // tokenizer file. + int32_t tokenizer_vocab_size = metadata[0]; + vocab_size_ = tokenizer_vocab_size; + bos_tok_ = metadata[1]; + eos_tok_ = metadata[2]; + max_token_length_ = metadata[3]; + + // allocate space for the vocabulary + vocab_ = std::make_unique(vocab_size_); + vocab_scores_ = std::make_unique(vocab_size_); + sorted_vocab_ = std::make_unique(vocab_size_); + + // read in the vocabulary + for (int i = 0; i < vocab_size_; i++) { + if (fread(vocab_scores_.get() + i, sizeof(float), 1, file) != 1) { + // This is allowed, we just pad the rest of the vocab with strings + std::string padding = ""; + vocab_[i] = new char[padding.length() + 1]; + strcpy(vocab_[i], padding.c_str()); + vocab_[i][padding.length()] = '\0'; + continue; + } + int32_t len; + if (fread(&len, sizeof(int32_t), 1, file) != 1) { + TK_LOG(Error, "Failed to read the length of the word at index %d", i); + return Error::ParseFailure; + } + vocab_[i] = new char[len + 1]; + if (fread(vocab_[i], len, 1, file) != 1) { + TK_LOG( + Error, + "Failed to read the word, total length %d, index %d\n", + len, + i); + return Error::ParseFailure; + } + vocab_[i][len] = '\0'; // add the string terminating token + } + fclose(file); + + for (int32_t i = 0; i < vocab_size_; i++) { + sorted_vocab_[i].str = vocab_[i]; + sorted_vocab_[i].id = i; + } + qsort(sorted_vocab_.get(), vocab_size_, sizeof(TokenIndex), compare_tokens); + + initialized_ = true; + return Error::Ok; +} + +Llama2cTokenizer::~Llama2cTokenizer() { + for (int i = 0; i < vocab_size_; i++) { + delete[] vocab_[i]; + } +} + +/** + * @brief Decode a token into string. + * + * @param prev_token The previous token. + * @param token The current token. + * @return Result A pointer to the string representation of the + * token. + */ +Result Llama2cTokenizer::decode( + uint64_t prev_token, + uint64_t token) const { + TK_CHECK_OK_OR_RETURN_ERROR(Tokenizer::decode_verify(token)); + const char* piece = vocab_[token]; + // following BOS token, sentencepiece decoder strips any leading + // whitespace + if (prev_token == bos_tok_ && piece[0] == ' ') { + piece++; + } + // careful, some tokens designate raw bytes, and look like e.g. '<0x01>' + // parse this and convert and return the actual byte + unsigned char byte_val; + if (sscanf(piece, "<0x%02hhX>", &byte_val) == 1) { + piece = (char*)byte_pieces_ + byte_val * 2; + } + std::string res(piece); + return res; +} + +static int32_t +str_lookup(const char* str, TokenIndex* sorted_vocab, int32_t vocab_size) { + // efficiently find the perfect match for str in vocab, return its index or -1 + // if not found + TokenIndex tok = {.str = str}; // acts as the key to search for + TokenIndex* res = (TokenIndex*)bsearch( + &tok, sorted_vocab, vocab_size, sizeof(TokenIndex), compare_tokens); + return res != nullptr ? res->id : -1; +} + +/** + * @brief Encode a string into a sequence of tokens. + * + * @param text The string to be encoded. + * @param bos The number of BOS to prepend to the token list. + * @param eos The number of EOS to append to the token list. + * @param tokens The output tokens. + * @param n_tokens The number of tokens. + * @return Result> + */ +Result> Llama2cTokenizer::encode( + const std::string& text, + int8_t bos, + int8_t eos) const { + if (!initialized_) { + TK_LOG(Error, "Tokenizer not initialized"); + return Error::Uninitialized; + } + // encode the string text (input) into an upper-bound preallocated tokens[] + // array bos != 0 means prepend the BOS token (=1), eos != 0 means append the + // EOS token (=2) + if (text.empty()) { + TK_LOG(Error, "cannot encode empty text"); + return Error::EncodeFailure; + } + + // create a temporary buffer that will store merge candidates of always two + // consecutive tokens *2 for concat, +1 for null terminator +2 for UTF8 (in + // case max_token_length is 1) + char* str_buffer = new char[max_token_length_ * 2 + 1 + 2]; + size_t str_len = 0; + + // start at 0 tokens + std::vector tokens; + + // add optional BOS token, if desired + if (bos >= 0) { + while (bos--) { + tokens.push_back(bos_tok_); + } + } else { + TK_LOG(Error, "bos %d should be >= 0", bos); + return Error::EncodeFailure; + } + + // add_dummy_prefix is true by default + // so prepend a dummy prefix token to the input string, but only if text != "" + // TODO: pretty sure this isn't correct in the general case but I don't have + // the energy to read more of the sentencepiece code to figure out what it's + // doing + const char* space = " "; + if (text[0] != '\0') { + int dummy_prefix = str_lookup(space, sorted_vocab_.get(), vocab_size_); + tokens.push_back(dummy_prefix); + } + + // Okay UTF-8 time. This will get messy. Here is the reference from Wikipedia: + // Code point ↔ UTF-8 conversion + // First code point Last code point Byte 1 Byte 2 Byte 3 Byte 4 + // U+0000 U+007F 0xxxxxxx + // U+0080 U+07FF 110xxxxx 10xxxxxx + // U+0800 U+FFFF 1110xxxx 10xxxxxx 10xxxxxx + // U+10000 U+10FFFF 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx + + // process the raw (UTF-8) byte sequence of the input string + for (const char* c = text.c_str(); *c != '\0'; c++) { + // reset buffer if the current byte is ASCII or a leading byte + // 0xC0 is 11000000, so (*c & 0xC0) keeps the first 2 bits and zeros the + // rest 0x80 is 10000000 in UTF-8, all continuation bytes start with "10" in + // first two bits so in English this is: "if this byte is not a continuation + // byte" + if ((*c & 0xC0) != 0x80) { + // this byte must be either a leading byte (11...) or an ASCII char + // (0x...) + // => reset our location, as we're starting a new UTF-8 codepoint + str_len = 0; + } + + // append the current byte to the buffer + str_buffer[str_len++] = + *c; // ++ is post-increment, incremented after this line + str_buffer[str_len] = '\0'; + + // while the next character is a continuation byte, continue appending + // but if there are too many of them, just stop to avoid overruning + // str_buffer size. + if ((*(c + 1) & 0xC0) == 0x80 && str_len < 4) { + continue; + } + + // ok c+1 is not a continuation byte, so we've read in a full codepoint + int id = str_lookup(str_buffer, sorted_vocab_.get(), vocab_size_); + if (id != -1) { + // we found this codepoint in vocab, add it as a token + tokens.push_back(id); + } else { + // byte_fallback encoding: just encode each byte as a token + // +3 is here because the first 3 vocab elements are , , + // so the individual bytes only start at index 3 + for (int i = 0; i < str_len; i++) { + tokens.push_back((unsigned char)str_buffer[i] + 3); + } + } + str_len = 0; // protect against a sequence of stray UTF8 continuation bytes + } + + // merge the best consecutive pair each iteration, according the scores in + // vocab_scores + while (1) { + float best_score = -1e10; + int best_id = -1; + int best_idx = -1; + + for (int i = 0; i < tokens.size() - 1; i++) { + // check if we can merge the pair (tokens[i], tokens[i+1]) + snprintf( + str_buffer, + max_token_length_ * 2 + 3, + "%s%s", + vocab_[tokens[i]], + vocab_[tokens[i + 1]]); + int id = str_lookup(str_buffer, sorted_vocab_.get(), vocab_size_); + if (id != -1 && vocab_scores_[id] > best_score) { + // this merge pair exists in vocab! record its score and position + best_score = vocab_scores_[id]; + best_id = id; + best_idx = i; + } + } + + if (best_idx == -1) { + break; // we couldn't find any more pairs to merge, so we're done + } + + // merge the consecutive pair (best_idx, best_idx+1) into new token best_id + tokens[best_idx] = best_id; + // delete token at position best_idx+1, shift the entire sequence back 1 + for (int i = best_idx + 1; i < tokens.size() - 1; i++) { + tokens[i] = tokens[i + 1]; + } + tokens.pop_back(); // token length decreased + } + + // add optional EOS (=2) token, if desired + if (eos >= 0) { + while (eos--) { + tokens.push_back(eos_tok_); + } + } else { + TK_LOG(Error, "eos %d should be >= 0", eos); + return Error::EncodeFailure; + } + + delete[] str_buffer; + return Result(tokens); +} + +} // namespace tokenizers diff --git a/targets.bzl b/targets.bzl index 9bd9116..dd26998 100644 --- a/targets.bzl +++ b/targets.bzl @@ -94,3 +94,16 @@ def define_common_targets(): "nlohmann_json", ], ) + + runtime.cxx_library( + name = "llama2c_tokenizer", + srcs = [ + "src/llama2c_tokenizer.cpp", + ], + exported_deps = [ + ":headers", + ], + visibility = [ + "@EXECUTORCH_CLIENTS", + ], + ) diff --git a/test/resources/test_llama2c_tokenizer.bin b/test/resources/test_llama2c_tokenizer.bin new file mode 100644 index 0000000000000000000000000000000000000000..01d633b27e8ea9b17084fc911d0c8cc43a4170a9 GIT binary patch literal 16 KcmZQzKm`B*5C8!H literal 0 HcmV?d00001 diff --git a/test/test_llama2c_tokenizer.cpp b/test/test_llama2c_tokenizer.cpp new file mode 100644 index 0000000..72abc48 --- /dev/null +++ b/test/test_llama2c_tokenizer.cpp @@ -0,0 +1,80 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifdef TOKENIZERS_FB_BUCK +#include +#endif +#include +#include "llama2c_tokenizer.h" + +using namespace ::testing; + +namespace tokenizers { + +namespace { +// Test case based on llama2.c tokenizer +static inline std::string _get_resource_path(const std::string& name) { +#ifdef TOKENIZERS_FB_BUCK + return facebook::xplat::testing::getPathForTestResource( + "test/resources/" + name); +#else + return std::getenv("RESOURCES_PATH") + std::string("/") + name; +#endif +} + +} // namespace + +class Llama2cTokenizerTest : public Test { + public: + void SetUp() override { + tokenizer_ = std::make_unique(); + modelPath_ = _get_resource_path("test_llama2c_tokenizer.bin"); + } + + std::unique_ptr tokenizer_; + std::string modelPath_; +}; + +TEST_F(Llama2cTokenizerTest, EncodeWithoutLoadFails) { + Result> res = tokenizer_->encode("hello world", 0, 0); + EXPECT_EQ(res.error(), Error::Uninitialized); +} + +TEST_F(Llama2cTokenizerTest, DecodeWithoutLoadFails) { + auto result = tokenizer_->decode(0, 0); + EXPECT_EQ(result.error(), Error::Uninitialized); +} + +TEST_F(Llama2cTokenizerTest, DecodeOutOfRangeFails) { + Error res = tokenizer_->load(modelPath_.c_str()); + EXPECT_EQ(res, Error::Ok); + auto result = tokenizer_->decode(0, 64000); + // The vocab size is 32000, and token 64000 is out of vocab range. + EXPECT_EQ(result.error(), Error::OutOfRange); +} + +TEST_F(Llama2cTokenizerTest, TokenizerMetadataIsExpected) { + Error res = tokenizer_->load(modelPath_.c_str()); + EXPECT_EQ(res, Error::Ok); + // test_bpe_tokenizer.bin has vocab_size 0, bos_id 0, eos_id 0 recorded. + EXPECT_EQ(tokenizer_->vocab_size(), 0); + EXPECT_EQ(tokenizer_->bos_tok(), 0); + EXPECT_EQ(tokenizer_->eos_tok(), 0); +} + +TEST_F(Llama2cTokenizerTest, SafeToDestruct) { + // Safe to destruct initialized tokenizer. + tokenizer_->load(modelPath_); + tokenizer_.reset(); + + // Safe to destruct uninitialized tokenizer. + tokenizer_ = std::make_unique(); + tokenizer_.reset(); +} + +} // namespace tokenizers diff --git a/tools/llama2c/TARGETS b/tools/llama2c/TARGETS new file mode 100644 index 0000000..1e8cc17 --- /dev/null +++ b/tools/llama2c/TARGETS @@ -0,0 +1,8 @@ +# Any targets that should be shared between fbcode and xplat must be defined in +# targets.bzl. This file can contain xplat-only targets. + +load(":targets.bzl", "define_common_targets") + +oncall("executorch") + +define_common_targets() diff --git a/tools/llama2c/__init__.py b/tools/llama2c/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tools/llama2c/convert.py b/tools/llama2c/convert.py new file mode 100644 index 0000000..1f915fc --- /dev/null +++ b/tools/llama2c/convert.py @@ -0,0 +1,150 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +# Script to rewrite tokenizer model given by sentencepiece to llama2.c format, with lightweight +# postprocessing logic. The output can be consumed by llama2c_tokenizer.cpp. + +import argparse +import logging +import os +import struct +from typing import List + +from sentencepiece import SentencePieceProcessor as SentencePieceProcessor + + +class Tokenizer: + def __init__(self, model_path: str): + assert os.path.isfile( + model_path + ), f"Need a valid tokenizer model path but got {model_path}" + # pyre-fixme[28]: Unexpected keyword argument `model_file` to call `SentencePieceProcessor.__init__`. + self.sp_model = SentencePieceProcessor(model_file=model_path) + self.model_path = model_path + + # BOS / EOS token IDs + self.n_words: int = self.sp_model.vocab_size() + self.bos_id: int = self.sp_model.bos_id() + self.eos_id: int = self.sp_model.eos_id() + logging.info( + f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}" + ) + # pyre-fixme[16]: `SentencePieceProcessor` has no attribute `get_piece_size`. + assert self.sp_model.vocab_size() == self.sp_model.get_piece_size() + + def encode(self, s: str, bos: bool, eos: bool) -> List[int]: + assert type(s) is str + # pyre-fixme[16]: `SentencePieceProcessor` has no attribute `encode`. + t = self.sp_model.encode(s) + if bos: + t = [self.bos_id] + t + if eos: + t = t + [self.eos_id] + return t + + def decode(self, t: List[int]) -> str: + # pyre-fixme[16]: `SentencePieceProcessor` has no attribute `encode`. + return self.sp_model.decode(t) + + def decode_token(self, t: int) -> str: + # pyre-fixme[16]: `SentencePieceProcessor` has no attribute `encode`. + return self.sp_model.decode(t) + + def export(self, output_path: str, *, prepend_padding: bool = False) -> None: + """ + Export tokenizer.model to another serialization format. Here we did some lightweight + processing such as supporting prepend padding token, prepend max token length and + replace '_' back to empty space. + + The binary format is: + 1. vocab size: int32 + 2. bos token id: int32 + 3. eos token id: int32 + 4. max token length: int32 + 5. score: float32, len of bytes: int32, token bytes: [byte] for each token + + :param output_path: output path of the new binary. + :param prepend_padding: a boolean to control if we want to prepend a padding token. + + :return: None + """ + + # get all the tokens (postprocessed) and their scores as floats + tokens, scores = [], [] + + if prepend_padding: + # Here we use the default padding token and its score. + tokens.append("".encode("utf-8")) + scores.append(-1) + + for i in range(self.n_words): + # decode the token and light postprocessing + # pyre-fixme[16]: `SentencePieceProcessor` has no attribute `id_to_piece`. + t = self.sp_model.id_to_piece(i) + # pyre-fixme[16]: `SentencePieceProcessor` has no attribute `get_score`. + s = self.sp_model.get_score(i) + # sentencepiece use '' as BOS and '' for EOS + if i == self.bos_id: + t = "" + elif i == self.eos_id: + t = "" + t = t.replace("▁", " ") # sentencepiece uses this character as whitespace + b = t.encode("utf-8") # bytes of this token, utf-8 encoded + + tokens.append(b) + scores.append(s) + + # record the max token length + max_token_length = 0 if not tokens else max(len(t) for t in tokens) + + # write to a binary file + with open(output_path, "wb") as f: + # write the vocab size, bos/eos ids and max token length + f.write( + struct.pack( + "IIII", self.n_words, self.bos_id, self.eos_id, max_token_length + ) + ) + for bytes, score in zip(tokens, scores): + f.write(struct.pack("fI", score, len(bytes))) + f.write(bytes) + logging.info(f"Wrote tokenizer to {output_path}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "-t", + "--tokenizer-model", + type=str, + default="tokenizer.model", + help="path to tokenizer model, given by sentencepiece", + ) + parser.add_argument( + "-o", + "--output-path", + type=str, + default=None, + help="output path of postprocessed tokenizer model", + ) + parser.add_argument( + "-p", + "--prepend-padding", + action="store_true", + help="whether to prepend a padding token to the beginning of the tokenizer", + ) + + args = parser.parse_args() + + t = Tokenizer(args.tokenizer_model) + + output_path = ( + args.output_path + if args.output_path + else args.tokenizer_model.replace(".model", ".bin") + ) + t.export(output_path, prepend_padding=args.prepend_padding) diff --git a/tools/llama2c/targets.bzl b/tools/llama2c/targets.bzl new file mode 100644 index 0000000..2449d4e --- /dev/null +++ b/tools/llama2c/targets.bzl @@ -0,0 +1,40 @@ +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +def define_common_targets(): + """Defines targets that should be shared between fbcode and xplat. + + The directory containing this targets.bzl file should also contain both + TARGETS and BUCK files that call this function. + """ + runtime.python_library( + name = "convert_lib", + srcs = [ + "__init__.py", + "convert.py", + ], + base_module = "pytorch.tokenizers.tools.llama2c", + visibility = [ + "//executorch/examples/...", + "//executorch/extension/llm/export/...", + "//bento/...", + "//bento_kernels/...", + "@EXECUTORCH_CLIENTS", + ], + _is_external_target = True, + external_deps = [ + "sentencepiece-py", + ], + ) + + runtime.python_binary( + name = "convert", + main_module = "pytorch.tokenizers.tools.llama2c.convert", + visibility = [ + "//executorch/examples/...", + "fbsource//xplat/executorch/examples/...", + ], + _is_external_target = True, + deps = [ + ":convert_lib", + ], + )