Skip to content

Commit

Permalink
Tokenizer test
Browse files Browse the repository at this point in the history
  • Loading branch information
lucylq committed Feb 15, 2025
1 parent b3ba207 commit cba637e
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 158 deletions.
2 changes: 1 addition & 1 deletion test/test_base64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
* LICENSE file in the root directory of this source tree.
*/

#include <pytorch/tokenizers/base64.h>
#include "gtest/gtest.h"
#include <pytorch/tokenizers/base64.h>

namespace tokenizers {

Expand Down
25 changes: 5 additions & 20 deletions test/test_llama2c_tokenizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,34 +6,19 @@
* LICENSE file in the root directory of this source tree.
*/

#ifdef TOKENIZERS_FB_BUCK
#include <TestResourceUtils/TestResourceUtils.h>
#endif
#include <gtest/gtest.h>
#include <pytorch/tokenizers/llama2c_tokenizer.h>

using namespace ::testing;

namespace tokenizers {

namespace {
// Test case based on llama2.c tokenizer
static inline std::string _get_resource_path(const std::string& name) {
#ifdef TOKENIZERS_FB_BUCK
return facebook::xplat::testing::getPathForTestResource(
"test/resources/" + name);
#else
return std::getenv("RESOURCES_PATH") + std::string("/") + name;
#endif
}

} // namespace

class Llama2cTokenizerTest : public Test {
public:
public:
void SetUp() override {
tokenizer_ = std::make_unique<Llama2cTokenizer>();
modelPath_ = _get_resource_path("test_llama2c_tokenizer.bin");
modelPath_ = std::getenv("RESOURCES_PATH") +
std::string("/test_llama2c_tokenizer.bin");
}

std::unique_ptr<Tokenizer> tokenizer_;
Expand All @@ -51,15 +36,15 @@ TEST_F(Llama2cTokenizerTest, DecodeWithoutLoadFails) {
}

TEST_F(Llama2cTokenizerTest, DecodeOutOfRangeFails) {
Error res = tokenizer_->load(modelPath_.c_str());
Error res = tokenizer_->load(modelPath_);
EXPECT_EQ(res, Error::Ok);
auto result = tokenizer_->decode(0, 64000);
// The vocab size is 32000, and token 64000 is out of vocab range.
EXPECT_EQ(result.error(), Error::OutOfRange);
}

TEST_F(Llama2cTokenizerTest, TokenizerMetadataIsExpected) {
Error res = tokenizer_->load(modelPath_.c_str());
Error res = tokenizer_->load(modelPath_);
EXPECT_EQ(res, Error::Ok);
// test_bpe_tokenizer.bin has vocab_size 0, bos_id 0, eos_id 0 recorded.
EXPECT_EQ(tokenizer_->vocab_size(), 0);
Expand Down
87 changes: 25 additions & 62 deletions test/test_pre_tokenizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,11 @@ using namespace tokenizers;

// Helpers /////////////////////////////////////////////////////////////////////

static void assert_split_match(
const PreTokenizer& ptok,
const std::string& prompt,
const std::vector<std::string>& expected) {
static void assert_split_match(const PreTokenizer &ptok,
const std::string &prompt,
const std::vector<std::string> &expected) {
re2::StringPiece prompt_view(prompt);
const auto& got = ptok.pre_tokenize(prompt_view);
const auto &got = ptok.pre_tokenize(prompt_view);
EXPECT_EQ(expected.size(), got.size());
for (auto i = 0; i < got.size(); ++i) {
EXPECT_EQ(expected[i], got[i]);
Expand All @@ -35,16 +34,14 @@ static void assert_split_match(
class RegexPreTokenizerTest : public ::testing::Test {};

// Test the basic construction
TEST_F(RegexPreTokenizerTest, Construct) {
RegexPreTokenizer ptok("[0-9]+");
}
TEST_F(RegexPreTokenizerTest, Construct) { RegexPreTokenizer ptok("[0-9]+"); }

// Test basic splitting using the expression for Tiktoken
TEST_F(RegexPreTokenizerTest, TiktokenExpr) {
RegexPreTokenizer ptok(
R"((?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+)");
assert_split_match(
ptok, "How are you doing?", {"How", " are", " you", " doing", "?"});
assert_split_match(ptok, "How are you doing?",
{"How", " are", " you", " doing", "?"});
}

// DigitsPreTokenizer //////////////////////////////////////////////////////////
Expand All @@ -54,18 +51,15 @@ class DigitsPreTokenizerTest : public ::testing::Test {};
TEST_F(DigitsPreTokenizerTest, IndividualDigits) {
DigitsPreTokenizer ptok(true);
assert_split_match(
ptok,
"The number 1 then 234 then 5.",
ptok, "The number 1 then 234 then 5.",
{"The number ", "1", " then ", "2", "3", "4", " then ", "5", "."});
}

// Test digit splitting with contiguous digits
TEST_F(DigitsPreTokenizerTest, ContiguousDigits) {
DigitsPreTokenizer ptok(false);
assert_split_match(
ptok,
"The number 1 then 234 then 5.",
{"The number ", "1", " then ", "234", " then ", "5", "."});
assert_split_match(ptok, "The number 1 then 234 then 5.",
{"The number ", "1", " then ", "234", " then ", "5", "."});
}

// ByteLevelPreTokenizer ///////////////////////////////////////////////////////
Expand All @@ -75,8 +69,7 @@ TEST_F(ByteLevelPreTokenizerTest, PreTokenizeDefault) {
ByteLevelPreTokenizer ptok;
assert_split_match(ptok, "Hello World", {"ĠHello", "ĠWorld"});
assert_split_match(
ptok,
"The number 1 then 234 then 5.",
ptok, "The number 1 then 234 then 5.",
{"ĠThe", "Ġnumber", "Ġ1", "Ġthen", "Ġ234", "Ġthen", "Ġ5", "."});
}

Expand All @@ -97,22 +90,9 @@ TEST_F(SequencePreTokenizerTest, PreTokenizeDigitAndByteLevel) {
PreTokenizer::Ptr dptok(new DigitsPreTokenizer(true));
PreTokenizer::Ptr bptok(new ByteLevelPreTokenizer(false));
SequencePreTokenizer ptok({dptok, bptok});
assert_split_match(
ptok,
"The number 1 then 234 then 5.",
{"The",
"Ġnumber",
"Ġ",
"1",
"Ġthen",
"Ġ",
"2",
"3",
"4",
"Ġthen",
"Ġ",
"5",
"."});
assert_split_match(ptok, "The number 1 then 234 then 5.",
{"The", "Ġnumber", "Ġ", "1", "Ġthen", "Ġ", "2", "3", "4",
"Ġthen", "Ġ", "5", "."});
}

// PreTokenizerConfig //////////////////////////////////////////////////////////
Expand Down Expand Up @@ -152,14 +132,12 @@ TEST_F(PreTokenizerConfigTest, AllTypesFailureCases) {

// Sequence
EXPECT_THROW(PreTokenizerConfig("Sequence").create(), std::runtime_error);
EXPECT_THROW(
PreTokenizerConfig("Sequence").set_pretokenizers({}).create(),
std::runtime_error);
EXPECT_THROW(
PreTokenizerConfig("Sequence")
.set_pretokenizers({PreTokenizerConfig("Split")})
.create(),
std::runtime_error);
EXPECT_THROW(PreTokenizerConfig("Sequence").set_pretokenizers({}).create(),
std::runtime_error);
EXPECT_THROW(PreTokenizerConfig("Sequence")
.set_pretokenizers({PreTokenizerConfig("Split")})
.create(),
std::runtime_error);

// Unsupported
EXPECT_THROW(PreTokenizerConfig("Unsupported").create(), std::runtime_error);
Expand All @@ -183,22 +161,9 @@ TEST_F(PreTokenizerConfigTest, ParseJson) {
}},
})
.create();
assert_split_match(
*ptok,
"The number 1 then 234 then 5.",
{"The",
"Ġnumber",
"Ġ",
"1",
"Ġthen",
"Ġ",
"2",
"3",
"4",
"Ġthen",
"Ġ",
"5",
"."});
assert_split_match(*ptok, "The number 1 then 234 then 5.",
{"The", "Ġnumber", "Ġ", "1", "Ġthen", "Ġ", "2", "3", "4",
"Ġthen", "Ġ", "5", "."});
}

TEST_F(PreTokenizerConfigTest, ParseJsonOptionalKey) {
Expand All @@ -208,10 +173,8 @@ TEST_F(PreTokenizerConfigTest, ParseJsonOptionalKey) {
{"type", "Digits"},
})
.create();
assert_split_match(
*ptok,
"The number 1 then 234 then 5.",
{"The number ", "1", " then ", "234", " then ", "5", "."});
assert_split_match(*ptok, "The number 1 then 234 then 5.",
{"The number ", "1", " then ", "234", " then ", "5", "."});
}

TEST_F(PreTokenizerConfigTest, Split) {
Expand Down
24 changes: 6 additions & 18 deletions test/test_sentencepiece.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,11 @@
*/
// @lint-ignore-every LICENSELINT

#ifdef TOKENIZERS_FB_BUCK
#include <TestResourceUtils/TestResourceUtils.h>
#endif
#include <gtest/gtest.h>
#include <pytorch/tokenizers/sentencepiece.h>

namespace tokenizers {

namespace {
static inline std::string _get_resource_path(const std::string& name) {
#ifdef TOKENIZERS_FB_BUCK
return facebook::xplat::testing::getPathForTestResource(
"test/resources/" + name);
#else
return std::getenv("RESOURCES_PATH") + std::string("/") + name;
#endif
}

} // namespace

TEST(SPTokenizerTest, TestEncodeWithoutLoad) {
SPTokenizer tokenizer;
std::string text = "Hello world!";
Expand All @@ -42,7 +27,8 @@ TEST(SPTokenizerTest, TestDecodeWithoutLoad) {

TEST(SPTokenizerTest, TestLoad) {
SPTokenizer tokenizer;
auto path = _get_resource_path("test_sentencepiece.model");
auto path =
std::getenv("RESOURCES_PATH") + std::string("/test_sentencepiece.model");
auto error = tokenizer.load(path);
EXPECT_EQ(error, Error::Ok);
}
Expand All @@ -55,7 +41,8 @@ TEST(SPTokenizerTest, TestLoadInvalidPath) {

TEST(SPTokenizerTest, TestEncode) {
SPTokenizer tokenizer;
auto path = _get_resource_path("test_sentencepiece.model");
auto path =
std::getenv("RESOURCES_PATH") + std::string("/test_sentencepiece.model");
auto error = tokenizer.load(path);
EXPECT_EQ(error, Error::Ok);
std::string text = "Hello world!";
Expand All @@ -70,7 +57,8 @@ TEST(SPTokenizerTest, TestEncode) {

TEST(SPTokenizerTest, TestDecode) {
SPTokenizer tokenizer;
auto path = _get_resource_path("test_sentencepiece.model");
auto path =
std::getenv("RESOURCES_PATH") + std::string("/test_sentencepiece.model");
auto error = tokenizer.load(path);
EXPECT_EQ(error, Error::Ok);
std::vector<uint64_t> tokens = {1, 15043, 3186, 29991};
Expand Down
Loading

0 comments on commit cba637e

Please sign in to comment.