From 15f0b54539cbc72b0b8c37eaeabaf086d9ca04a6 Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Tue, 3 Dec 2024 21:57:53 -0800 Subject: [PATCH] Add base64.h Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- .github/workflows/pull.yml | 2 +- include/base64.h | 195 +++++++++++++++++++++++++++++++++++++ include/error.h | 32 ++++++ include/tiktoken.h | 81 +++++++++++++++ 4 files changed, 309 insertions(+), 1 deletion(-) create mode 100644 include/base64.h create mode 100644 include/tiktoken.h diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index ab17999..50ccff4 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -18,7 +18,7 @@ jobs: strategy: fail-fast: false with: - runner: linux.4xlarge + runner: linux.2xlarge docker-image: executorch-ubuntu-22.04-clang12 submodules: 'true' ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} diff --git a/include/base64.h b/include/base64.h new file mode 100644 index 0000000..c1373bc --- /dev/null +++ b/include/base64.h @@ -0,0 +1,195 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +// @lint-ignore-every LICENSELINT +/************************************************************************** + Copyright (c) 2023 sewenew + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + *************************************************************************/ + +#pragma once + +#include +#include +#include +#include + +#include "result.h" + +namespace base64 { + +using tokenizers::Error; +using tokenizers::Result; + +Result decode(const std::string_view &input); + +namespace detail { + +constexpr uint32_t DECODE_TABLE[] = { + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 62, 255, + 255, 255, 63, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 255, 255, + 255, 255, 255, 255, 255, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, + 25, 255, 255, 255, 255, 255, 255, 26, 27, 28, 29, 30, 31, 32, 33, + 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, + 49, 50, 51, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255}; + +inline Error validate(uint32_t v) { + if (v == 255) { + fprintf(stderr, "invalid char"); + return Error::Base64DecodeFailure; + } + return Error::Ok; +} + +inline Error decode(const std::string_view &input, std::string &output) { + if (input.size() != 4) { + fprintf(stderr, "input length must be 4, got %zu", input.size()); + return Error::Base64DecodeFailure; + } + + uint32_t val = 0; + + uint8_t c = input[0]; + auto v = DECODE_TABLE[c]; + TK_CHECK_OK_OR_RETURN_ERROR(validate(v)); + val = v; + + c = input[1]; + v = DECODE_TABLE[c]; + TK_CHECK_OK_OR_RETURN_ERROR(validate(v)); + val = (val << 6) | v; + + c = input[2]; + v = DECODE_TABLE[c]; + TK_CHECK_OK_OR_RETURN_ERROR(validate(v)); + val = (val << 6) | v; + + c = input[3]; + v = DECODE_TABLE[c]; + TK_CHECK_OK_OR_RETURN_ERROR(validate(v)); + val = (val << 6) | v; + + output.push_back(static_cast((val >> 16) & 0xFF)); + output.push_back(static_cast((val >> 8) & 0xFF)); + output.push_back(static_cast(val & 0xFF)); + return Error::Ok; +} + +inline Error decode_1_padding(const std::string_view &input, + std::string &output) { + if (input.size() != 3) { + fprintf(stderr, "input length must be 3, got %zu", input.size()); + return Error::Base64DecodeFailure; + } + + uint32_t val = 0; + + uint8_t c = input[0]; + auto v = DECODE_TABLE[c]; + TK_CHECK_OK_OR_RETURN_ERROR(validate(v)); + val = v; + + c = input[1]; + v = DECODE_TABLE[c]; + TK_CHECK_OK_OR_RETURN_ERROR(validate(v)); + val = (val << 6) | v; + + c = input[2]; + v = DECODE_TABLE[c]; + TK_CHECK_OK_OR_RETURN_ERROR(validate(v)); + val = (val << 6) | v; + + output.push_back(static_cast((val >> 10) & 0xFF)); + output.push_back(static_cast((val >> 2) & 0xFF)); + return Error::Ok; +} + +inline Error decode_2_padding(const std::string_view &input, + std::string &output) { + TK_CHECK_OR_RETURN_ERROR(input.size() == 2, Base64DecodeFailure); + + uint32_t val = 0; + + uint8_t c = input[0]; + auto v = DECODE_TABLE[c]; + TK_CHECK_OK_OR_RETURN_ERROR(validate(v)); + val = v; + + c = input[1]; + v = DECODE_TABLE[c]; + TK_CHECK_OK_OR_RETURN_ERROR(validate(v)); + val = (val << 6) | v; + + output.push_back(static_cast((val >> 4) & 0xFF)); + return Error::Ok; +} + +} // namespace detail + +inline tokenizers::Result decode(const std::string_view &input) { + if (input.empty()) { + fprintf(stderr, "empty input"); + return Error::Base64DecodeFailure; + } + + // Faster than `input.size() % 4`. + if ((input.size() & 3) != 0 || input.size() < 4) { + fprintf(stderr, + "input length must be larger than 4 and is multiple of 4, got %zu", + input.size()); + return Error::Base64DecodeFailure; + } + + std::string output; + output.reserve(input.size() / 4 * 3); + auto idx = 0U; + for (; idx < input.size() - 4; idx += 4) { + TK_CHECK_OK_OR_RETURN_ERROR(detail::decode(input.substr(idx, 4), output)); + } + + // Last 4 bytes. Might contain paddings. + if (input[idx + 3] == '=') { + if (input[idx + 2] == '=') { + // Tow paddings. + TK_CHECK_OK_OR_RETURN_ERROR( + detail::decode_2_padding(input.substr(idx, 2), output)); + } else { + // One padding. + TK_CHECK_OK_OR_RETURN_ERROR( + detail::decode_1_padding(input.substr(idx, 3), output)); + } + } else { + // No padding. + TK_CHECK_OK_OR_RETURN_ERROR(detail::decode(input.substr(idx, 4), output)); + } + + return output; +} +} // namespace base64 diff --git a/include/error.h b/include/error.h index 196a889..9510373 100644 --- a/include/error.h +++ b/include/error.h @@ -45,6 +45,38 @@ enum class Error : error_code_t { /// Encode failure. EncodeFailure = 0x05, + + /// Base64 decode failure. + Base64DecodeFailure = 0x06, }; } // namespace tokenizers + +/** + * If cond__ is false, return the specified Error + * from the current function, which must be of return type + * tokenizers::Error. + * TODO: Add logging support + * @param[in] cond__ The condition to be checked, asserted as true. + * @param[in] error__ Error enum value to return without the `Error::` prefix, + * like `InvalidArgument`. + */ +#define TK_CHECK_OR_RETURN_ERROR(cond__, error__) \ + { \ + if (!(cond__)) { \ + return ::tokenizers::Error::error__; \ + } \ + } + +/** + * If error__ is not Error::Ok, return the specified Error + * TODO: Add logging support + * @param[in] error__ Error enum value to return without the `Error::` prefix, + * like `InvalidArgument`. + */ +#define TK_CHECK_OK_OR_RETURN_ERROR(error__) \ + { \ + if (error__ != ::tokenizers::Error::Ok) { \ + return error__; \ + } \ + } diff --git a/include/tiktoken.h b/include/tiktoken.h new file mode 100644 index 0000000..399a1fa --- /dev/null +++ b/include/tiktoken.h @@ -0,0 +1,81 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Tiktoken header +// Used by OpenAI, adapted from https://github.com/sewenew/tokenizer +#include "re2/re2.h" +#include "tokenizer.h" +#include + +#pragma once + +using Encoder = std::unordered_map; +using Decoder = std::unordered_map; +using Re2UPtr = std::unique_ptr; + +namespace tokenizers { + +class Tiktoken : public Tokenizer { +public: + explicit Tiktoken(); + ~Tiktoken() override; + + Error load(const std::string &tokenizer_path) override; + + Result> encode(const std::string &input, int8_t bos, + int8_t eos) const override; + + Result decode(uint64_t prev_token, + uint64_t token) const override; + +private: + static inline const Encoder _get_special_tokens(ssize_t num_base_tokens) { + Encoder special_tokens; + special_tokens.emplace("<|begin_of_text|>", num_base_tokens++); + special_tokens.emplace("<|end_of_text|>", num_base_tokens++); + special_tokens.emplace("<|reserved_special_token_0|>", num_base_tokens++); + special_tokens.emplace("<|reserved_special_token_1|>", num_base_tokens++); + special_tokens.emplace("<|reserved_special_token_2|>", num_base_tokens++); + special_tokens.emplace("<|reserved_special_token_3|>", num_base_tokens++); + special_tokens.emplace("<|start_header_id|>", num_base_tokens++); + special_tokens.emplace("<|end_header_id|>", num_base_tokens++); + special_tokens.emplace("<|reserved_special_token_4|>", num_base_tokens++); + special_tokens.emplace("<|eot_id|>", num_base_tokens++); + for (auto i = 5; i < 251; ++i) { + special_tokens.emplace("<|reserved_special_token_" + std::to_string(i) + + "|>", + num_base_tokens++); + } + return special_tokens; + } + + template + std::pair, re2::StringPiece> + _split_with_allowed_special_token(re2::StringPiece &input, + const T &allowed_special); + + void _encode(re2::StringPiece &input, std::vector &ret, + uint64_t &last_piece_token_len); + + template + std::pair, uint64_t> + _encode_with_special_token(const std::string &text, const T &allowed_special); + + // Removed negative lookahead \s+(?!\S) since it's not supported by RE2. + const std::string _pattern = + R"((?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+)"; + Encoder _encoder; + Encoder _special_token_encoder; + Decoder _decoder; + Decoder _special_token_decoder; + + Re2UPtr _regex; + Re2UPtr _special_token_regex; +}; + +} // namespace tokenizers