diff --git a/CMakeLists.txt b/CMakeLists.txt index 9d0c865..0728111 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -21,6 +21,9 @@ project(Tokenizers) option(TOKENIZERS_BUILD_TEST "Build tests" OFF) +# Ignore weak attribute warning +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-attributes") + set(ABSL_ENABLE_INSTALL ON) set(ABSL_PROPAGATE_CXX_STD ON) set(_pic_flag ${CMAKE_POSITION_INDEPENDENT_CODE}) diff --git a/include/base64.h b/include/base64.h index c1373bc..3dfebc7 100644 --- a/include/base64.h +++ b/include/base64.h @@ -69,10 +69,8 @@ inline Error validate(uint32_t v) { } inline Error decode(const std::string_view &input, std::string &output) { - if (input.size() != 4) { - fprintf(stderr, "input length must be 4, got %zu", input.size()); - return Error::Base64DecodeFailure; - } + TK_CHECK_OR_RETURN_ERROR(input.size() == 4, Base64DecodeFailure, + "input length must be 4, got %zu", input.size()); uint32_t val = 0; @@ -104,10 +102,8 @@ inline Error decode(const std::string_view &input, std::string &output) { inline Error decode_1_padding(const std::string_view &input, std::string &output) { - if (input.size() != 3) { - fprintf(stderr, "input length must be 3, got %zu", input.size()); - return Error::Base64DecodeFailure; - } + TK_CHECK_OR_RETURN_ERROR(input.size() == 3, Base64DecodeFailure, + "input length must be 3, got %zu", input.size()); uint32_t val = 0; @@ -133,7 +129,8 @@ inline Error decode_1_padding(const std::string_view &input, inline Error decode_2_padding(const std::string_view &input, std::string &output) { - TK_CHECK_OR_RETURN_ERROR(input.size() == 2, Base64DecodeFailure); + TK_CHECK_OR_RETURN_ERROR(input.size() == 2, Base64DecodeFailure, + "input length must be 2, got %zu", input.size()); uint32_t val = 0; @@ -154,18 +151,13 @@ inline Error decode_2_padding(const std::string_view &input, } // namespace detail inline tokenizers::Result decode(const std::string_view &input) { - if (input.empty()) { - fprintf(stderr, "empty input"); - return Error::Base64DecodeFailure; - } + TK_CHECK_OR_RETURN_ERROR(!input.empty(), Base64DecodeFailure, "empty input"); // Faster than `input.size() % 4`. - if ((input.size() & 3) != 0 || input.size() < 4) { - fprintf(stderr, - "input length must be larger than 4 and is multiple of 4, got %zu", - input.size()); - return Error::Base64DecodeFailure; - } + TK_CHECK_OR_RETURN_ERROR( + (input.size() & 3) == 0 && input.size() >= 4, Base64DecodeFailure, + "input length must be larger than 4 and is multiple of 4, got %zu", + input.size()); std::string output; output.reserve(input.size() / 4 * 3); diff --git a/include/error.h b/include/error.h index 9510373..0746c9b 100644 --- a/include/error.h +++ b/include/error.h @@ -13,6 +13,7 @@ #pragma once +#include "log.h" #include namespace tokenizers { @@ -59,11 +60,14 @@ enum class Error : error_code_t { * TODO: Add logging support * @param[in] cond__ The condition to be checked, asserted as true. * @param[in] error__ Error enum value to return without the `Error::` prefix, - * like `InvalidArgument`. + * like `Base64DecodeFailure`. + * @param[in] message__ Format string for the log error message. + * @param[in] ... Optional additional arguments for the format string. */ -#define TK_CHECK_OR_RETURN_ERROR(cond__, error__) \ +#define TK_CHECK_OR_RETURN_ERROR(cond__, error__, message__, ...) \ { \ if (!(cond__)) { \ + TK_LOG(Error, message__, ##__VA_ARGS__); \ return ::tokenizers::Error::error__; \ } \ } @@ -72,11 +76,80 @@ enum class Error : error_code_t { * If error__ is not Error::Ok, return the specified Error * TODO: Add logging support * @param[in] error__ Error enum value to return without the `Error::` prefix, - * like `InvalidArgument`. + * like `Base64DecodeFailure`. + * @param[in] ... Optional format string for the log error message and its + * arguments. */ -#define TK_CHECK_OK_OR_RETURN_ERROR(error__) \ - { \ - if (error__ != ::tokenizers::Error::Ok) { \ - return error__; \ +#define TK_CHECK_OK_OR_RETURN_ERROR(error__, ...) \ + TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR(error__, ##__VA_ARGS__) + +// Internal only: Use ET_CHECK_OK_OR_RETURN_ERROR() instead. +#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR(...) \ + TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_SELECT(__VA_ARGS__, 10, 9, 8, 7, 6, 5, \ + 4, 3, 2, 1) \ + (__VA_ARGS__) + +/** + * Internal only: Use TK_CHECK_OK_OR_RETURN_ERROR() instead. + * This macro selects the correct version of + * TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR based on the number of arguments passed. + * It uses a trick with the preprocessor to count the number of arguments and + * then selects the appropriate macro. + * + * The macro expansion uses __VA_ARGS__ to accept any number of arguments and + * then appends them to TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_, followed by the + * count of arguments. The count is determined by the macro + * TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_SELECT which takes the arguments and + * passes them along with a sequence of numbers (2, 1). The preprocessor then + * matches this sequence to the correct number of arguments provided. + * + * If two arguments are passed, TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2 is + * selected, suitable for cases where an error code and a custom message are + * provided. If only one argument is passed, + * TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_1 is selected, which is used for cases + * with just an error code. + * + * Usage: + * TK_CHECK_OK_OR_RETURN_ERROR(error_code); // Calls v1 + * TK_CHECK_OK_OR_RETURN_ERROR(error_code, "Error message", ...); // Calls v2 + */ +#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_SELECT(_1, _2, _3, _4, _5, _6, \ + _7, _8, _9, _10, N, ...) \ + TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_##N + +// Internal only: Use ET_CHECK_OK_OR_RETURN_ERROR() instead. +#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_1(error__) \ + do { \ + const auto et_error__ = (error__); \ + if (et_error__ != ::tokenizers::Error::Ok) { \ + return et_error__; \ } \ - } + } while (0) + +// Internal only: Use ET_CHECK_OK_OR_RETURN_ERROR() instead. +#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2(error__, message__, ...) \ + do { \ + const auto et_error__ = (error__); \ + if (et_error__ != ::tokenizers::Error::Ok) { \ + TK_LOG(Error, message__, ##__VA_ARGS__); \ + return et_error__; \ + } \ + } while (0) + +// Internal only: Use ET_CHECK_OK_OR_RETURN_ERROR() instead. +#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_3 \ + TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2 +#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_4 \ + TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2 +#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_5 \ + TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2 +#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_6 \ + TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2 +#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_7 \ + TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2 +#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_8 \ + TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2 +#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_9 \ + TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2 +#define TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_10 \ + TK_INTERNAL_CHECK_OK_OR_RETURN_ERROR_2 diff --git a/include/log.h b/include/log.h new file mode 100644 index 0000000..839b764 --- /dev/null +++ b/include/log.h @@ -0,0 +1,278 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/** + * @file + * Tokenizers logging API, adopted from ExecuTorch. + */ + +#pragma once + +#include +#include +#include +#include + +// Set minimum log severity if compiler option is not provided. +#ifndef TK_MIN_LOG_LEVEL +#define TK_MIN_LOG_LEVEL Info +#endif // !defined(TK_MIN_LOG_LEVEL) + +/* + * Enable logging by default if compiler option is not provided. + * This should facilitate less confusion for those developing Tokenizers. + */ +#ifndef TK_LOG_ENABLED +#define TK_LOG_ENABLED 1 +#endif // !defined(TK_LOG_ENABLED) + +/** + * Annotation marking a function as printf-like, providing compiler support + * for format string argument checking. + */ +#ifdef _MSC_VER +#include +#define TK_PRINTFLIKE(_string_index, _va_index) _Printf_format_string_ +#else +#define TK_PRINTFLIKE(_string_index, _va_index) \ + __attribute__((format(printf, _string_index, _va_index))) +#endif + +/// Define a C symbol with weak linkage. +#ifdef _MSC_VER +// There currently doesn't seem to be a great way to do this in Windows and +// given that weak linkage is not really critical on Windows, we'll just leave +// it as a stub. +#define TK_WEAK +#else +#define TK_WEAK __attribute__((weak)) +#endif + +#ifndef __has_builtin +#define __has_builtin(x) (0) +#endif + +#if __has_builtin(__builtin_strrchr) +/// Name of the source file without a directory string. +#define TK_SHORT_FILENAME (__builtin_strrchr("/" __FILE__, '/') + 1) +#else +#define TK_SHORT_FILENAME __FILE__ +#endif + +#if __has_builtin(__builtin_LINE) +/// Current line as an integer. +#define TK_LINE __builtin_LINE() +#else +#define TK_LINE __LINE__ +#endif // __has_builtin(__builtin_LINE) + +#if __has_builtin(__builtin_FUNCTION) +/// Name of the current function as a const char[]. +#define TK_FUNCTION __builtin_FUNCTION() +#else +#define TK_FUNCTION __FUNCTION__ +#endif // __has_builtin(__builtin_FUNCTION) + +/** + * Clients should neither define nor use this macro. Used to optionally declare + * the tk_pal_*() functions as weak symbols. + * + * This provides a way to both: + * - Include the header and define weak symbols (used by the internal default + * implementations) + * - Include the header and define strong symbols (used by client overrides) + */ +#ifndef TK_INTERNAL_PLATFORM_WEAKNESS +#define TK_INTERNAL_PLATFORM_WEAKNESS TK_WEAK +#endif // !defined(TK_INTERNAL_PLATFORM_WEAKNESS) + +// TODO: making an assumption that we have stderr +#define TK_LOG_OUTPUT_FILE stderr + +extern "C" { +/** + * Severity level of a log message. Values must map to printable 7-bit ASCII + * uppercase letters. + */ +typedef enum { + kDebug = 'D', + kInfo = 'I', + kError = 'E', + kFatal = 'F', + kUnknown = '?', // Exception to the "uppercase letter" rule. +} tk_pal_log_level_t; + +/** + * Emit a log message via platform output (serial port, console, etc). + * + * @param[in] level Severity level of the message. Must be a printable 7-bit + * ASCII uppercase letter. + * @param[in] filename Name of the file that created the log event. + * @param[in] function Name of the function that created the log event. + * @param[in] line Line in the source file where the log event was created. + * @param[in] message Message string to log. + * @param[in] length Message string length. + */ +inline void TK_INTERNAL_PLATFORM_WEAKNESS tk_pal_emit_log_message( + tk_pal_log_level_t level, const char *filename, const char *function, + size_t line, const char *message, size_t length) { + // Use a format similar to glog and folly::logging, except: + // - Print time since et_pal_init since we don't have wall time + // - Don't include the thread ID, to avoid adding a threading dependency + // - Add the string "tokenizers:" to make the logs more searchable + // + // Clients who want to change the format or add other fields can override this + // weak implementation of et_pal_emit_log_message. + fprintf(TK_LOG_OUTPUT_FILE, "%c tokenizers:%s:%zu] %s\n", level, filename, + line, message); + fflush(TK_LOG_OUTPUT_FILE); +} + +} // extern "C" +namespace tokenizers { + +/** + * Severity level of a log message. Must be ordered from lowest to highest + * severity. + */ +enum class LogLevel : uint8_t { + /** + * Log messages provided for highly granular debuggability. + * + * Log messages using this severity are unlikely to be compiled by default + * into most debug builds. + */ + Debug, + + /** + * Log messages providing information about the state of the system + * for debuggability. + */ + Info, + + /** + * Log messages about errors within Tokenizers during runtime. + */ + Error, + + /** + * Log messages that precede a fatal error. However, logging at this level + * does not perform the actual abort, something else needs to. + */ + Fatal, + + /** + * Number of supported log levels, with values in [0, NumLevels). + */ + NumLevels, +}; + +namespace internal { + +/** + * Maps LogLevel values to et_pal_log_level_t values. + * + * We don't share values because LogLevel values need to be ordered by severity, + * and et_pal_log_level_t values need to be printable characters. + */ +static constexpr tk_pal_log_level_t kLevelToPal[size_t(LogLevel::NumLevels)] = { + tk_pal_log_level_t::kDebug, + tk_pal_log_level_t::kInfo, + tk_pal_log_level_t::kError, + tk_pal_log_level_t::kFatal, +}; + +// TODO: add timestamp support + +/** + * Log a string message. + * + * Note: This is an internal function. Use the `ET_LOG` macro instead. + * + * @param[in] level Log severity level. + * @param[in] filename Name of the source file creating the log event. + * @param[in] function Name of the function creating the log event. + * @param[in] line Source file line of the caller. + * @param[in] format Format string. + * @param[in] args Variable argument list. + */ +TK_PRINTFLIKE(5, 0) +inline void vlogf(LogLevel level, const char *filename, const char *function, + size_t line, const char *format, va_list args) { + // Maximum length of a log message. + static constexpr size_t kMaxLogMessageLength = 256; + char buf[kMaxLogMessageLength]; + size_t len = vsnprintf(buf, kMaxLogMessageLength, format, args); + if (len >= kMaxLogMessageLength - 1) { + buf[kMaxLogMessageLength - 2] = '$'; + len = kMaxLogMessageLength - 1; + } + buf[kMaxLogMessageLength - 1] = 0; + + tk_pal_log_level_t pal_level = + (int(level) >= 0 && level < LogLevel::NumLevels) + ? kLevelToPal[size_t(level)] + : tk_pal_log_level_t::kUnknown; + + tk_pal_emit_log_message(pal_level, filename, function, line, buf, len); +} + +/** + * Log a string message. + * + * Note: This is an internal function. Use the `ET_LOG` macro instead. + * + * @param[in] level Log severity level. + * @param[in] filename Name of the source file creating the log event. + * @param[in] function Name of the function creating the log event. + * @param[in] line Source file line of the caller. + * @param[in] format Format string. + */ +TK_PRINTFLIKE(5, 6) +inline void logf(LogLevel level, const char *filename, const char *function, + size_t line, const char *format, ...) { +#if TK_LOG_ENABLED + va_list args; + va_start(args, format); + internal::vlogf(level, filename, function, line, format, args); + va_end(args); +#endif // TK_LOG_ENABLED +} + +} // namespace internal + +} // namespace tokenizers + +#if TK_LOG_ENABLED + +/** + * Log a message at the given log severity level. + * + * @param[in] _level Log severity level. + * @param[in] _format Log message format string. + */ +#define TK_LOG(_level, _format, ...) \ + do { \ + const auto _log_level = ::tokenizers::LogLevel::_level; \ + if (static_cast(_log_level) >= \ + static_cast(::tokenizers::LogLevel::TK_MIN_LOG_LEVEL)) { \ + ::tokenizers::internal::logf(_log_level, TK_SHORT_FILENAME, TK_FUNCTION, \ + TK_LINE, _format, ##__VA_ARGS__); \ + } \ + } while (0) +#else // TK_LOG_ENABLED + +/** + * Log a message at the given log severity level. + * + * @param[in] _level Log severity level. + * @param[in] _format Log message format string. + */ +#define TK_LOG(_level, _format, ...) ((void)0) + +#endif // TK_LOG_ENABLED