Skip to content

Add python sentencepiece tokenizer #88

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 42 additions & 4 deletions pytorch_tokenizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,60 @@
# @lint-ignore-every LICENSELINT


from enum import Enum
from typing import Optional

from .hf_tokenizer import HuggingFaceTokenizer
from .llama2c import Llama2cTokenizer
from .sentencepiece import SentencePieceTokenizer
from .tiktoken import TiktokenTokenizer

__all__ = ["TiktokenTokenizer", "Llama2cTokenizer", "HuggingFaceTokenizer"]

class TokenizerType(Enum):
LLAMA2C = "llama2c"
SENTENCEPIECE = "sentencepiece"
TIKTOKEN = "tiktoken"
HUGGINGFACE = "huggingface"

@classmethod
def from_str(cls, value: str) -> "TokenizerType":
"""Create TokenizerType from string value (case-insensitive)"""
value_lower = value.lower()
for tokenizer_type in cls:
if tokenizer_type.value == value_lower:
return tokenizer_type
raise ValueError(f"Invalid tokenizer type: {value}. Valid options: {[t.value for t in cls]}")

def get_tokenizer(tokenizer_path: str, tokenizer_config_path: Optional[str] = None):

__all__ = ["TiktokenTokenizer", "Llama2cTokenizer", "HuggingFaceTokenizer", "SentencePieceTokenizer", "TokenizerType"]


def get_tokenizer(
tokenizer_path: str,
tokenizer_config_path: Optional[str] = None,
tokenizer_type: Optional[TokenizerType] = None
):
if tokenizer_type is not None:
if tokenizer_type == TokenizerType.HUGGINGFACE:
return HuggingFaceTokenizer(tokenizer_path, tokenizer_config_path)
elif tokenizer_type == TokenizerType.LLAMA2C:
return Llama2cTokenizer(model_path=str(tokenizer_path))
elif tokenizer_type == TokenizerType.SENTENCEPIECE:
return SentencePieceTokenizer(model_path=str(tokenizer_path))
elif tokenizer_type == TokenizerType.TIKTOKEN:
return TiktokenTokenizer(model_path=str(tokenizer_path))

# Default fallback to auto-detection
if tokenizer_path.endswith(".json"):
tokenizer = HuggingFaceTokenizer(tokenizer_path, tokenizer_config_path)
else:
try:
tokenizer = Llama2cTokenizer(model_path=str(tokenizer_path))
except Exception:
print("Using Tiktokenizer")
tokenizer = TiktokenTokenizer(model_path=str(tokenizer_path))
try:
print("Using SentencePiece tokenizer")
tokenizer = SentencePieceTokenizer(model_path=str(tokenizer_path))
except Exception:
print("Using Tiktokenizer")
tokenizer = TiktokenTokenizer(model_path=str(tokenizer_path))
return tokenizer
43 changes: 43 additions & 0 deletions pytorch_tokenizers/sentencepiece.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# @lint-ignore-every LICENSELINT

import logging
import os
from typing import List

from sentencepiece import SentencePieceProcessor


class SentencePieceTokenizer:
def __init__(self, model_path: str):
assert os.path.isfile(
model_path
), f"Need a valid tokenizer model path but got {model_path}"
self.sp_model = SentencePieceProcessor(model_file=model_path)
self.model_path = model_path

self.n_words: int = self.sp_model.vocab_size()
self.bos_id: int = self.sp_model.bos_id()
self.eos_id: int = self.sp_model.eos_id()
logging.info(
f"SentencePiece - #words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}"
)

def encode(self, s: str, bos: bool = False, eos: bool = False) -> List[int]:
assert type(s) is str
t = self.sp_model.encode(s)
if bos:
t = [self.bos_id] + t
if eos:
t = t + [self.eos_id]
return t

def decode(self, t: List[int]) -> str:
return self.sp_model.decode(t)

def decode_token(self, t: int) -> str:
return self.sp_model.id_to_piece(t)