From 1b35a239b8af9c594cbd4cefd93d3955e60ddd83 Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Tue, 18 Feb 2025 17:59:01 -0800 Subject: [PATCH] Migrate extension/llm/tokenizer python users to use the new repo (#22) Summary: Migrate usages of ``` fbcode//executorch/extension/llm/tokenizer:tokenizer_py ``` to use: ``` fbcode//pytorch/tokenizers/pytorch_tokenizers:tokenizers ``` Differential Revision: D69820450 --- pytorch_tokenizers/TARGETS | 18 +++ pytorch_tokenizers/__init__.py | 26 ++++ pytorch_tokenizers/hf_tokenizer.py | 56 +++++++ pytorch_tokenizers/llama2c.py | 110 ++++++++++++++ pytorch_tokenizers/targets.bzl | 34 +++++ pytorch_tokenizers/tiktoken.py | 225 +++++++++++++++++++++++++++++ tools/llama2c/convert.py | 112 +------------- tools/llama2c/targets.bzl | 18 +-- 8 files changed, 481 insertions(+), 118 deletions(-) create mode 100644 pytorch_tokenizers/TARGETS create mode 100644 pytorch_tokenizers/__init__.py create mode 100644 pytorch_tokenizers/hf_tokenizer.py create mode 100644 pytorch_tokenizers/llama2c.py create mode 100644 pytorch_tokenizers/targets.bzl create mode 100644 pytorch_tokenizers/tiktoken.py diff --git a/pytorch_tokenizers/TARGETS b/pytorch_tokenizers/TARGETS new file mode 100644 index 0000000..d563279 --- /dev/null +++ b/pytorch_tokenizers/TARGETS @@ -0,0 +1,18 @@ +# Any targets that should be shared between fbcode and xplat must be defined in +# targets.bzl. This file can contain xplat-only targets. + +load("@fbcode_macros//build_defs:python_library.bzl", "python_library") +load(":targets.bzl", "define_common_targets") + +oncall("executorch") + +define_common_targets() + +python_library( + name = "hf_tokenizer", + srcs = ["hf_tokenizer.py"], + labels = ["autodeps2_generated"], + deps = [ + "fbsource//third-party/pypi/tokenizers:tokenizers", + ], +) diff --git a/pytorch_tokenizers/__init__.py b/pytorch_tokenizers/__init__.py new file mode 100644 index 0000000..fb81b2f --- /dev/null +++ b/pytorch_tokenizers/__init__.py @@ -0,0 +1,26 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import Optional + +from .hf_tokenizer import HuggingFaceTokenizer +from .llama2c import Llama2cTokenizer +from .tiktoken import TiktokenTokenizer + +__all__ = ["TiktokenTokenizer", "Llama2cTokenizer", "HuggingFaceTokenizer"] + + +def get_tokenizer(tokenizer_path: str, tokenizer_config_path: Optional[str] = None): + if tokenizer_path.endswith(".json"): + tokenizer = HuggingFaceTokenizer(tokenizer_path, tokenizer_config_path) + else: + try: + tokenizer = Llama2cTokenizer(model_path=str(tokenizer_path)) + except Exception: + print("Using Tiktokenizer") + tokenizer = TiktokenTokenizer(model_path=str(tokenizer_path)) + return tokenizer diff --git a/pytorch_tokenizers/hf_tokenizer.py b/pytorch_tokenizers/hf_tokenizer.py new file mode 100644 index 0000000..cc2e2cf --- /dev/null +++ b/pytorch_tokenizers/hf_tokenizer.py @@ -0,0 +1,56 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import json +import os +from typing import List, Optional + +from tokenizers import Tokenizer + + +class HuggingFaceTokenizer: + """ + Tokenizing and encoding/decoding text using the Hugging face tokenizer. + """ + + def __init__(self, model_path: str, config_path: Optional[str] = None): + """ + Initializes the Tokenizer with a tokenizer.json from HuggingFace. + + Args: + model_path (str): The path to the Tiktoken model file. + """ + assert os.path.isfile(model_path), model_path + + self.model = tokenizer = Tokenizer.from_file(model_path) + + self.n_words: int = tokenizer.get_vocab_size() + if config_path: + with open(config_path) as f: + tokenizer_config = json.load(f) + self.bos_id = ( + self.model.token_to_id(tokenizer_config["bos_token"]) + if tokenizer_config["bos_token"] + else None + ) + self.eos_id = self.model.token_to_id(tokenizer_config["eos_token"]) + else: # Fallback guess. + self.bos_id = self.model.token_to_id("<|begin_of_text|>") + self.eos_id = self.model.token_to_id("<|endoftext|>") + + self.stop_tokens = [ + self.eos_id, + ] + + def encode(self, s: str, *, bos: bool, eos: bool) -> List[int]: + assert type(s) is str + return self.model.encode(s).ids + + def decode(self, t: List[int]) -> str: + return self.model.decode(t) + + def decode_token(self, t: int) -> str: + return self.model.decode([t]) diff --git a/pytorch_tokenizers/llama2c.py b/pytorch_tokenizers/llama2c.py new file mode 100644 index 0000000..11715dd --- /dev/null +++ b/pytorch_tokenizers/llama2c.py @@ -0,0 +1,110 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +import struct +from typing import List + +from sentencepiece import SentencePieceProcessor as SentencePieceProcessor + + +class Llama2cTokenizer: + def __init__(self, model_path: str): + assert os.path.isfile( + model_path + ), f"Need a valid tokenizer model path but got {model_path}" + # pyre-fixme[28]: Unexpected keyword argument `model_file` to call `SentencePieceProcessor.__init__`. + self.sp_model = SentencePieceProcessor(model_file=model_path) + self.model_path = model_path + + # BOS / EOS token IDs + self.n_words: int = self.sp_model.vocab_size() + self.bos_id: int = self.sp_model.bos_id() + self.eos_id: int = self.sp_model.eos_id() + logging.info( + f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}" + ) + # pyre-fixme[16]: `SentencePieceProcessor` has no attribute `get_piece_size`. + assert self.sp_model.vocab_size() == self.sp_model.get_piece_size() + + def encode(self, s: str, bos: bool, eos: bool) -> List[int]: + assert type(s) is str + # pyre-fixme[16]: `SentencePieceProcessor` has no attribute `encode`. + t = self.sp_model.encode(s) + if bos: + t = [self.bos_id] + t + if eos: + t = t + [self.eos_id] + return t + + def decode(self, t: List[int]) -> str: + # pyre-fixme[16]: `SentencePieceProcessor` has no attribute `encode`. + return self.sp_model.decode(t) + + def decode_token(self, t: int) -> str: + # pyre-fixme[16]: `SentencePieceProcessor` has no attribute `encode`. + return self.sp_model.decode(t) + + def export(self, output_path: str, *, prepend_padding: bool = False) -> None: + """ + Export tokenizer.model to another serialization format. Here we did some lightweight + processing such as supporting prepend padding token, prepend max token length and + replace '_' back to empty space. + + The binary format is: + 1. vocab size: int32 + 2. bos token id: int32 + 3. eos token id: int32 + 4. max token length: int32 + 5. score: float32, len of bytes: int32, token bytes: [byte] for each token + + :param output_path: output path of the new binary. + :param prepend_padding: a boolean to control if we want to prepend a padding token. + + :return: None + """ + + # get all the tokens (postprocessed) and their scores as floats + tokens, scores = [], [] + + if prepend_padding: + # Here we use the default padding token and its score. + tokens.append("".encode("utf-8")) + scores.append(-1) + + for i in range(self.n_words): + # decode the token and light postprocessing + # pyre-fixme[16]: `SentencePieceProcessor` has no attribute `id_to_piece`. + t = self.sp_model.id_to_piece(i) + # pyre-fixme[16]: `SentencePieceProcessor` has no attribute `get_score`. + s = self.sp_model.get_score(i) + # sentencepiece use '' as BOS and '' for EOS + if i == self.bos_id: + t = "" + elif i == self.eos_id: + t = "" + t = t.replace("▁", " ") # sentencepiece uses this character as whitespace + b = t.encode("utf-8") # bytes of this token, utf-8 encoded + + tokens.append(b) + scores.append(s) + + # record the max token length + max_token_length = 0 if not tokens else max(len(t) for t in tokens) + + # write to a binary file + with open(output_path, "wb") as f: + # write the vocab size, bos/eos ids and max token length + f.write( + struct.pack( + "IIII", self.n_words, self.bos_id, self.eos_id, max_token_length + ) + ) + for bytes, score in zip(tokens, scores): + f.write(struct.pack("fI", score, len(bytes))) + f.write(bytes) + logging.info(f"Wrote tokenizer to {output_path}") diff --git a/pytorch_tokenizers/targets.bzl b/pytorch_tokenizers/targets.bzl new file mode 100644 index 0000000..b59ff84 --- /dev/null +++ b/pytorch_tokenizers/targets.bzl @@ -0,0 +1,34 @@ +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +def define_common_targets(): + """Defines targets that should be shared between fbcode and xplat. + + The directory containing this targets.bzl file should also contain both + TARGETS and BUCK files that call this function. + """ + runtime.python_library( + name = "tokenizers", + srcs = [ + "__init__.py", + "llama2c.py", + "tiktoken.py", + "hf_tokenizer.py", + ], + base_module = "pytorch_tokenizers", + visibility = [ + "//executorch/examples/...", + "//executorch/extension/llm/export/...", + "//bento/...", + "//bento_kernels/...", + "//pytorch/tokenizers/...", + "@EXECUTORCH_CLIENTS", + ], + _is_external_target = True, + external_deps = [ + "sentencepiece-py", + ], + deps = [ + "fbsource//third-party/pypi/tiktoken:tiktoken", + "fbsource//third-party/pypi/tokenizers:tokenizers", + ], + ) diff --git a/pytorch_tokenizers/tiktoken.py b/pytorch_tokenizers/tiktoken.py new file mode 100644 index 0000000..41b48d9 --- /dev/null +++ b/pytorch_tokenizers/tiktoken.py @@ -0,0 +1,225 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os +from logging import getLogger +from pathlib import Path +from typing import ( + AbstractSet, + cast, + Collection, + Dict, + Iterator, + List, + Literal, + Optional, + Sequence, + Union, +) + +import tiktoken + +from tiktoken.load import load_tiktoken_bpe + +logger = getLogger(__name__) + + +# The tiktoken tokenizer can handle <=400k chars without +# pyo3_runtime.PanicException. +TIKTOKEN_MAX_ENCODE_CHARS = 400_000 + +# https://github.com/openai/tiktoken/issues/195 +# Here we iterate over subsequences and split if we exceed the limit +# of max consecutive non-whitespace or whitespace characters. +MAX_NO_WHITESPACES_CHARS = 25_000 + + +_INSTANCE = None + + +class TiktokenTokenizer: + """ + Tokenizing and encoding/decoding text using the Tiktoken tokenizer. + WARNING: The regex and special tokens are hardcoded from Llama 3+. + """ + + special_tokens: Dict[str, int] + + num_reserved_special_tokens = 256 + + pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501 + + @classmethod + def get_instance(cls): + global _INSTANCE + + if _INSTANCE is None: + _INSTANCE = TiktokenTokenizer( + os.path.join(os.path.dirname(__file__), "tokenizer.model") + ) + return _INSTANCE + + def __init__(self, model_path: str): + """ + Initializes the Tokenizer with a Tiktoken model. + + Args: + model_path (str): The path to the Tiktoken model file. + """ + assert os.path.isfile(model_path), model_path + + mergeable_ranks = load_tiktoken_bpe(model_path) + num_base_tokens = len(mergeable_ranks) + special_tokens = [ + "<|begin_of_text|>", + "<|end_of_text|>", + "<|reserved_special_token_0|>", + "<|reserved_special_token_1|>", + "<|finetune_right_pad_id|>", + "<|step_id|>", + "<|start_header_id|>", + "<|end_header_id|>", + "<|eom_id|>", # end of message + "<|eot_id|>", # end of turn + "<|python_tag|>", + "<|image|>", + ] + reserved_tokens = [ + f"<|reserved_special_token_{2 + i}|>" + for i in range(self.num_reserved_special_tokens - len(special_tokens)) + ] + special_tokens = special_tokens + reserved_tokens + + self.special_tokens = { + token: num_base_tokens + i for i, token in enumerate(special_tokens) + } + self.model = tiktoken.Encoding( + name=Path(model_path).name, + pat_str=self.pat_str, + mergeable_ranks=mergeable_ranks, + special_tokens=self.special_tokens, + ) + + self.n_words: int = num_base_tokens + len(special_tokens) + # BOS / EOS token IDs + self.bos_id: int = self.special_tokens["<|begin_of_text|>"] + self.eos_id: int = self.special_tokens["<|end_of_text|>"] + self.eot_id: int = self.special_tokens["<|eot_id|>"] + self.eom_id: int = self.special_tokens["<|eom_id|>"] + self.python_tag_id = self.special_tokens["<|python_tag|>"] + self.pad_id: int = self.special_tokens["<|finetune_right_pad_id|>"] + self.stop_tokens = [ + self.eos_id, + self.special_tokens["<|eom_id|>"], + self.special_tokens["<|eot_id|>"], + ] + + def encode( + self, + s: str, + *, + bos: bool, + eos: bool, + allowed_special: Optional[Union[Literal["all"], AbstractSet[str]]] = None, + disallowed_special: Union[Literal["all"], Collection[str]] = (), + ) -> List[int]: + """ + Encodes a string into a list of token IDs. + + Args: + s (str): The input string to be encoded. + bos (bool): Whether to prepend the beginning-of-sequence token. + eos (bool): Whether to append the end-of-sequence token. + allowed_special ("all"|set[str]): allowed special tokens in string + disallowed_special ("all"|set[str]): special tokens that raise an error when in string + + Returns: + list[int]: A list of token IDs. + + By default, setting disallowed_special=() encodes a string by ignoring + special tokens. Specifically: + - Setting `disallowed_special` to () will cause all text corresponding + to special tokens to be encoded as natural text (insteading of raising + an error). + - Setting `allowed_special` to "all" will treat all text corresponding + to special tokens to be encoded as special tokens. + """ + if allowed_special is None: + allowed_special = set() + assert type(s) is str + + substrs = ( + substr + for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS) + for substr in self._split_whitespaces_or_nonwhitespaces( + s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS + ) + ) + t: List[int] = [] + for substr in substrs: + t.extend( + self.model.encode( + substr, + allowed_special=allowed_special, + disallowed_special=disallowed_special, + ) + ) + if bos: + t.insert(0, self.bos_id) + if eos: + t.append(self.eos_id) + return t + + def decode(self, t: Sequence[int]) -> str: + """ + Decodes a list of token IDs into a string. + + Args: + t (List[int]): The list of token IDs to be decoded. + + Returns: + str: The decoded string. + """ + # Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence. + return self.model.decode(cast(List[int], t)) + + def decode_token(self, t: int) -> str: + """ + Decodes a single token ID into a string. + + Args: + t (int): The token ID to be decoded. + + Returns: + str: The decoded string. + """ + return self.model.decode_single_token_bytes(t).decode("utf-8") + + @staticmethod + def _split_whitespaces_or_nonwhitespaces( + s: str, max_consecutive_slice_len: int + ) -> Iterator[str]: + """ + Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len` + consecutive whitespaces or consecutive non-whitespaces. + """ + current_slice_len = 0 + current_slice_is_space = s[0].isspace() if len(s) > 0 else False + slice_start = 0 + + for i in range(len(s)): + is_now_space = s[i].isspace() + + if current_slice_is_space ^ is_now_space: + current_slice_len = 1 + current_slice_is_space = is_now_space + else: + current_slice_len += 1 + if current_slice_len > max_consecutive_slice_len: + yield s[slice_start:i] + slice_start = i + current_slice_len = 1 + yield s[slice_start:] diff --git a/tools/llama2c/convert.py b/tools/llama2c/convert.py index 1f915fc..f6cbb08 100644 --- a/tools/llama2c/convert.py +++ b/tools/llama2c/convert.py @@ -9,113 +9,11 @@ # postprocessing logic. The output can be consumed by llama2c_tokenizer.cpp. import argparse -import logging -import os -import struct -from typing import List -from sentencepiece import SentencePieceProcessor as SentencePieceProcessor +from pytorch_tokenizers.llama2c import Llama2cTokenizer -class Tokenizer: - def __init__(self, model_path: str): - assert os.path.isfile( - model_path - ), f"Need a valid tokenizer model path but got {model_path}" - # pyre-fixme[28]: Unexpected keyword argument `model_file` to call `SentencePieceProcessor.__init__`. - self.sp_model = SentencePieceProcessor(model_file=model_path) - self.model_path = model_path - - # BOS / EOS token IDs - self.n_words: int = self.sp_model.vocab_size() - self.bos_id: int = self.sp_model.bos_id() - self.eos_id: int = self.sp_model.eos_id() - logging.info( - f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}" - ) - # pyre-fixme[16]: `SentencePieceProcessor` has no attribute `get_piece_size`. - assert self.sp_model.vocab_size() == self.sp_model.get_piece_size() - - def encode(self, s: str, bos: bool, eos: bool) -> List[int]: - assert type(s) is str - # pyre-fixme[16]: `SentencePieceProcessor` has no attribute `encode`. - t = self.sp_model.encode(s) - if bos: - t = [self.bos_id] + t - if eos: - t = t + [self.eos_id] - return t - - def decode(self, t: List[int]) -> str: - # pyre-fixme[16]: `SentencePieceProcessor` has no attribute `encode`. - return self.sp_model.decode(t) - - def decode_token(self, t: int) -> str: - # pyre-fixme[16]: `SentencePieceProcessor` has no attribute `encode`. - return self.sp_model.decode(t) - - def export(self, output_path: str, *, prepend_padding: bool = False) -> None: - """ - Export tokenizer.model to another serialization format. Here we did some lightweight - processing such as supporting prepend padding token, prepend max token length and - replace '_' back to empty space. - - The binary format is: - 1. vocab size: int32 - 2. bos token id: int32 - 3. eos token id: int32 - 4. max token length: int32 - 5. score: float32, len of bytes: int32, token bytes: [byte] for each token - - :param output_path: output path of the new binary. - :param prepend_padding: a boolean to control if we want to prepend a padding token. - - :return: None - """ - - # get all the tokens (postprocessed) and their scores as floats - tokens, scores = [], [] - - if prepend_padding: - # Here we use the default padding token and its score. - tokens.append("".encode("utf-8")) - scores.append(-1) - - for i in range(self.n_words): - # decode the token and light postprocessing - # pyre-fixme[16]: `SentencePieceProcessor` has no attribute `id_to_piece`. - t = self.sp_model.id_to_piece(i) - # pyre-fixme[16]: `SentencePieceProcessor` has no attribute `get_score`. - s = self.sp_model.get_score(i) - # sentencepiece use '' as BOS and '' for EOS - if i == self.bos_id: - t = "" - elif i == self.eos_id: - t = "" - t = t.replace("▁", " ") # sentencepiece uses this character as whitespace - b = t.encode("utf-8") # bytes of this token, utf-8 encoded - - tokens.append(b) - scores.append(s) - - # record the max token length - max_token_length = 0 if not tokens else max(len(t) for t in tokens) - - # write to a binary file - with open(output_path, "wb") as f: - # write the vocab size, bos/eos ids and max token length - f.write( - struct.pack( - "IIII", self.n_words, self.bos_id, self.eos_id, max_token_length - ) - ) - for bytes, score in zip(tokens, scores): - f.write(struct.pack("fI", score, len(bytes))) - f.write(bytes) - logging.info(f"Wrote tokenizer to {output_path}") - - -if __name__ == "__main__": +def main(): parser = argparse.ArgumentParser() parser.add_argument( "-t", @@ -140,7 +38,7 @@ def export(self, output_path: str, *, prepend_padding: bool = False) -> None: args = parser.parse_args() - t = Tokenizer(args.tokenizer_model) + t = Llama2cTokenizer(args.tokenizer_model) output_path = ( args.output_path @@ -148,3 +46,7 @@ def export(self, output_path: str, *, prepend_padding: bool = False) -> None: else args.tokenizer_model.replace(".model", ".bin") ) t.export(output_path, prepend_padding=args.prepend_padding) + + +if __name__ == "__main__": + main() diff --git a/tools/llama2c/targets.bzl b/tools/llama2c/targets.bzl index 2449d4e..6ef87d1 100644 --- a/tools/llama2c/targets.bzl +++ b/tools/llama2c/targets.bzl @@ -6,23 +6,15 @@ def define_common_targets(): The directory containing this targets.bzl file should also contain both TARGETS and BUCK files that call this function. """ + runtime.python_library( - name = "convert_lib", + name = "lib", srcs = [ "__init__.py", "convert.py", ], - base_module = "pytorch.tokenizers.tools.llama2c", - visibility = [ - "//executorch/examples/...", - "//executorch/extension/llm/export/...", - "//bento/...", - "//bento_kernels/...", - "@EXECUTORCH_CLIENTS", - ], - _is_external_target = True, - external_deps = [ - "sentencepiece-py", + deps = [ + "//pytorch/tokenizers/pytorch_tokenizers:tokenizers", ], ) @@ -35,6 +27,6 @@ def define_common_targets(): ], _is_external_target = True, deps = [ - ":convert_lib", + ":lib", ], )