Skip to content

Support gemma3 HF tokenizer.json #96

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ set(tokenizers_source_files
${CMAKE_CURRENT_SOURCE_DIR}/src/bpe_tokenizer_base.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/hf_tokenizer.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/llama2c_tokenizer.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/normalizer.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/pre_tokenizer.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/re2_regex.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/regex.cpp
Expand Down
77 changes: 76 additions & 1 deletion include/pytorch/tokenizers/bpe_tokenizer_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,73 @@ inline Result<std::unique_ptr<IRegex>> build_special_token_regex(
return create_regex(special_pattern);
}

// Hash function for std::pair<uint64_t, uint64_t>
struct PairHash {
std::size_t operator()(const std::pair<uint64_t, uint64_t>& p) const {
return std::hash<uint64_t>{}(p.first) ^
(std::hash<uint64_t>{}(p.second) << 1);
}
};

// Type alias for BPE merge map: (token_id_1, token_id_2) -> (rank,
// merged_token_id)
using MergeMap = std::unordered_map<
std::pair<uint64_t, uint64_t>,
std::pair<uint64_t, uint64_t>,
PairHash>;

// Utility function to build merge ranks map from merge rules
template <typename TMergeMap>
inline Result<TokenMap> buildMergeRanksMap(
const TMergeMap& merge_map,
const TokenMap& token_map) {
// Static assertions to verify TMergeMap has the expected key and value types
using KeyType = typename TMergeMap::key_type;
using ValueType = typename TMergeMap::mapped_type;

static_assert(
std::is_same_v<KeyType, std::pair<uint64_t, uint64_t>>,
"TMergeMap key type must be std::pair<uint64_t, uint64_t>");

static_assert(
std::is_same_v<ValueType, std::pair<uint64_t, uint64_t>>,
"TMergeMap value type must be std::pair<uint64_t, uint64_t>");

// Use a map to handle duplicates - keep the lowest rank (highest priority)
std::unordered_map<std::string, uint64_t> unique_merge_ranks;

for (const auto& [pair, rank_and_id] : merge_map) {
uint64_t first_id = pair.first;
uint64_t second_id = pair.second;
uint64_t rank = rank_and_id.first;

// Get the token strings for the pair
auto first_token = token_map.tryGetString(first_id);
auto second_token = token_map.tryGetString(second_id);

if (first_token && second_token) {
std::string merged_token =
std::string(*first_token) + std::string(*second_token);

// Keep the entry with the lowest rank (highest priority in BPE)
auto it = unique_merge_ranks.find(merged_token);
if (it == unique_merge_ranks.end() || rank < it->second) {
unique_merge_ranks[merged_token] = rank;
}
}
}

// Convert to vector for buildTokenMap
std::vector<std::pair<std::string, uint64_t>> merge_rank_pairs;
merge_rank_pairs.reserve(unique_merge_ranks.size());

for (const auto& [token, rank] : unique_merge_ranks) {
merge_rank_pairs.emplace_back(token, rank);
}

return buildTokenMap(std::move(merge_rank_pairs));
}

class BPETokenizerBase : public Tokenizer {
public:
Result<std::vector<uint64_t>>
Expand Down Expand Up @@ -152,14 +219,22 @@ class BPETokenizerBase : public Tokenizer {
const std::string& text,
const TokenMap& allowed_special) const;

Result<std::vector<uint64_t>> byte_pair_encode_(
virtual Result<std::vector<uint64_t>> byte_pair_encode_(
const std::string& piece,
const TokenMap& encoder) const;

// Virtual method for BPE merging - can be overridden by derived classes
virtual std::vector<uint64_t> _byte_pair_merge(
const std::string& piece,
const TokenMap& ranks,
std::function<uint64_t(uint64_t, uint64_t)> func) const;

// Protected members that can be overloaded by other BPE tokenizers
std::unique_ptr<IRegex> special_token_regex_;
std::optional<TokenMap> token_map_;
std::optional<TokenMap> special_token_map_;
std::unique_ptr<MergeMap> merge_map_;
std::optional<TokenMap> merge_ranks_; // Pre-computed merge ranks for BPE

private:
virtual Error _encode(
Expand Down
85 changes: 85 additions & 0 deletions include/pytorch/tokenizers/hf_tokenizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,85 @@
// Local
#include <pytorch/tokenizers/bpe_tokenizer_base.h>
#include <pytorch/tokenizers/error.h>
#include <pytorch/tokenizers/normalizer.h>
#include <pytorch/tokenizers/pre_tokenizer.h>
#include <pytorch/tokenizers/result.h>
#include <pytorch/tokenizers/token_decoder.h>

namespace tokenizers {

// Simple Word structure to mimic Rust's Word behavior
struct HFWord {
std::vector<uint64_t> tokens;
std::vector<size_t> byte_lengths;

void add(uint64_t token_id, size_t byte_len) {
tokens.push_back(token_id);
byte_lengths.push_back(byte_len);
}

size_t size() const {
return tokens.size();
}

// Apply all possible merges using the merge ranks
void merge_all(
const detail::TokenMap& merge_ranks,
const detail::TokenMap& token_map) {
while (tokens.size() > 1) {
std::optional<std::pair<size_t, uint32_t>> best_merge;

// Find the best merge (lowest rank) among adjacent token pairs
for (size_t i = 0; i < tokens.size() - 1; ++i) {
// Create the merged token string to look up its rank
auto first_token = token_map.tryGetString(tokens[i]);
auto second_token = token_map.tryGetString(tokens[i + 1]);

if (first_token && second_token) {
std::string merged_token =
std::string(*first_token) + std::string(*second_token);
auto rank = merge_ranks.tryGetInteger(merged_token);

if (rank && (!best_merge || *rank < best_merge->second)) {
best_merge = std::make_pair(i, static_cast<uint32_t>(*rank));
}
}
}

if (!best_merge) {
break; // No more merges possible
}

// Apply the best merge
size_t merge_idx = best_merge->first;

// Get the merged token ID
auto first_token = token_map.tryGetString(tokens[merge_idx]);
auto second_token = token_map.tryGetString(tokens[merge_idx + 1]);

if (first_token && second_token) {
std::string merged_token =
std::string(*first_token) + std::string(*second_token);
auto merged_id = token_map.tryGetInteger(merged_token);

if (merged_id) {
// Replace the two tokens with the merged token
tokens[merge_idx] = *merged_id;
byte_lengths[merge_idx] += byte_lengths[merge_idx + 1];

// Remove the second token
tokens.erase(tokens.begin() + merge_idx + 1);
byte_lengths.erase(byte_lengths.begin() + merge_idx + 1);
} else {
break; // Merged token not found in vocabulary
}
} else {
break; // Original tokens not found in vocabulary
}
}
}
};

class HFTokenizer : public detail::BPETokenizerBase {
public:
/*-- Public Interface --*/
Expand All @@ -46,6 +120,17 @@ class HFTokenizer : public detail::BPETokenizerBase {

void _decode(const std::string& input, std::string& ret) const override;

Result<std::vector<uint64_t>> byte_pair_encode_(
const std::string& piece,
const detail::TokenMap& encoder) const override;

// Override the virtual _byte_pair_merge method to use HF-specific BPE logic
std::vector<uint64_t> _byte_pair_merge(
const std::string& piece,
const detail::TokenMap& ranks,
std::function<uint64_t(uint64_t, uint64_t)> func) const override;

Normalizer::Ptr _normalizer;
PreTokenizer::Ptr _pretokenizer;
TokenDecoder::Ptr _decoder;
};
Expand Down
9 changes: 9 additions & 0 deletions include/pytorch/tokenizers/llama2c_tokenizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,15 @@ class Llama2cTokenizer : public Tokenizer {
const override;

private:
inline Error _decode_verify(uint64_t token) const {
if (!initialized_) {
return Error::Uninitialized;
}
if (token >= vocab_size_) {
return Error::OutOfRange;
}
return Error::Ok;
}
std::unique_ptr<char*[]> vocab_ = nullptr;
std::unique_ptr<float[]> vocab_scores_ = nullptr;
std::unique_ptr<TokenIndex[]> sorted_vocab_ = nullptr;
Expand Down
Loading