diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..3e4e48b --- /dev/null +++ b/.dockerignore @@ -0,0 +1 @@ +.gitignore \ No newline at end of file diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..dfe0770 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,2 @@ +# Auto detect text files and perform LF normalization +* text=auto diff --git a/.github/workflows/cmake-multi-platform.yml b/.github/workflows/cmake-multi-platform.yml new file mode 100644 index 0000000..24373c8 --- /dev/null +++ b/.github/workflows/cmake-multi-platform.yml @@ -0,0 +1,101 @@ +name: Multi-Platform CMake Build + +on: + push: + branches: [ "*" ] + pull_request: + branches: [ "*" ] + +jobs: + build: + runs-on: ${{ matrix.os }} + + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, windows-latest, macos-latest] + build_type: [Release] + c_compiler: [gcc, clang, cl] + include: + - os: windows-latest + c_compiler: cl + cpp_compiler: cl + - os: ubuntu-latest + c_compiler: gcc + cpp_compiler: g++ + - os: ubuntu-latest + c_compiler: clang + cpp_compiler: clang++ + - os: macos-latest + c_compiler: clang + cpp_compiler: clang++ + exclude: + - os: windows-latest + c_compiler: gcc + - os: windows-latest + c_compiler: clang + - os: ubuntu-latest + c_compiler: cl + - os: macos-latest + c_compiler: gcc + - os: macos-latest + c_compiler: cl + + steps: + - uses: actions/checkout@v4 + + - name: Set reusable strings + id: strings + shell: bash + run: | + echo "build-output-dir=${{ github.workspace }}/build" >> "$GITHUB_OUTPUT" + + - name: Configure CMake + run: > + cmake -B ${{ steps.strings.outputs.build-output-dir }} + -DCMAKE_CXX_COMPILER=${{ matrix.cpp_compiler }} + -DCMAKE_C_COMPILER=${{ matrix.c_compiler }} + -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} + -S ${{ github.workspace }} + + - name: Build C++ Runtime + run: cmake --build ${{ steps.strings.outputs.build-output-dir }} --config ${{ matrix.build_type }} -j4 + + - name: Prepare C++ Test Data + run: cp -r ${{ github.workspace }}/test_data ${{ steps.strings.outputs.build-output-dir }}/test/test_data + + - name: Run C++ Runtime Tests + working-directory: ${{ steps.strings.outputs.build-output-dir }} + run: ctest --build-config ${{ matrix.build_type }} --verbose + + - name: Setup Python Virtual Environment + shell: bash + run: | + python3 -m venv .venv + . .venv/bin/activate + pip install --upgrade pip + + - name: Build Python wheel + shell: bash + run: | + . .venv/bin/activate + cd ${{ github.workspace }}/python/magnetron_framework + pip wheel --verbose -w dist . + + - name: Install Python wheel + shell: bash + run: | + . .venv/bin/activate + pip install ${{ github.workspace }}/python/magnetron_framework/dist/*.whl + + - name: Install Pytest + shell: bash + run: | + . .venv/bin/activate + pip install pytest + + - name: Run Python Tests + shell: bash + run: | + . .venv/bin/activate + python -m pytest ${{ github.workspace }}/python/tests/tests.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..a30158e --- /dev/null +++ b/.gitignore @@ -0,0 +1,34 @@ +CMakeLists.txt.user +CMakeCache.txt +CMakeFiles +CMakeScripts +Testing +Makefile +cmake_install.cmake +install_manifest.txt +compile_commands.json +CTestTestfile.cmake +_deps +/cmake-build-debug +/cmake-build-release +/cmake-build-relwithdebinfo +/.idea +/.vscode +/bin +/python/.venv +/python/.vscode +/python/.idea +/python/magnetron_framework/build/ +/python/magnetron_framework/dist/ +/python/magnetron_framework/magnetron.egg-info/ +/python/__pycache__ +/python/magnetron/__pycache__/ +/python/examples/__pycache__/ +/python/extra/__pycache__/ +/python/tests/__pycache__/ +**/.DS_Store +.vs/ +out/ +.tmp +CMakeSettings.json +magnetron_chat/.idea diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000..6520ba7 --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,126 @@ +# (c) 2024 Mario "Neo" Sieg. + +cmake_minimum_required(VERSION 3.18) + +project(magnetron LANGUAGES C) + +message("Configuring magnetron project for ${CMAKE_SYSTEM_PROCESSOR}...") + +set(CMAKE_C_STANDARD 99) +set(LIBRARY_OUTPUT_PATH "${CMAKE_BINARY_DIR}") +set(EXECUTABLE_OUTPUT_PATH "${CMAKE_BINARY_DIR}") + +option(MAGNETRON_BUILD_TESTS "Build tests" ON) +option(MAGNETRON_BUILD_BENCHMARKS "Build benchmarks" ON) +option(MAGNETRON_BUILD_FUZZERS "Build fuzzers" OFF) + +# (CPU only) Enable SIMD math function approximations. Greatly increases performance. Try disabling if you encounter numerical instability. Does NOT enable -ffast-math or similar compiler flags. +option(MAGNETRON_CPU_APPROX_MATH "Use faster but less precise math functions" ON) +option(MAGNETRON_ENABLE_CUDA "Enable CUDA support" ON) # Enable CUDA support +if (APPLE) + option(MAGNETRON_ENABLE_ACCELERATE "Use Apple's Accelerate framework" ON) +endif() + +if (${MAGNETRON_BUILD_TESTS}) + enable_testing() +endif() + +set(MAGNETRON_CUDA_COMPILER "/usr/local/cuda-12.6/bin/nvcc" CACHE STRING "Path to the CUDA compiler") # Set to your CUDA compiler path +if(CMAKE_SYSTEM_PROCESSOR MATCHES "(x86)|(X86)|(amd64)|(AMD64)") + set(IS_AMD64 TRUE) +else() + set(IS_AMD64 FALSE) +endif() +if(CMAKE_SYSTEM_PROCESSOR MATCHES "(aarch64)|(arm64)") + set(IS_ARM64 TRUE) +else() + set(IS_ARM64 FALSE) +endif() + +file(GLOB_RECURSE MAGNETRON_SOURCES magnetron/*.h magnetron/*.c magnetron/*.m) +add_library(magnetron SHARED ${MAGNETRON_SOURCES}) +target_include_directories(magnetron PRIVATE extern) +if(WIN32) # Windows specific config + target_compile_options(magnetron PRIVATE /W3 /Oi) + if (CMAKE_BUILD_TYPE STREQUAL "Release" OR CMAKE_BUILD_TYPE STREQUAL "MinSizeRel" OR CMAKE_BUILD_TYPE STREQUAL "RelWithDebInfo") # Enable optimizations for release builds + target_compile_options(magnetron PRIVATE /O2 /Oy) + endif() +else() # POSIX specific config + target_link_libraries(magnetron m) # link math library + target_compile_options(magnetron PRIVATE + -Wall + -Werror + -std=c99 + -Wno-gnu-zero-variadic-macro-arguments + -Wno-error=overflow + -Wno-error=unused-function + -fvisibility=hidden + -std=gnu99 + ) + if (CMAKE_C_COMPILER_ID STREQUAL "GNU") + target_compile_options(magnetron PRIVATE -Wno-error=format-truncation) + endif() + if (APPLE) # Apple specific configuration + if (${MAGNETRON_ENABLE_ACCELERATE}) + find_library(ACCELERATE_FRAMEWORK Accelerate) + if (ACCELERATE_FRAMEWORK) + message(STATUS "Accelerate framework found") + target_compile_definitions(magnetron PRIVATE MAG_ACCELERATE) + target_compile_definitions(magnetron PRIVATE ACCELERATE_NEW_LAPACK) + target_compile_definitions(magnetron PRIVATE ACCELERATE_LAPACK_ILP64) + target_link_libraries(magnetron ${ACCELERATE_FRAMEWORK}) + else() + message(WARNING "Accelerate framework not found, using fallback") + endif() + endif() + endif() + if (${MAGNETRON_CPU_APPROX_MATH}) + target_compile_definitions(magnetron PRIVATE MAG_APPROXMATH) + endif() + if(${IS_AMD64}) # x86-64 specific compilation options + target_compile_options(magnetron PRIVATE -msse2 -msse3 -mssse3 -msse4.1 -msse4.2 -mpclmul) + elseif(${IS_ARM64}) + target_compile_options(magnetron PRIVATE -march=armv8-a+simd) + endif() + if(CMAKE_BUILD_TYPE STREQUAL "Release" OR CMAKE_BUILD_TYPE STREQUAL "MinSizeRel" OR CMAKE_BUILD_TYPE STREQUAL "RelWithDebInfo") # Enable optimizations only for release builds + target_compile_options(magnetron PRIVATE -O3 -flto) + target_link_options(magnetron PRIVATE -flto) + endif() +endif() + +if (${MAGNETRON_ENABLE_CUDA}) + find_package(CUDAToolkit) + if (CUDAToolkit_FOUND) + if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES) + set(CMAKE_CUDA_ARCHITECTURES "52;61;70;75") + endif() + if (NOT DEFINED CMAKE_CUDA_COMPILER) + set(CMAKE_CUDA_COMPILER "${MAGNETRON_CUDA_COMPILER}") + endif() + set(CMAKE_CUDA_STANDARD 20) + enable_language(CUDA) + file(GLOB_RECURSE MAGNETRON_CUDA_SOURCES magnetron/*.cu magnetron/*.cuh) + add_library(magnetron_cuda SHARED ${MAGNETRON_CUDA_SOURCES}) + target_include_directories(magnetron_cuda PRIVATE extern) + target_link_libraries(magnetron_cuda CUDA::cudart CUDA::cuda_driver) + target_compile_definitions(magnetron_cuda PRIVATE MAG_ENABLE_CUDA) + target_compile_definitions(magnetron PRIVATE MAG_ENABLE_CUDA) + target_link_libraries(magnetron magnetron_cuda) + else() + message(WARNING "CUDA not found, disabling CUDA support") + set(MAGNETRON_ENABLE_CUDA OFF) + endif() +endif() + +if (${MAGNETRON_BUILD_TESTS}) + add_subdirectory(test) +endif() + +if (${MAGNETRON_BUILD_FUZZERS}) + add_subdirectory(fuzzer) +endif() + +if (${MAGNETRON_BUILD_BENCHMARKS}) + add_subdirectory(benchmark) +endif() + diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..9398b23 --- /dev/null +++ b/LICENSE @@ -0,0 +1,205 @@ +==================================================================================================== +magnetron License - Apache License 2.0 +==================================================================================================== + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2024 Mario Sieg + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md new file mode 100644 index 0000000..4537aae --- /dev/null +++ b/README.md @@ -0,0 +1,188 @@ +[![Contributors][contributors-shield]][contributors-url] +[![Forks][forks-shield]][forks-url] +[![Stargazers][stars-shield]][stars-url] +[![Issues][issues-shield]][issues-url] +[![MIT License][license-shield]][license-url] +[![LinkedIn][linkedin-shield]][linkedin-url] + +
+
+ + Logo + + +

magnetron

+

+ Minimalistic homemade PyTorch alternative, written in C99 and Python. +
+ Explore the docs » +
+
+ View Demo + · + Report Bug + · + Request Feature +

+
+ +
+ Table of Contents +
    +
  1. + About The Project +
  2. +
  3. + Getting Started + +
  4. +
  5. Usage
  6. +
  7. Roadmap
  8. +
  9. Contributing
  10. +
  11. License
  12. +
  13. Contact
  14. +
  15. Acknowledgments
  16. +
+
+ +## About The Project + +![ScreenShot](media/xor.png) +This project started as a learning experience and a way to understand the inner workings of PyTorch and other deep learning frameworks.
+The goal is to create a minimalistic but still powerful deep learning framework that can be used for research and production.
+The project is written in C99 and Python, with the Python API being the main interface for the user.
+The project is still in its early stages and many features are missing, but the core functionality is already implemented.
+ +

(back to top)

+ + +## Getting Started + +To get a local copy up and running follow these simple steps. + +### Prerequisites +* Linux, MacOS or Windows +* A C99 compiler (gcc, clang, msvc) +* Python 3.6 or higher + +### Installation +*A pip installable package will be provided, as soon as all core features are implemented.* +1. Clone the repo +2. If you want to use the C library, use CMake to build the library and link it to your project. +3. If you want to use the Python API, enter the `python/magnetron_framework` directory within a VENV and run `install_wheel_local.sh` to build the wheel and install it locally. +4. Run the XOR example in the `python/examples/simple/xor.py` directory to test the installation. + +

(back to top)

+ + + + +## Usage +See the `python/examples` directory for examples on how to use the framework. +For usage in C99 see the `test` directory in the root of the project. + +

(back to top)

+ + +## Roadmap + +The goal is to implement training and inference for LLMs and other state of the art models, while providing a simple and small codebase that is easy to understand and modify. + +- [X] 6 Dimensional, Linearized Tensors +- [X] Dynamic Computation Graph +- [X] Static Computation Graph +- [X] Modern Python API +- [X] CPU Compute and optimization +- [X] Compressed tensor file format +- [X] Validation and friendly error messages +- [X] Fast, custom memory allocators for CPU and GPU +- [X] Automatic differentiation +- [ ] Compute on GPU (Cuda) +- [ ] Other Datatypes (f16, bf16, int8) +- [ ] Multithreaded CPU Compute +- [ ] Distributed Training and Inference +- [ ] CPU and GPU kernel JIT compilation +- [ ] Better examples with real world models (LLMs and state of the art models) + +

(back to top)

+ + +## Contributing + +Contributions are what make the open source community such an amazing place to learn, inspire, and create. Any contributions you make are **greatly appreciated**. + +If you have a suggestion that would make this better, please fork the repo and create a pull request. You can also simply open an issue with the tag "enhancement". +Don't forget to give the project a star! Thanks again! + +1. Fork the Project +2. Create your Feature Branch (`git checkout -b feature/AmazingFeature`) +3. Commit your Changes (`git commit -m 'Add some AmazingFeature'`) +4. Push to the Branch (`git push origin feature/AmazingFeature`) +5. Open a Pull Request + +

(back to top)

+ + +## License + +Distributed under the Apache 2 License. See `LICENSE.txt` for more information. + +

(back to top)

+ + + +## Contact + +Mario Sieg - [@_mario_neo_](https://twitter.com/_mario_neo_) - mario.sieg.64@gmail.com + +Project Link: [https://github.com/MarioSieg/magnetron](https://github.com/MarioSieg/magnetron) + +Developed 2024 in Berlin, Germany. + +

(back to top)

+ + +## Similar Projects + +* [GGML](https://github.com/ggerganov/ggml) +* [TINYGRAD](https://github.com/tinygrad/tinygrad) +* [MICROGRAD](https://github.com/karpathy/micrograd) + +

(back to top)

+ + + + + +[contributors-shield]: https://img.shields.io/github/contributors/MarioSieg/magnetron.svg?style=for-the-badge +[contributors-url]: https://github.com/MarioSieg/magnetron/graphs/contributors +[forks-shield]: https://img.shields.io/github/forks/MarioSieg/magnetron.svg?style=for-the-badge +[forks-url]: https://github.com/MarioSieg/magnetron/network/members +[stars-shield]: https://img.shields.io/github/stars/MarioSieg/magnetron.svg?style=for-the-badge +[stars-url]: https://github.com/MarioSieg/magnetron/stargazers +[issues-shield]: https://img.shields.io/github/issues/MarioSieg/magnetron.svg?style=for-the-badge +[issues-url]: https://github.com/MarioSieg/magnetron/issues +[license-shield]: https://img.shields.io/github/license/MarioSieg/magnetron.svg?style=for-the-badge +[license-url]: https://github.com/MarioSieg/magnetron/blob/master/LICENSE.txt +[linkedin-shield]: https://img.shields.io/badge/-LinkedIn-black.svg?style=for-the-badge&logo=linkedin&colorB=555 +[linkedin-url]: https://linkedin.com/in/mario-sieg +[product-screenshot]: images/screenshot.png +[Next.js]: https://img.shields.io/badge/next.js-000000?style=for-the-badge&logo=nextdotjs&logoColor=white +[Next-url]: https://nextjs.org/ +[React.js]: https://img.shields.io/badge/React-20232A?style=for-the-badge&logo=react&logoColor=61DAFB +[React-url]: https://reactjs.org/ +[Vue.js]: https://img.shields.io/badge/Vue.js-35495E?style=for-the-badge&logo=vuedotjs&logoColor=4FC08D +[Vue-url]: https://vuejs.org/ +[Angular.io]: https://img.shields.io/badge/Angular-DD0031?style=for-the-badge&logo=angular&logoColor=white +[Angular-url]: https://angular.io/ +[Svelte.dev]: https://img.shields.io/badge/Svelte-4A4A55?style=for-the-badge&logo=svelte&logoColor=FF3E00 +[Svelte-url]: https://svelte.dev/ +[Laravel.com]: https://img.shields.io/badge/Laravel-FF2D20?style=for-the-badge&logo=laravel&logoColor=white +[Laravel-url]: https://laravel.com +[Bootstrap.com]: https://img.shields.io/badge/Bootstrap-563D7C?style=for-the-badge&logo=bootstrap&logoColor=white +[Bootstrap-url]: https://getbootstrap.com +[JQuery.com]: https://img.shields.io/badge/jQuery-0769AD?style=for-the-badge&logo=jquery&logoColor=white +[JQuery-url]: https://jquery.com diff --git a/benchmark/CMakeLists.txt b/benchmark/CMakeLists.txt new file mode 100644 index 0000000..b5467ab --- /dev/null +++ b/benchmark/CMakeLists.txt @@ -0,0 +1,11 @@ +# (c) 2024 Mario "Neo" Sieg. + +enable_language(CXX) +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +file(GLOB_RECURSE BENCHMARK_SOURCES *.cpp) +add_executable(magnetron_benchmark ${BENCHMARK_SOURCES}) +target_link_libraries(magnetron_benchmark magnetron) +target_include_directories(magnetron_benchmark PRIVATE ../magnetron) +target_include_directories(magnetron_benchmark PRIVATE nanobench) diff --git a/benchmark/benchmarks.cpp b/benchmark/benchmarks.cpp new file mode 100644 index 0000000..7acd37f --- /dev/null +++ b/benchmark/benchmarks.cpp @@ -0,0 +1,96 @@ +// (c) 2024 Mario "Neo" Sieg. + +// ON LINUX: Before running the benchmark, execute: linux_prepare_perf.sh to setup the system for performance measurements. + +#include + +#include +#define ANKERL_NANOBENCH_IMPLEMENT +#include + +#include "magnetron_internal.h" + +template +static auto run_bench( + std::string_view name, + std::string_view unit, + F&& callback, + Args&&... args +) -> void; + +auto main() -> int { + mag_set_log_mode(true); + mag_ctx_t* ctx = mag_ctx_create(MAG_COMPUTE_DEVICE_TYPE_CPU); // Create context to print CPU and system info + mag_ctx_destroy(ctx); + mag_set_log_mode(false); + + volatile std::size_t alloc_n = 400; + + run_bench("Fixed Malloc Free", "malloc", [=](mag_ctx_t* ctx) -> void { + void* p = std::malloc(alloc_n); + ankerl::nanobench::doNotOptimizeAway(p); + std::free(p); + }); + + mag_fixed_intrusive_pool cache {}; + mag_fixed_intrusive_pool_init(&cache, alloc_n, 8, 8192); + run_bench("Fixed Cached Pool Alloc Free", "malloc", [&cache](mag_ctx_t* ctx) -> void { + void* p = mag_fixed_intrusive_pool_malloc(&cache); + ankerl::nanobench::doNotOptimizeAway(p); + mag_fixed_intrusive_pool_free(&cache, p); + }); + mag_fixed_intrusive_pool_destroy(&cache); + + run_bench("Tensor CPU Allocation", "alloc", [](mag_ctx_t* ctx) -> void { + mag_tensor_t* A = mag_tensor_create_1d(ctx, MAG_DTYPE_F32, 8); + mag_tensor_t* B = mag_tensor_create_2d(ctx, MAG_DTYPE_F32, 8, 8); + mag_tensor_t* C = mag_tensor_create_3d(ctx, MAG_DTYPE_F32, 8, 8, 8); + mag_tensor_t* D = mag_tensor_create_4d(ctx, MAG_DTYPE_F32, 8, 8, 8, 8); + mag_tensor_t* E = mag_tensor_create_5d(ctx, MAG_DTYPE_F32, 8, 8, 8, 8, 8); + mag_tensor_t* F = mag_tensor_create_6d(ctx, MAG_DTYPE_F32, 8, 8, 8, 8, 8, 8); + mag_tensor_decref(A); + mag_tensor_decref(B); + mag_tensor_decref(C); + mag_tensor_decref(D); + mag_tensor_decref(E); + mag_tensor_decref(F); + }); + + run_bench("Tensor Add", "add", [](mag_ctx_t* ctx) -> void { + constexpr std::int64_t N = 1024; + + mag_tensor_t* A = mag_tensor_create_2d(ctx, MAG_DTYPE_F32, N, N); + mag_tensor_fill(A, 1.0f); + + mag_tensor_t* B = mag_tensor_create_2d(ctx, MAG_DTYPE_F32, N, N); + mag_tensor_fill(A, 1.0f); + + mag_tensor_t* C = mag_add(A, B); + + mag_tensor_decref(A); + mag_tensor_decref(B); + mag_tensor_decref(C); + }); +} + +template +static auto run_bench( + std::string_view name, + std::string_view unit, + F&& callback, + Args&&... args +) -> void { + static_assert(std::is_invocable_r_v); + ankerl::nanobench::Bench bench {}; + bench.title(name.data()) + .unit(unit.data()) + .minEpochIterations(10) + .relative(true); + bench.performanceCounters(true); + mag_ctx_t* ctx = mag_ctx_create(MAG_COMPUTE_DEVICE_TYPE_CPU); + bench.run(name.data(), [&] { + std::invoke(callback, ctx, std::forward(args)...); + }); + ankerl::nanobench::doNotOptimizeAway(ctx); + mag_ctx_destroy(ctx); +} diff --git a/benchmark/linux_prepare_perf.sh b/benchmark/linux_prepare_perf.sh new file mode 100644 index 0000000..54540ee --- /dev/null +++ b/benchmark/linux_prepare_perf.sh @@ -0,0 +1,8 @@ +#!/usr/bin/env bash + +# Run this script before running the benchmark + +sudo -S sysctl kernel.perf_event_paranoid=1 +sudo cpufreq-set -g performance +cat /sys/devices/system/cpu/cpu*/cpufreq/scaling_governor +cat /sys/devices/system/cpu/cpu0/cpufreq/scaling_available_governors diff --git a/benchmark/nanobench/nanobench.h b/benchmark/nanobench/nanobench.h new file mode 100644 index 0000000..d233dcc --- /dev/null +++ b/benchmark/nanobench/nanobench.h @@ -0,0 +1,3484 @@ +// __ _ _______ __ _ _____ ______ _______ __ _ _______ _ _ +// | \ | |_____| | \ | | | |_____] |______ | \ | | |_____| +// | \_| | | | \_| |_____| |_____] |______ | \_| |_____ | | +// +// Microbenchmark framework for C++11/14/17/20 +// https://github.com/martinus/nanobench +// +// Licensed under the MIT License . +// SPDX-License-Identifier: MIT +// Copyright (c) 2019-2023 Martin Leitner-Ankerl +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#ifndef ANKERL_NANOBENCH_H_INCLUDED +#define ANKERL_NANOBENCH_H_INCLUDED + +// see https://semver.org/ +#define ANKERL_NANOBENCH_VERSION_MAJOR 4 // incompatible API changes +#define ANKERL_NANOBENCH_VERSION_MINOR 3 // backwards-compatible changes +#define ANKERL_NANOBENCH_VERSION_PATCH 11 // backwards-compatible bug fixes + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// public facing api - as minimal as possible +/////////////////////////////////////////////////////////////////////////////////////////////////// + +#include // high_resolution_clock +#include // memcpy +#include // for std::ostream* custom output target in Config +#include // all names +#include // holds context information of results +#include // holds all results + +#define ANKERL_NANOBENCH(x) ANKERL_NANOBENCH_PRIVATE_##x() + +#define ANKERL_NANOBENCH_PRIVATE_CXX() __cplusplus +#define ANKERL_NANOBENCH_PRIVATE_CXX98() 199711L +#define ANKERL_NANOBENCH_PRIVATE_CXX11() 201103L +#define ANKERL_NANOBENCH_PRIVATE_CXX14() 201402L +#define ANKERL_NANOBENCH_PRIVATE_CXX17() 201703L + +#if ANKERL_NANOBENCH(CXX) >= ANKERL_NANOBENCH(CXX17) +# define ANKERL_NANOBENCH_PRIVATE_NODISCARD() [[nodiscard]] +#else +# define ANKERL_NANOBENCH_PRIVATE_NODISCARD() +#endif + +#if defined(__clang__) +# define ANKERL_NANOBENCH_PRIVATE_IGNORE_PADDED_PUSH() \ + _Pragma("clang diagnostic push") _Pragma("clang diagnostic ignored \"-Wpadded\"") +# define ANKERL_NANOBENCH_PRIVATE_IGNORE_PADDED_POP() _Pragma("clang diagnostic pop") +#else +# define ANKERL_NANOBENCH_PRIVATE_IGNORE_PADDED_PUSH() +# define ANKERL_NANOBENCH_PRIVATE_IGNORE_PADDED_POP() +#endif + +#if defined(__GNUC__) +# define ANKERL_NANOBENCH_PRIVATE_IGNORE_EFFCPP_PUSH() _Pragma("GCC diagnostic push") _Pragma("GCC diagnostic ignored \"-Weffc++\"") +# define ANKERL_NANOBENCH_PRIVATE_IGNORE_EFFCPP_POP() _Pragma("GCC diagnostic pop") +#else +# define ANKERL_NANOBENCH_PRIVATE_IGNORE_EFFCPP_PUSH() +# define ANKERL_NANOBENCH_PRIVATE_IGNORE_EFFCPP_POP() +#endif + +#if defined(ANKERL_NANOBENCH_LOG_ENABLED) +# include +# define ANKERL_NANOBENCH_LOG(x) \ + do { \ + std::cout << __FUNCTION__ << "@" << __LINE__ << ": " << x << std::endl; \ + } while (0) +#else +# define ANKERL_NANOBENCH_LOG(x) \ + do { \ + } while (0) +#endif + +#define ANKERL_NANOBENCH_PRIVATE_PERF_COUNTERS() 0 +#if defined(__linux__) && !defined(ANKERL_NANOBENCH_DISABLE_PERF_COUNTERS) +# include +# if LINUX_VERSION_CODE >= KERNEL_VERSION(3, 3, 0) +// PERF_COUNT_HW_REF_CPU_CYCLES only available since kernel 3.3 +// PERF_FLAG_FD_CLOEXEC since kernel 3.14 +# undef ANKERL_NANOBENCH_PRIVATE_PERF_COUNTERS +# define ANKERL_NANOBENCH_PRIVATE_PERF_COUNTERS() 1 +# endif +#endif + +#if defined(__clang__) +# define ANKERL_NANOBENCH_NO_SANITIZE(...) __attribute__((no_sanitize(__VA_ARGS__))) +#else +# define ANKERL_NANOBENCH_NO_SANITIZE(...) +#endif + +#if defined(_MSC_VER) +# define ANKERL_NANOBENCH_PRIVATE_NOINLINE() __declspec(noinline) +#else +# define ANKERL_NANOBENCH_PRIVATE_NOINLINE() __attribute__((noinline)) +#endif + +// workaround missing "is_trivially_copyable" in g++ < 5.0 +// See https://stackoverflow.com/a/31798726/48181 +#if defined(__GNUC__) && __GNUC__ < 5 +# define ANKERL_NANOBENCH_IS_TRIVIALLY_COPYABLE(...) __has_trivial_copy(__VA_ARGS__) +#else +# define ANKERL_NANOBENCH_IS_TRIVIALLY_COPYABLE(...) std::is_trivially_copyable<__VA_ARGS__>::value +#endif + +// noexcept may be missing for std::string. +// See https://gcc.gnu.org/bugzilla/show_bug.cgi?id=58265 +#define ANKERL_NANOBENCH_PRIVATE_NOEXCEPT_STRING_MOVE() std::is_nothrow_move_assignable::value + +// declarations /////////////////////////////////////////////////////////////////////////////////// + +namespace ankerl { +namespace nanobench { + +using Clock = std::conditional::type; +class Bench; +struct Config; +class Result; +class Rng; +class BigO; + +/** + * @brief Renders output from a mustache-like template and benchmark results. + * + * The templating facility here is heavily inspired by [mustache - logic-less templates](https://mustache.github.io/). + * It adds a few more features that are necessary to get all of the captured data out of nanobench. Please read the + * excellent [mustache manual](https://mustache.github.io/mustache.5.html) to see what this is all about. + * + * nanobench output has two nested layers, *result* and *measurement*. Here is a hierarchy of the allowed tags: + * + * * `{{#result}}` Marks the begin of the result layer. Whatever comes after this will be instantiated as often as + * a benchmark result is available. Within it, you can use these tags: + * + * * `{{title}}` See Bench::title. + * + * * `{{name}}` Benchmark name, usually directly provided with Bench::run, but can also be set with Bench::name. + * + * * `{{unit}}` Unit, e.g. `byte`. Defaults to `op`, see Bench::unit. + * + * * `{{batch}}` Batch size, see Bench::batch. + * + * * `{{complexityN}}` Value used for asymptotic complexity calculation. See Bench::complexityN. + * + * * `{{epochs}}` Number of epochs, see Bench::epochs. + * + * * `{{clockResolution}}` Accuracy of the clock, i.e. what's the smallest time possible to measure with the clock. + * For modern systems, this can be around 20 ns. This value is automatically determined by nanobench at the first + * benchmark that is run, and used as a static variable throughout the application's runtime. + * + * * `{{clockResolutionMultiple}}` Configuration multiplier for `clockResolution`. See Bench::clockResolutionMultiple. + * This is the target runtime for each measurement (epoch). That means the more accurate your clock is, the faster + * will be the benchmark. Basing the measurement's runtime on the clock resolution is the main reason why nanobench is so fast. + * + * * `{{maxEpochTime}}` Configuration for a maximum time each measurement (epoch) is allowed to take. Note that at least + * a single iteration will be performed, even when that takes longer than maxEpochTime. See Bench::maxEpochTime. + * + * * `{{minEpochTime}}` Minimum epoch time, defaults to 1ms. See Bench::minEpochTime. + * + * * `{{minEpochIterations}}` See Bench::minEpochIterations. + * + * * `{{epochIterations}}` See Bench::epochIterations. + * + * * `{{warmup}}` Number of iterations used before measuring starts. See Bench::warmup. + * + * * `{{relative}}` True or false, depending on the setting you have used. See Bench::relative. + * + * * `{{context(variableName)}}` See Bench::context. + * + * Apart from these tags, it is also possible to use some mathematical operations on the measurement data. The operations + * are of the form `{{command(name)}}`. Currently `name` can be one of `elapsed`, `iterations`. If performance counters + * are available (currently only on current Linux systems), you also have `pagefaults`, `cpucycles`, + * `contextswitches`, `instructions`, `branchinstructions`, and `branchmisses`. All the measures (except `iterations`) are + * provided for a single iteration (so `elapsed` is the time a single iteration took). The following tags are available: + * + * * `{{median()}}` Calculate median of a measurement data set, e.g. `{{median(elapsed)}}`. + * + * * `{{average()}}` Average (mean) calculation. + * + * * `{{medianAbsolutePercentError()}}` Calculates MdAPE, the Median Absolute Percentage Error. The MdAPE is an excellent + * metric for the variation of measurements. It is more robust to outliers than the + * [Mean absolute percentage error (M-APE)](https://en.wikipedia.org/wiki/Mean_absolute_percentage_error). + * @f[ + * \mathrm{MdAPE}(e) = \mathrm{med}\{| \frac{e_i - \mathrm{med}\{e\}}{e_i}| \} + * @f] + * E.g. for *elapsed*: First, @f$ \mathrm{med}\{e\} @f$ calculates the median by sorting and then taking the middle element + * of all *elapsed* measurements. This is used to calculate the absolute percentage + * error to this median for each measurement, as in @f$ | \frac{e_i - \mathrm{med}\{e\}}{e_i}| @f$. All these results + * are sorted, and the middle value is chosen as the median absolute percent error. + * + * This measurement is a bit hard to interpret, but it is very robust against outliers. E.g. a value of 5% means that half of the + * measurements deviate less than 5% from the median, and the other deviate more than 5% from the median. + * + * * `{{sum()}}` Sum of all the measurements. E.g. `{{sum(iterations)}}` will give you the total number of iterations +* measured in this benchmark. + * + * * `{{minimum()}}` Minimum of all measurements. + * + * * `{{maximum()}}` Maximum of all measurements. + * + * * `{{sumProduct(, )}}` Calculates the sum of the products of corresponding measures: + * @f[ + * \mathrm{sumProduct}(a,b) = \sum_{i=1}^{n}a_i\cdot b_i + * @f] + * E.g. to calculate total runtime of the benchmark, you multiply iterations with elapsed time for each measurement, and + * sum these results up: + * `{{sumProduct(iterations, elapsed)}}`. + * + * * `{{#measurement}}` To access individual measurement results, open the begin tag for measurements. + * + * * `{{elapsed}}` Average elapsed wall clock time per iteration, in seconds. + * + * * `{{iterations}}` Number of iterations in the measurement. The number of iterations will fluctuate due + * to some applied randomness, to enhance accuracy. + * + * * `{{pagefaults}}` Average number of pagefaults per iteration. + * + * * `{{cpucycles}}` Average number of CPU cycles processed per iteration. + * + * * `{{contextswitches}}` Average number of context switches per iteration. + * + * * `{{instructions}}` Average number of retired instructions per iteration. + * + * * `{{branchinstructions}}` Average number of branches executed per iteration. + * + * * `{{branchmisses}}` Average number of branches that were missed per iteration. + * + * * `{{/measurement}}` Ends the measurement tag. + * + * * `{{/result}}` Marks the end of the result layer. This is the end marker for the template part that will be instantiated + * for each benchmark result. + * + * + * For the layer tags *result* and *measurement* you additionally can use these special markers: + * + * * ``{{#-first}}`` - Begin marker of a template that will be instantiated *only for the first* entry in the layer. Use is only + * allowed between the begin and end marker of the layer. So between ``{{#result}}`` and ``{{/result}}``, or between + * ``{{#measurement}}`` and ``{{/measurement}}``. Finish the template with ``{{/-first}}``. + * + * * ``{{^-first}}`` - Begin marker of a template that will be instantiated *for each except the first* entry in the layer. This, + * this is basically the inversion of ``{{#-first}}``. Use is only allowed between the begin and end marker of the layer. + * So between ``{{#result}}`` and ``{{/result}}``, or between ``{{#measurement}}`` and ``{{/measurement}}``. + * + * * ``{{/-first}}`` - End marker for either ``{{#-first}}`` or ``{{^-first}}``. + * + * * ``{{#-last}}`` - Begin marker of a template that will be instantiated *only for the last* entry in the layer. Use is only + * allowed between the begin and end marker of the layer. So between ``{{#result}}`` and ``{{/result}}``, or between + * ``{{#measurement}}`` and ``{{/measurement}}``. Finish the template with ``{{/-last}}``. + * + * * ``{{^-last}}`` - Begin marker of a template that will be instantiated *for each except the last* entry in the layer. This, + * this is basically the inversion of ``{{#-last}}``. Use is only allowed between the begin and end marker of the layer. + * So between ``{{#result}}`` and ``{{/result}}``, or between ``{{#measurement}}`` and ``{{/measurement}}``. + * + * * ``{{/-last}}`` - End marker for either ``{{#-last}}`` or ``{{^-last}}``. + * + @verbatim embed:rst + + For an overview of all the possible data you can get out of nanobench, please see the tutorial at :ref:`tutorial-template-json`. + + The templates that ship with nanobench are: + + * :cpp:func:`templates::csv() ` + * :cpp:func:`templates::json() ` + * :cpp:func:`templates::htmlBoxplot() ` + * :cpp:func:`templates::pyperf() ` + + @endverbatim + * + * @param mustacheTemplate The template. + * @param bench Benchmark, containing all the results. + * @param out Output for the generated output. + */ +void render(char const* mustacheTemplate, Bench const& bench, std::ostream& out); +void render(std::string const& mustacheTemplate, Bench const& bench, std::ostream& out); + +/** + * Same as render(char const* mustacheTemplate, Bench const& bench, std::ostream& out), but for when + * you only have results available. + * + * @param mustacheTemplate The template. + * @param results All the results to be used for rendering. + * @param out Output for the generated output. + */ +void render(char const* mustacheTemplate, std::vector const& results, std::ostream& out); +void render(std::string const& mustacheTemplate, std::vector const& results, std::ostream& out); + +// Contains mustache-like templates +namespace templates { + +/*! + @brief CSV data for the benchmark results. + + Generates a comma-separated values dataset. First line is the header, each following line is a summary of each benchmark run. + + @verbatim embed:rst + See the tutorial at :ref:`tutorial-template-csv` for an example. + @endverbatim + */ +char const* csv() noexcept; + +/*! + @brief HTML output that uses plotly to generate an interactive boxplot chart. See the tutorial for an example output. + + The output uses only the elapsed wall clock time, and displays each epoch as a single dot. + @verbatim embed:rst + See the tutorial at :ref:`tutorial-template-html` for an example. + @endverbatim + + @see also ankerl::nanobench::render() + */ +char const* htmlBoxplot() noexcept; + +/*! + @brief Output in pyperf compatible JSON format, which can be used for more analyzation. + @verbatim embed:rst + See the tutorial at :ref:`tutorial-template-pyperf` for an example how to further analyze the output. + @endverbatim + */ +char const* pyperf() noexcept; + +/*! + @brief Template to generate JSON data. + + The generated JSON data contains *all* data that has been generated. All times are as double values, in seconds. The output can get + quite large. + @verbatim embed:rst + See the tutorial at :ref:`tutorial-template-json` for an example. + @endverbatim + */ +char const* json() noexcept; + +} // namespace templates + +namespace detail { + +template +struct PerfCountSet; + +class IterationLogic; +class PerformanceCounters; + +#if ANKERL_NANOBENCH(PERF_COUNTERS) +class LinuxPerformanceCounters; +#endif + +} // namespace detail +} // namespace nanobench +} // namespace ankerl + +// definitions //////////////////////////////////////////////////////////////////////////////////// + +namespace ankerl { +namespace nanobench { +namespace detail { + +template +struct PerfCountSet { + T pageFaults{}; + T cpuCycles{}; + T contextSwitches{}; + T instructions{}; + T branchInstructions{}; + T branchMisses{}; +}; + +} // namespace detail + +ANKERL_NANOBENCH(IGNORE_PADDED_PUSH) +struct Config { + // actual benchmark config + std::string mBenchmarkTitle = "benchmark"; // NOLINT(misc-non-private-member-variables-in-classes) + std::string mBenchmarkName = "noname"; // NOLINT(misc-non-private-member-variables-in-classes) + std::string mUnit = "op"; // NOLINT(misc-non-private-member-variables-in-classes) + double mBatch = 1.0; // NOLINT(misc-non-private-member-variables-in-classes) + double mComplexityN = -1.0; // NOLINT(misc-non-private-member-variables-in-classes) + size_t mNumEpochs = 11; // NOLINT(misc-non-private-member-variables-in-classes) + size_t mClockResolutionMultiple = static_cast(1000); // NOLINT(misc-non-private-member-variables-in-classes) + std::chrono::nanoseconds mMaxEpochTime = std::chrono::milliseconds(100); // NOLINT(misc-non-private-member-variables-in-classes) + std::chrono::nanoseconds mMinEpochTime = std::chrono::milliseconds(1); // NOLINT(misc-non-private-member-variables-in-classes) + uint64_t mMinEpochIterations{1}; // NOLINT(misc-non-private-member-variables-in-classes) + // If not 0, run *exactly* these number of iterations per epoch. + uint64_t mEpochIterations{0}; // NOLINT(misc-non-private-member-variables-in-classes) + uint64_t mWarmup = 0; // NOLINT(misc-non-private-member-variables-in-classes) + std::ostream* mOut = nullptr; // NOLINT(misc-non-private-member-variables-in-classes) + std::chrono::duration mTimeUnit = std::chrono::nanoseconds{1}; // NOLINT(misc-non-private-member-variables-in-classes) + std::string mTimeUnitName = "ns"; // NOLINT(misc-non-private-member-variables-in-classes) + bool mShowPerformanceCounters = true; // NOLINT(misc-non-private-member-variables-in-classes) + bool mIsRelative = false; // NOLINT(misc-non-private-member-variables-in-classes) + std::unordered_map mContext{}; // NOLINT(misc-non-private-member-variables-in-classes) + + Config(); + ~Config(); + Config& operator=(Config const& other); + Config& operator=(Config&& other) noexcept(ANKERL_NANOBENCH(NOEXCEPT_STRING_MOVE)); + Config(Config const& other); + Config(Config&& other) noexcept; +}; +ANKERL_NANOBENCH(IGNORE_PADDED_POP) + +// Result returned after a benchmark has finished. Can be used as a baseline for relative(). +ANKERL_NANOBENCH(IGNORE_PADDED_PUSH) +class Result { +public: + enum class Measure : size_t { + elapsed, + iterations, + pagefaults, + cpucycles, + contextswitches, + instructions, + branchinstructions, + branchmisses, + _size + }; + + explicit Result(Config benchmarkConfig); + + ~Result(); + Result& operator=(Result const& other); + Result& operator=(Result&& other) noexcept(ANKERL_NANOBENCH(NOEXCEPT_STRING_MOVE)); + Result(Result const& other); + Result(Result&& other) noexcept; + + // adds new measurement results + // all values are scaled by iters (except iters...) + void add(Clock::duration totalElapsed, uint64_t iters, detail::PerformanceCounters const& pc); + + ANKERL_NANOBENCH(NODISCARD) Config const& config() const noexcept; + + ANKERL_NANOBENCH(NODISCARD) double median(Measure m) const; + ANKERL_NANOBENCH(NODISCARD) double medianAbsolutePercentError(Measure m) const; + ANKERL_NANOBENCH(NODISCARD) double average(Measure m) const; + ANKERL_NANOBENCH(NODISCARD) double sum(Measure m) const noexcept; + ANKERL_NANOBENCH(NODISCARD) double sumProduct(Measure m1, Measure m2) const noexcept; + ANKERL_NANOBENCH(NODISCARD) double minimum(Measure m) const noexcept; + ANKERL_NANOBENCH(NODISCARD) double maximum(Measure m) const noexcept; + ANKERL_NANOBENCH(NODISCARD) std::string const& context(char const* variableName) const; + ANKERL_NANOBENCH(NODISCARD) std::string const& context(std::string const& variableName) const; + + ANKERL_NANOBENCH(NODISCARD) bool has(Measure m) const noexcept; + ANKERL_NANOBENCH(NODISCARD) double get(size_t idx, Measure m) const; + ANKERL_NANOBENCH(NODISCARD) bool empty() const noexcept; + ANKERL_NANOBENCH(NODISCARD) size_t size() const noexcept; + + // Finds string, if not found, returns _size. + static Measure fromString(std::string const& str); + +private: + Config mConfig{}; + std::vector> mNameToMeasurements{}; +}; +ANKERL_NANOBENCH(IGNORE_PADDED_POP) + +/** + * An extremely fast random generator. Currently, this implements *RomuDuoJr*, developed by Mark Overton. Source: + * http://www.romu-random.org/ + * + * RomuDuoJr is extremely fast and provides reasonable good randomness. Not enough for large jobs, but definitely + * good enough for a benchmarking framework. + * + * * Estimated capacity: @f$ 2^{51} @f$ bytes + * * Register pressure: 4 + * * State size: 128 bits + * + * This random generator is a drop-in replacement for the generators supplied by ````. It is not + * cryptographically secure. It's intended purpose is to be very fast so that benchmarks that make use + * of randomness are not distorted too much by the random generator. + * + * Rng also provides a few non-standard helpers, optimized for speed. + */ +class Rng final { +public: + /** + * @brief This RNG provides 64bit randomness. + */ + using result_type = uint64_t; + + static constexpr uint64_t(min)(); + static constexpr uint64_t(max)(); + + /** + * As a safety precaution, we don't allow copying. Copying a PRNG would mean you would have two random generators that produce the + * same sequence, which is generally not what one wants. Instead create a new rng with the default constructor Rng(), which is + * automatically seeded from `std::random_device`. If you really need a copy, use `copy()`. + */ + Rng(Rng const&) = delete; + + /** + * Same as Rng(Rng const&), we don't allow assignment. If you need a new Rng create one with the default constructor Rng(). + */ + Rng& operator=(Rng const&) = delete; + + // moving is ok + Rng(Rng&&) noexcept = default; + Rng& operator=(Rng&&) noexcept = default; + ~Rng() noexcept = default; + + /** + * @brief Creates a new Random generator with random seed. + * + * Instead of a default seed (as the random generators from the STD), this properly seeds the random generator from + * `std::random_device`. It guarantees correct seeding. Note that seeding can be relatively slow, depending on the source of + * randomness used. So it is best to create a Rng once and use it for all your randomness purposes. + */ + Rng(); + + /*! + Creates a new Rng that is seeded with a specific seed. Each Rng created from the same seed will produce the same randomness + sequence. This can be useful for deterministic behavior. + + @verbatim embed:rst + .. note:: + + The random algorithm might change between nanobench releases. Whenever a faster and/or better random + generator becomes available, I will switch the implementation. + @endverbatim + + As per the Romu paper, this seeds the Rng with splitMix64 algorithm and performs 10 initial rounds for further mixing up of the + internal state. + + @param seed The 64bit seed. All values are allowed, even 0. + */ + explicit Rng(uint64_t seed) noexcept; + Rng(uint64_t x, uint64_t y) noexcept; + explicit Rng(std::vector const& data); + + /** + * Creates a copy of the Rng, thus the copy provides exactly the same random sequence as the original. + */ + ANKERL_NANOBENCH(NODISCARD) Rng copy() const noexcept; + + /** + * @brief Produces a 64bit random value. This should be very fast, thus it is marked as inline. In my benchmark, this is ~46 times + * faster than `std::default_random_engine` for producing 64bit random values. It seems that the fastest std contender is + * `std::mt19937_64`. Still, this RNG is 2-3 times as fast. + * + * @return uint64_t The next 64 bit random value. + */ + inline uint64_t operator()() noexcept; + + // This is slightly biased. See + + /** + * Generates a random number between 0 and range (excluding range). + * + * The algorithm only produces 32bit numbers, and is slightly biased. The effect is quite small unless your range is close to the + * maximum value of an integer. It is possible to correct the bias with rejection sampling (see + * [here](https://lemire.me/blog/2016/06/30/fast-random-shuffling/), but this is most likely irrelevant in practices for the + * purposes of this Rng. + * + * See Daniel Lemire's blog post [A fast alternative to the modulo + * reduction](https://lemire.me/blog/2016/06/27/a-fast-alternative-to-the-modulo-reduction/) + * + * @param range Upper exclusive range. E.g a value of 3 will generate random numbers 0, 1, 2. + * @return uint32_t Generated random values in range [0, range(. + */ + inline uint32_t bounded(uint32_t range) noexcept; + + // random double in range [0, 1( + // see http://prng.di.unimi.it/ + + /** + * Provides a random uniform double value between 0 and 1. This uses the method described in [Generating uniform doubles in the + * unit interval](http://prng.di.unimi.it/), and is extremely fast. + * + * @return double Uniformly distributed double value in range [0,1(, excluding 1. + */ + inline double uniform01() noexcept; + + /** + * Shuffles all entries in the given container. Although this has a slight bias due to the implementation of bounded(), this is + * preferable to `std::shuffle` because it is over 5 times faster. See Daniel Lemire's blog post [Fast random + * shuffling](https://lemire.me/blog/2016/06/30/fast-random-shuffling/). + * + * @param container The whole container will be shuffled. + */ + template + void shuffle(Container& container) noexcept; + + /** + * Extracts the full state of the generator, e.g. for serialization. For this RNG this is just 2 values, but to stay API compatible + * with future implementations that potentially use more state, we use a vector. + * + * @return Vector containing the full state: + */ + ANKERL_NANOBENCH(NODISCARD) std::vector state() const; + +private: + static constexpr uint64_t rotl(uint64_t x, unsigned k) noexcept; + + uint64_t mX; + uint64_t mY; +}; + +/** + * @brief Main entry point to nanobench's benchmarking facility. + * + * It holds configuration and results from one or more benchmark runs. Usually it is used in a single line, where the object is + * constructed, configured, and then a benchmark is run. E.g. like this: + * + * ankerl::nanobench::Bench().unit("byte").batch(1000).run("random fluctuations", [&] { + * // here be the benchmark code + * }); + * + * In that example Bench() constructs the benchmark, it is then configured with unit() and batch(), and after configuration a + * benchmark is executed with run(). Once run() has finished, it prints the result to `std::cout`. It would also store the results + * in the Bench instance, but in this case the object is immediately destroyed so it's not available any more. + */ +ANKERL_NANOBENCH(IGNORE_PADDED_PUSH) +class Bench { +public: + /** + * @brief Creates a new benchmark for configuration and running of benchmarks. + */ + Bench(); + + Bench(Bench&& other) noexcept; + Bench& operator=(Bench&& other) noexcept(ANKERL_NANOBENCH(NOEXCEPT_STRING_MOVE)); + Bench(Bench const& other); + Bench& operator=(Bench const& other); + ~Bench() noexcept; + + /*! + @brief Repeatedly calls `op()` based on the configuration, and performs measurements. + + This call is marked with `noinline` to prevent the compiler to optimize beyond different benchmarks. This can have quite a big + effect on benchmark accuracy. + + @verbatim embed:rst + .. note:: + + Each call to your lambda must have a side effect that the compiler can't possibly optimize it away. E.g. add a result to an + externally defined number (like `x` in the above example), and finally call `doNotOptimizeAway` on the variables the compiler + must not remove. You can also use :cpp:func:`ankerl::nanobench::doNotOptimizeAway` directly in the lambda, but be aware that + this has a small overhead. + + @endverbatim + + @tparam Op The code to benchmark. + */ + template + ANKERL_NANOBENCH(NOINLINE) + Bench& run(char const* benchmarkName, Op&& op); + + template + ANKERL_NANOBENCH(NOINLINE) + Bench& run(std::string const& benchmarkName, Op&& op); + + /** + * @brief Same as run(char const* benchmarkName, Op op), but instead uses the previously set name. + * @tparam Op The code to benchmark. + */ + template + ANKERL_NANOBENCH(NOINLINE) + Bench& run(Op&& op); + + /** + * @brief Title of the benchmark, will be shown in the table header. Changing the title will start a new markdown table. + * + * @param benchmarkTitle The title of the benchmark. + */ + Bench& title(char const* benchmarkTitle); + Bench& title(std::string const& benchmarkTitle); + + /** + * @brief Gets the title of the benchmark + */ + ANKERL_NANOBENCH(NODISCARD) std::string const& title() const noexcept; + + /// Name of the benchmark, will be shown in the table row. + Bench& name(char const* benchmarkName); + Bench& name(std::string const& benchmarkName); + ANKERL_NANOBENCH(NODISCARD) std::string const& name() const noexcept; + + /** + * @brief Set context information. + * + * The information can be accessed using custom render templates via `{{context(variableName)}}`. + * Trying to render a variable that hasn't been set before raises an exception. + * Not included in (default) markdown table. + * + * @see clearContext, render + * + * @param variableName The name of the context variable. + * @param variableValue The value of the context variable. + */ + Bench& context(char const* variableName, char const* variableValue); + Bench& context(std::string const& variableName, std::string const& variableValue); + + /** + * @brief Reset context information. + * + * This may improve efficiency when using many context entries, + * or improve robustness by removing spurious context entries. + * + * @see context + */ + Bench& clearContext(); + + /** + * @brief Sets the batch size. + * + * E.g. number of processed byte, or some other metric for the size of the processed data in each iteration. If you benchmark + * hashing of a 1000 byte long string and want byte/sec as a result, you can specify 1000 as the batch size. + * + * @tparam T Any input type is internally cast to `double`. + * @param b batch size + */ + template + Bench& batch(T b) noexcept; + ANKERL_NANOBENCH(NODISCARD) double batch() const noexcept; + + /** + * @brief Sets the operation unit. + * + * Defaults to "op". Could be e.g. "byte" for string processing. This is used for the table header, e.g. to show `ns/byte`. Use + * singular (*byte*, not *bytes*). A change clears the currently collected results. + * + * @param unit The unit name. + */ + Bench& unit(char const* unit); + Bench& unit(std::string const& unit); + ANKERL_NANOBENCH(NODISCARD) std::string const& unit() const noexcept; + + /** + * @brief Sets the time unit to be used for the default output. + * + * Nanobench defaults to using ns (nanoseconds) as output in the markdown. For some benchmarks this is too coarse, so it is + * possible to configure this. E.g. use `timeUnit(1ms, "ms")` to show `ms/op` instead of `ns/op`. + * + * @param tu Time unit to display the results in, default is 1ns. + * @param tuName Name for the time unit, default is "ns" + */ + Bench& timeUnit(std::chrono::duration const& tu, std::string const& tuName); + ANKERL_NANOBENCH(NODISCARD) std::string const& timeUnitName() const noexcept; + ANKERL_NANOBENCH(NODISCARD) std::chrono::duration const& timeUnit() const noexcept; + + /** + * @brief Set the output stream where the resulting markdown table will be printed to. + * + * The default is `&std::cout`. You can disable all output by setting `nullptr`. + * + * @param outstream Pointer to output stream, can be `nullptr`. + */ + Bench& output(std::ostream* outstream) noexcept; + ANKERL_NANOBENCH(NODISCARD) std::ostream* output() const noexcept; + + /** + * Modern processors have a very accurate clock, being able to measure as low as 20 nanoseconds. This is the main trick nanobech to + * be so fast: we find out how accurate the clock is, then run the benchmark only so often that the clock's accuracy is good enough + * for accurate measurements. + * + * The default is to run one epoch for 1000 times the clock resolution. So for 20ns resolution and 11 epochs, this gives a total + * runtime of + * + * @f[ + * 20ns * 1000 * 11 \approx 0.2ms + * @f] + * + * To be precise, nanobench adds a 0-20% random noise to each evaluation. This is to prevent any aliasing effects, and further + * improves accuracy. + * + * Total runtime will be higher though: Some initial time is needed to find out the target number of iterations for each epoch, and + * there is some overhead involved to start & stop timers and calculate resulting statistics and writing the output. + * + * @param multiple Target number of times of clock resolution. Usually 1000 is a good compromise between runtime and accuracy. + */ + Bench& clockResolutionMultiple(size_t multiple) noexcept; + ANKERL_NANOBENCH(NODISCARD) size_t clockResolutionMultiple() const noexcept; + + /** + * @brief Controls number of epochs, the number of measurements to perform. + * + * The reported result will be the median of evaluation of each epoch. The higher you choose this, the more + * deterministic the result be and outliers will be more easily removed. Also the `err%` will be more accurate the higher this + * number is. Note that the `err%` will not necessarily decrease when number of epochs is increased. But it will be a more accurate + * representation of the benchmarked code's runtime stability. + * + * Choose the value wisely. In practice, 11 has been shown to be a reasonable choice between runtime performance and accuracy. + * This setting goes hand in hand with minEpochIterations() (or minEpochTime()). If you are more interested in *median* runtime, + * you might want to increase epochs(). If you are more interested in *mean* runtime, you might want to increase + * minEpochIterations() instead. + * + * @param numEpochs Number of epochs. + */ + Bench& epochs(size_t numEpochs) noexcept; + ANKERL_NANOBENCH(NODISCARD) size_t epochs() const noexcept; + + /** + * @brief Upper limit for the runtime of each epoch. + * + * As a safety precaution if the clock is not very accurate, we can set an upper limit for the maximum evaluation time per + * epoch. Default is 100ms. At least a single evaluation of the benchmark is performed. + * + * @see minEpochTime, minEpochIterations + * + * @param t Maximum target runtime for a single epoch. + */ + Bench& maxEpochTime(std::chrono::nanoseconds t) noexcept; + ANKERL_NANOBENCH(NODISCARD) std::chrono::nanoseconds maxEpochTime() const noexcept; + + /** + * @brief Minimum time each epoch should take. + * + * Default is 1ms, so we are mostly relying on clockResolutionMultiple(). In most cases this is exactly what you want. If you see + * that the evaluation is unreliable with a high `err%`, you can increase either minEpochTime() or minEpochIterations(). + * + * @see maxEpochTime, minEpochIterations + * + * @param t Minimum time each epoch should take. + */ + Bench& minEpochTime(std::chrono::nanoseconds t) noexcept; + ANKERL_NANOBENCH(NODISCARD) std::chrono::nanoseconds minEpochTime() const noexcept; + + /** + * @brief Sets the minimum number of iterations each epoch should take. + * + * Default is 1, and we rely on clockResolutionMultiple(). If the `err%` is high and you want a more smooth result, you might want + * to increase the minimum number of iterations, or increase the minEpochTime(). + * + * @see minEpochTime, maxEpochTime, minEpochIterations + * + * @param numIters Minimum number of iterations per epoch. + */ + Bench& minEpochIterations(uint64_t numIters) noexcept; + ANKERL_NANOBENCH(NODISCARD) uint64_t minEpochIterations() const noexcept; + + /** + * Sets exactly the number of iterations for each epoch. Ignores all other epoch limits. This forces nanobench to use exactly + * the given number of iterations for each epoch, not more and not less. Default is 0 (disabled). + * + * @param numIters Exact number of iterations to use. Set to 0 to disable. + */ + Bench& epochIterations(uint64_t numIters) noexcept; + ANKERL_NANOBENCH(NODISCARD) uint64_t epochIterations() const noexcept; + + /** + * @brief Sets a number of iterations that are initially performed without any measurements. + * + * Some benchmarks need a few evaluations to warm up caches / database / whatever access. Normally this should not be needed, since + * we show the median result so initial outliers will be filtered away automatically. If the warmup effect is large though, you + * might want to set it. Default is 0. + * + * @param numWarmupIters Number of warmup iterations. + */ + Bench& warmup(uint64_t numWarmupIters) noexcept; + ANKERL_NANOBENCH(NODISCARD) uint64_t warmup() const noexcept; + + /** + * @brief Marks the next run as the baseline. + * + * Call `relative(true)` to mark the run as the baseline. Successive runs will be compared to this run. It is calculated by + * + * @f[ + * 100\% * \frac{baseline}{runtime} + * @f] + * + * * 100% means it is exactly as fast as the baseline + * * >100% means it is faster than the baseline. E.g. 200% means the current run is twice as fast as the baseline. + * * <100% means it is slower than the baseline. E.g. 50% means it is twice as slow as the baseline. + * + * See the tutorial section "Comparing Results" for example usage. + * + * @param isRelativeEnabled True to enable processing + */ + Bench& relative(bool isRelativeEnabled) noexcept; + ANKERL_NANOBENCH(NODISCARD) bool relative() const noexcept; + + /** + * @brief Enables/disables performance counters. + * + * On Linux nanobench has a powerful feature to use performance counters. This enables counting of retired instructions, count + * number of branches, missed branches, etc. On default this is enabled, but you can disable it if you don't need that feature. + * + * @param showPerformanceCounters True to enable, false to disable. + */ + Bench& performanceCounters(bool showPerformanceCounters) noexcept; + ANKERL_NANOBENCH(NODISCARD) bool performanceCounters() const noexcept; + + /** + * @brief Retrieves all benchmark results collected by the bench object so far. + * + * Each call to run() generates a Result that is stored within the Bench instance. This is mostly for advanced users who want to + * see all the nitty gritty details. + * + * @return All results collected so far. + */ + ANKERL_NANOBENCH(NODISCARD) std::vector const& results() const noexcept; + + /*! + @verbatim embed:rst + + Convenience shortcut to :cpp:func:`ankerl::nanobench::doNotOptimizeAway`. + + @endverbatim + */ + template + Bench& doNotOptimizeAway(Arg&& arg); + + /*! + @verbatim embed:rst + + Sets N for asymptotic complexity calculation, so it becomes possible to calculate `Big O + `_ from multiple benchmark evaluations. + + Use :cpp:func:`ankerl::nanobench::Bench::complexityBigO` when the evaluation has finished. See the tutorial + :ref:`asymptotic-complexity` for details. + + @endverbatim + + @tparam T Any type is cast to `double`. + @param n Length of N for the next benchmark run, so it is possible to calculate `bigO`. + */ + template + Bench& complexityN(T n) noexcept; + ANKERL_NANOBENCH(NODISCARD) double complexityN() const noexcept; + + /*! + Calculates [Big O](https://en.wikipedia.org/wiki/Big_O_notation>) of the results with all preconfigured complexity functions. + Currently these complexity functions are fitted into the benchmark results: + + @f$ \mathcal{O}(1) @f$, + @f$ \mathcal{O}(n) @f$, + @f$ \mathcal{O}(\log{}n) @f$, + @f$ \mathcal{O}(n\log{}n) @f$, + @f$ \mathcal{O}(n^2) @f$, + @f$ \mathcal{O}(n^3) @f$. + + If we e.g. evaluate the complexity of `std::sort`, this is the result of `std::cout << bench.complexityBigO()`: + + ``` + | coefficient | err% | complexity + |--------------:|-------:|------------ + | 5.08935e-09 | 2.6% | O(n log n) + | 6.10608e-08 | 8.0% | O(n) + | 1.29307e-11 | 47.2% | O(n^2) + | 2.48677e-15 | 69.6% | O(n^3) + | 9.88133e-06 | 132.3% | O(log n) + | 5.98793e-05 | 162.5% | O(1) + ``` + + So in this case @f$ \mathcal{O}(n\log{}n) @f$ provides the best approximation. + + @verbatim embed:rst + See the tutorial :ref:`asymptotic-complexity` for details. + @endverbatim + @return Evaluation results, which can be printed or otherwise inspected. + */ + std::vector complexityBigO() const; + + /** + * @brief Calculates bigO for a custom function. + * + * E.g. to calculate the mean squared error for @f$ \mathcal{O}(\log{}\log{}n) @f$, which is not part of the default set of + * complexityBigO(), you can do this: + * + * ``` + * auto logLogN = bench.complexityBigO("O(log log n)", [](double n) { + * return std::log2(std::log2(n)); + * }); + * ``` + * + * The resulting mean squared error can be printed with `std::cout << logLogN`. E.g. it prints something like this: + * + * ```text + * 2.46985e-05 * O(log log n), rms=1.48121 + * ``` + * + * @tparam Op Type of mapping operation. + * @param name Name for the function, e.g. "O(log log n)" + * @param op Op's operator() maps a `double` with the desired complexity function, e.g. `log2(log2(n))`. + * @return BigO Error calculation, which is streamable to std::cout. + */ + template + BigO complexityBigO(char const* name, Op op) const; + + template + BigO complexityBigO(std::string const& name, Op op) const; + + /*! + @verbatim embed:rst + + Convenience shortcut to :cpp:func:`ankerl::nanobench::render`. + + @endverbatim + */ + Bench& render(char const* templateContent, std::ostream& os); + Bench& render(std::string const& templateContent, std::ostream& os); + + Bench& config(Config const& benchmarkConfig); + ANKERL_NANOBENCH(NODISCARD) Config const& config() const noexcept; + +private: + Config mConfig{}; + std::vector mResults{}; +}; +ANKERL_NANOBENCH(IGNORE_PADDED_POP) + +/** + * @brief Makes sure none of the given arguments are optimized away by the compiler. + * + * @tparam Arg Type of the argument that shouldn't be optimized away. + * @param arg The input that we mark as being used, even though we don't do anything with it. + */ +template +void doNotOptimizeAway(Arg&& arg); + +namespace detail { + +#if defined(_MSC_VER) +void doNotOptimizeAwaySink(void const*); + +template +void doNotOptimizeAway(T const& val); + +#else + +// These assembly magic is directly from what Google Benchmark is doing. I have previously used what facebook's folly was doing, but +// this seemed to have compilation problems in some cases. Google Benchmark seemed to be the most well tested anyways. +// see https://github.com/google/benchmark/blob/v1.7.1/include/benchmark/benchmark.h#L443-L446 +template +void doNotOptimizeAway(T const& val) { + // NOLINTNEXTLINE(hicpp-no-assembler) + asm volatile("" : : "r,m"(val) : "memory"); +} + +template +void doNotOptimizeAway(T& val) { +# if defined(__clang__) + // NOLINTNEXTLINE(hicpp-no-assembler) + asm volatile("" : "+r,m"(val) : : "memory"); +# else + // NOLINTNEXTLINE(hicpp-no-assembler) + asm volatile("" : "+m,r"(val) : : "memory"); +# endif +} +#endif + +// internally used, but visible because run() is templated. +// Not movable/copy-able, so we simply use a pointer instead of unique_ptr. This saves us from +// having to include , and the template instantiation overhead of unique_ptr which is unfortunately quite significant. +ANKERL_NANOBENCH(IGNORE_EFFCPP_PUSH) +class IterationLogic { +public: + explicit IterationLogic(Bench const& bench); + IterationLogic(IterationLogic&&) = delete; + IterationLogic& operator=(IterationLogic&&) = delete; + IterationLogic(IterationLogic const&) = delete; + IterationLogic& operator=(IterationLogic const&) = delete; + ~IterationLogic(); + + ANKERL_NANOBENCH(NODISCARD) uint64_t numIters() const noexcept; + void add(std::chrono::nanoseconds elapsed, PerformanceCounters const& pc) noexcept; + void moveResultTo(std::vector& results) noexcept; + +private: + struct Impl; + Impl* mPimpl; +}; +ANKERL_NANOBENCH(IGNORE_EFFCPP_POP) + +ANKERL_NANOBENCH(IGNORE_PADDED_PUSH) +class PerformanceCounters { +public: + PerformanceCounters(PerformanceCounters const&) = delete; + PerformanceCounters(PerformanceCounters&&) = delete; + PerformanceCounters& operator=(PerformanceCounters const&) = delete; + PerformanceCounters& operator=(PerformanceCounters&&) = delete; + + PerformanceCounters(); + ~PerformanceCounters(); + + void beginMeasure(); + void endMeasure(); + void updateResults(uint64_t numIters); + + ANKERL_NANOBENCH(NODISCARD) PerfCountSet const& val() const noexcept; + ANKERL_NANOBENCH(NODISCARD) PerfCountSet const& has() const noexcept; + +private: +#if ANKERL_NANOBENCH(PERF_COUNTERS) + LinuxPerformanceCounters* mPc = nullptr; +#endif + PerfCountSet mVal{}; + PerfCountSet mHas{}; +}; +ANKERL_NANOBENCH(IGNORE_PADDED_POP) + +// Gets the singleton +PerformanceCounters& performanceCounters(); + +} // namespace detail + +class BigO { +public: + using RangeMeasure = std::vector>; + + template + static RangeMeasure mapRangeMeasure(RangeMeasure data, Op op) { + for (auto& rangeMeasure : data) { + rangeMeasure.first = op(rangeMeasure.first); + } + return data; + } + + static RangeMeasure collectRangeMeasure(std::vector const& results); + + template + BigO(char const* bigOName, RangeMeasure const& rangeMeasure, Op rangeToN) + : BigO(bigOName, mapRangeMeasure(rangeMeasure, rangeToN)) {} + + template + BigO(std::string bigOName, RangeMeasure const& rangeMeasure, Op rangeToN) + : BigO(std::move(bigOName), mapRangeMeasure(rangeMeasure, rangeToN)) {} + + BigO(char const* bigOName, RangeMeasure const& scaledRangeMeasure); + BigO(std::string bigOName, RangeMeasure const& scaledRangeMeasure); + ANKERL_NANOBENCH(NODISCARD) std::string const& name() const noexcept; + ANKERL_NANOBENCH(NODISCARD) double constant() const noexcept; + ANKERL_NANOBENCH(NODISCARD) double normalizedRootMeanSquare() const noexcept; + ANKERL_NANOBENCH(NODISCARD) bool operator<(BigO const& other) const noexcept; + +private: + std::string mName{}; + double mConstant{}; + double mNormalizedRootMeanSquare{}; +}; +std::ostream& operator<<(std::ostream& os, BigO const& bigO); +std::ostream& operator<<(std::ostream& os, std::vector const& bigOs); + +} // namespace nanobench +} // namespace ankerl + +// implementation ///////////////////////////////////////////////////////////////////////////////// + +namespace ankerl { +namespace nanobench { + +constexpr uint64_t(Rng::min)() { + return 0; +} + +constexpr uint64_t(Rng::max)() { + return (std::numeric_limits::max)(); +} + +ANKERL_NANOBENCH_NO_SANITIZE("integer", "undefined") +uint64_t Rng::operator()() noexcept { + auto x = mX; + + mX = UINT64_C(15241094284759029579) * mY; + mY = rotl(mY - x, 27); + + return x; +} + +ANKERL_NANOBENCH_NO_SANITIZE("integer", "undefined") +uint32_t Rng::bounded(uint32_t range) noexcept { + uint64_t const r32 = static_cast(operator()()); + auto multiresult = r32 * range; + return static_cast(multiresult >> 32U); +} + +double Rng::uniform01() noexcept { + auto i = (UINT64_C(0x3ff) << 52U) | (operator()() >> 12U); + // can't use union in c++ here for type puning, it's undefined behavior. + // std::memcpy is optimized anyways. + double d{}; + std::memcpy(&d, &i, sizeof(double)); + return d - 1.0; +} + +template +void Rng::shuffle(Container& container) noexcept { + auto i = container.size(); + while (i > 1U) { + using std::swap; + auto n = operator()(); + // using decltype(i) instead of size_t to be compatible to containers with 32bit index (see #80) + auto b1 = static_cast((static_cast(n) * static_cast(i)) >> 32U); + swap(container[--i], container[b1]); + + auto b2 = static_cast(((n >> 32U) * static_cast(i)) >> 32U); + swap(container[--i], container[b2]); + } +} + +ANKERL_NANOBENCH_NO_SANITIZE("integer", "undefined") +constexpr uint64_t Rng::rotl(uint64_t x, unsigned k) noexcept { + return (x << k) | (x >> (64U - k)); +} + +template +ANKERL_NANOBENCH_NO_SANITIZE("integer") +Bench& Bench::run(Op&& op) { + // It is important that this method is kept short so the compiler can do better optimizations/ inlining of op() + detail::IterationLogic iterationLogic(*this); + auto& pc = detail::performanceCounters(); + + while (auto n = iterationLogic.numIters()) { + pc.beginMeasure(); + Clock::time_point const before = Clock::now(); + while (n-- > 0) { + op(); + } + Clock::time_point const after = Clock::now(); + pc.endMeasure(); + pc.updateResults(iterationLogic.numIters()); + iterationLogic.add(after - before, pc); + } + iterationLogic.moveResultTo(mResults); + return *this; +} + +// Performs all evaluations. +template +Bench& Bench::run(char const* benchmarkName, Op&& op) { + name(benchmarkName); + return run(std::forward(op)); +} + +template +Bench& Bench::run(std::string const& benchmarkName, Op&& op) { + name(benchmarkName); + return run(std::forward(op)); +} + +template +BigO Bench::complexityBigO(char const* benchmarkName, Op op) const { + return BigO(benchmarkName, BigO::collectRangeMeasure(mResults), op); +} + +template +BigO Bench::complexityBigO(std::string const& benchmarkName, Op op) const { + return BigO(benchmarkName, BigO::collectRangeMeasure(mResults), op); +} + +// Set the batch size, e.g. number of processed bytes, or some other metric for the size of the processed data in each iteration. +// Any argument is cast to double. +template +Bench& Bench::batch(T b) noexcept { + mConfig.mBatch = static_cast(b); + return *this; +} + +// Sets the computation complexity of the next run. Any argument is cast to double. +template +Bench& Bench::complexityN(T n) noexcept { + mConfig.mComplexityN = static_cast(n); + return *this; +} + +// Convenience: makes sure none of the given arguments are optimized away by the compiler. +template +Bench& Bench::doNotOptimizeAway(Arg&& arg) { + detail::doNotOptimizeAway(std::forward(arg)); + return *this; +} + +// Makes sure none of the given arguments are optimized away by the compiler. +template +void doNotOptimizeAway(Arg&& arg) { + detail::doNotOptimizeAway(std::forward(arg)); +} + +namespace detail { + +#if defined(_MSC_VER) +template +void doNotOptimizeAway(T const& val) { + doNotOptimizeAwaySink(&val); +} + +#endif + +} // namespace detail +} // namespace nanobench +} // namespace ankerl + +#if defined(ANKERL_NANOBENCH_IMPLEMENT) + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// implementation part - only visible in .cpp +/////////////////////////////////////////////////////////////////////////////////////////////////// + +# include // sort, reverse +# include // compare_exchange_strong in loop overhead +# include // getenv +# include // strstr, strncmp +# include // ifstream to parse proc files +# include // setw, setprecision +# include // cout +# include // accumulate +# include // random_device +# include // to_s in Number +# include // throw for rendering templates +# include // std::tie +# if defined(__linux__) +# include //sysconf +# endif +# if ANKERL_NANOBENCH(PERF_COUNTERS) +# include // map + +# include +# include +# include +# endif + +// declarations /////////////////////////////////////////////////////////////////////////////////// + +namespace ankerl { +namespace nanobench { + +// helper stuff that is only intended to be used internally +namespace detail { + +struct TableInfo; + +// formatting utilities +namespace fmt { + +class NumSep; +class StreamStateRestorer; +class Number; +class MarkDownColumn; +class MarkDownCode; + +} // namespace fmt +} // namespace detail +} // namespace nanobench +} // namespace ankerl + +// definitions //////////////////////////////////////////////////////////////////////////////////// + +namespace ankerl { +namespace nanobench { + +uint64_t splitMix64(uint64_t& state) noexcept; + +namespace detail { + +// helpers to get double values +template +inline double d(T t) noexcept { + return static_cast(t); +} +inline double d(Clock::duration duration) noexcept { + return std::chrono::duration_cast>(duration).count(); +} + +// Calculates clock resolution once, and remembers the result +inline Clock::duration clockResolution() noexcept; + +} // namespace detail + +namespace templates { + +char const* csv() noexcept { + return R"DELIM("title";"name";"unit";"batch";"elapsed";"error %";"instructions";"branches";"branch misses";"total" +{{#result}}"{{title}}";"{{name}}";"{{unit}}";{{batch}};{{median(elapsed)}};{{medianAbsolutePercentError(elapsed)}};{{median(instructions)}};{{median(branchinstructions)}};{{median(branchmisses)}};{{sumProduct(iterations, elapsed)}} +{{/result}})DELIM"; +} + +char const* htmlBoxplot() noexcept { + return R"DELIM( + + + + + + +
+ + + +)DELIM"; +} + +char const* pyperf() noexcept { + return R"DELIM({ + "benchmarks": [ + { + "runs": [ + { + "values": [ +{{#measurement}} {{elapsed}}{{^-last}}, +{{/last}}{{/measurement}} + ] + } + ] + } + ], + "metadata": { + "loops": {{sum(iterations)}}, + "inner_loops": {{batch}}, + "name": "{{title}}", + "unit": "second" + }, + "version": "1.0" +})DELIM"; +} + +char const* json() noexcept { + return R"DELIM({ + "results": [ +{{#result}} { + "title": "{{title}}", + "name": "{{name}}", + "unit": "{{unit}}", + "batch": {{batch}}, + "complexityN": {{complexityN}}, + "epochs": {{epochs}}, + "clockResolution": {{clockResolution}}, + "clockResolutionMultiple": {{clockResolutionMultiple}}, + "maxEpochTime": {{maxEpochTime}}, + "minEpochTime": {{minEpochTime}}, + "minEpochIterations": {{minEpochIterations}}, + "epochIterations": {{epochIterations}}, + "warmup": {{warmup}}, + "relative": {{relative}}, + "median(elapsed)": {{median(elapsed)}}, + "medianAbsolutePercentError(elapsed)": {{medianAbsolutePercentError(elapsed)}}, + "median(instructions)": {{median(instructions)}}, + "medianAbsolutePercentError(instructions)": {{medianAbsolutePercentError(instructions)}}, + "median(cpucycles)": {{median(cpucycles)}}, + "median(contextswitches)": {{median(contextswitches)}}, + "median(pagefaults)": {{median(pagefaults)}}, + "median(branchinstructions)": {{median(branchinstructions)}}, + "median(branchmisses)": {{median(branchmisses)}}, + "totalTime": {{sumProduct(iterations, elapsed)}}, + "measurements": [ +{{#measurement}} { + "iterations": {{iterations}}, + "elapsed": {{elapsed}}, + "pagefaults": {{pagefaults}}, + "cpucycles": {{cpucycles}}, + "contextswitches": {{contextswitches}}, + "instructions": {{instructions}}, + "branchinstructions": {{branchinstructions}}, + "branchmisses": {{branchmisses}} + }{{^-last}},{{/-last}} +{{/measurement}} ] + }{{^-last}},{{/-last}} +{{/result}} ] +})DELIM"; +} + +ANKERL_NANOBENCH(IGNORE_PADDED_PUSH) +struct Node { + enum class Type { tag, content, section, inverted_section }; + + char const* begin; + char const* end; + std::vector children; + Type type; + + template + // NOLINTNEXTLINE(hicpp-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays) + bool operator==(char const (&str)[N]) const noexcept { + // NOLINTNEXTLINE(cppcoreguidelines-pro-bounds-array-to-pointer-decay) + return static_cast(std::distance(begin, end) + 1) == N && 0 == strncmp(str, begin, N - 1); + } +}; +ANKERL_NANOBENCH(IGNORE_PADDED_POP) + +// NOLINTNEXTLINE(misc-no-recursion) +static std::vector parseMustacheTemplate(char const** tpl) { + std::vector nodes; + + while (true) { + auto const* begin = std::strstr(*tpl, "{{"); + auto const* end = begin; + if (begin != nullptr) { + // NOLINTNEXTLINE(cppcoreguidelines-pro-bounds-pointer-arithmetic) + begin += 2; + end = std::strstr(begin, "}}"); + } + + if (begin == nullptr || end == nullptr) { + // nothing found, finish node + // NOLINTNEXTLINE(cppcoreguidelines-pro-bounds-pointer-arithmetic) + nodes.emplace_back(Node{*tpl, *tpl + std::strlen(*tpl), std::vector{}, Node::Type::content}); + return nodes; + } + + // NOLINTNEXTLINE(cppcoreguidelines-pro-bounds-pointer-arithmetic) + nodes.emplace_back(Node{*tpl, begin - 2, std::vector{}, Node::Type::content}); + + // we found a tag + // NOLINTNEXTLINE(cppcoreguidelines-pro-bounds-pointer-arithmetic) + *tpl = end + 2; + switch (*begin) { + case '/': + // finished! bail out + return nodes; + + case '#': + // NOLINTNEXTLINE(cppcoreguidelines-pro-bounds-pointer-arithmetic) + nodes.emplace_back(Node{begin + 1, end, parseMustacheTemplate(tpl), Node::Type::section}); + break; + + case '^': + // NOLINTNEXTLINE(cppcoreguidelines-pro-bounds-pointer-arithmetic) + nodes.emplace_back(Node{begin + 1, end, parseMustacheTemplate(tpl), Node::Type::inverted_section}); + break; + + default: + nodes.emplace_back(Node{begin, end, std::vector{}, Node::Type::tag}); + break; + } + } +} + +static bool generateFirstLast(Node const& n, size_t idx, size_t size, std::ostream& out) { + ANKERL_NANOBENCH_LOG("n.type=" << static_cast(n.type)); + bool const matchFirst = n == "-first"; + bool const matchLast = n == "-last"; + if (!matchFirst && !matchLast) { + return false; + } + + bool doWrite = false; + if (n.type == Node::Type::section) { + doWrite = (matchFirst && idx == 0) || (matchLast && idx == size - 1); + } else if (n.type == Node::Type::inverted_section) { + doWrite = (matchFirst && idx != 0) || (matchLast && idx != size - 1); + } + + if (doWrite) { + for (auto const& child : n.children) { + if (child.type == Node::Type::content) { + out.write(child.begin, std::distance(child.begin, child.end)); + } + } + } + return true; +} + +static bool matchCmdArgs(std::string const& str, std::vector& matchResult) { + matchResult.clear(); + auto idxOpen = str.find('('); + auto idxClose = str.find(')', idxOpen); + if (idxClose == std::string::npos) { + return false; + } + + matchResult.emplace_back(str.substr(0, idxOpen)); + + // split by comma + matchResult.emplace_back(); + for (size_t i = idxOpen + 1; i != idxClose; ++i) { + if (str[i] == ' ' || str[i] == '\t') { + // skip whitespace + continue; + } + if (str[i] == ',') { + // got a comma => new string + matchResult.emplace_back(); + continue; + } + // no whitespace no comma, append + matchResult.back() += str[i]; + } + return true; +} + +static bool generateConfigTag(Node const& n, Config const& config, std::ostream& out) { + using detail::d; + + if (n == "title") { + out << config.mBenchmarkTitle; + return true; + } + if (n == "name") { + out << config.mBenchmarkName; + return true; + } + if (n == "unit") { + out << config.mUnit; + return true; + } + if (n == "batch") { + out << config.mBatch; + return true; + } + if (n == "complexityN") { + out << config.mComplexityN; + return true; + } + if (n == "epochs") { + out << config.mNumEpochs; + return true; + } + if (n == "clockResolution") { + out << d(detail::clockResolution()); + return true; + } + if (n == "clockResolutionMultiple") { + out << config.mClockResolutionMultiple; + return true; + } + if (n == "maxEpochTime") { + out << d(config.mMaxEpochTime); + return true; + } + if (n == "minEpochTime") { + out << d(config.mMinEpochTime); + return true; + } + if (n == "minEpochIterations") { + out << config.mMinEpochIterations; + return true; + } + if (n == "epochIterations") { + out << config.mEpochIterations; + return true; + } + if (n == "warmup") { + out << config.mWarmup; + return true; + } + if (n == "relative") { + out << config.mIsRelative; + return true; + } + return false; +} + +// NOLINTNEXTLINE(readability-function-cognitive-complexity) +static std::ostream& generateResultTag(Node const& n, Result const& r, std::ostream& out) { + if (generateConfigTag(n, r.config(), out)) { + return out; + } + // match e.g. "median(elapsed)" + // g++ 4.8 doesn't implement std::regex :( + // static std::regex const regOpArg1("^([a-zA-Z]+)\\(([a-zA-Z]*)\\)$"); + // std::cmatch matchResult; + // if (std::regex_match(n.begin, n.end, matchResult, regOpArg1)) { + std::vector matchResult; + if (matchCmdArgs(std::string(n.begin, n.end), matchResult)) { + if (matchResult.size() == 2) { + if (matchResult[0] == "context") { + return out << r.context(matchResult[1]); + } + + auto m = Result::fromString(matchResult[1]); + if (m == Result::Measure::_size) { + return out << 0.0; + } + + if (matchResult[0] == "median") { + return out << r.median(m); + } + if (matchResult[0] == "average") { + return out << r.average(m); + } + if (matchResult[0] == "medianAbsolutePercentError") { + return out << r.medianAbsolutePercentError(m); + } + if (matchResult[0] == "sum") { + return out << r.sum(m); + } + if (matchResult[0] == "minimum") { + return out << r.minimum(m); + } + if (matchResult[0] == "maximum") { + return out << r.maximum(m); + } + } else if (matchResult.size() == 3) { + auto m1 = Result::fromString(matchResult[1]); + auto m2 = Result::fromString(matchResult[2]); + if (m1 == Result::Measure::_size || m2 == Result::Measure::_size) { + return out << 0.0; + } + + if (matchResult[0] == "sumProduct") { + return out << r.sumProduct(m1, m2); + } + } + } + + // match e.g. "sumProduct(elapsed, iterations)" + // static std::regex const regOpArg2("^([a-zA-Z]+)\\(([a-zA-Z]*)\\s*,\\s+([a-zA-Z]*)\\)$"); + + // nothing matches :( + throw std::runtime_error("command '" + std::string(n.begin, n.end) + "' not understood"); +} + +static void generateResultMeasurement(std::vector const& nodes, size_t idx, Result const& r, std::ostream& out) { + for (auto const& n : nodes) { + if (!generateFirstLast(n, idx, r.size(), out)) { + ANKERL_NANOBENCH_LOG("n.type=" << static_cast(n.type)); + switch (n.type) { + case Node::Type::content: + out.write(n.begin, std::distance(n.begin, n.end)); + break; + + case Node::Type::inverted_section: + throw std::runtime_error("got a inverted section inside measurement"); + + case Node::Type::section: + throw std::runtime_error("got a section inside measurement"); + + case Node::Type::tag: { + auto m = Result::fromString(std::string(n.begin, n.end)); + if (m == Result::Measure::_size || !r.has(m)) { + out << 0.0; + } else { + out << r.get(idx, m); + } + break; + } + } + } + } +} + +static void generateResult(std::vector const& nodes, size_t idx, std::vector const& results, std::ostream& out) { + auto const& r = results[idx]; + for (auto const& n : nodes) { + if (!generateFirstLast(n, idx, results.size(), out)) { + ANKERL_NANOBENCH_LOG("n.type=" << static_cast(n.type)); + switch (n.type) { + case Node::Type::content: + out.write(n.begin, std::distance(n.begin, n.end)); + break; + + case Node::Type::inverted_section: + throw std::runtime_error("got a inverted section inside result"); + + case Node::Type::section: + if (n == "measurement") { + for (size_t i = 0; i < r.size(); ++i) { + generateResultMeasurement(n.children, i, r, out); + } + } else { + throw std::runtime_error("got a section inside result"); + } + break; + + case Node::Type::tag: + generateResultTag(n, r, out); + break; + } + } + } +} + +} // namespace templates + +// helper stuff that only intended to be used internally +namespace detail { + +char const* getEnv(char const* name); +bool isEndlessRunning(std::string const& name); +bool isWarningsEnabled(); + +template +T parseFile(std::string const& filename, bool* fail); + +void gatherStabilityInformation(std::vector& warnings, std::vector& recommendations); +void printStabilityInformationOnce(std::ostream* outStream); + +// remembers the last table settings used. When it changes, a new table header is automatically written for the new entry. +uint64_t& singletonHeaderHash() noexcept; + +// determines resolution of the given clock. This is done by measuring multiple times and returning the minimum time difference. +Clock::duration calcClockResolution(size_t numEvaluations) noexcept; + +// formatting utilities +namespace fmt { + +// adds thousands separator to numbers +ANKERL_NANOBENCH(IGNORE_PADDED_PUSH) +class NumSep : public std::numpunct { +public: + explicit NumSep(char sep); + char do_thousands_sep() const override; + std::string do_grouping() const override; + +private: + char mSep; +}; +ANKERL_NANOBENCH(IGNORE_PADDED_POP) + +// RAII to save & restore a stream's state +ANKERL_NANOBENCH(IGNORE_PADDED_PUSH) +class StreamStateRestorer { +public: + explicit StreamStateRestorer(std::ostream& s); + ~StreamStateRestorer(); + + // sets back all stream info that we remembered at construction + void restore(); + + // don't allow copying / moving + StreamStateRestorer(StreamStateRestorer const&) = delete; + StreamStateRestorer& operator=(StreamStateRestorer const&) = delete; + StreamStateRestorer(StreamStateRestorer&&) = delete; + StreamStateRestorer& operator=(StreamStateRestorer&&) = delete; + +private: + std::ostream& mStream; + std::locale mLocale; + std::streamsize const mPrecision; + std::streamsize const mWidth; + std::ostream::char_type const mFill; + std::ostream::fmtflags const mFmtFlags; +}; +ANKERL_NANOBENCH(IGNORE_PADDED_POP) + +// Number formatter +class Number { +public: + Number(int width, int precision, double value); + Number(int width, int precision, int64_t value); + ANKERL_NANOBENCH(NODISCARD) std::string to_s() const; + +private: + friend std::ostream& operator<<(std::ostream& os, Number const& n); + std::ostream& write(std::ostream& os) const; + + int mWidth; + int mPrecision; + double mValue; +}; + +// helper replacement for std::to_string of signed/unsigned numbers so we are locale independent +std::string to_s(uint64_t n); + +std::ostream& operator<<(std::ostream& os, Number const& n); + +class MarkDownColumn { +public: + MarkDownColumn(int w, int prec, std::string tit, std::string suff, double val) noexcept; + ANKERL_NANOBENCH(NODISCARD) std::string title() const; + ANKERL_NANOBENCH(NODISCARD) std::string separator() const; + ANKERL_NANOBENCH(NODISCARD) std::string invalid() const; + ANKERL_NANOBENCH(NODISCARD) std::string value() const; + +private: + int mWidth; + int mPrecision; + std::string mTitle; + std::string mSuffix; + double mValue; +}; + +// Formats any text as markdown code, escaping backticks. +class MarkDownCode { +public: + explicit MarkDownCode(std::string const& what); + +private: + friend std::ostream& operator<<(std::ostream& os, MarkDownCode const& mdCode); + std::ostream& write(std::ostream& os) const; + + std::string mWhat{}; +}; + +std::ostream& operator<<(std::ostream& os, MarkDownCode const& mdCode); + +} // namespace fmt +} // namespace detail +} // namespace nanobench +} // namespace ankerl + +// implementation ///////////////////////////////////////////////////////////////////////////////// + +namespace ankerl { +namespace nanobench { + +// NOLINTNEXTLINE(readability-function-cognitive-complexity) +void render(char const* mustacheTemplate, std::vector const& results, std::ostream& out) { + detail::fmt::StreamStateRestorer const restorer(out); + + out.precision(std::numeric_limits::digits10); + auto nodes = templates::parseMustacheTemplate(&mustacheTemplate); + + for (auto const& n : nodes) { + ANKERL_NANOBENCH_LOG("n.type=" << static_cast(n.type)); + switch (n.type) { + case templates::Node::Type::content: + out.write(n.begin, std::distance(n.begin, n.end)); + break; + + case templates::Node::Type::inverted_section: + throw std::runtime_error("unknown list '" + std::string(n.begin, n.end) + "'"); + + case templates::Node::Type::section: + if (n == "result") { + const size_t nbResults = results.size(); + for (size_t i = 0; i < nbResults; ++i) { + generateResult(n.children, i, results, out); + } + } else if (n == "measurement") { + if (results.size() != 1) { + throw std::runtime_error( + "render: can only use section 'measurement' here if there is a single result, but there are " + + detail::fmt::to_s(results.size())); + } + // when we only have a single result, we can immediately go into its measurement. + auto const& r = results.front(); + for (size_t i = 0; i < r.size(); ++i) { + generateResultMeasurement(n.children, i, r, out); + } + } else { + throw std::runtime_error("render: unknown section '" + std::string(n.begin, n.end) + "'"); + } + break; + + case templates::Node::Type::tag: + if (results.size() == 1) { + // result & config are both supported there + generateResultTag(n, results.front(), out); + } else { + // This just uses the last result's config. + if (!generateConfigTag(n, results.back().config(), out)) { + throw std::runtime_error("unknown tag '" + std::string(n.begin, n.end) + "'"); + } + } + break; + } + } +} + +void render(std::string const& mustacheTemplate, std::vector const& results, std::ostream& out) { + render(mustacheTemplate.c_str(), results, out); +} + +void render(char const* mustacheTemplate, const Bench& bench, std::ostream& out) { + render(mustacheTemplate, bench.results(), out); +} + +void render(std::string const& mustacheTemplate, const Bench& bench, std::ostream& out) { + render(mustacheTemplate.c_str(), bench.results(), out); +} + +namespace detail { + +PerformanceCounters& performanceCounters() { +# if defined(__clang__) +# pragma clang diagnostic push +# pragma clang diagnostic ignored "-Wexit-time-destructors" +# endif + static PerformanceCounters pc; +# if defined(__clang__) +# pragma clang diagnostic pop +# endif + return pc; +} + +// Windows version of doNotOptimizeAway +// see https://github.com/google/benchmark/blob/v1.7.1/include/benchmark/benchmark.h#L514 +// see https://github.com/facebook/folly/blob/v2023.01.30.00/folly/lang/Hint-inl.h#L54-L58 +// see https://learn.microsoft.com/en-us/cpp/preprocessor/optimize +# if defined(_MSC_VER) +# pragma optimize("", off) +void doNotOptimizeAwaySink(void const*) {} +# pragma optimize("", on) +# endif + +template +T parseFile(std::string const& filename, bool* fail) { + std::ifstream fin(filename); // NOLINT(misc-const-correctness) + T num{}; + fin >> num; + if (fail != nullptr) { + *fail = fin.fail(); + } + return num; +} + +char const* getEnv(char const* name) { +# if defined(_MSC_VER) +# pragma warning(push) +# pragma warning(disable : 4996) // getenv': This function or variable may be unsafe. +# endif + return std::getenv(name); // NOLINT(concurrency-mt-unsafe) +# if defined(_MSC_VER) +# pragma warning(pop) +# endif +} + +bool isEndlessRunning(std::string const& name) { + auto const* const endless = getEnv("NANOBENCH_ENDLESS"); + return nullptr != endless && endless == name; +} + +// True when environment variable NANOBENCH_SUPPRESS_WARNINGS is either not set at all, or set to "0" +bool isWarningsEnabled() { + auto const* const suppression = getEnv("NANOBENCH_SUPPRESS_WARNINGS"); + return nullptr == suppression || suppression == std::string("0"); +} + +void gatherStabilityInformation(std::vector& warnings, std::vector& recommendations) { + warnings.clear(); + recommendations.clear(); + +# if defined(DEBUG) + warnings.emplace_back("DEBUG defined"); + bool const recommendCheckFlags = true; +# else + bool const recommendCheckFlags = false; +# endif + + bool recommendPyPerf = false; +# if defined(__linux__) + auto nprocs = sysconf(_SC_NPROCESSORS_CONF); + if (nprocs <= 0) { + warnings.emplace_back("couldn't figure out number of processors - no governor, turbo check possible"); + } else { + // check frequency scaling + for (long id = 0; id < nprocs; ++id) { + auto idStr = detail::fmt::to_s(static_cast(id)); + auto sysCpu = "/sys/devices/system/cpu/cpu" + idStr; + auto minFreq = parseFile(sysCpu + "/cpufreq/scaling_min_freq", nullptr); + auto maxFreq = parseFile(sysCpu + "/cpufreq/scaling_max_freq", nullptr); + if (minFreq != maxFreq) { + auto minMHz = d(minFreq) / 1000.0; + auto maxMHz = d(maxFreq) / 1000.0; + warnings.emplace_back("CPU frequency scaling enabled: CPU " + idStr + " between " + + detail::fmt::Number(1, 1, minMHz).to_s() + " and " + detail::fmt::Number(1, 1, maxMHz).to_s() + + " MHz"); + recommendPyPerf = true; + break; + } + } + + auto fail = false; + auto currentGovernor = parseFile("/sys/devices/system/cpu/cpu0/cpufreq/scaling_governor", &fail); + if (!fail && "performance" != currentGovernor) { + warnings.emplace_back("CPU governor is '" + currentGovernor + "' but should be 'performance'"); + recommendPyPerf = true; + } + + auto noTurbo = parseFile("/sys/devices/system/cpu/intel_pstate/no_turbo", &fail); + if (!fail && noTurbo == 0) { + warnings.emplace_back("Turbo is enabled, CPU frequency will fluctuate"); + recommendPyPerf = true; + } + } +# endif + + if (recommendCheckFlags) { + recommendations.emplace_back("Make sure you compile for Release"); + } + if (recommendPyPerf) { + recommendations.emplace_back("Use 'pyperf system tune' before benchmarking. See https://github.com/psf/pyperf"); + } +} + +void printStabilityInformationOnce(std::ostream* outStream) { + static bool shouldPrint = true; + if (shouldPrint && (nullptr != outStream) && isWarningsEnabled()) { + auto& os = *outStream; + shouldPrint = false; + std::vector warnings; + std::vector recommendations; + gatherStabilityInformation(warnings, recommendations); + if (warnings.empty()) { + return; + } + + os << "Warning, results might be unstable:" << std::endl; + for (auto const& w : warnings) { + os << "* " << w << std::endl; + } + + os << std::endl << "Recommendations" << std::endl; + for (auto const& r : recommendations) { + os << "* " << r << std::endl; + } + } +} + +// remembers the last table settings used. When it changes, a new table header is automatically written for the new entry. +uint64_t& singletonHeaderHash() noexcept { + static uint64_t sHeaderHash{}; + return sHeaderHash; +} + +ANKERL_NANOBENCH_NO_SANITIZE("integer", "undefined") +inline uint64_t hash_combine(uint64_t seed, uint64_t val) { + return seed ^ (val + UINT64_C(0x9e3779b9) + (seed << 6U) + (seed >> 2U)); +} + +// determines resolution of the given clock. This is done by measuring multiple times and returning the minimum time difference. +Clock::duration calcClockResolution(size_t numEvaluations) noexcept { + auto bestDuration = Clock::duration::max(); + Clock::time_point tBegin; + Clock::time_point tEnd; + for (size_t i = 0; i < numEvaluations; ++i) { + tBegin = Clock::now(); + do { + tEnd = Clock::now(); + } while (tBegin == tEnd); + bestDuration = (std::min)(bestDuration, tEnd - tBegin); + } + return bestDuration; +} + +// Calculates clock resolution once, and remembers the result +Clock::duration clockResolution() noexcept { + static Clock::duration const sResolution = calcClockResolution(20); + return sResolution; +} + +ANKERL_NANOBENCH(IGNORE_PADDED_PUSH) +struct IterationLogic::Impl { + enum class State { warmup, upscaling_runtime, measuring, endless }; + + explicit Impl(Bench const& bench) + : mBench(bench) + , mResult(bench.config()) { + printStabilityInformationOnce(mBench.output()); + + // determine target runtime per epoch + mTargetRuntimePerEpoch = detail::clockResolution() * mBench.clockResolutionMultiple(); + if (mTargetRuntimePerEpoch > mBench.maxEpochTime()) { + mTargetRuntimePerEpoch = mBench.maxEpochTime(); + } + if (mTargetRuntimePerEpoch < mBench.minEpochTime()) { + mTargetRuntimePerEpoch = mBench.minEpochTime(); + } + + if (isEndlessRunning(mBench.name())) { + std::cerr << "NANOBENCH_ENDLESS set: running '" << mBench.name() << "' endlessly" << std::endl; + mNumIters = (std::numeric_limits::max)(); + mState = State::endless; + } else if (0 != mBench.warmup()) { + mNumIters = mBench.warmup(); + mState = State::warmup; + } else if (0 != mBench.epochIterations()) { + // exact number of iterations + mNumIters = mBench.epochIterations(); + mState = State::measuring; + } else { + mNumIters = mBench.minEpochIterations(); + mState = State::upscaling_runtime; + } + } + + // directly calculates new iters based on elapsed&iters, and adds a 10% noise. Makes sure we don't underflow. + ANKERL_NANOBENCH(NODISCARD) uint64_t calcBestNumIters(std::chrono::nanoseconds elapsed, uint64_t iters) noexcept { + auto doubleElapsed = d(elapsed); + auto doubleTargetRuntimePerEpoch = d(mTargetRuntimePerEpoch); + auto doubleNewIters = doubleTargetRuntimePerEpoch / doubleElapsed * d(iters); + + auto doubleMinEpochIters = d(mBench.minEpochIterations()); + if (doubleNewIters < doubleMinEpochIters) { + doubleNewIters = doubleMinEpochIters; + } + doubleNewIters *= 1.0 + 0.2 * mRng.uniform01(); + + // +0.5 for correct rounding when casting + // NOLINTNEXTLINE(bugprone-incorrect-roundings) + return static_cast(doubleNewIters + 0.5); + } + + ANKERL_NANOBENCH_NO_SANITIZE("integer", "undefined") void upscale(std::chrono::nanoseconds elapsed) { + if (elapsed * 10 < mTargetRuntimePerEpoch) { + // we are far below the target runtime. Multiply iterations by 10 (with overflow check) + if (mNumIters * 10 < mNumIters) { + // overflow :-( + showResult("iterations overflow. Maybe your code got optimized away?"); + mNumIters = 0; + return; + } + mNumIters *= 10; + } else { + mNumIters = calcBestNumIters(elapsed, mNumIters); + } + } + + void add(std::chrono::nanoseconds elapsed, PerformanceCounters const& pc) noexcept { +# if defined(ANKERL_NANOBENCH_LOG_ENABLED) + auto oldIters = mNumIters; +# endif + + switch (mState) { + case State::warmup: + if (isCloseEnoughForMeasurements(elapsed)) { + // if elapsed is close enough, we can skip upscaling and go right to measurements + // still, we don't add the result to the measurements. + mState = State::measuring; + mNumIters = calcBestNumIters(elapsed, mNumIters); + } else { + // not close enough: switch to upscaling + mState = State::upscaling_runtime; + upscale(elapsed); + } + break; + + case State::upscaling_runtime: + if (isCloseEnoughForMeasurements(elapsed)) { + // if we are close enough, add measurement and switch to always measuring + mState = State::measuring; + mTotalElapsed += elapsed; + mTotalNumIters += mNumIters; + mResult.add(elapsed, mNumIters, pc); + mNumIters = calcBestNumIters(mTotalElapsed, mTotalNumIters); + } else { + upscale(elapsed); + } + break; + + case State::measuring: + // just add measurements - no questions asked. Even when runtime is low. But we can't ignore + // that fluctuation, or else we would bias the result + mTotalElapsed += elapsed; + mTotalNumIters += mNumIters; + mResult.add(elapsed, mNumIters, pc); + if (0 != mBench.epochIterations()) { + mNumIters = mBench.epochIterations(); + } else { + mNumIters = calcBestNumIters(mTotalElapsed, mTotalNumIters); + } + break; + + case State::endless: + mNumIters = (std::numeric_limits::max)(); + break; + } + + if (static_cast(mResult.size()) == mBench.epochs()) { + // we got all the results that we need, finish it + showResult(""); + mNumIters = 0; + } + + ANKERL_NANOBENCH_LOG(mBench.name() << ": " << detail::fmt::Number(20, 3, d(elapsed.count())) << " elapsed, " + << detail::fmt::Number(20, 3, d(mTargetRuntimePerEpoch.count())) << " target. oldIters=" + << oldIters << ", mNumIters=" << mNumIters << ", mState=" << static_cast(mState)); + } + + // NOLINTNEXTLINE(readability-function-cognitive-complexity) + void showResult(std::string const& errorMessage) const { + ANKERL_NANOBENCH_LOG(errorMessage); + + if (mBench.output() != nullptr) { + // prepare column data /////// + std::vector columns; + + auto rMedian = mResult.median(Result::Measure::elapsed); + + if (mBench.relative()) { + double d = 100.0; + if (!mBench.results().empty()) { + d = rMedian <= 0.0 ? 0.0 : mBench.results().front().median(Result::Measure::elapsed) / rMedian * 100.0; + } + columns.emplace_back(11, 1, "relative", "%", d); + } + + if (mBench.complexityN() > 0) { + columns.emplace_back(14, 0, "complexityN", "", mBench.complexityN()); + } + + columns.emplace_back(22, 2, mBench.timeUnitName() + "/" + mBench.unit(), "", + rMedian / (mBench.timeUnit().count() * mBench.batch())); + columns.emplace_back(22, 2, mBench.unit() + "/s", "", rMedian <= 0.0 ? 0.0 : mBench.batch() / rMedian); + + double const rErrorMedian = mResult.medianAbsolutePercentError(Result::Measure::elapsed); + columns.emplace_back(10, 1, "err%", "%", rErrorMedian * 100.0); + + double rInsMedian = -1.0; + if (mBench.performanceCounters() && mResult.has(Result::Measure::instructions)) { + rInsMedian = mResult.median(Result::Measure::instructions); + columns.emplace_back(18, 2, "ins/" + mBench.unit(), "", rInsMedian / mBench.batch()); + } + + double rCycMedian = -1.0; + if (mBench.performanceCounters() && mResult.has(Result::Measure::cpucycles)) { + rCycMedian = mResult.median(Result::Measure::cpucycles); + columns.emplace_back(18, 2, "cyc/" + mBench.unit(), "", rCycMedian / mBench.batch()); + } + if (rInsMedian > 0.0 && rCycMedian > 0.0) { + columns.emplace_back(9, 3, "IPC", "", rCycMedian <= 0.0 ? 0.0 : rInsMedian / rCycMedian); + } + if (mBench.performanceCounters() && mResult.has(Result::Measure::branchinstructions)) { + double const rBraMedian = mResult.median(Result::Measure::branchinstructions); + columns.emplace_back(17, 2, "bra/" + mBench.unit(), "", rBraMedian / mBench.batch()); + if (mResult.has(Result::Measure::branchmisses)) { + double p = 0.0; + if (rBraMedian >= 1e-9) { + p = 100.0 * mResult.median(Result::Measure::branchmisses) / rBraMedian; + } + columns.emplace_back(10, 1, "miss%", "%", p); + } + } + + columns.emplace_back(12, 2, "total", "", mResult.sumProduct(Result::Measure::iterations, Result::Measure::elapsed)); + + // write everything + auto& os = *mBench.output(); + + // combine all elements that are relevant for printing the header + uint64_t hash = 0; + hash = hash_combine(std::hash{}(mBench.unit()), hash); + hash = hash_combine(std::hash{}(mBench.title()), hash); + hash = hash_combine(std::hash{}(mBench.timeUnitName()), hash); + hash = hash_combine(std::hash{}(mBench.timeUnit().count()), hash); + hash = hash_combine(std::hash{}(mBench.relative()), hash); + hash = hash_combine(std::hash{}(mBench.performanceCounters()), hash); + + if (hash != singletonHeaderHash()) { + singletonHeaderHash() = hash; + + // no result yet, print header + os << std::endl; + for (auto const& col : columns) { + os << col.title(); + } + os << "| " << mBench.title() << std::endl; + + for (auto const& col : columns) { + os << col.separator(); + } + os << "|:" << std::string(mBench.title().size() + 1U, '-') << std::endl; + } + + if (!errorMessage.empty()) { + for (auto const& col : columns) { + os << col.invalid(); + } + os << "| :boom: " << fmt::MarkDownCode(mBench.name()) << " (" << errorMessage << ')' << std::endl; + } else { + for (auto const& col : columns) { + os << col.value(); + } + os << "| "; + auto showUnstable = isWarningsEnabled() && rErrorMedian >= 0.05; + if (showUnstable) { + os << ":wavy_dash: "; + } + os << fmt::MarkDownCode(mBench.name()); + if (showUnstable) { + auto avgIters = d(mTotalNumIters) / d(mBench.epochs()); + // NOLINTNEXTLINE(bugprone-incorrect-roundings) + auto suggestedIters = static_cast(avgIters * 10 + 0.5); + + os << " (Unstable with ~" << detail::fmt::Number(1, 1, avgIters) + << " iters. Increase `minEpochIterations` to e.g. " << suggestedIters << ")"; + } + os << std::endl; + } + } + } + + ANKERL_NANOBENCH(NODISCARD) bool isCloseEnoughForMeasurements(std::chrono::nanoseconds elapsed) const noexcept { + return elapsed * 3 >= mTargetRuntimePerEpoch * 2; + } + + uint64_t mNumIters = 1; // NOLINT(misc-non-private-member-variables-in-classes) + Bench const& mBench; // NOLINT(misc-non-private-member-variables-in-classes) + std::chrono::nanoseconds mTargetRuntimePerEpoch{}; // NOLINT(misc-non-private-member-variables-in-classes) + Result mResult; // NOLINT(misc-non-private-member-variables-in-classes) + Rng mRng{123}; // NOLINT(misc-non-private-member-variables-in-classes) + std::chrono::nanoseconds mTotalElapsed{}; // NOLINT(misc-non-private-member-variables-in-classes) + uint64_t mTotalNumIters = 0; // NOLINT(misc-non-private-member-variables-in-classes) + State mState = State::upscaling_runtime; // NOLINT(misc-non-private-member-variables-in-classes) +}; +ANKERL_NANOBENCH(IGNORE_PADDED_POP) + +IterationLogic::IterationLogic(Bench const& bench) + : mPimpl(new Impl(bench)) {} + +IterationLogic::~IterationLogic() { + delete mPimpl; +} + +uint64_t IterationLogic::numIters() const noexcept { + ANKERL_NANOBENCH_LOG(mPimpl->mBench.name() << ": mNumIters=" << mPimpl->mNumIters); + return mPimpl->mNumIters; +} + +void IterationLogic::add(std::chrono::nanoseconds elapsed, PerformanceCounters const& pc) noexcept { + mPimpl->add(elapsed, pc); +} + +void IterationLogic::moveResultTo(std::vector& results) noexcept { + results.emplace_back(std::move(mPimpl->mResult)); +} + +# if ANKERL_NANOBENCH(PERF_COUNTERS) + +ANKERL_NANOBENCH(IGNORE_PADDED_PUSH) +class LinuxPerformanceCounters { +public: + struct Target { + Target(uint64_t* targetValue_, bool correctMeasuringOverhead_, bool correctLoopOverhead_) + : targetValue(targetValue_) + , correctMeasuringOverhead(correctMeasuringOverhead_) + , correctLoopOverhead(correctLoopOverhead_) {} + + uint64_t* targetValue{}; // NOLINT(misc-non-private-member-variables-in-classes) + bool correctMeasuringOverhead{}; // NOLINT(misc-non-private-member-variables-in-classes) + bool correctLoopOverhead{}; // NOLINT(misc-non-private-member-variables-in-classes) + }; + + LinuxPerformanceCounters() = default; + LinuxPerformanceCounters(LinuxPerformanceCounters const&) = delete; + LinuxPerformanceCounters(LinuxPerformanceCounters&&) = delete; + LinuxPerformanceCounters& operator=(LinuxPerformanceCounters const&) = delete; + LinuxPerformanceCounters& operator=(LinuxPerformanceCounters&&) = delete; + ~LinuxPerformanceCounters(); + + // quick operation + inline void start() {} + + inline void stop() {} + + bool monitor(perf_sw_ids swId, Target target); + bool monitor(perf_hw_id hwId, Target target); + + ANKERL_NANOBENCH(NODISCARD) bool hasError() const noexcept { + return mHasError; + } + + // Just reading data is faster than enable & disabling. + // we subtract data ourselves. + inline void beginMeasure() { + if (mHasError) { + return; + } + + // NOLINTNEXTLINE(hicpp-signed-bitwise,cppcoreguidelines-pro-type-vararg) + mHasError = -1 == ioctl(mFd, PERF_EVENT_IOC_RESET, PERF_IOC_FLAG_GROUP); + if (mHasError) { + return; + } + + // NOLINTNEXTLINE(hicpp-signed-bitwise,cppcoreguidelines-pro-type-vararg) + mHasError = -1 == ioctl(mFd, PERF_EVENT_IOC_ENABLE, PERF_IOC_FLAG_GROUP); + } + + inline void endMeasure() { + if (mHasError) { + return; + } + + // NOLINTNEXTLINE(hicpp-signed-bitwise,cppcoreguidelines-pro-type-vararg) + mHasError = (-1 == ioctl(mFd, PERF_EVENT_IOC_DISABLE, PERF_IOC_FLAG_GROUP)); + if (mHasError) { + return; + } + + auto const numBytes = sizeof(uint64_t) * mCounters.size(); + auto ret = read(mFd, mCounters.data(), numBytes); + mHasError = ret != static_cast(numBytes); + } + + void updateResults(uint64_t numIters); + + // rounded integer division + template + static inline T divRounded(T a, T divisor) { + return (a + divisor / 2) / divisor; + } + + ANKERL_NANOBENCH_NO_SANITIZE("integer", "undefined") + static inline uint32_t mix(uint32_t x) noexcept { + x ^= x << 13U; + x ^= x >> 17U; + x ^= x << 5U; + return x; + } + + template + ANKERL_NANOBENCH_NO_SANITIZE("integer", "undefined") + void calibrate(Op&& op) { + // clear current calibration data, + for (auto& v : mCalibratedOverhead) { + v = UINT64_C(0); + } + + // create new calibration data + auto newCalibration = mCalibratedOverhead; + for (auto& v : newCalibration) { + v = (std::numeric_limits::max)(); + } + for (size_t iter = 0; iter < 100; ++iter) { + beginMeasure(); + op(); + endMeasure(); + if (mHasError) { + return; + } + + for (size_t i = 0; i < newCalibration.size(); ++i) { + auto diff = mCounters[i]; + if (newCalibration[i] > diff) { + newCalibration[i] = diff; + } + } + } + + mCalibratedOverhead = std::move(newCalibration); + + { + // calibrate loop overhead. For branches & instructions this makes sense, not so much for everything else like cycles. + // marsaglia's xorshift: mov, sal/shr, xor. Times 3. + // This has the nice property that the compiler doesn't seem to be able to optimize multiple calls any further. + // see https://godbolt.org/z/49RVQ5 + uint64_t const numIters = 100000U + (std::random_device{}() & 3U); + uint64_t n = numIters; + uint32_t x = 1234567; + + beginMeasure(); + while (n-- > 0) { + x = mix(x); + } + endMeasure(); + detail::doNotOptimizeAway(x); + auto measure1 = mCounters; + + n = numIters; + beginMeasure(); + while (n-- > 0) { + // we now run *twice* so we can easily calculate the overhead + x = mix(x); + x = mix(x); + } + endMeasure(); + detail::doNotOptimizeAway(x); + auto measure2 = mCounters; + + for (size_t i = 0; i < mCounters.size(); ++i) { + // factor 2 because we have two instructions per loop + auto m1 = measure1[i] > mCalibratedOverhead[i] ? measure1[i] - mCalibratedOverhead[i] : 0; + auto m2 = measure2[i] > mCalibratedOverhead[i] ? measure2[i] - mCalibratedOverhead[i] : 0; + auto overhead = m1 * 2 > m2 ? m1 * 2 - m2 : 0; + + mLoopOverhead[i] = divRounded(overhead, numIters); + } + } + } + +private: + bool monitor(uint32_t type, uint64_t eventid, Target target); + + std::map mIdToTarget{}; + + // start with minimum size of 3 for read_format + std::vector mCounters{3}; + std::vector mCalibratedOverhead{3}; + std::vector mLoopOverhead{3}; + + uint64_t mTimeEnabledNanos = 0; + uint64_t mTimeRunningNanos = 0; + int mFd = -1; + bool mHasError = false; +}; +ANKERL_NANOBENCH(IGNORE_PADDED_POP) + +LinuxPerformanceCounters::~LinuxPerformanceCounters() { + if (-1 != mFd) { + close(mFd); + } +} + +bool LinuxPerformanceCounters::monitor(perf_sw_ids swId, LinuxPerformanceCounters::Target target) { + return monitor(PERF_TYPE_SOFTWARE, swId, target); +} + +bool LinuxPerformanceCounters::monitor(perf_hw_id hwId, LinuxPerformanceCounters::Target target) { + return monitor(PERF_TYPE_HARDWARE, hwId, target); +} + +// overflow is ok, it's checked +ANKERL_NANOBENCH_NO_SANITIZE("integer", "undefined") +void LinuxPerformanceCounters::updateResults(uint64_t numIters) { + // clear old data + for (auto& id_value : mIdToTarget) { + *id_value.second.targetValue = UINT64_C(0); + } + + if (mHasError) { + return; + } + + mTimeEnabledNanos = mCounters[1] - mCalibratedOverhead[1]; + mTimeRunningNanos = mCounters[2] - mCalibratedOverhead[2]; + + for (uint64_t i = 0; i < mCounters[0]; ++i) { + auto idx = static_cast(3 + i * 2 + 0); + auto id = mCounters[idx + 1U]; + + auto it = mIdToTarget.find(id); + if (it != mIdToTarget.end()) { + + auto& tgt = it->second; + *tgt.targetValue = mCounters[idx]; + if (tgt.correctMeasuringOverhead) { + if (*tgt.targetValue >= mCalibratedOverhead[idx]) { + *tgt.targetValue -= mCalibratedOverhead[idx]; + } else { + *tgt.targetValue = 0U; + } + } + if (tgt.correctLoopOverhead) { + auto correctionVal = mLoopOverhead[idx] * numIters; + if (*tgt.targetValue >= correctionVal) { + *tgt.targetValue -= correctionVal; + } else { + *tgt.targetValue = 0U; + } + } + } + } +} + +bool LinuxPerformanceCounters::monitor(uint32_t type, uint64_t eventid, Target target) { + *target.targetValue = (std::numeric_limits::max)(); + if (mHasError) { + return false; + } + + auto pea = perf_event_attr(); + std::memset(&pea, 0, sizeof(perf_event_attr)); + pea.type = type; + pea.size = sizeof(perf_event_attr); + pea.config = eventid; + pea.disabled = 1; // start counter as disabled + pea.exclude_kernel = 1; + pea.exclude_hv = 1; + + // NOLINTNEXTLINE(hicpp-signed-bitwise) + pea.read_format = PERF_FORMAT_GROUP | PERF_FORMAT_ID | PERF_FORMAT_TOTAL_TIME_ENABLED | PERF_FORMAT_TOTAL_TIME_RUNNING; + + const int pid = 0; // the current process + const int cpu = -1; // all CPUs +# if defined(PERF_FLAG_FD_CLOEXEC) // since Linux 3.14 + const unsigned long flags = PERF_FLAG_FD_CLOEXEC; +# else + const unsigned long flags = 0; +# endif + + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-vararg) + auto fd = static_cast(syscall(__NR_perf_event_open, &pea, pid, cpu, mFd, flags)); + if (-1 == fd) { + return false; + } + if (-1 == mFd) { + // first call: set to fd, and use this from now on + mFd = fd; + } + uint64_t id = 0; + // NOLINTNEXTLINE(hicpp-signed-bitwise,cppcoreguidelines-pro-type-vararg) + if (-1 == ioctl(fd, PERF_EVENT_IOC_ID, &id)) { + // couldn't get id + return false; + } + + // insert into map, rely on the fact that map's references are constant. + mIdToTarget.emplace(id, target); + + // prepare readformat with the correct size (after the insert) + auto size = 3 + 2 * mIdToTarget.size(); + mCounters.resize(size); + mCalibratedOverhead.resize(size); + mLoopOverhead.resize(size); + + return true; +} + +PerformanceCounters::PerformanceCounters() + : mPc(new LinuxPerformanceCounters()) + , mVal() + , mHas() { + + // HW events + mHas.cpuCycles = mPc->monitor(PERF_COUNT_HW_REF_CPU_CYCLES, LinuxPerformanceCounters::Target(&mVal.cpuCycles, true, false)); + if (!mHas.cpuCycles) { + // Fallback to cycles counter, reference cycles not available in many systems. + mHas.cpuCycles = mPc->monitor(PERF_COUNT_HW_CPU_CYCLES, LinuxPerformanceCounters::Target(&mVal.cpuCycles, true, false)); + } + mHas.instructions = mPc->monitor(PERF_COUNT_HW_INSTRUCTIONS, LinuxPerformanceCounters::Target(&mVal.instructions, true, true)); + mHas.branchInstructions = + mPc->monitor(PERF_COUNT_HW_BRANCH_INSTRUCTIONS, LinuxPerformanceCounters::Target(&mVal.branchInstructions, true, false)); + mHas.branchMisses = mPc->monitor(PERF_COUNT_HW_BRANCH_MISSES, LinuxPerformanceCounters::Target(&mVal.branchMisses, true, false)); + // mHas.branchMisses = false; + + // SW events + mHas.pageFaults = mPc->monitor(PERF_COUNT_SW_PAGE_FAULTS, LinuxPerformanceCounters::Target(&mVal.pageFaults, true, false)); + mHas.contextSwitches = + mPc->monitor(PERF_COUNT_SW_CONTEXT_SWITCHES, LinuxPerformanceCounters::Target(&mVal.contextSwitches, true, false)); + + mPc->start(); + mPc->calibrate([] { + auto before = ankerl::nanobench::Clock::now(); + auto after = ankerl::nanobench::Clock::now(); + (void)before; + (void)after; + }); + + if (mPc->hasError()) { + // something failed, don't monitor anything. + mHas = PerfCountSet{}; + } +} + +PerformanceCounters::~PerformanceCounters() { + // no need to check for nullptr, delete nullptr has no effect + delete mPc; +} + +void PerformanceCounters::beginMeasure() { + mPc->beginMeasure(); +} + +void PerformanceCounters::endMeasure() { + mPc->endMeasure(); +} + +void PerformanceCounters::updateResults(uint64_t numIters) { + mPc->updateResults(numIters); +} + +# else + +PerformanceCounters::PerformanceCounters() = default; +PerformanceCounters::~PerformanceCounters() = default; +void PerformanceCounters::beginMeasure() {} +void PerformanceCounters::endMeasure() {} +void PerformanceCounters::updateResults(uint64_t) {} + +# endif + +ANKERL_NANOBENCH(NODISCARD) PerfCountSet const& PerformanceCounters::val() const noexcept { + return mVal; +} +ANKERL_NANOBENCH(NODISCARD) PerfCountSet const& PerformanceCounters::has() const noexcept { + return mHas; +} + +// formatting utilities +namespace fmt { + +// adds thousands separator to numbers +NumSep::NumSep(char sep) + : mSep(sep) {} + +char NumSep::do_thousands_sep() const { + return mSep; +} + +std::string NumSep::do_grouping() const { + return "\003"; +} + +// RAII to save & restore a stream's state +StreamStateRestorer::StreamStateRestorer(std::ostream& s) + : mStream(s) + , mLocale(s.getloc()) + , mPrecision(s.precision()) + , mWidth(s.width()) + , mFill(s.fill()) + , mFmtFlags(s.flags()) {} + +StreamStateRestorer::~StreamStateRestorer() { + restore(); +} + +// sets back all stream info that we remembered at construction +void StreamStateRestorer::restore() { + mStream.imbue(mLocale); + mStream.precision(mPrecision); + mStream.width(mWidth); + mStream.fill(mFill); + mStream.flags(mFmtFlags); +} + +Number::Number(int width, int precision, int64_t value) + : mWidth(width) + , mPrecision(precision) + , mValue(d(value)) {} + +Number::Number(int width, int precision, double value) + : mWidth(width) + , mPrecision(precision) + , mValue(value) {} + +std::ostream& Number::write(std::ostream& os) const { + StreamStateRestorer const restorer(os); + os.imbue(std::locale(os.getloc(), new NumSep(','))); + os << std::setw(mWidth) << std::setprecision(mPrecision) << std::fixed << mValue; + return os; +} + +std::string Number::to_s() const { + std::stringstream ss; + write(ss); + return ss.str(); +} + +std::string to_s(uint64_t n) { + std::string str; + do { + str += static_cast('0' + static_cast(n % 10)); + n /= 10; + } while (n != 0); + std::reverse(str.begin(), str.end()); + return str; +} + +std::ostream& operator<<(std::ostream& os, Number const& n) { + return n.write(os); +} + +MarkDownColumn::MarkDownColumn(int w, int prec, std::string tit, std::string suff, double val) noexcept + : mWidth(w) + , mPrecision(prec) + , mTitle(std::move(tit)) + , mSuffix(std::move(suff)) + , mValue(val) {} + +std::string MarkDownColumn::title() const { + std::stringstream ss; + ss << '|' << std::setw(mWidth - 2) << std::right << mTitle << ' '; + return ss.str(); +} + +std::string MarkDownColumn::separator() const { + std::string sep(static_cast(mWidth), '-'); + sep.front() = '|'; + sep.back() = ':'; + return sep; +} + +std::string MarkDownColumn::invalid() const { + std::string sep(static_cast(mWidth), ' '); + sep.front() = '|'; + sep[sep.size() - 2] = '-'; + return sep; +} + +std::string MarkDownColumn::value() const { + std::stringstream ss; + auto width = mWidth - 2 - static_cast(mSuffix.size()); + ss << '|' << Number(width, mPrecision, mValue) << mSuffix << ' '; + return ss.str(); +} + +// Formats any text as markdown code, escaping backticks. +MarkDownCode::MarkDownCode(std::string const& what) { + mWhat.reserve(what.size() + 2); + mWhat.push_back('`'); + for (char const c : what) { + mWhat.push_back(c); + if ('`' == c) { + mWhat.push_back('`'); + } + } + mWhat.push_back('`'); +} + +std::ostream& MarkDownCode::write(std::ostream& os) const { + return os << mWhat; +} + +std::ostream& operator<<(std::ostream& os, MarkDownCode const& mdCode) { + return mdCode.write(os); +} +} // namespace fmt +} // namespace detail + +// provide implementation here so it's only generated once +Config::Config() = default; +Config::~Config() = default; +Config& Config::operator=(Config const&) = default; +Config& Config::operator=(Config&&) noexcept(ANKERL_NANOBENCH(NOEXCEPT_STRING_MOVE)) = default; +Config::Config(Config const&) = default; +Config::Config(Config&&) noexcept = default; + +// provide implementation here so it's only generated once +Result::~Result() = default; +Result& Result::operator=(Result const&) = default; +Result& Result::operator=(Result&&) noexcept(ANKERL_NANOBENCH(NOEXCEPT_STRING_MOVE)) = default; +Result::Result(Result const&) = default; +Result::Result(Result&&) noexcept = default; + +namespace detail { +template +inline constexpr typename std::underlying_type::type u(T val) noexcept { + return static_cast::type>(val); +} +} // namespace detail + +// Result returned after a benchmark has finished. Can be used as a baseline for relative(). +Result::Result(Config benchmarkConfig) + : mConfig(std::move(benchmarkConfig)) + , mNameToMeasurements{detail::u(Result::Measure::_size)} {} + +void Result::add(Clock::duration totalElapsed, uint64_t iters, detail::PerformanceCounters const& pc) { + using detail::d; + using detail::u; + + double const dIters = d(iters); + mNameToMeasurements[u(Result::Measure::iterations)].push_back(dIters); + + mNameToMeasurements[u(Result::Measure::elapsed)].push_back(d(totalElapsed) / dIters); + if (pc.has().pageFaults) { + mNameToMeasurements[u(Result::Measure::pagefaults)].push_back(d(pc.val().pageFaults) / dIters); + } + if (pc.has().cpuCycles) { + mNameToMeasurements[u(Result::Measure::cpucycles)].push_back(d(pc.val().cpuCycles) / dIters); + } + if (pc.has().contextSwitches) { + mNameToMeasurements[u(Result::Measure::contextswitches)].push_back(d(pc.val().contextSwitches) / dIters); + } + if (pc.has().instructions) { + mNameToMeasurements[u(Result::Measure::instructions)].push_back(d(pc.val().instructions) / dIters); + } + if (pc.has().branchInstructions) { + double branchInstructions = 0.0; + // correcting branches: remove branch introduced by the while (...) loop for each iteration. + if (pc.val().branchInstructions > iters + 1U) { + branchInstructions = d(pc.val().branchInstructions - (iters + 1U)); + } + mNameToMeasurements[u(Result::Measure::branchinstructions)].push_back(branchInstructions / dIters); + + if (pc.has().branchMisses) { + // correcting branch misses + double branchMisses = d(pc.val().branchMisses); + if (branchMisses > branchInstructions) { + // can't have branch misses when there were branches... + branchMisses = branchInstructions; + } + + // assuming at least one missed branch for the loop + branchMisses -= 1.0; + if (branchMisses < 1.0) { + branchMisses = 1.0; + } + mNameToMeasurements[u(Result::Measure::branchmisses)].push_back(branchMisses / dIters); + } + } +} + +Config const& Result::config() const noexcept { + return mConfig; +} + +inline double calcMedian(std::vector& data) { + if (data.empty()) { + return 0.0; + } + std::sort(data.begin(), data.end()); + + auto midIdx = data.size() / 2U; + if (1U == (data.size() & 1U)) { + return data[midIdx]; + } + return (data[midIdx - 1U] + data[midIdx]) / 2U; +} + +double Result::median(Measure m) const { + // create a copy so we can sort + auto data = mNameToMeasurements[detail::u(m)]; + return calcMedian(data); +} + +double Result::average(Measure m) const { + using detail::d; + auto const& data = mNameToMeasurements[detail::u(m)]; + if (data.empty()) { + return 0.0; + } + + // create a copy so we can sort + return sum(m) / d(data.size()); +} + +double Result::medianAbsolutePercentError(Measure m) const { + // create copy + auto data = mNameToMeasurements[detail::u(m)]; + + // calculates MdAPE which is the median of percentage error + // see https://support.numxl.com/hc/en-us/articles/115001223503-MdAPE-Median-Absolute-Percentage-Error + auto med = calcMedian(data); + + // transform the data to absolute error + for (auto& x : data) { + x = (x - med) / x; + if (x < 0) { + x = -x; + } + } + return calcMedian(data); +} + +double Result::sum(Measure m) const noexcept { + auto const& data = mNameToMeasurements[detail::u(m)]; + return std::accumulate(data.begin(), data.end(), 0.0); +} + +double Result::sumProduct(Measure m1, Measure m2) const noexcept { + auto const& data1 = mNameToMeasurements[detail::u(m1)]; + auto const& data2 = mNameToMeasurements[detail::u(m2)]; + + if (data1.size() != data2.size()) { + return 0.0; + } + + double result = 0.0; + for (size_t i = 0, s = data1.size(); i != s; ++i) { + result += data1[i] * data2[i]; + } + return result; +} + +bool Result::has(Measure m) const noexcept { + return !mNameToMeasurements[detail::u(m)].empty(); +} + +double Result::get(size_t idx, Measure m) const { + auto const& data = mNameToMeasurements[detail::u(m)]; + return data.at(idx); +} + +bool Result::empty() const noexcept { + return 0U == size(); +} + +size_t Result::size() const noexcept { + auto const& data = mNameToMeasurements[detail::u(Measure::elapsed)]; + return data.size(); +} + +double Result::minimum(Measure m) const noexcept { + auto const& data = mNameToMeasurements[detail::u(m)]; + if (data.empty()) { + return 0.0; + } + + // here its save to assume that at least one element is there + return *std::min_element(data.begin(), data.end()); +} + +double Result::maximum(Measure m) const noexcept { + auto const& data = mNameToMeasurements[detail::u(m)]; + if (data.empty()) { + return 0.0; + } + + // here its save to assume that at least one element is there + return *std::max_element(data.begin(), data.end()); +} + +std::string const& Result::context(char const* variableName) const { + return mConfig.mContext.at(variableName); +} + +std::string const& Result::context(std::string const& variableName) const { + return mConfig.mContext.at(variableName); +} + +Result::Measure Result::fromString(std::string const& str) { + if (str == "elapsed") { + return Measure::elapsed; + } + if (str == "iterations") { + return Measure::iterations; + } + if (str == "pagefaults") { + return Measure::pagefaults; + } + if (str == "cpucycles") { + return Measure::cpucycles; + } + if (str == "contextswitches") { + return Measure::contextswitches; + } + if (str == "instructions") { + return Measure::instructions; + } + if (str == "branchinstructions") { + return Measure::branchinstructions; + } + if (str == "branchmisses") { + return Measure::branchmisses; + } + // not found, return _size + return Measure::_size; +} + +// Configuration of a microbenchmark. +Bench::Bench() { + mConfig.mOut = &std::cout; +} + +Bench::Bench(Bench&&) noexcept = default; +Bench& Bench::operator=(Bench&&) noexcept(ANKERL_NANOBENCH(NOEXCEPT_STRING_MOVE)) = default; +Bench::Bench(Bench const&) = default; +Bench& Bench::operator=(Bench const&) = default; +Bench::~Bench() noexcept = default; + +double Bench::batch() const noexcept { + return mConfig.mBatch; +} + +double Bench::complexityN() const noexcept { + return mConfig.mComplexityN; +} + +// Set a baseline to compare it to. 100% it is exactly as fast as the baseline, >100% means it is faster than the baseline, <100% +// means it is slower than the baseline. +Bench& Bench::relative(bool isRelativeEnabled) noexcept { + mConfig.mIsRelative = isRelativeEnabled; + return *this; +} +bool Bench::relative() const noexcept { + return mConfig.mIsRelative; +} + +Bench& Bench::performanceCounters(bool showPerformanceCounters) noexcept { + mConfig.mShowPerformanceCounters = showPerformanceCounters; + return *this; +} +bool Bench::performanceCounters() const noexcept { + return mConfig.mShowPerformanceCounters; +} + +// Operation unit. Defaults to "op", could be e.g. "byte" for string processing. +// If u differs from currently set unit, the stored results will be cleared. +// Use singular (byte, not bytes). +Bench& Bench::unit(char const* u) { + if (u != mConfig.mUnit) { + mResults.clear(); + } + mConfig.mUnit = u; + return *this; +} + +Bench& Bench::unit(std::string const& u) { + return unit(u.c_str()); +} + +std::string const& Bench::unit() const noexcept { + return mConfig.mUnit; +} + +Bench& Bench::timeUnit(std::chrono::duration const& tu, std::string const& tuName) { + mConfig.mTimeUnit = tu; + mConfig.mTimeUnitName = tuName; + return *this; +} + +std::string const& Bench::timeUnitName() const noexcept { + return mConfig.mTimeUnitName; +} + +std::chrono::duration const& Bench::timeUnit() const noexcept { + return mConfig.mTimeUnit; +} + +// If benchmarkTitle differs from currently set title, the stored results will be cleared. +Bench& Bench::title(const char* benchmarkTitle) { + if (benchmarkTitle != mConfig.mBenchmarkTitle) { + mResults.clear(); + } + mConfig.mBenchmarkTitle = benchmarkTitle; + return *this; +} +Bench& Bench::title(std::string const& benchmarkTitle) { + if (benchmarkTitle != mConfig.mBenchmarkTitle) { + mResults.clear(); + } + mConfig.mBenchmarkTitle = benchmarkTitle; + return *this; +} + +std::string const& Bench::title() const noexcept { + return mConfig.mBenchmarkTitle; +} + +Bench& Bench::name(const char* benchmarkName) { + mConfig.mBenchmarkName = benchmarkName; + return *this; +} + +Bench& Bench::name(std::string const& benchmarkName) { + mConfig.mBenchmarkName = benchmarkName; + return *this; +} + +std::string const& Bench::name() const noexcept { + return mConfig.mBenchmarkName; +} + +Bench& Bench::context(char const* variableName, char const* variableValue) { + mConfig.mContext[variableName] = variableValue; + return *this; +} + +Bench& Bench::context(std::string const& variableName, std::string const& variableValue) { + mConfig.mContext[variableName] = variableValue; + return *this; +} + +Bench& Bench::clearContext() { + mConfig.mContext.clear(); + return *this; +} + +// Number of epochs to evaluate. The reported result will be the median of evaluation of each epoch. +Bench& Bench::epochs(size_t numEpochs) noexcept { + mConfig.mNumEpochs = numEpochs; + return *this; +} +size_t Bench::epochs() const noexcept { + return mConfig.mNumEpochs; +} + +// Desired evaluation time is a multiple of clock resolution. Default is to be 1000 times above this measurement precision. +Bench& Bench::clockResolutionMultiple(size_t multiple) noexcept { + mConfig.mClockResolutionMultiple = multiple; + return *this; +} +size_t Bench::clockResolutionMultiple() const noexcept { + return mConfig.mClockResolutionMultiple; +} + +// Sets the maximum time each epoch should take. Default is 100ms. +Bench& Bench::maxEpochTime(std::chrono::nanoseconds t) noexcept { + mConfig.mMaxEpochTime = t; + return *this; +} +std::chrono::nanoseconds Bench::maxEpochTime() const noexcept { + return mConfig.mMaxEpochTime; +} + +// Sets the maximum time each epoch should take. Default is 100ms. +Bench& Bench::minEpochTime(std::chrono::nanoseconds t) noexcept { + mConfig.mMinEpochTime = t; + return *this; +} +std::chrono::nanoseconds Bench::minEpochTime() const noexcept { + return mConfig.mMinEpochTime; +} + +Bench& Bench::minEpochIterations(uint64_t numIters) noexcept { + mConfig.mMinEpochIterations = (numIters == 0) ? 1 : numIters; + return *this; +} +uint64_t Bench::minEpochIterations() const noexcept { + return mConfig.mMinEpochIterations; +} + +Bench& Bench::epochIterations(uint64_t numIters) noexcept { + mConfig.mEpochIterations = numIters; + return *this; +} +uint64_t Bench::epochIterations() const noexcept { + return mConfig.mEpochIterations; +} + +Bench& Bench::warmup(uint64_t numWarmupIters) noexcept { + mConfig.mWarmup = numWarmupIters; + return *this; +} +uint64_t Bench::warmup() const noexcept { + return mConfig.mWarmup; +} + +Bench& Bench::config(Config const& benchmarkConfig) { + mConfig = benchmarkConfig; + return *this; +} +Config const& Bench::config() const noexcept { + return mConfig; +} + +Bench& Bench::output(std::ostream* outstream) noexcept { + mConfig.mOut = outstream; + return *this; +} + +ANKERL_NANOBENCH(NODISCARD) std::ostream* Bench::output() const noexcept { + return mConfig.mOut; +} + +std::vector const& Bench::results() const noexcept { + return mResults; +} + +Bench& Bench::render(char const* templateContent, std::ostream& os) { + ::ankerl::nanobench::render(templateContent, *this, os); + return *this; +} + +Bench& Bench::render(std::string const& templateContent, std::ostream& os) { + ::ankerl::nanobench::render(templateContent, *this, os); + return *this; +} + +std::vector Bench::complexityBigO() const { + std::vector bigOs; + auto rangeMeasure = BigO::collectRangeMeasure(mResults); + bigOs.emplace_back("O(1)", rangeMeasure, [](double) { + return 1.0; + }); + bigOs.emplace_back("O(n)", rangeMeasure, [](double n) { + return n; + }); + bigOs.emplace_back("O(log n)", rangeMeasure, [](double n) { + return std::log2(n); + }); + bigOs.emplace_back("O(n log n)", rangeMeasure, [](double n) { + return n * std::log2(n); + }); + bigOs.emplace_back("O(n^2)", rangeMeasure, [](double n) { + return n * n; + }); + bigOs.emplace_back("O(n^3)", rangeMeasure, [](double n) { + return n * n * n; + }); + std::sort(bigOs.begin(), bigOs.end()); + return bigOs; +} + +Rng::Rng() + : mX(0) + , mY(0) { + std::random_device rd; + std::uniform_int_distribution dist; + do { + mX = dist(rd); + mY = dist(rd); + } while (mX == 0 && mY == 0); +} + +ANKERL_NANOBENCH_NO_SANITIZE("integer", "undefined") +uint64_t splitMix64(uint64_t& state) noexcept { + uint64_t z = (state += UINT64_C(0x9e3779b97f4a7c15)); + z = (z ^ (z >> 30U)) * UINT64_C(0xbf58476d1ce4e5b9); + z = (z ^ (z >> 27U)) * UINT64_C(0x94d049bb133111eb); + return z ^ (z >> 31U); +} + +// Seeded as described in romu paper (update april 2020) +Rng::Rng(uint64_t seed) noexcept + : mX(splitMix64(seed)) + , mY(splitMix64(seed)) { + for (size_t i = 0; i < 10; ++i) { + operator()(); + } +} + +// only internally used to copy the RNG. +Rng::Rng(uint64_t x, uint64_t y) noexcept + : mX(x) + , mY(y) {} + +Rng Rng::copy() const noexcept { + return Rng{mX, mY}; +} + +Rng::Rng(std::vector const& data) + : mX(0) + , mY(0) { + if (data.size() != 2) { + throw std::runtime_error("ankerl::nanobench::Rng::Rng: needed exactly 2 entries in data, but got " + + detail::fmt::to_s(data.size())); + } + mX = data[0]; + mY = data[1]; +} + +std::vector Rng::state() const { + std::vector data(2); + data[0] = mX; + data[1] = mY; + return data; +} + +BigO::RangeMeasure BigO::collectRangeMeasure(std::vector const& results) { + BigO::RangeMeasure rangeMeasure; + for (auto const& result : results) { + if (result.config().mComplexityN > 0.0) { + rangeMeasure.emplace_back(result.config().mComplexityN, result.median(Result::Measure::elapsed)); + } + } + return rangeMeasure; +} + +BigO::BigO(std::string bigOName, RangeMeasure const& rangeMeasure) + : mName(std::move(bigOName)) { + + // estimate the constant factor + double sumRangeMeasure = 0.0; + double sumRangeRange = 0.0; + + for (const auto& rm : rangeMeasure) { + sumRangeMeasure += rm.first * rm.second; + sumRangeRange += rm.first * rm.first; + } + mConstant = sumRangeMeasure / sumRangeRange; + + // calculate root mean square + double err = 0.0; + double sumMeasure = 0.0; + for (const auto& rm : rangeMeasure) { + auto diff = mConstant * rm.first - rm.second; + err += diff * diff; + + sumMeasure += rm.second; + } + + auto n = detail::d(rangeMeasure.size()); + auto mean = sumMeasure / n; + mNormalizedRootMeanSquare = std::sqrt(err / n) / mean; +} + +BigO::BigO(const char* bigOName, RangeMeasure const& rangeMeasure) + : BigO(std::string(bigOName), rangeMeasure) {} + +std::string const& BigO::name() const noexcept { + return mName; +} + +double BigO::constant() const noexcept { + return mConstant; +} + +double BigO::normalizedRootMeanSquare() const noexcept { + return mNormalizedRootMeanSquare; +} + +bool BigO::operator<(BigO const& other) const noexcept { + return std::tie(mNormalizedRootMeanSquare, mName) < std::tie(other.mNormalizedRootMeanSquare, other.mName); +} + +std::ostream& operator<<(std::ostream& os, BigO const& bigO) { + return os << bigO.constant() << " * " << bigO.name() << ", rms=" << bigO.normalizedRootMeanSquare(); +} + +std::ostream& operator<<(std::ostream& os, std::vector const& bigOs) { + detail::fmt::StreamStateRestorer const restorer(os); + os << std::endl << "| coefficient | err% | complexity" << std::endl << "|--------------:|-------:|------------" << std::endl; + for (auto const& bigO : bigOs) { + os << "|" << std::setw(14) << std::setprecision(7) << std::scientific << bigO.constant() << " "; + os << "|" << detail::fmt::Number(6, 1, bigO.normalizedRootMeanSquare() * 100.0) << "% "; + os << "| " << bigO.name(); + os << std::endl; + } + return os; +} + +} // namespace nanobench +} // namespace ankerl + +#endif // ANKERL_NANOBENCH_IMPLEMENT +#endif // ANKERL_NANOBENCH_H_INCLUDED diff --git a/datasets/mnist/images-idx3-ubyte.gz b/datasets/mnist/images-idx3-ubyte.gz new file mode 100644 index 0000000..b50e4b6 Binary files /dev/null and b/datasets/mnist/images-idx3-ubyte.gz differ diff --git a/datasets/mnist/labels-idx1-ubyte.gz b/datasets/mnist/labels-idx1-ubyte.gz new file mode 100644 index 0000000..707a576 Binary files /dev/null and b/datasets/mnist/labels-idx1-ubyte.gz differ diff --git a/extern/stb_image.h b/extern/stb_image.h new file mode 100644 index 0000000..cbe247a --- /dev/null +++ b/extern/stb_image.h @@ -0,0 +1,7988 @@ +/* stb_image - v2.30 - public domain image loader - http://nothings.org/stb + no warranty implied; use at your own risk + + Do this: + #define STB_IMAGE_IMPLEMENTATION + before you include this file in *one* C or C++ file to create the implementation. + + // i.e. it should look like this: + #include ... + #include ... + #include ... + #define STB_IMAGE_IMPLEMENTATION + #include "stb_image.h" + + You can #define STBI_ASSERT(x) before the #include to avoid using assert.h. + And #define STBI_MALLOC, STBI_REALLOC, and STBI_FREE to avoid using malloc,realloc,free + + + QUICK NOTES: + Primarily of interest to game developers and other people who can + avoid problematic images and only need the trivial interface + + JPEG baseline & progressive (12 bpc/arithmetic not supported, same as stock IJG lib) + PNG 1/2/4/8/16-bit-per-channel + + TGA (not sure what subset, if a subset) + BMP non-1bpp, non-RLE + PSD (composited view only, no extra channels, 8/16 bit-per-channel) + + GIF (*comp always reports as 4-channel) + HDR (radiance rgbE format) + PIC (Softimage PIC) + PNM (PPM and PGM binary only) + + Animated GIF still needs a proper API, but here's one way to do it: + http://gist.github.com/urraka/685d9a6340b26b830d49 + + - decode from memory or through FILE (define STBI_NO_STDIO to remove code) + - decode from arbitrary I/O callbacks + - SIMD acceleration on x86/x64 (SSE2) and ARM (NEON) + + Full documentation under "DOCUMENTATION" below. + + +LICENSE + + See end of file for license information. + +RECENT REVISION HISTORY: + + 2.30 (2024-05-31) avoid erroneous gcc warning + 2.29 (2023-05-xx) optimizations + 2.28 (2023-01-29) many error fixes, security errors, just tons of stuff + 2.27 (2021-07-11) document stbi_info better, 16-bit PNM support, bug fixes + 2.26 (2020-07-13) many minor fixes + 2.25 (2020-02-02) fix warnings + 2.24 (2020-02-02) fix warnings; thread-local failure_reason and flip_vertically + 2.23 (2019-08-11) fix clang static analysis warning + 2.22 (2019-03-04) gif fixes, fix warnings + 2.21 (2019-02-25) fix typo in comment + 2.20 (2019-02-07) support utf8 filenames in Windows; fix warnings and platform ifdefs + 2.19 (2018-02-11) fix warning + 2.18 (2018-01-30) fix warnings + 2.17 (2018-01-29) bugfix, 1-bit BMP, 16-bitness query, fix warnings + 2.16 (2017-07-23) all functions have 16-bit variants; optimizations; bugfixes + 2.15 (2017-03-18) fix png-1,2,4; all Imagenet JPGs; no runtime SSE detection on GCC + 2.14 (2017-03-03) remove deprecated STBI_JPEG_OLD; fixes for Imagenet JPGs + 2.13 (2016-12-04) experimental 16-bit API, only for PNG so far; fixes + 2.12 (2016-04-02) fix typo in 2.11 PSD fix that caused crashes + 2.11 (2016-04-02) 16-bit PNGS; enable SSE2 in non-gcc x64 + RGB-format JPEG; remove white matting in PSD; + allocate large structures on the stack; + correct channel count for PNG & BMP + 2.10 (2016-01-22) avoid warning introduced in 2.09 + 2.09 (2016-01-16) 16-bit TGA; comments in PNM files; STBI_REALLOC_SIZED + + See end of file for full revision history. + + + ============================ Contributors ========================= + + Image formats Extensions, features + Sean Barrett (jpeg, png, bmp) Jetro Lauha (stbi_info) + Nicolas Schulz (hdr, psd) Martin "SpartanJ" Golini (stbi_info) + Jonathan Dummer (tga) James "moose2000" Brown (iPhone PNG) + Jean-Marc Lienher (gif) Ben "Disch" Wenger (io callbacks) + Tom Seddon (pic) Omar Cornut (1/2/4-bit PNG) + Thatcher Ulrich (psd) Nicolas Guillemot (vertical flip) + Ken Miller (pgm, ppm) Richard Mitton (16-bit PSD) + github:urraka (animated gif) Junggon Kim (PNM comments) + Christopher Forseth (animated gif) Daniel Gibson (16-bit TGA) + socks-the-fox (16-bit PNG) + Jeremy Sawicki (handle all ImageNet JPGs) + Optimizations & bugfixes Mikhail Morozov (1-bit BMP) + Fabian "ryg" Giesen Anael Seghezzi (is-16-bit query) + Arseny Kapoulkine Simon Breuss (16-bit PNM) + John-Mark Allen + Carmelo J Fdez-Aguera + + Bug & warning fixes + Marc LeBlanc David Woo Guillaume George Martins Mozeiko + Christpher Lloyd Jerry Jansson Joseph Thomson Blazej Dariusz Roszkowski + Phil Jordan Dave Moore Roy Eltham + Hayaki Saito Nathan Reed Won Chun + Luke Graham Johan Duparc Nick Verigakis the Horde3D community + Thomas Ruf Ronny Chevalier github:rlyeh + Janez Zemva John Bartholomew Michal Cichon github:romigrou + Jonathan Blow Ken Hamada Tero Hanninen github:svdijk + Eugene Golushkov Laurent Gomila Cort Stratton github:snagar + Aruelien Pocheville Sergio Gonzalez Thibault Reuille github:Zelex + Cass Everitt Ryamond Barbiero github:grim210 + Paul Du Bois Engin Manap Aldo Culquicondor github:sammyhw + Philipp Wiesemann Dale Weiler Oriol Ferrer Mesia github:phprus + Josh Tobin Neil Bickford Matthew Gregan github:poppolopoppo + Julian Raschke Gregory Mullen Christian Floisand github:darealshinji + Baldur Karlsson Kevin Schmidt JR Smith github:Michaelangel007 + Brad Weinberger Matvey Cherevko github:mosra + Luca Sas Alexander Veselov Zack Middleton [reserved] + Ryan C. Gordon [reserved] [reserved] + DO NOT ADD YOUR NAME HERE + + Jacko Dirks + + To add your name to the credits, pick a random blank space in the middle and fill it. + 80% of merge conflicts on stb PRs are due to people adding their name at the end + of the credits. +*/ + +#ifndef STBI_INCLUDE_STB_IMAGE_H +#define STBI_INCLUDE_STB_IMAGE_H + +// DOCUMENTATION +// +// Limitations: +// - no 12-bit-per-channel JPEG +// - no JPEGs with arithmetic coding +// - GIF always returns *comp=4 +// +// Basic usage (see HDR discussion below for HDR usage): +// int x,y,n; +// unsigned char *data = stbi_load(filename, &x, &y, &n, 0); +// // ... process data if not NULL ... +// // ... x = width, y = height, n = # 8-bit components per pixel ... +// // ... replace '0' with '1'..'4' to force that many components per pixel +// // ... but 'n' will always be the number that it would have been if you said 0 +// stbi_image_free(data); +// +// Standard parameters: +// int *x -- outputs image width in pixels +// int *y -- outputs image height in pixels +// int *channels_in_file -- outputs # of image components in image file +// int desired_channels -- if non-zero, # of image components requested in result +// +// The return value from an image loader is an 'unsigned char *' which points +// to the pixel data, or NULL on an allocation failure or if the image is +// corrupt or invalid. The pixel data consists of *y scanlines of *x pixels, +// with each pixel consisting of N interleaved 8-bit components; the first +// pixel pointed to is top-left-most in the image. There is no padding between +// image scanlines or between pixels, regardless of format. The number of +// components N is 'desired_channels' if desired_channels is non-zero, or +// *channels_in_file otherwise. If desired_channels is non-zero, +// *channels_in_file has the number of components that _would_ have been +// output otherwise. E.g. if you set desired_channels to 4, you will always +// get RGBA output, but you can check *channels_in_file to see if it's trivially +// opaque because e.g. there were only 3 channels in the source image. +// +// An output image with N components has the following components interleaved +// in this order in each pixel: +// +// N=#comp components +// 1 grey +// 2 grey, alpha +// 3 red, green, blue +// 4 red, green, blue, alpha +// +// If image loading fails for any reason, the return value will be NULL, +// and *x, *y, *channels_in_file will be unchanged. The function +// stbi_failure_reason() can be queried for an extremely brief, end-user +// unfriendly explanation of why the load failed. Define STBI_NO_FAILURE_STRINGS +// to avoid compiling these strings at all, and STBI_FAILURE_USERMSG to get slightly +// more user-friendly ones. +// +// Paletted PNG, BMP, GIF, and PIC images are automatically depalettized. +// +// To query the width, height and component count of an image without having to +// decode the full file, you can use the stbi_info family of functions: +// +// int x,y,n,ok; +// ok = stbi_info(filename, &x, &y, &n); +// // returns ok=1 and sets x, y, n if image is a supported format, +// // 0 otherwise. +// +// Note that stb_image pervasively uses ints in its public API for sizes, +// including sizes of memory buffers. This is now part of the API and thus +// hard to change without causing breakage. As a result, the various image +// loaders all have certain limits on image size; these differ somewhat +// by format but generally boil down to either just under 2GB or just under +// 1GB. When the decoded image would be larger than this, stb_image decoding +// will fail. +// +// Additionally, stb_image will reject image files that have any of their +// dimensions set to a larger value than the configurable STBI_MAX_DIMENSIONS, +// which defaults to 2**24 = 16777216 pixels. Due to the above memory limit, +// the only way to have an image with such dimensions load correctly +// is for it to have a rather extreme aspect ratio. Either way, the +// assumption here is that such larger images are likely to be malformed +// or malicious. If you do need to load an image with individual dimensions +// larger than that, and it still fits in the overall size limit, you can +// #define STBI_MAX_DIMENSIONS on your own to be something larger. +// +// =========================================================================== +// +// UNICODE: +// +// If compiling for Windows and you wish to use Unicode filenames, compile +// with +// #define STBI_WINDOWS_UTF8 +// and pass utf8-encoded filenames. Call stbi_convert_wchar_to_utf8 to convert +// Windows wchar_t filenames to utf8. +// +// =========================================================================== +// +// Philosophy +// +// stb libraries are designed with the following priorities: +// +// 1. easy to use +// 2. easy to maintain +// 3. good performance +// +// Sometimes I let "good performance" creep up in priority over "easy to maintain", +// and for best performance I may provide less-easy-to-use APIs that give higher +// performance, in addition to the easy-to-use ones. Nevertheless, it's important +// to keep in mind that from the standpoint of you, a client of this library, +// all you care about is #1 and #3, and stb libraries DO NOT emphasize #3 above all. +// +// Some secondary priorities arise directly from the first two, some of which +// provide more explicit reasons why performance can't be emphasized. +// +// - Portable ("ease of use") +// - Small source code footprint ("easy to maintain") +// - No dependencies ("ease of use") +// +// =========================================================================== +// +// I/O callbacks +// +// I/O callbacks allow you to read from arbitrary sources, like packaged +// files or some other source. Data read from callbacks are processed +// through a small internal buffer (currently 128 bytes) to try to reduce +// overhead. +// +// The three functions you must define are "read" (reads some bytes of data), +// "skip" (skips some bytes of data), "eof" (reports if the stream is at the end). +// +// =========================================================================== +// +// SIMD support +// +// The JPEG decoder will try to automatically use SIMD kernels on x86 when +// supported by the compiler. For ARM Neon support, you must explicitly +// request it. +// +// (The old do-it-yourself SIMD API is no longer supported in the current +// code.) +// +// On x86, SSE2 will automatically be used when available based on a run-time +// test; if not, the generic C versions are used as a fall-back. On ARM targets, +// the typical path is to have separate builds for NEON and non-NEON devices +// (at least this is true for iOS and Android). Therefore, the NEON support is +// toggled by a build flag: define STBI_NEON to get NEON loops. +// +// If for some reason you do not want to use any of SIMD code, or if +// you have issues compiling it, you can disable it entirely by +// defining STBI_NO_SIMD. +// +// =========================================================================== +// +// HDR image support (disable by defining STBI_NO_HDR) +// +// stb_image supports loading HDR images in general, and currently the Radiance +// .HDR file format specifically. You can still load any file through the existing +// interface; if you attempt to load an HDR file, it will be automatically remapped +// to LDR, assuming gamma 2.2 and an arbitrary scale factor defaulting to 1; +// both of these constants can be reconfigured through this interface: +// +// stbi_hdr_to_ldr_gamma(2.2f); +// stbi_hdr_to_ldr_scale(1.0f); +// +// (note, do not use _inverse_ constants; stbi_image will invert them +// appropriately). +// +// Additionally, there is a new, parallel interface for loading files as +// (linear) floats to preserve the full dynamic range: +// +// float *data = stbi_loadf(filename, &x, &y, &n, 0); +// +// If you load LDR images through this interface, those images will +// be promoted to floating point values, run through the inverse of +// constants corresponding to the above: +// +// stbi_ldr_to_hdr_scale(1.0f); +// stbi_ldr_to_hdr_gamma(2.2f); +// +// Finally, given a filename (or an open file or memory block--see header +// file for details) containing image data, you can query for the "most +// appropriate" interface to use (that is, whether the image is HDR or +// not), using: +// +// stbi_is_hdr(char *filename); +// +// =========================================================================== +// +// iPhone PNG support: +// +// We optionally support converting iPhone-formatted PNGs (which store +// premultiplied BGRA) back to RGB, even though they're internally encoded +// differently. To enable this conversion, call +// stbi_convert_iphone_png_to_rgb(1). +// +// Call stbi_set_unpremultiply_on_load(1) as well to force a divide per +// pixel to remove any premultiplied alpha *only* if the image file explicitly +// says there's premultiplied data (currently only happens in iPhone images, +// and only if iPhone convert-to-rgb processing is on). +// +// =========================================================================== +// +// ADDITIONAL CONFIGURATION +// +// - You can suppress implementation of any of the decoders to reduce +// your code footprint by #defining one or more of the following +// symbols before creating the implementation. +// +// STBI_NO_JPEG +// STBI_NO_PNG +// STBI_NO_BMP +// STBI_NO_PSD +// STBI_NO_TGA +// STBI_NO_GIF +// STBI_NO_HDR +// STBI_NO_PIC +// STBI_NO_PNM (.ppm and .pgm) +// +// - You can request *only* certain decoders and suppress all other ones +// (this will be more forward-compatible, as addition of new decoders +// doesn't require you to disable them explicitly): +// +// STBI_ONLY_JPEG +// STBI_ONLY_PNG +// STBI_ONLY_BMP +// STBI_ONLY_PSD +// STBI_ONLY_TGA +// STBI_ONLY_GIF +// STBI_ONLY_HDR +// STBI_ONLY_PIC +// STBI_ONLY_PNM (.ppm and .pgm) +// +// - If you use STBI_NO_PNG (or _ONLY_ without PNG), and you still +// want the zlib decoder to be available, #define STBI_SUPPORT_ZLIB +// +// - If you define STBI_MAX_DIMENSIONS, stb_image will reject images greater +// than that size (in either width or height) without further processing. +// This is to let programs in the wild set an upper bound to prevent +// denial-of-service attacks on untrusted data, as one could generate a +// valid image of gigantic dimensions and force stb_image to allocate a +// huge block of memory and spend disproportionate time decoding it. By +// default this is set to (1 << 24), which is 16777216, but that's still +// very big. + +#ifndef STBI_NO_STDIO +#include +#endif // STBI_NO_STDIO + +#define STBI_VERSION 1 + +enum +{ + STBI_default = 0, // only used for desired_channels + + STBI_grey = 1, + STBI_grey_alpha = 2, + STBI_rgb = 3, + STBI_rgb_alpha = 4 +}; + +#include +typedef unsigned char stbi_uc; +typedef unsigned short stbi_us; + +#ifdef __cplusplus +extern "C" { +#endif + +#ifndef STBIDEF +#ifdef STB_IMAGE_STATIC +#define STBIDEF static +#else +#define STBIDEF extern +#endif +#endif + +////////////////////////////////////////////////////////////////////////////// +// +// PRIMARY API - works on images of any type +// + +// +// load image by filename, open file, or memory buffer +// + +typedef struct +{ + int (*read) (void *user,char *data,int size); // fill 'data' with 'size' bytes. return number of bytes actually read + void (*skip) (void *user,int n); // skip the next 'n' bytes, or 'unget' the last -n bytes if negative + int (*eof) (void *user); // returns nonzero if we are at end of file/data +} stbi_io_callbacks; + +//////////////////////////////////// +// +// 8-bits-per-channel interface +// + +STBIDEF stbi_uc *stbi_load_from_memory (stbi_uc const *buffer, int len , int *x, int *y, int *channels_in_file, int desired_channels); +STBIDEF stbi_uc *stbi_load_from_callbacks(stbi_io_callbacks const *clbk , void *user, int *x, int *y, int *channels_in_file, int desired_channels); + +#ifndef STBI_NO_STDIO +STBIDEF stbi_uc *stbi_load (char const *filename, int *x, int *y, int *channels_in_file, int desired_channels); +STBIDEF stbi_uc *stbi_load_from_file (FILE *f, int *x, int *y, int *channels_in_file, int desired_channels); +// for stbi_load_from_file, file pointer is left pointing immediately after image +#endif + +#ifndef STBI_NO_GIF +STBIDEF stbi_uc *stbi_load_gif_from_memory(stbi_uc const *buffer, int len, int **delays, int *x, int *y, int *z, int *comp, int req_comp); +#endif + +#ifdef STBI_WINDOWS_UTF8 +STBIDEF int stbi_convert_wchar_to_utf8(char *buffer, size_t bufferlen, const wchar_t* input); +#endif + +//////////////////////////////////// +// +// 16-bits-per-channel interface +// + +STBIDEF stbi_us *stbi_load_16_from_memory (stbi_uc const *buffer, int len, int *x, int *y, int *channels_in_file, int desired_channels); +STBIDEF stbi_us *stbi_load_16_from_callbacks(stbi_io_callbacks const *clbk, void *user, int *x, int *y, int *channels_in_file, int desired_channels); + +#ifndef STBI_NO_STDIO +STBIDEF stbi_us *stbi_load_16 (char const *filename, int *x, int *y, int *channels_in_file, int desired_channels); +STBIDEF stbi_us *stbi_load_from_file_16(FILE *f, int *x, int *y, int *channels_in_file, int desired_channels); +#endif + +//////////////////////////////////// +// +// float-per-channel interface +// +#ifndef STBI_NO_LINEAR +STBIDEF float *stbi_loadf_from_memory (stbi_uc const *buffer, int len, int *x, int *y, int *channels_in_file, int desired_channels); +STBIDEF float *stbi_loadf_from_callbacks (stbi_io_callbacks const *clbk, void *user, int *x, int *y, int *channels_in_file, int desired_channels); + +#ifndef STBI_NO_STDIO +STBIDEF float *stbi_loadf (char const *filename, int *x, int *y, int *channels_in_file, int desired_channels); +STBIDEF float *stbi_loadf_from_file (FILE *f, int *x, int *y, int *channels_in_file, int desired_channels); +#endif +#endif + +#ifndef STBI_NO_HDR +STBIDEF void stbi_hdr_to_ldr_gamma(float gamma); +STBIDEF void stbi_hdr_to_ldr_scale(float scale); +#endif // STBI_NO_HDR + +#ifndef STBI_NO_LINEAR +STBIDEF void stbi_ldr_to_hdr_gamma(float gamma); +STBIDEF void stbi_ldr_to_hdr_scale(float scale); +#endif // STBI_NO_LINEAR + +// stbi_is_hdr is always defined, but always returns false if STBI_NO_HDR +STBIDEF int stbi_is_hdr_from_callbacks(stbi_io_callbacks const *clbk, void *user); +STBIDEF int stbi_is_hdr_from_memory(stbi_uc const *buffer, int len); +#ifndef STBI_NO_STDIO +STBIDEF int stbi_is_hdr (char const *filename); +STBIDEF int stbi_is_hdr_from_file(FILE *f); +#endif // STBI_NO_STDIO + + +// get a VERY brief reason for failure +// on most compilers (and ALL modern mainstream compilers) this is threadsafe +STBIDEF const char *stbi_failure_reason (void); + +// free the loaded image -- this is just free() +STBIDEF void stbi_image_free (void *retval_from_stbi_load); + +// get image dimensions & components without fully decoding +STBIDEF int stbi_info_from_memory(stbi_uc const *buffer, int len, int *x, int *y, int *comp); +STBIDEF int stbi_info_from_callbacks(stbi_io_callbacks const *clbk, void *user, int *x, int *y, int *comp); +STBIDEF int stbi_is_16_bit_from_memory(stbi_uc const *buffer, int len); +STBIDEF int stbi_is_16_bit_from_callbacks(stbi_io_callbacks const *clbk, void *user); + +#ifndef STBI_NO_STDIO +STBIDEF int stbi_info (char const *filename, int *x, int *y, int *comp); +STBIDEF int stbi_info_from_file (FILE *f, int *x, int *y, int *comp); +STBIDEF int stbi_is_16_bit (char const *filename); +STBIDEF int stbi_is_16_bit_from_file(FILE *f); +#endif + + + +// for image formats that explicitly notate that they have premultiplied alpha, +// we just return the colors as stored in the file. set this flag to force +// unpremultiplication. results are undefined if the unpremultiply overflow. +STBIDEF void stbi_set_unpremultiply_on_load(int flag_true_if_should_unpremultiply); + +// indicate whether we should process iphone images back to canonical format, +// or just pass them through "as-is" +STBIDEF void stbi_convert_iphone_png_to_rgb(int flag_true_if_should_convert); + +// flip the image vertically, so the first pixel in the output array is the bottom left +STBIDEF void stbi_set_flip_vertically_on_load(int flag_true_if_should_flip); + +// as above, but only applies to images loaded on the thread that calls the function +// this function is only available if your compiler supports thread-local variables; +// calling it will fail to link if your compiler doesn't +STBIDEF void stbi_set_unpremultiply_on_load_thread(int flag_true_if_should_unpremultiply); +STBIDEF void stbi_convert_iphone_png_to_rgb_thread(int flag_true_if_should_convert); +STBIDEF void stbi_set_flip_vertically_on_load_thread(int flag_true_if_should_flip); + +// ZLIB client - used by PNG, available for other purposes + +STBIDEF char *stbi_zlib_decode_malloc_guesssize(const char *buffer, int len, int initial_size, int *outlen); +STBIDEF char *stbi_zlib_decode_malloc_guesssize_headerflag(const char *buffer, int len, int initial_size, int *outlen, int parse_header); +STBIDEF char *stbi_zlib_decode_malloc(const char *buffer, int len, int *outlen); +STBIDEF int stbi_zlib_decode_buffer(char *obuffer, int olen, const char *ibuffer, int ilen); + +STBIDEF char *stbi_zlib_decode_noheader_malloc(const char *buffer, int len, int *outlen); +STBIDEF int stbi_zlib_decode_noheader_buffer(char *obuffer, int olen, const char *ibuffer, int ilen); + + +#ifdef __cplusplus +} +#endif + +// +// +//// end header file ///////////////////////////////////////////////////// +#endif // STBI_INCLUDE_STB_IMAGE_H + +#ifdef STB_IMAGE_IMPLEMENTATION + +#if defined(STBI_ONLY_JPEG) || defined(STBI_ONLY_PNG) || defined(STBI_ONLY_BMP) \ + || defined(STBI_ONLY_TGA) || defined(STBI_ONLY_GIF) || defined(STBI_ONLY_PSD) \ + || defined(STBI_ONLY_HDR) || defined(STBI_ONLY_PIC) || defined(STBI_ONLY_PNM) \ + || defined(STBI_ONLY_ZLIB) + #ifndef STBI_ONLY_JPEG + #define STBI_NO_JPEG + #endif + #ifndef STBI_ONLY_PNG + #define STBI_NO_PNG + #endif + #ifndef STBI_ONLY_BMP + #define STBI_NO_BMP + #endif + #ifndef STBI_ONLY_PSD + #define STBI_NO_PSD + #endif + #ifndef STBI_ONLY_TGA + #define STBI_NO_TGA + #endif + #ifndef STBI_ONLY_GIF + #define STBI_NO_GIF + #endif + #ifndef STBI_ONLY_HDR + #define STBI_NO_HDR + #endif + #ifndef STBI_ONLY_PIC + #define STBI_NO_PIC + #endif + #ifndef STBI_ONLY_PNM + #define STBI_NO_PNM + #endif +#endif + +#if defined(STBI_NO_PNG) && !defined(STBI_SUPPORT_ZLIB) && !defined(STBI_NO_ZLIB) +#define STBI_NO_ZLIB +#endif + + +#include +#include // ptrdiff_t on osx +#include +#include +#include + +#if !defined(STBI_NO_LINEAR) || !defined(STBI_NO_HDR) +#include // ldexp, pow +#endif + +#ifndef STBI_NO_STDIO +#include +#endif + +#ifndef STBI_ASSERT +#include +#define STBI_ASSERT(x) assert(x) +#endif + +#ifdef __cplusplus +#define STBI_EXTERN extern "C" +#else +#define STBI_EXTERN extern +#endif + + +#ifndef _MSC_VER + #ifdef __cplusplus + #define stbi_inline inline + #else + #define stbi_inline + #endif +#else + #define stbi_inline __forceinline +#endif + +#ifndef STBI_NO_THREAD_LOCALS + #if defined(__cplusplus) && __cplusplus >= 201103L + #define STBI_THREAD_LOCAL thread_local + #elif defined(__GNUC__) && __GNUC__ < 5 + #define STBI_THREAD_LOCAL __thread + #elif defined(_MSC_VER) + #define STBI_THREAD_LOCAL __declspec(thread) + #elif defined (__STDC_VERSION__) && __STDC_VERSION__ >= 201112L && !defined(__STDC_NO_THREADS__) + #define STBI_THREAD_LOCAL _Thread_local + #endif + + #ifndef STBI_THREAD_LOCAL + #if defined(__GNUC__) + #define STBI_THREAD_LOCAL __thread + #endif + #endif +#endif + +#if defined(_MSC_VER) || defined(__SYMBIAN32__) +typedef unsigned short stbi__uint16; +typedef signed short stbi__int16; +typedef unsigned int stbi__uint32; +typedef signed int stbi__int32; +#else +#include +typedef uint16_t stbi__uint16; +typedef int16_t stbi__int16; +typedef uint32_t stbi__uint32; +typedef int32_t stbi__int32; +#endif + +// should produce compiler error if size is wrong +typedef unsigned char validate_uint32[sizeof(stbi__uint32)==4 ? 1 : -1]; + +#ifdef _MSC_VER +#define STBI_NOTUSED(v) (void)(v) +#else +#define STBI_NOTUSED(v) (void)sizeof(v) +#endif + +#ifdef _MSC_VER +#define STBI_HAS_LROTL +#endif + +#ifdef STBI_HAS_LROTL + #define stbi_lrot(x,y) _lrotl(x,y) +#else + #define stbi_lrot(x,y) (((x) << (y)) | ((x) >> (-(y) & 31))) +#endif + +#if defined(STBI_MALLOC) && defined(STBI_FREE) && (defined(STBI_REALLOC) || defined(STBI_REALLOC_SIZED)) +// ok +#elif !defined(STBI_MALLOC) && !defined(STBI_FREE) && !defined(STBI_REALLOC) && !defined(STBI_REALLOC_SIZED) +// ok +#else +#error "Must define all or none of STBI_MALLOC, STBI_FREE, and STBI_REALLOC (or STBI_REALLOC_SIZED)." +#endif + +#ifndef STBI_MALLOC +#define STBI_MALLOC(sz) malloc(sz) +#define STBI_REALLOC(p,newsz) realloc(p,newsz) +#define STBI_FREE(p) free(p) +#endif + +#ifndef STBI_REALLOC_SIZED +#define STBI_REALLOC_SIZED(p,oldsz,newsz) STBI_REALLOC(p,newsz) +#endif + +// x86/x64 detection +#if defined(__x86_64__) || defined(_M_X64) +#define STBI__X64_TARGET +#elif defined(__i386) || defined(_M_IX86) +#define STBI__X86_TARGET +#endif + +#if defined(__GNUC__) && defined(STBI__X86_TARGET) && !defined(__SSE2__) && !defined(STBI_NO_SIMD) +// gcc doesn't support sse2 intrinsics unless you compile with -msse2, +// which in turn means it gets to use SSE2 everywhere. This is unfortunate, +// but previous attempts to provide the SSE2 functions with runtime +// detection caused numerous issues. The way architecture extensions are +// exposed in GCC/Clang is, sadly, not really suited for one-file libs. +// New behavior: if compiled with -msse2, we use SSE2 without any +// detection; if not, we don't use it at all. +#define STBI_NO_SIMD +#endif + +#if defined(__MINGW32__) && defined(STBI__X86_TARGET) && !defined(STBI_MINGW_ENABLE_SSE2) && !defined(STBI_NO_SIMD) +// Note that __MINGW32__ doesn't actually mean 32-bit, so we have to avoid STBI__X64_TARGET +// +// 32-bit MinGW wants ESP to be 16-byte aligned, but this is not in the +// Windows ABI and VC++ as well as Windows DLLs don't maintain that invariant. +// As a result, enabling SSE2 on 32-bit MinGW is dangerous when not +// simultaneously enabling "-mstackrealign". +// +// See https://github.com/nothings/stb/issues/81 for more information. +// +// So default to no SSE2 on 32-bit MinGW. If you've read this far and added +// -mstackrealign to your build settings, feel free to #define STBI_MINGW_ENABLE_SSE2. +#define STBI_NO_SIMD +#endif + +#if !defined(STBI_NO_SIMD) && (defined(STBI__X86_TARGET) || defined(STBI__X64_TARGET)) +#define STBI_SSE2 +#include + +#ifdef _MSC_VER + +#if _MSC_VER >= 1400 // not VC6 +#include // __cpuid +static int stbi__cpuid3(void) +{ + int info[4]; + __cpuid(info,1); + return info[3]; +} +#else +static int stbi__cpuid3(void) +{ + int res; + __asm { + mov eax,1 + cpuid + mov res,edx + } + return res; +} +#endif + +#define STBI_SIMD_ALIGN(type, name) __declspec(align(16)) type name + +#if !defined(STBI_NO_JPEG) && defined(STBI_SSE2) +static int stbi__sse2_available(void) +{ + int info3 = stbi__cpuid3(); + return ((info3 >> 26) & 1) != 0; +} +#endif + +#else // assume GCC-style if not VC++ +#define STBI_SIMD_ALIGN(type, name) type name __attribute__((aligned(16))) + +#if !defined(STBI_NO_JPEG) && defined(STBI_SSE2) +static int stbi__sse2_available(void) +{ + // If we're even attempting to compile this on GCC/Clang, that means + // -msse2 is on, which means the compiler is allowed to use SSE2 + // instructions at will, and so are we. + return 1; +} +#endif + +#endif +#endif + +// ARM NEON +#if defined(STBI_NO_SIMD) && defined(STBI_NEON) +#undef STBI_NEON +#endif + +#ifdef STBI_NEON +#include +#ifdef _MSC_VER +#define STBI_SIMD_ALIGN(type, name) __declspec(align(16)) type name +#else +#define STBI_SIMD_ALIGN(type, name) type name __attribute__((aligned(16))) +#endif +#endif + +#ifndef STBI_SIMD_ALIGN +#define STBI_SIMD_ALIGN(type, name) type name +#endif + +#ifndef STBI_MAX_DIMENSIONS +#define STBI_MAX_DIMENSIONS (1 << 24) +#endif + +/////////////////////////////////////////////// +// +// stbi__context struct and start_xxx functions + +// stbi__context structure is our basic context used by all images, so it +// contains all the IO context, plus some basic image information +typedef struct +{ + stbi__uint32 img_x, img_y; + int img_n, img_out_n; + + stbi_io_callbacks io; + void *io_user_data; + + int read_from_callbacks; + int buflen; + stbi_uc buffer_start[128]; + int callback_already_read; + + stbi_uc *img_buffer, *img_buffer_end; + stbi_uc *img_buffer_original, *img_buffer_original_end; +} stbi__context; + + +static void stbi__refill_buffer(stbi__context *s); + +// initialize a memory-decode context +static void stbi__start_mem(stbi__context *s, stbi_uc const *buffer, int len) +{ + s->io.read = NULL; + s->read_from_callbacks = 0; + s->callback_already_read = 0; + s->img_buffer = s->img_buffer_original = (stbi_uc *) buffer; + s->img_buffer_end = s->img_buffer_original_end = (stbi_uc *) buffer+len; +} + +// initialize a callback-based context +static void stbi__start_callbacks(stbi__context *s, stbi_io_callbacks *c, void *user) +{ + s->io = *c; + s->io_user_data = user; + s->buflen = sizeof(s->buffer_start); + s->read_from_callbacks = 1; + s->callback_already_read = 0; + s->img_buffer = s->img_buffer_original = s->buffer_start; + stbi__refill_buffer(s); + s->img_buffer_original_end = s->img_buffer_end; +} + +#ifndef STBI_NO_STDIO + +static int stbi__stdio_read(void *user, char *data, int size) +{ + return (int) fread(data,1,size,(FILE*) user); +} + +static void stbi__stdio_skip(void *user, int n) +{ + int ch; + fseek((FILE*) user, n, SEEK_CUR); + ch = fgetc((FILE*) user); /* have to read a byte to reset feof()'s flag */ + if (ch != EOF) { + ungetc(ch, (FILE *) user); /* push byte back onto stream if valid. */ + } +} + +static int stbi__stdio_eof(void *user) +{ + return feof((FILE*) user) || ferror((FILE *) user); +} + +static stbi_io_callbacks stbi__stdio_callbacks = +{ + stbi__stdio_read, + stbi__stdio_skip, + stbi__stdio_eof, +}; + +static void stbi__start_file(stbi__context *s, FILE *f) +{ + stbi__start_callbacks(s, &stbi__stdio_callbacks, (void *) f); +} + +//static void stop_file(stbi__context *s) { } + +#endif // !STBI_NO_STDIO + +static void stbi__rewind(stbi__context *s) +{ + // conceptually rewind SHOULD rewind to the beginning of the stream, + // but we just rewind to the beginning of the initial buffer, because + // we only use it after doing 'test', which only ever looks at at most 92 bytes + s->img_buffer = s->img_buffer_original; + s->img_buffer_end = s->img_buffer_original_end; +} + +enum +{ + STBI_ORDER_RGB, + STBI_ORDER_BGR +}; + +typedef struct +{ + int bits_per_channel; + int num_channels; + int channel_order; +} stbi__result_info; + +#ifndef STBI_NO_JPEG +static int stbi__jpeg_test(stbi__context *s); +static void *stbi__jpeg_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri); +static int stbi__jpeg_info(stbi__context *s, int *x, int *y, int *comp); +#endif + +#ifndef STBI_NO_PNG +static int stbi__png_test(stbi__context *s); +static void *stbi__png_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri); +static int stbi__png_info(stbi__context *s, int *x, int *y, int *comp); +static int stbi__png_is16(stbi__context *s); +#endif + +#ifndef STBI_NO_BMP +static int stbi__bmp_test(stbi__context *s); +static void *stbi__bmp_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri); +static int stbi__bmp_info(stbi__context *s, int *x, int *y, int *comp); +#endif + +#ifndef STBI_NO_TGA +static int stbi__tga_test(stbi__context *s); +static void *stbi__tga_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri); +static int stbi__tga_info(stbi__context *s, int *x, int *y, int *comp); +#endif + +#ifndef STBI_NO_PSD +static int stbi__psd_test(stbi__context *s); +static void *stbi__psd_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri, int bpc); +static int stbi__psd_info(stbi__context *s, int *x, int *y, int *comp); +static int stbi__psd_is16(stbi__context *s); +#endif + +#ifndef STBI_NO_HDR +static int stbi__hdr_test(stbi__context *s); +static float *stbi__hdr_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri); +static int stbi__hdr_info(stbi__context *s, int *x, int *y, int *comp); +#endif + +#ifndef STBI_NO_PIC +static int stbi__pic_test(stbi__context *s); +static void *stbi__pic_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri); +static int stbi__pic_info(stbi__context *s, int *x, int *y, int *comp); +#endif + +#ifndef STBI_NO_GIF +static int stbi__gif_test(stbi__context *s); +static void *stbi__gif_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri); +static void *stbi__load_gif_main(stbi__context *s, int **delays, int *x, int *y, int *z, int *comp, int req_comp); +static int stbi__gif_info(stbi__context *s, int *x, int *y, int *comp); +#endif + +#ifndef STBI_NO_PNM +static int stbi__pnm_test(stbi__context *s); +static void *stbi__pnm_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri); +static int stbi__pnm_info(stbi__context *s, int *x, int *y, int *comp); +static int stbi__pnm_is16(stbi__context *s); +#endif + +static +#ifdef STBI_THREAD_LOCAL +STBI_THREAD_LOCAL +#endif +const char *stbi__g_failure_reason; + +STBIDEF const char *stbi_failure_reason(void) +{ + return stbi__g_failure_reason; +} + +#ifndef STBI_NO_FAILURE_STRINGS +static int stbi__err(const char *str) +{ + stbi__g_failure_reason = str; + return 0; +} +#endif + +static void *stbi__malloc(size_t size) +{ + return STBI_MALLOC(size); +} + +// stb_image uses ints pervasively, including for offset calculations. +// therefore the largest decoded image size we can support with the +// current code, even on 64-bit targets, is INT_MAX. this is not a +// significant limitation for the intended use case. +// +// we do, however, need to make sure our size calculations don't +// overflow. hence a few helper functions for size calculations that +// multiply integers together, making sure that they're non-negative +// and no overflow occurs. + +// return 1 if the sum is valid, 0 on overflow. +// negative terms are considered invalid. +static int stbi__addsizes_valid(int a, int b) +{ + if (b < 0) return 0; + // now 0 <= b <= INT_MAX, hence also + // 0 <= INT_MAX - b <= INTMAX. + // And "a + b <= INT_MAX" (which might overflow) is the + // same as a <= INT_MAX - b (no overflow) + return a <= INT_MAX - b; +} + +// returns 1 if the product is valid, 0 on overflow. +// negative factors are considered invalid. +static int stbi__mul2sizes_valid(int a, int b) +{ + if (a < 0 || b < 0) return 0; + if (b == 0) return 1; // mul-by-0 is always safe + // portable way to check for no overflows in a*b + return a <= INT_MAX/b; +} + +#if !defined(STBI_NO_JPEG) || !defined(STBI_NO_PNG) || !defined(STBI_NO_TGA) || !defined(STBI_NO_HDR) +// returns 1 if "a*b + add" has no negative terms/factors and doesn't overflow +static int stbi__mad2sizes_valid(int a, int b, int add) +{ + return stbi__mul2sizes_valid(a, b) && stbi__addsizes_valid(a*b, add); +} +#endif + +// returns 1 if "a*b*c + add" has no negative terms/factors and doesn't overflow +static int stbi__mad3sizes_valid(int a, int b, int c, int add) +{ + return stbi__mul2sizes_valid(a, b) && stbi__mul2sizes_valid(a*b, c) && + stbi__addsizes_valid(a*b*c, add); +} + +// returns 1 if "a*b*c*d + add" has no negative terms/factors and doesn't overflow +#if !defined(STBI_NO_LINEAR) || !defined(STBI_NO_HDR) || !defined(STBI_NO_PNM) +static int stbi__mad4sizes_valid(int a, int b, int c, int d, int add) +{ + return stbi__mul2sizes_valid(a, b) && stbi__mul2sizes_valid(a*b, c) && + stbi__mul2sizes_valid(a*b*c, d) && stbi__addsizes_valid(a*b*c*d, add); +} +#endif + +#if !defined(STBI_NO_JPEG) || !defined(STBI_NO_PNG) || !defined(STBI_NO_TGA) || !defined(STBI_NO_HDR) +// mallocs with size overflow checking +static void *stbi__malloc_mad2(int a, int b, int add) +{ + if (!stbi__mad2sizes_valid(a, b, add)) return NULL; + return stbi__malloc(a*b + add); +} +#endif + +static void *stbi__malloc_mad3(int a, int b, int c, int add) +{ + if (!stbi__mad3sizes_valid(a, b, c, add)) return NULL; + return stbi__malloc(a*b*c + add); +} + +#if !defined(STBI_NO_LINEAR) || !defined(STBI_NO_HDR) || !defined(STBI_NO_PNM) +static void *stbi__malloc_mad4(int a, int b, int c, int d, int add) +{ + if (!stbi__mad4sizes_valid(a, b, c, d, add)) return NULL; + return stbi__malloc(a*b*c*d + add); +} +#endif + +// returns 1 if the sum of two signed ints is valid (between -2^31 and 2^31-1 inclusive), 0 on overflow. +static int stbi__addints_valid(int a, int b) +{ + if ((a >= 0) != (b >= 0)) return 1; // a and b have different signs, so no overflow + if (a < 0 && b < 0) return a >= INT_MIN - b; // same as a + b >= INT_MIN; INT_MIN - b cannot overflow since b < 0. + return a <= INT_MAX - b; +} + +// returns 1 if the product of two ints fits in a signed short, 0 on overflow. +static int stbi__mul2shorts_valid(int a, int b) +{ + if (b == 0 || b == -1) return 1; // multiplication by 0 is always 0; check for -1 so SHRT_MIN/b doesn't overflow + if ((a >= 0) == (b >= 0)) return a <= SHRT_MAX/b; // product is positive, so similar to mul2sizes_valid + if (b < 0) return a <= SHRT_MIN / b; // same as a * b >= SHRT_MIN + return a >= SHRT_MIN / b; +} + +// stbi__err - error +// stbi__errpf - error returning pointer to float +// stbi__errpuc - error returning pointer to unsigned char + +#ifdef STBI_NO_FAILURE_STRINGS + #define stbi__err(x,y) 0 +#elif defined(STBI_FAILURE_USERMSG) + #define stbi__err(x,y) stbi__err(y) +#else + #define stbi__err(x,y) stbi__err(x) +#endif + +#define stbi__errpf(x,y) ((float *)(size_t) (stbi__err(x,y)?NULL:NULL)) +#define stbi__errpuc(x,y) ((unsigned char *)(size_t) (stbi__err(x,y)?NULL:NULL)) + +STBIDEF void stbi_image_free(void *retval_from_stbi_load) +{ + STBI_FREE(retval_from_stbi_load); +} + +#ifndef STBI_NO_LINEAR +static float *stbi__ldr_to_hdr(stbi_uc *data, int x, int y, int comp); +#endif + +#ifndef STBI_NO_HDR +static stbi_uc *stbi__hdr_to_ldr(float *data, int x, int y, int comp); +#endif + +static int stbi__vertically_flip_on_load_global = 0; + +STBIDEF void stbi_set_flip_vertically_on_load(int flag_true_if_should_flip) +{ + stbi__vertically_flip_on_load_global = flag_true_if_should_flip; +} + +#ifndef STBI_THREAD_LOCAL +#define stbi__vertically_flip_on_load stbi__vertically_flip_on_load_global +#else +static STBI_THREAD_LOCAL int stbi__vertically_flip_on_load_local, stbi__vertically_flip_on_load_set; + +STBIDEF void stbi_set_flip_vertically_on_load_thread(int flag_true_if_should_flip) +{ + stbi__vertically_flip_on_load_local = flag_true_if_should_flip; + stbi__vertically_flip_on_load_set = 1; +} + +#define stbi__vertically_flip_on_load (stbi__vertically_flip_on_load_set \ + ? stbi__vertically_flip_on_load_local \ + : stbi__vertically_flip_on_load_global) +#endif // STBI_THREAD_LOCAL + +static void *stbi__load_main(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri, int bpc) +{ + memset(ri, 0, sizeof(*ri)); // make sure it's initialized if we add new fields + ri->bits_per_channel = 8; // default is 8 so most paths don't have to be changed + ri->channel_order = STBI_ORDER_RGB; // all current input & output are this, but this is here so we can add BGR order + ri->num_channels = 0; + + // test the formats with a very explicit header first (at least a FOURCC + // or distinctive magic number first) + #ifndef STBI_NO_PNG + if (stbi__png_test(s)) return stbi__png_load(s,x,y,comp,req_comp, ri); + #endif + #ifndef STBI_NO_BMP + if (stbi__bmp_test(s)) return stbi__bmp_load(s,x,y,comp,req_comp, ri); + #endif + #ifndef STBI_NO_GIF + if (stbi__gif_test(s)) return stbi__gif_load(s,x,y,comp,req_comp, ri); + #endif + #ifndef STBI_NO_PSD + if (stbi__psd_test(s)) return stbi__psd_load(s,x,y,comp,req_comp, ri, bpc); + #else + STBI_NOTUSED(bpc); + #endif + #ifndef STBI_NO_PIC + if (stbi__pic_test(s)) return stbi__pic_load(s,x,y,comp,req_comp, ri); + #endif + + // then the formats that can end up attempting to load with just 1 or 2 + // bytes matching expectations; these are prone to false positives, so + // try them later + #ifndef STBI_NO_JPEG + if (stbi__jpeg_test(s)) return stbi__jpeg_load(s,x,y,comp,req_comp, ri); + #endif + #ifndef STBI_NO_PNM + if (stbi__pnm_test(s)) return stbi__pnm_load(s,x,y,comp,req_comp, ri); + #endif + + #ifndef STBI_NO_HDR + if (stbi__hdr_test(s)) { + float *hdr = stbi__hdr_load(s, x,y,comp,req_comp, ri); + return stbi__hdr_to_ldr(hdr, *x, *y, req_comp ? req_comp : *comp); + } + #endif + + #ifndef STBI_NO_TGA + // test tga last because it's a crappy test! + if (stbi__tga_test(s)) + return stbi__tga_load(s,x,y,comp,req_comp, ri); + #endif + + return stbi__errpuc("unknown image type", "Image not of any known type, or corrupt"); +} + +static stbi_uc *stbi__convert_16_to_8(stbi__uint16 *orig, int w, int h, int channels) +{ + int i; + int img_len = w * h * channels; + stbi_uc *reduced; + + reduced = (stbi_uc *) stbi__malloc(img_len); + if (reduced == NULL) return stbi__errpuc("outofmem", "Out of memory"); + + for (i = 0; i < img_len; ++i) + reduced[i] = (stbi_uc)((orig[i] >> 8) & 0xFF); // top half of each byte is sufficient approx of 16->8 bit scaling + + STBI_FREE(orig); + return reduced; +} + +static stbi__uint16 *stbi__convert_8_to_16(stbi_uc *orig, int w, int h, int channels) +{ + int i; + int img_len = w * h * channels; + stbi__uint16 *enlarged; + + enlarged = (stbi__uint16 *) stbi__malloc(img_len*2); + if (enlarged == NULL) return (stbi__uint16 *) stbi__errpuc("outofmem", "Out of memory"); + + for (i = 0; i < img_len; ++i) + enlarged[i] = (stbi__uint16)((orig[i] << 8) + orig[i]); // replicate to high and low byte, maps 0->0, 255->0xffff + + STBI_FREE(orig); + return enlarged; +} + +static void stbi__vertical_flip(void *image, int w, int h, int bytes_per_pixel) +{ + int row; + size_t bytes_per_row = (size_t)w * bytes_per_pixel; + stbi_uc temp[2048]; + stbi_uc *bytes = (stbi_uc *)image; + + for (row = 0; row < (h>>1); row++) { + stbi_uc *row0 = bytes + row*bytes_per_row; + stbi_uc *row1 = bytes + (h - row - 1)*bytes_per_row; + // swap row0 with row1 + size_t bytes_left = bytes_per_row; + while (bytes_left) { + size_t bytes_copy = (bytes_left < sizeof(temp)) ? bytes_left : sizeof(temp); + memcpy(temp, row0, bytes_copy); + memcpy(row0, row1, bytes_copy); + memcpy(row1, temp, bytes_copy); + row0 += bytes_copy; + row1 += bytes_copy; + bytes_left -= bytes_copy; + } + } +} + +#ifndef STBI_NO_GIF +static void stbi__vertical_flip_slices(void *image, int w, int h, int z, int bytes_per_pixel) +{ + int slice; + int slice_size = w * h * bytes_per_pixel; + + stbi_uc *bytes = (stbi_uc *)image; + for (slice = 0; slice < z; ++slice) { + stbi__vertical_flip(bytes, w, h, bytes_per_pixel); + bytes += slice_size; + } +} +#endif + +static unsigned char *stbi__load_and_postprocess_8bit(stbi__context *s, int *x, int *y, int *comp, int req_comp) +{ + stbi__result_info ri; + void *result = stbi__load_main(s, x, y, comp, req_comp, &ri, 8); + + if (result == NULL) + return NULL; + + // it is the responsibility of the loaders to make sure we get either 8 or 16 bit. + STBI_ASSERT(ri.bits_per_channel == 8 || ri.bits_per_channel == 16); + + if (ri.bits_per_channel != 8) { + result = stbi__convert_16_to_8((stbi__uint16 *) result, *x, *y, req_comp == 0 ? *comp : req_comp); + ri.bits_per_channel = 8; + } + + // @TODO: move stbi__convert_format to here + + if (stbi__vertically_flip_on_load) { + int channels = req_comp ? req_comp : *comp; + stbi__vertical_flip(result, *x, *y, channels * sizeof(stbi_uc)); + } + + return (unsigned char *) result; +} + +static stbi__uint16 *stbi__load_and_postprocess_16bit(stbi__context *s, int *x, int *y, int *comp, int req_comp) +{ + stbi__result_info ri; + void *result = stbi__load_main(s, x, y, comp, req_comp, &ri, 16); + + if (result == NULL) + return NULL; + + // it is the responsibility of the loaders to make sure we get either 8 or 16 bit. + STBI_ASSERT(ri.bits_per_channel == 8 || ri.bits_per_channel == 16); + + if (ri.bits_per_channel != 16) { + result = stbi__convert_8_to_16((stbi_uc *) result, *x, *y, req_comp == 0 ? *comp : req_comp); + ri.bits_per_channel = 16; + } + + // @TODO: move stbi__convert_format16 to here + // @TODO: special case RGB-to-Y (and RGBA-to-YA) for 8-bit-to-16-bit case to keep more precision + + if (stbi__vertically_flip_on_load) { + int channels = req_comp ? req_comp : *comp; + stbi__vertical_flip(result, *x, *y, channels * sizeof(stbi__uint16)); + } + + return (stbi__uint16 *) result; +} + +#if !defined(STBI_NO_HDR) && !defined(STBI_NO_LINEAR) +static void stbi__float_postprocess(float *result, int *x, int *y, int *comp, int req_comp) +{ + if (stbi__vertically_flip_on_load && result != NULL) { + int channels = req_comp ? req_comp : *comp; + stbi__vertical_flip(result, *x, *y, channels * sizeof(float)); + } +} +#endif + +#ifndef STBI_NO_STDIO + +#if defined(_WIN32) && defined(STBI_WINDOWS_UTF8) +STBI_EXTERN __declspec(dllimport) int __stdcall MultiByteToWideChar(unsigned int cp, unsigned long flags, const char *str, int cbmb, wchar_t *widestr, int cchwide); +STBI_EXTERN __declspec(dllimport) int __stdcall WideCharToMultiByte(unsigned int cp, unsigned long flags, const wchar_t *widestr, int cchwide, char *str, int cbmb, const char *defchar, int *used_default); +#endif + +#if defined(_WIN32) && defined(STBI_WINDOWS_UTF8) +STBIDEF int stbi_convert_wchar_to_utf8(char *buffer, size_t bufferlen, const wchar_t* input) +{ + return WideCharToMultiByte(65001 /* UTF8 */, 0, input, -1, buffer, (int) bufferlen, NULL, NULL); +} +#endif + +static FILE *stbi__fopen(char const *filename, char const *mode) +{ + FILE *f; +#if defined(_WIN32) && defined(STBI_WINDOWS_UTF8) + wchar_t wMode[64]; + wchar_t wFilename[1024]; + if (0 == MultiByteToWideChar(65001 /* UTF8 */, 0, filename, -1, wFilename, sizeof(wFilename)/sizeof(*wFilename))) + return 0; + + if (0 == MultiByteToWideChar(65001 /* UTF8 */, 0, mode, -1, wMode, sizeof(wMode)/sizeof(*wMode))) + return 0; + +#if defined(_MSC_VER) && _MSC_VER >= 1400 + if (0 != _wfopen_s(&f, wFilename, wMode)) + f = 0; +#else + f = _wfopen(wFilename, wMode); +#endif + +#elif defined(_MSC_VER) && _MSC_VER >= 1400 + if (0 != fopen_s(&f, filename, mode)) + f=0; +#else + f = fopen(filename, mode); +#endif + return f; +} + + +STBIDEF stbi_uc *stbi_load(char const *filename, int *x, int *y, int *comp, int req_comp) +{ + FILE *f = stbi__fopen(filename, "rb"); + unsigned char *result; + if (!f) return stbi__errpuc("can't fopen", "Unable to open file"); + result = stbi_load_from_file(f,x,y,comp,req_comp); + fclose(f); + return result; +} + +STBIDEF stbi_uc *stbi_load_from_file(FILE *f, int *x, int *y, int *comp, int req_comp) +{ + unsigned char *result; + stbi__context s; + stbi__start_file(&s,f); + result = stbi__load_and_postprocess_8bit(&s,x,y,comp,req_comp); + if (result) { + // need to 'unget' all the characters in the IO buffer + fseek(f, - (int) (s.img_buffer_end - s.img_buffer), SEEK_CUR); + } + return result; +} + +STBIDEF stbi__uint16 *stbi_load_from_file_16(FILE *f, int *x, int *y, int *comp, int req_comp) +{ + stbi__uint16 *result; + stbi__context s; + stbi__start_file(&s,f); + result = stbi__load_and_postprocess_16bit(&s,x,y,comp,req_comp); + if (result) { + // need to 'unget' all the characters in the IO buffer + fseek(f, - (int) (s.img_buffer_end - s.img_buffer), SEEK_CUR); + } + return result; +} + +STBIDEF stbi_us *stbi_load_16(char const *filename, int *x, int *y, int *comp, int req_comp) +{ + FILE *f = stbi__fopen(filename, "rb"); + stbi__uint16 *result; + if (!f) return (stbi_us *) stbi__errpuc("can't fopen", "Unable to open file"); + result = stbi_load_from_file_16(f,x,y,comp,req_comp); + fclose(f); + return result; +} + + +#endif //!STBI_NO_STDIO + +STBIDEF stbi_us *stbi_load_16_from_memory(stbi_uc const *buffer, int len, int *x, int *y, int *channels_in_file, int desired_channels) +{ + stbi__context s; + stbi__start_mem(&s,buffer,len); + return stbi__load_and_postprocess_16bit(&s,x,y,channels_in_file,desired_channels); +} + +STBIDEF stbi_us *stbi_load_16_from_callbacks(stbi_io_callbacks const *clbk, void *user, int *x, int *y, int *channels_in_file, int desired_channels) +{ + stbi__context s; + stbi__start_callbacks(&s, (stbi_io_callbacks *)clbk, user); + return stbi__load_and_postprocess_16bit(&s,x,y,channels_in_file,desired_channels); +} + +STBIDEF stbi_uc *stbi_load_from_memory(stbi_uc const *buffer, int len, int *x, int *y, int *comp, int req_comp) +{ + stbi__context s; + stbi__start_mem(&s,buffer,len); + return stbi__load_and_postprocess_8bit(&s,x,y,comp,req_comp); +} + +STBIDEF stbi_uc *stbi_load_from_callbacks(stbi_io_callbacks const *clbk, void *user, int *x, int *y, int *comp, int req_comp) +{ + stbi__context s; + stbi__start_callbacks(&s, (stbi_io_callbacks *) clbk, user); + return stbi__load_and_postprocess_8bit(&s,x,y,comp,req_comp); +} + +#ifndef STBI_NO_GIF +STBIDEF stbi_uc *stbi_load_gif_from_memory(stbi_uc const *buffer, int len, int **delays, int *x, int *y, int *z, int *comp, int req_comp) +{ + unsigned char *result; + stbi__context s; + stbi__start_mem(&s,buffer,len); + + result = (unsigned char*) stbi__load_gif_main(&s, delays, x, y, z, comp, req_comp); + if (stbi__vertically_flip_on_load) { + stbi__vertical_flip_slices( result, *x, *y, *z, *comp ); + } + + return result; +} +#endif + +#ifndef STBI_NO_LINEAR +static float *stbi__loadf_main(stbi__context *s, int *x, int *y, int *comp, int req_comp) +{ + unsigned char *data; + #ifndef STBI_NO_HDR + if (stbi__hdr_test(s)) { + stbi__result_info ri; + float *hdr_data = stbi__hdr_load(s,x,y,comp,req_comp, &ri); + if (hdr_data) + stbi__float_postprocess(hdr_data,x,y,comp,req_comp); + return hdr_data; + } + #endif + data = stbi__load_and_postprocess_8bit(s, x, y, comp, req_comp); + if (data) + return stbi__ldr_to_hdr(data, *x, *y, req_comp ? req_comp : *comp); + return stbi__errpf("unknown image type", "Image not of any known type, or corrupt"); +} + +STBIDEF float *stbi_loadf_from_memory(stbi_uc const *buffer, int len, int *x, int *y, int *comp, int req_comp) +{ + stbi__context s; + stbi__start_mem(&s,buffer,len); + return stbi__loadf_main(&s,x,y,comp,req_comp); +} + +STBIDEF float *stbi_loadf_from_callbacks(stbi_io_callbacks const *clbk, void *user, int *x, int *y, int *comp, int req_comp) +{ + stbi__context s; + stbi__start_callbacks(&s, (stbi_io_callbacks *) clbk, user); + return stbi__loadf_main(&s,x,y,comp,req_comp); +} + +#ifndef STBI_NO_STDIO +STBIDEF float *stbi_loadf(char const *filename, int *x, int *y, int *comp, int req_comp) +{ + float *result; + FILE *f = stbi__fopen(filename, "rb"); + if (!f) return stbi__errpf("can't fopen", "Unable to open file"); + result = stbi_loadf_from_file(f,x,y,comp,req_comp); + fclose(f); + return result; +} + +STBIDEF float *stbi_loadf_from_file(FILE *f, int *x, int *y, int *comp, int req_comp) +{ + stbi__context s; + stbi__start_file(&s,f); + return stbi__loadf_main(&s,x,y,comp,req_comp); +} +#endif // !STBI_NO_STDIO + +#endif // !STBI_NO_LINEAR + +// these is-hdr-or-not is defined independent of whether STBI_NO_LINEAR is +// defined, for API simplicity; if STBI_NO_LINEAR is defined, it always +// reports false! + +STBIDEF int stbi_is_hdr_from_memory(stbi_uc const *buffer, int len) +{ + #ifndef STBI_NO_HDR + stbi__context s; + stbi__start_mem(&s,buffer,len); + return stbi__hdr_test(&s); + #else + STBI_NOTUSED(buffer); + STBI_NOTUSED(len); + return 0; + #endif +} + +#ifndef STBI_NO_STDIO +STBIDEF int stbi_is_hdr (char const *filename) +{ + FILE *f = stbi__fopen(filename, "rb"); + int result=0; + if (f) { + result = stbi_is_hdr_from_file(f); + fclose(f); + } + return result; +} + +STBIDEF int stbi_is_hdr_from_file(FILE *f) +{ + #ifndef STBI_NO_HDR + long pos = ftell(f); + int res; + stbi__context s; + stbi__start_file(&s,f); + res = stbi__hdr_test(&s); + fseek(f, pos, SEEK_SET); + return res; + #else + STBI_NOTUSED(f); + return 0; + #endif +} +#endif // !STBI_NO_STDIO + +STBIDEF int stbi_is_hdr_from_callbacks(stbi_io_callbacks const *clbk, void *user) +{ + #ifndef STBI_NO_HDR + stbi__context s; + stbi__start_callbacks(&s, (stbi_io_callbacks *) clbk, user); + return stbi__hdr_test(&s); + #else + STBI_NOTUSED(clbk); + STBI_NOTUSED(user); + return 0; + #endif +} + +#ifndef STBI_NO_LINEAR +static float stbi__l2h_gamma=2.2f, stbi__l2h_scale=1.0f; + +STBIDEF void stbi_ldr_to_hdr_gamma(float gamma) { stbi__l2h_gamma = gamma; } +STBIDEF void stbi_ldr_to_hdr_scale(float scale) { stbi__l2h_scale = scale; } +#endif + +static float stbi__h2l_gamma_i=1.0f/2.2f, stbi__h2l_scale_i=1.0f; + +STBIDEF void stbi_hdr_to_ldr_gamma(float gamma) { stbi__h2l_gamma_i = 1/gamma; } +STBIDEF void stbi_hdr_to_ldr_scale(float scale) { stbi__h2l_scale_i = 1/scale; } + + +////////////////////////////////////////////////////////////////////////////// +// +// Common code used by all image loaders +// + +enum +{ + STBI__SCAN_load=0, + STBI__SCAN_type, + STBI__SCAN_header +}; + +static void stbi__refill_buffer(stbi__context *s) +{ + int n = (s->io.read)(s->io_user_data,(char*)s->buffer_start,s->buflen); + s->callback_already_read += (int) (s->img_buffer - s->img_buffer_original); + if (n == 0) { + // at end of file, treat same as if from memory, but need to handle case + // where s->img_buffer isn't pointing to safe memory, e.g. 0-byte file + s->read_from_callbacks = 0; + s->img_buffer = s->buffer_start; + s->img_buffer_end = s->buffer_start+1; + *s->img_buffer = 0; + } else { + s->img_buffer = s->buffer_start; + s->img_buffer_end = s->buffer_start + n; + } +} + +stbi_inline static stbi_uc stbi__get8(stbi__context *s) +{ + if (s->img_buffer < s->img_buffer_end) + return *s->img_buffer++; + if (s->read_from_callbacks) { + stbi__refill_buffer(s); + return *s->img_buffer++; + } + return 0; +} + +#if defined(STBI_NO_JPEG) && defined(STBI_NO_HDR) && defined(STBI_NO_PIC) && defined(STBI_NO_PNM) +// nothing +#else +stbi_inline static int stbi__at_eof(stbi__context *s) +{ + if (s->io.read) { + if (!(s->io.eof)(s->io_user_data)) return 0; + // if feof() is true, check if buffer = end + // special case: we've only got the special 0 character at the end + if (s->read_from_callbacks == 0) return 1; + } + + return s->img_buffer >= s->img_buffer_end; +} +#endif + +#if defined(STBI_NO_JPEG) && defined(STBI_NO_PNG) && defined(STBI_NO_BMP) && defined(STBI_NO_PSD) && defined(STBI_NO_TGA) && defined(STBI_NO_GIF) && defined(STBI_NO_PIC) +// nothing +#else +static void stbi__skip(stbi__context *s, int n) +{ + if (n == 0) return; // already there! + if (n < 0) { + s->img_buffer = s->img_buffer_end; + return; + } + if (s->io.read) { + int blen = (int) (s->img_buffer_end - s->img_buffer); + if (blen < n) { + s->img_buffer = s->img_buffer_end; + (s->io.skip)(s->io_user_data, n - blen); + return; + } + } + s->img_buffer += n; +} +#endif + +#if defined(STBI_NO_PNG) && defined(STBI_NO_TGA) && defined(STBI_NO_HDR) && defined(STBI_NO_PNM) +// nothing +#else +static int stbi__getn(stbi__context *s, stbi_uc *buffer, int n) +{ + if (s->io.read) { + int blen = (int) (s->img_buffer_end - s->img_buffer); + if (blen < n) { + int res, count; + + memcpy(buffer, s->img_buffer, blen); + + count = (s->io.read)(s->io_user_data, (char*) buffer + blen, n - blen); + res = (count == (n-blen)); + s->img_buffer = s->img_buffer_end; + return res; + } + } + + if (s->img_buffer+n <= s->img_buffer_end) { + memcpy(buffer, s->img_buffer, n); + s->img_buffer += n; + return 1; + } else + return 0; +} +#endif + +#if defined(STBI_NO_JPEG) && defined(STBI_NO_PNG) && defined(STBI_NO_PSD) && defined(STBI_NO_PIC) +// nothing +#else +static int stbi__get16be(stbi__context *s) +{ + int z = stbi__get8(s); + return (z << 8) + stbi__get8(s); +} +#endif + +#if defined(STBI_NO_PNG) && defined(STBI_NO_PSD) && defined(STBI_NO_PIC) +// nothing +#else +static stbi__uint32 stbi__get32be(stbi__context *s) +{ + stbi__uint32 z = stbi__get16be(s); + return (z << 16) + stbi__get16be(s); +} +#endif + +#if defined(STBI_NO_BMP) && defined(STBI_NO_TGA) && defined(STBI_NO_GIF) +// nothing +#else +static int stbi__get16le(stbi__context *s) +{ + int z = stbi__get8(s); + return z + (stbi__get8(s) << 8); +} +#endif + +#ifndef STBI_NO_BMP +static stbi__uint32 stbi__get32le(stbi__context *s) +{ + stbi__uint32 z = stbi__get16le(s); + z += (stbi__uint32)stbi__get16le(s) << 16; + return z; +} +#endif + +#define STBI__BYTECAST(x) ((stbi_uc) ((x) & 255)) // truncate int to byte without warnings + +#if defined(STBI_NO_JPEG) && defined(STBI_NO_PNG) && defined(STBI_NO_BMP) && defined(STBI_NO_PSD) && defined(STBI_NO_TGA) && defined(STBI_NO_GIF) && defined(STBI_NO_PIC) && defined(STBI_NO_PNM) +// nothing +#else +////////////////////////////////////////////////////////////////////////////// +// +// generic converter from built-in img_n to req_comp +// individual types do this automatically as much as possible (e.g. jpeg +// does all cases internally since it needs to colorspace convert anyway, +// and it never has alpha, so very few cases ). png can automatically +// interleave an alpha=255 channel, but falls back to this for other cases +// +// assume data buffer is malloced, so malloc a new one and free that one +// only failure mode is malloc failing + +static stbi_uc stbi__compute_y(int r, int g, int b) +{ + return (stbi_uc) (((r*77) + (g*150) + (29*b)) >> 8); +} +#endif + +#if defined(STBI_NO_PNG) && defined(STBI_NO_BMP) && defined(STBI_NO_PSD) && defined(STBI_NO_TGA) && defined(STBI_NO_GIF) && defined(STBI_NO_PIC) && defined(STBI_NO_PNM) +// nothing +#else +static unsigned char *stbi__convert_format(unsigned char *data, int img_n, int req_comp, unsigned int x, unsigned int y) +{ + int i,j; + unsigned char *good; + + if (req_comp == img_n) return data; + STBI_ASSERT(req_comp >= 1 && req_comp <= 4); + + good = (unsigned char *) stbi__malloc_mad3(req_comp, x, y, 0); + if (good == NULL) { + STBI_FREE(data); + return stbi__errpuc("outofmem", "Out of memory"); + } + + for (j=0; j < (int) y; ++j) { + unsigned char *src = data + j * x * img_n ; + unsigned char *dest = good + j * x * req_comp; + + #define STBI__COMBO(a,b) ((a)*8+(b)) + #define STBI__CASE(a,b) case STBI__COMBO(a,b): for(i=x-1; i >= 0; --i, src += a, dest += b) + // convert source image with img_n components to one with req_comp components; + // avoid switch per pixel, so use switch per scanline and massive macros + switch (STBI__COMBO(img_n, req_comp)) { + STBI__CASE(1,2) { dest[0]=src[0]; dest[1]=255; } break; + STBI__CASE(1,3) { dest[0]=dest[1]=dest[2]=src[0]; } break; + STBI__CASE(1,4) { dest[0]=dest[1]=dest[2]=src[0]; dest[3]=255; } break; + STBI__CASE(2,1) { dest[0]=src[0]; } break; + STBI__CASE(2,3) { dest[0]=dest[1]=dest[2]=src[0]; } break; + STBI__CASE(2,4) { dest[0]=dest[1]=dest[2]=src[0]; dest[3]=src[1]; } break; + STBI__CASE(3,4) { dest[0]=src[0];dest[1]=src[1];dest[2]=src[2];dest[3]=255; } break; + STBI__CASE(3,1) { dest[0]=stbi__compute_y(src[0],src[1],src[2]); } break; + STBI__CASE(3,2) { dest[0]=stbi__compute_y(src[0],src[1],src[2]); dest[1] = 255; } break; + STBI__CASE(4,1) { dest[0]=stbi__compute_y(src[0],src[1],src[2]); } break; + STBI__CASE(4,2) { dest[0]=stbi__compute_y(src[0],src[1],src[2]); dest[1] = src[3]; } break; + STBI__CASE(4,3) { dest[0]=src[0];dest[1]=src[1];dest[2]=src[2]; } break; + default: STBI_ASSERT(0); STBI_FREE(data); STBI_FREE(good); return stbi__errpuc("unsupported", "Unsupported format conversion"); + } + #undef STBI__CASE + } + + STBI_FREE(data); + return good; +} +#endif + +#if defined(STBI_NO_PNG) && defined(STBI_NO_PSD) +// nothing +#else +static stbi__uint16 stbi__compute_y_16(int r, int g, int b) +{ + return (stbi__uint16) (((r*77) + (g*150) + (29*b)) >> 8); +} +#endif + +#if defined(STBI_NO_PNG) && defined(STBI_NO_PSD) +// nothing +#else +static stbi__uint16 *stbi__convert_format16(stbi__uint16 *data, int img_n, int req_comp, unsigned int x, unsigned int y) +{ + int i,j; + stbi__uint16 *good; + + if (req_comp == img_n) return data; + STBI_ASSERT(req_comp >= 1 && req_comp <= 4); + + good = (stbi__uint16 *) stbi__malloc(req_comp * x * y * 2); + if (good == NULL) { + STBI_FREE(data); + return (stbi__uint16 *) stbi__errpuc("outofmem", "Out of memory"); + } + + for (j=0; j < (int) y; ++j) { + stbi__uint16 *src = data + j * x * img_n ; + stbi__uint16 *dest = good + j * x * req_comp; + + #define STBI__COMBO(a,b) ((a)*8+(b)) + #define STBI__CASE(a,b) case STBI__COMBO(a,b): for(i=x-1; i >= 0; --i, src += a, dest += b) + // convert source image with img_n components to one with req_comp components; + // avoid switch per pixel, so use switch per scanline and massive macros + switch (STBI__COMBO(img_n, req_comp)) { + STBI__CASE(1,2) { dest[0]=src[0]; dest[1]=0xffff; } break; + STBI__CASE(1,3) { dest[0]=dest[1]=dest[2]=src[0]; } break; + STBI__CASE(1,4) { dest[0]=dest[1]=dest[2]=src[0]; dest[3]=0xffff; } break; + STBI__CASE(2,1) { dest[0]=src[0]; } break; + STBI__CASE(2,3) { dest[0]=dest[1]=dest[2]=src[0]; } break; + STBI__CASE(2,4) { dest[0]=dest[1]=dest[2]=src[0]; dest[3]=src[1]; } break; + STBI__CASE(3,4) { dest[0]=src[0];dest[1]=src[1];dest[2]=src[2];dest[3]=0xffff; } break; + STBI__CASE(3,1) { dest[0]=stbi__compute_y_16(src[0],src[1],src[2]); } break; + STBI__CASE(3,2) { dest[0]=stbi__compute_y_16(src[0],src[1],src[2]); dest[1] = 0xffff; } break; + STBI__CASE(4,1) { dest[0]=stbi__compute_y_16(src[0],src[1],src[2]); } break; + STBI__CASE(4,2) { dest[0]=stbi__compute_y_16(src[0],src[1],src[2]); dest[1] = src[3]; } break; + STBI__CASE(4,3) { dest[0]=src[0];dest[1]=src[1];dest[2]=src[2]; } break; + default: STBI_ASSERT(0); STBI_FREE(data); STBI_FREE(good); return (stbi__uint16*) stbi__errpuc("unsupported", "Unsupported format conversion"); + } + #undef STBI__CASE + } + + STBI_FREE(data); + return good; +} +#endif + +#ifndef STBI_NO_LINEAR +static float *stbi__ldr_to_hdr(stbi_uc *data, int x, int y, int comp) +{ + int i,k,n; + float *output; + if (!data) return NULL; + output = (float *) stbi__malloc_mad4(x, y, comp, sizeof(float), 0); + if (output == NULL) { STBI_FREE(data); return stbi__errpf("outofmem", "Out of memory"); } + // compute number of non-alpha components + if (comp & 1) n = comp; else n = comp-1; + for (i=0; i < x*y; ++i) { + for (k=0; k < n; ++k) { + output[i*comp + k] = (float) (pow(data[i*comp+k]/255.0f, stbi__l2h_gamma) * stbi__l2h_scale); + } + } + if (n < comp) { + for (i=0; i < x*y; ++i) { + output[i*comp + n] = data[i*comp + n]/255.0f; + } + } + STBI_FREE(data); + return output; +} +#endif + +#ifndef STBI_NO_HDR +#define stbi__float2int(x) ((int) (x)) +static stbi_uc *stbi__hdr_to_ldr(float *data, int x, int y, int comp) +{ + int i,k,n; + stbi_uc *output; + if (!data) return NULL; + output = (stbi_uc *) stbi__malloc_mad3(x, y, comp, 0); + if (output == NULL) { STBI_FREE(data); return stbi__errpuc("outofmem", "Out of memory"); } + // compute number of non-alpha components + if (comp & 1) n = comp; else n = comp-1; + for (i=0; i < x*y; ++i) { + for (k=0; k < n; ++k) { + float z = (float) pow(data[i*comp+k]*stbi__h2l_scale_i, stbi__h2l_gamma_i) * 255 + 0.5f; + if (z < 0) z = 0; + if (z > 255) z = 255; + output[i*comp + k] = (stbi_uc) stbi__float2int(z); + } + if (k < comp) { + float z = data[i*comp+k] * 255 + 0.5f; + if (z < 0) z = 0; + if (z > 255) z = 255; + output[i*comp + k] = (stbi_uc) stbi__float2int(z); + } + } + STBI_FREE(data); + return output; +} +#endif + +////////////////////////////////////////////////////////////////////////////// +// +// "baseline" JPEG/JFIF decoder +// +// simple implementation +// - doesn't support delayed output of y-dimension +// - simple interface (only one output format: 8-bit interleaved RGB) +// - doesn't try to recover corrupt jpegs +// - doesn't allow partial loading, loading multiple at once +// - still fast on x86 (copying globals into locals doesn't help x86) +// - allocates lots of intermediate memory (full size of all components) +// - non-interleaved case requires this anyway +// - allows good upsampling (see next) +// high-quality +// - upsampled channels are bilinearly interpolated, even across blocks +// - quality integer IDCT derived from IJG's 'slow' +// performance +// - fast huffman; reasonable integer IDCT +// - some SIMD kernels for common paths on targets with SSE2/NEON +// - uses a lot of intermediate memory, could cache poorly + +#ifndef STBI_NO_JPEG + +// huffman decoding acceleration +#define FAST_BITS 9 // larger handles more cases; smaller stomps less cache + +typedef struct +{ + stbi_uc fast[1 << FAST_BITS]; + // weirdly, repacking this into AoS is a 10% speed loss, instead of a win + stbi__uint16 code[256]; + stbi_uc values[256]; + stbi_uc size[257]; + unsigned int maxcode[18]; + int delta[17]; // old 'firstsymbol' - old 'firstcode' +} stbi__huffman; + +typedef struct +{ + stbi__context *s; + stbi__huffman huff_dc[4]; + stbi__huffman huff_ac[4]; + stbi__uint16 dequant[4][64]; + stbi__int16 fast_ac[4][1 << FAST_BITS]; + +// sizes for components, interleaved MCUs + int img_h_max, img_v_max; + int img_mcu_x, img_mcu_y; + int img_mcu_w, img_mcu_h; + +// definition of jpeg image component + struct + { + int id; + int h,v; + int tq; + int hd,ha; + int dc_pred; + + int x,y,w2,h2; + stbi_uc *data; + void *raw_data, *raw_coeff; + stbi_uc *linebuf; + short *coeff; // progressive only + int coeff_w, coeff_h; // number of 8x8 coefficient blocks + } img_comp[4]; + + stbi__uint32 code_buffer; // jpeg entropy-coded buffer + int code_bits; // number of valid bits + unsigned char marker; // marker seen while filling entropy buffer + int nomore; // flag if we saw a marker so must stop + + int progressive; + int spec_start; + int spec_end; + int succ_high; + int succ_low; + int eob_run; + int jfif; + int app14_color_transform; // Adobe APP14 tag + int rgb; + + int scan_n, order[4]; + int restart_interval, todo; + +// kernels + void (*idct_block_kernel)(stbi_uc *out, int out_stride, short data[64]); + void (*YCbCr_to_RGB_kernel)(stbi_uc *out, const stbi_uc *y, const stbi_uc *pcb, const stbi_uc *pcr, int count, int step); + stbi_uc *(*resample_row_hv_2_kernel)(stbi_uc *out, stbi_uc *in_near, stbi_uc *in_far, int w, int hs); +} stbi__jpeg; + +static int stbi__build_huffman(stbi__huffman *h, int *count) +{ + int i,j,k=0; + unsigned int code; + // build size list for each symbol (from JPEG spec) + for (i=0; i < 16; ++i) { + for (j=0; j < count[i]; ++j) { + h->size[k++] = (stbi_uc) (i+1); + if(k >= 257) return stbi__err("bad size list","Corrupt JPEG"); + } + } + h->size[k] = 0; + + // compute actual symbols (from jpeg spec) + code = 0; + k = 0; + for(j=1; j <= 16; ++j) { + // compute delta to add to code to compute symbol id + h->delta[j] = k - code; + if (h->size[k] == j) { + while (h->size[k] == j) + h->code[k++] = (stbi__uint16) (code++); + if (code-1 >= (1u << j)) return stbi__err("bad code lengths","Corrupt JPEG"); + } + // compute largest code + 1 for this size, preshifted as needed later + h->maxcode[j] = code << (16-j); + code <<= 1; + } + h->maxcode[j] = 0xffffffff; + + // build non-spec acceleration table; 255 is flag for not-accelerated + memset(h->fast, 255, 1 << FAST_BITS); + for (i=0; i < k; ++i) { + int s = h->size[i]; + if (s <= FAST_BITS) { + int c = h->code[i] << (FAST_BITS-s); + int m = 1 << (FAST_BITS-s); + for (j=0; j < m; ++j) { + h->fast[c+j] = (stbi_uc) i; + } + } + } + return 1; +} + +// build a table that decodes both magnitude and value of small ACs in +// one go. +static void stbi__build_fast_ac(stbi__int16 *fast_ac, stbi__huffman *h) +{ + int i; + for (i=0; i < (1 << FAST_BITS); ++i) { + stbi_uc fast = h->fast[i]; + fast_ac[i] = 0; + if (fast < 255) { + int rs = h->values[fast]; + int run = (rs >> 4) & 15; + int magbits = rs & 15; + int len = h->size[fast]; + + if (magbits && len + magbits <= FAST_BITS) { + // magnitude code followed by receive_extend code + int k = ((i << len) & ((1 << FAST_BITS) - 1)) >> (FAST_BITS - magbits); + int m = 1 << (magbits - 1); + if (k < m) k += (~0U << magbits) + 1; + // if the result is small enough, we can fit it in fast_ac table + if (k >= -128 && k <= 127) + fast_ac[i] = (stbi__int16) ((k * 256) + (run * 16) + (len + magbits)); + } + } + } +} + +static void stbi__grow_buffer_unsafe(stbi__jpeg *j) +{ + do { + unsigned int b = j->nomore ? 0 : stbi__get8(j->s); + if (b == 0xff) { + int c = stbi__get8(j->s); + while (c == 0xff) c = stbi__get8(j->s); // consume fill bytes + if (c != 0) { + j->marker = (unsigned char) c; + j->nomore = 1; + return; + } + } + j->code_buffer |= b << (24 - j->code_bits); + j->code_bits += 8; + } while (j->code_bits <= 24); +} + +// (1 << n) - 1 +static const stbi__uint32 stbi__bmask[17]={0,1,3,7,15,31,63,127,255,511,1023,2047,4095,8191,16383,32767,65535}; + +// decode a jpeg huffman value from the bitstream +stbi_inline static int stbi__jpeg_huff_decode(stbi__jpeg *j, stbi__huffman *h) +{ + unsigned int temp; + int c,k; + + if (j->code_bits < 16) stbi__grow_buffer_unsafe(j); + + // look at the top FAST_BITS and determine what symbol ID it is, + // if the code is <= FAST_BITS + c = (j->code_buffer >> (32 - FAST_BITS)) & ((1 << FAST_BITS)-1); + k = h->fast[c]; + if (k < 255) { + int s = h->size[k]; + if (s > j->code_bits) + return -1; + j->code_buffer <<= s; + j->code_bits -= s; + return h->values[k]; + } + + // naive test is to shift the code_buffer down so k bits are + // valid, then test against maxcode. To speed this up, we've + // preshifted maxcode left so that it has (16-k) 0s at the + // end; in other words, regardless of the number of bits, it + // wants to be compared against something shifted to have 16; + // that way we don't need to shift inside the loop. + temp = j->code_buffer >> 16; + for (k=FAST_BITS+1 ; ; ++k) + if (temp < h->maxcode[k]) + break; + if (k == 17) { + // error! code not found + j->code_bits -= 16; + return -1; + } + + if (k > j->code_bits) + return -1; + + // convert the huffman code to the symbol id + c = ((j->code_buffer >> (32 - k)) & stbi__bmask[k]) + h->delta[k]; + if(c < 0 || c >= 256) // symbol id out of bounds! + return -1; + STBI_ASSERT((((j->code_buffer) >> (32 - h->size[c])) & stbi__bmask[h->size[c]]) == h->code[c]); + + // convert the id to a symbol + j->code_bits -= k; + j->code_buffer <<= k; + return h->values[c]; +} + +// bias[n] = (-1<code_bits < n) stbi__grow_buffer_unsafe(j); + if (j->code_bits < n) return 0; // ran out of bits from stream, return 0s intead of continuing + + sgn = j->code_buffer >> 31; // sign bit always in MSB; 0 if MSB clear (positive), 1 if MSB set (negative) + k = stbi_lrot(j->code_buffer, n); + j->code_buffer = k & ~stbi__bmask[n]; + k &= stbi__bmask[n]; + j->code_bits -= n; + return k + (stbi__jbias[n] & (sgn - 1)); +} + +// get some unsigned bits +stbi_inline static int stbi__jpeg_get_bits(stbi__jpeg *j, int n) +{ + unsigned int k; + if (j->code_bits < n) stbi__grow_buffer_unsafe(j); + if (j->code_bits < n) return 0; // ran out of bits from stream, return 0s intead of continuing + k = stbi_lrot(j->code_buffer, n); + j->code_buffer = k & ~stbi__bmask[n]; + k &= stbi__bmask[n]; + j->code_bits -= n; + return k; +} + +stbi_inline static int stbi__jpeg_get_bit(stbi__jpeg *j) +{ + unsigned int k; + if (j->code_bits < 1) stbi__grow_buffer_unsafe(j); + if (j->code_bits < 1) return 0; // ran out of bits from stream, return 0s intead of continuing + k = j->code_buffer; + j->code_buffer <<= 1; + --j->code_bits; + return k & 0x80000000; +} + +// given a value that's at position X in the zigzag stream, +// where does it appear in the 8x8 matrix coded as row-major? +static const stbi_uc stbi__jpeg_dezigzag[64+15] = +{ + 0, 1, 8, 16, 9, 2, 3, 10, + 17, 24, 32, 25, 18, 11, 4, 5, + 12, 19, 26, 33, 40, 48, 41, 34, + 27, 20, 13, 6, 7, 14, 21, 28, + 35, 42, 49, 56, 57, 50, 43, 36, + 29, 22, 15, 23, 30, 37, 44, 51, + 58, 59, 52, 45, 38, 31, 39, 46, + 53, 60, 61, 54, 47, 55, 62, 63, + // let corrupt input sample past end + 63, 63, 63, 63, 63, 63, 63, 63, + 63, 63, 63, 63, 63, 63, 63 +}; + +// decode one 64-entry block-- +static int stbi__jpeg_decode_block(stbi__jpeg *j, short data[64], stbi__huffman *hdc, stbi__huffman *hac, stbi__int16 *fac, int b, stbi__uint16 *dequant) +{ + int diff,dc,k; + int t; + + if (j->code_bits < 16) stbi__grow_buffer_unsafe(j); + t = stbi__jpeg_huff_decode(j, hdc); + if (t < 0 || t > 15) return stbi__err("bad huffman code","Corrupt JPEG"); + + // 0 all the ac values now so we can do it 32-bits at a time + memset(data,0,64*sizeof(data[0])); + + diff = t ? stbi__extend_receive(j, t) : 0; + if (!stbi__addints_valid(j->img_comp[b].dc_pred, diff)) return stbi__err("bad delta","Corrupt JPEG"); + dc = j->img_comp[b].dc_pred + diff; + j->img_comp[b].dc_pred = dc; + if (!stbi__mul2shorts_valid(dc, dequant[0])) return stbi__err("can't merge dc and ac", "Corrupt JPEG"); + data[0] = (short) (dc * dequant[0]); + + // decode AC components, see JPEG spec + k = 1; + do { + unsigned int zig; + int c,r,s; + if (j->code_bits < 16) stbi__grow_buffer_unsafe(j); + c = (j->code_buffer >> (32 - FAST_BITS)) & ((1 << FAST_BITS)-1); + r = fac[c]; + if (r) { // fast-AC path + k += (r >> 4) & 15; // run + s = r & 15; // combined length + if (s > j->code_bits) return stbi__err("bad huffman code", "Combined length longer than code bits available"); + j->code_buffer <<= s; + j->code_bits -= s; + // decode into unzigzag'd location + zig = stbi__jpeg_dezigzag[k++]; + data[zig] = (short) ((r >> 8) * dequant[zig]); + } else { + int rs = stbi__jpeg_huff_decode(j, hac); + if (rs < 0) return stbi__err("bad huffman code","Corrupt JPEG"); + s = rs & 15; + r = rs >> 4; + if (s == 0) { + if (rs != 0xf0) break; // end block + k += 16; + } else { + k += r; + // decode into unzigzag'd location + zig = stbi__jpeg_dezigzag[k++]; + data[zig] = (short) (stbi__extend_receive(j,s) * dequant[zig]); + } + } + } while (k < 64); + return 1; +} + +static int stbi__jpeg_decode_block_prog_dc(stbi__jpeg *j, short data[64], stbi__huffman *hdc, int b) +{ + int diff,dc; + int t; + if (j->spec_end != 0) return stbi__err("can't merge dc and ac", "Corrupt JPEG"); + + if (j->code_bits < 16) stbi__grow_buffer_unsafe(j); + + if (j->succ_high == 0) { + // first scan for DC coefficient, must be first + memset(data,0,64*sizeof(data[0])); // 0 all the ac values now + t = stbi__jpeg_huff_decode(j, hdc); + if (t < 0 || t > 15) return stbi__err("can't merge dc and ac", "Corrupt JPEG"); + diff = t ? stbi__extend_receive(j, t) : 0; + + if (!stbi__addints_valid(j->img_comp[b].dc_pred, diff)) return stbi__err("bad delta", "Corrupt JPEG"); + dc = j->img_comp[b].dc_pred + diff; + j->img_comp[b].dc_pred = dc; + if (!stbi__mul2shorts_valid(dc, 1 << j->succ_low)) return stbi__err("can't merge dc and ac", "Corrupt JPEG"); + data[0] = (short) (dc * (1 << j->succ_low)); + } else { + // refinement scan for DC coefficient + if (stbi__jpeg_get_bit(j)) + data[0] += (short) (1 << j->succ_low); + } + return 1; +} + +// @OPTIMIZE: store non-zigzagged during the decode passes, +// and only de-zigzag when dequantizing +static int stbi__jpeg_decode_block_prog_ac(stbi__jpeg *j, short data[64], stbi__huffman *hac, stbi__int16 *fac) +{ + int k; + if (j->spec_start == 0) return stbi__err("can't merge dc and ac", "Corrupt JPEG"); + + if (j->succ_high == 0) { + int shift = j->succ_low; + + if (j->eob_run) { + --j->eob_run; + return 1; + } + + k = j->spec_start; + do { + unsigned int zig; + int c,r,s; + if (j->code_bits < 16) stbi__grow_buffer_unsafe(j); + c = (j->code_buffer >> (32 - FAST_BITS)) & ((1 << FAST_BITS)-1); + r = fac[c]; + if (r) { // fast-AC path + k += (r >> 4) & 15; // run + s = r & 15; // combined length + if (s > j->code_bits) return stbi__err("bad huffman code", "Combined length longer than code bits available"); + j->code_buffer <<= s; + j->code_bits -= s; + zig = stbi__jpeg_dezigzag[k++]; + data[zig] = (short) ((r >> 8) * (1 << shift)); + } else { + int rs = stbi__jpeg_huff_decode(j, hac); + if (rs < 0) return stbi__err("bad huffman code","Corrupt JPEG"); + s = rs & 15; + r = rs >> 4; + if (s == 0) { + if (r < 15) { + j->eob_run = (1 << r); + if (r) + j->eob_run += stbi__jpeg_get_bits(j, r); + --j->eob_run; + break; + } + k += 16; + } else { + k += r; + zig = stbi__jpeg_dezigzag[k++]; + data[zig] = (short) (stbi__extend_receive(j,s) * (1 << shift)); + } + } + } while (k <= j->spec_end); + } else { + // refinement scan for these AC coefficients + + short bit = (short) (1 << j->succ_low); + + if (j->eob_run) { + --j->eob_run; + for (k = j->spec_start; k <= j->spec_end; ++k) { + short *p = &data[stbi__jpeg_dezigzag[k]]; + if (*p != 0) + if (stbi__jpeg_get_bit(j)) + if ((*p & bit)==0) { + if (*p > 0) + *p += bit; + else + *p -= bit; + } + } + } else { + k = j->spec_start; + do { + int r,s; + int rs = stbi__jpeg_huff_decode(j, hac); // @OPTIMIZE see if we can use the fast path here, advance-by-r is so slow, eh + if (rs < 0) return stbi__err("bad huffman code","Corrupt JPEG"); + s = rs & 15; + r = rs >> 4; + if (s == 0) { + if (r < 15) { + j->eob_run = (1 << r) - 1; + if (r) + j->eob_run += stbi__jpeg_get_bits(j, r); + r = 64; // force end of block + } else { + // r=15 s=0 should write 16 0s, so we just do + // a run of 15 0s and then write s (which is 0), + // so we don't have to do anything special here + } + } else { + if (s != 1) return stbi__err("bad huffman code", "Corrupt JPEG"); + // sign bit + if (stbi__jpeg_get_bit(j)) + s = bit; + else + s = -bit; + } + + // advance by r + while (k <= j->spec_end) { + short *p = &data[stbi__jpeg_dezigzag[k++]]; + if (*p != 0) { + if (stbi__jpeg_get_bit(j)) + if ((*p & bit)==0) { + if (*p > 0) + *p += bit; + else + *p -= bit; + } + } else { + if (r == 0) { + *p = (short) s; + break; + } + --r; + } + } + } while (k <= j->spec_end); + } + } + return 1; +} + +// take a -128..127 value and stbi__clamp it and convert to 0..255 +stbi_inline static stbi_uc stbi__clamp(int x) +{ + // trick to use a single test to catch both cases + if ((unsigned int) x > 255) { + if (x < 0) return 0; + if (x > 255) return 255; + } + return (stbi_uc) x; +} + +#define stbi__f2f(x) ((int) (((x) * 4096 + 0.5))) +#define stbi__fsh(x) ((x) * 4096) + +// derived from jidctint -- DCT_ISLOW +#define STBI__IDCT_1D(s0,s1,s2,s3,s4,s5,s6,s7) \ + int t0,t1,t2,t3,p1,p2,p3,p4,p5,x0,x1,x2,x3; \ + p2 = s2; \ + p3 = s6; \ + p1 = (p2+p3) * stbi__f2f(0.5411961f); \ + t2 = p1 + p3*stbi__f2f(-1.847759065f); \ + t3 = p1 + p2*stbi__f2f( 0.765366865f); \ + p2 = s0; \ + p3 = s4; \ + t0 = stbi__fsh(p2+p3); \ + t1 = stbi__fsh(p2-p3); \ + x0 = t0+t3; \ + x3 = t0-t3; \ + x1 = t1+t2; \ + x2 = t1-t2; \ + t0 = s7; \ + t1 = s5; \ + t2 = s3; \ + t3 = s1; \ + p3 = t0+t2; \ + p4 = t1+t3; \ + p1 = t0+t3; \ + p2 = t1+t2; \ + p5 = (p3+p4)*stbi__f2f( 1.175875602f); \ + t0 = t0*stbi__f2f( 0.298631336f); \ + t1 = t1*stbi__f2f( 2.053119869f); \ + t2 = t2*stbi__f2f( 3.072711026f); \ + t3 = t3*stbi__f2f( 1.501321110f); \ + p1 = p5 + p1*stbi__f2f(-0.899976223f); \ + p2 = p5 + p2*stbi__f2f(-2.562915447f); \ + p3 = p3*stbi__f2f(-1.961570560f); \ + p4 = p4*stbi__f2f(-0.390180644f); \ + t3 += p1+p4; \ + t2 += p2+p3; \ + t1 += p2+p4; \ + t0 += p1+p3; + +static void stbi__idct_block(stbi_uc *out, int out_stride, short data[64]) +{ + int i,val[64],*v=val; + stbi_uc *o; + short *d = data; + + // columns + for (i=0; i < 8; ++i,++d, ++v) { + // if all zeroes, shortcut -- this avoids dequantizing 0s and IDCTing + if (d[ 8]==0 && d[16]==0 && d[24]==0 && d[32]==0 + && d[40]==0 && d[48]==0 && d[56]==0) { + // no shortcut 0 seconds + // (1|2|3|4|5|6|7)==0 0 seconds + // all separate -0.047 seconds + // 1 && 2|3 && 4|5 && 6|7: -0.047 seconds + int dcterm = d[0]*4; + v[0] = v[8] = v[16] = v[24] = v[32] = v[40] = v[48] = v[56] = dcterm; + } else { + STBI__IDCT_1D(d[ 0],d[ 8],d[16],d[24],d[32],d[40],d[48],d[56]) + // constants scaled things up by 1<<12; let's bring them back + // down, but keep 2 extra bits of precision + x0 += 512; x1 += 512; x2 += 512; x3 += 512; + v[ 0] = (x0+t3) >> 10; + v[56] = (x0-t3) >> 10; + v[ 8] = (x1+t2) >> 10; + v[48] = (x1-t2) >> 10; + v[16] = (x2+t1) >> 10; + v[40] = (x2-t1) >> 10; + v[24] = (x3+t0) >> 10; + v[32] = (x3-t0) >> 10; + } + } + + for (i=0, v=val, o=out; i < 8; ++i,v+=8,o+=out_stride) { + // no fast case since the first 1D IDCT spread components out + STBI__IDCT_1D(v[0],v[1],v[2],v[3],v[4],v[5],v[6],v[7]) + // constants scaled things up by 1<<12, plus we had 1<<2 from first + // loop, plus horizontal and vertical each scale by sqrt(8) so together + // we've got an extra 1<<3, so 1<<17 total we need to remove. + // so we want to round that, which means adding 0.5 * 1<<17, + // aka 65536. Also, we'll end up with -128 to 127 that we want + // to encode as 0..255 by adding 128, so we'll add that before the shift + x0 += 65536 + (128<<17); + x1 += 65536 + (128<<17); + x2 += 65536 + (128<<17); + x3 += 65536 + (128<<17); + // tried computing the shifts into temps, or'ing the temps to see + // if any were out of range, but that was slower + o[0] = stbi__clamp((x0+t3) >> 17); + o[7] = stbi__clamp((x0-t3) >> 17); + o[1] = stbi__clamp((x1+t2) >> 17); + o[6] = stbi__clamp((x1-t2) >> 17); + o[2] = stbi__clamp((x2+t1) >> 17); + o[5] = stbi__clamp((x2-t1) >> 17); + o[3] = stbi__clamp((x3+t0) >> 17); + o[4] = stbi__clamp((x3-t0) >> 17); + } +} + +#ifdef STBI_SSE2 +// sse2 integer IDCT. not the fastest possible implementation but it +// produces bit-identical results to the generic C version so it's +// fully "transparent". +static void stbi__idct_simd(stbi_uc *out, int out_stride, short data[64]) +{ + // This is constructed to match our regular (generic) integer IDCT exactly. + __m128i row0, row1, row2, row3, row4, row5, row6, row7; + __m128i tmp; + + // dot product constant: even elems=x, odd elems=y + #define dct_const(x,y) _mm_setr_epi16((x),(y),(x),(y),(x),(y),(x),(y)) + + // out(0) = c0[even]*x + c0[odd]*y (c0, x, y 16-bit, out 32-bit) + // out(1) = c1[even]*x + c1[odd]*y + #define dct_rot(out0,out1, x,y,c0,c1) \ + __m128i c0##lo = _mm_unpacklo_epi16((x),(y)); \ + __m128i c0##hi = _mm_unpackhi_epi16((x),(y)); \ + __m128i out0##_l = _mm_madd_epi16(c0##lo, c0); \ + __m128i out0##_h = _mm_madd_epi16(c0##hi, c0); \ + __m128i out1##_l = _mm_madd_epi16(c0##lo, c1); \ + __m128i out1##_h = _mm_madd_epi16(c0##hi, c1) + + // out = in << 12 (in 16-bit, out 32-bit) + #define dct_widen(out, in) \ + __m128i out##_l = _mm_srai_epi32(_mm_unpacklo_epi16(_mm_setzero_si128(), (in)), 4); \ + __m128i out##_h = _mm_srai_epi32(_mm_unpackhi_epi16(_mm_setzero_si128(), (in)), 4) + + // wide add + #define dct_wadd(out, a, b) \ + __m128i out##_l = _mm_add_epi32(a##_l, b##_l); \ + __m128i out##_h = _mm_add_epi32(a##_h, b##_h) + + // wide sub + #define dct_wsub(out, a, b) \ + __m128i out##_l = _mm_sub_epi32(a##_l, b##_l); \ + __m128i out##_h = _mm_sub_epi32(a##_h, b##_h) + + // butterfly a/b, add bias, then shift by "s" and pack + #define dct_bfly32o(out0, out1, a,b,bias,s) \ + { \ + __m128i abiased_l = _mm_add_epi32(a##_l, bias); \ + __m128i abiased_h = _mm_add_epi32(a##_h, bias); \ + dct_wadd(sum, abiased, b); \ + dct_wsub(dif, abiased, b); \ + out0 = _mm_packs_epi32(_mm_srai_epi32(sum_l, s), _mm_srai_epi32(sum_h, s)); \ + out1 = _mm_packs_epi32(_mm_srai_epi32(dif_l, s), _mm_srai_epi32(dif_h, s)); \ + } + + // 8-bit interleave step (for transposes) + #define dct_interleave8(a, b) \ + tmp = a; \ + a = _mm_unpacklo_epi8(a, b); \ + b = _mm_unpackhi_epi8(tmp, b) + + // 16-bit interleave step (for transposes) + #define dct_interleave16(a, b) \ + tmp = a; \ + a = _mm_unpacklo_epi16(a, b); \ + b = _mm_unpackhi_epi16(tmp, b) + + #define dct_pass(bias,shift) \ + { \ + /* even part */ \ + dct_rot(t2e,t3e, row2,row6, rot0_0,rot0_1); \ + __m128i sum04 = _mm_add_epi16(row0, row4); \ + __m128i dif04 = _mm_sub_epi16(row0, row4); \ + dct_widen(t0e, sum04); \ + dct_widen(t1e, dif04); \ + dct_wadd(x0, t0e, t3e); \ + dct_wsub(x3, t0e, t3e); \ + dct_wadd(x1, t1e, t2e); \ + dct_wsub(x2, t1e, t2e); \ + /* odd part */ \ + dct_rot(y0o,y2o, row7,row3, rot2_0,rot2_1); \ + dct_rot(y1o,y3o, row5,row1, rot3_0,rot3_1); \ + __m128i sum17 = _mm_add_epi16(row1, row7); \ + __m128i sum35 = _mm_add_epi16(row3, row5); \ + dct_rot(y4o,y5o, sum17,sum35, rot1_0,rot1_1); \ + dct_wadd(x4, y0o, y4o); \ + dct_wadd(x5, y1o, y5o); \ + dct_wadd(x6, y2o, y5o); \ + dct_wadd(x7, y3o, y4o); \ + dct_bfly32o(row0,row7, x0,x7,bias,shift); \ + dct_bfly32o(row1,row6, x1,x6,bias,shift); \ + dct_bfly32o(row2,row5, x2,x5,bias,shift); \ + dct_bfly32o(row3,row4, x3,x4,bias,shift); \ + } + + __m128i rot0_0 = dct_const(stbi__f2f(0.5411961f), stbi__f2f(0.5411961f) + stbi__f2f(-1.847759065f)); + __m128i rot0_1 = dct_const(stbi__f2f(0.5411961f) + stbi__f2f( 0.765366865f), stbi__f2f(0.5411961f)); + __m128i rot1_0 = dct_const(stbi__f2f(1.175875602f) + stbi__f2f(-0.899976223f), stbi__f2f(1.175875602f)); + __m128i rot1_1 = dct_const(stbi__f2f(1.175875602f), stbi__f2f(1.175875602f) + stbi__f2f(-2.562915447f)); + __m128i rot2_0 = dct_const(stbi__f2f(-1.961570560f) + stbi__f2f( 0.298631336f), stbi__f2f(-1.961570560f)); + __m128i rot2_1 = dct_const(stbi__f2f(-1.961570560f), stbi__f2f(-1.961570560f) + stbi__f2f( 3.072711026f)); + __m128i rot3_0 = dct_const(stbi__f2f(-0.390180644f) + stbi__f2f( 2.053119869f), stbi__f2f(-0.390180644f)); + __m128i rot3_1 = dct_const(stbi__f2f(-0.390180644f), stbi__f2f(-0.390180644f) + stbi__f2f( 1.501321110f)); + + // rounding biases in column/row passes, see stbi__idct_block for explanation. + __m128i bias_0 = _mm_set1_epi32(512); + __m128i bias_1 = _mm_set1_epi32(65536 + (128<<17)); + + // load + row0 = _mm_load_si128((const __m128i *) (data + 0*8)); + row1 = _mm_load_si128((const __m128i *) (data + 1*8)); + row2 = _mm_load_si128((const __m128i *) (data + 2*8)); + row3 = _mm_load_si128((const __m128i *) (data + 3*8)); + row4 = _mm_load_si128((const __m128i *) (data + 4*8)); + row5 = _mm_load_si128((const __m128i *) (data + 5*8)); + row6 = _mm_load_si128((const __m128i *) (data + 6*8)); + row7 = _mm_load_si128((const __m128i *) (data + 7*8)); + + // column pass + dct_pass(bias_0, 10); + + { + // 16bit 8x8 transpose pass 1 + dct_interleave16(row0, row4); + dct_interleave16(row1, row5); + dct_interleave16(row2, row6); + dct_interleave16(row3, row7); + + // transpose pass 2 + dct_interleave16(row0, row2); + dct_interleave16(row1, row3); + dct_interleave16(row4, row6); + dct_interleave16(row5, row7); + + // transpose pass 3 + dct_interleave16(row0, row1); + dct_interleave16(row2, row3); + dct_interleave16(row4, row5); + dct_interleave16(row6, row7); + } + + // row pass + dct_pass(bias_1, 17); + + { + // pack + __m128i p0 = _mm_packus_epi16(row0, row1); // a0a1a2a3...a7b0b1b2b3...b7 + __m128i p1 = _mm_packus_epi16(row2, row3); + __m128i p2 = _mm_packus_epi16(row4, row5); + __m128i p3 = _mm_packus_epi16(row6, row7); + + // 8bit 8x8 transpose pass 1 + dct_interleave8(p0, p2); // a0e0a1e1... + dct_interleave8(p1, p3); // c0g0c1g1... + + // transpose pass 2 + dct_interleave8(p0, p1); // a0c0e0g0... + dct_interleave8(p2, p3); // b0d0f0h0... + + // transpose pass 3 + dct_interleave8(p0, p2); // a0b0c0d0... + dct_interleave8(p1, p3); // a4b4c4d4... + + // store + _mm_storel_epi64((__m128i *) out, p0); out += out_stride; + _mm_storel_epi64((__m128i *) out, _mm_shuffle_epi32(p0, 0x4e)); out += out_stride; + _mm_storel_epi64((__m128i *) out, p2); out += out_stride; + _mm_storel_epi64((__m128i *) out, _mm_shuffle_epi32(p2, 0x4e)); out += out_stride; + _mm_storel_epi64((__m128i *) out, p1); out += out_stride; + _mm_storel_epi64((__m128i *) out, _mm_shuffle_epi32(p1, 0x4e)); out += out_stride; + _mm_storel_epi64((__m128i *) out, p3); out += out_stride; + _mm_storel_epi64((__m128i *) out, _mm_shuffle_epi32(p3, 0x4e)); + } + +#undef dct_const +#undef dct_rot +#undef dct_widen +#undef dct_wadd +#undef dct_wsub +#undef dct_bfly32o +#undef dct_interleave8 +#undef dct_interleave16 +#undef dct_pass +} + +#endif // STBI_SSE2 + +#ifdef STBI_NEON + +// NEON integer IDCT. should produce bit-identical +// results to the generic C version. +static void stbi__idct_simd(stbi_uc *out, int out_stride, short data[64]) +{ + int16x8_t row0, row1, row2, row3, row4, row5, row6, row7; + + int16x4_t rot0_0 = vdup_n_s16(stbi__f2f(0.5411961f)); + int16x4_t rot0_1 = vdup_n_s16(stbi__f2f(-1.847759065f)); + int16x4_t rot0_2 = vdup_n_s16(stbi__f2f( 0.765366865f)); + int16x4_t rot1_0 = vdup_n_s16(stbi__f2f( 1.175875602f)); + int16x4_t rot1_1 = vdup_n_s16(stbi__f2f(-0.899976223f)); + int16x4_t rot1_2 = vdup_n_s16(stbi__f2f(-2.562915447f)); + int16x4_t rot2_0 = vdup_n_s16(stbi__f2f(-1.961570560f)); + int16x4_t rot2_1 = vdup_n_s16(stbi__f2f(-0.390180644f)); + int16x4_t rot3_0 = vdup_n_s16(stbi__f2f( 0.298631336f)); + int16x4_t rot3_1 = vdup_n_s16(stbi__f2f( 2.053119869f)); + int16x4_t rot3_2 = vdup_n_s16(stbi__f2f( 3.072711026f)); + int16x4_t rot3_3 = vdup_n_s16(stbi__f2f( 1.501321110f)); + +#define dct_long_mul(out, inq, coeff) \ + int32x4_t out##_l = vmull_s16(vget_low_s16(inq), coeff); \ + int32x4_t out##_h = vmull_s16(vget_high_s16(inq), coeff) + +#define dct_long_mac(out, acc, inq, coeff) \ + int32x4_t out##_l = vmlal_s16(acc##_l, vget_low_s16(inq), coeff); \ + int32x4_t out##_h = vmlal_s16(acc##_h, vget_high_s16(inq), coeff) + +#define dct_widen(out, inq) \ + int32x4_t out##_l = vshll_n_s16(vget_low_s16(inq), 12); \ + int32x4_t out##_h = vshll_n_s16(vget_high_s16(inq), 12) + +// wide add +#define dct_wadd(out, a, b) \ + int32x4_t out##_l = vaddq_s32(a##_l, b##_l); \ + int32x4_t out##_h = vaddq_s32(a##_h, b##_h) + +// wide sub +#define dct_wsub(out, a, b) \ + int32x4_t out##_l = vsubq_s32(a##_l, b##_l); \ + int32x4_t out##_h = vsubq_s32(a##_h, b##_h) + +// butterfly a/b, then shift using "shiftop" by "s" and pack +#define dct_bfly32o(out0,out1, a,b,shiftop,s) \ + { \ + dct_wadd(sum, a, b); \ + dct_wsub(dif, a, b); \ + out0 = vcombine_s16(shiftop(sum_l, s), shiftop(sum_h, s)); \ + out1 = vcombine_s16(shiftop(dif_l, s), shiftop(dif_h, s)); \ + } + +#define dct_pass(shiftop, shift) \ + { \ + /* even part */ \ + int16x8_t sum26 = vaddq_s16(row2, row6); \ + dct_long_mul(p1e, sum26, rot0_0); \ + dct_long_mac(t2e, p1e, row6, rot0_1); \ + dct_long_mac(t3e, p1e, row2, rot0_2); \ + int16x8_t sum04 = vaddq_s16(row0, row4); \ + int16x8_t dif04 = vsubq_s16(row0, row4); \ + dct_widen(t0e, sum04); \ + dct_widen(t1e, dif04); \ + dct_wadd(x0, t0e, t3e); \ + dct_wsub(x3, t0e, t3e); \ + dct_wadd(x1, t1e, t2e); \ + dct_wsub(x2, t1e, t2e); \ + /* odd part */ \ + int16x8_t sum15 = vaddq_s16(row1, row5); \ + int16x8_t sum17 = vaddq_s16(row1, row7); \ + int16x8_t sum35 = vaddq_s16(row3, row5); \ + int16x8_t sum37 = vaddq_s16(row3, row7); \ + int16x8_t sumodd = vaddq_s16(sum17, sum35); \ + dct_long_mul(p5o, sumodd, rot1_0); \ + dct_long_mac(p1o, p5o, sum17, rot1_1); \ + dct_long_mac(p2o, p5o, sum35, rot1_2); \ + dct_long_mul(p3o, sum37, rot2_0); \ + dct_long_mul(p4o, sum15, rot2_1); \ + dct_wadd(sump13o, p1o, p3o); \ + dct_wadd(sump24o, p2o, p4o); \ + dct_wadd(sump23o, p2o, p3o); \ + dct_wadd(sump14o, p1o, p4o); \ + dct_long_mac(x4, sump13o, row7, rot3_0); \ + dct_long_mac(x5, sump24o, row5, rot3_1); \ + dct_long_mac(x6, sump23o, row3, rot3_2); \ + dct_long_mac(x7, sump14o, row1, rot3_3); \ + dct_bfly32o(row0,row7, x0,x7,shiftop,shift); \ + dct_bfly32o(row1,row6, x1,x6,shiftop,shift); \ + dct_bfly32o(row2,row5, x2,x5,shiftop,shift); \ + dct_bfly32o(row3,row4, x3,x4,shiftop,shift); \ + } + + // load + row0 = vld1q_s16(data + 0*8); + row1 = vld1q_s16(data + 1*8); + row2 = vld1q_s16(data + 2*8); + row3 = vld1q_s16(data + 3*8); + row4 = vld1q_s16(data + 4*8); + row5 = vld1q_s16(data + 5*8); + row6 = vld1q_s16(data + 6*8); + row7 = vld1q_s16(data + 7*8); + + // add DC bias + row0 = vaddq_s16(row0, vsetq_lane_s16(1024, vdupq_n_s16(0), 0)); + + // column pass + dct_pass(vrshrn_n_s32, 10); + + // 16bit 8x8 transpose + { +// these three map to a single VTRN.16, VTRN.32, and VSWP, respectively. +// whether compilers actually get this is another story, sadly. +#define dct_trn16(x, y) { int16x8x2_t t = vtrnq_s16(x, y); x = t.val[0]; y = t.val[1]; } +#define dct_trn32(x, y) { int32x4x2_t t = vtrnq_s32(vreinterpretq_s32_s16(x), vreinterpretq_s32_s16(y)); x = vreinterpretq_s16_s32(t.val[0]); y = vreinterpretq_s16_s32(t.val[1]); } +#define dct_trn64(x, y) { int16x8_t x0 = x; int16x8_t y0 = y; x = vcombine_s16(vget_low_s16(x0), vget_low_s16(y0)); y = vcombine_s16(vget_high_s16(x0), vget_high_s16(y0)); } + + // pass 1 + dct_trn16(row0, row1); // a0b0a2b2a4b4a6b6 + dct_trn16(row2, row3); + dct_trn16(row4, row5); + dct_trn16(row6, row7); + + // pass 2 + dct_trn32(row0, row2); // a0b0c0d0a4b4c4d4 + dct_trn32(row1, row3); + dct_trn32(row4, row6); + dct_trn32(row5, row7); + + // pass 3 + dct_trn64(row0, row4); // a0b0c0d0e0f0g0h0 + dct_trn64(row1, row5); + dct_trn64(row2, row6); + dct_trn64(row3, row7); + +#undef dct_trn16 +#undef dct_trn32 +#undef dct_trn64 + } + + // row pass + // vrshrn_n_s32 only supports shifts up to 16, we need + // 17. so do a non-rounding shift of 16 first then follow + // up with a rounding shift by 1. + dct_pass(vshrn_n_s32, 16); + + { + // pack and round + uint8x8_t p0 = vqrshrun_n_s16(row0, 1); + uint8x8_t p1 = vqrshrun_n_s16(row1, 1); + uint8x8_t p2 = vqrshrun_n_s16(row2, 1); + uint8x8_t p3 = vqrshrun_n_s16(row3, 1); + uint8x8_t p4 = vqrshrun_n_s16(row4, 1); + uint8x8_t p5 = vqrshrun_n_s16(row5, 1); + uint8x8_t p6 = vqrshrun_n_s16(row6, 1); + uint8x8_t p7 = vqrshrun_n_s16(row7, 1); + + // again, these can translate into one instruction, but often don't. +#define dct_trn8_8(x, y) { uint8x8x2_t t = vtrn_u8(x, y); x = t.val[0]; y = t.val[1]; } +#define dct_trn8_16(x, y) { uint16x4x2_t t = vtrn_u16(vreinterpret_u16_u8(x), vreinterpret_u16_u8(y)); x = vreinterpret_u8_u16(t.val[0]); y = vreinterpret_u8_u16(t.val[1]); } +#define dct_trn8_32(x, y) { uint32x2x2_t t = vtrn_u32(vreinterpret_u32_u8(x), vreinterpret_u32_u8(y)); x = vreinterpret_u8_u32(t.val[0]); y = vreinterpret_u8_u32(t.val[1]); } + + // sadly can't use interleaved stores here since we only write + // 8 bytes to each scan line! + + // 8x8 8-bit transpose pass 1 + dct_trn8_8(p0, p1); + dct_trn8_8(p2, p3); + dct_trn8_8(p4, p5); + dct_trn8_8(p6, p7); + + // pass 2 + dct_trn8_16(p0, p2); + dct_trn8_16(p1, p3); + dct_trn8_16(p4, p6); + dct_trn8_16(p5, p7); + + // pass 3 + dct_trn8_32(p0, p4); + dct_trn8_32(p1, p5); + dct_trn8_32(p2, p6); + dct_trn8_32(p3, p7); + + // store + vst1_u8(out, p0); out += out_stride; + vst1_u8(out, p1); out += out_stride; + vst1_u8(out, p2); out += out_stride; + vst1_u8(out, p3); out += out_stride; + vst1_u8(out, p4); out += out_stride; + vst1_u8(out, p5); out += out_stride; + vst1_u8(out, p6); out += out_stride; + vst1_u8(out, p7); + +#undef dct_trn8_8 +#undef dct_trn8_16 +#undef dct_trn8_32 + } + +#undef dct_long_mul +#undef dct_long_mac +#undef dct_widen +#undef dct_wadd +#undef dct_wsub +#undef dct_bfly32o +#undef dct_pass +} + +#endif // STBI_NEON + +#define STBI__MARKER_none 0xff +// if there's a pending marker from the entropy stream, return that +// otherwise, fetch from the stream and get a marker. if there's no +// marker, return 0xff, which is never a valid marker value +static stbi_uc stbi__get_marker(stbi__jpeg *j) +{ + stbi_uc x; + if (j->marker != STBI__MARKER_none) { x = j->marker; j->marker = STBI__MARKER_none; return x; } + x = stbi__get8(j->s); + if (x != 0xff) return STBI__MARKER_none; + while (x == 0xff) + x = stbi__get8(j->s); // consume repeated 0xff fill bytes + return x; +} + +// in each scan, we'll have scan_n components, and the order +// of the components is specified by order[] +#define STBI__RESTART(x) ((x) >= 0xd0 && (x) <= 0xd7) + +// after a restart interval, stbi__jpeg_reset the entropy decoder and +// the dc prediction +static void stbi__jpeg_reset(stbi__jpeg *j) +{ + j->code_bits = 0; + j->code_buffer = 0; + j->nomore = 0; + j->img_comp[0].dc_pred = j->img_comp[1].dc_pred = j->img_comp[2].dc_pred = j->img_comp[3].dc_pred = 0; + j->marker = STBI__MARKER_none; + j->todo = j->restart_interval ? j->restart_interval : 0x7fffffff; + j->eob_run = 0; + // no more than 1<<31 MCUs if no restart_interal? that's plenty safe, + // since we don't even allow 1<<30 pixels +} + +static int stbi__parse_entropy_coded_data(stbi__jpeg *z) +{ + stbi__jpeg_reset(z); + if (!z->progressive) { + if (z->scan_n == 1) { + int i,j; + STBI_SIMD_ALIGN(short, data[64]); + int n = z->order[0]; + // non-interleaved data, we just need to process one block at a time, + // in trivial scanline order + // number of blocks to do just depends on how many actual "pixels" this + // component has, independent of interleaved MCU blocking and such + int w = (z->img_comp[n].x+7) >> 3; + int h = (z->img_comp[n].y+7) >> 3; + for (j=0; j < h; ++j) { + for (i=0; i < w; ++i) { + int ha = z->img_comp[n].ha; + if (!stbi__jpeg_decode_block(z, data, z->huff_dc+z->img_comp[n].hd, z->huff_ac+ha, z->fast_ac[ha], n, z->dequant[z->img_comp[n].tq])) return 0; + z->idct_block_kernel(z->img_comp[n].data+z->img_comp[n].w2*j*8+i*8, z->img_comp[n].w2, data); + // every data block is an MCU, so countdown the restart interval + if (--z->todo <= 0) { + if (z->code_bits < 24) stbi__grow_buffer_unsafe(z); + // if it's NOT a restart, then just bail, so we get corrupt data + // rather than no data + if (!STBI__RESTART(z->marker)) return 1; + stbi__jpeg_reset(z); + } + } + } + return 1; + } else { // interleaved + int i,j,k,x,y; + STBI_SIMD_ALIGN(short, data[64]); + for (j=0; j < z->img_mcu_y; ++j) { + for (i=0; i < z->img_mcu_x; ++i) { + // scan an interleaved mcu... process scan_n components in order + for (k=0; k < z->scan_n; ++k) { + int n = z->order[k]; + // scan out an mcu's worth of this component; that's just determined + // by the basic H and V specified for the component + for (y=0; y < z->img_comp[n].v; ++y) { + for (x=0; x < z->img_comp[n].h; ++x) { + int x2 = (i*z->img_comp[n].h + x)*8; + int y2 = (j*z->img_comp[n].v + y)*8; + int ha = z->img_comp[n].ha; + if (!stbi__jpeg_decode_block(z, data, z->huff_dc+z->img_comp[n].hd, z->huff_ac+ha, z->fast_ac[ha], n, z->dequant[z->img_comp[n].tq])) return 0; + z->idct_block_kernel(z->img_comp[n].data+z->img_comp[n].w2*y2+x2, z->img_comp[n].w2, data); + } + } + } + // after all interleaved components, that's an interleaved MCU, + // so now count down the restart interval + if (--z->todo <= 0) { + if (z->code_bits < 24) stbi__grow_buffer_unsafe(z); + if (!STBI__RESTART(z->marker)) return 1; + stbi__jpeg_reset(z); + } + } + } + return 1; + } + } else { + if (z->scan_n == 1) { + int i,j; + int n = z->order[0]; + // non-interleaved data, we just need to process one block at a time, + // in trivial scanline order + // number of blocks to do just depends on how many actual "pixels" this + // component has, independent of interleaved MCU blocking and such + int w = (z->img_comp[n].x+7) >> 3; + int h = (z->img_comp[n].y+7) >> 3; + for (j=0; j < h; ++j) { + for (i=0; i < w; ++i) { + short *data = z->img_comp[n].coeff + 64 * (i + j * z->img_comp[n].coeff_w); + if (z->spec_start == 0) { + if (!stbi__jpeg_decode_block_prog_dc(z, data, &z->huff_dc[z->img_comp[n].hd], n)) + return 0; + } else { + int ha = z->img_comp[n].ha; + if (!stbi__jpeg_decode_block_prog_ac(z, data, &z->huff_ac[ha], z->fast_ac[ha])) + return 0; + } + // every data block is an MCU, so countdown the restart interval + if (--z->todo <= 0) { + if (z->code_bits < 24) stbi__grow_buffer_unsafe(z); + if (!STBI__RESTART(z->marker)) return 1; + stbi__jpeg_reset(z); + } + } + } + return 1; + } else { // interleaved + int i,j,k,x,y; + for (j=0; j < z->img_mcu_y; ++j) { + for (i=0; i < z->img_mcu_x; ++i) { + // scan an interleaved mcu... process scan_n components in order + for (k=0; k < z->scan_n; ++k) { + int n = z->order[k]; + // scan out an mcu's worth of this component; that's just determined + // by the basic H and V specified for the component + for (y=0; y < z->img_comp[n].v; ++y) { + for (x=0; x < z->img_comp[n].h; ++x) { + int x2 = (i*z->img_comp[n].h + x); + int y2 = (j*z->img_comp[n].v + y); + short *data = z->img_comp[n].coeff + 64 * (x2 + y2 * z->img_comp[n].coeff_w); + if (!stbi__jpeg_decode_block_prog_dc(z, data, &z->huff_dc[z->img_comp[n].hd], n)) + return 0; + } + } + } + // after all interleaved components, that's an interleaved MCU, + // so now count down the restart interval + if (--z->todo <= 0) { + if (z->code_bits < 24) stbi__grow_buffer_unsafe(z); + if (!STBI__RESTART(z->marker)) return 1; + stbi__jpeg_reset(z); + } + } + } + return 1; + } + } +} + +static void stbi__jpeg_dequantize(short *data, stbi__uint16 *dequant) +{ + int i; + for (i=0; i < 64; ++i) + data[i] *= dequant[i]; +} + +static void stbi__jpeg_finish(stbi__jpeg *z) +{ + if (z->progressive) { + // dequantize and idct the data + int i,j,n; + for (n=0; n < z->s->img_n; ++n) { + int w = (z->img_comp[n].x+7) >> 3; + int h = (z->img_comp[n].y+7) >> 3; + for (j=0; j < h; ++j) { + for (i=0; i < w; ++i) { + short *data = z->img_comp[n].coeff + 64 * (i + j * z->img_comp[n].coeff_w); + stbi__jpeg_dequantize(data, z->dequant[z->img_comp[n].tq]); + z->idct_block_kernel(z->img_comp[n].data+z->img_comp[n].w2*j*8+i*8, z->img_comp[n].w2, data); + } + } + } + } +} + +static int stbi__process_marker(stbi__jpeg *z, int m) +{ + int L; + switch (m) { + case STBI__MARKER_none: // no marker found + return stbi__err("expected marker","Corrupt JPEG"); + + case 0xDD: // DRI - specify restart interval + if (stbi__get16be(z->s) != 4) return stbi__err("bad DRI len","Corrupt JPEG"); + z->restart_interval = stbi__get16be(z->s); + return 1; + + case 0xDB: // DQT - define quantization table + L = stbi__get16be(z->s)-2; + while (L > 0) { + int q = stbi__get8(z->s); + int p = q >> 4, sixteen = (p != 0); + int t = q & 15,i; + if (p != 0 && p != 1) return stbi__err("bad DQT type","Corrupt JPEG"); + if (t > 3) return stbi__err("bad DQT table","Corrupt JPEG"); + + for (i=0; i < 64; ++i) + z->dequant[t][stbi__jpeg_dezigzag[i]] = (stbi__uint16)(sixteen ? stbi__get16be(z->s) : stbi__get8(z->s)); + L -= (sixteen ? 129 : 65); + } + return L==0; + + case 0xC4: // DHT - define huffman table + L = stbi__get16be(z->s)-2; + while (L > 0) { + stbi_uc *v; + int sizes[16],i,n=0; + int q = stbi__get8(z->s); + int tc = q >> 4; + int th = q & 15; + if (tc > 1 || th > 3) return stbi__err("bad DHT header","Corrupt JPEG"); + for (i=0; i < 16; ++i) { + sizes[i] = stbi__get8(z->s); + n += sizes[i]; + } + if(n > 256) return stbi__err("bad DHT header","Corrupt JPEG"); // Loop over i < n would write past end of values! + L -= 17; + if (tc == 0) { + if (!stbi__build_huffman(z->huff_dc+th, sizes)) return 0; + v = z->huff_dc[th].values; + } else { + if (!stbi__build_huffman(z->huff_ac+th, sizes)) return 0; + v = z->huff_ac[th].values; + } + for (i=0; i < n; ++i) + v[i] = stbi__get8(z->s); + if (tc != 0) + stbi__build_fast_ac(z->fast_ac[th], z->huff_ac + th); + L -= n; + } + return L==0; + } + + // check for comment block or APP blocks + if ((m >= 0xE0 && m <= 0xEF) || m == 0xFE) { + L = stbi__get16be(z->s); + if (L < 2) { + if (m == 0xFE) + return stbi__err("bad COM len","Corrupt JPEG"); + else + return stbi__err("bad APP len","Corrupt JPEG"); + } + L -= 2; + + if (m == 0xE0 && L >= 5) { // JFIF APP0 segment + static const unsigned char tag[5] = {'J','F','I','F','\0'}; + int ok = 1; + int i; + for (i=0; i < 5; ++i) + if (stbi__get8(z->s) != tag[i]) + ok = 0; + L -= 5; + if (ok) + z->jfif = 1; + } else if (m == 0xEE && L >= 12) { // Adobe APP14 segment + static const unsigned char tag[6] = {'A','d','o','b','e','\0'}; + int ok = 1; + int i; + for (i=0; i < 6; ++i) + if (stbi__get8(z->s) != tag[i]) + ok = 0; + L -= 6; + if (ok) { + stbi__get8(z->s); // version + stbi__get16be(z->s); // flags0 + stbi__get16be(z->s); // flags1 + z->app14_color_transform = stbi__get8(z->s); // color transform + L -= 6; + } + } + + stbi__skip(z->s, L); + return 1; + } + + return stbi__err("unknown marker","Corrupt JPEG"); +} + +// after we see SOS +static int stbi__process_scan_header(stbi__jpeg *z) +{ + int i; + int Ls = stbi__get16be(z->s); + z->scan_n = stbi__get8(z->s); + if (z->scan_n < 1 || z->scan_n > 4 || z->scan_n > (int) z->s->img_n) return stbi__err("bad SOS component count","Corrupt JPEG"); + if (Ls != 6+2*z->scan_n) return stbi__err("bad SOS len","Corrupt JPEG"); + for (i=0; i < z->scan_n; ++i) { + int id = stbi__get8(z->s), which; + int q = stbi__get8(z->s); + for (which = 0; which < z->s->img_n; ++which) + if (z->img_comp[which].id == id) + break; + if (which == z->s->img_n) return 0; // no match + z->img_comp[which].hd = q >> 4; if (z->img_comp[which].hd > 3) return stbi__err("bad DC huff","Corrupt JPEG"); + z->img_comp[which].ha = q & 15; if (z->img_comp[which].ha > 3) return stbi__err("bad AC huff","Corrupt JPEG"); + z->order[i] = which; + } + + { + int aa; + z->spec_start = stbi__get8(z->s); + z->spec_end = stbi__get8(z->s); // should be 63, but might be 0 + aa = stbi__get8(z->s); + z->succ_high = (aa >> 4); + z->succ_low = (aa & 15); + if (z->progressive) { + if (z->spec_start > 63 || z->spec_end > 63 || z->spec_start > z->spec_end || z->succ_high > 13 || z->succ_low > 13) + return stbi__err("bad SOS", "Corrupt JPEG"); + } else { + if (z->spec_start != 0) return stbi__err("bad SOS","Corrupt JPEG"); + if (z->succ_high != 0 || z->succ_low != 0) return stbi__err("bad SOS","Corrupt JPEG"); + z->spec_end = 63; + } + } + + return 1; +} + +static int stbi__free_jpeg_components(stbi__jpeg *z, int ncomp, int why) +{ + int i; + for (i=0; i < ncomp; ++i) { + if (z->img_comp[i].raw_data) { + STBI_FREE(z->img_comp[i].raw_data); + z->img_comp[i].raw_data = NULL; + z->img_comp[i].data = NULL; + } + if (z->img_comp[i].raw_coeff) { + STBI_FREE(z->img_comp[i].raw_coeff); + z->img_comp[i].raw_coeff = 0; + z->img_comp[i].coeff = 0; + } + if (z->img_comp[i].linebuf) { + STBI_FREE(z->img_comp[i].linebuf); + z->img_comp[i].linebuf = NULL; + } + } + return why; +} + +static int stbi__process_frame_header(stbi__jpeg *z, int scan) +{ + stbi__context *s = z->s; + int Lf,p,i,q, h_max=1,v_max=1,c; + Lf = stbi__get16be(s); if (Lf < 11) return stbi__err("bad SOF len","Corrupt JPEG"); // JPEG + p = stbi__get8(s); if (p != 8) return stbi__err("only 8-bit","JPEG format not supported: 8-bit only"); // JPEG baseline + s->img_y = stbi__get16be(s); if (s->img_y == 0) return stbi__err("no header height", "JPEG format not supported: delayed height"); // Legal, but we don't handle it--but neither does IJG + s->img_x = stbi__get16be(s); if (s->img_x == 0) return stbi__err("0 width","Corrupt JPEG"); // JPEG requires + if (s->img_y > STBI_MAX_DIMENSIONS) return stbi__err("too large","Very large image (corrupt?)"); + if (s->img_x > STBI_MAX_DIMENSIONS) return stbi__err("too large","Very large image (corrupt?)"); + c = stbi__get8(s); + if (c != 3 && c != 1 && c != 4) return stbi__err("bad component count","Corrupt JPEG"); + s->img_n = c; + for (i=0; i < c; ++i) { + z->img_comp[i].data = NULL; + z->img_comp[i].linebuf = NULL; + } + + if (Lf != 8+3*s->img_n) return stbi__err("bad SOF len","Corrupt JPEG"); + + z->rgb = 0; + for (i=0; i < s->img_n; ++i) { + static const unsigned char rgb[3] = { 'R', 'G', 'B' }; + z->img_comp[i].id = stbi__get8(s); + if (s->img_n == 3 && z->img_comp[i].id == rgb[i]) + ++z->rgb; + q = stbi__get8(s); + z->img_comp[i].h = (q >> 4); if (!z->img_comp[i].h || z->img_comp[i].h > 4) return stbi__err("bad H","Corrupt JPEG"); + z->img_comp[i].v = q & 15; if (!z->img_comp[i].v || z->img_comp[i].v > 4) return stbi__err("bad V","Corrupt JPEG"); + z->img_comp[i].tq = stbi__get8(s); if (z->img_comp[i].tq > 3) return stbi__err("bad TQ","Corrupt JPEG"); + } + + if (scan != STBI__SCAN_load) return 1; + + if (!stbi__mad3sizes_valid(s->img_x, s->img_y, s->img_n, 0)) return stbi__err("too large", "Image too large to decode"); + + for (i=0; i < s->img_n; ++i) { + if (z->img_comp[i].h > h_max) h_max = z->img_comp[i].h; + if (z->img_comp[i].v > v_max) v_max = z->img_comp[i].v; + } + + // check that plane subsampling factors are integer ratios; our resamplers can't deal with fractional ratios + // and I've never seen a non-corrupted JPEG file actually use them + for (i=0; i < s->img_n; ++i) { + if (h_max % z->img_comp[i].h != 0) return stbi__err("bad H","Corrupt JPEG"); + if (v_max % z->img_comp[i].v != 0) return stbi__err("bad V","Corrupt JPEG"); + } + + // compute interleaved mcu info + z->img_h_max = h_max; + z->img_v_max = v_max; + z->img_mcu_w = h_max * 8; + z->img_mcu_h = v_max * 8; + // these sizes can't be more than 17 bits + z->img_mcu_x = (s->img_x + z->img_mcu_w-1) / z->img_mcu_w; + z->img_mcu_y = (s->img_y + z->img_mcu_h-1) / z->img_mcu_h; + + for (i=0; i < s->img_n; ++i) { + // number of effective pixels (e.g. for non-interleaved MCU) + z->img_comp[i].x = (s->img_x * z->img_comp[i].h + h_max-1) / h_max; + z->img_comp[i].y = (s->img_y * z->img_comp[i].v + v_max-1) / v_max; + // to simplify generation, we'll allocate enough memory to decode + // the bogus oversized data from using interleaved MCUs and their + // big blocks (e.g. a 16x16 iMCU on an image of width 33); we won't + // discard the extra data until colorspace conversion + // + // img_mcu_x, img_mcu_y: <=17 bits; comp[i].h and .v are <=4 (checked earlier) + // so these muls can't overflow with 32-bit ints (which we require) + z->img_comp[i].w2 = z->img_mcu_x * z->img_comp[i].h * 8; + z->img_comp[i].h2 = z->img_mcu_y * z->img_comp[i].v * 8; + z->img_comp[i].coeff = 0; + z->img_comp[i].raw_coeff = 0; + z->img_comp[i].linebuf = NULL; + z->img_comp[i].raw_data = stbi__malloc_mad2(z->img_comp[i].w2, z->img_comp[i].h2, 15); + if (z->img_comp[i].raw_data == NULL) + return stbi__free_jpeg_components(z, i+1, stbi__err("outofmem", "Out of memory")); + // align blocks for idct using mmx/sse + z->img_comp[i].data = (stbi_uc*) (((size_t) z->img_comp[i].raw_data + 15) & ~15); + if (z->progressive) { + // w2, h2 are multiples of 8 (see above) + z->img_comp[i].coeff_w = z->img_comp[i].w2 / 8; + z->img_comp[i].coeff_h = z->img_comp[i].h2 / 8; + z->img_comp[i].raw_coeff = stbi__malloc_mad3(z->img_comp[i].w2, z->img_comp[i].h2, sizeof(short), 15); + if (z->img_comp[i].raw_coeff == NULL) + return stbi__free_jpeg_components(z, i+1, stbi__err("outofmem", "Out of memory")); + z->img_comp[i].coeff = (short*) (((size_t) z->img_comp[i].raw_coeff + 15) & ~15); + } + } + + return 1; +} + +// use comparisons since in some cases we handle more than one case (e.g. SOF) +#define stbi__DNL(x) ((x) == 0xdc) +#define stbi__SOI(x) ((x) == 0xd8) +#define stbi__EOI(x) ((x) == 0xd9) +#define stbi__SOF(x) ((x) == 0xc0 || (x) == 0xc1 || (x) == 0xc2) +#define stbi__SOS(x) ((x) == 0xda) + +#define stbi__SOF_progressive(x) ((x) == 0xc2) + +static int stbi__decode_jpeg_header(stbi__jpeg *z, int scan) +{ + int m; + z->jfif = 0; + z->app14_color_transform = -1; // valid values are 0,1,2 + z->marker = STBI__MARKER_none; // initialize cached marker to empty + m = stbi__get_marker(z); + if (!stbi__SOI(m)) return stbi__err("no SOI","Corrupt JPEG"); + if (scan == STBI__SCAN_type) return 1; + m = stbi__get_marker(z); + while (!stbi__SOF(m)) { + if (!stbi__process_marker(z,m)) return 0; + m = stbi__get_marker(z); + while (m == STBI__MARKER_none) { + // some files have extra padding after their blocks, so ok, we'll scan + if (stbi__at_eof(z->s)) return stbi__err("no SOF", "Corrupt JPEG"); + m = stbi__get_marker(z); + } + } + z->progressive = stbi__SOF_progressive(m); + if (!stbi__process_frame_header(z, scan)) return 0; + return 1; +} + +static stbi_uc stbi__skip_jpeg_junk_at_end(stbi__jpeg *j) +{ + // some JPEGs have junk at end, skip over it but if we find what looks + // like a valid marker, resume there + while (!stbi__at_eof(j->s)) { + stbi_uc x = stbi__get8(j->s); + while (x == 0xff) { // might be a marker + if (stbi__at_eof(j->s)) return STBI__MARKER_none; + x = stbi__get8(j->s); + if (x != 0x00 && x != 0xff) { + // not a stuffed zero or lead-in to another marker, looks + // like an actual marker, return it + return x; + } + // stuffed zero has x=0 now which ends the loop, meaning we go + // back to regular scan loop. + // repeated 0xff keeps trying to read the next byte of the marker. + } + } + return STBI__MARKER_none; +} + +// decode image to YCbCr format +static int stbi__decode_jpeg_image(stbi__jpeg *j) +{ + int m; + for (m = 0; m < 4; m++) { + j->img_comp[m].raw_data = NULL; + j->img_comp[m].raw_coeff = NULL; + } + j->restart_interval = 0; + if (!stbi__decode_jpeg_header(j, STBI__SCAN_load)) return 0; + m = stbi__get_marker(j); + while (!stbi__EOI(m)) { + if (stbi__SOS(m)) { + if (!stbi__process_scan_header(j)) return 0; + if (!stbi__parse_entropy_coded_data(j)) return 0; + if (j->marker == STBI__MARKER_none ) { + j->marker = stbi__skip_jpeg_junk_at_end(j); + // if we reach eof without hitting a marker, stbi__get_marker() below will fail and we'll eventually return 0 + } + m = stbi__get_marker(j); + if (STBI__RESTART(m)) + m = stbi__get_marker(j); + } else if (stbi__DNL(m)) { + int Ld = stbi__get16be(j->s); + stbi__uint32 NL = stbi__get16be(j->s); + if (Ld != 4) return stbi__err("bad DNL len", "Corrupt JPEG"); + if (NL != j->s->img_y) return stbi__err("bad DNL height", "Corrupt JPEG"); + m = stbi__get_marker(j); + } else { + if (!stbi__process_marker(j, m)) return 1; + m = stbi__get_marker(j); + } + } + if (j->progressive) + stbi__jpeg_finish(j); + return 1; +} + +// static jfif-centered resampling (across block boundaries) + +typedef stbi_uc *(*resample_row_func)(stbi_uc *out, stbi_uc *in0, stbi_uc *in1, + int w, int hs); + +#define stbi__div4(x) ((stbi_uc) ((x) >> 2)) + +static stbi_uc *resample_row_1(stbi_uc *out, stbi_uc *in_near, stbi_uc *in_far, int w, int hs) +{ + STBI_NOTUSED(out); + STBI_NOTUSED(in_far); + STBI_NOTUSED(w); + STBI_NOTUSED(hs); + return in_near; +} + +static stbi_uc* stbi__resample_row_v_2(stbi_uc *out, stbi_uc *in_near, stbi_uc *in_far, int w, int hs) +{ + // need to generate two samples vertically for every one in input + int i; + STBI_NOTUSED(hs); + for (i=0; i < w; ++i) + out[i] = stbi__div4(3*in_near[i] + in_far[i] + 2); + return out; +} + +static stbi_uc* stbi__resample_row_h_2(stbi_uc *out, stbi_uc *in_near, stbi_uc *in_far, int w, int hs) +{ + // need to generate two samples horizontally for every one in input + int i; + stbi_uc *input = in_near; + + if (w == 1) { + // if only one sample, can't do any interpolation + out[0] = out[1] = input[0]; + return out; + } + + out[0] = input[0]; + out[1] = stbi__div4(input[0]*3 + input[1] + 2); + for (i=1; i < w-1; ++i) { + int n = 3*input[i]+2; + out[i*2+0] = stbi__div4(n+input[i-1]); + out[i*2+1] = stbi__div4(n+input[i+1]); + } + out[i*2+0] = stbi__div4(input[w-2]*3 + input[w-1] + 2); + out[i*2+1] = input[w-1]; + + STBI_NOTUSED(in_far); + STBI_NOTUSED(hs); + + return out; +} + +#define stbi__div16(x) ((stbi_uc) ((x) >> 4)) + +static stbi_uc *stbi__resample_row_hv_2(stbi_uc *out, stbi_uc *in_near, stbi_uc *in_far, int w, int hs) +{ + // need to generate 2x2 samples for every one in input + int i,t0,t1; + if (w == 1) { + out[0] = out[1] = stbi__div4(3*in_near[0] + in_far[0] + 2); + return out; + } + + t1 = 3*in_near[0] + in_far[0]; + out[0] = stbi__div4(t1+2); + for (i=1; i < w; ++i) { + t0 = t1; + t1 = 3*in_near[i]+in_far[i]; + out[i*2-1] = stbi__div16(3*t0 + t1 + 8); + out[i*2 ] = stbi__div16(3*t1 + t0 + 8); + } + out[w*2-1] = stbi__div4(t1+2); + + STBI_NOTUSED(hs); + + return out; +} + +#if defined(STBI_SSE2) || defined(STBI_NEON) +static stbi_uc *stbi__resample_row_hv_2_simd(stbi_uc *out, stbi_uc *in_near, stbi_uc *in_far, int w, int hs) +{ + // need to generate 2x2 samples for every one in input + int i=0,t0,t1; + + if (w == 1) { + out[0] = out[1] = stbi__div4(3*in_near[0] + in_far[0] + 2); + return out; + } + + t1 = 3*in_near[0] + in_far[0]; + // process groups of 8 pixels for as long as we can. + // note we can't handle the last pixel in a row in this loop + // because we need to handle the filter boundary conditions. + for (; i < ((w-1) & ~7); i += 8) { +#if defined(STBI_SSE2) + // load and perform the vertical filtering pass + // this uses 3*x + y = 4*x + (y - x) + __m128i zero = _mm_setzero_si128(); + __m128i farb = _mm_loadl_epi64((__m128i *) (in_far + i)); + __m128i nearb = _mm_loadl_epi64((__m128i *) (in_near + i)); + __m128i farw = _mm_unpacklo_epi8(farb, zero); + __m128i nearw = _mm_unpacklo_epi8(nearb, zero); + __m128i diff = _mm_sub_epi16(farw, nearw); + __m128i nears = _mm_slli_epi16(nearw, 2); + __m128i curr = _mm_add_epi16(nears, diff); // current row + + // horizontal filter works the same based on shifted vers of current + // row. "prev" is current row shifted right by 1 pixel; we need to + // insert the previous pixel value (from t1). + // "next" is current row shifted left by 1 pixel, with first pixel + // of next block of 8 pixels added in. + __m128i prv0 = _mm_slli_si128(curr, 2); + __m128i nxt0 = _mm_srli_si128(curr, 2); + __m128i prev = _mm_insert_epi16(prv0, t1, 0); + __m128i next = _mm_insert_epi16(nxt0, 3*in_near[i+8] + in_far[i+8], 7); + + // horizontal filter, polyphase implementation since it's convenient: + // even pixels = 3*cur + prev = cur*4 + (prev - cur) + // odd pixels = 3*cur + next = cur*4 + (next - cur) + // note the shared term. + __m128i bias = _mm_set1_epi16(8); + __m128i curs = _mm_slli_epi16(curr, 2); + __m128i prvd = _mm_sub_epi16(prev, curr); + __m128i nxtd = _mm_sub_epi16(next, curr); + __m128i curb = _mm_add_epi16(curs, bias); + __m128i even = _mm_add_epi16(prvd, curb); + __m128i odd = _mm_add_epi16(nxtd, curb); + + // interleave even and odd pixels, then undo scaling. + __m128i int0 = _mm_unpacklo_epi16(even, odd); + __m128i int1 = _mm_unpackhi_epi16(even, odd); + __m128i de0 = _mm_srli_epi16(int0, 4); + __m128i de1 = _mm_srli_epi16(int1, 4); + + // pack and write output + __m128i outv = _mm_packus_epi16(de0, de1); + _mm_storeu_si128((__m128i *) (out + i*2), outv); +#elif defined(STBI_NEON) + // load and perform the vertical filtering pass + // this uses 3*x + y = 4*x + (y - x) + uint8x8_t farb = vld1_u8(in_far + i); + uint8x8_t nearb = vld1_u8(in_near + i); + int16x8_t diff = vreinterpretq_s16_u16(vsubl_u8(farb, nearb)); + int16x8_t nears = vreinterpretq_s16_u16(vshll_n_u8(nearb, 2)); + int16x8_t curr = vaddq_s16(nears, diff); // current row + + // horizontal filter works the same based on shifted vers of current + // row. "prev" is current row shifted right by 1 pixel; we need to + // insert the previous pixel value (from t1). + // "next" is current row shifted left by 1 pixel, with first pixel + // of next block of 8 pixels added in. + int16x8_t prv0 = vextq_s16(curr, curr, 7); + int16x8_t nxt0 = vextq_s16(curr, curr, 1); + int16x8_t prev = vsetq_lane_s16(t1, prv0, 0); + int16x8_t next = vsetq_lane_s16(3*in_near[i+8] + in_far[i+8], nxt0, 7); + + // horizontal filter, polyphase implementation since it's convenient: + // even pixels = 3*cur + prev = cur*4 + (prev - cur) + // odd pixels = 3*cur + next = cur*4 + (next - cur) + // note the shared term. + int16x8_t curs = vshlq_n_s16(curr, 2); + int16x8_t prvd = vsubq_s16(prev, curr); + int16x8_t nxtd = vsubq_s16(next, curr); + int16x8_t even = vaddq_s16(curs, prvd); + int16x8_t odd = vaddq_s16(curs, nxtd); + + // undo scaling and round, then store with even/odd phases interleaved + uint8x8x2_t o; + o.val[0] = vqrshrun_n_s16(even, 4); + o.val[1] = vqrshrun_n_s16(odd, 4); + vst2_u8(out + i*2, o); +#endif + + // "previous" value for next iter + t1 = 3*in_near[i+7] + in_far[i+7]; + } + + t0 = t1; + t1 = 3*in_near[i] + in_far[i]; + out[i*2] = stbi__div16(3*t1 + t0 + 8); + + for (++i; i < w; ++i) { + t0 = t1; + t1 = 3*in_near[i]+in_far[i]; + out[i*2-1] = stbi__div16(3*t0 + t1 + 8); + out[i*2 ] = stbi__div16(3*t1 + t0 + 8); + } + out[w*2-1] = stbi__div4(t1+2); + + STBI_NOTUSED(hs); + + return out; +} +#endif + +static stbi_uc *stbi__resample_row_generic(stbi_uc *out, stbi_uc *in_near, stbi_uc *in_far, int w, int hs) +{ + // resample with nearest-neighbor + int i,j; + STBI_NOTUSED(in_far); + for (i=0; i < w; ++i) + for (j=0; j < hs; ++j) + out[i*hs+j] = in_near[i]; + return out; +} + +// this is a reduced-precision calculation of YCbCr-to-RGB introduced +// to make sure the code produces the same results in both SIMD and scalar +#define stbi__float2fixed(x) (((int) ((x) * 4096.0f + 0.5f)) << 8) +static void stbi__YCbCr_to_RGB_row(stbi_uc *out, const stbi_uc *y, const stbi_uc *pcb, const stbi_uc *pcr, int count, int step) +{ + int i; + for (i=0; i < count; ++i) { + int y_fixed = (y[i] << 20) + (1<<19); // rounding + int r,g,b; + int cr = pcr[i] - 128; + int cb = pcb[i] - 128; + r = y_fixed + cr* stbi__float2fixed(1.40200f); + g = y_fixed + (cr*-stbi__float2fixed(0.71414f)) + ((cb*-stbi__float2fixed(0.34414f)) & 0xffff0000); + b = y_fixed + cb* stbi__float2fixed(1.77200f); + r >>= 20; + g >>= 20; + b >>= 20; + if ((unsigned) r > 255) { if (r < 0) r = 0; else r = 255; } + if ((unsigned) g > 255) { if (g < 0) g = 0; else g = 255; } + if ((unsigned) b > 255) { if (b < 0) b = 0; else b = 255; } + out[0] = (stbi_uc)r; + out[1] = (stbi_uc)g; + out[2] = (stbi_uc)b; + out[3] = 255; + out += step; + } +} + +#if defined(STBI_SSE2) || defined(STBI_NEON) +static void stbi__YCbCr_to_RGB_simd(stbi_uc *out, stbi_uc const *y, stbi_uc const *pcb, stbi_uc const *pcr, int count, int step) +{ + int i = 0; + +#ifdef STBI_SSE2 + // step == 3 is pretty ugly on the final interleave, and i'm not convinced + // it's useful in practice (you wouldn't use it for textures, for example). + // so just accelerate step == 4 case. + if (step == 4) { + // this is a fairly straightforward implementation and not super-optimized. + __m128i signflip = _mm_set1_epi8(-0x80); + __m128i cr_const0 = _mm_set1_epi16( (short) ( 1.40200f*4096.0f+0.5f)); + __m128i cr_const1 = _mm_set1_epi16( - (short) ( 0.71414f*4096.0f+0.5f)); + __m128i cb_const0 = _mm_set1_epi16( - (short) ( 0.34414f*4096.0f+0.5f)); + __m128i cb_const1 = _mm_set1_epi16( (short) ( 1.77200f*4096.0f+0.5f)); + __m128i y_bias = _mm_set1_epi8((char) (unsigned char) 128); + __m128i xw = _mm_set1_epi16(255); // alpha channel + + for (; i+7 < count; i += 8) { + // load + __m128i y_bytes = _mm_loadl_epi64((__m128i *) (y+i)); + __m128i cr_bytes = _mm_loadl_epi64((__m128i *) (pcr+i)); + __m128i cb_bytes = _mm_loadl_epi64((__m128i *) (pcb+i)); + __m128i cr_biased = _mm_xor_si128(cr_bytes, signflip); // -128 + __m128i cb_biased = _mm_xor_si128(cb_bytes, signflip); // -128 + + // unpack to short (and left-shift cr, cb by 8) + __m128i yw = _mm_unpacklo_epi8(y_bias, y_bytes); + __m128i crw = _mm_unpacklo_epi8(_mm_setzero_si128(), cr_biased); + __m128i cbw = _mm_unpacklo_epi8(_mm_setzero_si128(), cb_biased); + + // color transform + __m128i yws = _mm_srli_epi16(yw, 4); + __m128i cr0 = _mm_mulhi_epi16(cr_const0, crw); + __m128i cb0 = _mm_mulhi_epi16(cb_const0, cbw); + __m128i cb1 = _mm_mulhi_epi16(cbw, cb_const1); + __m128i cr1 = _mm_mulhi_epi16(crw, cr_const1); + __m128i rws = _mm_add_epi16(cr0, yws); + __m128i gwt = _mm_add_epi16(cb0, yws); + __m128i bws = _mm_add_epi16(yws, cb1); + __m128i gws = _mm_add_epi16(gwt, cr1); + + // descale + __m128i rw = _mm_srai_epi16(rws, 4); + __m128i bw = _mm_srai_epi16(bws, 4); + __m128i gw = _mm_srai_epi16(gws, 4); + + // back to byte, set up for transpose + __m128i brb = _mm_packus_epi16(rw, bw); + __m128i gxb = _mm_packus_epi16(gw, xw); + + // transpose to interleave channels + __m128i t0 = _mm_unpacklo_epi8(brb, gxb); + __m128i t1 = _mm_unpackhi_epi8(brb, gxb); + __m128i o0 = _mm_unpacklo_epi16(t0, t1); + __m128i o1 = _mm_unpackhi_epi16(t0, t1); + + // store + _mm_storeu_si128((__m128i *) (out + 0), o0); + _mm_storeu_si128((__m128i *) (out + 16), o1); + out += 32; + } + } +#endif + +#ifdef STBI_NEON + // in this version, step=3 support would be easy to add. but is there demand? + if (step == 4) { + // this is a fairly straightforward implementation and not super-optimized. + uint8x8_t signflip = vdup_n_u8(0x80); + int16x8_t cr_const0 = vdupq_n_s16( (short) ( 1.40200f*4096.0f+0.5f)); + int16x8_t cr_const1 = vdupq_n_s16( - (short) ( 0.71414f*4096.0f+0.5f)); + int16x8_t cb_const0 = vdupq_n_s16( - (short) ( 0.34414f*4096.0f+0.5f)); + int16x8_t cb_const1 = vdupq_n_s16( (short) ( 1.77200f*4096.0f+0.5f)); + + for (; i+7 < count; i += 8) { + // load + uint8x8_t y_bytes = vld1_u8(y + i); + uint8x8_t cr_bytes = vld1_u8(pcr + i); + uint8x8_t cb_bytes = vld1_u8(pcb + i); + int8x8_t cr_biased = vreinterpret_s8_u8(vsub_u8(cr_bytes, signflip)); + int8x8_t cb_biased = vreinterpret_s8_u8(vsub_u8(cb_bytes, signflip)); + + // expand to s16 + int16x8_t yws = vreinterpretq_s16_u16(vshll_n_u8(y_bytes, 4)); + int16x8_t crw = vshll_n_s8(cr_biased, 7); + int16x8_t cbw = vshll_n_s8(cb_biased, 7); + + // color transform + int16x8_t cr0 = vqdmulhq_s16(crw, cr_const0); + int16x8_t cb0 = vqdmulhq_s16(cbw, cb_const0); + int16x8_t cr1 = vqdmulhq_s16(crw, cr_const1); + int16x8_t cb1 = vqdmulhq_s16(cbw, cb_const1); + int16x8_t rws = vaddq_s16(yws, cr0); + int16x8_t gws = vaddq_s16(vaddq_s16(yws, cb0), cr1); + int16x8_t bws = vaddq_s16(yws, cb1); + + // undo scaling, round, convert to byte + uint8x8x4_t o; + o.val[0] = vqrshrun_n_s16(rws, 4); + o.val[1] = vqrshrun_n_s16(gws, 4); + o.val[2] = vqrshrun_n_s16(bws, 4); + o.val[3] = vdup_n_u8(255); + + // store, interleaving r/g/b/a + vst4_u8(out, o); + out += 8*4; + } + } +#endif + + for (; i < count; ++i) { + int y_fixed = (y[i] << 20) + (1<<19); // rounding + int r,g,b; + int cr = pcr[i] - 128; + int cb = pcb[i] - 128; + r = y_fixed + cr* stbi__float2fixed(1.40200f); + g = y_fixed + cr*-stbi__float2fixed(0.71414f) + ((cb*-stbi__float2fixed(0.34414f)) & 0xffff0000); + b = y_fixed + cb* stbi__float2fixed(1.77200f); + r >>= 20; + g >>= 20; + b >>= 20; + if ((unsigned) r > 255) { if (r < 0) r = 0; else r = 255; } + if ((unsigned) g > 255) { if (g < 0) g = 0; else g = 255; } + if ((unsigned) b > 255) { if (b < 0) b = 0; else b = 255; } + out[0] = (stbi_uc)r; + out[1] = (stbi_uc)g; + out[2] = (stbi_uc)b; + out[3] = 255; + out += step; + } +} +#endif + +// set up the kernels +static void stbi__setup_jpeg(stbi__jpeg *j) +{ + j->idct_block_kernel = stbi__idct_block; + j->YCbCr_to_RGB_kernel = stbi__YCbCr_to_RGB_row; + j->resample_row_hv_2_kernel = stbi__resample_row_hv_2; + +#ifdef STBI_SSE2 + if (stbi__sse2_available()) { + j->idct_block_kernel = stbi__idct_simd; + j->YCbCr_to_RGB_kernel = stbi__YCbCr_to_RGB_simd; + j->resample_row_hv_2_kernel = stbi__resample_row_hv_2_simd; + } +#endif + +#ifdef STBI_NEON + j->idct_block_kernel = stbi__idct_simd; + j->YCbCr_to_RGB_kernel = stbi__YCbCr_to_RGB_simd; + j->resample_row_hv_2_kernel = stbi__resample_row_hv_2_simd; +#endif +} + +// clean up the temporary component buffers +static void stbi__cleanup_jpeg(stbi__jpeg *j) +{ + stbi__free_jpeg_components(j, j->s->img_n, 0); +} + +typedef struct +{ + resample_row_func resample; + stbi_uc *line0,*line1; + int hs,vs; // expansion factor in each axis + int w_lores; // horizontal pixels pre-expansion + int ystep; // how far through vertical expansion we are + int ypos; // which pre-expansion row we're on +} stbi__resample; + +// fast 0..255 * 0..255 => 0..255 rounded multiplication +static stbi_uc stbi__blinn_8x8(stbi_uc x, stbi_uc y) +{ + unsigned int t = x*y + 128; + return (stbi_uc) ((t + (t >>8)) >> 8); +} + +static stbi_uc *load_jpeg_image(stbi__jpeg *z, int *out_x, int *out_y, int *comp, int req_comp) +{ + int n, decode_n, is_rgb; + z->s->img_n = 0; // make stbi__cleanup_jpeg safe + + // validate req_comp + if (req_comp < 0 || req_comp > 4) return stbi__errpuc("bad req_comp", "Internal error"); + + // load a jpeg image from whichever source, but leave in YCbCr format + if (!stbi__decode_jpeg_image(z)) { stbi__cleanup_jpeg(z); return NULL; } + + // determine actual number of components to generate + n = req_comp ? req_comp : z->s->img_n >= 3 ? 3 : 1; + + is_rgb = z->s->img_n == 3 && (z->rgb == 3 || (z->app14_color_transform == 0 && !z->jfif)); + + if (z->s->img_n == 3 && n < 3 && !is_rgb) + decode_n = 1; + else + decode_n = z->s->img_n; + + // nothing to do if no components requested; check this now to avoid + // accessing uninitialized coutput[0] later + if (decode_n <= 0) { stbi__cleanup_jpeg(z); return NULL; } + + // resample and color-convert + { + int k; + unsigned int i,j; + stbi_uc *output; + stbi_uc *coutput[4] = { NULL, NULL, NULL, NULL }; + + stbi__resample res_comp[4]; + + for (k=0; k < decode_n; ++k) { + stbi__resample *r = &res_comp[k]; + + // allocate line buffer big enough for upsampling off the edges + // with upsample factor of 4 + z->img_comp[k].linebuf = (stbi_uc *) stbi__malloc(z->s->img_x + 3); + if (!z->img_comp[k].linebuf) { stbi__cleanup_jpeg(z); return stbi__errpuc("outofmem", "Out of memory"); } + + r->hs = z->img_h_max / z->img_comp[k].h; + r->vs = z->img_v_max / z->img_comp[k].v; + r->ystep = r->vs >> 1; + r->w_lores = (z->s->img_x + r->hs-1) / r->hs; + r->ypos = 0; + r->line0 = r->line1 = z->img_comp[k].data; + + if (r->hs == 1 && r->vs == 1) r->resample = resample_row_1; + else if (r->hs == 1 && r->vs == 2) r->resample = stbi__resample_row_v_2; + else if (r->hs == 2 && r->vs == 1) r->resample = stbi__resample_row_h_2; + else if (r->hs == 2 && r->vs == 2) r->resample = z->resample_row_hv_2_kernel; + else r->resample = stbi__resample_row_generic; + } + + // can't error after this so, this is safe + output = (stbi_uc *) stbi__malloc_mad3(n, z->s->img_x, z->s->img_y, 1); + if (!output) { stbi__cleanup_jpeg(z); return stbi__errpuc("outofmem", "Out of memory"); } + + // now go ahead and resample + for (j=0; j < z->s->img_y; ++j) { + stbi_uc *out = output + n * z->s->img_x * j; + for (k=0; k < decode_n; ++k) { + stbi__resample *r = &res_comp[k]; + int y_bot = r->ystep >= (r->vs >> 1); + coutput[k] = r->resample(z->img_comp[k].linebuf, + y_bot ? r->line1 : r->line0, + y_bot ? r->line0 : r->line1, + r->w_lores, r->hs); + if (++r->ystep >= r->vs) { + r->ystep = 0; + r->line0 = r->line1; + if (++r->ypos < z->img_comp[k].y) + r->line1 += z->img_comp[k].w2; + } + } + if (n >= 3) { + stbi_uc *y = coutput[0]; + if (z->s->img_n == 3) { + if (is_rgb) { + for (i=0; i < z->s->img_x; ++i) { + out[0] = y[i]; + out[1] = coutput[1][i]; + out[2] = coutput[2][i]; + out[3] = 255; + out += n; + } + } else { + z->YCbCr_to_RGB_kernel(out, y, coutput[1], coutput[2], z->s->img_x, n); + } + } else if (z->s->img_n == 4) { + if (z->app14_color_transform == 0) { // CMYK + for (i=0; i < z->s->img_x; ++i) { + stbi_uc m = coutput[3][i]; + out[0] = stbi__blinn_8x8(coutput[0][i], m); + out[1] = stbi__blinn_8x8(coutput[1][i], m); + out[2] = stbi__blinn_8x8(coutput[2][i], m); + out[3] = 255; + out += n; + } + } else if (z->app14_color_transform == 2) { // YCCK + z->YCbCr_to_RGB_kernel(out, y, coutput[1], coutput[2], z->s->img_x, n); + for (i=0; i < z->s->img_x; ++i) { + stbi_uc m = coutput[3][i]; + out[0] = stbi__blinn_8x8(255 - out[0], m); + out[1] = stbi__blinn_8x8(255 - out[1], m); + out[2] = stbi__blinn_8x8(255 - out[2], m); + out += n; + } + } else { // YCbCr + alpha? Ignore the fourth channel for now + z->YCbCr_to_RGB_kernel(out, y, coutput[1], coutput[2], z->s->img_x, n); + } + } else + for (i=0; i < z->s->img_x; ++i) { + out[0] = out[1] = out[2] = y[i]; + out[3] = 255; // not used if n==3 + out += n; + } + } else { + if (is_rgb) { + if (n == 1) + for (i=0; i < z->s->img_x; ++i) + *out++ = stbi__compute_y(coutput[0][i], coutput[1][i], coutput[2][i]); + else { + for (i=0; i < z->s->img_x; ++i, out += 2) { + out[0] = stbi__compute_y(coutput[0][i], coutput[1][i], coutput[2][i]); + out[1] = 255; + } + } + } else if (z->s->img_n == 4 && z->app14_color_transform == 0) { + for (i=0; i < z->s->img_x; ++i) { + stbi_uc m = coutput[3][i]; + stbi_uc r = stbi__blinn_8x8(coutput[0][i], m); + stbi_uc g = stbi__blinn_8x8(coutput[1][i], m); + stbi_uc b = stbi__blinn_8x8(coutput[2][i], m); + out[0] = stbi__compute_y(r, g, b); + out[1] = 255; + out += n; + } + } else if (z->s->img_n == 4 && z->app14_color_transform == 2) { + for (i=0; i < z->s->img_x; ++i) { + out[0] = stbi__blinn_8x8(255 - coutput[0][i], coutput[3][i]); + out[1] = 255; + out += n; + } + } else { + stbi_uc *y = coutput[0]; + if (n == 1) + for (i=0; i < z->s->img_x; ++i) out[i] = y[i]; + else + for (i=0; i < z->s->img_x; ++i) { *out++ = y[i]; *out++ = 255; } + } + } + } + stbi__cleanup_jpeg(z); + *out_x = z->s->img_x; + *out_y = z->s->img_y; + if (comp) *comp = z->s->img_n >= 3 ? 3 : 1; // report original components, not output + return output; + } +} + +static void *stbi__jpeg_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri) +{ + unsigned char* result; + stbi__jpeg* j = (stbi__jpeg*) stbi__malloc(sizeof(stbi__jpeg)); + if (!j) return stbi__errpuc("outofmem", "Out of memory"); + memset(j, 0, sizeof(stbi__jpeg)); + STBI_NOTUSED(ri); + j->s = s; + stbi__setup_jpeg(j); + result = load_jpeg_image(j, x,y,comp,req_comp); + STBI_FREE(j); + return result; +} + +static int stbi__jpeg_test(stbi__context *s) +{ + int r; + stbi__jpeg* j = (stbi__jpeg*)stbi__malloc(sizeof(stbi__jpeg)); + if (!j) return stbi__err("outofmem", "Out of memory"); + memset(j, 0, sizeof(stbi__jpeg)); + j->s = s; + stbi__setup_jpeg(j); + r = stbi__decode_jpeg_header(j, STBI__SCAN_type); + stbi__rewind(s); + STBI_FREE(j); + return r; +} + +static int stbi__jpeg_info_raw(stbi__jpeg *j, int *x, int *y, int *comp) +{ + if (!stbi__decode_jpeg_header(j, STBI__SCAN_header)) { + stbi__rewind( j->s ); + return 0; + } + if (x) *x = j->s->img_x; + if (y) *y = j->s->img_y; + if (comp) *comp = j->s->img_n >= 3 ? 3 : 1; + return 1; +} + +static int stbi__jpeg_info(stbi__context *s, int *x, int *y, int *comp) +{ + int result; + stbi__jpeg* j = (stbi__jpeg*) (stbi__malloc(sizeof(stbi__jpeg))); + if (!j) return stbi__err("outofmem", "Out of memory"); + memset(j, 0, sizeof(stbi__jpeg)); + j->s = s; + result = stbi__jpeg_info_raw(j, x, y, comp); + STBI_FREE(j); + return result; +} +#endif + +// public domain zlib decode v0.2 Sean Barrett 2006-11-18 +// simple implementation +// - all input must be provided in an upfront buffer +// - all output is written to a single output buffer (can malloc/realloc) +// performance +// - fast huffman + +#ifndef STBI_NO_ZLIB + +// fast-way is faster to check than jpeg huffman, but slow way is slower +#define STBI__ZFAST_BITS 9 // accelerate all cases in default tables +#define STBI__ZFAST_MASK ((1 << STBI__ZFAST_BITS) - 1) +#define STBI__ZNSYMS 288 // number of symbols in literal/length alphabet + +// zlib-style huffman encoding +// (jpegs packs from left, zlib from right, so can't share code) +typedef struct +{ + stbi__uint16 fast[1 << STBI__ZFAST_BITS]; + stbi__uint16 firstcode[16]; + int maxcode[17]; + stbi__uint16 firstsymbol[16]; + stbi_uc size[STBI__ZNSYMS]; + stbi__uint16 value[STBI__ZNSYMS]; +} stbi__zhuffman; + +stbi_inline static int stbi__bitreverse16(int n) +{ + n = ((n & 0xAAAA) >> 1) | ((n & 0x5555) << 1); + n = ((n & 0xCCCC) >> 2) | ((n & 0x3333) << 2); + n = ((n & 0xF0F0) >> 4) | ((n & 0x0F0F) << 4); + n = ((n & 0xFF00) >> 8) | ((n & 0x00FF) << 8); + return n; +} + +stbi_inline static int stbi__bit_reverse(int v, int bits) +{ + STBI_ASSERT(bits <= 16); + // to bit reverse n bits, reverse 16 and shift + // e.g. 11 bits, bit reverse and shift away 5 + return stbi__bitreverse16(v) >> (16-bits); +} + +static int stbi__zbuild_huffman(stbi__zhuffman *z, const stbi_uc *sizelist, int num) +{ + int i,k=0; + int code, next_code[16], sizes[17]; + + // DEFLATE spec for generating codes + memset(sizes, 0, sizeof(sizes)); + memset(z->fast, 0, sizeof(z->fast)); + for (i=0; i < num; ++i) + ++sizes[sizelist[i]]; + sizes[0] = 0; + for (i=1; i < 16; ++i) + if (sizes[i] > (1 << i)) + return stbi__err("bad sizes", "Corrupt PNG"); + code = 0; + for (i=1; i < 16; ++i) { + next_code[i] = code; + z->firstcode[i] = (stbi__uint16) code; + z->firstsymbol[i] = (stbi__uint16) k; + code = (code + sizes[i]); + if (sizes[i]) + if (code-1 >= (1 << i)) return stbi__err("bad codelengths","Corrupt PNG"); + z->maxcode[i] = code << (16-i); // preshift for inner loop + code <<= 1; + k += sizes[i]; + } + z->maxcode[16] = 0x10000; // sentinel + for (i=0; i < num; ++i) { + int s = sizelist[i]; + if (s) { + int c = next_code[s] - z->firstcode[s] + z->firstsymbol[s]; + stbi__uint16 fastv = (stbi__uint16) ((s << 9) | i); + z->size [c] = (stbi_uc ) s; + z->value[c] = (stbi__uint16) i; + if (s <= STBI__ZFAST_BITS) { + int j = stbi__bit_reverse(next_code[s],s); + while (j < (1 << STBI__ZFAST_BITS)) { + z->fast[j] = fastv; + j += (1 << s); + } + } + ++next_code[s]; + } + } + return 1; +} + +// zlib-from-memory implementation for PNG reading +// because PNG allows splitting the zlib stream arbitrarily, +// and it's annoying structurally to have PNG call ZLIB call PNG, +// we require PNG read all the IDATs and combine them into a single +// memory buffer + +typedef struct +{ + stbi_uc *zbuffer, *zbuffer_end; + int num_bits; + int hit_zeof_once; + stbi__uint32 code_buffer; + + char *zout; + char *zout_start; + char *zout_end; + int z_expandable; + + stbi__zhuffman z_length, z_distance; +} stbi__zbuf; + +stbi_inline static int stbi__zeof(stbi__zbuf *z) +{ + return (z->zbuffer >= z->zbuffer_end); +} + +stbi_inline static stbi_uc stbi__zget8(stbi__zbuf *z) +{ + return stbi__zeof(z) ? 0 : *z->zbuffer++; +} + +static void stbi__fill_bits(stbi__zbuf *z) +{ + do { + if (z->code_buffer >= (1U << z->num_bits)) { + z->zbuffer = z->zbuffer_end; /* treat this as EOF so we fail. */ + return; + } + z->code_buffer |= (unsigned int) stbi__zget8(z) << z->num_bits; + z->num_bits += 8; + } while (z->num_bits <= 24); +} + +stbi_inline static unsigned int stbi__zreceive(stbi__zbuf *z, int n) +{ + unsigned int k; + if (z->num_bits < n) stbi__fill_bits(z); + k = z->code_buffer & ((1 << n) - 1); + z->code_buffer >>= n; + z->num_bits -= n; + return k; +} + +static int stbi__zhuffman_decode_slowpath(stbi__zbuf *a, stbi__zhuffman *z) +{ + int b,s,k; + // not resolved by fast table, so compute it the slow way + // use jpeg approach, which requires MSbits at top + k = stbi__bit_reverse(a->code_buffer, 16); + for (s=STBI__ZFAST_BITS+1; ; ++s) + if (k < z->maxcode[s]) + break; + if (s >= 16) return -1; // invalid code! + // code size is s, so: + b = (k >> (16-s)) - z->firstcode[s] + z->firstsymbol[s]; + if (b >= STBI__ZNSYMS) return -1; // some data was corrupt somewhere! + if (z->size[b] != s) return -1; // was originally an assert, but report failure instead. + a->code_buffer >>= s; + a->num_bits -= s; + return z->value[b]; +} + +stbi_inline static int stbi__zhuffman_decode(stbi__zbuf *a, stbi__zhuffman *z) +{ + int b,s; + if (a->num_bits < 16) { + if (stbi__zeof(a)) { + if (!a->hit_zeof_once) { + // This is the first time we hit eof, insert 16 extra padding btis + // to allow us to keep going; if we actually consume any of them + // though, that is invalid data. This is caught later. + a->hit_zeof_once = 1; + a->num_bits += 16; // add 16 implicit zero bits + } else { + // We already inserted our extra 16 padding bits and are again + // out, this stream is actually prematurely terminated. + return -1; + } + } else { + stbi__fill_bits(a); + } + } + b = z->fast[a->code_buffer & STBI__ZFAST_MASK]; + if (b) { + s = b >> 9; + a->code_buffer >>= s; + a->num_bits -= s; + return b & 511; + } + return stbi__zhuffman_decode_slowpath(a, z); +} + +static int stbi__zexpand(stbi__zbuf *z, char *zout, int n) // need to make room for n bytes +{ + char *q; + unsigned int cur, limit, old_limit; + z->zout = zout; + if (!z->z_expandable) return stbi__err("output buffer limit","Corrupt PNG"); + cur = (unsigned int) (z->zout - z->zout_start); + limit = old_limit = (unsigned) (z->zout_end - z->zout_start); + if (UINT_MAX - cur < (unsigned) n) return stbi__err("outofmem", "Out of memory"); + while (cur + n > limit) { + if(limit > UINT_MAX / 2) return stbi__err("outofmem", "Out of memory"); + limit *= 2; + } + q = (char *) STBI_REALLOC_SIZED(z->zout_start, old_limit, limit); + STBI_NOTUSED(old_limit); + if (q == NULL) return stbi__err("outofmem", "Out of memory"); + z->zout_start = q; + z->zout = q + cur; + z->zout_end = q + limit; + return 1; +} + +static const int stbi__zlength_base[31] = { + 3,4,5,6,7,8,9,10,11,13, + 15,17,19,23,27,31,35,43,51,59, + 67,83,99,115,131,163,195,227,258,0,0 }; + +static const int stbi__zlength_extra[31]= +{ 0,0,0,0,0,0,0,0,1,1,1,1,2,2,2,2,3,3,3,3,4,4,4,4,5,5,5,5,0,0,0 }; + +static const int stbi__zdist_base[32] = { 1,2,3,4,5,7,9,13,17,25,33,49,65,97,129,193, +257,385,513,769,1025,1537,2049,3073,4097,6145,8193,12289,16385,24577,0,0}; + +static const int stbi__zdist_extra[32] = +{ 0,0,0,0,1,1,2,2,3,3,4,4,5,5,6,6,7,7,8,8,9,9,10,10,11,11,12,12,13,13}; + +static int stbi__parse_huffman_block(stbi__zbuf *a) +{ + char *zout = a->zout; + for(;;) { + int z = stbi__zhuffman_decode(a, &a->z_length); + if (z < 256) { + if (z < 0) return stbi__err("bad huffman code","Corrupt PNG"); // error in huffman codes + if (zout >= a->zout_end) { + if (!stbi__zexpand(a, zout, 1)) return 0; + zout = a->zout; + } + *zout++ = (char) z; + } else { + stbi_uc *p; + int len,dist; + if (z == 256) { + a->zout = zout; + if (a->hit_zeof_once && a->num_bits < 16) { + // The first time we hit zeof, we inserted 16 extra zero bits into our bit + // buffer so the decoder can just do its speculative decoding. But if we + // actually consumed any of those bits (which is the case when num_bits < 16), + // the stream actually read past the end so it is malformed. + return stbi__err("unexpected end","Corrupt PNG"); + } + return 1; + } + if (z >= 286) return stbi__err("bad huffman code","Corrupt PNG"); // per DEFLATE, length codes 286 and 287 must not appear in compressed data + z -= 257; + len = stbi__zlength_base[z]; + if (stbi__zlength_extra[z]) len += stbi__zreceive(a, stbi__zlength_extra[z]); + z = stbi__zhuffman_decode(a, &a->z_distance); + if (z < 0 || z >= 30) return stbi__err("bad huffman code","Corrupt PNG"); // per DEFLATE, distance codes 30 and 31 must not appear in compressed data + dist = stbi__zdist_base[z]; + if (stbi__zdist_extra[z]) dist += stbi__zreceive(a, stbi__zdist_extra[z]); + if (zout - a->zout_start < dist) return stbi__err("bad dist","Corrupt PNG"); + if (len > a->zout_end - zout) { + if (!stbi__zexpand(a, zout, len)) return 0; + zout = a->zout; + } + p = (stbi_uc *) (zout - dist); + if (dist == 1) { // run of one byte; common in images. + stbi_uc v = *p; + if (len) { do *zout++ = v; while (--len); } + } else { + if (len) { do *zout++ = *p++; while (--len); } + } + } + } +} + +static int stbi__compute_huffman_codes(stbi__zbuf *a) +{ + static const stbi_uc length_dezigzag[19] = { 16,17,18,0,8,7,9,6,10,5,11,4,12,3,13,2,14,1,15 }; + stbi__zhuffman z_codelength; + stbi_uc lencodes[286+32+137];//padding for maximum single op + stbi_uc codelength_sizes[19]; + int i,n; + + int hlit = stbi__zreceive(a,5) + 257; + int hdist = stbi__zreceive(a,5) + 1; + int hclen = stbi__zreceive(a,4) + 4; + int ntot = hlit + hdist; + + memset(codelength_sizes, 0, sizeof(codelength_sizes)); + for (i=0; i < hclen; ++i) { + int s = stbi__zreceive(a,3); + codelength_sizes[length_dezigzag[i]] = (stbi_uc) s; + } + if (!stbi__zbuild_huffman(&z_codelength, codelength_sizes, 19)) return 0; + + n = 0; + while (n < ntot) { + int c = stbi__zhuffman_decode(a, &z_codelength); + if (c < 0 || c >= 19) return stbi__err("bad codelengths", "Corrupt PNG"); + if (c < 16) + lencodes[n++] = (stbi_uc) c; + else { + stbi_uc fill = 0; + if (c == 16) { + c = stbi__zreceive(a,2)+3; + if (n == 0) return stbi__err("bad codelengths", "Corrupt PNG"); + fill = lencodes[n-1]; + } else if (c == 17) { + c = stbi__zreceive(a,3)+3; + } else if (c == 18) { + c = stbi__zreceive(a,7)+11; + } else { + return stbi__err("bad codelengths", "Corrupt PNG"); + } + if (ntot - n < c) return stbi__err("bad codelengths", "Corrupt PNG"); + memset(lencodes+n, fill, c); + n += c; + } + } + if (n != ntot) return stbi__err("bad codelengths","Corrupt PNG"); + if (!stbi__zbuild_huffman(&a->z_length, lencodes, hlit)) return 0; + if (!stbi__zbuild_huffman(&a->z_distance, lencodes+hlit, hdist)) return 0; + return 1; +} + +static int stbi__parse_uncompressed_block(stbi__zbuf *a) +{ + stbi_uc header[4]; + int len,nlen,k; + if (a->num_bits & 7) + stbi__zreceive(a, a->num_bits & 7); // discard + // drain the bit-packed data into header + k = 0; + while (a->num_bits > 0) { + header[k++] = (stbi_uc) (a->code_buffer & 255); // suppress MSVC run-time check + a->code_buffer >>= 8; + a->num_bits -= 8; + } + if (a->num_bits < 0) return stbi__err("zlib corrupt","Corrupt PNG"); + // now fill header the normal way + while (k < 4) + header[k++] = stbi__zget8(a); + len = header[1] * 256 + header[0]; + nlen = header[3] * 256 + header[2]; + if (nlen != (len ^ 0xffff)) return stbi__err("zlib corrupt","Corrupt PNG"); + if (a->zbuffer + len > a->zbuffer_end) return stbi__err("read past buffer","Corrupt PNG"); + if (a->zout + len > a->zout_end) + if (!stbi__zexpand(a, a->zout, len)) return 0; + memcpy(a->zout, a->zbuffer, len); + a->zbuffer += len; + a->zout += len; + return 1; +} + +static int stbi__parse_zlib_header(stbi__zbuf *a) +{ + int cmf = stbi__zget8(a); + int cm = cmf & 15; + /* int cinfo = cmf >> 4; */ + int flg = stbi__zget8(a); + if (stbi__zeof(a)) return stbi__err("bad zlib header","Corrupt PNG"); // zlib spec + if ((cmf*256+flg) % 31 != 0) return stbi__err("bad zlib header","Corrupt PNG"); // zlib spec + if (flg & 32) return stbi__err("no preset dict","Corrupt PNG"); // preset dictionary not allowed in png + if (cm != 8) return stbi__err("bad compression","Corrupt PNG"); // DEFLATE required for png + // window = 1 << (8 + cinfo)... but who cares, we fully buffer output + return 1; +} + +static const stbi_uc stbi__zdefault_length[STBI__ZNSYMS] = +{ + 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, + 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, + 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, + 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, + 8,8,8,8,8,8,8,8,8,8,8,8,8,8,8,8, 9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9, + 9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9, 9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9, + 9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9, 9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9, + 9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9, 9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9, + 7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7, 7,7,7,7,7,7,7,7,8,8,8,8,8,8,8,8 +}; +static const stbi_uc stbi__zdefault_distance[32] = +{ + 5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5,5 +}; +/* +Init algorithm: +{ + int i; // use <= to match clearly with spec + for (i=0; i <= 143; ++i) stbi__zdefault_length[i] = 8; + for ( ; i <= 255; ++i) stbi__zdefault_length[i] = 9; + for ( ; i <= 279; ++i) stbi__zdefault_length[i] = 7; + for ( ; i <= 287; ++i) stbi__zdefault_length[i] = 8; + + for (i=0; i <= 31; ++i) stbi__zdefault_distance[i] = 5; +} +*/ + +static int stbi__parse_zlib(stbi__zbuf *a, int parse_header) +{ + int final, type; + if (parse_header) + if (!stbi__parse_zlib_header(a)) return 0; + a->num_bits = 0; + a->code_buffer = 0; + a->hit_zeof_once = 0; + do { + final = stbi__zreceive(a,1); + type = stbi__zreceive(a,2); + if (type == 0) { + if (!stbi__parse_uncompressed_block(a)) return 0; + } else if (type == 3) { + return 0; + } else { + if (type == 1) { + // use fixed code lengths + if (!stbi__zbuild_huffman(&a->z_length , stbi__zdefault_length , STBI__ZNSYMS)) return 0; + if (!stbi__zbuild_huffman(&a->z_distance, stbi__zdefault_distance, 32)) return 0; + } else { + if (!stbi__compute_huffman_codes(a)) return 0; + } + if (!stbi__parse_huffman_block(a)) return 0; + } + } while (!final); + return 1; +} + +static int stbi__do_zlib(stbi__zbuf *a, char *obuf, int olen, int exp, int parse_header) +{ + a->zout_start = obuf; + a->zout = obuf; + a->zout_end = obuf + olen; + a->z_expandable = exp; + + return stbi__parse_zlib(a, parse_header); +} + +STBIDEF char *stbi_zlib_decode_malloc_guesssize(const char *buffer, int len, int initial_size, int *outlen) +{ + stbi__zbuf a; + char *p = (char *) stbi__malloc(initial_size); + if (p == NULL) return NULL; + a.zbuffer = (stbi_uc *) buffer; + a.zbuffer_end = (stbi_uc *) buffer + len; + if (stbi__do_zlib(&a, p, initial_size, 1, 1)) { + if (outlen) *outlen = (int) (a.zout - a.zout_start); + return a.zout_start; + } else { + STBI_FREE(a.zout_start); + return NULL; + } +} + +STBIDEF char *stbi_zlib_decode_malloc(char const *buffer, int len, int *outlen) +{ + return stbi_zlib_decode_malloc_guesssize(buffer, len, 16384, outlen); +} + +STBIDEF char *stbi_zlib_decode_malloc_guesssize_headerflag(const char *buffer, int len, int initial_size, int *outlen, int parse_header) +{ + stbi__zbuf a; + char *p = (char *) stbi__malloc(initial_size); + if (p == NULL) return NULL; + a.zbuffer = (stbi_uc *) buffer; + a.zbuffer_end = (stbi_uc *) buffer + len; + if (stbi__do_zlib(&a, p, initial_size, 1, parse_header)) { + if (outlen) *outlen = (int) (a.zout - a.zout_start); + return a.zout_start; + } else { + STBI_FREE(a.zout_start); + return NULL; + } +} + +STBIDEF int stbi_zlib_decode_buffer(char *obuffer, int olen, char const *ibuffer, int ilen) +{ + stbi__zbuf a; + a.zbuffer = (stbi_uc *) ibuffer; + a.zbuffer_end = (stbi_uc *) ibuffer + ilen; + if (stbi__do_zlib(&a, obuffer, olen, 0, 1)) + return (int) (a.zout - a.zout_start); + else + return -1; +} + +STBIDEF char *stbi_zlib_decode_noheader_malloc(char const *buffer, int len, int *outlen) +{ + stbi__zbuf a; + char *p = (char *) stbi__malloc(16384); + if (p == NULL) return NULL; + a.zbuffer = (stbi_uc *) buffer; + a.zbuffer_end = (stbi_uc *) buffer+len; + if (stbi__do_zlib(&a, p, 16384, 1, 0)) { + if (outlen) *outlen = (int) (a.zout - a.zout_start); + return a.zout_start; + } else { + STBI_FREE(a.zout_start); + return NULL; + } +} + +STBIDEF int stbi_zlib_decode_noheader_buffer(char *obuffer, int olen, const char *ibuffer, int ilen) +{ + stbi__zbuf a; + a.zbuffer = (stbi_uc *) ibuffer; + a.zbuffer_end = (stbi_uc *) ibuffer + ilen; + if (stbi__do_zlib(&a, obuffer, olen, 0, 0)) + return (int) (a.zout - a.zout_start); + else + return -1; +} +#endif + +// public domain "baseline" PNG decoder v0.10 Sean Barrett 2006-11-18 +// simple implementation +// - only 8-bit samples +// - no CRC checking +// - allocates lots of intermediate memory +// - avoids problem of streaming data between subsystems +// - avoids explicit window management +// performance +// - uses stb_zlib, a PD zlib implementation with fast huffman decoding + +#ifndef STBI_NO_PNG +typedef struct +{ + stbi__uint32 length; + stbi__uint32 type; +} stbi__pngchunk; + +static stbi__pngchunk stbi__get_chunk_header(stbi__context *s) +{ + stbi__pngchunk c; + c.length = stbi__get32be(s); + c.type = stbi__get32be(s); + return c; +} + +static int stbi__check_png_header(stbi__context *s) +{ + static const stbi_uc png_sig[8] = { 137,80,78,71,13,10,26,10 }; + int i; + for (i=0; i < 8; ++i) + if (stbi__get8(s) != png_sig[i]) return stbi__err("bad png sig","Not a PNG"); + return 1; +} + +typedef struct +{ + stbi__context *s; + stbi_uc *idata, *expanded, *out; + int depth; +} stbi__png; + + +enum { + STBI__F_none=0, + STBI__F_sub=1, + STBI__F_up=2, + STBI__F_avg=3, + STBI__F_paeth=4, + // synthetic filter used for first scanline to avoid needing a dummy row of 0s + STBI__F_avg_first +}; + +static stbi_uc first_row_filter[5] = +{ + STBI__F_none, + STBI__F_sub, + STBI__F_none, + STBI__F_avg_first, + STBI__F_sub // Paeth with b=c=0 turns out to be equivalent to sub +}; + +static int stbi__paeth(int a, int b, int c) +{ + // This formulation looks very different from the reference in the PNG spec, but is + // actually equivalent and has favorable data dependencies and admits straightforward + // generation of branch-free code, which helps performance significantly. + int thresh = c*3 - (a + b); + int lo = a < b ? a : b; + int hi = a < b ? b : a; + int t0 = (hi <= thresh) ? lo : c; + int t1 = (thresh <= lo) ? hi : t0; + return t1; +} + +static const stbi_uc stbi__depth_scale_table[9] = { 0, 0xff, 0x55, 0, 0x11, 0,0,0, 0x01 }; + +// adds an extra all-255 alpha channel +// dest == src is legal +// img_n must be 1 or 3 +static void stbi__create_png_alpha_expand8(stbi_uc *dest, stbi_uc *src, stbi__uint32 x, int img_n) +{ + int i; + // must process data backwards since we allow dest==src + if (img_n == 1) { + for (i=x-1; i >= 0; --i) { + dest[i*2+1] = 255; + dest[i*2+0] = src[i]; + } + } else { + STBI_ASSERT(img_n == 3); + for (i=x-1; i >= 0; --i) { + dest[i*4+3] = 255; + dest[i*4+2] = src[i*3+2]; + dest[i*4+1] = src[i*3+1]; + dest[i*4+0] = src[i*3+0]; + } + } +} + +// create the png data from post-deflated data +static int stbi__create_png_image_raw(stbi__png *a, stbi_uc *raw, stbi__uint32 raw_len, int out_n, stbi__uint32 x, stbi__uint32 y, int depth, int color) +{ + int bytes = (depth == 16 ? 2 : 1); + stbi__context *s = a->s; + stbi__uint32 i,j,stride = x*out_n*bytes; + stbi__uint32 img_len, img_width_bytes; + stbi_uc *filter_buf; + int all_ok = 1; + int k; + int img_n = s->img_n; // copy it into a local for later + + int output_bytes = out_n*bytes; + int filter_bytes = img_n*bytes; + int width = x; + + STBI_ASSERT(out_n == s->img_n || out_n == s->img_n+1); + a->out = (stbi_uc *) stbi__malloc_mad3(x, y, output_bytes, 0); // extra bytes to write off the end into + if (!a->out) return stbi__err("outofmem", "Out of memory"); + + // note: error exits here don't need to clean up a->out individually, + // stbi__do_png always does on error. + if (!stbi__mad3sizes_valid(img_n, x, depth, 7)) return stbi__err("too large", "Corrupt PNG"); + img_width_bytes = (((img_n * x * depth) + 7) >> 3); + if (!stbi__mad2sizes_valid(img_width_bytes, y, img_width_bytes)) return stbi__err("too large", "Corrupt PNG"); + img_len = (img_width_bytes + 1) * y; + + // we used to check for exact match between raw_len and img_len on non-interlaced PNGs, + // but issue #276 reported a PNG in the wild that had extra data at the end (all zeros), + // so just check for raw_len < img_len always. + if (raw_len < img_len) return stbi__err("not enough pixels","Corrupt PNG"); + + // Allocate two scan lines worth of filter workspace buffer. + filter_buf = (stbi_uc *) stbi__malloc_mad2(img_width_bytes, 2, 0); + if (!filter_buf) return stbi__err("outofmem", "Out of memory"); + + // Filtering for low-bit-depth images + if (depth < 8) { + filter_bytes = 1; + width = img_width_bytes; + } + + for (j=0; j < y; ++j) { + // cur/prior filter buffers alternate + stbi_uc *cur = filter_buf + (j & 1)*img_width_bytes; + stbi_uc *prior = filter_buf + (~j & 1)*img_width_bytes; + stbi_uc *dest = a->out + stride*j; + int nk = width * filter_bytes; + int filter = *raw++; + + // check filter type + if (filter > 4) { + all_ok = stbi__err("invalid filter","Corrupt PNG"); + break; + } + + // if first row, use special filter that doesn't sample previous row + if (j == 0) filter = first_row_filter[filter]; + + // perform actual filtering + switch (filter) { + case STBI__F_none: + memcpy(cur, raw, nk); + break; + case STBI__F_sub: + memcpy(cur, raw, filter_bytes); + for (k = filter_bytes; k < nk; ++k) + cur[k] = STBI__BYTECAST(raw[k] + cur[k-filter_bytes]); + break; + case STBI__F_up: + for (k = 0; k < nk; ++k) + cur[k] = STBI__BYTECAST(raw[k] + prior[k]); + break; + case STBI__F_avg: + for (k = 0; k < filter_bytes; ++k) + cur[k] = STBI__BYTECAST(raw[k] + (prior[k]>>1)); + for (k = filter_bytes; k < nk; ++k) + cur[k] = STBI__BYTECAST(raw[k] + ((prior[k] + cur[k-filter_bytes])>>1)); + break; + case STBI__F_paeth: + for (k = 0; k < filter_bytes; ++k) + cur[k] = STBI__BYTECAST(raw[k] + prior[k]); // prior[k] == stbi__paeth(0,prior[k],0) + for (k = filter_bytes; k < nk; ++k) + cur[k] = STBI__BYTECAST(raw[k] + stbi__paeth(cur[k-filter_bytes], prior[k], prior[k-filter_bytes])); + break; + case STBI__F_avg_first: + memcpy(cur, raw, filter_bytes); + for (k = filter_bytes; k < nk; ++k) + cur[k] = STBI__BYTECAST(raw[k] + (cur[k-filter_bytes] >> 1)); + break; + } + + raw += nk; + + // expand decoded bits in cur to dest, also adding an extra alpha channel if desired + if (depth < 8) { + stbi_uc scale = (color == 0) ? stbi__depth_scale_table[depth] : 1; // scale grayscale values to 0..255 range + stbi_uc *in = cur; + stbi_uc *out = dest; + stbi_uc inb = 0; + stbi__uint32 nsmp = x*img_n; + + // expand bits to bytes first + if (depth == 4) { + for (i=0; i < nsmp; ++i) { + if ((i & 1) == 0) inb = *in++; + *out++ = scale * (inb >> 4); + inb <<= 4; + } + } else if (depth == 2) { + for (i=0; i < nsmp; ++i) { + if ((i & 3) == 0) inb = *in++; + *out++ = scale * (inb >> 6); + inb <<= 2; + } + } else { + STBI_ASSERT(depth == 1); + for (i=0; i < nsmp; ++i) { + if ((i & 7) == 0) inb = *in++; + *out++ = scale * (inb >> 7); + inb <<= 1; + } + } + + // insert alpha=255 values if desired + if (img_n != out_n) + stbi__create_png_alpha_expand8(dest, dest, x, img_n); + } else if (depth == 8) { + if (img_n == out_n) + memcpy(dest, cur, x*img_n); + else + stbi__create_png_alpha_expand8(dest, cur, x, img_n); + } else if (depth == 16) { + // convert the image data from big-endian to platform-native + stbi__uint16 *dest16 = (stbi__uint16*)dest; + stbi__uint32 nsmp = x*img_n; + + if (img_n == out_n) { + for (i = 0; i < nsmp; ++i, ++dest16, cur += 2) + *dest16 = (cur[0] << 8) | cur[1]; + } else { + STBI_ASSERT(img_n+1 == out_n); + if (img_n == 1) { + for (i = 0; i < x; ++i, dest16 += 2, cur += 2) { + dest16[0] = (cur[0] << 8) | cur[1]; + dest16[1] = 0xffff; + } + } else { + STBI_ASSERT(img_n == 3); + for (i = 0; i < x; ++i, dest16 += 4, cur += 6) { + dest16[0] = (cur[0] << 8) | cur[1]; + dest16[1] = (cur[2] << 8) | cur[3]; + dest16[2] = (cur[4] << 8) | cur[5]; + dest16[3] = 0xffff; + } + } + } + } + } + + STBI_FREE(filter_buf); + if (!all_ok) return 0; + + return 1; +} + +static int stbi__create_png_image(stbi__png *a, stbi_uc *image_data, stbi__uint32 image_data_len, int out_n, int depth, int color, int interlaced) +{ + int bytes = (depth == 16 ? 2 : 1); + int out_bytes = out_n * bytes; + stbi_uc *final; + int p; + if (!interlaced) + return stbi__create_png_image_raw(a, image_data, image_data_len, out_n, a->s->img_x, a->s->img_y, depth, color); + + // de-interlacing + final = (stbi_uc *) stbi__malloc_mad3(a->s->img_x, a->s->img_y, out_bytes, 0); + if (!final) return stbi__err("outofmem", "Out of memory"); + for (p=0; p < 7; ++p) { + int xorig[] = { 0,4,0,2,0,1,0 }; + int yorig[] = { 0,0,4,0,2,0,1 }; + int xspc[] = { 8,8,4,4,2,2,1 }; + int yspc[] = { 8,8,8,4,4,2,2 }; + int i,j,x,y; + // pass1_x[4] = 0, pass1_x[5] = 1, pass1_x[12] = 1 + x = (a->s->img_x - xorig[p] + xspc[p]-1) / xspc[p]; + y = (a->s->img_y - yorig[p] + yspc[p]-1) / yspc[p]; + if (x && y) { + stbi__uint32 img_len = ((((a->s->img_n * x * depth) + 7) >> 3) + 1) * y; + if (!stbi__create_png_image_raw(a, image_data, image_data_len, out_n, x, y, depth, color)) { + STBI_FREE(final); + return 0; + } + for (j=0; j < y; ++j) { + for (i=0; i < x; ++i) { + int out_y = j*yspc[p]+yorig[p]; + int out_x = i*xspc[p]+xorig[p]; + memcpy(final + out_y*a->s->img_x*out_bytes + out_x*out_bytes, + a->out + (j*x+i)*out_bytes, out_bytes); + } + } + STBI_FREE(a->out); + image_data += img_len; + image_data_len -= img_len; + } + } + a->out = final; + + return 1; +} + +static int stbi__compute_transparency(stbi__png *z, stbi_uc tc[3], int out_n) +{ + stbi__context *s = z->s; + stbi__uint32 i, pixel_count = s->img_x * s->img_y; + stbi_uc *p = z->out; + + // compute color-based transparency, assuming we've + // already got 255 as the alpha value in the output + STBI_ASSERT(out_n == 2 || out_n == 4); + + if (out_n == 2) { + for (i=0; i < pixel_count; ++i) { + p[1] = (p[0] == tc[0] ? 0 : 255); + p += 2; + } + } else { + for (i=0; i < pixel_count; ++i) { + if (p[0] == tc[0] && p[1] == tc[1] && p[2] == tc[2]) + p[3] = 0; + p += 4; + } + } + return 1; +} + +static int stbi__compute_transparency16(stbi__png *z, stbi__uint16 tc[3], int out_n) +{ + stbi__context *s = z->s; + stbi__uint32 i, pixel_count = s->img_x * s->img_y; + stbi__uint16 *p = (stbi__uint16*) z->out; + + // compute color-based transparency, assuming we've + // already got 65535 as the alpha value in the output + STBI_ASSERT(out_n == 2 || out_n == 4); + + if (out_n == 2) { + for (i = 0; i < pixel_count; ++i) { + p[1] = (p[0] == tc[0] ? 0 : 65535); + p += 2; + } + } else { + for (i = 0; i < pixel_count; ++i) { + if (p[0] == tc[0] && p[1] == tc[1] && p[2] == tc[2]) + p[3] = 0; + p += 4; + } + } + return 1; +} + +static int stbi__expand_png_palette(stbi__png *a, stbi_uc *palette, int len, int pal_img_n) +{ + stbi__uint32 i, pixel_count = a->s->img_x * a->s->img_y; + stbi_uc *p, *temp_out, *orig = a->out; + + p = (stbi_uc *) stbi__malloc_mad2(pixel_count, pal_img_n, 0); + if (p == NULL) return stbi__err("outofmem", "Out of memory"); + + // between here and free(out) below, exitting would leak + temp_out = p; + + if (pal_img_n == 3) { + for (i=0; i < pixel_count; ++i) { + int n = orig[i]*4; + p[0] = palette[n ]; + p[1] = palette[n+1]; + p[2] = palette[n+2]; + p += 3; + } + } else { + for (i=0; i < pixel_count; ++i) { + int n = orig[i]*4; + p[0] = palette[n ]; + p[1] = palette[n+1]; + p[2] = palette[n+2]; + p[3] = palette[n+3]; + p += 4; + } + } + STBI_FREE(a->out); + a->out = temp_out; + + STBI_NOTUSED(len); + + return 1; +} + +static int stbi__unpremultiply_on_load_global = 0; +static int stbi__de_iphone_flag_global = 0; + +STBIDEF void stbi_set_unpremultiply_on_load(int flag_true_if_should_unpremultiply) +{ + stbi__unpremultiply_on_load_global = flag_true_if_should_unpremultiply; +} + +STBIDEF void stbi_convert_iphone_png_to_rgb(int flag_true_if_should_convert) +{ + stbi__de_iphone_flag_global = flag_true_if_should_convert; +} + +#ifndef STBI_THREAD_LOCAL +#define stbi__unpremultiply_on_load stbi__unpremultiply_on_load_global +#define stbi__de_iphone_flag stbi__de_iphone_flag_global +#else +static STBI_THREAD_LOCAL int stbi__unpremultiply_on_load_local, stbi__unpremultiply_on_load_set; +static STBI_THREAD_LOCAL int stbi__de_iphone_flag_local, stbi__de_iphone_flag_set; + +STBIDEF void stbi_set_unpremultiply_on_load_thread(int flag_true_if_should_unpremultiply) +{ + stbi__unpremultiply_on_load_local = flag_true_if_should_unpremultiply; + stbi__unpremultiply_on_load_set = 1; +} + +STBIDEF void stbi_convert_iphone_png_to_rgb_thread(int flag_true_if_should_convert) +{ + stbi__de_iphone_flag_local = flag_true_if_should_convert; + stbi__de_iphone_flag_set = 1; +} + +#define stbi__unpremultiply_on_load (stbi__unpremultiply_on_load_set \ + ? stbi__unpremultiply_on_load_local \ + : stbi__unpremultiply_on_load_global) +#define stbi__de_iphone_flag (stbi__de_iphone_flag_set \ + ? stbi__de_iphone_flag_local \ + : stbi__de_iphone_flag_global) +#endif // STBI_THREAD_LOCAL + +static void stbi__de_iphone(stbi__png *z) +{ + stbi__context *s = z->s; + stbi__uint32 i, pixel_count = s->img_x * s->img_y; + stbi_uc *p = z->out; + + if (s->img_out_n == 3) { // convert bgr to rgb + for (i=0; i < pixel_count; ++i) { + stbi_uc t = p[0]; + p[0] = p[2]; + p[2] = t; + p += 3; + } + } else { + STBI_ASSERT(s->img_out_n == 4); + if (stbi__unpremultiply_on_load) { + // convert bgr to rgb and unpremultiply + for (i=0; i < pixel_count; ++i) { + stbi_uc a = p[3]; + stbi_uc t = p[0]; + if (a) { + stbi_uc half = a / 2; + p[0] = (p[2] * 255 + half) / a; + p[1] = (p[1] * 255 + half) / a; + p[2] = ( t * 255 + half) / a; + } else { + p[0] = p[2]; + p[2] = t; + } + p += 4; + } + } else { + // convert bgr to rgb + for (i=0; i < pixel_count; ++i) { + stbi_uc t = p[0]; + p[0] = p[2]; + p[2] = t; + p += 4; + } + } + } +} + +#define STBI__PNG_TYPE(a,b,c,d) (((unsigned) (a) << 24) + ((unsigned) (b) << 16) + ((unsigned) (c) << 8) + (unsigned) (d)) + +static int stbi__parse_png_file(stbi__png *z, int scan, int req_comp) +{ + stbi_uc palette[1024], pal_img_n=0; + stbi_uc has_trans=0, tc[3]={0}; + stbi__uint16 tc16[3]; + stbi__uint32 ioff=0, idata_limit=0, i, pal_len=0; + int first=1,k,interlace=0, color=0, is_iphone=0; + stbi__context *s = z->s; + + z->expanded = NULL; + z->idata = NULL; + z->out = NULL; + + if (!stbi__check_png_header(s)) return 0; + + if (scan == STBI__SCAN_type) return 1; + + for (;;) { + stbi__pngchunk c = stbi__get_chunk_header(s); + switch (c.type) { + case STBI__PNG_TYPE('C','g','B','I'): + is_iphone = 1; + stbi__skip(s, c.length); + break; + case STBI__PNG_TYPE('I','H','D','R'): { + int comp,filter; + if (!first) return stbi__err("multiple IHDR","Corrupt PNG"); + first = 0; + if (c.length != 13) return stbi__err("bad IHDR len","Corrupt PNG"); + s->img_x = stbi__get32be(s); + s->img_y = stbi__get32be(s); + if (s->img_y > STBI_MAX_DIMENSIONS) return stbi__err("too large","Very large image (corrupt?)"); + if (s->img_x > STBI_MAX_DIMENSIONS) return stbi__err("too large","Very large image (corrupt?)"); + z->depth = stbi__get8(s); if (z->depth != 1 && z->depth != 2 && z->depth != 4 && z->depth != 8 && z->depth != 16) return stbi__err("1/2/4/8/16-bit only","PNG not supported: 1/2/4/8/16-bit only"); + color = stbi__get8(s); if (color > 6) return stbi__err("bad ctype","Corrupt PNG"); + if (color == 3 && z->depth == 16) return stbi__err("bad ctype","Corrupt PNG"); + if (color == 3) pal_img_n = 3; else if (color & 1) return stbi__err("bad ctype","Corrupt PNG"); + comp = stbi__get8(s); if (comp) return stbi__err("bad comp method","Corrupt PNG"); + filter= stbi__get8(s); if (filter) return stbi__err("bad filter method","Corrupt PNG"); + interlace = stbi__get8(s); if (interlace>1) return stbi__err("bad interlace method","Corrupt PNG"); + if (!s->img_x || !s->img_y) return stbi__err("0-pixel image","Corrupt PNG"); + if (!pal_img_n) { + s->img_n = (color & 2 ? 3 : 1) + (color & 4 ? 1 : 0); + if ((1 << 30) / s->img_x / s->img_n < s->img_y) return stbi__err("too large", "Image too large to decode"); + } else { + // if paletted, then pal_n is our final components, and + // img_n is # components to decompress/filter. + s->img_n = 1; + if ((1 << 30) / s->img_x / 4 < s->img_y) return stbi__err("too large","Corrupt PNG"); + } + // even with SCAN_header, have to scan to see if we have a tRNS + break; + } + + case STBI__PNG_TYPE('P','L','T','E'): { + if (first) return stbi__err("first not IHDR", "Corrupt PNG"); + if (c.length > 256*3) return stbi__err("invalid PLTE","Corrupt PNG"); + pal_len = c.length / 3; + if (pal_len * 3 != c.length) return stbi__err("invalid PLTE","Corrupt PNG"); + for (i=0; i < pal_len; ++i) { + palette[i*4+0] = stbi__get8(s); + palette[i*4+1] = stbi__get8(s); + palette[i*4+2] = stbi__get8(s); + palette[i*4+3] = 255; + } + break; + } + + case STBI__PNG_TYPE('t','R','N','S'): { + if (first) return stbi__err("first not IHDR", "Corrupt PNG"); + if (z->idata) return stbi__err("tRNS after IDAT","Corrupt PNG"); + if (pal_img_n) { + if (scan == STBI__SCAN_header) { s->img_n = 4; return 1; } + if (pal_len == 0) return stbi__err("tRNS before PLTE","Corrupt PNG"); + if (c.length > pal_len) return stbi__err("bad tRNS len","Corrupt PNG"); + pal_img_n = 4; + for (i=0; i < c.length; ++i) + palette[i*4+3] = stbi__get8(s); + } else { + if (!(s->img_n & 1)) return stbi__err("tRNS with alpha","Corrupt PNG"); + if (c.length != (stbi__uint32) s->img_n*2) return stbi__err("bad tRNS len","Corrupt PNG"); + has_trans = 1; + // non-paletted with tRNS = constant alpha. if header-scanning, we can stop now. + if (scan == STBI__SCAN_header) { ++s->img_n; return 1; } + if (z->depth == 16) { + for (k = 0; k < s->img_n && k < 3; ++k) // extra loop test to suppress false GCC warning + tc16[k] = (stbi__uint16)stbi__get16be(s); // copy the values as-is + } else { + for (k = 0; k < s->img_n && k < 3; ++k) + tc[k] = (stbi_uc)(stbi__get16be(s) & 255) * stbi__depth_scale_table[z->depth]; // non 8-bit images will be larger + } + } + break; + } + + case STBI__PNG_TYPE('I','D','A','T'): { + if (first) return stbi__err("first not IHDR", "Corrupt PNG"); + if (pal_img_n && !pal_len) return stbi__err("no PLTE","Corrupt PNG"); + if (scan == STBI__SCAN_header) { + // header scan definitely stops at first IDAT + if (pal_img_n) + s->img_n = pal_img_n; + return 1; + } + if (c.length > (1u << 30)) return stbi__err("IDAT size limit", "IDAT section larger than 2^30 bytes"); + if ((int)(ioff + c.length) < (int)ioff) return 0; + if (ioff + c.length > idata_limit) { + stbi__uint32 idata_limit_old = idata_limit; + stbi_uc *p; + if (idata_limit == 0) idata_limit = c.length > 4096 ? c.length : 4096; + while (ioff + c.length > idata_limit) + idata_limit *= 2; + STBI_NOTUSED(idata_limit_old); + p = (stbi_uc *) STBI_REALLOC_SIZED(z->idata, idata_limit_old, idata_limit); if (p == NULL) return stbi__err("outofmem", "Out of memory"); + z->idata = p; + } + if (!stbi__getn(s, z->idata+ioff,c.length)) return stbi__err("outofdata","Corrupt PNG"); + ioff += c.length; + break; + } + + case STBI__PNG_TYPE('I','E','N','D'): { + stbi__uint32 raw_len, bpl; + if (first) return stbi__err("first not IHDR", "Corrupt PNG"); + if (scan != STBI__SCAN_load) return 1; + if (z->idata == NULL) return stbi__err("no IDAT","Corrupt PNG"); + // initial guess for decoded data size to avoid unnecessary reallocs + bpl = (s->img_x * z->depth + 7) / 8; // bytes per line, per component + raw_len = bpl * s->img_y * s->img_n /* pixels */ + s->img_y /* filter mode per row */; + z->expanded = (stbi_uc *) stbi_zlib_decode_malloc_guesssize_headerflag((char *) z->idata, ioff, raw_len, (int *) &raw_len, !is_iphone); + if (z->expanded == NULL) return 0; // zlib should set error + STBI_FREE(z->idata); z->idata = NULL; + if ((req_comp == s->img_n+1 && req_comp != 3 && !pal_img_n) || has_trans) + s->img_out_n = s->img_n+1; + else + s->img_out_n = s->img_n; + if (!stbi__create_png_image(z, z->expanded, raw_len, s->img_out_n, z->depth, color, interlace)) return 0; + if (has_trans) { + if (z->depth == 16) { + if (!stbi__compute_transparency16(z, tc16, s->img_out_n)) return 0; + } else { + if (!stbi__compute_transparency(z, tc, s->img_out_n)) return 0; + } + } + if (is_iphone && stbi__de_iphone_flag && s->img_out_n > 2) + stbi__de_iphone(z); + if (pal_img_n) { + // pal_img_n == 3 or 4 + s->img_n = pal_img_n; // record the actual colors we had + s->img_out_n = pal_img_n; + if (req_comp >= 3) s->img_out_n = req_comp; + if (!stbi__expand_png_palette(z, palette, pal_len, s->img_out_n)) + return 0; + } else if (has_trans) { + // non-paletted image with tRNS -> source image has (constant) alpha + ++s->img_n; + } + STBI_FREE(z->expanded); z->expanded = NULL; + // end of PNG chunk, read and skip CRC + stbi__get32be(s); + return 1; + } + + default: + // if critical, fail + if (first) return stbi__err("first not IHDR", "Corrupt PNG"); + if ((c.type & (1 << 29)) == 0) { + #ifndef STBI_NO_FAILURE_STRINGS + // not threadsafe + static char invalid_chunk[] = "XXXX PNG chunk not known"; + invalid_chunk[0] = STBI__BYTECAST(c.type >> 24); + invalid_chunk[1] = STBI__BYTECAST(c.type >> 16); + invalid_chunk[2] = STBI__BYTECAST(c.type >> 8); + invalid_chunk[3] = STBI__BYTECAST(c.type >> 0); + #endif + return stbi__err(invalid_chunk, "PNG not supported: unknown PNG chunk type"); + } + stbi__skip(s, c.length); + break; + } + // end of PNG chunk, read and skip CRC + stbi__get32be(s); + } +} + +static void *stbi__do_png(stbi__png *p, int *x, int *y, int *n, int req_comp, stbi__result_info *ri) +{ + void *result=NULL; + if (req_comp < 0 || req_comp > 4) return stbi__errpuc("bad req_comp", "Internal error"); + if (stbi__parse_png_file(p, STBI__SCAN_load, req_comp)) { + if (p->depth <= 8) + ri->bits_per_channel = 8; + else if (p->depth == 16) + ri->bits_per_channel = 16; + else + return stbi__errpuc("bad bits_per_channel", "PNG not supported: unsupported color depth"); + result = p->out; + p->out = NULL; + if (req_comp && req_comp != p->s->img_out_n) { + if (ri->bits_per_channel == 8) + result = stbi__convert_format((unsigned char *) result, p->s->img_out_n, req_comp, p->s->img_x, p->s->img_y); + else + result = stbi__convert_format16((stbi__uint16 *) result, p->s->img_out_n, req_comp, p->s->img_x, p->s->img_y); + p->s->img_out_n = req_comp; + if (result == NULL) return result; + } + *x = p->s->img_x; + *y = p->s->img_y; + if (n) *n = p->s->img_n; + } + STBI_FREE(p->out); p->out = NULL; + STBI_FREE(p->expanded); p->expanded = NULL; + STBI_FREE(p->idata); p->idata = NULL; + + return result; +} + +static void *stbi__png_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri) +{ + stbi__png p; + p.s = s; + return stbi__do_png(&p, x,y,comp,req_comp, ri); +} + +static int stbi__png_test(stbi__context *s) +{ + int r; + r = stbi__check_png_header(s); + stbi__rewind(s); + return r; +} + +static int stbi__png_info_raw(stbi__png *p, int *x, int *y, int *comp) +{ + if (!stbi__parse_png_file(p, STBI__SCAN_header, 0)) { + stbi__rewind( p->s ); + return 0; + } + if (x) *x = p->s->img_x; + if (y) *y = p->s->img_y; + if (comp) *comp = p->s->img_n; + return 1; +} + +static int stbi__png_info(stbi__context *s, int *x, int *y, int *comp) +{ + stbi__png p; + p.s = s; + return stbi__png_info_raw(&p, x, y, comp); +} + +static int stbi__png_is16(stbi__context *s) +{ + stbi__png p; + p.s = s; + if (!stbi__png_info_raw(&p, NULL, NULL, NULL)) + return 0; + if (p.depth != 16) { + stbi__rewind(p.s); + return 0; + } + return 1; +} +#endif + +// Microsoft/Windows BMP image + +#ifndef STBI_NO_BMP +static int stbi__bmp_test_raw(stbi__context *s) +{ + int r; + int sz; + if (stbi__get8(s) != 'B') return 0; + if (stbi__get8(s) != 'M') return 0; + stbi__get32le(s); // discard filesize + stbi__get16le(s); // discard reserved + stbi__get16le(s); // discard reserved + stbi__get32le(s); // discard data offset + sz = stbi__get32le(s); + r = (sz == 12 || sz == 40 || sz == 56 || sz == 108 || sz == 124); + return r; +} + +static int stbi__bmp_test(stbi__context *s) +{ + int r = stbi__bmp_test_raw(s); + stbi__rewind(s); + return r; +} + + +// returns 0..31 for the highest set bit +static int stbi__high_bit(unsigned int z) +{ + int n=0; + if (z == 0) return -1; + if (z >= 0x10000) { n += 16; z >>= 16; } + if (z >= 0x00100) { n += 8; z >>= 8; } + if (z >= 0x00010) { n += 4; z >>= 4; } + if (z >= 0x00004) { n += 2; z >>= 2; } + if (z >= 0x00002) { n += 1;/* >>= 1;*/ } + return n; +} + +static int stbi__bitcount(unsigned int a) +{ + a = (a & 0x55555555) + ((a >> 1) & 0x55555555); // max 2 + a = (a & 0x33333333) + ((a >> 2) & 0x33333333); // max 4 + a = (a + (a >> 4)) & 0x0f0f0f0f; // max 8 per 4, now 8 bits + a = (a + (a >> 8)); // max 16 per 8 bits + a = (a + (a >> 16)); // max 32 per 8 bits + return a & 0xff; +} + +// extract an arbitrarily-aligned N-bit value (N=bits) +// from v, and then make it 8-bits long and fractionally +// extend it to full full range. +static int stbi__shiftsigned(unsigned int v, int shift, int bits) +{ + static unsigned int mul_table[9] = { + 0, + 0xff/*0b11111111*/, 0x55/*0b01010101*/, 0x49/*0b01001001*/, 0x11/*0b00010001*/, + 0x21/*0b00100001*/, 0x41/*0b01000001*/, 0x81/*0b10000001*/, 0x01/*0b00000001*/, + }; + static unsigned int shift_table[9] = { + 0, 0,0,1,0,2,4,6,0, + }; + if (shift < 0) + v <<= -shift; + else + v >>= shift; + STBI_ASSERT(v < 256); + v >>= (8-bits); + STBI_ASSERT(bits >= 0 && bits <= 8); + return (int) ((unsigned) v * mul_table[bits]) >> shift_table[bits]; +} + +typedef struct +{ + int bpp, offset, hsz; + unsigned int mr,mg,mb,ma, all_a; + int extra_read; +} stbi__bmp_data; + +static int stbi__bmp_set_mask_defaults(stbi__bmp_data *info, int compress) +{ + // BI_BITFIELDS specifies masks explicitly, don't override + if (compress == 3) + return 1; + + if (compress == 0) { + if (info->bpp == 16) { + info->mr = 31u << 10; + info->mg = 31u << 5; + info->mb = 31u << 0; + } else if (info->bpp == 32) { + info->mr = 0xffu << 16; + info->mg = 0xffu << 8; + info->mb = 0xffu << 0; + info->ma = 0xffu << 24; + info->all_a = 0; // if all_a is 0 at end, then we loaded alpha channel but it was all 0 + } else { + // otherwise, use defaults, which is all-0 + info->mr = info->mg = info->mb = info->ma = 0; + } + return 1; + } + return 0; // error +} + +static void *stbi__bmp_parse_header(stbi__context *s, stbi__bmp_data *info) +{ + int hsz; + if (stbi__get8(s) != 'B' || stbi__get8(s) != 'M') return stbi__errpuc("not BMP", "Corrupt BMP"); + stbi__get32le(s); // discard filesize + stbi__get16le(s); // discard reserved + stbi__get16le(s); // discard reserved + info->offset = stbi__get32le(s); + info->hsz = hsz = stbi__get32le(s); + info->mr = info->mg = info->mb = info->ma = 0; + info->extra_read = 14; + + if (info->offset < 0) return stbi__errpuc("bad BMP", "bad BMP"); + + if (hsz != 12 && hsz != 40 && hsz != 56 && hsz != 108 && hsz != 124) return stbi__errpuc("unknown BMP", "BMP type not supported: unknown"); + if (hsz == 12) { + s->img_x = stbi__get16le(s); + s->img_y = stbi__get16le(s); + } else { + s->img_x = stbi__get32le(s); + s->img_y = stbi__get32le(s); + } + if (stbi__get16le(s) != 1) return stbi__errpuc("bad BMP", "bad BMP"); + info->bpp = stbi__get16le(s); + if (hsz != 12) { + int compress = stbi__get32le(s); + if (compress == 1 || compress == 2) return stbi__errpuc("BMP RLE", "BMP type not supported: RLE"); + if (compress >= 4) return stbi__errpuc("BMP JPEG/PNG", "BMP type not supported: unsupported compression"); // this includes PNG/JPEG modes + if (compress == 3 && info->bpp != 16 && info->bpp != 32) return stbi__errpuc("bad BMP", "bad BMP"); // bitfields requires 16 or 32 bits/pixel + stbi__get32le(s); // discard sizeof + stbi__get32le(s); // discard hres + stbi__get32le(s); // discard vres + stbi__get32le(s); // discard colorsused + stbi__get32le(s); // discard max important + if (hsz == 40 || hsz == 56) { + if (hsz == 56) { + stbi__get32le(s); + stbi__get32le(s); + stbi__get32le(s); + stbi__get32le(s); + } + if (info->bpp == 16 || info->bpp == 32) { + if (compress == 0) { + stbi__bmp_set_mask_defaults(info, compress); + } else if (compress == 3) { + info->mr = stbi__get32le(s); + info->mg = stbi__get32le(s); + info->mb = stbi__get32le(s); + info->extra_read += 12; + // not documented, but generated by photoshop and handled by mspaint + if (info->mr == info->mg && info->mg == info->mb) { + // ?!?!? + return stbi__errpuc("bad BMP", "bad BMP"); + } + } else + return stbi__errpuc("bad BMP", "bad BMP"); + } + } else { + // V4/V5 header + int i; + if (hsz != 108 && hsz != 124) + return stbi__errpuc("bad BMP", "bad BMP"); + info->mr = stbi__get32le(s); + info->mg = stbi__get32le(s); + info->mb = stbi__get32le(s); + info->ma = stbi__get32le(s); + if (compress != 3) // override mr/mg/mb unless in BI_BITFIELDS mode, as per docs + stbi__bmp_set_mask_defaults(info, compress); + stbi__get32le(s); // discard color space + for (i=0; i < 12; ++i) + stbi__get32le(s); // discard color space parameters + if (hsz == 124) { + stbi__get32le(s); // discard rendering intent + stbi__get32le(s); // discard offset of profile data + stbi__get32le(s); // discard size of profile data + stbi__get32le(s); // discard reserved + } + } + } + return (void *) 1; +} + + +static void *stbi__bmp_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri) +{ + stbi_uc *out; + unsigned int mr=0,mg=0,mb=0,ma=0, all_a; + stbi_uc pal[256][4]; + int psize=0,i,j,width; + int flip_vertically, pad, target; + stbi__bmp_data info; + STBI_NOTUSED(ri); + + info.all_a = 255; + if (stbi__bmp_parse_header(s, &info) == NULL) + return NULL; // error code already set + + flip_vertically = ((int) s->img_y) > 0; + s->img_y = abs((int) s->img_y); + + if (s->img_y > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); + if (s->img_x > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); + + mr = info.mr; + mg = info.mg; + mb = info.mb; + ma = info.ma; + all_a = info.all_a; + + if (info.hsz == 12) { + if (info.bpp < 24) + psize = (info.offset - info.extra_read - 24) / 3; + } else { + if (info.bpp < 16) + psize = (info.offset - info.extra_read - info.hsz) >> 2; + } + if (psize == 0) { + // accept some number of extra bytes after the header, but if the offset points either to before + // the header ends or implies a large amount of extra data, reject the file as malformed + int bytes_read_so_far = s->callback_already_read + (int)(s->img_buffer - s->img_buffer_original); + int header_limit = 1024; // max we actually read is below 256 bytes currently. + int extra_data_limit = 256*4; // what ordinarily goes here is a palette; 256 entries*4 bytes is its max size. + if (bytes_read_so_far <= 0 || bytes_read_so_far > header_limit) { + return stbi__errpuc("bad header", "Corrupt BMP"); + } + // we established that bytes_read_so_far is positive and sensible. + // the first half of this test rejects offsets that are either too small positives, or + // negative, and guarantees that info.offset >= bytes_read_so_far > 0. this in turn + // ensures the number computed in the second half of the test can't overflow. + if (info.offset < bytes_read_so_far || info.offset - bytes_read_so_far > extra_data_limit) { + return stbi__errpuc("bad offset", "Corrupt BMP"); + } else { + stbi__skip(s, info.offset - bytes_read_so_far); + } + } + + if (info.bpp == 24 && ma == 0xff000000) + s->img_n = 3; + else + s->img_n = ma ? 4 : 3; + if (req_comp && req_comp >= 3) // we can directly decode 3 or 4 + target = req_comp; + else + target = s->img_n; // if they want monochrome, we'll post-convert + + // sanity-check size + if (!stbi__mad3sizes_valid(target, s->img_x, s->img_y, 0)) + return stbi__errpuc("too large", "Corrupt BMP"); + + out = (stbi_uc *) stbi__malloc_mad3(target, s->img_x, s->img_y, 0); + if (!out) return stbi__errpuc("outofmem", "Out of memory"); + if (info.bpp < 16) { + int z=0; + if (psize == 0 || psize > 256) { STBI_FREE(out); return stbi__errpuc("invalid", "Corrupt BMP"); } + for (i=0; i < psize; ++i) { + pal[i][2] = stbi__get8(s); + pal[i][1] = stbi__get8(s); + pal[i][0] = stbi__get8(s); + if (info.hsz != 12) stbi__get8(s); + pal[i][3] = 255; + } + stbi__skip(s, info.offset - info.extra_read - info.hsz - psize * (info.hsz == 12 ? 3 : 4)); + if (info.bpp == 1) width = (s->img_x + 7) >> 3; + else if (info.bpp == 4) width = (s->img_x + 1) >> 1; + else if (info.bpp == 8) width = s->img_x; + else { STBI_FREE(out); return stbi__errpuc("bad bpp", "Corrupt BMP"); } + pad = (-width)&3; + if (info.bpp == 1) { + for (j=0; j < (int) s->img_y; ++j) { + int bit_offset = 7, v = stbi__get8(s); + for (i=0; i < (int) s->img_x; ++i) { + int color = (v>>bit_offset)&0x1; + out[z++] = pal[color][0]; + out[z++] = pal[color][1]; + out[z++] = pal[color][2]; + if (target == 4) out[z++] = 255; + if (i+1 == (int) s->img_x) break; + if((--bit_offset) < 0) { + bit_offset = 7; + v = stbi__get8(s); + } + } + stbi__skip(s, pad); + } + } else { + for (j=0; j < (int) s->img_y; ++j) { + for (i=0; i < (int) s->img_x; i += 2) { + int v=stbi__get8(s),v2=0; + if (info.bpp == 4) { + v2 = v & 15; + v >>= 4; + } + out[z++] = pal[v][0]; + out[z++] = pal[v][1]; + out[z++] = pal[v][2]; + if (target == 4) out[z++] = 255; + if (i+1 == (int) s->img_x) break; + v = (info.bpp == 8) ? stbi__get8(s) : v2; + out[z++] = pal[v][0]; + out[z++] = pal[v][1]; + out[z++] = pal[v][2]; + if (target == 4) out[z++] = 255; + } + stbi__skip(s, pad); + } + } + } else { + int rshift=0,gshift=0,bshift=0,ashift=0,rcount=0,gcount=0,bcount=0,acount=0; + int z = 0; + int easy=0; + stbi__skip(s, info.offset - info.extra_read - info.hsz); + if (info.bpp == 24) width = 3 * s->img_x; + else if (info.bpp == 16) width = 2*s->img_x; + else /* bpp = 32 and pad = 0 */ width=0; + pad = (-width) & 3; + if (info.bpp == 24) { + easy = 1; + } else if (info.bpp == 32) { + if (mb == 0xff && mg == 0xff00 && mr == 0x00ff0000 && ma == 0xff000000) + easy = 2; + } + if (!easy) { + if (!mr || !mg || !mb) { STBI_FREE(out); return stbi__errpuc("bad masks", "Corrupt BMP"); } + // right shift amt to put high bit in position #7 + rshift = stbi__high_bit(mr)-7; rcount = stbi__bitcount(mr); + gshift = stbi__high_bit(mg)-7; gcount = stbi__bitcount(mg); + bshift = stbi__high_bit(mb)-7; bcount = stbi__bitcount(mb); + ashift = stbi__high_bit(ma)-7; acount = stbi__bitcount(ma); + if (rcount > 8 || gcount > 8 || bcount > 8 || acount > 8) { STBI_FREE(out); return stbi__errpuc("bad masks", "Corrupt BMP"); } + } + for (j=0; j < (int) s->img_y; ++j) { + if (easy) { + for (i=0; i < (int) s->img_x; ++i) { + unsigned char a; + out[z+2] = stbi__get8(s); + out[z+1] = stbi__get8(s); + out[z+0] = stbi__get8(s); + z += 3; + a = (easy == 2 ? stbi__get8(s) : 255); + all_a |= a; + if (target == 4) out[z++] = a; + } + } else { + int bpp = info.bpp; + for (i=0; i < (int) s->img_x; ++i) { + stbi__uint32 v = (bpp == 16 ? (stbi__uint32) stbi__get16le(s) : stbi__get32le(s)); + unsigned int a; + out[z++] = STBI__BYTECAST(stbi__shiftsigned(v & mr, rshift, rcount)); + out[z++] = STBI__BYTECAST(stbi__shiftsigned(v & mg, gshift, gcount)); + out[z++] = STBI__BYTECAST(stbi__shiftsigned(v & mb, bshift, bcount)); + a = (ma ? stbi__shiftsigned(v & ma, ashift, acount) : 255); + all_a |= a; + if (target == 4) out[z++] = STBI__BYTECAST(a); + } + } + stbi__skip(s, pad); + } + } + + // if alpha channel is all 0s, replace with all 255s + if (target == 4 && all_a == 0) + for (i=4*s->img_x*s->img_y-1; i >= 0; i -= 4) + out[i] = 255; + + if (flip_vertically) { + stbi_uc t; + for (j=0; j < (int) s->img_y>>1; ++j) { + stbi_uc *p1 = out + j *s->img_x*target; + stbi_uc *p2 = out + (s->img_y-1-j)*s->img_x*target; + for (i=0; i < (int) s->img_x*target; ++i) { + t = p1[i]; p1[i] = p2[i]; p2[i] = t; + } + } + } + + if (req_comp && req_comp != target) { + out = stbi__convert_format(out, target, req_comp, s->img_x, s->img_y); + if (out == NULL) return out; // stbi__convert_format frees input on failure + } + + *x = s->img_x; + *y = s->img_y; + if (comp) *comp = s->img_n; + return out; +} +#endif + +// Targa Truevision - TGA +// by Jonathan Dummer +#ifndef STBI_NO_TGA +// returns STBI_rgb or whatever, 0 on error +static int stbi__tga_get_comp(int bits_per_pixel, int is_grey, int* is_rgb16) +{ + // only RGB or RGBA (incl. 16bit) or grey allowed + if (is_rgb16) *is_rgb16 = 0; + switch(bits_per_pixel) { + case 8: return STBI_grey; + case 16: if(is_grey) return STBI_grey_alpha; + // fallthrough + case 15: if(is_rgb16) *is_rgb16 = 1; + return STBI_rgb; + case 24: // fallthrough + case 32: return bits_per_pixel/8; + default: return 0; + } +} + +static int stbi__tga_info(stbi__context *s, int *x, int *y, int *comp) +{ + int tga_w, tga_h, tga_comp, tga_image_type, tga_bits_per_pixel, tga_colormap_bpp; + int sz, tga_colormap_type; + stbi__get8(s); // discard Offset + tga_colormap_type = stbi__get8(s); // colormap type + if( tga_colormap_type > 1 ) { + stbi__rewind(s); + return 0; // only RGB or indexed allowed + } + tga_image_type = stbi__get8(s); // image type + if ( tga_colormap_type == 1 ) { // colormapped (paletted) image + if (tga_image_type != 1 && tga_image_type != 9) { + stbi__rewind(s); + return 0; + } + stbi__skip(s,4); // skip index of first colormap entry and number of entries + sz = stbi__get8(s); // check bits per palette color entry + if ( (sz != 8) && (sz != 15) && (sz != 16) && (sz != 24) && (sz != 32) ) { + stbi__rewind(s); + return 0; + } + stbi__skip(s,4); // skip image x and y origin + tga_colormap_bpp = sz; + } else { // "normal" image w/o colormap - only RGB or grey allowed, +/- RLE + if ( (tga_image_type != 2) && (tga_image_type != 3) && (tga_image_type != 10) && (tga_image_type != 11) ) { + stbi__rewind(s); + return 0; // only RGB or grey allowed, +/- RLE + } + stbi__skip(s,9); // skip colormap specification and image x/y origin + tga_colormap_bpp = 0; + } + tga_w = stbi__get16le(s); + if( tga_w < 1 ) { + stbi__rewind(s); + return 0; // test width + } + tga_h = stbi__get16le(s); + if( tga_h < 1 ) { + stbi__rewind(s); + return 0; // test height + } + tga_bits_per_pixel = stbi__get8(s); // bits per pixel + stbi__get8(s); // ignore alpha bits + if (tga_colormap_bpp != 0) { + if((tga_bits_per_pixel != 8) && (tga_bits_per_pixel != 16)) { + // when using a colormap, tga_bits_per_pixel is the size of the indexes + // I don't think anything but 8 or 16bit indexes makes sense + stbi__rewind(s); + return 0; + } + tga_comp = stbi__tga_get_comp(tga_colormap_bpp, 0, NULL); + } else { + tga_comp = stbi__tga_get_comp(tga_bits_per_pixel, (tga_image_type == 3) || (tga_image_type == 11), NULL); + } + if(!tga_comp) { + stbi__rewind(s); + return 0; + } + if (x) *x = tga_w; + if (y) *y = tga_h; + if (comp) *comp = tga_comp; + return 1; // seems to have passed everything +} + +static int stbi__tga_test(stbi__context *s) +{ + int res = 0; + int sz, tga_color_type; + stbi__get8(s); // discard Offset + tga_color_type = stbi__get8(s); // color type + if ( tga_color_type > 1 ) goto errorEnd; // only RGB or indexed allowed + sz = stbi__get8(s); // image type + if ( tga_color_type == 1 ) { // colormapped (paletted) image + if (sz != 1 && sz != 9) goto errorEnd; // colortype 1 demands image type 1 or 9 + stbi__skip(s,4); // skip index of first colormap entry and number of entries + sz = stbi__get8(s); // check bits per palette color entry + if ( (sz != 8) && (sz != 15) && (sz != 16) && (sz != 24) && (sz != 32) ) goto errorEnd; + stbi__skip(s,4); // skip image x and y origin + } else { // "normal" image w/o colormap + if ( (sz != 2) && (sz != 3) && (sz != 10) && (sz != 11) ) goto errorEnd; // only RGB or grey allowed, +/- RLE + stbi__skip(s,9); // skip colormap specification and image x/y origin + } + if ( stbi__get16le(s) < 1 ) goto errorEnd; // test width + if ( stbi__get16le(s) < 1 ) goto errorEnd; // test height + sz = stbi__get8(s); // bits per pixel + if ( (tga_color_type == 1) && (sz != 8) && (sz != 16) ) goto errorEnd; // for colormapped images, bpp is size of an index + if ( (sz != 8) && (sz != 15) && (sz != 16) && (sz != 24) && (sz != 32) ) goto errorEnd; + + res = 1; // if we got this far, everything's good and we can return 1 instead of 0 + +errorEnd: + stbi__rewind(s); + return res; +} + +// read 16bit value and convert to 24bit RGB +static void stbi__tga_read_rgb16(stbi__context *s, stbi_uc* out) +{ + stbi__uint16 px = (stbi__uint16)stbi__get16le(s); + stbi__uint16 fiveBitMask = 31; + // we have 3 channels with 5bits each + int r = (px >> 10) & fiveBitMask; + int g = (px >> 5) & fiveBitMask; + int b = px & fiveBitMask; + // Note that this saves the data in RGB(A) order, so it doesn't need to be swapped later + out[0] = (stbi_uc)((r * 255)/31); + out[1] = (stbi_uc)((g * 255)/31); + out[2] = (stbi_uc)((b * 255)/31); + + // some people claim that the most significant bit might be used for alpha + // (possibly if an alpha-bit is set in the "image descriptor byte") + // but that only made 16bit test images completely translucent.. + // so let's treat all 15 and 16bit TGAs as RGB with no alpha. +} + +static void *stbi__tga_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri) +{ + // read in the TGA header stuff + int tga_offset = stbi__get8(s); + int tga_indexed = stbi__get8(s); + int tga_image_type = stbi__get8(s); + int tga_is_RLE = 0; + int tga_palette_start = stbi__get16le(s); + int tga_palette_len = stbi__get16le(s); + int tga_palette_bits = stbi__get8(s); + int tga_x_origin = stbi__get16le(s); + int tga_y_origin = stbi__get16le(s); + int tga_width = stbi__get16le(s); + int tga_height = stbi__get16le(s); + int tga_bits_per_pixel = stbi__get8(s); + int tga_comp, tga_rgb16=0; + int tga_inverted = stbi__get8(s); + // int tga_alpha_bits = tga_inverted & 15; // the 4 lowest bits - unused (useless?) + // image data + unsigned char *tga_data; + unsigned char *tga_palette = NULL; + int i, j; + unsigned char raw_data[4] = {0}; + int RLE_count = 0; + int RLE_repeating = 0; + int read_next_pixel = 1; + STBI_NOTUSED(ri); + STBI_NOTUSED(tga_x_origin); // @TODO + STBI_NOTUSED(tga_y_origin); // @TODO + + if (tga_height > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); + if (tga_width > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); + + // do a tiny bit of precessing + if ( tga_image_type >= 8 ) + { + tga_image_type -= 8; + tga_is_RLE = 1; + } + tga_inverted = 1 - ((tga_inverted >> 5) & 1); + + // If I'm paletted, then I'll use the number of bits from the palette + if ( tga_indexed ) tga_comp = stbi__tga_get_comp(tga_palette_bits, 0, &tga_rgb16); + else tga_comp = stbi__tga_get_comp(tga_bits_per_pixel, (tga_image_type == 3), &tga_rgb16); + + if(!tga_comp) // shouldn't really happen, stbi__tga_test() should have ensured basic consistency + return stbi__errpuc("bad format", "Can't find out TGA pixelformat"); + + // tga info + *x = tga_width; + *y = tga_height; + if (comp) *comp = tga_comp; + + if (!stbi__mad3sizes_valid(tga_width, tga_height, tga_comp, 0)) + return stbi__errpuc("too large", "Corrupt TGA"); + + tga_data = (unsigned char*)stbi__malloc_mad3(tga_width, tga_height, tga_comp, 0); + if (!tga_data) return stbi__errpuc("outofmem", "Out of memory"); + + // skip to the data's starting position (offset usually = 0) + stbi__skip(s, tga_offset ); + + if ( !tga_indexed && !tga_is_RLE && !tga_rgb16 ) { + for (i=0; i < tga_height; ++i) { + int row = tga_inverted ? tga_height -i - 1 : i; + stbi_uc *tga_row = tga_data + row*tga_width*tga_comp; + stbi__getn(s, tga_row, tga_width * tga_comp); + } + } else { + // do I need to load a palette? + if ( tga_indexed) + { + if (tga_palette_len == 0) { /* you have to have at least one entry! */ + STBI_FREE(tga_data); + return stbi__errpuc("bad palette", "Corrupt TGA"); + } + + // any data to skip? (offset usually = 0) + stbi__skip(s, tga_palette_start ); + // load the palette + tga_palette = (unsigned char*)stbi__malloc_mad2(tga_palette_len, tga_comp, 0); + if (!tga_palette) { + STBI_FREE(tga_data); + return stbi__errpuc("outofmem", "Out of memory"); + } + if (tga_rgb16) { + stbi_uc *pal_entry = tga_palette; + STBI_ASSERT(tga_comp == STBI_rgb); + for (i=0; i < tga_palette_len; ++i) { + stbi__tga_read_rgb16(s, pal_entry); + pal_entry += tga_comp; + } + } else if (!stbi__getn(s, tga_palette, tga_palette_len * tga_comp)) { + STBI_FREE(tga_data); + STBI_FREE(tga_palette); + return stbi__errpuc("bad palette", "Corrupt TGA"); + } + } + // load the data + for (i=0; i < tga_width * tga_height; ++i) + { + // if I'm in RLE mode, do I need to get a RLE stbi__pngchunk? + if ( tga_is_RLE ) + { + if ( RLE_count == 0 ) + { + // yep, get the next byte as a RLE command + int RLE_cmd = stbi__get8(s); + RLE_count = 1 + (RLE_cmd & 127); + RLE_repeating = RLE_cmd >> 7; + read_next_pixel = 1; + } else if ( !RLE_repeating ) + { + read_next_pixel = 1; + } + } else + { + read_next_pixel = 1; + } + // OK, if I need to read a pixel, do it now + if ( read_next_pixel ) + { + // load however much data we did have + if ( tga_indexed ) + { + // read in index, then perform the lookup + int pal_idx = (tga_bits_per_pixel == 8) ? stbi__get8(s) : stbi__get16le(s); + if ( pal_idx >= tga_palette_len ) { + // invalid index + pal_idx = 0; + } + pal_idx *= tga_comp; + for (j = 0; j < tga_comp; ++j) { + raw_data[j] = tga_palette[pal_idx+j]; + } + } else if(tga_rgb16) { + STBI_ASSERT(tga_comp == STBI_rgb); + stbi__tga_read_rgb16(s, raw_data); + } else { + // read in the data raw + for (j = 0; j < tga_comp; ++j) { + raw_data[j] = stbi__get8(s); + } + } + // clear the reading flag for the next pixel + read_next_pixel = 0; + } // end of reading a pixel + + // copy data + for (j = 0; j < tga_comp; ++j) + tga_data[i*tga_comp+j] = raw_data[j]; + + // in case we're in RLE mode, keep counting down + --RLE_count; + } + // do I need to invert the image? + if ( tga_inverted ) + { + for (j = 0; j*2 < tga_height; ++j) + { + int index1 = j * tga_width * tga_comp; + int index2 = (tga_height - 1 - j) * tga_width * tga_comp; + for (i = tga_width * tga_comp; i > 0; --i) + { + unsigned char temp = tga_data[index1]; + tga_data[index1] = tga_data[index2]; + tga_data[index2] = temp; + ++index1; + ++index2; + } + } + } + // clear my palette, if I had one + if ( tga_palette != NULL ) + { + STBI_FREE( tga_palette ); + } + } + + // swap RGB - if the source data was RGB16, it already is in the right order + if (tga_comp >= 3 && !tga_rgb16) + { + unsigned char* tga_pixel = tga_data; + for (i=0; i < tga_width * tga_height; ++i) + { + unsigned char temp = tga_pixel[0]; + tga_pixel[0] = tga_pixel[2]; + tga_pixel[2] = temp; + tga_pixel += tga_comp; + } + } + + // convert to target component count + if (req_comp && req_comp != tga_comp) + tga_data = stbi__convert_format(tga_data, tga_comp, req_comp, tga_width, tga_height); + + // the things I do to get rid of an error message, and yet keep + // Microsoft's C compilers happy... [8^( + tga_palette_start = tga_palette_len = tga_palette_bits = + tga_x_origin = tga_y_origin = 0; + STBI_NOTUSED(tga_palette_start); + // OK, done + return tga_data; +} +#endif + +// ************************************************************************************************* +// Photoshop PSD loader -- PD by Thatcher Ulrich, integration by Nicolas Schulz, tweaked by STB + +#ifndef STBI_NO_PSD +static int stbi__psd_test(stbi__context *s) +{ + int r = (stbi__get32be(s) == 0x38425053); + stbi__rewind(s); + return r; +} + +static int stbi__psd_decode_rle(stbi__context *s, stbi_uc *p, int pixelCount) +{ + int count, nleft, len; + + count = 0; + while ((nleft = pixelCount - count) > 0) { + len = stbi__get8(s); + if (len == 128) { + // No-op. + } else if (len < 128) { + // Copy next len+1 bytes literally. + len++; + if (len > nleft) return 0; // corrupt data + count += len; + while (len) { + *p = stbi__get8(s); + p += 4; + len--; + } + } else if (len > 128) { + stbi_uc val; + // Next -len+1 bytes in the dest are replicated from next source byte. + // (Interpret len as a negative 8-bit int.) + len = 257 - len; + if (len > nleft) return 0; // corrupt data + val = stbi__get8(s); + count += len; + while (len) { + *p = val; + p += 4; + len--; + } + } + } + + return 1; +} + +static void *stbi__psd_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri, int bpc) +{ + int pixelCount; + int channelCount, compression; + int channel, i; + int bitdepth; + int w,h; + stbi_uc *out; + STBI_NOTUSED(ri); + + // Check identifier + if (stbi__get32be(s) != 0x38425053) // "8BPS" + return stbi__errpuc("not PSD", "Corrupt PSD image"); + + // Check file type version. + if (stbi__get16be(s) != 1) + return stbi__errpuc("wrong version", "Unsupported version of PSD image"); + + // Skip 6 reserved bytes. + stbi__skip(s, 6 ); + + // Read the number of channels (R, G, B, A, etc). + channelCount = stbi__get16be(s); + if (channelCount < 0 || channelCount > 16) + return stbi__errpuc("wrong channel count", "Unsupported number of channels in PSD image"); + + // Read the rows and columns of the image. + h = stbi__get32be(s); + w = stbi__get32be(s); + + if (h > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); + if (w > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); + + // Make sure the depth is 8 bits. + bitdepth = stbi__get16be(s); + if (bitdepth != 8 && bitdepth != 16) + return stbi__errpuc("unsupported bit depth", "PSD bit depth is not 8 or 16 bit"); + + // Make sure the color mode is RGB. + // Valid options are: + // 0: Bitmap + // 1: Grayscale + // 2: Indexed color + // 3: RGB color + // 4: CMYK color + // 7: Multichannel + // 8: Duotone + // 9: Lab color + if (stbi__get16be(s) != 3) + return stbi__errpuc("wrong color format", "PSD is not in RGB color format"); + + // Skip the Mode Data. (It's the palette for indexed color; other info for other modes.) + stbi__skip(s,stbi__get32be(s) ); + + // Skip the image resources. (resolution, pen tool paths, etc) + stbi__skip(s, stbi__get32be(s) ); + + // Skip the reserved data. + stbi__skip(s, stbi__get32be(s) ); + + // Find out if the data is compressed. + // Known values: + // 0: no compression + // 1: RLE compressed + compression = stbi__get16be(s); + if (compression > 1) + return stbi__errpuc("bad compression", "PSD has an unknown compression format"); + + // Check size + if (!stbi__mad3sizes_valid(4, w, h, 0)) + return stbi__errpuc("too large", "Corrupt PSD"); + + // Create the destination image. + + if (!compression && bitdepth == 16 && bpc == 16) { + out = (stbi_uc *) stbi__malloc_mad3(8, w, h, 0); + ri->bits_per_channel = 16; + } else + out = (stbi_uc *) stbi__malloc(4 * w*h); + + if (!out) return stbi__errpuc("outofmem", "Out of memory"); + pixelCount = w*h; + + // Initialize the data to zero. + //memset( out, 0, pixelCount * 4 ); + + // Finally, the image data. + if (compression) { + // RLE as used by .PSD and .TIFF + // Loop until you get the number of unpacked bytes you are expecting: + // Read the next source byte into n. + // If n is between 0 and 127 inclusive, copy the next n+1 bytes literally. + // Else if n is between -127 and -1 inclusive, copy the next byte -n+1 times. + // Else if n is 128, noop. + // Endloop + + // The RLE-compressed data is preceded by a 2-byte data count for each row in the data, + // which we're going to just skip. + stbi__skip(s, h * channelCount * 2 ); + + // Read the RLE data by channel. + for (channel = 0; channel < 4; channel++) { + stbi_uc *p; + + p = out+channel; + if (channel >= channelCount) { + // Fill this channel with default data. + for (i = 0; i < pixelCount; i++, p += 4) + *p = (channel == 3 ? 255 : 0); + } else { + // Read the RLE data. + if (!stbi__psd_decode_rle(s, p, pixelCount)) { + STBI_FREE(out); + return stbi__errpuc("corrupt", "bad RLE data"); + } + } + } + + } else { + // We're at the raw image data. It's each channel in order (Red, Green, Blue, Alpha, ...) + // where each channel consists of an 8-bit (or 16-bit) value for each pixel in the image. + + // Read the data by channel. + for (channel = 0; channel < 4; channel++) { + if (channel >= channelCount) { + // Fill this channel with default data. + if (bitdepth == 16 && bpc == 16) { + stbi__uint16 *q = ((stbi__uint16 *) out) + channel; + stbi__uint16 val = channel == 3 ? 65535 : 0; + for (i = 0; i < pixelCount; i++, q += 4) + *q = val; + } else { + stbi_uc *p = out+channel; + stbi_uc val = channel == 3 ? 255 : 0; + for (i = 0; i < pixelCount; i++, p += 4) + *p = val; + } + } else { + if (ri->bits_per_channel == 16) { // output bpc + stbi__uint16 *q = ((stbi__uint16 *) out) + channel; + for (i = 0; i < pixelCount; i++, q += 4) + *q = (stbi__uint16) stbi__get16be(s); + } else { + stbi_uc *p = out+channel; + if (bitdepth == 16) { // input bpc + for (i = 0; i < pixelCount; i++, p += 4) + *p = (stbi_uc) (stbi__get16be(s) >> 8); + } else { + for (i = 0; i < pixelCount; i++, p += 4) + *p = stbi__get8(s); + } + } + } + } + } + + // remove weird white matte from PSD + if (channelCount >= 4) { + if (ri->bits_per_channel == 16) { + for (i=0; i < w*h; ++i) { + stbi__uint16 *pixel = (stbi__uint16 *) out + 4*i; + if (pixel[3] != 0 && pixel[3] != 65535) { + float a = pixel[3] / 65535.0f; + float ra = 1.0f / a; + float inv_a = 65535.0f * (1 - ra); + pixel[0] = (stbi__uint16) (pixel[0]*ra + inv_a); + pixel[1] = (stbi__uint16) (pixel[1]*ra + inv_a); + pixel[2] = (stbi__uint16) (pixel[2]*ra + inv_a); + } + } + } else { + for (i=0; i < w*h; ++i) { + unsigned char *pixel = out + 4*i; + if (pixel[3] != 0 && pixel[3] != 255) { + float a = pixel[3] / 255.0f; + float ra = 1.0f / a; + float inv_a = 255.0f * (1 - ra); + pixel[0] = (unsigned char) (pixel[0]*ra + inv_a); + pixel[1] = (unsigned char) (pixel[1]*ra + inv_a); + pixel[2] = (unsigned char) (pixel[2]*ra + inv_a); + } + } + } + } + + // convert to desired output format + if (req_comp && req_comp != 4) { + if (ri->bits_per_channel == 16) + out = (stbi_uc *) stbi__convert_format16((stbi__uint16 *) out, 4, req_comp, w, h); + else + out = stbi__convert_format(out, 4, req_comp, w, h); + if (out == NULL) return out; // stbi__convert_format frees input on failure + } + + if (comp) *comp = 4; + *y = h; + *x = w; + + return out; +} +#endif + +// ************************************************************************************************* +// Softimage PIC loader +// by Tom Seddon +// +// See http://softimage.wiki.softimage.com/index.php/INFO:_PIC_file_format +// See http://ozviz.wasp.uwa.edu.au/~pbourke/dataformats/softimagepic/ + +#ifndef STBI_NO_PIC +static int stbi__pic_is4(stbi__context *s,const char *str) +{ + int i; + for (i=0; i<4; ++i) + if (stbi__get8(s) != (stbi_uc)str[i]) + return 0; + + return 1; +} + +static int stbi__pic_test_core(stbi__context *s) +{ + int i; + + if (!stbi__pic_is4(s,"\x53\x80\xF6\x34")) + return 0; + + for(i=0;i<84;++i) + stbi__get8(s); + + if (!stbi__pic_is4(s,"PICT")) + return 0; + + return 1; +} + +typedef struct +{ + stbi_uc size,type,channel; +} stbi__pic_packet; + +static stbi_uc *stbi__readval(stbi__context *s, int channel, stbi_uc *dest) +{ + int mask=0x80, i; + + for (i=0; i<4; ++i, mask>>=1) { + if (channel & mask) { + if (stbi__at_eof(s)) return stbi__errpuc("bad file","PIC file too short"); + dest[i]=stbi__get8(s); + } + } + + return dest; +} + +static void stbi__copyval(int channel,stbi_uc *dest,const stbi_uc *src) +{ + int mask=0x80,i; + + for (i=0;i<4; ++i, mask>>=1) + if (channel&mask) + dest[i]=src[i]; +} + +static stbi_uc *stbi__pic_load_core(stbi__context *s,int width,int height,int *comp, stbi_uc *result) +{ + int act_comp=0,num_packets=0,y,chained; + stbi__pic_packet packets[10]; + + // this will (should...) cater for even some bizarre stuff like having data + // for the same channel in multiple packets. + do { + stbi__pic_packet *packet; + + if (num_packets==sizeof(packets)/sizeof(packets[0])) + return stbi__errpuc("bad format","too many packets"); + + packet = &packets[num_packets++]; + + chained = stbi__get8(s); + packet->size = stbi__get8(s); + packet->type = stbi__get8(s); + packet->channel = stbi__get8(s); + + act_comp |= packet->channel; + + if (stbi__at_eof(s)) return stbi__errpuc("bad file","file too short (reading packets)"); + if (packet->size != 8) return stbi__errpuc("bad format","packet isn't 8bpp"); + } while (chained); + + *comp = (act_comp & 0x10 ? 4 : 3); // has alpha channel? + + for(y=0; ytype) { + default: + return stbi__errpuc("bad format","packet has bad compression type"); + + case 0: {//uncompressed + int x; + + for(x=0;xchannel,dest)) + return 0; + break; + } + + case 1://Pure RLE + { + int left=width, i; + + while (left>0) { + stbi_uc count,value[4]; + + count=stbi__get8(s); + if (stbi__at_eof(s)) return stbi__errpuc("bad file","file too short (pure read count)"); + + if (count > left) + count = (stbi_uc) left; + + if (!stbi__readval(s,packet->channel,value)) return 0; + + for(i=0; ichannel,dest,value); + left -= count; + } + } + break; + + case 2: {//Mixed RLE + int left=width; + while (left>0) { + int count = stbi__get8(s), i; + if (stbi__at_eof(s)) return stbi__errpuc("bad file","file too short (mixed read count)"); + + if (count >= 128) { // Repeated + stbi_uc value[4]; + + if (count==128) + count = stbi__get16be(s); + else + count -= 127; + if (count > left) + return stbi__errpuc("bad file","scanline overrun"); + + if (!stbi__readval(s,packet->channel,value)) + return 0; + + for(i=0;ichannel,dest,value); + } else { // Raw + ++count; + if (count>left) return stbi__errpuc("bad file","scanline overrun"); + + for(i=0;ichannel,dest)) + return 0; + } + left-=count; + } + break; + } + } + } + } + + return result; +} + +static void *stbi__pic_load(stbi__context *s,int *px,int *py,int *comp,int req_comp, stbi__result_info *ri) +{ + stbi_uc *result; + int i, x,y, internal_comp; + STBI_NOTUSED(ri); + + if (!comp) comp = &internal_comp; + + for (i=0; i<92; ++i) + stbi__get8(s); + + x = stbi__get16be(s); + y = stbi__get16be(s); + + if (y > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); + if (x > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); + + if (stbi__at_eof(s)) return stbi__errpuc("bad file","file too short (pic header)"); + if (!stbi__mad3sizes_valid(x, y, 4, 0)) return stbi__errpuc("too large", "PIC image too large to decode"); + + stbi__get32be(s); //skip `ratio' + stbi__get16be(s); //skip `fields' + stbi__get16be(s); //skip `pad' + + // intermediate buffer is RGBA + result = (stbi_uc *) stbi__malloc_mad3(x, y, 4, 0); + if (!result) return stbi__errpuc("outofmem", "Out of memory"); + memset(result, 0xff, x*y*4); + + if (!stbi__pic_load_core(s,x,y,comp, result)) { + STBI_FREE(result); + result=0; + } + *px = x; + *py = y; + if (req_comp == 0) req_comp = *comp; + result=stbi__convert_format(result,4,req_comp,x,y); + + return result; +} + +static int stbi__pic_test(stbi__context *s) +{ + int r = stbi__pic_test_core(s); + stbi__rewind(s); + return r; +} +#endif + +// ************************************************************************************************* +// GIF loader -- public domain by Jean-Marc Lienher -- simplified/shrunk by stb + +#ifndef STBI_NO_GIF +typedef struct +{ + stbi__int16 prefix; + stbi_uc first; + stbi_uc suffix; +} stbi__gif_lzw; + +typedef struct +{ + int w,h; + stbi_uc *out; // output buffer (always 4 components) + stbi_uc *background; // The current "background" as far as a gif is concerned + stbi_uc *history; + int flags, bgindex, ratio, transparent, eflags; + stbi_uc pal[256][4]; + stbi_uc lpal[256][4]; + stbi__gif_lzw codes[8192]; + stbi_uc *color_table; + int parse, step; + int lflags; + int start_x, start_y; + int max_x, max_y; + int cur_x, cur_y; + int line_size; + int delay; +} stbi__gif; + +static int stbi__gif_test_raw(stbi__context *s) +{ + int sz; + if (stbi__get8(s) != 'G' || stbi__get8(s) != 'I' || stbi__get8(s) != 'F' || stbi__get8(s) != '8') return 0; + sz = stbi__get8(s); + if (sz != '9' && sz != '7') return 0; + if (stbi__get8(s) != 'a') return 0; + return 1; +} + +static int stbi__gif_test(stbi__context *s) +{ + int r = stbi__gif_test_raw(s); + stbi__rewind(s); + return r; +} + +static void stbi__gif_parse_colortable(stbi__context *s, stbi_uc pal[256][4], int num_entries, int transp) +{ + int i; + for (i=0; i < num_entries; ++i) { + pal[i][2] = stbi__get8(s); + pal[i][1] = stbi__get8(s); + pal[i][0] = stbi__get8(s); + pal[i][3] = transp == i ? 0 : 255; + } +} + +static int stbi__gif_header(stbi__context *s, stbi__gif *g, int *comp, int is_info) +{ + stbi_uc version; + if (stbi__get8(s) != 'G' || stbi__get8(s) != 'I' || stbi__get8(s) != 'F' || stbi__get8(s) != '8') + return stbi__err("not GIF", "Corrupt GIF"); + + version = stbi__get8(s); + if (version != '7' && version != '9') return stbi__err("not GIF", "Corrupt GIF"); + if (stbi__get8(s) != 'a') return stbi__err("not GIF", "Corrupt GIF"); + + stbi__g_failure_reason = ""; + g->w = stbi__get16le(s); + g->h = stbi__get16le(s); + g->flags = stbi__get8(s); + g->bgindex = stbi__get8(s); + g->ratio = stbi__get8(s); + g->transparent = -1; + + if (g->w > STBI_MAX_DIMENSIONS) return stbi__err("too large","Very large image (corrupt?)"); + if (g->h > STBI_MAX_DIMENSIONS) return stbi__err("too large","Very large image (corrupt?)"); + + if (comp != 0) *comp = 4; // can't actually tell whether it's 3 or 4 until we parse the comments + + if (is_info) return 1; + + if (g->flags & 0x80) + stbi__gif_parse_colortable(s,g->pal, 2 << (g->flags & 7), -1); + + return 1; +} + +static int stbi__gif_info_raw(stbi__context *s, int *x, int *y, int *comp) +{ + stbi__gif* g = (stbi__gif*) stbi__malloc(sizeof(stbi__gif)); + if (!g) return stbi__err("outofmem", "Out of memory"); + if (!stbi__gif_header(s, g, comp, 1)) { + STBI_FREE(g); + stbi__rewind( s ); + return 0; + } + if (x) *x = g->w; + if (y) *y = g->h; + STBI_FREE(g); + return 1; +} + +static void stbi__out_gif_code(stbi__gif *g, stbi__uint16 code) +{ + stbi_uc *p, *c; + int idx; + + // recurse to decode the prefixes, since the linked-list is backwards, + // and working backwards through an interleaved image would be nasty + if (g->codes[code].prefix >= 0) + stbi__out_gif_code(g, g->codes[code].prefix); + + if (g->cur_y >= g->max_y) return; + + idx = g->cur_x + g->cur_y; + p = &g->out[idx]; + g->history[idx / 4] = 1; + + c = &g->color_table[g->codes[code].suffix * 4]; + if (c[3] > 128) { // don't render transparent pixels; + p[0] = c[2]; + p[1] = c[1]; + p[2] = c[0]; + p[3] = c[3]; + } + g->cur_x += 4; + + if (g->cur_x >= g->max_x) { + g->cur_x = g->start_x; + g->cur_y += g->step; + + while (g->cur_y >= g->max_y && g->parse > 0) { + g->step = (1 << g->parse) * g->line_size; + g->cur_y = g->start_y + (g->step >> 1); + --g->parse; + } + } +} + +static stbi_uc *stbi__process_gif_raster(stbi__context *s, stbi__gif *g) +{ + stbi_uc lzw_cs; + stbi__int32 len, init_code; + stbi__uint32 first; + stbi__int32 codesize, codemask, avail, oldcode, bits, valid_bits, clear; + stbi__gif_lzw *p; + + lzw_cs = stbi__get8(s); + if (lzw_cs > 12) return NULL; + clear = 1 << lzw_cs; + first = 1; + codesize = lzw_cs + 1; + codemask = (1 << codesize) - 1; + bits = 0; + valid_bits = 0; + for (init_code = 0; init_code < clear; init_code++) { + g->codes[init_code].prefix = -1; + g->codes[init_code].first = (stbi_uc) init_code; + g->codes[init_code].suffix = (stbi_uc) init_code; + } + + // support no starting clear code + avail = clear+2; + oldcode = -1; + + len = 0; + for(;;) { + if (valid_bits < codesize) { + if (len == 0) { + len = stbi__get8(s); // start new block + if (len == 0) + return g->out; + } + --len; + bits |= (stbi__int32) stbi__get8(s) << valid_bits; + valid_bits += 8; + } else { + stbi__int32 code = bits & codemask; + bits >>= codesize; + valid_bits -= codesize; + // @OPTIMIZE: is there some way we can accelerate the non-clear path? + if (code == clear) { // clear code + codesize = lzw_cs + 1; + codemask = (1 << codesize) - 1; + avail = clear + 2; + oldcode = -1; + first = 0; + } else if (code == clear + 1) { // end of stream code + stbi__skip(s, len); + while ((len = stbi__get8(s)) > 0) + stbi__skip(s,len); + return g->out; + } else if (code <= avail) { + if (first) { + return stbi__errpuc("no clear code", "Corrupt GIF"); + } + + if (oldcode >= 0) { + p = &g->codes[avail++]; + if (avail > 8192) { + return stbi__errpuc("too many codes", "Corrupt GIF"); + } + + p->prefix = (stbi__int16) oldcode; + p->first = g->codes[oldcode].first; + p->suffix = (code == avail) ? p->first : g->codes[code].first; + } else if (code == avail) + return stbi__errpuc("illegal code in raster", "Corrupt GIF"); + + stbi__out_gif_code(g, (stbi__uint16) code); + + if ((avail & codemask) == 0 && avail <= 0x0FFF) { + codesize++; + codemask = (1 << codesize) - 1; + } + + oldcode = code; + } else { + return stbi__errpuc("illegal code in raster", "Corrupt GIF"); + } + } + } +} + +// this function is designed to support animated gifs, although stb_image doesn't support it +// two back is the image from two frames ago, used for a very specific disposal format +static stbi_uc *stbi__gif_load_next(stbi__context *s, stbi__gif *g, int *comp, int req_comp, stbi_uc *two_back) +{ + int dispose; + int first_frame; + int pi; + int pcount; + STBI_NOTUSED(req_comp); + + // on first frame, any non-written pixels get the background colour (non-transparent) + first_frame = 0; + if (g->out == 0) { + if (!stbi__gif_header(s, g, comp,0)) return 0; // stbi__g_failure_reason set by stbi__gif_header + if (!stbi__mad3sizes_valid(4, g->w, g->h, 0)) + return stbi__errpuc("too large", "GIF image is too large"); + pcount = g->w * g->h; + g->out = (stbi_uc *) stbi__malloc(4 * pcount); + g->background = (stbi_uc *) stbi__malloc(4 * pcount); + g->history = (stbi_uc *) stbi__malloc(pcount); + if (!g->out || !g->background || !g->history) + return stbi__errpuc("outofmem", "Out of memory"); + + // image is treated as "transparent" at the start - ie, nothing overwrites the current background; + // background colour is only used for pixels that are not rendered first frame, after that "background" + // color refers to the color that was there the previous frame. + memset(g->out, 0x00, 4 * pcount); + memset(g->background, 0x00, 4 * pcount); // state of the background (starts transparent) + memset(g->history, 0x00, pcount); // pixels that were affected previous frame + first_frame = 1; + } else { + // second frame - how do we dispose of the previous one? + dispose = (g->eflags & 0x1C) >> 2; + pcount = g->w * g->h; + + if ((dispose == 3) && (two_back == 0)) { + dispose = 2; // if I don't have an image to revert back to, default to the old background + } + + if (dispose == 3) { // use previous graphic + for (pi = 0; pi < pcount; ++pi) { + if (g->history[pi]) { + memcpy( &g->out[pi * 4], &two_back[pi * 4], 4 ); + } + } + } else if (dispose == 2) { + // restore what was changed last frame to background before that frame; + for (pi = 0; pi < pcount; ++pi) { + if (g->history[pi]) { + memcpy( &g->out[pi * 4], &g->background[pi * 4], 4 ); + } + } + } else { + // This is a non-disposal case eithe way, so just + // leave the pixels as is, and they will become the new background + // 1: do not dispose + // 0: not specified. + } + + // background is what out is after the undoing of the previou frame; + memcpy( g->background, g->out, 4 * g->w * g->h ); + } + + // clear my history; + memset( g->history, 0x00, g->w * g->h ); // pixels that were affected previous frame + + for (;;) { + int tag = stbi__get8(s); + switch (tag) { + case 0x2C: /* Image Descriptor */ + { + stbi__int32 x, y, w, h; + stbi_uc *o; + + x = stbi__get16le(s); + y = stbi__get16le(s); + w = stbi__get16le(s); + h = stbi__get16le(s); + if (((x + w) > (g->w)) || ((y + h) > (g->h))) + return stbi__errpuc("bad Image Descriptor", "Corrupt GIF"); + + g->line_size = g->w * 4; + g->start_x = x * 4; + g->start_y = y * g->line_size; + g->max_x = g->start_x + w * 4; + g->max_y = g->start_y + h * g->line_size; + g->cur_x = g->start_x; + g->cur_y = g->start_y; + + // if the width of the specified rectangle is 0, that means + // we may not see *any* pixels or the image is malformed; + // to make sure this is caught, move the current y down to + // max_y (which is what out_gif_code checks). + if (w == 0) + g->cur_y = g->max_y; + + g->lflags = stbi__get8(s); + + if (g->lflags & 0x40) { + g->step = 8 * g->line_size; // first interlaced spacing + g->parse = 3; + } else { + g->step = g->line_size; + g->parse = 0; + } + + if (g->lflags & 0x80) { + stbi__gif_parse_colortable(s,g->lpal, 2 << (g->lflags & 7), g->eflags & 0x01 ? g->transparent : -1); + g->color_table = (stbi_uc *) g->lpal; + } else if (g->flags & 0x80) { + g->color_table = (stbi_uc *) g->pal; + } else + return stbi__errpuc("missing color table", "Corrupt GIF"); + + o = stbi__process_gif_raster(s, g); + if (!o) return NULL; + + // if this was the first frame, + pcount = g->w * g->h; + if (first_frame && (g->bgindex > 0)) { + // if first frame, any pixel not drawn to gets the background color + for (pi = 0; pi < pcount; ++pi) { + if (g->history[pi] == 0) { + g->pal[g->bgindex][3] = 255; // just in case it was made transparent, undo that; It will be reset next frame if need be; + memcpy( &g->out[pi * 4], &g->pal[g->bgindex], 4 ); + } + } + } + + return o; + } + + case 0x21: // Comment Extension. + { + int len; + int ext = stbi__get8(s); + if (ext == 0xF9) { // Graphic Control Extension. + len = stbi__get8(s); + if (len == 4) { + g->eflags = stbi__get8(s); + g->delay = 10 * stbi__get16le(s); // delay - 1/100th of a second, saving as 1/1000ths. + + // unset old transparent + if (g->transparent >= 0) { + g->pal[g->transparent][3] = 255; + } + if (g->eflags & 0x01) { + g->transparent = stbi__get8(s); + if (g->transparent >= 0) { + g->pal[g->transparent][3] = 0; + } + } else { + // don't need transparent + stbi__skip(s, 1); + g->transparent = -1; + } + } else { + stbi__skip(s, len); + break; + } + } + while ((len = stbi__get8(s)) != 0) { + stbi__skip(s, len); + } + break; + } + + case 0x3B: // gif stream termination code + return (stbi_uc *) s; // using '1' causes warning on some compilers + + default: + return stbi__errpuc("unknown code", "Corrupt GIF"); + } + } +} + +static void *stbi__load_gif_main_outofmem(stbi__gif *g, stbi_uc *out, int **delays) +{ + STBI_FREE(g->out); + STBI_FREE(g->history); + STBI_FREE(g->background); + + if (out) STBI_FREE(out); + if (delays && *delays) STBI_FREE(*delays); + return stbi__errpuc("outofmem", "Out of memory"); +} + +static void *stbi__load_gif_main(stbi__context *s, int **delays, int *x, int *y, int *z, int *comp, int req_comp) +{ + if (stbi__gif_test(s)) { + int layers = 0; + stbi_uc *u = 0; + stbi_uc *out = 0; + stbi_uc *two_back = 0; + stbi__gif g; + int stride; + int out_size = 0; + int delays_size = 0; + + STBI_NOTUSED(out_size); + STBI_NOTUSED(delays_size); + + memset(&g, 0, sizeof(g)); + if (delays) { + *delays = 0; + } + + do { + u = stbi__gif_load_next(s, &g, comp, req_comp, two_back); + if (u == (stbi_uc *) s) u = 0; // end of animated gif marker + + if (u) { + *x = g.w; + *y = g.h; + ++layers; + stride = g.w * g.h * 4; + + if (out) { + void *tmp = (stbi_uc*) STBI_REALLOC_SIZED( out, out_size, layers * stride ); + if (!tmp) + return stbi__load_gif_main_outofmem(&g, out, delays); + else { + out = (stbi_uc*) tmp; + out_size = layers * stride; + } + + if (delays) { + int *new_delays = (int*) STBI_REALLOC_SIZED( *delays, delays_size, sizeof(int) * layers ); + if (!new_delays) + return stbi__load_gif_main_outofmem(&g, out, delays); + *delays = new_delays; + delays_size = layers * sizeof(int); + } + } else { + out = (stbi_uc*)stbi__malloc( layers * stride ); + if (!out) + return stbi__load_gif_main_outofmem(&g, out, delays); + out_size = layers * stride; + if (delays) { + *delays = (int*) stbi__malloc( layers * sizeof(int) ); + if (!*delays) + return stbi__load_gif_main_outofmem(&g, out, delays); + delays_size = layers * sizeof(int); + } + } + memcpy( out + ((layers - 1) * stride), u, stride ); + if (layers >= 2) { + two_back = out - 2 * stride; + } + + if (delays) { + (*delays)[layers - 1U] = g.delay; + } + } + } while (u != 0); + + // free temp buffer; + STBI_FREE(g.out); + STBI_FREE(g.history); + STBI_FREE(g.background); + + // do the final conversion after loading everything; + if (req_comp && req_comp != 4) + out = stbi__convert_format(out, 4, req_comp, layers * g.w, g.h); + + *z = layers; + return out; + } else { + return stbi__errpuc("not GIF", "Image was not as a gif type."); + } +} + +static void *stbi__gif_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri) +{ + stbi_uc *u = 0; + stbi__gif g; + memset(&g, 0, sizeof(g)); + STBI_NOTUSED(ri); + + u = stbi__gif_load_next(s, &g, comp, req_comp, 0); + if (u == (stbi_uc *) s) u = 0; // end of animated gif marker + if (u) { + *x = g.w; + *y = g.h; + + // moved conversion to after successful load so that the same + // can be done for multiple frames. + if (req_comp && req_comp != 4) + u = stbi__convert_format(u, 4, req_comp, g.w, g.h); + } else if (g.out) { + // if there was an error and we allocated an image buffer, free it! + STBI_FREE(g.out); + } + + // free buffers needed for multiple frame loading; + STBI_FREE(g.history); + STBI_FREE(g.background); + + return u; +} + +static int stbi__gif_info(stbi__context *s, int *x, int *y, int *comp) +{ + return stbi__gif_info_raw(s,x,y,comp); +} +#endif + +// ************************************************************************************************* +// Radiance RGBE HDR loader +// originally by Nicolas Schulz +#ifndef STBI_NO_HDR +static int stbi__hdr_test_core(stbi__context *s, const char *signature) +{ + int i; + for (i=0; signature[i]; ++i) + if (stbi__get8(s) != signature[i]) + return 0; + stbi__rewind(s); + return 1; +} + +static int stbi__hdr_test(stbi__context* s) +{ + int r = stbi__hdr_test_core(s, "#?RADIANCE\n"); + stbi__rewind(s); + if(!r) { + r = stbi__hdr_test_core(s, "#?RGBE\n"); + stbi__rewind(s); + } + return r; +} + +#define STBI__HDR_BUFLEN 1024 +static char *stbi__hdr_gettoken(stbi__context *z, char *buffer) +{ + int len=0; + char c = '\0'; + + c = (char) stbi__get8(z); + + while (!stbi__at_eof(z) && c != '\n') { + buffer[len++] = c; + if (len == STBI__HDR_BUFLEN-1) { + // flush to end of line + while (!stbi__at_eof(z) && stbi__get8(z) != '\n') + ; + break; + } + c = (char) stbi__get8(z); + } + + buffer[len] = 0; + return buffer; +} + +static void stbi__hdr_convert(float *output, stbi_uc *input, int req_comp) +{ + if ( input[3] != 0 ) { + float f1; + // Exponent + f1 = (float) ldexp(1.0f, input[3] - (int)(128 + 8)); + if (req_comp <= 2) + output[0] = (input[0] + input[1] + input[2]) * f1 / 3; + else { + output[0] = input[0] * f1; + output[1] = input[1] * f1; + output[2] = input[2] * f1; + } + if (req_comp == 2) output[1] = 1; + if (req_comp == 4) output[3] = 1; + } else { + switch (req_comp) { + case 4: output[3] = 1; /* fallthrough */ + case 3: output[0] = output[1] = output[2] = 0; + break; + case 2: output[1] = 1; /* fallthrough */ + case 1: output[0] = 0; + break; + } + } +} + +static float *stbi__hdr_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri) +{ + char buffer[STBI__HDR_BUFLEN]; + char *token; + int valid = 0; + int width, height; + stbi_uc *scanline; + float *hdr_data; + int len; + unsigned char count, value; + int i, j, k, c1,c2, z; + const char *headerToken; + STBI_NOTUSED(ri); + + // Check identifier + headerToken = stbi__hdr_gettoken(s,buffer); + if (strcmp(headerToken, "#?RADIANCE") != 0 && strcmp(headerToken, "#?RGBE") != 0) + return stbi__errpf("not HDR", "Corrupt HDR image"); + + // Parse header + for(;;) { + token = stbi__hdr_gettoken(s,buffer); + if (token[0] == 0) break; + if (strcmp(token, "FORMAT=32-bit_rle_rgbe") == 0) valid = 1; + } + + if (!valid) return stbi__errpf("unsupported format", "Unsupported HDR format"); + + // Parse width and height + // can't use sscanf() if we're not using stdio! + token = stbi__hdr_gettoken(s,buffer); + if (strncmp(token, "-Y ", 3)) return stbi__errpf("unsupported data layout", "Unsupported HDR format"); + token += 3; + height = (int) strtol(token, &token, 10); + while (*token == ' ') ++token; + if (strncmp(token, "+X ", 3)) return stbi__errpf("unsupported data layout", "Unsupported HDR format"); + token += 3; + width = (int) strtol(token, NULL, 10); + + if (height > STBI_MAX_DIMENSIONS) return stbi__errpf("too large","Very large image (corrupt?)"); + if (width > STBI_MAX_DIMENSIONS) return stbi__errpf("too large","Very large image (corrupt?)"); + + *x = width; + *y = height; + + if (comp) *comp = 3; + if (req_comp == 0) req_comp = 3; + + if (!stbi__mad4sizes_valid(width, height, req_comp, sizeof(float), 0)) + return stbi__errpf("too large", "HDR image is too large"); + + // Read data + hdr_data = (float *) stbi__malloc_mad4(width, height, req_comp, sizeof(float), 0); + if (!hdr_data) + return stbi__errpf("outofmem", "Out of memory"); + + // Load image data + // image data is stored as some number of sca + if ( width < 8 || width >= 32768) { + // Read flat data + for (j=0; j < height; ++j) { + for (i=0; i < width; ++i) { + stbi_uc rgbe[4]; + main_decode_loop: + stbi__getn(s, rgbe, 4); + stbi__hdr_convert(hdr_data + j * width * req_comp + i * req_comp, rgbe, req_comp); + } + } + } else { + // Read RLE-encoded data + scanline = NULL; + + for (j = 0; j < height; ++j) { + c1 = stbi__get8(s); + c2 = stbi__get8(s); + len = stbi__get8(s); + if (c1 != 2 || c2 != 2 || (len & 0x80)) { + // not run-length encoded, so we have to actually use THIS data as a decoded + // pixel (note this can't be a valid pixel--one of RGB must be >= 128) + stbi_uc rgbe[4]; + rgbe[0] = (stbi_uc) c1; + rgbe[1] = (stbi_uc) c2; + rgbe[2] = (stbi_uc) len; + rgbe[3] = (stbi_uc) stbi__get8(s); + stbi__hdr_convert(hdr_data, rgbe, req_comp); + i = 1; + j = 0; + STBI_FREE(scanline); + goto main_decode_loop; // yes, this makes no sense + } + len <<= 8; + len |= stbi__get8(s); + if (len != width) { STBI_FREE(hdr_data); STBI_FREE(scanline); return stbi__errpf("invalid decoded scanline length", "corrupt HDR"); } + if (scanline == NULL) { + scanline = (stbi_uc *) stbi__malloc_mad2(width, 4, 0); + if (!scanline) { + STBI_FREE(hdr_data); + return stbi__errpf("outofmem", "Out of memory"); + } + } + + for (k = 0; k < 4; ++k) { + int nleft; + i = 0; + while ((nleft = width - i) > 0) { + count = stbi__get8(s); + if (count > 128) { + // Run + value = stbi__get8(s); + count -= 128; + if ((count == 0) || (count > nleft)) { STBI_FREE(hdr_data); STBI_FREE(scanline); return stbi__errpf("corrupt", "bad RLE data in HDR"); } + for (z = 0; z < count; ++z) + scanline[i++ * 4 + k] = value; + } else { + // Dump + if ((count == 0) || (count > nleft)) { STBI_FREE(hdr_data); STBI_FREE(scanline); return stbi__errpf("corrupt", "bad RLE data in HDR"); } + for (z = 0; z < count; ++z) + scanline[i++ * 4 + k] = stbi__get8(s); + } + } + } + for (i=0; i < width; ++i) + stbi__hdr_convert(hdr_data+(j*width + i)*req_comp, scanline + i*4, req_comp); + } + if (scanline) + STBI_FREE(scanline); + } + + return hdr_data; +} + +static int stbi__hdr_info(stbi__context *s, int *x, int *y, int *comp) +{ + char buffer[STBI__HDR_BUFLEN]; + char *token; + int valid = 0; + int dummy; + + if (!x) x = &dummy; + if (!y) y = &dummy; + if (!comp) comp = &dummy; + + if (stbi__hdr_test(s) == 0) { + stbi__rewind( s ); + return 0; + } + + for(;;) { + token = stbi__hdr_gettoken(s,buffer); + if (token[0] == 0) break; + if (strcmp(token, "FORMAT=32-bit_rle_rgbe") == 0) valid = 1; + } + + if (!valid) { + stbi__rewind( s ); + return 0; + } + token = stbi__hdr_gettoken(s,buffer); + if (strncmp(token, "-Y ", 3)) { + stbi__rewind( s ); + return 0; + } + token += 3; + *y = (int) strtol(token, &token, 10); + while (*token == ' ') ++token; + if (strncmp(token, "+X ", 3)) { + stbi__rewind( s ); + return 0; + } + token += 3; + *x = (int) strtol(token, NULL, 10); + *comp = 3; + return 1; +} +#endif // STBI_NO_HDR + +#ifndef STBI_NO_BMP +static int stbi__bmp_info(stbi__context *s, int *x, int *y, int *comp) +{ + void *p; + stbi__bmp_data info; + + info.all_a = 255; + p = stbi__bmp_parse_header(s, &info); + if (p == NULL) { + stbi__rewind( s ); + return 0; + } + if (x) *x = s->img_x; + if (y) *y = s->img_y; + if (comp) { + if (info.bpp == 24 && info.ma == 0xff000000) + *comp = 3; + else + *comp = info.ma ? 4 : 3; + } + return 1; +} +#endif + +#ifndef STBI_NO_PSD +static int stbi__psd_info(stbi__context *s, int *x, int *y, int *comp) +{ + int channelCount, dummy, depth; + if (!x) x = &dummy; + if (!y) y = &dummy; + if (!comp) comp = &dummy; + if (stbi__get32be(s) != 0x38425053) { + stbi__rewind( s ); + return 0; + } + if (stbi__get16be(s) != 1) { + stbi__rewind( s ); + return 0; + } + stbi__skip(s, 6); + channelCount = stbi__get16be(s); + if (channelCount < 0 || channelCount > 16) { + stbi__rewind( s ); + return 0; + } + *y = stbi__get32be(s); + *x = stbi__get32be(s); + depth = stbi__get16be(s); + if (depth != 8 && depth != 16) { + stbi__rewind( s ); + return 0; + } + if (stbi__get16be(s) != 3) { + stbi__rewind( s ); + return 0; + } + *comp = 4; + return 1; +} + +static int stbi__psd_is16(stbi__context *s) +{ + int channelCount, depth; + if (stbi__get32be(s) != 0x38425053) { + stbi__rewind( s ); + return 0; + } + if (stbi__get16be(s) != 1) { + stbi__rewind( s ); + return 0; + } + stbi__skip(s, 6); + channelCount = stbi__get16be(s); + if (channelCount < 0 || channelCount > 16) { + stbi__rewind( s ); + return 0; + } + STBI_NOTUSED(stbi__get32be(s)); + STBI_NOTUSED(stbi__get32be(s)); + depth = stbi__get16be(s); + if (depth != 16) { + stbi__rewind( s ); + return 0; + } + return 1; +} +#endif + +#ifndef STBI_NO_PIC +static int stbi__pic_info(stbi__context *s, int *x, int *y, int *comp) +{ + int act_comp=0,num_packets=0,chained,dummy; + stbi__pic_packet packets[10]; + + if (!x) x = &dummy; + if (!y) y = &dummy; + if (!comp) comp = &dummy; + + if (!stbi__pic_is4(s,"\x53\x80\xF6\x34")) { + stbi__rewind(s); + return 0; + } + + stbi__skip(s, 88); + + *x = stbi__get16be(s); + *y = stbi__get16be(s); + if (stbi__at_eof(s)) { + stbi__rewind( s); + return 0; + } + if ( (*x) != 0 && (1 << 28) / (*x) < (*y)) { + stbi__rewind( s ); + return 0; + } + + stbi__skip(s, 8); + + do { + stbi__pic_packet *packet; + + if (num_packets==sizeof(packets)/sizeof(packets[0])) + return 0; + + packet = &packets[num_packets++]; + chained = stbi__get8(s); + packet->size = stbi__get8(s); + packet->type = stbi__get8(s); + packet->channel = stbi__get8(s); + act_comp |= packet->channel; + + if (stbi__at_eof(s)) { + stbi__rewind( s ); + return 0; + } + if (packet->size != 8) { + stbi__rewind( s ); + return 0; + } + } while (chained); + + *comp = (act_comp & 0x10 ? 4 : 3); + + return 1; +} +#endif + +// ************************************************************************************************* +// Portable Gray Map and Portable Pixel Map loader +// by Ken Miller +// +// PGM: http://netpbm.sourceforge.net/doc/pgm.html +// PPM: http://netpbm.sourceforge.net/doc/ppm.html +// +// Known limitations: +// Does not support comments in the header section +// Does not support ASCII image data (formats P2 and P3) + +#ifndef STBI_NO_PNM + +static int stbi__pnm_test(stbi__context *s) +{ + char p, t; + p = (char) stbi__get8(s); + t = (char) stbi__get8(s); + if (p != 'P' || (t != '5' && t != '6')) { + stbi__rewind( s ); + return 0; + } + return 1; +} + +static void *stbi__pnm_load(stbi__context *s, int *x, int *y, int *comp, int req_comp, stbi__result_info *ri) +{ + stbi_uc *out; + STBI_NOTUSED(ri); + + ri->bits_per_channel = stbi__pnm_info(s, (int *)&s->img_x, (int *)&s->img_y, (int *)&s->img_n); + if (ri->bits_per_channel == 0) + return 0; + + if (s->img_y > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); + if (s->img_x > STBI_MAX_DIMENSIONS) return stbi__errpuc("too large","Very large image (corrupt?)"); + + *x = s->img_x; + *y = s->img_y; + if (comp) *comp = s->img_n; + + if (!stbi__mad4sizes_valid(s->img_n, s->img_x, s->img_y, ri->bits_per_channel / 8, 0)) + return stbi__errpuc("too large", "PNM too large"); + + out = (stbi_uc *) stbi__malloc_mad4(s->img_n, s->img_x, s->img_y, ri->bits_per_channel / 8, 0); + if (!out) return stbi__errpuc("outofmem", "Out of memory"); + if (!stbi__getn(s, out, s->img_n * s->img_x * s->img_y * (ri->bits_per_channel / 8))) { + STBI_FREE(out); + return stbi__errpuc("bad PNM", "PNM file truncated"); + } + + if (req_comp && req_comp != s->img_n) { + if (ri->bits_per_channel == 16) { + out = (stbi_uc *) stbi__convert_format16((stbi__uint16 *) out, s->img_n, req_comp, s->img_x, s->img_y); + } else { + out = stbi__convert_format(out, s->img_n, req_comp, s->img_x, s->img_y); + } + if (out == NULL) return out; // stbi__convert_format frees input on failure + } + return out; +} + +static int stbi__pnm_isspace(char c) +{ + return c == ' ' || c == '\t' || c == '\n' || c == '\v' || c == '\f' || c == '\r'; +} + +static void stbi__pnm_skip_whitespace(stbi__context *s, char *c) +{ + for (;;) { + while (!stbi__at_eof(s) && stbi__pnm_isspace(*c)) + *c = (char) stbi__get8(s); + + if (stbi__at_eof(s) || *c != '#') + break; + + while (!stbi__at_eof(s) && *c != '\n' && *c != '\r' ) + *c = (char) stbi__get8(s); + } +} + +static int stbi__pnm_isdigit(char c) +{ + return c >= '0' && c <= '9'; +} + +static int stbi__pnm_getinteger(stbi__context *s, char *c) +{ + int value = 0; + + while (!stbi__at_eof(s) && stbi__pnm_isdigit(*c)) { + value = value*10 + (*c - '0'); + *c = (char) stbi__get8(s); + if((value > 214748364) || (value == 214748364 && *c > '7')) + return stbi__err("integer parse overflow", "Parsing an integer in the PPM header overflowed a 32-bit int"); + } + + return value; +} + +static int stbi__pnm_info(stbi__context *s, int *x, int *y, int *comp) +{ + int maxv, dummy; + char c, p, t; + + if (!x) x = &dummy; + if (!y) y = &dummy; + if (!comp) comp = &dummy; + + stbi__rewind(s); + + // Get identifier + p = (char) stbi__get8(s); + t = (char) stbi__get8(s); + if (p != 'P' || (t != '5' && t != '6')) { + stbi__rewind(s); + return 0; + } + + *comp = (t == '6') ? 3 : 1; // '5' is 1-component .pgm; '6' is 3-component .ppm + + c = (char) stbi__get8(s); + stbi__pnm_skip_whitespace(s, &c); + + *x = stbi__pnm_getinteger(s, &c); // read width + if(*x == 0) + return stbi__err("invalid width", "PPM image header had zero or overflowing width"); + stbi__pnm_skip_whitespace(s, &c); + + *y = stbi__pnm_getinteger(s, &c); // read height + if (*y == 0) + return stbi__err("invalid width", "PPM image header had zero or overflowing width"); + stbi__pnm_skip_whitespace(s, &c); + + maxv = stbi__pnm_getinteger(s, &c); // read max value + if (maxv > 65535) + return stbi__err("max value > 65535", "PPM image supports only 8-bit and 16-bit images"); + else if (maxv > 255) + return 16; + else + return 8; +} + +static int stbi__pnm_is16(stbi__context *s) +{ + if (stbi__pnm_info(s, NULL, NULL, NULL) == 16) + return 1; + return 0; +} +#endif + +static int stbi__info_main(stbi__context *s, int *x, int *y, int *comp) +{ + #ifndef STBI_NO_JPEG + if (stbi__jpeg_info(s, x, y, comp)) return 1; + #endif + + #ifndef STBI_NO_PNG + if (stbi__png_info(s, x, y, comp)) return 1; + #endif + + #ifndef STBI_NO_GIF + if (stbi__gif_info(s, x, y, comp)) return 1; + #endif + + #ifndef STBI_NO_BMP + if (stbi__bmp_info(s, x, y, comp)) return 1; + #endif + + #ifndef STBI_NO_PSD + if (stbi__psd_info(s, x, y, comp)) return 1; + #endif + + #ifndef STBI_NO_PIC + if (stbi__pic_info(s, x, y, comp)) return 1; + #endif + + #ifndef STBI_NO_PNM + if (stbi__pnm_info(s, x, y, comp)) return 1; + #endif + + #ifndef STBI_NO_HDR + if (stbi__hdr_info(s, x, y, comp)) return 1; + #endif + + // test tga last because it's a crappy test! + #ifndef STBI_NO_TGA + if (stbi__tga_info(s, x, y, comp)) + return 1; + #endif + return stbi__err("unknown image type", "Image not of any known type, or corrupt"); +} + +static int stbi__is_16_main(stbi__context *s) +{ + #ifndef STBI_NO_PNG + if (stbi__png_is16(s)) return 1; + #endif + + #ifndef STBI_NO_PSD + if (stbi__psd_is16(s)) return 1; + #endif + + #ifndef STBI_NO_PNM + if (stbi__pnm_is16(s)) return 1; + #endif + return 0; +} + +#ifndef STBI_NO_STDIO +STBIDEF int stbi_info(char const *filename, int *x, int *y, int *comp) +{ + FILE *f = stbi__fopen(filename, "rb"); + int result; + if (!f) return stbi__err("can't fopen", "Unable to open file"); + result = stbi_info_from_file(f, x, y, comp); + fclose(f); + return result; +} + +STBIDEF int stbi_info_from_file(FILE *f, int *x, int *y, int *comp) +{ + int r; + stbi__context s; + long pos = ftell(f); + stbi__start_file(&s, f); + r = stbi__info_main(&s,x,y,comp); + fseek(f,pos,SEEK_SET); + return r; +} + +STBIDEF int stbi_is_16_bit(char const *filename) +{ + FILE *f = stbi__fopen(filename, "rb"); + int result; + if (!f) return stbi__err("can't fopen", "Unable to open file"); + result = stbi_is_16_bit_from_file(f); + fclose(f); + return result; +} + +STBIDEF int stbi_is_16_bit_from_file(FILE *f) +{ + int r; + stbi__context s; + long pos = ftell(f); + stbi__start_file(&s, f); + r = stbi__is_16_main(&s); + fseek(f,pos,SEEK_SET); + return r; +} +#endif // !STBI_NO_STDIO + +STBIDEF int stbi_info_from_memory(stbi_uc const *buffer, int len, int *x, int *y, int *comp) +{ + stbi__context s; + stbi__start_mem(&s,buffer,len); + return stbi__info_main(&s,x,y,comp); +} + +STBIDEF int stbi_info_from_callbacks(stbi_io_callbacks const *c, void *user, int *x, int *y, int *comp) +{ + stbi__context s; + stbi__start_callbacks(&s, (stbi_io_callbacks *) c, user); + return stbi__info_main(&s,x,y,comp); +} + +STBIDEF int stbi_is_16_bit_from_memory(stbi_uc const *buffer, int len) +{ + stbi__context s; + stbi__start_mem(&s,buffer,len); + return stbi__is_16_main(&s); +} + +STBIDEF int stbi_is_16_bit_from_callbacks(stbi_io_callbacks const *c, void *user) +{ + stbi__context s; + stbi__start_callbacks(&s, (stbi_io_callbacks *) c, user); + return stbi__is_16_main(&s); +} + +#endif // STB_IMAGE_IMPLEMENTATION + +/* + revision history: + 2.20 (2019-02-07) support utf8 filenames in Windows; fix warnings and platform ifdefs + 2.19 (2018-02-11) fix warning + 2.18 (2018-01-30) fix warnings + 2.17 (2018-01-29) change sbti__shiftsigned to avoid clang -O2 bug + 1-bit BMP + *_is_16_bit api + avoid warnings + 2.16 (2017-07-23) all functions have 16-bit variants; + STBI_NO_STDIO works again; + compilation fixes; + fix rounding in unpremultiply; + optimize vertical flip; + disable raw_len validation; + documentation fixes + 2.15 (2017-03-18) fix png-1,2,4 bug; now all Imagenet JPGs decode; + warning fixes; disable run-time SSE detection on gcc; + uniform handling of optional "return" values; + thread-safe initialization of zlib tables + 2.14 (2017-03-03) remove deprecated STBI_JPEG_OLD; fixes for Imagenet JPGs + 2.13 (2016-11-29) add 16-bit API, only supported for PNG right now + 2.12 (2016-04-02) fix typo in 2.11 PSD fix that caused crashes + 2.11 (2016-04-02) allocate large structures on the stack + remove white matting for transparent PSD + fix reported channel count for PNG & BMP + re-enable SSE2 in non-gcc 64-bit + support RGB-formatted JPEG + read 16-bit PNGs (only as 8-bit) + 2.10 (2016-01-22) avoid warning introduced in 2.09 by STBI_REALLOC_SIZED + 2.09 (2016-01-16) allow comments in PNM files + 16-bit-per-pixel TGA (not bit-per-component) + info() for TGA could break due to .hdr handling + info() for BMP to shares code instead of sloppy parse + can use STBI_REALLOC_SIZED if allocator doesn't support realloc + code cleanup + 2.08 (2015-09-13) fix to 2.07 cleanup, reading RGB PSD as RGBA + 2.07 (2015-09-13) fix compiler warnings + partial animated GIF support + limited 16-bpc PSD support + #ifdef unused functions + bug with < 92 byte PIC,PNM,HDR,TGA + 2.06 (2015-04-19) fix bug where PSD returns wrong '*comp' value + 2.05 (2015-04-19) fix bug in progressive JPEG handling, fix warning + 2.04 (2015-04-15) try to re-enable SIMD on MinGW 64-bit + 2.03 (2015-04-12) extra corruption checking (mmozeiko) + stbi_set_flip_vertically_on_load (nguillemot) + fix NEON support; fix mingw support + 2.02 (2015-01-19) fix incorrect assert, fix warning + 2.01 (2015-01-17) fix various warnings; suppress SIMD on gcc 32-bit without -msse2 + 2.00b (2014-12-25) fix STBI_MALLOC in progressive JPEG + 2.00 (2014-12-25) optimize JPG, including x86 SSE2 & NEON SIMD (ryg) + progressive JPEG (stb) + PGM/PPM support (Ken Miller) + STBI_MALLOC,STBI_REALLOC,STBI_FREE + GIF bugfix -- seemingly never worked + STBI_NO_*, STBI_ONLY_* + 1.48 (2014-12-14) fix incorrectly-named assert() + 1.47 (2014-12-14) 1/2/4-bit PNG support, both direct and paletted (Omar Cornut & stb) + optimize PNG (ryg) + fix bug in interlaced PNG with user-specified channel count (stb) + 1.46 (2014-08-26) + fix broken tRNS chunk (colorkey-style transparency) in non-paletted PNG + 1.45 (2014-08-16) + fix MSVC-ARM internal compiler error by wrapping malloc + 1.44 (2014-08-07) + various warning fixes from Ronny Chevalier + 1.43 (2014-07-15) + fix MSVC-only compiler problem in code changed in 1.42 + 1.42 (2014-07-09) + don't define _CRT_SECURE_NO_WARNINGS (affects user code) + fixes to stbi__cleanup_jpeg path + added STBI_ASSERT to avoid requiring assert.h + 1.41 (2014-06-25) + fix search&replace from 1.36 that messed up comments/error messages + 1.40 (2014-06-22) + fix gcc struct-initialization warning + 1.39 (2014-06-15) + fix to TGA optimization when req_comp != number of components in TGA; + fix to GIF loading because BMP wasn't rewinding (whoops, no GIFs in my test suite) + add support for BMP version 5 (more ignored fields) + 1.38 (2014-06-06) + suppress MSVC warnings on integer casts truncating values + fix accidental rename of 'skip' field of I/O + 1.37 (2014-06-04) + remove duplicate typedef + 1.36 (2014-06-03) + convert to header file single-file library + if de-iphone isn't set, load iphone images color-swapped instead of returning NULL + 1.35 (2014-05-27) + various warnings + fix broken STBI_SIMD path + fix bug where stbi_load_from_file no longer left file pointer in correct place + fix broken non-easy path for 32-bit BMP (possibly never used) + TGA optimization by Arseny Kapoulkine + 1.34 (unknown) + use STBI_NOTUSED in stbi__resample_row_generic(), fix one more leak in tga failure case + 1.33 (2011-07-14) + make stbi_is_hdr work in STBI_NO_HDR (as specified), minor compiler-friendly improvements + 1.32 (2011-07-13) + support for "info" function for all supported filetypes (SpartanJ) + 1.31 (2011-06-20) + a few more leak fixes, bug in PNG handling (SpartanJ) + 1.30 (2011-06-11) + added ability to load files via callbacks to accomidate custom input streams (Ben Wenger) + removed deprecated format-specific test/load functions + removed support for installable file formats (stbi_loader) -- would have been broken for IO callbacks anyway + error cases in bmp and tga give messages and don't leak (Raymond Barbiero, grisha) + fix inefficiency in decoding 32-bit BMP (David Woo) + 1.29 (2010-08-16) + various warning fixes from Aurelien Pocheville + 1.28 (2010-08-01) + fix bug in GIF palette transparency (SpartanJ) + 1.27 (2010-08-01) + cast-to-stbi_uc to fix warnings + 1.26 (2010-07-24) + fix bug in file buffering for PNG reported by SpartanJ + 1.25 (2010-07-17) + refix trans_data warning (Won Chun) + 1.24 (2010-07-12) + perf improvements reading from files on platforms with lock-heavy fgetc() + minor perf improvements for jpeg + deprecated type-specific functions so we'll get feedback if they're needed + attempt to fix trans_data warning (Won Chun) + 1.23 fixed bug in iPhone support + 1.22 (2010-07-10) + removed image *writing* support + stbi_info support from Jetro Lauha + GIF support from Jean-Marc Lienher + iPhone PNG-extensions from James Brown + warning-fixes from Nicolas Schulz and Janez Zemva (i.stbi__err. Janez (U+017D)emva) + 1.21 fix use of 'stbi_uc' in header (reported by jon blow) + 1.20 added support for Softimage PIC, by Tom Seddon + 1.19 bug in interlaced PNG corruption check (found by ryg) + 1.18 (2008-08-02) + fix a threading bug (local mutable static) + 1.17 support interlaced PNG + 1.16 major bugfix - stbi__convert_format converted one too many pixels + 1.15 initialize some fields for thread safety + 1.14 fix threadsafe conversion bug + header-file-only version (#define STBI_HEADER_FILE_ONLY before including) + 1.13 threadsafe + 1.12 const qualifiers in the API + 1.11 Support installable IDCT, colorspace conversion routines + 1.10 Fixes for 64-bit (don't use "unsigned long") + optimized upsampling by Fabian "ryg" Giesen + 1.09 Fix format-conversion for PSD code (bad global variables!) + 1.08 Thatcher Ulrich's PSD code integrated by Nicolas Schulz + 1.07 attempt to fix C++ warning/errors again + 1.06 attempt to fix C++ warning/errors again + 1.05 fix TGA loading to return correct *comp and use good luminance calc + 1.04 default float alpha is 1, not 255; use 'void *' for stbi_image_free + 1.03 bugfixes to STBI_NO_STDIO, STBI_NO_HDR + 1.02 support for (subset of) HDR files, float interface for preferred access to them + 1.01 fix bug: possible bug in handling right-side up bmps... not sure + fix bug: the stbi__bmp_load() and stbi__tga_load() functions didn't work at all + 1.00 interface to zlib that skips zlib header + 0.99 correct handling of alpha in palette + 0.98 TGA loader by lonesock; dynamically add loaders (untested) + 0.97 jpeg errors on too large a file; also catch another malloc failure + 0.96 fix detection of invalid v value - particleman@mollyrocket forum + 0.95 during header scan, seek to markers in case of padding + 0.94 STBI_NO_STDIO to disable stdio usage; rename all #defines the same + 0.93 handle jpegtran output; verbose errors + 0.92 read 4,8,16,24,32-bit BMP files of several formats + 0.91 output 24-bit Windows 3.0 BMP files + 0.90 fix a few more warnings; bump version number to approach 1.0 + 0.61 bugfixes due to Marc LeBlanc, Christopher Lloyd + 0.60 fix compiling as c++ + 0.59 fix warnings: merge Dave Moore's -Wall fixes + 0.58 fix bug: zlib uncompressed mode len/nlen was wrong endian + 0.57 fix bug: jpg last huffman symbol before marker was >9 bits but less than 16 available + 0.56 fix bug: zlib uncompressed mode len vs. nlen + 0.55 fix bug: restart_interval not initialized to 0 + 0.54 allow NULL for 'int *comp' + 0.53 fix bug in png 3->4; speedup png decoding + 0.52 png handles req_comp=3,4 directly; minor cleanup; jpeg comments + 0.51 obey req_comp requests, 1-component jpegs return as 1-component, + on 'test' only check type, not whether we support this variant + 0.50 (2006-11-19) + first released version +*/ + + +/* +------------------------------------------------------------------------------ +This software is available under 2 licenses -- choose whichever you prefer. +------------------------------------------------------------------------------ +ALTERNATIVE A - MIT License +Copyright (c) 2017 Sean Barrett +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies +of the Software, and to permit persons to whom the Software is furnished to do +so, subject to the following conditions: +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +------------------------------------------------------------------------------ +ALTERNATIVE B - Public Domain (www.unlicense.org) +This is free and unencumbered software released into the public domain. +Anyone is free to copy, modify, publish, use, compile, sell, or distribute this +software, either in source code form or as a compiled binary, for any purpose, +commercial or non-commercial, and by any means. +In jurisdictions that recognize copyright laws, the author or authors of this +software dedicate any and all copyright interest in the software to the public +domain. We make this dedication for the benefit of the public at large and to +the detriment of our heirs and successors. We intend this dedication to be an +overt act of relinquishment in perpetuity of all present and future rights to +this software under copyright law. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN +ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +------------------------------------------------------------------------------ +*/ diff --git a/extern/stb_image_write.h b/extern/stb_image_write.h new file mode 100644 index 0000000..e4b32ed --- /dev/null +++ b/extern/stb_image_write.h @@ -0,0 +1,1724 @@ +/* stb_image_write - v1.16 - public domain - http://nothings.org/stb + writes out PNG/BMP/TGA/JPEG/HDR images to C stdio - Sean Barrett 2010-2015 + no warranty implied; use at your own risk + + Before #including, + + #define STB_IMAGE_WRITE_IMPLEMENTATION + + in the file that you want to have the implementation. + + Will probably not work correctly with strict-aliasing optimizations. + +ABOUT: + + This header file is a library for writing images to C stdio or a callback. + + The PNG output is not optimal; it is 20-50% larger than the file + written by a decent optimizing implementation; though providing a custom + zlib compress function (see STBIW_ZLIB_COMPRESS) can mitigate that. + This library is designed for source code compactness and simplicity, + not optimal image file size or run-time performance. + +BUILDING: + + You can #define STBIW_ASSERT(x) before the #include to avoid using assert.h. + You can #define STBIW_MALLOC(), STBIW_REALLOC(), and STBIW_FREE() to replace + malloc,realloc,free. + You can #define STBIW_MEMMOVE() to replace memmove() + You can #define STBIW_ZLIB_COMPRESS to use a custom zlib-style compress function + for PNG compression (instead of the builtin one), it must have the following signature: + unsigned char * my_compress(unsigned char *data, int data_len, int *out_len, int quality); + The returned data will be freed with STBIW_FREE() (free() by default), + so it must be heap allocated with STBIW_MALLOC() (malloc() by default), + +UNICODE: + + If compiling for Windows and you wish to use Unicode filenames, compile + with + #define STBIW_WINDOWS_UTF8 + and pass utf8-encoded filenames. Call stbiw_convert_wchar_to_utf8 to convert + Windows wchar_t filenames to utf8. + +USAGE: + + There are five functions, one for each image file format: + + int stbi_write_png(char const *filename, int w, int h, int comp, const void *data, int stride_in_bytes); + int stbi_write_bmp(char const *filename, int w, int h, int comp, const void *data); + int stbi_write_tga(char const *filename, int w, int h, int comp, const void *data); + int stbi_write_jpg(char const *filename, int w, int h, int comp, const void *data, int quality); + int stbi_write_hdr(char const *filename, int w, int h, int comp, const float *data); + + void stbi_flip_vertically_on_write(int flag); // flag is non-zero to flip data vertically + + There are also five equivalent functions that use an arbitrary write function. You are + expected to open/close your file-equivalent before and after calling these: + + int stbi_write_png_to_func(stbi_write_func *func, void *context, int w, int h, int comp, const void *data, int stride_in_bytes); + int stbi_write_bmp_to_func(stbi_write_func *func, void *context, int w, int h, int comp, const void *data); + int stbi_write_tga_to_func(stbi_write_func *func, void *context, int w, int h, int comp, const void *data); + int stbi_write_hdr_to_func(stbi_write_func *func, void *context, int w, int h, int comp, const float *data); + int stbi_write_jpg_to_func(stbi_write_func *func, void *context, int x, int y, int comp, const void *data, int quality); + + where the callback is: + void stbi_write_func(void *context, void *data, int size); + + You can configure it with these global variables: + int stbi_write_tga_with_rle; // defaults to true; set to 0 to disable RLE + int stbi_write_png_compression_level; // defaults to 8; set to higher for more compression + int stbi_write_force_png_filter; // defaults to -1; set to 0..5 to force a filter mode + + + You can define STBI_WRITE_NO_STDIO to disable the file variant of these + functions, so the library will not use stdio.h at all. However, this will + also disable HDR writing, because it requires stdio for formatted output. + + Each function returns 0 on failure and non-0 on success. + + The functions create an image file defined by the parameters. The image + is a rectangle of pixels stored from left-to-right, top-to-bottom. + Each pixel contains 'comp' channels of data stored interleaved with 8-bits + per channel, in the following order: 1=Y, 2=YA, 3=RGB, 4=RGBA. (Y is + monochrome color.) The rectangle is 'w' pixels wide and 'h' pixels tall. + The *data pointer points to the first byte of the top-left-most pixel. + For PNG, "stride_in_bytes" is the distance in bytes from the first byte of + a row of pixels to the first byte of the next row of pixels. + + PNG creates output files with the same number of components as the input. + The BMP format expands Y to RGB in the file format and does not + output alpha. + + PNG supports writing rectangles of data even when the bytes storing rows of + data are not consecutive in memory (e.g. sub-rectangles of a larger image), + by supplying the stride between the beginning of adjacent rows. The other + formats do not. (Thus you cannot write a native-format BMP through the BMP + writer, both because it is in BGR order and because it may have padding + at the end of the line.) + + PNG allows you to set the deflate compression level by setting the global + variable 'stbi_write_png_compression_level' (it defaults to 8). + + HDR expects linear float data. Since the format is always 32-bit rgb(e) + data, alpha (if provided) is discarded, and for monochrome data it is + replicated across all three channels. + + TGA supports RLE or non-RLE compressed data. To use non-RLE-compressed + data, set the global variable 'stbi_write_tga_with_rle' to 0. + + JPEG does ignore alpha channels in input data; quality is between 1 and 100. + Higher quality looks better but results in a bigger image. + JPEG baseline (no JPEG progressive). + +CREDITS: + + + Sean Barrett - PNG/BMP/TGA + Baldur Karlsson - HDR + Jean-Sebastien Guay - TGA monochrome + Tim Kelsey - misc enhancements + Alan Hickman - TGA RLE + Emmanuel Julien - initial file IO callback implementation + Jon Olick - original jo_jpeg.cpp code + Daniel Gibson - integrate JPEG, allow external zlib + Aarni Koskela - allow choosing PNG filter + + bugfixes: + github:Chribba + Guillaume Chereau + github:jry2 + github:romigrou + Sergio Gonzalez + Jonas Karlsson + Filip Wasil + Thatcher Ulrich + github:poppolopoppo + Patrick Boettcher + github:xeekworx + Cap Petschulat + Simon Rodriguez + Ivan Tikhonov + github:ignotion + Adam Schackart + Andrew Kensler + +LICENSE + + See end of file for license information. + +*/ + +#ifndef INCLUDE_STB_IMAGE_WRITE_H +#define INCLUDE_STB_IMAGE_WRITE_H + +#include + +// if STB_IMAGE_WRITE_STATIC causes problems, try defining STBIWDEF to 'inline' or 'static inline' +#ifndef STBIWDEF +#ifdef STB_IMAGE_WRITE_STATIC +#define STBIWDEF static +#else +#ifdef __cplusplus +#define STBIWDEF extern "C" +#else +#define STBIWDEF extern +#endif +#endif +#endif + +#ifndef STB_IMAGE_WRITE_STATIC // C++ forbids static forward declarations +STBIWDEF int stbi_write_tga_with_rle; +STBIWDEF int stbi_write_png_compression_level; +STBIWDEF int stbi_write_force_png_filter; +#endif + +#ifndef STBI_WRITE_NO_STDIO +STBIWDEF int stbi_write_png(char const *filename, int w, int h, int comp, const void *data, int stride_in_bytes); +STBIWDEF int stbi_write_bmp(char const *filename, int w, int h, int comp, const void *data); +STBIWDEF int stbi_write_tga(char const *filename, int w, int h, int comp, const void *data); +STBIWDEF int stbi_write_hdr(char const *filename, int w, int h, int comp, const float *data); +STBIWDEF int stbi_write_jpg(char const *filename, int x, int y, int comp, const void *data, int quality); + +#ifdef STBIW_WINDOWS_UTF8 +STBIWDEF int stbiw_convert_wchar_to_utf8(char *buffer, size_t bufferlen, const wchar_t* input); +#endif +#endif + +typedef void stbi_write_func(void *context, void *data, int size); + +STBIWDEF int stbi_write_png_to_func(stbi_write_func *func, void *context, int w, int h, int comp, const void *data, int stride_in_bytes); +STBIWDEF int stbi_write_bmp_to_func(stbi_write_func *func, void *context, int w, int h, int comp, const void *data); +STBIWDEF int stbi_write_tga_to_func(stbi_write_func *func, void *context, int w, int h, int comp, const void *data); +STBIWDEF int stbi_write_hdr_to_func(stbi_write_func *func, void *context, int w, int h, int comp, const float *data); +STBIWDEF int stbi_write_jpg_to_func(stbi_write_func *func, void *context, int x, int y, int comp, const void *data, int quality); + +STBIWDEF void stbi_flip_vertically_on_write(int flip_boolean); + +#endif//INCLUDE_STB_IMAGE_WRITE_H + +#ifdef STB_IMAGE_WRITE_IMPLEMENTATION + +#ifdef _WIN32 + #ifndef _CRT_SECURE_NO_WARNINGS + #define _CRT_SECURE_NO_WARNINGS + #endif + #ifndef _CRT_NONSTDC_NO_DEPRECATE + #define _CRT_NONSTDC_NO_DEPRECATE + #endif +#endif + +#ifndef STBI_WRITE_NO_STDIO +#include +#endif // STBI_WRITE_NO_STDIO + +#include +#include +#include +#include + +#if defined(STBIW_MALLOC) && defined(STBIW_FREE) && (defined(STBIW_REALLOC) || defined(STBIW_REALLOC_SIZED)) +// ok +#elif !defined(STBIW_MALLOC) && !defined(STBIW_FREE) && !defined(STBIW_REALLOC) && !defined(STBIW_REALLOC_SIZED) +// ok +#else +#error "Must define all or none of STBIW_MALLOC, STBIW_FREE, and STBIW_REALLOC (or STBIW_REALLOC_SIZED)." +#endif + +#ifndef STBIW_MALLOC +#define STBIW_MALLOC(sz) malloc(sz) +#define STBIW_REALLOC(p,newsz) realloc(p,newsz) +#define STBIW_FREE(p) free(p) +#endif + +#ifndef STBIW_REALLOC_SIZED +#define STBIW_REALLOC_SIZED(p,oldsz,newsz) STBIW_REALLOC(p,newsz) +#endif + + +#ifndef STBIW_MEMMOVE +#define STBIW_MEMMOVE(a,b,sz) memmove(a,b,sz) +#endif + + +#ifndef STBIW_ASSERT +#include +#define STBIW_ASSERT(x) assert(x) +#endif + +#define STBIW_UCHAR(x) (unsigned char) ((x) & 0xff) + +#ifdef STB_IMAGE_WRITE_STATIC +static int stbi_write_png_compression_level = 8; +static int stbi_write_tga_with_rle = 1; +static int stbi_write_force_png_filter = -1; +#else +int stbi_write_png_compression_level = 8; +int stbi_write_tga_with_rle = 1; +int stbi_write_force_png_filter = -1; +#endif + +static int stbi__flip_vertically_on_write = 0; + +STBIWDEF void stbi_flip_vertically_on_write(int flag) +{ + stbi__flip_vertically_on_write = flag; +} + +typedef struct +{ + stbi_write_func *func; + void *context; + unsigned char buffer[64]; + int buf_used; +} stbi__write_context; + +// initialize a callback-based context +static void stbi__start_write_callbacks(stbi__write_context *s, stbi_write_func *c, void *context) +{ + s->func = c; + s->context = context; +} + +#ifndef STBI_WRITE_NO_STDIO + +static void stbi__stdio_write(void *context, void *data, int size) +{ + fwrite(data,1,size,(FILE*) context); +} + +#if defined(_WIN32) && defined(STBIW_WINDOWS_UTF8) +#ifdef __cplusplus +#define STBIW_EXTERN extern "C" +#else +#define STBIW_EXTERN extern +#endif +STBIW_EXTERN __declspec(dllimport) int __stdcall MultiByteToWideChar(unsigned int cp, unsigned long flags, const char *str, int cbmb, wchar_t *widestr, int cchwide); +STBIW_EXTERN __declspec(dllimport) int __stdcall WideCharToMultiByte(unsigned int cp, unsigned long flags, const wchar_t *widestr, int cchwide, char *str, int cbmb, const char *defchar, int *used_default); + +STBIWDEF int stbiw_convert_wchar_to_utf8(char *buffer, size_t bufferlen, const wchar_t* input) +{ + return WideCharToMultiByte(65001 /* UTF8 */, 0, input, -1, buffer, (int) bufferlen, NULL, NULL); +} +#endif + +static FILE *stbiw__fopen(char const *filename, char const *mode) +{ + FILE *f; +#if defined(_WIN32) && defined(STBIW_WINDOWS_UTF8) + wchar_t wMode[64]; + wchar_t wFilename[1024]; + if (0 == MultiByteToWideChar(65001 /* UTF8 */, 0, filename, -1, wFilename, sizeof(wFilename)/sizeof(*wFilename))) + return 0; + + if (0 == MultiByteToWideChar(65001 /* UTF8 */, 0, mode, -1, wMode, sizeof(wMode)/sizeof(*wMode))) + return 0; + +#if defined(_MSC_VER) && _MSC_VER >= 1400 + if (0 != _wfopen_s(&f, wFilename, wMode)) + f = 0; +#else + f = _wfopen(wFilename, wMode); +#endif + +#elif defined(_MSC_VER) && _MSC_VER >= 1400 + if (0 != fopen_s(&f, filename, mode)) + f=0; +#else + f = fopen(filename, mode); +#endif + return f; +} + +static int stbi__start_write_file(stbi__write_context *s, const char *filename) +{ + FILE *f = stbiw__fopen(filename, "wb"); + stbi__start_write_callbacks(s, stbi__stdio_write, (void *) f); + return f != NULL; +} + +static void stbi__end_write_file(stbi__write_context *s) +{ + fclose((FILE *)s->context); +} + +#endif // !STBI_WRITE_NO_STDIO + +typedef unsigned int stbiw_uint32; +typedef int stb_image_write_test[sizeof(stbiw_uint32)==4 ? 1 : -1]; + +static void stbiw__writefv(stbi__write_context *s, const char *fmt, va_list v) +{ + while (*fmt) { + switch (*fmt++) { + case ' ': break; + case '1': { unsigned char x = STBIW_UCHAR(va_arg(v, int)); + s->func(s->context,&x,1); + break; } + case '2': { int x = va_arg(v,int); + unsigned char b[2]; + b[0] = STBIW_UCHAR(x); + b[1] = STBIW_UCHAR(x>>8); + s->func(s->context,b,2); + break; } + case '4': { stbiw_uint32 x = va_arg(v,int); + unsigned char b[4]; + b[0]=STBIW_UCHAR(x); + b[1]=STBIW_UCHAR(x>>8); + b[2]=STBIW_UCHAR(x>>16); + b[3]=STBIW_UCHAR(x>>24); + s->func(s->context,b,4); + break; } + default: + STBIW_ASSERT(0); + return; + } + } +} + +static void stbiw__writef(stbi__write_context *s, const char *fmt, ...) +{ + va_list v; + va_start(v, fmt); + stbiw__writefv(s, fmt, v); + va_end(v); +} + +static void stbiw__write_flush(stbi__write_context *s) +{ + if (s->buf_used) { + s->func(s->context, &s->buffer, s->buf_used); + s->buf_used = 0; + } +} + +static void stbiw__putc(stbi__write_context *s, unsigned char c) +{ + s->func(s->context, &c, 1); +} + +static void stbiw__write1(stbi__write_context *s, unsigned char a) +{ + if ((size_t)s->buf_used + 1 > sizeof(s->buffer)) + stbiw__write_flush(s); + s->buffer[s->buf_used++] = a; +} + +static void stbiw__write3(stbi__write_context *s, unsigned char a, unsigned char b, unsigned char c) +{ + int n; + if ((size_t)s->buf_used + 3 > sizeof(s->buffer)) + stbiw__write_flush(s); + n = s->buf_used; + s->buf_used = n+3; + s->buffer[n+0] = a; + s->buffer[n+1] = b; + s->buffer[n+2] = c; +} + +static void stbiw__write_pixel(stbi__write_context *s, int rgb_dir, int comp, int write_alpha, int expand_mono, unsigned char *d) +{ + unsigned char bg[3] = { 255, 0, 255}, px[3]; + int k; + + if (write_alpha < 0) + stbiw__write1(s, d[comp - 1]); + + switch (comp) { + case 2: // 2 pixels = mono + alpha, alpha is written separately, so same as 1-channel case + case 1: + if (expand_mono) + stbiw__write3(s, d[0], d[0], d[0]); // monochrome bmp + else + stbiw__write1(s, d[0]); // monochrome TGA + break; + case 4: + if (!write_alpha) { + // composite against pink background + for (k = 0; k < 3; ++k) + px[k] = bg[k] + ((d[k] - bg[k]) * d[3]) / 255; + stbiw__write3(s, px[1 - rgb_dir], px[1], px[1 + rgb_dir]); + break; + } + /* FALLTHROUGH */ + case 3: + stbiw__write3(s, d[1 - rgb_dir], d[1], d[1 + rgb_dir]); + break; + } + if (write_alpha > 0) + stbiw__write1(s, d[comp - 1]); +} + +static void stbiw__write_pixels(stbi__write_context *s, int rgb_dir, int vdir, int x, int y, int comp, void *data, int write_alpha, int scanline_pad, int expand_mono) +{ + stbiw_uint32 zero = 0; + int i,j, j_end; + + if (y <= 0) + return; + + if (stbi__flip_vertically_on_write) + vdir *= -1; + + if (vdir < 0) { + j_end = -1; j = y-1; + } else { + j_end = y; j = 0; + } + + for (; j != j_end; j += vdir) { + for (i=0; i < x; ++i) { + unsigned char *d = (unsigned char *) data + (j*x+i)*comp; + stbiw__write_pixel(s, rgb_dir, comp, write_alpha, expand_mono, d); + } + stbiw__write_flush(s); + s->func(s->context, &zero, scanline_pad); + } +} + +static int stbiw__outfile(stbi__write_context *s, int rgb_dir, int vdir, int x, int y, int comp, int expand_mono, void *data, int alpha, int pad, const char *fmt, ...) +{ + if (y < 0 || x < 0) { + return 0; + } else { + va_list v; + va_start(v, fmt); + stbiw__writefv(s, fmt, v); + va_end(v); + stbiw__write_pixels(s,rgb_dir,vdir,x,y,comp,data,alpha,pad, expand_mono); + return 1; + } +} + +static int stbi_write_bmp_core(stbi__write_context *s, int x, int y, int comp, const void *data) +{ + if (comp != 4) { + // write RGB bitmap + int pad = (-x*3) & 3; + return stbiw__outfile(s,-1,-1,x,y,comp,1,(void *) data,0,pad, + "11 4 22 4" "4 44 22 444444", + 'B', 'M', 14+40+(x*3+pad)*y, 0,0, 14+40, // file header + 40, x,y, 1,24, 0,0,0,0,0,0); // bitmap header + } else { + // RGBA bitmaps need a v4 header + // use BI_BITFIELDS mode with 32bpp and alpha mask + // (straight BI_RGB with alpha mask doesn't work in most readers) + return stbiw__outfile(s,-1,-1,x,y,comp,1,(void *)data,1,0, + "11 4 22 4" "4 44 22 444444 4444 4 444 444 444 444", + 'B', 'M', 14+108+x*y*4, 0, 0, 14+108, // file header + 108, x,y, 1,32, 3,0,0,0,0,0, 0xff0000,0xff00,0xff,0xff000000u, 0, 0,0,0, 0,0,0, 0,0,0, 0,0,0); // bitmap V4 header + } +} + +STBIWDEF int stbi_write_bmp_to_func(stbi_write_func *func, void *context, int x, int y, int comp, const void *data) +{ + stbi__write_context s = { 0 }; + stbi__start_write_callbacks(&s, func, context); + return stbi_write_bmp_core(&s, x, y, comp, data); +} + +#ifndef STBI_WRITE_NO_STDIO +STBIWDEF int stbi_write_bmp(char const *filename, int x, int y, int comp, const void *data) +{ + stbi__write_context s = { 0 }; + if (stbi__start_write_file(&s,filename)) { + int r = stbi_write_bmp_core(&s, x, y, comp, data); + stbi__end_write_file(&s); + return r; + } else + return 0; +} +#endif //!STBI_WRITE_NO_STDIO + +static int stbi_write_tga_core(stbi__write_context *s, int x, int y, int comp, void *data) +{ + int has_alpha = (comp == 2 || comp == 4); + int colorbytes = has_alpha ? comp-1 : comp; + int format = colorbytes < 2 ? 3 : 2; // 3 color channels (RGB/RGBA) = 2, 1 color channel (Y/YA) = 3 + + if (y < 0 || x < 0) + return 0; + + if (!stbi_write_tga_with_rle) { + return stbiw__outfile(s, -1, -1, x, y, comp, 0, (void *) data, has_alpha, 0, + "111 221 2222 11", 0, 0, format, 0, 0, 0, 0, 0, x, y, (colorbytes + has_alpha) * 8, has_alpha * 8); + } else { + int i,j,k; + int jend, jdir; + + stbiw__writef(s, "111 221 2222 11", 0,0,format+8, 0,0,0, 0,0,x,y, (colorbytes + has_alpha) * 8, has_alpha * 8); + + if (stbi__flip_vertically_on_write) { + j = 0; + jend = y; + jdir = 1; + } else { + j = y-1; + jend = -1; + jdir = -1; + } + for (; j != jend; j += jdir) { + unsigned char *row = (unsigned char *) data + j * x * comp; + int len; + + for (i = 0; i < x; i += len) { + unsigned char *begin = row + i * comp; + int diff = 1; + len = 1; + + if (i < x - 1) { + ++len; + diff = memcmp(begin, row + (i + 1) * comp, comp); + if (diff) { + const unsigned char *prev = begin; + for (k = i + 2; k < x && len < 128; ++k) { + if (memcmp(prev, row + k * comp, comp)) { + prev += comp; + ++len; + } else { + --len; + break; + } + } + } else { + for (k = i + 2; k < x && len < 128; ++k) { + if (!memcmp(begin, row + k * comp, comp)) { + ++len; + } else { + break; + } + } + } + } + + if (diff) { + unsigned char header = STBIW_UCHAR(len - 1); + stbiw__write1(s, header); + for (k = 0; k < len; ++k) { + stbiw__write_pixel(s, -1, comp, has_alpha, 0, begin + k * comp); + } + } else { + unsigned char header = STBIW_UCHAR(len - 129); + stbiw__write1(s, header); + stbiw__write_pixel(s, -1, comp, has_alpha, 0, begin); + } + } + } + stbiw__write_flush(s); + } + return 1; +} + +STBIWDEF int stbi_write_tga_to_func(stbi_write_func *func, void *context, int x, int y, int comp, const void *data) +{ + stbi__write_context s = { 0 }; + stbi__start_write_callbacks(&s, func, context); + return stbi_write_tga_core(&s, x, y, comp, (void *) data); +} + +#ifndef STBI_WRITE_NO_STDIO +STBIWDEF int stbi_write_tga(char const *filename, int x, int y, int comp, const void *data) +{ + stbi__write_context s = { 0 }; + if (stbi__start_write_file(&s,filename)) { + int r = stbi_write_tga_core(&s, x, y, comp, (void *) data); + stbi__end_write_file(&s); + return r; + } else + return 0; +} +#endif + +// ************************************************************************************************* +// Radiance RGBE HDR writer +// by Baldur Karlsson + +#define stbiw__max(a, b) ((a) > (b) ? (a) : (b)) + +#ifndef STBI_WRITE_NO_STDIO + +static void stbiw__linear_to_rgbe(unsigned char *rgbe, float *linear) +{ + int exponent; + float maxcomp = stbiw__max(linear[0], stbiw__max(linear[1], linear[2])); + + if (maxcomp < 1e-32f) { + rgbe[0] = rgbe[1] = rgbe[2] = rgbe[3] = 0; + } else { + float normalize = (float) frexp(maxcomp, &exponent) * 256.0f/maxcomp; + + rgbe[0] = (unsigned char)(linear[0] * normalize); + rgbe[1] = (unsigned char)(linear[1] * normalize); + rgbe[2] = (unsigned char)(linear[2] * normalize); + rgbe[3] = (unsigned char)(exponent + 128); + } +} + +static void stbiw__write_run_data(stbi__write_context *s, int length, unsigned char databyte) +{ + unsigned char lengthbyte = STBIW_UCHAR(length+128); + STBIW_ASSERT(length+128 <= 255); + s->func(s->context, &lengthbyte, 1); + s->func(s->context, &databyte, 1); +} + +static void stbiw__write_dump_data(stbi__write_context *s, int length, unsigned char *data) +{ + unsigned char lengthbyte = STBIW_UCHAR(length); + STBIW_ASSERT(length <= 128); // inconsistent with spec but consistent with official code + s->func(s->context, &lengthbyte, 1); + s->func(s->context, data, length); +} + +static void stbiw__write_hdr_scanline(stbi__write_context *s, int width, int ncomp, unsigned char *scratch, float *scanline) +{ + unsigned char scanlineheader[4] = { 2, 2, 0, 0 }; + unsigned char rgbe[4]; + float linear[3]; + int x; + + scanlineheader[2] = (width&0xff00)>>8; + scanlineheader[3] = (width&0x00ff); + + /* skip RLE for images too small or large */ + if (width < 8 || width >= 32768) { + for (x=0; x < width; x++) { + switch (ncomp) { + case 4: /* fallthrough */ + case 3: linear[2] = scanline[x*ncomp + 2]; + linear[1] = scanline[x*ncomp + 1]; + linear[0] = scanline[x*ncomp + 0]; + break; + default: + linear[0] = linear[1] = linear[2] = scanline[x*ncomp + 0]; + break; + } + stbiw__linear_to_rgbe(rgbe, linear); + s->func(s->context, rgbe, 4); + } + } else { + int c,r; + /* encode into scratch buffer */ + for (x=0; x < width; x++) { + switch(ncomp) { + case 4: /* fallthrough */ + case 3: linear[2] = scanline[x*ncomp + 2]; + linear[1] = scanline[x*ncomp + 1]; + linear[0] = scanline[x*ncomp + 0]; + break; + default: + linear[0] = linear[1] = linear[2] = scanline[x*ncomp + 0]; + break; + } + stbiw__linear_to_rgbe(rgbe, linear); + scratch[x + width*0] = rgbe[0]; + scratch[x + width*1] = rgbe[1]; + scratch[x + width*2] = rgbe[2]; + scratch[x + width*3] = rgbe[3]; + } + + s->func(s->context, scanlineheader, 4); + + /* RLE each component separately */ + for (c=0; c < 4; c++) { + unsigned char *comp = &scratch[width*c]; + + x = 0; + while (x < width) { + // find first run + r = x; + while (r+2 < width) { + if (comp[r] == comp[r+1] && comp[r] == comp[r+2]) + break; + ++r; + } + if (r+2 >= width) + r = width; + // dump up to first run + while (x < r) { + int len = r-x; + if (len > 128) len = 128; + stbiw__write_dump_data(s, len, &comp[x]); + x += len; + } + // if there's a run, output it + if (r+2 < width) { // same test as what we break out of in search loop, so only true if we break'd + // find next byte after run + while (r < width && comp[r] == comp[x]) + ++r; + // output run up to r + while (x < r) { + int len = r-x; + if (len > 127) len = 127; + stbiw__write_run_data(s, len, comp[x]); + x += len; + } + } + } + } + } +} + +static int stbi_write_hdr_core(stbi__write_context *s, int x, int y, int comp, float *data) +{ + if (y <= 0 || x <= 0 || data == NULL) + return 0; + else { + // Each component is stored separately. Allocate scratch space for full output scanline. + unsigned char *scratch = (unsigned char *) STBIW_MALLOC(x*4); + int i, len; + char buffer[128]; + char header[] = "#?RADIANCE\n# Written by stb_image_write.h\nFORMAT=32-bit_rle_rgbe\n"; + s->func(s->context, header, sizeof(header)-1); + +#ifdef __STDC_LIB_EXT1__ + len = sprintf_s(buffer, sizeof(buffer), "EXPOSURE= 1.0000000000000\n\n-Y %d +X %d\n", y, x); +#else + len = sprintf(buffer, "EXPOSURE= 1.0000000000000\n\n-Y %d +X %d\n", y, x); +#endif + s->func(s->context, buffer, len); + + for(i=0; i < y; i++) + stbiw__write_hdr_scanline(s, x, comp, scratch, data + comp*x*(stbi__flip_vertically_on_write ? y-1-i : i)); + STBIW_FREE(scratch); + return 1; + } +} + +STBIWDEF int stbi_write_hdr_to_func(stbi_write_func *func, void *context, int x, int y, int comp, const float *data) +{ + stbi__write_context s = { 0 }; + stbi__start_write_callbacks(&s, func, context); + return stbi_write_hdr_core(&s, x, y, comp, (float *) data); +} + +STBIWDEF int stbi_write_hdr(char const *filename, int x, int y, int comp, const float *data) +{ + stbi__write_context s = { 0 }; + if (stbi__start_write_file(&s,filename)) { + int r = stbi_write_hdr_core(&s, x, y, comp, (float *) data); + stbi__end_write_file(&s); + return r; + } else + return 0; +} +#endif // STBI_WRITE_NO_STDIO + + +////////////////////////////////////////////////////////////////////////////// +// +// PNG writer +// + +#ifndef STBIW_ZLIB_COMPRESS +// stretchy buffer; stbiw__sbpush() == vector<>::push_back() -- stbiw__sbcount() == vector<>::size() +#define stbiw__sbraw(a) ((int *) (void *) (a) - 2) +#define stbiw__sbm(a) stbiw__sbraw(a)[0] +#define stbiw__sbn(a) stbiw__sbraw(a)[1] + +#define stbiw__sbneedgrow(a,n) ((a)==0 || stbiw__sbn(a)+n >= stbiw__sbm(a)) +#define stbiw__sbmaybegrow(a,n) (stbiw__sbneedgrow(a,(n)) ? stbiw__sbgrow(a,n) : 0) +#define stbiw__sbgrow(a,n) stbiw__sbgrowf((void **) &(a), (n), sizeof(*(a))) + +#define stbiw__sbpush(a, v) (stbiw__sbmaybegrow(a,1), (a)[stbiw__sbn(a)++] = (v)) +#define stbiw__sbcount(a) ((a) ? stbiw__sbn(a) : 0) +#define stbiw__sbfree(a) ((a) ? STBIW_FREE(stbiw__sbraw(a)),0 : 0) + +static void *stbiw__sbgrowf(void **arr, int increment, int itemsize) +{ + int m = *arr ? 2*stbiw__sbm(*arr)+increment : increment+1; + void *p = STBIW_REALLOC_SIZED(*arr ? stbiw__sbraw(*arr) : 0, *arr ? (stbiw__sbm(*arr)*itemsize + sizeof(int)*2) : 0, itemsize * m + sizeof(int)*2); + STBIW_ASSERT(p); + if (p) { + if (!*arr) ((int *) p)[1] = 0; + *arr = (void *) ((int *) p + 2); + stbiw__sbm(*arr) = m; + } + return *arr; +} + +static unsigned char *stbiw__zlib_flushf(unsigned char *data, unsigned int *bitbuffer, int *bitcount) +{ + while (*bitcount >= 8) { + stbiw__sbpush(data, STBIW_UCHAR(*bitbuffer)); + *bitbuffer >>= 8; + *bitcount -= 8; + } + return data; +} + +static int stbiw__zlib_bitrev(int code, int codebits) +{ + int res=0; + while (codebits--) { + res = (res << 1) | (code & 1); + code >>= 1; + } + return res; +} + +static unsigned int stbiw__zlib_countm(unsigned char *a, unsigned char *b, int limit) +{ + int i; + for (i=0; i < limit && i < 258; ++i) + if (a[i] != b[i]) break; + return i; +} + +static unsigned int stbiw__zhash(unsigned char *data) +{ + stbiw_uint32 hash = data[0] + (data[1] << 8) + (data[2] << 16); + hash ^= hash << 3; + hash += hash >> 5; + hash ^= hash << 4; + hash += hash >> 17; + hash ^= hash << 25; + hash += hash >> 6; + return hash; +} + +#define stbiw__zlib_flush() (out = stbiw__zlib_flushf(out, &bitbuf, &bitcount)) +#define stbiw__zlib_add(code,codebits) \ + (bitbuf |= (code) << bitcount, bitcount += (codebits), stbiw__zlib_flush()) +#define stbiw__zlib_huffa(b,c) stbiw__zlib_add(stbiw__zlib_bitrev(b,c),c) +// default huffman tables +#define stbiw__zlib_huff1(n) stbiw__zlib_huffa(0x30 + (n), 8) +#define stbiw__zlib_huff2(n) stbiw__zlib_huffa(0x190 + (n)-144, 9) +#define stbiw__zlib_huff3(n) stbiw__zlib_huffa(0 + (n)-256,7) +#define stbiw__zlib_huff4(n) stbiw__zlib_huffa(0xc0 + (n)-280,8) +#define stbiw__zlib_huff(n) ((n) <= 143 ? stbiw__zlib_huff1(n) : (n) <= 255 ? stbiw__zlib_huff2(n) : (n) <= 279 ? stbiw__zlib_huff3(n) : stbiw__zlib_huff4(n)) +#define stbiw__zlib_huffb(n) ((n) <= 143 ? stbiw__zlib_huff1(n) : stbiw__zlib_huff2(n)) + +#define stbiw__ZHASH 16384 + +#endif // STBIW_ZLIB_COMPRESS + +STBIWDEF unsigned char * stbi_zlib_compress(unsigned char *data, int data_len, int *out_len, int quality) +{ +#ifdef STBIW_ZLIB_COMPRESS + // user provided a zlib compress implementation, use that + return STBIW_ZLIB_COMPRESS(data, data_len, out_len, quality); +#else // use builtin + static unsigned short lengthc[] = { 3,4,5,6,7,8,9,10,11,13,15,17,19,23,27,31,35,43,51,59,67,83,99,115,131,163,195,227,258, 259 }; + static unsigned char lengtheb[]= { 0,0,0,0,0,0,0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 0 }; + static unsigned short distc[] = { 1,2,3,4,5,7,9,13,17,25,33,49,65,97,129,193,257,385,513,769,1025,1537,2049,3073,4097,6145,8193,12289,16385,24577, 32768 }; + static unsigned char disteb[] = { 0,0,0,0,1,1,2,2,3,3,4,4,5,5,6,6,7,7,8,8,9,9,10,10,11,11,12,12,13,13 }; + unsigned int bitbuf=0; + int i,j, bitcount=0; + unsigned char *out = NULL; + unsigned char ***hash_table = (unsigned char***) STBIW_MALLOC(stbiw__ZHASH * sizeof(unsigned char**)); + if (hash_table == NULL) + return NULL; + if (quality < 5) quality = 5; + + stbiw__sbpush(out, 0x78); // DEFLATE 32K window + stbiw__sbpush(out, 0x5e); // FLEVEL = 1 + stbiw__zlib_add(1,1); // BFINAL = 1 + stbiw__zlib_add(1,2); // BTYPE = 1 -- fixed huffman + + for (i=0; i < stbiw__ZHASH; ++i) + hash_table[i] = NULL; + + i=0; + while (i < data_len-3) { + // hash next 3 bytes of data to be compressed + int h = stbiw__zhash(data+i)&(stbiw__ZHASH-1), best=3; + unsigned char *bestloc = 0; + unsigned char **hlist = hash_table[h]; + int n = stbiw__sbcount(hlist); + for (j=0; j < n; ++j) { + if (hlist[j]-data > i-32768) { // if entry lies within window + int d = stbiw__zlib_countm(hlist[j], data+i, data_len-i); + if (d >= best) { best=d; bestloc=hlist[j]; } + } + } + // when hash table entry is too long, delete half the entries + if (hash_table[h] && stbiw__sbn(hash_table[h]) == 2*quality) { + STBIW_MEMMOVE(hash_table[h], hash_table[h]+quality, sizeof(hash_table[h][0])*quality); + stbiw__sbn(hash_table[h]) = quality; + } + stbiw__sbpush(hash_table[h],data+i); + + if (bestloc) { + // "lazy matching" - check match at *next* byte, and if it's better, do cur byte as literal + h = stbiw__zhash(data+i+1)&(stbiw__ZHASH-1); + hlist = hash_table[h]; + n = stbiw__sbcount(hlist); + for (j=0; j < n; ++j) { + if (hlist[j]-data > i-32767) { + int e = stbiw__zlib_countm(hlist[j], data+i+1, data_len-i-1); + if (e > best) { // if next match is better, bail on current match + bestloc = NULL; + break; + } + } + } + } + + if (bestloc) { + int d = (int) (data+i - bestloc); // distance back + STBIW_ASSERT(d <= 32767 && best <= 258); + for (j=0; best > lengthc[j+1]-1; ++j); + stbiw__zlib_huff(j+257); + if (lengtheb[j]) stbiw__zlib_add(best - lengthc[j], lengtheb[j]); + for (j=0; d > distc[j+1]-1; ++j); + stbiw__zlib_add(stbiw__zlib_bitrev(j,5),5); + if (disteb[j]) stbiw__zlib_add(d - distc[j], disteb[j]); + i += best; + } else { + stbiw__zlib_huffb(data[i]); + ++i; + } + } + // write out final bytes + for (;i < data_len; ++i) + stbiw__zlib_huffb(data[i]); + stbiw__zlib_huff(256); // end of block + // pad with 0 bits to byte boundary + while (bitcount) + stbiw__zlib_add(0,1); + + for (i=0; i < stbiw__ZHASH; ++i) + (void) stbiw__sbfree(hash_table[i]); + STBIW_FREE(hash_table); + + // store uncompressed instead if compression was worse + if (stbiw__sbn(out) > data_len + 2 + ((data_len+32766)/32767)*5) { + stbiw__sbn(out) = 2; // truncate to DEFLATE 32K window and FLEVEL = 1 + for (j = 0; j < data_len;) { + int blocklen = data_len - j; + if (blocklen > 32767) blocklen = 32767; + stbiw__sbpush(out, data_len - j == blocklen); // BFINAL = ?, BTYPE = 0 -- no compression + stbiw__sbpush(out, STBIW_UCHAR(blocklen)); // LEN + stbiw__sbpush(out, STBIW_UCHAR(blocklen >> 8)); + stbiw__sbpush(out, STBIW_UCHAR(~blocklen)); // NLEN + stbiw__sbpush(out, STBIW_UCHAR(~blocklen >> 8)); + memcpy(out+stbiw__sbn(out), data+j, blocklen); + stbiw__sbn(out) += blocklen; + j += blocklen; + } + } + + { + // compute adler32 on input + unsigned int s1=1, s2=0; + int blocklen = (int) (data_len % 5552); + j=0; + while (j < data_len) { + for (i=0; i < blocklen; ++i) { s1 += data[j+i]; s2 += s1; } + s1 %= 65521; s2 %= 65521; + j += blocklen; + blocklen = 5552; + } + stbiw__sbpush(out, STBIW_UCHAR(s2 >> 8)); + stbiw__sbpush(out, STBIW_UCHAR(s2)); + stbiw__sbpush(out, STBIW_UCHAR(s1 >> 8)); + stbiw__sbpush(out, STBIW_UCHAR(s1)); + } + *out_len = stbiw__sbn(out); + // make returned pointer freeable + STBIW_MEMMOVE(stbiw__sbraw(out), out, *out_len); + return (unsigned char *) stbiw__sbraw(out); +#endif // STBIW_ZLIB_COMPRESS +} + +static unsigned int stbiw__crc32(unsigned char *buffer, int len) +{ +#ifdef STBIW_CRC32 + return STBIW_CRC32(buffer, len); +#else + static unsigned int crc_table[256] = + { + 0x00000000, 0x77073096, 0xEE0E612C, 0x990951BA, 0x076DC419, 0x706AF48F, 0xE963A535, 0x9E6495A3, + 0x0eDB8832, 0x79DCB8A4, 0xE0D5E91E, 0x97D2D988, 0x09B64C2B, 0x7EB17CBD, 0xE7B82D07, 0x90BF1D91, + 0x1DB71064, 0x6AB020F2, 0xF3B97148, 0x84BE41DE, 0x1ADAD47D, 0x6DDDE4EB, 0xF4D4B551, 0x83D385C7, + 0x136C9856, 0x646BA8C0, 0xFD62F97A, 0x8A65C9EC, 0x14015C4F, 0x63066CD9, 0xFA0F3D63, 0x8D080DF5, + 0x3B6E20C8, 0x4C69105E, 0xD56041E4, 0xA2677172, 0x3C03E4D1, 0x4B04D447, 0xD20D85FD, 0xA50AB56B, + 0x35B5A8FA, 0x42B2986C, 0xDBBBC9D6, 0xACBCF940, 0x32D86CE3, 0x45DF5C75, 0xDCD60DCF, 0xABD13D59, + 0x26D930AC, 0x51DE003A, 0xC8D75180, 0xBFD06116, 0x21B4F4B5, 0x56B3C423, 0xCFBA9599, 0xB8BDA50F, + 0x2802B89E, 0x5F058808, 0xC60CD9B2, 0xB10BE924, 0x2F6F7C87, 0x58684C11, 0xC1611DAB, 0xB6662D3D, + 0x76DC4190, 0x01DB7106, 0x98D220BC, 0xEFD5102A, 0x71B18589, 0x06B6B51F, 0x9FBFE4A5, 0xE8B8D433, + 0x7807C9A2, 0x0F00F934, 0x9609A88E, 0xE10E9818, 0x7F6A0DBB, 0x086D3D2D, 0x91646C97, 0xE6635C01, + 0x6B6B51F4, 0x1C6C6162, 0x856530D8, 0xF262004E, 0x6C0695ED, 0x1B01A57B, 0x8208F4C1, 0xF50FC457, + 0x65B0D9C6, 0x12B7E950, 0x8BBEB8EA, 0xFCB9887C, 0x62DD1DDF, 0x15DA2D49, 0x8CD37CF3, 0xFBD44C65, + 0x4DB26158, 0x3AB551CE, 0xA3BC0074, 0xD4BB30E2, 0x4ADFA541, 0x3DD895D7, 0xA4D1C46D, 0xD3D6F4FB, + 0x4369E96A, 0x346ED9FC, 0xAD678846, 0xDA60B8D0, 0x44042D73, 0x33031DE5, 0xAA0A4C5F, 0xDD0D7CC9, + 0x5005713C, 0x270241AA, 0xBE0B1010, 0xC90C2086, 0x5768B525, 0x206F85B3, 0xB966D409, 0xCE61E49F, + 0x5EDEF90E, 0x29D9C998, 0xB0D09822, 0xC7D7A8B4, 0x59B33D17, 0x2EB40D81, 0xB7BD5C3B, 0xC0BA6CAD, + 0xEDB88320, 0x9ABFB3B6, 0x03B6E20C, 0x74B1D29A, 0xEAD54739, 0x9DD277AF, 0x04DB2615, 0x73DC1683, + 0xE3630B12, 0x94643B84, 0x0D6D6A3E, 0x7A6A5AA8, 0xE40ECF0B, 0x9309FF9D, 0x0A00AE27, 0x7D079EB1, + 0xF00F9344, 0x8708A3D2, 0x1E01F268, 0x6906C2FE, 0xF762575D, 0x806567CB, 0x196C3671, 0x6E6B06E7, + 0xFED41B76, 0x89D32BE0, 0x10DA7A5A, 0x67DD4ACC, 0xF9B9DF6F, 0x8EBEEFF9, 0x17B7BE43, 0x60B08ED5, + 0xD6D6A3E8, 0xA1D1937E, 0x38D8C2C4, 0x4FDFF252, 0xD1BB67F1, 0xA6BC5767, 0x3FB506DD, 0x48B2364B, + 0xD80D2BDA, 0xAF0A1B4C, 0x36034AF6, 0x41047A60, 0xDF60EFC3, 0xA867DF55, 0x316E8EEF, 0x4669BE79, + 0xCB61B38C, 0xBC66831A, 0x256FD2A0, 0x5268E236, 0xCC0C7795, 0xBB0B4703, 0x220216B9, 0x5505262F, + 0xC5BA3BBE, 0xB2BD0B28, 0x2BB45A92, 0x5CB36A04, 0xC2D7FFA7, 0xB5D0CF31, 0x2CD99E8B, 0x5BDEAE1D, + 0x9B64C2B0, 0xEC63F226, 0x756AA39C, 0x026D930A, 0x9C0906A9, 0xEB0E363F, 0x72076785, 0x05005713, + 0x95BF4A82, 0xE2B87A14, 0x7BB12BAE, 0x0CB61B38, 0x92D28E9B, 0xE5D5BE0D, 0x7CDCEFB7, 0x0BDBDF21, + 0x86D3D2D4, 0xF1D4E242, 0x68DDB3F8, 0x1FDA836E, 0x81BE16CD, 0xF6B9265B, 0x6FB077E1, 0x18B74777, + 0x88085AE6, 0xFF0F6A70, 0x66063BCA, 0x11010B5C, 0x8F659EFF, 0xF862AE69, 0x616BFFD3, 0x166CCF45, + 0xA00AE278, 0xD70DD2EE, 0x4E048354, 0x3903B3C2, 0xA7672661, 0xD06016F7, 0x4969474D, 0x3E6E77DB, + 0xAED16A4A, 0xD9D65ADC, 0x40DF0B66, 0x37D83BF0, 0xA9BCAE53, 0xDEBB9EC5, 0x47B2CF7F, 0x30B5FFE9, + 0xBDBDF21C, 0xCABAC28A, 0x53B39330, 0x24B4A3A6, 0xBAD03605, 0xCDD70693, 0x54DE5729, 0x23D967BF, + 0xB3667A2E, 0xC4614AB8, 0x5D681B02, 0x2A6F2B94, 0xB40BBE37, 0xC30C8EA1, 0x5A05DF1B, 0x2D02EF8D + }; + + unsigned int crc = ~0u; + int i; + for (i=0; i < len; ++i) + crc = (crc >> 8) ^ crc_table[buffer[i] ^ (crc & 0xff)]; + return ~crc; +#endif +} + +#define stbiw__wpng4(o,a,b,c,d) ((o)[0]=STBIW_UCHAR(a),(o)[1]=STBIW_UCHAR(b),(o)[2]=STBIW_UCHAR(c),(o)[3]=STBIW_UCHAR(d),(o)+=4) +#define stbiw__wp32(data,v) stbiw__wpng4(data, (v)>>24,(v)>>16,(v)>>8,(v)); +#define stbiw__wptag(data,s) stbiw__wpng4(data, s[0],s[1],s[2],s[3]) + +static void stbiw__wpcrc(unsigned char **data, int len) +{ + unsigned int crc = stbiw__crc32(*data - len - 4, len+4); + stbiw__wp32(*data, crc); +} + +static unsigned char stbiw__paeth(int a, int b, int c) +{ + int p = a + b - c, pa = abs(p-a), pb = abs(p-b), pc = abs(p-c); + if (pa <= pb && pa <= pc) return STBIW_UCHAR(a); + if (pb <= pc) return STBIW_UCHAR(b); + return STBIW_UCHAR(c); +} + +// @OPTIMIZE: provide an option that always forces left-predict or paeth predict +static void stbiw__encode_png_line(unsigned char *pixels, int stride_bytes, int width, int height, int y, int n, int filter_type, signed char *line_buffer) +{ + static int mapping[] = { 0,1,2,3,4 }; + static int firstmap[] = { 0,1,0,5,6 }; + int *mymap = (y != 0) ? mapping : firstmap; + int i; + int type = mymap[filter_type]; + unsigned char *z = pixels + stride_bytes * (stbi__flip_vertically_on_write ? height-1-y : y); + int signed_stride = stbi__flip_vertically_on_write ? -stride_bytes : stride_bytes; + + if (type==0) { + memcpy(line_buffer, z, width*n); + return; + } + + // first loop isn't optimized since it's just one pixel + for (i = 0; i < n; ++i) { + switch (type) { + case 1: line_buffer[i] = z[i]; break; + case 2: line_buffer[i] = z[i] - z[i-signed_stride]; break; + case 3: line_buffer[i] = z[i] - (z[i-signed_stride]>>1); break; + case 4: line_buffer[i] = (signed char) (z[i] - stbiw__paeth(0,z[i-signed_stride],0)); break; + case 5: line_buffer[i] = z[i]; break; + case 6: line_buffer[i] = z[i]; break; + } + } + switch (type) { + case 1: for (i=n; i < width*n; ++i) line_buffer[i] = z[i] - z[i-n]; break; + case 2: for (i=n; i < width*n; ++i) line_buffer[i] = z[i] - z[i-signed_stride]; break; + case 3: for (i=n; i < width*n; ++i) line_buffer[i] = z[i] - ((z[i-n] + z[i-signed_stride])>>1); break; + case 4: for (i=n; i < width*n; ++i) line_buffer[i] = z[i] - stbiw__paeth(z[i-n], z[i-signed_stride], z[i-signed_stride-n]); break; + case 5: for (i=n; i < width*n; ++i) line_buffer[i] = z[i] - (z[i-n]>>1); break; + case 6: for (i=n; i < width*n; ++i) line_buffer[i] = z[i] - stbiw__paeth(z[i-n], 0,0); break; + } +} + +STBIWDEF unsigned char *stbi_write_png_to_mem(const unsigned char *pixels, int stride_bytes, int x, int y, int n, int *out_len) +{ + int force_filter = stbi_write_force_png_filter; + int ctype[5] = { -1, 0, 4, 2, 6 }; + unsigned char sig[8] = { 137,80,78,71,13,10,26,10 }; + unsigned char *out,*o, *filt, *zlib; + signed char *line_buffer; + int j,zlen; + + if (stride_bytes == 0) + stride_bytes = x * n; + + if (force_filter >= 5) { + force_filter = -1; + } + + filt = (unsigned char *) STBIW_MALLOC((x*n+1) * y); if (!filt) return 0; + line_buffer = (signed char *) STBIW_MALLOC(x * n); if (!line_buffer) { STBIW_FREE(filt); return 0; } + for (j=0; j < y; ++j) { + int filter_type; + if (force_filter > -1) { + filter_type = force_filter; + stbiw__encode_png_line((unsigned char*)(pixels), stride_bytes, x, y, j, n, force_filter, line_buffer); + } else { // Estimate the best filter by running through all of them: + int best_filter = 0, best_filter_val = 0x7fffffff, est, i; + for (filter_type = 0; filter_type < 5; filter_type++) { + stbiw__encode_png_line((unsigned char*)(pixels), stride_bytes, x, y, j, n, filter_type, line_buffer); + + // Estimate the entropy of the line using this filter; the less, the better. + est = 0; + for (i = 0; i < x*n; ++i) { + est += abs((signed char) line_buffer[i]); + } + if (est < best_filter_val) { + best_filter_val = est; + best_filter = filter_type; + } + } + if (filter_type != best_filter) { // If the last iteration already got us the best filter, don't redo it + stbiw__encode_png_line((unsigned char*)(pixels), stride_bytes, x, y, j, n, best_filter, line_buffer); + filter_type = best_filter; + } + } + // when we get here, filter_type contains the filter type, and line_buffer contains the data + filt[j*(x*n+1)] = (unsigned char) filter_type; + STBIW_MEMMOVE(filt+j*(x*n+1)+1, line_buffer, x*n); + } + STBIW_FREE(line_buffer); + zlib = stbi_zlib_compress(filt, y*( x*n+1), &zlen, stbi_write_png_compression_level); + STBIW_FREE(filt); + if (!zlib) return 0; + + // each tag requires 12 bytes of overhead + out = (unsigned char *) STBIW_MALLOC(8 + 12+13 + 12+zlen + 12); + if (!out) return 0; + *out_len = 8 + 12+13 + 12+zlen + 12; + + o=out; + STBIW_MEMMOVE(o,sig,8); o+= 8; + stbiw__wp32(o, 13); // header length + stbiw__wptag(o, "IHDR"); + stbiw__wp32(o, x); + stbiw__wp32(o, y); + *o++ = 8; + *o++ = STBIW_UCHAR(ctype[n]); + *o++ = 0; + *o++ = 0; + *o++ = 0; + stbiw__wpcrc(&o,13); + + stbiw__wp32(o, zlen); + stbiw__wptag(o, "IDAT"); + STBIW_MEMMOVE(o, zlib, zlen); + o += zlen; + STBIW_FREE(zlib); + stbiw__wpcrc(&o, zlen); + + stbiw__wp32(o,0); + stbiw__wptag(o, "IEND"); + stbiw__wpcrc(&o,0); + + STBIW_ASSERT(o == out + *out_len); + + return out; +} + +#ifndef STBI_WRITE_NO_STDIO +STBIWDEF int stbi_write_png(char const *filename, int x, int y, int comp, const void *data, int stride_bytes) +{ + FILE *f; + int len; + unsigned char *png = stbi_write_png_to_mem((const unsigned char *) data, stride_bytes, x, y, comp, &len); + if (png == NULL) return 0; + + f = stbiw__fopen(filename, "wb"); + if (!f) { STBIW_FREE(png); return 0; } + fwrite(png, 1, len, f); + fclose(f); + STBIW_FREE(png); + return 1; +} +#endif + +STBIWDEF int stbi_write_png_to_func(stbi_write_func *func, void *context, int x, int y, int comp, const void *data, int stride_bytes) +{ + int len; + unsigned char *png = stbi_write_png_to_mem((const unsigned char *) data, stride_bytes, x, y, comp, &len); + if (png == NULL) return 0; + func(context, png, len); + STBIW_FREE(png); + return 1; +} + + +/* *************************************************************************** + * + * JPEG writer + * + * This is based on Jon Olick's jo_jpeg.cpp: + * public domain Simple, Minimalistic JPEG writer - http://www.jonolick.com/code.html + */ + +static const unsigned char stbiw__jpg_ZigZag[] = { 0,1,5,6,14,15,27,28,2,4,7,13,16,26,29,42,3,8,12,17,25,30,41,43,9,11,18, + 24,31,40,44,53,10,19,23,32,39,45,52,54,20,22,33,38,46,51,55,60,21,34,37,47,50,56,59,61,35,36,48,49,57,58,62,63 }; + +static void stbiw__jpg_writeBits(stbi__write_context *s, int *bitBufP, int *bitCntP, const unsigned short *bs) { + int bitBuf = *bitBufP, bitCnt = *bitCntP; + bitCnt += bs[1]; + bitBuf |= bs[0] << (24 - bitCnt); + while(bitCnt >= 8) { + unsigned char c = (bitBuf >> 16) & 255; + stbiw__putc(s, c); + if(c == 255) { + stbiw__putc(s, 0); + } + bitBuf <<= 8; + bitCnt -= 8; + } + *bitBufP = bitBuf; + *bitCntP = bitCnt; +} + +static void stbiw__jpg_DCT(float *d0p, float *d1p, float *d2p, float *d3p, float *d4p, float *d5p, float *d6p, float *d7p) { + float d0 = *d0p, d1 = *d1p, d2 = *d2p, d3 = *d3p, d4 = *d4p, d5 = *d5p, d6 = *d6p, d7 = *d7p; + float z1, z2, z3, z4, z5, z11, z13; + + float tmp0 = d0 + d7; + float tmp7 = d0 - d7; + float tmp1 = d1 + d6; + float tmp6 = d1 - d6; + float tmp2 = d2 + d5; + float tmp5 = d2 - d5; + float tmp3 = d3 + d4; + float tmp4 = d3 - d4; + + // Even part + float tmp10 = tmp0 + tmp3; // phase 2 + float tmp13 = tmp0 - tmp3; + float tmp11 = tmp1 + tmp2; + float tmp12 = tmp1 - tmp2; + + d0 = tmp10 + tmp11; // phase 3 + d4 = tmp10 - tmp11; + + z1 = (tmp12 + tmp13) * 0.707106781f; // c4 + d2 = tmp13 + z1; // phase 5 + d6 = tmp13 - z1; + + // Odd part + tmp10 = tmp4 + tmp5; // phase 2 + tmp11 = tmp5 + tmp6; + tmp12 = tmp6 + tmp7; + + // The rotator is modified from fig 4-8 to avoid extra negations. + z5 = (tmp10 - tmp12) * 0.382683433f; // c6 + z2 = tmp10 * 0.541196100f + z5; // c2-c6 + z4 = tmp12 * 1.306562965f + z5; // c2+c6 + z3 = tmp11 * 0.707106781f; // c4 + + z11 = tmp7 + z3; // phase 5 + z13 = tmp7 - z3; + + *d5p = z13 + z2; // phase 6 + *d3p = z13 - z2; + *d1p = z11 + z4; + *d7p = z11 - z4; + + *d0p = d0; *d2p = d2; *d4p = d4; *d6p = d6; +} + +static void stbiw__jpg_calcBits(int val, unsigned short bits[2]) { + int tmp1 = val < 0 ? -val : val; + val = val < 0 ? val-1 : val; + bits[1] = 1; + while(tmp1 >>= 1) { + ++bits[1]; + } + bits[0] = val & ((1<0)&&(DU[end0pos]==0); --end0pos) { + } + // end0pos = first element in reverse order !=0 + if(end0pos == 0) { + stbiw__jpg_writeBits(s, bitBuf, bitCnt, EOB); + return DU[0]; + } + for(i = 1; i <= end0pos; ++i) { + int startpos = i; + int nrzeroes; + unsigned short bits[2]; + for (; DU[i]==0 && i<=end0pos; ++i) { + } + nrzeroes = i-startpos; + if ( nrzeroes >= 16 ) { + int lng = nrzeroes>>4; + int nrmarker; + for (nrmarker=1; nrmarker <= lng; ++nrmarker) + stbiw__jpg_writeBits(s, bitBuf, bitCnt, M16zeroes); + nrzeroes &= 15; + } + stbiw__jpg_calcBits(DU[i], bits); + stbiw__jpg_writeBits(s, bitBuf, bitCnt, HTAC[(nrzeroes<<4)+bits[1]]); + stbiw__jpg_writeBits(s, bitBuf, bitCnt, bits); + } + if(end0pos != 63) { + stbiw__jpg_writeBits(s, bitBuf, bitCnt, EOB); + } + return DU[0]; +} + +static int stbi_write_jpg_core(stbi__write_context *s, int width, int height, int comp, const void* data, int quality) { + // Constants that don't pollute global namespace + static const unsigned char std_dc_luminance_nrcodes[] = {0,0,1,5,1,1,1,1,1,1,0,0,0,0,0,0,0}; + static const unsigned char std_dc_luminance_values[] = {0,1,2,3,4,5,6,7,8,9,10,11}; + static const unsigned char std_ac_luminance_nrcodes[] = {0,0,2,1,3,3,2,4,3,5,5,4,4,0,0,1,0x7d}; + static const unsigned char std_ac_luminance_values[] = { + 0x01,0x02,0x03,0x00,0x04,0x11,0x05,0x12,0x21,0x31,0x41,0x06,0x13,0x51,0x61,0x07,0x22,0x71,0x14,0x32,0x81,0x91,0xa1,0x08, + 0x23,0x42,0xb1,0xc1,0x15,0x52,0xd1,0xf0,0x24,0x33,0x62,0x72,0x82,0x09,0x0a,0x16,0x17,0x18,0x19,0x1a,0x25,0x26,0x27,0x28, + 0x29,0x2a,0x34,0x35,0x36,0x37,0x38,0x39,0x3a,0x43,0x44,0x45,0x46,0x47,0x48,0x49,0x4a,0x53,0x54,0x55,0x56,0x57,0x58,0x59, + 0x5a,0x63,0x64,0x65,0x66,0x67,0x68,0x69,0x6a,0x73,0x74,0x75,0x76,0x77,0x78,0x79,0x7a,0x83,0x84,0x85,0x86,0x87,0x88,0x89, + 0x8a,0x92,0x93,0x94,0x95,0x96,0x97,0x98,0x99,0x9a,0xa2,0xa3,0xa4,0xa5,0xa6,0xa7,0xa8,0xa9,0xaa,0xb2,0xb3,0xb4,0xb5,0xb6, + 0xb7,0xb8,0xb9,0xba,0xc2,0xc3,0xc4,0xc5,0xc6,0xc7,0xc8,0xc9,0xca,0xd2,0xd3,0xd4,0xd5,0xd6,0xd7,0xd8,0xd9,0xda,0xe1,0xe2, + 0xe3,0xe4,0xe5,0xe6,0xe7,0xe8,0xe9,0xea,0xf1,0xf2,0xf3,0xf4,0xf5,0xf6,0xf7,0xf8,0xf9,0xfa + }; + static const unsigned char std_dc_chrominance_nrcodes[] = {0,0,3,1,1,1,1,1,1,1,1,1,0,0,0,0,0}; + static const unsigned char std_dc_chrominance_values[] = {0,1,2,3,4,5,6,7,8,9,10,11}; + static const unsigned char std_ac_chrominance_nrcodes[] = {0,0,2,1,2,4,4,3,4,7,5,4,4,0,1,2,0x77}; + static const unsigned char std_ac_chrominance_values[] = { + 0x00,0x01,0x02,0x03,0x11,0x04,0x05,0x21,0x31,0x06,0x12,0x41,0x51,0x07,0x61,0x71,0x13,0x22,0x32,0x81,0x08,0x14,0x42,0x91, + 0xa1,0xb1,0xc1,0x09,0x23,0x33,0x52,0xf0,0x15,0x62,0x72,0xd1,0x0a,0x16,0x24,0x34,0xe1,0x25,0xf1,0x17,0x18,0x19,0x1a,0x26, + 0x27,0x28,0x29,0x2a,0x35,0x36,0x37,0x38,0x39,0x3a,0x43,0x44,0x45,0x46,0x47,0x48,0x49,0x4a,0x53,0x54,0x55,0x56,0x57,0x58, + 0x59,0x5a,0x63,0x64,0x65,0x66,0x67,0x68,0x69,0x6a,0x73,0x74,0x75,0x76,0x77,0x78,0x79,0x7a,0x82,0x83,0x84,0x85,0x86,0x87, + 0x88,0x89,0x8a,0x92,0x93,0x94,0x95,0x96,0x97,0x98,0x99,0x9a,0xa2,0xa3,0xa4,0xa5,0xa6,0xa7,0xa8,0xa9,0xaa,0xb2,0xb3,0xb4, + 0xb5,0xb6,0xb7,0xb8,0xb9,0xba,0xc2,0xc3,0xc4,0xc5,0xc6,0xc7,0xc8,0xc9,0xca,0xd2,0xd3,0xd4,0xd5,0xd6,0xd7,0xd8,0xd9,0xda, + 0xe2,0xe3,0xe4,0xe5,0xe6,0xe7,0xe8,0xe9,0xea,0xf2,0xf3,0xf4,0xf5,0xf6,0xf7,0xf8,0xf9,0xfa + }; + // Huffman tables + static const unsigned short YDC_HT[256][2] = { {0,2},{2,3},{3,3},{4,3},{5,3},{6,3},{14,4},{30,5},{62,6},{126,7},{254,8},{510,9}}; + static const unsigned short UVDC_HT[256][2] = { {0,2},{1,2},{2,2},{6,3},{14,4},{30,5},{62,6},{126,7},{254,8},{510,9},{1022,10},{2046,11}}; + static const unsigned short YAC_HT[256][2] = { + {10,4},{0,2},{1,2},{4,3},{11,4},{26,5},{120,7},{248,8},{1014,10},{65410,16},{65411,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, + {12,4},{27,5},{121,7},{502,9},{2038,11},{65412,16},{65413,16},{65414,16},{65415,16},{65416,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, + {28,5},{249,8},{1015,10},{4084,12},{65417,16},{65418,16},{65419,16},{65420,16},{65421,16},{65422,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, + {58,6},{503,9},{4085,12},{65423,16},{65424,16},{65425,16},{65426,16},{65427,16},{65428,16},{65429,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, + {59,6},{1016,10},{65430,16},{65431,16},{65432,16},{65433,16},{65434,16},{65435,16},{65436,16},{65437,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, + {122,7},{2039,11},{65438,16},{65439,16},{65440,16},{65441,16},{65442,16},{65443,16},{65444,16},{65445,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, + {123,7},{4086,12},{65446,16},{65447,16},{65448,16},{65449,16},{65450,16},{65451,16},{65452,16},{65453,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, + {250,8},{4087,12},{65454,16},{65455,16},{65456,16},{65457,16},{65458,16},{65459,16},{65460,16},{65461,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, + {504,9},{32704,15},{65462,16},{65463,16},{65464,16},{65465,16},{65466,16},{65467,16},{65468,16},{65469,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, + {505,9},{65470,16},{65471,16},{65472,16},{65473,16},{65474,16},{65475,16},{65476,16},{65477,16},{65478,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, + {506,9},{65479,16},{65480,16},{65481,16},{65482,16},{65483,16},{65484,16},{65485,16},{65486,16},{65487,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, + {1017,10},{65488,16},{65489,16},{65490,16},{65491,16},{65492,16},{65493,16},{65494,16},{65495,16},{65496,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, + {1018,10},{65497,16},{65498,16},{65499,16},{65500,16},{65501,16},{65502,16},{65503,16},{65504,16},{65505,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, + {2040,11},{65506,16},{65507,16},{65508,16},{65509,16},{65510,16},{65511,16},{65512,16},{65513,16},{65514,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, + {65515,16},{65516,16},{65517,16},{65518,16},{65519,16},{65520,16},{65521,16},{65522,16},{65523,16},{65524,16},{0,0},{0,0},{0,0},{0,0},{0,0}, + {2041,11},{65525,16},{65526,16},{65527,16},{65528,16},{65529,16},{65530,16},{65531,16},{65532,16},{65533,16},{65534,16},{0,0},{0,0},{0,0},{0,0},{0,0} + }; + static const unsigned short UVAC_HT[256][2] = { + {0,2},{1,2},{4,3},{10,4},{24,5},{25,5},{56,6},{120,7},{500,9},{1014,10},{4084,12},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, + {11,4},{57,6},{246,8},{501,9},{2038,11},{4085,12},{65416,16},{65417,16},{65418,16},{65419,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, + {26,5},{247,8},{1015,10},{4086,12},{32706,15},{65420,16},{65421,16},{65422,16},{65423,16},{65424,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, + {27,5},{248,8},{1016,10},{4087,12},{65425,16},{65426,16},{65427,16},{65428,16},{65429,16},{65430,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, + {58,6},{502,9},{65431,16},{65432,16},{65433,16},{65434,16},{65435,16},{65436,16},{65437,16},{65438,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, + {59,6},{1017,10},{65439,16},{65440,16},{65441,16},{65442,16},{65443,16},{65444,16},{65445,16},{65446,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, + {121,7},{2039,11},{65447,16},{65448,16},{65449,16},{65450,16},{65451,16},{65452,16},{65453,16},{65454,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, + {122,7},{2040,11},{65455,16},{65456,16},{65457,16},{65458,16},{65459,16},{65460,16},{65461,16},{65462,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, + {249,8},{65463,16},{65464,16},{65465,16},{65466,16},{65467,16},{65468,16},{65469,16},{65470,16},{65471,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, + {503,9},{65472,16},{65473,16},{65474,16},{65475,16},{65476,16},{65477,16},{65478,16},{65479,16},{65480,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, + {504,9},{65481,16},{65482,16},{65483,16},{65484,16},{65485,16},{65486,16},{65487,16},{65488,16},{65489,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, + {505,9},{65490,16},{65491,16},{65492,16},{65493,16},{65494,16},{65495,16},{65496,16},{65497,16},{65498,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, + {506,9},{65499,16},{65500,16},{65501,16},{65502,16},{65503,16},{65504,16},{65505,16},{65506,16},{65507,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, + {2041,11},{65508,16},{65509,16},{65510,16},{65511,16},{65512,16},{65513,16},{65514,16},{65515,16},{65516,16},{0,0},{0,0},{0,0},{0,0},{0,0},{0,0}, + {16352,14},{65517,16},{65518,16},{65519,16},{65520,16},{65521,16},{65522,16},{65523,16},{65524,16},{65525,16},{0,0},{0,0},{0,0},{0,0},{0,0}, + {1018,10},{32707,15},{65526,16},{65527,16},{65528,16},{65529,16},{65530,16},{65531,16},{65532,16},{65533,16},{65534,16},{0,0},{0,0},{0,0},{0,0},{0,0} + }; + static const int YQT[] = {16,11,10,16,24,40,51,61,12,12,14,19,26,58,60,55,14,13,16,24,40,57,69,56,14,17,22,29,51,87,80,62,18,22, + 37,56,68,109,103,77,24,35,55,64,81,104,113,92,49,64,78,87,103,121,120,101,72,92,95,98,112,100,103,99}; + static const int UVQT[] = {17,18,24,47,99,99,99,99,18,21,26,66,99,99,99,99,24,26,56,99,99,99,99,99,47,66,99,99,99,99,99,99, + 99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99,99}; + static const float aasf[] = { 1.0f * 2.828427125f, 1.387039845f * 2.828427125f, 1.306562965f * 2.828427125f, 1.175875602f * 2.828427125f, + 1.0f * 2.828427125f, 0.785694958f * 2.828427125f, 0.541196100f * 2.828427125f, 0.275899379f * 2.828427125f }; + + int row, col, i, k, subsample; + float fdtbl_Y[64], fdtbl_UV[64]; + unsigned char YTable[64], UVTable[64]; + + if(!data || !width || !height || comp > 4 || comp < 1) { + return 0; + } + + quality = quality ? quality : 90; + subsample = quality <= 90 ? 1 : 0; + quality = quality < 1 ? 1 : quality > 100 ? 100 : quality; + quality = quality < 50 ? 5000 / quality : 200 - quality * 2; + + for(i = 0; i < 64; ++i) { + int uvti, yti = (YQT[i]*quality+50)/100; + YTable[stbiw__jpg_ZigZag[i]] = (unsigned char) (yti < 1 ? 1 : yti > 255 ? 255 : yti); + uvti = (UVQT[i]*quality+50)/100; + UVTable[stbiw__jpg_ZigZag[i]] = (unsigned char) (uvti < 1 ? 1 : uvti > 255 ? 255 : uvti); + } + + for(row = 0, k = 0; row < 8; ++row) { + for(col = 0; col < 8; ++col, ++k) { + fdtbl_Y[k] = 1 / (YTable [stbiw__jpg_ZigZag[k]] * aasf[row] * aasf[col]); + fdtbl_UV[k] = 1 / (UVTable[stbiw__jpg_ZigZag[k]] * aasf[row] * aasf[col]); + } + } + + // Write Headers + { + static const unsigned char head0[] = { 0xFF,0xD8,0xFF,0xE0,0,0x10,'J','F','I','F',0,1,1,0,0,1,0,1,0,0,0xFF,0xDB,0,0x84,0 }; + static const unsigned char head2[] = { 0xFF,0xDA,0,0xC,3,1,0,2,0x11,3,0x11,0,0x3F,0 }; + const unsigned char head1[] = { 0xFF,0xC0,0,0x11,8,(unsigned char)(height>>8),STBIW_UCHAR(height),(unsigned char)(width>>8),STBIW_UCHAR(width), + 3,1,(unsigned char)(subsample?0x22:0x11),0,2,0x11,1,3,0x11,1,0xFF,0xC4,0x01,0xA2,0 }; + s->func(s->context, (void*)head0, sizeof(head0)); + s->func(s->context, (void*)YTable, sizeof(YTable)); + stbiw__putc(s, 1); + s->func(s->context, UVTable, sizeof(UVTable)); + s->func(s->context, (void*)head1, sizeof(head1)); + s->func(s->context, (void*)(std_dc_luminance_nrcodes+1), sizeof(std_dc_luminance_nrcodes)-1); + s->func(s->context, (void*)std_dc_luminance_values, sizeof(std_dc_luminance_values)); + stbiw__putc(s, 0x10); // HTYACinfo + s->func(s->context, (void*)(std_ac_luminance_nrcodes+1), sizeof(std_ac_luminance_nrcodes)-1); + s->func(s->context, (void*)std_ac_luminance_values, sizeof(std_ac_luminance_values)); + stbiw__putc(s, 1); // HTUDCinfo + s->func(s->context, (void*)(std_dc_chrominance_nrcodes+1), sizeof(std_dc_chrominance_nrcodes)-1); + s->func(s->context, (void*)std_dc_chrominance_values, sizeof(std_dc_chrominance_values)); + stbiw__putc(s, 0x11); // HTUACinfo + s->func(s->context, (void*)(std_ac_chrominance_nrcodes+1), sizeof(std_ac_chrominance_nrcodes)-1); + s->func(s->context, (void*)std_ac_chrominance_values, sizeof(std_ac_chrominance_values)); + s->func(s->context, (void*)head2, sizeof(head2)); + } + + // Encode 8x8 macroblocks + { + static const unsigned short fillBits[] = {0x7F, 7}; + int DCY=0, DCU=0, DCV=0; + int bitBuf=0, bitCnt=0; + // comp == 2 is grey+alpha (alpha is ignored) + int ofsG = comp > 2 ? 1 : 0, ofsB = comp > 2 ? 2 : 0; + const unsigned char *dataR = (const unsigned char *)data; + const unsigned char *dataG = dataR + ofsG; + const unsigned char *dataB = dataR + ofsB; + int x, y, pos; + if(subsample) { + for(y = 0; y < height; y += 16) { + for(x = 0; x < width; x += 16) { + float Y[256], U[256], V[256]; + for(row = y, pos = 0; row < y+16; ++row) { + // row >= height => use last input row + int clamped_row = (row < height) ? row : height - 1; + int base_p = (stbi__flip_vertically_on_write ? (height-1-clamped_row) : clamped_row)*width*comp; + for(col = x; col < x+16; ++col, ++pos) { + // if col >= width => use pixel from last input column + int p = base_p + ((col < width) ? col : (width-1))*comp; + float r = dataR[p], g = dataG[p], b = dataB[p]; + Y[pos]= +0.29900f*r + 0.58700f*g + 0.11400f*b - 128; + U[pos]= -0.16874f*r - 0.33126f*g + 0.50000f*b; + V[pos]= +0.50000f*r - 0.41869f*g - 0.08131f*b; + } + } + DCY = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, Y+0, 16, fdtbl_Y, DCY, YDC_HT, YAC_HT); + DCY = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, Y+8, 16, fdtbl_Y, DCY, YDC_HT, YAC_HT); + DCY = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, Y+128, 16, fdtbl_Y, DCY, YDC_HT, YAC_HT); + DCY = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, Y+136, 16, fdtbl_Y, DCY, YDC_HT, YAC_HT); + + // subsample U,V + { + float subU[64], subV[64]; + int yy, xx; + for(yy = 0, pos = 0; yy < 8; ++yy) { + for(xx = 0; xx < 8; ++xx, ++pos) { + int j = yy*32+xx*2; + subU[pos] = (U[j+0] + U[j+1] + U[j+16] + U[j+17]) * 0.25f; + subV[pos] = (V[j+0] + V[j+1] + V[j+16] + V[j+17]) * 0.25f; + } + } + DCU = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, subU, 8, fdtbl_UV, DCU, UVDC_HT, UVAC_HT); + DCV = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, subV, 8, fdtbl_UV, DCV, UVDC_HT, UVAC_HT); + } + } + } + } else { + for(y = 0; y < height; y += 8) { + for(x = 0; x < width; x += 8) { + float Y[64], U[64], V[64]; + for(row = y, pos = 0; row < y+8; ++row) { + // row >= height => use last input row + int clamped_row = (row < height) ? row : height - 1; + int base_p = (stbi__flip_vertically_on_write ? (height-1-clamped_row) : clamped_row)*width*comp; + for(col = x; col < x+8; ++col, ++pos) { + // if col >= width => use pixel from last input column + int p = base_p + ((col < width) ? col : (width-1))*comp; + float r = dataR[p], g = dataG[p], b = dataB[p]; + Y[pos]= +0.29900f*r + 0.58700f*g + 0.11400f*b - 128; + U[pos]= -0.16874f*r - 0.33126f*g + 0.50000f*b; + V[pos]= +0.50000f*r - 0.41869f*g - 0.08131f*b; + } + } + + DCY = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, Y, 8, fdtbl_Y, DCY, YDC_HT, YAC_HT); + DCU = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, U, 8, fdtbl_UV, DCU, UVDC_HT, UVAC_HT); + DCV = stbiw__jpg_processDU(s, &bitBuf, &bitCnt, V, 8, fdtbl_UV, DCV, UVDC_HT, UVAC_HT); + } + } + } + + // Do the bit alignment of the EOI marker + stbiw__jpg_writeBits(s, &bitBuf, &bitCnt, fillBits); + } + + // EOI + stbiw__putc(s, 0xFF); + stbiw__putc(s, 0xD9); + + return 1; +} + +STBIWDEF int stbi_write_jpg_to_func(stbi_write_func *func, void *context, int x, int y, int comp, const void *data, int quality) +{ + stbi__write_context s = { 0 }; + stbi__start_write_callbacks(&s, func, context); + return stbi_write_jpg_core(&s, x, y, comp, (void *) data, quality); +} + + +#ifndef STBI_WRITE_NO_STDIO +STBIWDEF int stbi_write_jpg(char const *filename, int x, int y, int comp, const void *data, int quality) +{ + stbi__write_context s = { 0 }; + if (stbi__start_write_file(&s,filename)) { + int r = stbi_write_jpg_core(&s, x, y, comp, data, quality); + stbi__end_write_file(&s); + return r; + } else + return 0; +} +#endif + +#endif // STB_IMAGE_WRITE_IMPLEMENTATION + +/* Revision history + 1.16 (2021-07-11) + make Deflate code emit uncompressed blocks when it would otherwise expand + support writing BMPs with alpha channel + 1.15 (2020-07-13) unknown + 1.14 (2020-02-02) updated JPEG writer to downsample chroma channels + 1.13 + 1.12 + 1.11 (2019-08-11) + + 1.10 (2019-02-07) + support utf8 filenames in Windows; fix warnings and platform ifdefs + 1.09 (2018-02-11) + fix typo in zlib quality API, improve STB_I_W_STATIC in C++ + 1.08 (2018-01-29) + add stbi__flip_vertically_on_write, external zlib, zlib quality, choose PNG filter + 1.07 (2017-07-24) + doc fix + 1.06 (2017-07-23) + writing JPEG (using Jon Olick's code) + 1.05 ??? + 1.04 (2017-03-03) + monochrome BMP expansion + 1.03 ??? + 1.02 (2016-04-02) + avoid allocating large structures on the stack + 1.01 (2016-01-16) + STBIW_REALLOC_SIZED: support allocators with no realloc support + avoid race-condition in crc initialization + minor compile issues + 1.00 (2015-09-14) + installable file IO function + 0.99 (2015-09-13) + warning fixes; TGA rle support + 0.98 (2015-04-08) + added STBIW_MALLOC, STBIW_ASSERT etc + 0.97 (2015-01-18) + fixed HDR asserts, rewrote HDR rle logic + 0.96 (2015-01-17) + add HDR output + fix monochrome BMP + 0.95 (2014-08-17) + add monochrome TGA output + 0.94 (2014-05-31) + rename private functions to avoid conflicts with stb_image.h + 0.93 (2014-05-27) + warning fixes + 0.92 (2010-08-01) + casts to unsigned char to fix warnings + 0.91 (2010-07-17) + first public release + 0.90 first internal release +*/ + +/* +------------------------------------------------------------------------------ +This software is available under 2 licenses -- choose whichever you prefer. +------------------------------------------------------------------------------ +ALTERNATIVE A - MIT License +Copyright (c) 2017 Sean Barrett +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies +of the Software, and to permit persons to whom the Software is furnished to do +so, subject to the following conditions: +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +------------------------------------------------------------------------------ +ALTERNATIVE B - Public Domain (www.unlicense.org) +This is free and unencumbered software released into the public domain. +Anyone is free to copy, modify, publish, use, compile, sell, or distribute this +software, either in source code form or as a compiled binary, for any purpose, +commercial or non-commercial, and by any means. +In jurisdictions that recognize copyright laws, the author or authors of this +software dedicate any and all copyright interest in the software to the public +domain. We make this dedication for the benefit of the public at large and to +the detriment of our heirs and successors. We intend this dedication to be an +overt act of relinquishment in perpetuity of all present and future rights to +this software under copyright law. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN +ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +------------------------------------------------------------------------------ +*/ diff --git a/fuzzer/CMakeLists.txt b/fuzzer/CMakeLists.txt new file mode 100644 index 0000000..c153f53 --- /dev/null +++ b/fuzzer/CMakeLists.txt @@ -0,0 +1,16 @@ +# (c) 2024 Mario "Neo" Sieg. + +enable_language(CXX) +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +# Fuzzer requires LLVM as compiler +if (NOT CMAKE_CXX_COMPILER_ID STREQUAL "Clang") + message(FATAL_ERROR "Fuzzer requires Clang as compiler") +endif() + +add_executable(msml_fuzzer fuzzer.cpp) +target_compile_options(msml_fuzzer PRIVATE -fsanitize=fuzzer -O3 -march=native) +target_link_options(msml_fuzzer PRIVATE -fsanitize=fuzzer) +target_link_libraries(msml_fuzzer msml) +target_include_directories(msml_fuzzer PRIVATE ../msml) diff --git a/fuzzer/fuzzer.cpp b/fuzzer/fuzzer.cpp new file mode 100644 index 0000000..097b668 --- /dev/null +++ b/fuzzer/fuzzer.cpp @@ -0,0 +1,20 @@ +// (c) 2024 Mario "Neo" Sieg. + +// Command line options: +// -jobs=64 -workers=64 -max_len=16384 -rss_limit_mb=16384 -max_total_time=3600 -exact_artifact_path="bin/fuzz" +// 3600 = 1hr + +#include +#include + +#include + +extern "C" [[nodiscard]] auto mag__sto_read_buffered(mag_ctx_t* ctx, const std::uint8_t* buf, std::size_t size, std::size_t* out_n_tensors) -> mag_tensor_t**; // Imported from msml.c + +extern "C" auto LLVMFuzzerTestOneInput(const std::int8_t* data, std::size_t dize) -> int { + static std::unique_ptr ctx {new magnetron::ctx{}}; + std::size_t n_tensors = 0; + [[maybe_unused]] + mag_tensor_t** volatile tensors = mag__sto_read_buffered(**ctx, reinterpret_cast(data), dize, &n_tensors); + return !tensors ? -1 : 0; +} \ No newline at end of file diff --git a/magnetron/magnetron.c b/magnetron/magnetron.c new file mode 100644 index 0000000..d282473 --- /dev/null +++ b/magnetron/magnetron.c @@ -0,0 +1,3965 @@ +/* (c) 2024 Mario "Neo" Sieg. */ + +/* +** +** +** ### To add a new operation: +** 1. Add the operation to the mag_op_def macro, which defines all operations, with all information needed. +** 2. Write a validation routine, or use an existing one (e.g. 'mag_validate_op_binary'). +** 3. Add the validation routine to the 'routines' table in 'mag_op_get_validator_routine', at the op index. +** 4. Write a result tensor constructor routine or use an existing one (e.g. 'mag_result_constructor_routine_isomorph'). +** 5. Add the result tensor constructor routine to the 'routines' table in 'mag_op_get_result_constructor_routine', at the op index. +** 6. Write a BLAS computation routine. +** 7. Add the BLAS computation routine to the 'dispatch_lut' table in 'mag_blas_compute_dispatch_table_default', at the op index. +*/ + +/* +** Here ⊕ denotes a binary or unary operator. +** Note that an operators can or cannot support any of those forms. This must be specified in mag_op_def. +** +** Operators can have two forms: +** +** 1. R = A ⊕ B +** Result is a new tensor of shape of A and contains element-wise result of A ⊕ B, where A and B are tensors. +** +** 2. R = A ⊕= B +** Result is a view tensor of shape of A and contains element-wise result of A ⊕= B, where A and B are tensors. (Safes 1 allocation) +*/ + +#include "magnetron.h" +#include "magnetron_internal.h" + +#include +#include +#include + +#ifdef _MSC_VER +#define _USE_MATH_DEFINES +#endif +#include +#include +#include +#include +#include + +#ifdef _WIN32 +#define WIN32_LEAN_AND_MEAN +#include +#elif defined(__APPLE__) +#include +#include +#include +#include +#include +#else +#include +#endif + +#ifdef NDEBUG +#define MAG_LOG_DEFAULT_ENABLE 0 +#else +#define MAG_LOG_DEFAULT_ENABLE 1 +#endif +bool mag_log_enabled = MAG_LOG_DEFAULT_ENABLE; /* Read from multiple threads, allowed to be written from main thread once at start. */ +#undef MAG_LOG_DEFAULT_ENABLE + +void mag_set_log_mode(bool enabled) { + mag_log_enabled = enabled; +} + +MAG_NORET MAG_COLDPROC MAG_EXPORT void mag_panic(const char* msg, ...) { + fprintf(stdout, "%s", MAG_CC_RED); + va_list args; + va_start(args, msg); + vfprintf(stdout, msg, args); + va_end(args); + fprintf(stdout, "%s", MAG_CC_RESET); + fputc('\n', stdout); + fflush(stdout); + abort(); +} + +static void* mag_os_alloc_stub(void* blk, size_t size) { + if (!size) { + free(blk); + return NULL; + } else if(!blk) { + blk = malloc(size); + mag_assert(blk, "Failed to allocate %.03fKiB memory", (double)size/(double)(1<<10)); + return blk; + } else { + void* block = realloc(blk, size); + mag_assert(blk, "Failed to reallocate %.03fKiB memory", (double)size/(double)(1<<10)); + return block; + } +} + +void* (*mag_alloc)(void* blk, size_t size) = &mag_os_alloc_stub; + +void* mag_alloc_aligned(size_t size, size_t align) { + mag_assert(align && !(align&(align-1)), "Alignment must be power of 2: %zu", align); /* Alignment must be a power of 2 */ + void* p = (*mag_alloc)(NULL, size+sizeof(void*)+align-1); + uintptr_t pp = ((uintptr_t)p+sizeof(void*)+align-1)&-align; + ((void**)pp)[-1] = p; + return (void*)pp; +} + +void mag_free_aligned(void* blk) { + (*mag_alloc)(((void**)blk)[-1], 0); +} + +/* Include STB libraries and override their allocator with ours. */ +#define STBI_STATIC +#define STBI_MALLOC(sz) ((*mag_alloc)(NULL, (sz))) +#define STBI_FREE(ptr) ((*mag_alloc)((ptr), 0)) +#define STBI_REALLOC(ptr, sz) ((*mag_alloc)((ptr), (sz))) +#define STBIW_MALLOC(sz) ((*mag_alloc)(NULL, (sz))) +#define STBIW_FREE(ptr) ((*mag_alloc)((ptr), 0)) +#define STBIW_REALLOC(ptr, sz) ((*mag_alloc)((ptr), (sz))) +#define STB_IMAGE_IMPLEMENTATION +#include +#define STB_IMAGE_WRITE_IMPLEMENTATION +#include + +#if defined(__x86_64__) || defined(_M_X64) +#define MAG_X86_64_CPUID_0H 0 +#define MAG_X86_64_CPUID_1H 1 +#define MAG_X86_64_CPUID_2H 2 +#define MAG_X86_64_CPUID_7H 3 +#define MAG_X86_64_CPUID_80000001H 4 +#define MAG_X86_64_CPUID_80000007H 5 +#define MAG_X86_64_CPUID_16H 6 +#define MAG_X86_64_CPUID_7H_1H 7 +#define MAG_X86_64_CPUID_EAX 0 +#define MAG_X86_64_CPUID_EBX 1 +#define MAG_X86_64_CPUID_ECX 2 +#define MAG_X86_64_CPUID_EDX 3 + +#define mag_x86_64_feature_def(_, __) /* Enumerator | CPUDID Leaf | Register | Bit Index */\ + _(AVX , 1H, ECX, 28)__\ + _(AVX2 , 7H, EBX, 5)__\ + _(AVXVNNI , 7H_1H, EAX, 4)__\ + _(AVXVNNIINT8 , 7H_1H, EDX, 4)__\ + _(AVXVNNIINT16 , 7H_1H, EDX, 10)__\ + _(AVX512BW , 7H, EBX, 30)__\ + _(AVX512CD , 7H, EBX, 28)__\ + _(AVX512DQ , 7H, EBX, 17)__\ + _(AVX512ER , 7H, EBX, 27)__\ + _(AVX512F , 7H, EBX, 16)__\ + _(AVX512IFMA , 7H, EBX, 21)__\ + _(AVX512PF , 7H, EBX, 26)__\ + _(AVX512VBMI , 7H, ECX, 1)__\ + _(AVX512VL , 7H, EBX, 31)__\ + _(AVX512_4FMAPS , 7H, EDX, 3)__\ + _(AVX512_4VNNIW , 7H, EDX, 2)__\ + _(AVX512_FP16 , 7H, EDX, 23)__\ + _(AVX512_BF16 , 7H_1H, EAX, 5)__\ + _(AVX512_BITALG , 7H, ECX, 12)__\ + _(AVX512_VBMI2 , 7H, ECX, 6)__\ + _(AVX512_VNNI , 7H, ECX, 11)__\ + _(AVX512_VP2INTERSECT , 7H, EDX, 8)__\ + _(AVX512_VPOPCNTDQ , 7H, ECX, 14)__\ + _(BMI , 7H, EBX, 3)__\ + _(BMI2 , 7H, EBX, 8)__\ + _(F16C , 1H, ECX, 29)__\ + _(FMA , 1H, ECX, 12)__\ + _(FPU , 1H, EDX, 0)__\ + _(GFNI , 7H, ECX, 8)__\ + _(IA64 , 1H, EDX, 30)__\ + _(MMX , 1H, EDX, 23)__\ + _(OSXSAVE , 1H, ECX, 27)__\ + _(PCLMUL , 1H, ECX, 1)__\ + _(RDRND , 1H, ECX, 30)__\ + _(RDSEED , 7H, EBX, 18)__\ + _(RDTSCP , 80000001H, EDX, 27)__\ + _(SHA , 7H, EBX, 29)__\ + _(SSE , 1H, EDX, 25)__\ + _(SSE2 , 1H, EDX, 26)__\ + _(SSE3 , 1H, ECX, 0)__\ + _(SSE4_1 , 1H, ECX, 19)__\ + _(SSE4_2 , 1H, ECX, 20)__\ + _(SSSE3 , 1H, ECX, 9)__\ + _(VAES , 7H, ECX, 9)__\ + _(VME , 1H, EDX, 1)__\ + _(VMX , 1H, ECX, 5)__\ + _(VPCLMULQDQ , 7H, ECX, 10)__\ + _(XSAVE , 1H, ECX, 26)__\ + _(HYBRID_CPU , 7H, EDX, 15)__ + +#define _(enumerator, leaf, reg, bit) MAG_X86_64_FEATURE_##enumerator +typedef enum mag_x86_64_feature_t { + mag_x86_64_feature_def(_, MAG_SEP) + MAG_X86_64_FEATURE__COUNT +} mag_x86_64_feature_t; +#undef _ +#define _(enumerator, leaf, reg, bit) #enumerator +static const char* const mag_x86_64_feature_names[MAG_X86_64_FEATURE__COUNT] = { + mag_x86_64_feature_def(_, MAG_SEP) +}; +#undef _ +#define _(enumerator, leaf, reg, bit) (0xff&MAG_X86_64_CPUID_##leaf) +static const uint8_t mag_x86_64_feature_leaves[MAG_X86_64_FEATURE__COUNT] = { + mag_x86_64_feature_def(_, MAG_SEP) +}; +#undef _ +#define _(enumerator, leaf, reg, bit) (0xff&MAG_X86_64_CPUID_##reg) +static const uint8_t mag_x86_64_feature_regs[MAG_X86_64_FEATURE__COUNT] = { + mag_x86_64_feature_def(_, MAG_SEP) +}; +#undef _ +#define _(enumerator, leaf, reg, bit) (1u<<(bit)) +static const uint32_t msml__x86_64_feature_masks[MAG_X86_64_FEATURE__COUNT] = { + mag_x86_64_feature_def(_, MAG_SEP) +}; +#undef _ +#undef mag_x86_64_feature_def +#endif + +void* (*mag_get_alloc_fn(void))(void* blk, size_t size) { + return mag_alloc; +} + +void mag_set_alloc_fn(void* (*alloc)(void* blk, size_t size)) { + mag_assert2(alloc); + mag_alloc = alloc; +} + +void mag_humanize_memory_size(size_t n, double* out, const char** unit) { + if (n < (1<<10)) { + *out = (double)n; + *unit = "B"; + } else if (n < (1<<20)) { + *out = (double)n/(double)(1<<10); + *unit = "KiB"; + } else if (n < (1<<30)) { + *out = (double)n/(double)(1<<20); + *unit = "MiB"; + } else { + *out = (double)n/(double)(1<<30); + *unit = "GiB"; + } +} + +static void MAG_COLDPROC mag_print_separator(FILE* f) { + f = f ? f : stdout; + char sep[100+1]; + for (size_t i=0; i < (sizeof(sep)/sizeof(*sep))-1; ++i) sep[i] = '-'; + sep[sizeof(sep)/sizeof(*sep)-1] = '\0'; + fprintf(f, "%s\n", sep); +} + +#define MAG_FMT_DIM_BUF_SIZE ((21+4)*MAG_MAX_DIMS) +static void mag_fmt_dims(char (*buf)[MAG_FMT_DIM_BUF_SIZE], const int64_t (*dims)[MAG_MAX_DIMS], int64_t rank) { + mag_static_assert(MAG_MAX_DIMS == 6); + memset(*buf, 0, sizeof(*buf)); + char* p = *buf; + mag_assert2(p+rank*21+3 < *buf+MAG_FMT_DIM_BUF_SIZE); + *p++ = '('; + for (int64_t i=0; i < rank; ++i) { + p += snprintf(p, 21, "%" PRIi64, (*dims)[i]); + if (i < rank-1) { + *p++ = ','; + *p++ = ' '; + } + } + *p++ = ')'; + *p = '\0'; +} + +#ifdef _WIN32 +#include +extern __declspec(dllimport) int __stdcall MultiByteToWideChar( + unsigned int cp, + unsigned long flags, + const char* str, + int cbmb, + wchar_t* widestr, + int cchwide +); +extern __declspec(dllimport) int __stdcall WideCharToMultiByte( + unsigned int cp, + unsigned long flags, + const wchar_t* widestr, + int cchwide, + char* str, + int cbmb, + const char* defchar, + int* used_default +); +#endif + +static FILE* mag_fopen(const char* file, const char* mode) { + mag_assert(file && *file && mode && *mode, "Invalid file name or mode"); + FILE* f = NULL; + #ifdef _WIN32 + wchar_t w_mode[64]; + wchar_t w_file[1024]; + if (MultiByteToWideChar(65001 /* UTF8 */, 0, file, -1, w_file, sizeof(w_file)/sizeof(*w_file)) == 0) return NULL; + if (MultiByteToWideChar(65001 /* UTF8 */, 0, mode, -1, w_mode, sizeof(w_mode)/sizeof(*w_mode)) == 0) return NULL; + #if defined(_MSC_VER) && _MSC_VER >= 1400 + if (_wfopen_s(&f, w_file, w_mode) != 0) + return NULL; + #else + f = _wfopen(w_file, w_mode); + #endif + #elif defined(_MSC_VER) && _MSC_VER >= 1400 + if (fopen_s(&f, filename, mode) != 0) return NULL; + #else + f = fopen(file, mode); + #endif + return f; +} + +static inline uintptr_t mag_thread_id(void) { + uintptr_t tid; + #if defined(_MSC_VER) && defined(_M_X64) + tid = __readgsqword(48); + #elif defined(_MSC_VER) && defined(_M_IX86) + tid = __readfsdword(24); + #elif defined(_MSC_VER) && defined(_M_ARM64) + tid = __getReg(18); + #elif defined(__i386__) + __asm__ __volatile__("movl %%gs:0, %0" : "=r" (tid)); /* x86-32 WIN32 uses %GS */ + #elif defined(__MACH__) && defined(__x86_64__) + __asm__ __volatile__("movq %%gs:0, %0" : "=r" (tid)); /* x86.64 OSX uses %GS */ + #elif defined(__x86_64__) + __asm__ __volatile__("movq %%fs:0, %0" : "=r" (tid)); /* x86-64 Linux and BSD uses %FS */ + #elif defined(__arm__) + __asm__ __volatile__("mrc p15, 0, %0, c13, c0, 3\nbic %0, %0, #3" : "=r" (tid)); + #elif defined(__aarch64__) && defined(__APPLE__) + __asm__ __volatile__("mrs %0, tpidrro_el0" : "=r" (tid)); + #elif defined(__aarch64__) + __asm__ __volatile__("mrs %0, tpidr_el0" : "=r" (tid)); + #elif defined(__powerpc64__) + #ifdef __clang__ + tid = (uintptr_t)__builtin_thread_pointer(); + #else + register uintptr_t tp __asm__ ("r13"); + __asm__ __volatile__("" : "=r" (tp)); + tid = tp; + #endif + #elif defined(__powerpc__) + #ifdef __clang__ + tid = (uintptr_t)__builtin_thread_pointer(); + #else + register uintptr_t tp __asm__ ("r2"); + __asm__ __volatile__("" : "=r" (tp)); + tid = tp; + #endif + #elif defined(__s390__) && defined(__GNUC__) + tid = (uintptr_t)__builtin_thread_pointer(); + #elif defined(__riscv) + #ifdef __clang__ + tid = (uintptr_t)__builtin_thread_pointer(); + #else + __asm__ ("mv %0, tp" : "=r" (tid)); + #endif + #else + #error "Unsupported magnetron platform" + #endif + return tid; +} + +static uint64_t mag_hpc_clock_ns(void) { /* High precision clock in nanoseconds. */ + #ifdef _WIN32 + static LONGLONG t_freq; + static LONGLONG t_boot; + static bool t_init = false; + if (!t_init) { /* Reduce chance of integer overflow when uptime is high. */ + LARGE_INTEGER li; + QueryPerformanceFrequency(&li); + t_freq = li.QuadPart; + QueryPerformanceCounter(&li); + t_boot = li.QuadPart; + t_init = true; + } + LARGE_INTEGER li; + QueryPerformanceCounter(&li); + return ((li.QuadPart - t_boot)*1000000000) / t_freq; + #else + struct timespec ts; + clock_gettime(CLOCK_MONOTONIC, &ts); + return (uint64_t)ts.tv_sec*1000000000 + (uint64_t)ts.tv_nsec; + #endif +} +static uint64_t mag_hpc_clock_elapsed_ns(uint64_t start) { /* High precision clock elapsed time in microseconds. */ + return (uint64_t)llabs((int64_t)mag_hpc_clock_ns() - (int64_t)start); +} +static double mag_hpc_clock_elapsed_ms(uint64_t start) { /* High precision clock elapsed time in milliseconds. */ + return (double)mag_hpc_clock_elapsed_ns(start) / 1e6; +} +#define mag_clock_cycles() ((uint64_t)clock()) +#define mag_cycles_per_ms() ((uint64_t)CLOCKS_PER_SEC/1000) + +typedef uint32_t mag_bitset_t; +mag_static_assert(sizeof(mag_bitset_t) == 4); +#define mag_bitset_size(n) (((n)+((4<<3)-1))>>5) +#define mag_bitset_get(sets, i) (!!(sets[(i)>>5]&(1u<<((i)&((4<<3)-1))))) +#define mag_bitset_set(sets, i) (sets[(i)>>5]|=(1u<<((i)&((4<<3)-1)))) +#define mag_bitset_clear(sets, i) (sets[(i)>>5]&=~(1u<<((i)&((4<<3)-1)))) +#define mag_bitset_toggle(sets, i) (sets[(i)>>5]^=(1u<<((i)&((4<<3)-1)))) + +#if defined(__aarch64__) && defined(__ARM_FEATURE_CRC32) && defined(__ARM_FEATURE_CRYPTO) +static uint64x2_t MAG_AINLINE mag_clmul_lo_e(uint64x2_t a, uint64x2_t b, uint64x2_t c) { + register uint64x2_t r; + __asm__ __volatile__( + "pmull %0.1q, %2.1d, %3.1d\n" + "eor %0.16b, %0.16b, %1.16b\n" + : "=w"(r), "+w"(c) : "w"(a), "w"(b) + ); + return r; +} +static uint64x2_t MAG_AINLINE mag_clmul_hi_e(uint64x2_t a, uint64x2_t b, uint64x2_t c) { + register uint64x2_t r; + __asm__ __volatile__( + "pmull2 %0.1q, %2.2d, %3.2d\n" + "eor %0.16b, %0.16b, %1.16b\n" + : "=w"(r), "+w"(c) : "w"(a), "w"(b) + ); + return r; +} +#elif defined(__x86_64__) || defined(_M_X64) +static uint32_t mag_xnmodp(uint64_t n) { /* x^n mod P, in log(n) time */ + uint64_t stack = ~(uint64_t)1; + uint32_t low; + for (; n > 191; n = (n>>1) - 16) stack = (stack<<1) + (n & 1); + stack = ~stack; + uint32_t acc = 0x80000000 >> (n & 31); + for (n >>= 5; n; --n) acc = _mm_crc32_u32(acc, 0); + while (low = stack & 1, stack >>= 1) { + __m128i x = _mm_cvtsi32_si128(acc); + uint64_t y = _mm_cvtsi128_si64(_mm_clmulepi64_si128(x, x, 0)); + acc = _mm_crc32_u64(0, y << low); + } + return acc; +} +static __m128i MAG_AINLINE mag_clmul_scalar(uint32_t a, uint32_t b) { + return _mm_clmulepi64_si128(_mm_cvtsi32_si128(a), _mm_cvtsi32_si128(b), 0); +} +static __m128i MAG_AINLINE mag_crc_shift(uint32_t crc, size_t sz) { + return mag_clmul_scalar(crc, mag_xnmodp((sz<<3) - 33)); +} +#endif + +static uint32_t mag_crc32c(const void* buffer, size_t size) { /* Compute CRC32 checksum with CRC32c polynomial. */ + if (mag_unlikely(!buffer || !size)) return 0; + const uint8_t* buf = (const uint8_t*)buffer; + #if defined(__aarch64__) && defined(__ARM_FEATURE_CRC32) && defined(__ARM_FEATURE_CRYPTO) + uint32_t crc = ~0; + for (; size && ((uintptr_t)buf & 7); --size) crc = __crc32cb(crc, *buf++); + if (((uintptr_t)buf & 8) && size >= 8) { + crc = __crc32cd(crc, *(const uint64_t*)buf); + buf += 8; + size -= 8; + } + if (size >= 192) { /* First vector chunk. */ + uint64x2_t x0 = vld1q_u64((const uint64_t*)buf), y0; + uint64x2_t x1 = vld1q_u64((const uint64_t*)(buf+16)), y1; + uint64x2_t x2 = vld1q_u64((const uint64_t*)(buf+32)), y2; + uint64x2_t x3 = vld1q_u64((const uint64_t*)(buf+48)), y3; + uint64x2_t x4 = vld1q_u64((const uint64_t*)(buf+64)), y4; + uint64x2_t x5 = vld1q_u64((const uint64_t*)(buf+80)), y5; + uint64x2_t x6 = vld1q_u64((const uint64_t*)(buf+96)), y6; + uint64x2_t x7 = vld1q_u64((const uint64_t*)(buf+112)), y7; + uint64x2_t x8 = vld1q_u64((const uint64_t*)(buf+128)), y8; + uint64x2_t x9 = vld1q_u64((const uint64_t*)(buf+144)), y9; + uint64x2_t x10 = vld1q_u64((const uint64_t*)(buf+160)), y10; + uint64x2_t x11 = vld1q_u64((const uint64_t*)(buf+176)), y11; + uint64x2_t k; + { static const uint64_t MAG_ALIGN(16) k_[] = {0xa87ab8a8, 0xab7aff2a}; k = vld1q_u64(k_); } + x0 = veorq_u64((uint64x2_t){crc, 0}, x0); + buf += 192; + size -= 192; + while (size >= 192) { /* Work loop. */ + y0 = mag_clmul_lo_e(x0, k, vld1q_u64((const uint64_t*)buf)), x0 = mag_clmul_hi_e(x0, k, y0); + y1 = mag_clmul_lo_e(x1, k, vld1q_u64((const uint64_t*)(buf+16))), x1 = mag_clmul_hi_e(x1, k, y1); + y2 = mag_clmul_lo_e(x2, k, vld1q_u64((const uint64_t*)(buf+32))), x2 = mag_clmul_hi_e(x2, k, y2); + y3 = mag_clmul_lo_e(x3, k, vld1q_u64((const uint64_t*)(buf+48))), x3 = mag_clmul_hi_e(x3, k, y3); + y4 = mag_clmul_lo_e(x4, k, vld1q_u64((const uint64_t*)(buf+64))), x4 = mag_clmul_hi_e(x4, k, y4); + y5 = mag_clmul_lo_e(x5, k, vld1q_u64((const uint64_t*)(buf+80))), x5 = mag_clmul_hi_e(x5, k, y5); + y6 = mag_clmul_lo_e(x6, k, vld1q_u64((const uint64_t*)(buf+96))), x6 = mag_clmul_hi_e(x6, k, y6); + y7 = mag_clmul_lo_e(x7, k, vld1q_u64((const uint64_t*)(buf+112))), x7 = mag_clmul_hi_e(x7, k, y7); + y8 = mag_clmul_lo_e(x8, k, vld1q_u64((const uint64_t*)(buf+128))), x8 = mag_clmul_hi_e(x8, k, y8); + y9 = mag_clmul_lo_e(x9, k, vld1q_u64((const uint64_t*)(buf+144))), x9 = mag_clmul_hi_e(x9, k, y9); + y10 = mag_clmul_lo_e(x10, k, vld1q_u64((const uint64_t*)(buf+160))), x10 = mag_clmul_hi_e(x10, k, y10); + y11 = mag_clmul_lo_e(x11, k, vld1q_u64((const uint64_t*)(buf+176))), x11 = mag_clmul_hi_e(x11, k, y11); + buf += 192; + size -= 192; + } + /* Reduce x0 ... x11 to just x0. */ + { static const uint64_t MAG_ALIGN(16) k_[] = {0xf20c0dfe, 0x493c7d27}; k = vld1q_u64(k_); } + y0 = mag_clmul_lo_e(x0, k, x1), x0 = mag_clmul_hi_e(x0, k, y0); + y2 = mag_clmul_lo_e(x2, k, x3), x2 = mag_clmul_hi_e(x2, k, y2); + y4 = mag_clmul_lo_e(x4, k, x5), x4 = mag_clmul_hi_e(x4, k, y4); + y6 = mag_clmul_lo_e(x6, k, x7), x6 = mag_clmul_hi_e(x6, k, y6); + y8 = mag_clmul_lo_e(x8, k, x9), x8 = mag_clmul_hi_e(x8, k, y8); + y10 = mag_clmul_lo_e(x10, k, x11), x10 = mag_clmul_hi_e(x10, k, y10); + { static const uint64_t MAG_ALIGN(16) k_[] = {0x3da6d0cb, 0xba4fc28e}; k = vld1q_u64(k_); } + y0 = mag_clmul_lo_e(x0, k, x2), x0 = mag_clmul_hi_e(x0, k, y0); + y4 = mag_clmul_lo_e(x4, k, x6), x4 = mag_clmul_hi_e(x4, k, y4); + y8 = mag_clmul_lo_e(x8, k, x10), x8 = mag_clmul_hi_e(x8, k, y8); + { static const uint64_t MAG_ALIGN(16) k_[] = {0x740eef02, 0x9e4addf8}; k = vld1q_u64(k_); } + y0 = mag_clmul_lo_e(x0, k, x4), x0 = mag_clmul_hi_e(x0, k, y0); + x4 = x8; + y0 = mag_clmul_lo_e(x0, k, x4), x0 = mag_clmul_hi_e(x0, k, y0); + /* Reduce 128 bits to 32 bits, and multiply by x^32. */ + crc = __crc32cd(0, vgetq_lane_u64(x0, 0)); + crc = __crc32cd(crc, vgetq_lane_u64(x0, 1)); + } + for (; size >= 8; buf += 8, size -= 8) crc = __crc32cd(crc, *(const uint64_t*)buf); + for (; size; --size) crc = __crc32cb(crc, *buf++); + return ~crc; + #elif defined(__x86_64__) || defined(_M_X64) + uint32_t crc = ~0; + for (; size && ((uintptr_t)buf & 7); --size) crc = _mm_crc32_u8(crc, *buf++); + if (size >= 32) { + size_t klen = ((size - 8) / 24)<<3; + uint32_t crc1 = 0; + uint32_t crc2 = 0; + /* Main loop. */ + do { + crc = _mm_crc32_u64(crc, *(const uint64_t*)buf); + crc1 = _mm_crc32_u64(crc1, *(const uint64_t*)(buf + klen)); + crc2 = _mm_crc32_u64(crc2, *(const uint64_t*)(buf + (klen<<1))); + buf += 8; + size -= 24; + } while (size >= 32); + __m128i vc0 = mag_crc_shift(crc, (klen << 1) + 8); + __m128i vc1 = mag_crc_shift(crc1, klen + 8); + uint64_t vc = _mm_extract_epi64(_mm_xor_si128(vc0, vc1), 0); + /* Final 8 bytes. */ + buf += klen<<1; + crc = crc2; + crc = _mm_crc32_u64(crc, *(const uint64_t*)buf ^ vc), buf += 8; + size -= 8; + } + for (; size >= 8; buf += 8, size -= 8) crc = _mm_crc32_u64(crc, *(const uint64_t*)buf); + for (; size; --size) crc = _mm_crc32_u8(crc, *buf++); + return ~crc; + #else + static const uint32_t crc_lut[256] = { + 0x00000000, 0xf26b8303, 0xe13b70f7, 0x1350f3f4, 0xc79a971f, 0x35f1141c, + 0x26a1e7e8, 0xd4ca64eb, 0x8ad958cf, 0x78b2dbcc, 0x6be22838, 0x9989ab3b, + 0x4d43cfd0, 0xbf284cd3, 0xac78bf27, 0x5e133c24, 0x105ec76f, 0xe235446c, + 0xf165b798, 0x030e349b, 0xd7c45070, 0x25afd373, 0x36ff2087, 0xc494a384, + 0x9a879fa0, 0x68ec1ca3, 0x7bbcef57, 0x89d76c54, 0x5d1d08bf, 0xaf768bbc, + 0xbc267848, 0x4e4dfb4b, 0x20bd8ede, 0xd2d60ddd, 0xc186fe29, 0x33ed7d2a, + 0xe72719c1, 0x154c9ac2, 0x061c6936, 0xf477ea35, 0xaa64d611, 0x580f5512, + 0x4b5fa6e6, 0xb93425e5, 0x6dfe410e, 0x9f95c20d, 0x8cc531f9, 0x7eaeb2fa, + 0x30e349b1, 0xc288cab2, 0xd1d83946, 0x23b3ba45, 0xf779deae, 0x05125dad, + 0x1642ae59, 0xe4292d5a, 0xba3a117e, 0x4851927d, 0x5b016189, 0xa96ae28a, + 0x7da08661, 0x8fcb0562, 0x9c9bf696, 0x6ef07595, 0x417b1dbc, 0xb3109ebf, + 0xa0406d4b, 0x522bee48, 0x86e18aa3, 0x748a09a0, 0x67dafa54, 0x95b17957, + 0xcba24573, 0x39c9c670, 0x2a993584, 0xd8f2b687, 0x0c38d26c, 0xfe53516f, + 0xed03a29b, 0x1f682198, 0x5125dad3, 0xa34e59d0, 0xb01eaa24, 0x42752927, + 0x96bf4dcc, 0x64d4cecf, 0x77843d3b, 0x85efbe38, 0xdbfc821c, 0x2997011f, + 0x3ac7f2eb, 0xc8ac71e8, 0x1c661503, 0xee0d9600, 0xfd5d65f4, 0x0f36e6f7, + 0x61c69362, 0x93ad1061, 0x80fde395, 0x72966096, 0xa65c047d, 0x5437877e, + 0x4767748a, 0xb50cf789, 0xeb1fcbad, 0x197448ae, 0x0a24bb5a, 0xf84f3859, + 0x2c855cb2, 0xdeeedfb1, 0xcdbe2c45, 0x3fd5af46, 0x7198540d, 0x83f3d70e, + 0x90a324fa, 0x62c8a7f9, 0xb602c312, 0x44694011, 0x5739b3e5, 0xa55230e6, + 0xfb410cc2, 0x092a8fc1, 0x1a7a7c35, 0xe811ff36, 0x3cdb9bdd, 0xceb018de, + 0xdde0eb2a, 0x2f8b6829, 0x82f63b78, 0x709db87b, 0x63cd4b8f, 0x91a6c88c, + 0x456cac67, 0xb7072f64, 0xa457dc90, 0x563c5f93, 0x082f63b7, 0xfa44e0b4, + 0xe9141340, 0x1b7f9043, 0xcfb5f4a8, 0x3dde77ab, 0x2e8e845f, 0xdce5075c, + 0x92a8fc17, 0x60c37f14, 0x73938ce0, 0x81f80fe3, 0x55326b08, 0xa759e80b, + 0xb4091bff, 0x466298fc, 0x1871a4d8, 0xea1a27db, 0xf94ad42f, 0x0b21572c, + 0xdfeb33c7, 0x2d80b0c4, 0x3ed04330, 0xccbbc033, 0xa24bb5a6, 0x502036a5, + 0x4370c551, 0xb11b4652, 0x65d122b9, 0x97baa1ba, 0x84ea524e, 0x7681d14d, + 0x2892ed69, 0xdaf96e6a, 0xc9a99d9e, 0x3bc21e9d, 0xef087a76, 0x1d63f975, + 0x0e330a81, 0xfc588982, 0xb21572c9, 0x407ef1ca, 0x532e023e, 0xa145813d, + 0x758fe5d6, 0x87e466d5, 0x94b49521, 0x66df1622, 0x38cc2a06, 0xcaa7a905, + 0xd9f75af1, 0x2b9cd9f2, 0xff56bd19, 0x0d3d3e1a, 0x1e6dcdee, 0xec064eed, + 0xc38d26c4, 0x31e6a5c7, 0x22b65633, 0xd0ddd530, 0x0417b1db, 0xf67c32d8, + 0xe52cc12c, 0x1747422f, 0x49547e0b, 0xbb3ffd08, 0xa86f0efc, 0x5a048dff, + 0x8ecee914, 0x7ca56a17, 0x6ff599e3, 0x9d9e1ae0, 0xd3d3e1ab, 0x21b862a8, + 0x32e8915c, 0xc083125f, 0x144976b4, 0xe622f5b7, 0xf5720643, 0x07198540, + 0x590ab964, 0xab613a67, 0xb831c993, 0x4a5a4a90, 0x9e902e7b, 0x6cfbad78, + 0x7fab5e8c, 0x8dc0dd8f, 0xe330a81a, 0x115b2b19, 0x020bd8ed, 0xf0605bee, + 0x24aa3f05, 0xd6c1bc06, 0xc5914ff2, 0x37faccf1, 0x69e9f0d5, 0x9b8273d6, + 0x88d28022, 0x7ab90321, 0xae7367ca, 0x5c18e4c9, 0x4f48173d, 0xbd23943e, + 0xf36e6f75, 0x0105ec76, 0x12551f82, 0xe03e9c81, 0x34f4f86a, 0xc69f7b69, + 0xd5cf889d, 0x27a40b9e, 0x79b737ba, 0x8bdcb4b9, 0x988c474d, 0x6ae7c44e, + 0xbe2da0a5, 0x4c4623a6, 0x5f16d052, 0xad7d5351 + }; + uint32_t crc = ~0u; + for (size_t i=0; i < size; ++i) + crc = (crc >> 8) ^ crc_lut[buf[i] ^ (crc & 0xff)]; + return ~crc; + #endif +} + +typedef enum mag_format_type { + MAG_FMT_EOF, MAG_FMT_ERR, MAG_FMT_LIT, MAG_FMT_INT, + MAG_FMT_UINT, MAG_FMT_NUM, MAG_FMT_STR, MAG_FMT_CHAR, + MAG_FMT_PTR +} mag_format_type; /* Format types for formatted output */ + +typedef uint32_t mag_format_flags; /* Flags for formatting output */ + +/* Format flags */ +#define MAG_FMT_F_LEFT 0x0100 /* Left-align the output */ +#define MAG_FMT_F_PLUS 0x0200 /* Prefix positive numbers with a plus sign */ +#define MAG_FMT_F_ZERO 0x0400 /* Pad with zeros instead of spaces */ +#define MAG_FMT_F_SPACE 0x0800 /* Prefix a space for positive numbers */ +#define MAG_FMT_F_ALT 0x1000 /* Alternate format flag */ +#define MAG_FMT_F_UPPER 0x2000 /* Use uppercase letters for hex output */ + +/* Format subtypes (bits reused) */ +#define MAG_FMT_T_HEX 0x0010 /* Hexadecimal format for unsigned integers */ +#define MAG_FMT_T_OCT 0x0020 /* Octal format for unsigned integers */ +#define MAG_FMT_T_FP_A 0x0000 /* 'a' format for floating-point numbers */ +#define MAG_FMT_T_FP_E 0x0010 /* 'e' format for floating-point numbers */ +#define MAG_FMT_T_FP_F 0x0020 /* 'f' format for floating-point numbers */ +#define MAG_FMT_T_FP_G 0x0030 /* 'g' format for floating-point numbers */ +#define MAG_FMT_T_QUOTED 0x0010 /* Quoted string format */ + +#define MAG_FMT_SH_WIDTH 16 /* Shift width for formatting */ +#define MAG_FMT_SH_PREC 24 /* Shift precision for formatting */ +#define MAG_FMT_TYPE(sf) ((mag_format_type)((sf) & 15)) /* Extract format type */ +#define MAG_FMT_WIDTH(sf) (((sf) >> MAG_FMT_SH_WIDTH) & 255u) /* Extract width */ +#define MAG_FMT_PREC(sf) ((((sf) >> MAG_FMT_SH_PREC) & 255u) - 1u) /* Extract precision */ +#define MAG_FMT_FP(sf) (((sf) >> 4) & 3) /* Extract floating-point format */ + +/* Formats for conversion characters */ +#define MAG_FMT_A (MAG_FMT_NUM|MAG_FMT_T_FP_A) /* 'a' format */ +#define MAG_FMT_C (MAG_FMT_CHAR) /* 'c' format */ +#define MAG_FMT_D (MAG_FMT_INT) /* 'd' format */ +#define MAG_FMT_E (MAG_FMT_NUM|MAG_FMT_T_FP_E) /* 'e' format */ +#define MAG_FMT_F (MAG_FMT_NUM|MAG_FMT_T_FP_F) /* 'f' format */ +#define MAG_FMT_G (MAG_FMT_NUM|MAG_FMT_T_FP_G) /* 'g' format */ +#define MAG_FMT_I MAG_FMT_D /* 'i' format (same as 'd') */ +#define MAG_FMT_O (MAG_FMT_UINT|MAG_FMT_T_OCT) /* 'o' format */ +#define MAG_FMT_P (MAG_FMT_PTR) /* 'p' format */ +#define MAG_FMT_Q (MAG_FMT_STR|MAG_FMT_T_QUOTED) /* Quoted string */ +#define MAG_FMT_S (MAG_FMT_STR) /* 's' format */ +#define MAG_FMT_U (MAG_FMT_UINT) /* 'u' format */ +#define MAG_FMT_X (MAG_FMT_UINT|MAG_FMT_T_HEX) /* 'x' format */ +#define MAG_FMT_G14 (MAG_FMT_G | ((14+1) << MAG_FMT_SH_PREC)) /* 'g' format with precision 14 */ + +static char* mag_fmt_f64(mag_format_flags sf, double n, char* p); + +static bool MAG_AINLINE mag_imull64_ov(int64_t a, int64_t b, int64_t* out) { /* Performs c = a*b with overflow checking. Returns true on overflow, else false. */ + #ifdef _MSC_VER + #ifdef _M_ARM64 + uint64_t high = __umulh(a, b); + *out = a*b; + return high != (*out>>63); + #else + int64_t high; + int64_t low = _mul128(a, b, &high); + int64_t sign = low >> 63; + *out = low; + return high != sign; + #endif + #else + #if __SIZEOF_LONG_LONG__ == 8 && __SIZEOF_LONG__ == 8 + return __builtin_smulll_overflow(a, b, (long long*)out); + #else + return __builtin_smull_overflow(a, b, out); + #endif + #endif +} + +/* Generate n uniform random floats within [min, max]. */ +static void mag_prng_generate_n(mag_ctx_t* ctx, float* out_gen, int64_t out_n, float min, float max) { + float rescale_uniform = max - min; + switch (ctx->prng_algorithm) { + case MAG_PRNG_MERSENNE_TWISTER: { + uint32_t* rem = &ctx->prng.mersenne.remaining; + uint32_t* next = &ctx->prng.mersenne.next; + uint32_t* state = ctx->prng.mersenne.state; + for (int64_t ii=0; ii < out_n; ++ii) { + if (--*rem <= 0) { + *rem = 624; + *next = 0; + uint32_t y, i; + for (i = 0; i < 624-397; ++i) { + y = (state[i] & 0x80000000u) | (state[i+1] & 0x7fffffffu); + state[i] = state[i+397] ^ (y>>1) ^ ((y&1) ? 0 : 0x9908b0dfu); + } + for (; i < 624-1; ++i) { + y = (state[i] & 0x80000000u) | (state[i+1] & 0x7fffffffu); + state[i] = state[i + (397-624)] ^ (y>>1) ^ ((y&1) ? 0 : 0x9908b0dfu); + } + y = (state[624-1] & 0x80000000u) | (*state & 0x7fffffffu); + state[624-1] = state[397-1] ^ (y>>1) ^ ((y&1) ? 0 : 0x9908b0dfu); + } + uint32_t y = state[(*next)++]; + y ^= y >> 11; + y ^= (y << 7) & 0x9d2c5680; + y ^= (y << 15) & 0xefc60000; + y ^= y >> 18; + out_gen[ii] = min + rescale_uniform * (1.f/(float)(1<<23)*((float)(y>>9) + 0.5f)); + } + } break; + case MAG_PRNG_PCG: { + uint64_t* state = &ctx->prng.pcg.state; + uint64_t* inc = &ctx->prng.pcg.inc; + for (int64_t ii=0; ii < out_n; ++ii) { + uint64_t prev = *state; + *state = prev*6364136223846793005ull + *inc; + uint32_t mixed = ((prev>>18u) ^ prev) >> 27u; + uint32_t rot = prev >> 59u; + uint32_t y = (mixed>>rot) | (mixed << ((-rot)&31)); + out_gen[ii] = min + rescale_uniform * (1.f/(float)(1<<23)*((float)(y>>9) + 0.5f)); + } + } break; + default: + mag_panic("Unknown PRNG algorithm: %d", ctx->prng_algorithm); + } +} + +static void mag_prng_init(mag_ctx_t* ctx, uint64_t seed) { + seed = seed ? seed : 0x853c49e6748fea9bull ^ (uintptr_t)ctx ^ (uintptr_t)&ctx; /* Default seed. */ + switch (ctx->prng_algorithm) { + case MAG_PRNG_MERSENNE_TWISTER: { + uint32_t* state = ctx->prng.mersenne.state; + *state = (uint32_t)seed; + for (size_t i=1; i < 624; ++i) + state[i] = ((state[i-1] ^ (state[i-1] >> 30))*1812433253 + i) & ~0u; + ctx->prng.mersenne.next = 0; + ctx->prng.mersenne.remaining = 1; + } break; + case MAG_PRNG_PCG: { + ctx->prng.pcg.state = seed ^ 0x853c49e6748fea9bull; + ctx->prng.pcg.inc = 0xda3e39cb94b95bdbull; + } break; + default: + mag_panic("Unknown PRNG algorithm: %d", ctx->prng_algorithm); + } +} + +#if defined(__x86_64__) || defined(_M_X64) +static bool mag_ctx_x86_64_cpu_has_feature(const mag_ctx_t* ctx, mag_x86_64_feature_t feature) { + const uint8_t (*leafs)[49] = &mag_x86_64_feature_leaves; + const uint8_t (*regs)[49] = &mag_x86_64_feature_regs; + const uint32_t (*features)[8][4] = &ctx->sys.x86_64_cpu_features; + const uint32_t (*masks)[49] = &msml__x86_64_feature_masks; + return (*features)[(*leafs)[feature]][(*regs)[feature]] & (*masks)[feature]; +} +#endif + +static void mag_system_host_info_query(mag_ctx_t* ctx); /* Query host system information. */ +static void mag_system_host_info_dump(mag_ctx_t* ctx) { + mag_log_info("OS/Kernel: %s", ctx->sys.os_name); + const char* cpu_arch = "?"; + #if defined(__x86_64__) || defined(_M_X64) + cpu_arch = "x86-64"; + #elif defined(__aarch64__) || defined(_M_ARM64) + cpu_arch = "aarch64"; + #else + #error "Unknwon CPU arch" + #endif + mag_log_info("CPU (%s): %s, Virtual Cores: %u, Physical Cores: %u, Sockets: %u", cpu_arch, ctx->sys.cpu_name, ctx->sys.cpu_virtual_cores, ctx->sys.cpu_physical_cores, ctx->sys.cpu_sockets); + #if defined(__x86_64__) || defined(_M_X64) /* Print CPU features for x86-64 platforms. */ + if (mag_log_enabled) { + printf("CPU Features:"); + for (uint32_t i=0, k=0; i < MAG_X86_64_FEATURE__COUNT; ++i) { + if (mag_ctx_x86_64_cpu_has_feature(ctx, i)) { + if ((k++ & 7) == 0) printf("\n\t"); + printf("%s ", mag_x86_64_feature_names[i]); + } + } + putchar('\n'); + } + #endif + double mem_total, mem_free, mem_used; + const char* mem_unit_total, *mem_unit_free, *mem_unit_used; + mag_humanize_memory_size(ctx->sys.phys_mem_total, &mem_total, &mem_unit_total); + mag_humanize_memory_size(ctx->sys.phys_mem_free, &mem_free, &mem_unit_free); + mag_humanize_memory_size((size_t)llabs((int64_t)ctx->sys.phys_mem_total-(int64_t)ctx->sys.phys_mem_free), &mem_used, &mem_unit_used); + double mem_used_percent = fabs((double)(ctx->sys.phys_mem_total-ctx->sys.phys_mem_free))/(double)ctx->sys.phys_mem_total*100.0; + mag_log_info("Physical Machine Memory: %.03f %s, Free: %.03f %s, Used: %.03f %s (%.02f%%)", mem_total, mem_unit_total, mem_free, mem_unit_free, mem_used, mem_unit_used, mem_used_percent); +} + +/* Default image loader/saver implementation. */ +static uint8_t* mag_default_image_load_impl(const char*, uint32_t(*)[3], mag_color_channels_t); +static void mag_default_image_load_free_fn_impl(uint8_t*); +static bool mag_default_image_save_impl(const char*, const uint8_t*, const uint32_t(*)[3]); + +static MAG_COLDPROC void mag_ctx_dump_compiler_info(void) { + const char* compiler_name = "Unknown"; + int compiler_version_major = 0, compiler_version_minor = 0; + #ifdef __clang__ + compiler_name = "Clang"; + compiler_version_major = __clang_major__; + compiler_version_minor = __clang_minor__; + #elif defined(__GNUC__) + compiler_name = "GCC"; + compiler_version_major = __GNUC__; + compiler_version_minor = __GNUC_MINOR__; + #elif defined(_MSC_VER) + compiler_name = "MSVC"; + compiler_version_major = _MSC_VER / 100; + compiler_version_minor = _MSC_VER % 100; + #endif + mag_log_info("magnetron v.%d.%d - " __DATE__ " " __TIME__ " - %s %d.%d", mag_version_major(MAG_VERSION), mag_version_minor(MAG_VERSION), compiler_name, compiler_version_major, compiler_version_minor); +} + +mag_ctx_t* mag_ctx_create(mag_compute_device_type_t device) { + mag_log_info("Creating magnetron context..."); + + uint64_t time_stamp_start = mag_hpc_clock_ns(); + mag_ctx_dump_compiler_info(); /* Dump compiler info. */ + + /* Initialize context with default values or from context info. */ + mag_ctx_t* ctx = (mag_ctx_t*)(*mag_alloc)(NULL, sizeof(*ctx)); /* Allocate context. */ + memset(ctx, 0, sizeof(*ctx)); + + /* Init allocators */ + mag_fixed_intrusive_pool_init(&ctx->tensor_pool, sizeof(mag_tensor_t), __alignof(mag_tensor_t), 4096); + + ctx->tr_id = mag_thread_id(); /* Get thread ID. */ + + /* Query and print host system information. */ + mag_system_host_info_query(ctx); + mag_system_host_info_dump(ctx); + + /* Configure configureable media processors */ + ctx->image_load_fn = &mag_default_image_load_impl; + ctx->image_load_free_fn = &mag_default_image_load_free_fn_impl; + ctx->image_save_fn = &mag_default_image_save_impl; + + /* Initialize PRNG state. */ + ctx->prng_algorithm = MAG_PRNG_MERSENNE_TWISTER; + mag_prng_init(ctx, ctx->tr_id^(uintptr_t)ctx^(uintptr_t)&ctx); /* Initialize PRNG state. */ + + /* Create selected compute device. */ + ctx->exec_mode = MAG_EXEC_MODE_EAGER; + ctx->device_type = device; + ctx->device = mag_init_dynamic_device(ctx, &ctx->device_type); + mag_log_info("Compute device: %s", ctx->device->name); + + + /* Print context initialization time. */ + mag_log_info("magnetron context initialized in %.05f ms", mag_hpc_clock_elapsed_ms(time_stamp_start)); + return ctx; +} + +static void mag_tensor_destroy(mag_tensor_t* t); + +void mag_ctx_destroy(mag_ctx_t* ctx) { +#if MAG_SANITIZE_RC /* Check for leaked tensors in RC tracking list and print them */ + mag_tensor_node_t** head = &ctx->rc_tracked; + mag_tensor_node_t* curr = *head; + uint32_t nleaked = 0; + for (; curr; curr = curr->next, ++nleaked) { + mag_log_error("Leaked tensor detected: %p, RCS: %u, RCW: %u, CTOR: %s", curr->tensor, curr->tensor->rcb.rc_strong, curr->tensor->rcb.rc_weak, curr->tensor->rcb.dtor ? "Y" : "N"); + mag_tensor_print(curr->tensor, true, false); + } + if (nleaked) mag_log_error("Leaked tensors detected: %u", nleaked); + else mag_log_info("No leaked tensors detected."); + curr = *head; + while (curr) { /* Free tracking list */ + mag_tensor_node_t* tmp = curr; + curr = curr->next; + mag_tensor_destroy(tmp->tensor); + (*mag_alloc)(tmp, 0); + } + *head = NULL; +#endif + mag_fixed_intrusive_pool_print_info(&ctx->tensor_pool, "Tensor Hull Pool"); + mag_fixed_intrusive_pool_destroy(&ctx->tensor_pool); + mag_destroy_dynamic_device(ctx->device); ctx->device = NULL; + memset(ctx, 0, sizeof(*ctx)); + (*mag_alloc)(ctx, 0); + ctx = NULL; + mag_log_info("magnetron context destroyed."); +} + +mag_exec_mode_t mag_ctx_get_exec_mode(const mag_ctx_t* ctx) { return ctx->exec_mode; } + +void mag_ctx_set_exec_mode(mag_ctx_t* ctx, mag_exec_mode_t mode) { + ctx->exec_mode = mode; + mag_log_info("Execution mode set to: %s", mode == MAG_EXEC_MODE_EAGER ? "Eager" : "Deferred"); +} + +mag_prng_algorithm_t mag_ctx_get_prng_algorithm(const mag_ctx_t* ctx) { return ctx->prng_algorithm; } + +void mag_ctx_set_prng_algorithm(mag_ctx_t* ctx, mag_prng_algorithm_t algorithm, uint64_t seed) { + ctx->prng_algorithm = algorithm; + mag_prng_init(ctx, seed); /* Reinitialize PRNG state with new seed. */ +} + +mag_compute_device_type_t mag_ctx_get_compute_device_type(const mag_ctx_t* ctx) { return ctx->device_type; } +const char* mag_ctx_get_compute_device_name(const mag_ctx_t* ctx) { return ctx->device->name; } +const char* mag_ctx_get_os_name(const mag_ctx_t* ctx) { return ctx->sys.os_name; } +const char* mag_ctx_get_cpu_name(const mag_ctx_t* ctx) { return ctx->sys.cpu_name; } +uint32_t mag_ctx_get_cpu_virtual_cores(const mag_ctx_t* ctx) { return ctx->sys.cpu_virtual_cores; } +uint32_t mag_ctx_get_cpu_physical_cores(const mag_ctx_t* ctx) { return ctx->sys.cpu_physical_cores; } +uint32_t mag_ctx_get_cpu_sockets(const mag_ctx_t* ctx) { return ctx->sys.cpu_sockets; } +uint64_t mag_ctx_get_physical_memory_total(const mag_ctx_t* ctx) { return ctx->sys.phys_mem_total; } +uint64_t mag_ctx_get_physical_memory_free(const mag_ctx_t* ctx) { return ctx->sys.phys_mem_free; } +bool mag_ctx_is_numa_system(const mag_ctx_t* ctx) { return false; /* TODO */ } +size_t mag_ctx_get_total_tensors_created(const mag_ctx_t* ctx) { return 0; /* TODO */ } + +static mag_intrusive_chunk* mag_fixed_pool_chunk_new(size_t block_size, size_t block_align, size_t blocks_per_chunk) { + size_t cap = blocks_per_chunk*block_size; + uintptr_t size = 0; + mag_pincr((void**)&size, sizeof(mag_intrusive_chunk), __alignof(mag_intrusive_chunk)); + mag_pincr((void**)&size, cap, block_align); + void* base = (*mag_alloc)(NULL, size), *pos = base; + mag_intrusive_chunk* chunk = mag_pincr(&pos, sizeof(mag_intrusive_chunk), __alignof(mag_intrusive_chunk)); + uint8_t* bot = mag_pincr(&pos, cap, block_align); + *chunk = (mag_intrusive_chunk) { + .bot = bot, + .top = bot+cap, + .next = NULL + }; + return chunk; +} + +void mag_fixed_intrusive_pool_init(mag_fixed_intrusive_pool* pool, size_t block_size, size_t block_align, size_t blocks_per_chunk) { + mag_assert2(block_size && blocks_per_chunk); + mag_intrusive_chunk* chunk = mag_fixed_pool_chunk_new(block_size, block_align, blocks_per_chunk); + *pool = (mag_fixed_intrusive_pool) { + .block_size = block_size, + .block_align = block_align, + .blocks_per_chunk = blocks_per_chunk, + .chunks = chunk, + .chunk_head = chunk, + .free_list = NULL, + .num_freelist_hits = 0, + .num_pool_hits = 0, + .num_chunks = 1, + .num_allocs = 0 + }; +} + +void* mag_fixed_intrusive_pool_malloc(mag_fixed_intrusive_pool* pool) { + ++pool->num_allocs; + if (mag_likely(pool->free_list)) { /* 1. Try to pop from free_list (fastest path) */ + ++pool->num_freelist_hits; + void* blk = pool->free_list; + pool->free_list = *(void**)blk; /* Next free block is stored at block [0..sizeof(void*)-1] */ + return blk; + } + mag_intrusive_chunk* chunk = pool->chunk_head; + mag_assert2(chunk); + uint8_t* top = chunk->top-pool->block_size; + if (mag_likely(top >= chunk->bot)) { /* 2. Allocate from the last pool if possible (fast path) */ + ++pool->num_pool_hits; + chunk->top = top; + return top; + } + mag_intrusive_chunk* new_chunk = mag_fixed_pool_chunk_new(pool->block_size, pool->block_align, pool->blocks_per_chunk); /* 3. Current chunk is exhausted, allocate new (slow path) */ + chunk->next = new_chunk; + pool->chunk_head = new_chunk; + new_chunk->top -= pool->block_size; + ++pool->num_chunks; + return new_chunk->top; +} + +void mag_fixed_intrusive_pool_free(mag_fixed_intrusive_pool* pool, void* blk) { /* Push chunk into free list */ + *(void**)blk = pool->free_list; + pool->free_list = blk; +} + +void mag_fixed_intrusive_pool_destroy(mag_fixed_intrusive_pool* pool) { + mag_intrusive_chunk* chunk = pool->chunks; + while (chunk) { + mag_intrusive_chunk* next = chunk->next; + (*mag_alloc)(chunk, 0); + chunk = next; + } + memset(pool, 0, sizeof(*pool)); +} + +MAG_COLDPROC void mag_fixed_intrusive_pool_print_info(mag_fixed_intrusive_pool* pool, const char* name) { + mag_log_info("Fixed Intrusive Pool: %s", name); + mag_log_info( + "\tBlock Size: %zu B, Block Align: %zu B, Blocks Per Chunk: %zu B", + pool->block_size, + pool->block_align, + pool->blocks_per_chunk + ); + mag_log_info( + "\tChunks: %zu, Allocs: %zu, Freelist Hits: %zu, Num Pool Hits: %zu", + (size_t)pool->num_chunks, + (size_t)pool->num_allocs, + (size_t)pool->num_freelist_hits, + (size_t)pool->num_pool_hits + ); + double mem_alloced, pool_mem; + const char* mem_unit_alloced, *mem_unit_pool; + mag_humanize_memory_size(pool->num_chunks*pool->blocks_per_chunk*pool->block_size, &mem_alloced, &mem_unit_alloced); + mag_humanize_memory_size(pool->num_allocs*pool->block_size, &pool_mem, &mem_unit_pool); + mag_log_info("\t Real Mem Allocated: %.03f %s, Total Pool Mem %.03f %s", mem_alloced, mem_unit_alloced, pool_mem, mem_unit_pool); +} + +void mag_ctx_profile_start_recording(mag_ctx_t* ctx) { + if (ctx->profiler_enabled) return; + memset(ctx->op_perf_mons_total, 0, sizeof(ctx->op_perf_mons_total)); + ctx->profiler_enabled = true; +} + +typedef struct mag_op_perf_record_t { + mag_op_perf_info_t perf; + mag_op_t op; +} mag_op_perf_record_t; + +static int mag_cmp_perf_info(const void* x, const void* y) { + const mag_op_perf_record_t* op1 = (const mag_op_perf_record_t *)x; + const mag_op_perf_record_t* op2 = (const mag_op_perf_record_t *)y; + if (op1->perf.elapsed_ns_acc < op2->perf.elapsed_ns_acc) return 1; + if (op1->perf.elapsed_ns_acc > op2->perf.elapsed_ns_acc) return -1; + return 0; +} + +void mag_ctx_profile_stop_recording(mag_ctx_t* ctx, const char* export_csv_file) { + mag_assert(ctx->profiler_enabled, "Profiler must be enabled to generate report"); + ctx->profiler_enabled = false; + bool csv = export_csv_file && *export_csv_file; + if (!csv) { + mag_print_separator(stdout); + printf("OS/Kernel: %s\n", ctx->sys.os_name); + printf("CPU: %s, Virtual Cores: %u, Physical Cores: %u, Sockets: %u\n", ctx->sys.cpu_name, ctx->sys.cpu_virtual_cores, ctx->sys.cpu_physical_cores, ctx->sys.cpu_sockets); + #if defined(__x86_64__) || defined(_M_X64) /* Print CPU features for x86-64 platforms. */ + printf("CPU Features:"); + for (unsigned i=0, k=0; i < MAG_X86_64_FEATURE__COUNT; ++i) { + if (mag_ctx_x86_64_cpu_has_feature(ctx, i)) { + if (k++ % 8 == 0) printf("\n\t"); + printf("%s ", mag_x86_64_feature_names[i]); + } + } + putchar('\n'); + #endif + double mem_total, mem_free, mem_used; + const char* mem_unit_total, *mem_unit_free, *mem_unit_used; + mag_humanize_memory_size(ctx->sys.phys_mem_total, &mem_total, &mem_unit_total); + mag_humanize_memory_size(ctx->sys.phys_mem_free, &mem_free, &mem_unit_free); + mag_humanize_memory_size((size_t)llabs((int64_t)ctx->sys.phys_mem_total-(int64_t)ctx->sys.phys_mem_free), &mem_used, &mem_unit_used); + double mem_used_percent = fabs((double)(ctx->sys.phys_mem_total-ctx->sys.phys_mem_free))/(double)ctx->sys.phys_mem_total*100.0; + printf("Physical memory: %.03f %s, Free: %.03f %s, Used: %.03f %s (%.02f%%)\n", mem_total, mem_unit_total, mem_free, mem_unit_free, mem_used, mem_unit_used, mem_used_percent); + mag_print_separator(stdout); + printf("%16s %16s %16s %16s %16s\n", "Operation", "Executions", "Usage (%)", "AVG Time (μs)", "Total Time (μs)"); + } + mag_op_perf_record_t sorted[MAG_OP__NUM]; + uint64_t exec_total = 0; + for (mag_op_t op=MAG_OP_NOP; op < MAG_OP__NUM; ++op) { /* Convert to sortable record. */ + sorted[op].op = op; + sorted[op].perf = ctx->op_perf_mons_total[op]; + exec_total += sorted[op].perf.n_execs; + } + if (mag_unlikely(!exec_total) && !csv) { + printf("\n! No operations profiled. Enable profiler and execute any operation to see results.\n"); + mag_print_separator(stdout); + return; + } + qsort(sorted, MAG_OP__NUM, sizeof(*sorted), &mag_cmp_perf_info); /* Quicksort by time descending. */ + FILE* f = NULL; + if (csv) { + f = mag_fopen(export_csv_file, "wt"); + mag_assert(f, "Failed to open CSV file: %s", export_csv_file); + fprintf(f, "Operation,Executions,Usage,AVG Time,Total Time\n"); /* CSV Header */ + } + for (mag_op_t i=MAG_OP_NOP; i < MAG_OP__NUM; ++i) { /* Format sorted performance data */ + const mag_op_perf_record_t* info = sorted+i; + const mag_op_perf_info_t* perf = &info->perf; + if (!perf->n_execs) continue; /* Op never executed. */ + const char* op_name = mag_op_meta_of(info->op)->mnemonic; + double perc_exec = (double)perf->n_execs/(double)exec_total * 100.0; + char perc_exec_str[64]; + snprintf(perc_exec_str, sizeof(perc_exec_str), "%.1f", perc_exec); + double avg_time = (double)perf->elapsed_ns_acc/1e3/(double)perf->n_execs; + char avg_time_str[64]; + snprintf(avg_time_str, sizeof(avg_time_str), "%f", avg_time); + double tot_time = (double)perf->elapsed_ns_acc/1e3; + char tot_time_str[64]; + snprintf(tot_time_str, sizeof(tot_time_str), "%f", tot_time); + if (csv) { + fprintf(f, "%s,%" PRIu64 ",%s,%s,%s\n", op_name, perf->n_execs, perc_exec_str, avg_time_str, tot_time_str); + } else { + printf("%16s %16" PRIu64 " %16s%16s%16s\n", op_name, perf->n_execs, perc_exec_str, avg_time_str, tot_time_str); + } + } + if (csv) fclose(f); + else { + putchar('\n'); + printf("Total operations profiled: %" PRIu64 "\n", exec_total); + mag_print_separator(stdout); + } +} + +uint32_t mag_pack_color_u8(uint8_t r, uint8_t g, uint8_t b) { return ((uint32_t)r<<16)|((uint32_t)g<<8)|(uint32_t)b; } +uint32_t mag_pack_color_f32(float r, float g, float b) { + return (((uint32_t)(r*255.0f)&255)<<16)|(((uint32_t)(g*255.0f)&255)<<8)|((uint32_t)(b*255.0f)&255); +} + +const char* mag_device_type_get_name(mag_compute_device_type_t op) { + static const char* const names[MAG_COMPUTE_DEVICE_TYPE__NUM] = { + [MAG_COMPUTE_DEVICE_TYPE_CPU] = "CPU", + [MAG_COMPUTE_DEVICE_TYPE_GPU_CUDA] = "CUDA GPU", + }; + return names[op]; +} + +const mag_dtype_meta_t* mag_dtype_meta_of(mag_dtype_t type) { + static const mag_dtype_meta_t infos[MAG_DTYPE__NUM] = { + [MAG_DTYPE_F32] = { + sizeof(float), + "f32" + }, + }; + return &infos[type]; +} + +/* +** validation error print template +** +** +printf("SHORT ERROR DESCRIPTION" + "ERROR: Failed to execute operation: %s.\n" + " - Input Tensor 1 '%s': MISMATCHES\n" + " - Input Tensor 2 '%s': MISMATCHES\n" + " Hint: ANY HINT FOR USR." +); +*/ + +static bool mag_validate_inputs(mag_op_t op, mag_tensor_t** inputs, uint32_t numin) { + const mag_op_meta_t* meta = mag_op_meta_of(op); + if (mag_unlikely(meta->argcount != numin || numin > MAG_MAX_INPUT_TENSORS)) { + mag_print_separator(stderr); + fprintf(stderr, + "Failed to execute operation: %s.\n" + "ERROR: Operation requires %u input tensors, but %u were provided.\n" + " Hint: Ensure the correct number of input tensors are provided.\n", + meta->mnemonic, meta->argcount, numin + ); + mag_print_separator(stderr); + fputc('\n', stderr); + fflush(stderr); + return false; + } + for (uint32_t i=0; i < meta->argcount; ++i) { + if (mag_unlikely(!inputs[i])) { + mag_print_separator(stderr); + fprintf(stderr, + "Failed to execute operation: %s.\n" + "ERROR: Input tensor %u is NULL.\n" + " Hint: Ensure all input tensors are valid and non-NULL.\n", + meta->mnemonic, i + ); + mag_print_separator(stderr); + fputc('\n', stderr); + fflush(stderr); + return false; + } + } + return true; +} + +static bool mag_validate_op_params(mag_op_t op, const mag_op_param_t* params, uint32_t numparams) { + const mag_op_meta_t* meta = mag_op_meta_of(op); + if (!meta->paramcount) return true; /* No parameters to validate. */ + if (mag_unlikely(meta->paramcount != numparams || numparams > MAG_MAX_OP_PARAMS)) { + mag_print_separator(stderr); + fprintf(stderr, + "Failed to execute operation: %s.\n" + "ERROR: Operation requires at most %u parameters, but %u were provided.\n" + " Hint: Ensure the correct number of operation parameters are provided.\n", + meta->mnemonic, MAG_MAX_OP_PARAMS, meta->paramcount + ); + mag_print_separator(stderr); + fputc('\n', stderr); + fflush(stderr); + return false; + } + if (mag_unlikely(!params)) { + mag_print_separator(stderr); + fprintf(stderr, + "Failed to execute operation: %s.\n" + "ERROR: Operation parameters are NULL.\n" + " Hint: Ensure all operation parameters are valid and non-NULL.\n", + meta->mnemonic + ); + mag_print_separator(stderr); + fputc('\n', stderr); + fflush(stderr); + return false; + } + for (uint32_t i=0; i < meta->paramcount; ++i) { + if (mag_unlikely(params[i].type != meta->param_types[i])) { + mag_print_separator(stderr); + fprintf(stderr, + "Failed to execute operation: %s.\n" + "ERROR: Operation parameter %u type mismatch.\n" + " - Expected type id: %d\n" + " - Provided type id: %d\n" + " Hint: Ensure the correct parameter types are provided.\n", + meta->mnemonic, i, meta->param_types[i], params[i].type + ); + } + } + return true; +} + +static bool mag_validate_shape_eq(mag_op_t op, const mag_tensor_t* a, const mag_tensor_t* b) { + const mag_op_meta_t* meta = mag_op_meta_of(op); + if (mag_likely(mag_tensor_is_shape_eq(a, b))) return true; + mag_print_separator(stderr); + char shape_1[MAG_FMT_DIM_BUF_SIZE]; + char shape_2[MAG_FMT_DIM_BUF_SIZE]; + mag_fmt_dims(&shape_1, &a->shape, a->rank); + mag_fmt_dims(&shape_2, &b->shape, b->rank); + fprintf(stderr, + "Failed to execute operation: %s.\n" + "ERROR: Input tensor shapes must be equal.\n" + " - Input Tensor 1 '%s' Shape: %s\n" + " - Input Tensor 2 '%s' Shape: %s\n" + " Hint: Adjust tensor shapes using transposition or permutation.\n", + meta->mnemonic, + a->name, shape_1, + b->name, shape_2 + ); + mag_print_separator(stderr); + fputc('\n', stderr); + fflush(stderr); + return false; +} + +static bool mag_validate_shape_broadcastable(mag_op_t op, const mag_tensor_t* a, const mag_tensor_t* b) { /* Check if tensor shapes are broadcast-able. (b into a) */ + const mag_op_meta_t* meta = mag_op_meta_of(op); + if (mag_likely(mag_tensor_can_broadcast(b, a))) return true; + mag_print_separator(stderr); + char shape_1[MAG_FMT_DIM_BUF_SIZE]; + char shape_2[MAG_FMT_DIM_BUF_SIZE]; + mag_fmt_dims(&shape_1, &a->shape, a->rank); + mag_fmt_dims(&shape_2, &b->shape, b->rank); + char broadcast_able_str[MAG_MAX_DIMS*2+4+1] = {0}; + char* p = broadcast_able_str; + *p++ = '['; + for (uint32_t i=0; i < MAG_MAX_DIMS; ++i) { + *p++ = a->shape[i] % b->shape[i] == 0 ? 'Y' : 'N'; + *p++ = i < MAG_MAX_DIMS-1 ? ',' : ']'; + } + *p = '\0'; + fprintf(stderr, + "Failed to execute operation: %s.\n" + "ERROR: Input tensor shapes must be broadcast-able.\n" + " - Input Tensor 1 '%s' Shape: %s\n" + " - Input Tensor 2 '%s' Shape: %s\n" + " Broadcast-able: %s\n" + " Hint: Adjust tensor shapes using transposition or permutation.\n", + meta->mnemonic, + a->name, shape_1, + b->name, shape_2, + broadcast_able_str + ); + mag_print_separator(stderr); + fputc('\n', stderr); + fflush(stderr); + return false; +} + +#define mag_validate_expr_gen(expr, message, ...) \ + if (mag_unlikely(!(expr))) { \ + if (1) { \ + mag_log_error(message, ## __VA_ARGS__); \ + } \ + return false; \ + } + +static bool mag_validate_op_unary(mag_op_t op, mag_tensor_t* result, mag_tensor_t** inputs, const mag_op_param_t* params) { + if (mag_unlikely(!mag_validate_shape_eq(op, result, inputs[0]))) return false; + return true; +} + +static bool mag_validate_op_binary(mag_op_t op, mag_tensor_t* result, mag_tensor_t** inputs, const mag_op_param_t* params) { + if (mag_unlikely(!mag_validate_shape_eq(op, result, inputs[0]))) return false; + if (mag_unlikely(!mag_validate_shape_broadcastable(op, inputs[0], inputs[1]))) return false; + mag_validate_expr_gen(mag_tensor_is_contiguous(result), "Result must be contiguous."); + mag_validate_expr_gen(mag_tensor_is_contiguous(inputs[0]), "First tensor must be contiguous."); + return true; +} + +static bool mag_validate_op_transpose(mag_op_t op, mag_tensor_t* result, mag_tensor_t** inputs, const mag_op_param_t* params) { + return true; +} + +static bool mag_validate_op_scalar(mag_op_t op, mag_tensor_t* result, mag_tensor_t** inputs, const mag_op_param_t* params) { + mag_validate_expr_gen(mag_tensor_is_contiguous(inputs[0]), "Mean"); /* TODO */ + mag_validate_expr_gen(result->shape[0] == 1, "Mean"); + mag_validate_expr_gen(result->shape[1] == inputs[0]->shape[1], "Mean"); + mag_validate_expr_gen(result->shape[2] == inputs[0]->shape[2], "Mean"); + mag_validate_expr_gen(result->shape[3] == inputs[0]->shape[3], "Mean"); + return true; +} + +static bool mag_validate_op_matmul(mag_op_t op, mag_tensor_t* result, mag_tensor_t** inputs, const mag_op_param_t* params) { + mag_validate_expr_gen(inputs[0]->shape[1] == inputs[1]->shape[0], "Input tensor shapes must be compatible for matrix multiplication."); + mag_validate_expr_gen(mag_tensor_is_contiguous(inputs[0]), "First tensor must be contiguous."); + mag_validate_expr_gen(mag_tensor_is_contiguous(inputs[1]), "Second tensor must be contiguous."); + mag_validate_expr_gen(inputs[1]->shape[2] % inputs[0]->shape[2] == 0, "Result tensor shape mismatch."); + mag_validate_expr_gen(inputs[1]->shape[3] % inputs[0]->shape[3] == 0, "Result tensor shape mismatch."); + return true; +} + +static mag_tensor_t* mag_tensor_create(mag_ctx_t* ctx, mag_dtype_t type, const int64_t* dims, int64_t rank, mag_tensor_t* view, size_t view_offs); + +static mag_tensor_t* mag_result_constructor_routine_isomorph(mag_tensor_t** inputs, const mag_op_param_t* params) { + (void)params; + return mag_tensor_create(inputs[0]->ctx, inputs[0]->dtype, inputs[0]->shape, inputs[0]->rank, NULL, 0); +} + +static mag_tensor_t* mag_result_constructor_routine_view(mag_tensor_t** inputs, const mag_op_param_t* params) { + (void)params; + return mag_tensor_create(inputs[0]->ctx, inputs[0]->dtype, inputs[0]->shape, MAG_MAX_DIMS, inputs[0], 0); +} + +static mag_tensor_t* mag_result_constructor_routine_scalar(mag_tensor_t** inputs, const mag_op_param_t* params) { + int64_t shape[MAG_MAX_DIMS]; + *shape = 1; + #pragma GCC unroll 5 + for (uint32_t i=1; i < MAG_MAX_DIMS; ++i) + shape[i] = inputs[0]->shape[i]; + return mag_tensor_create(inputs[0]->ctx, inputs[0]->dtype, shape, MAG_MAX_DIMS, NULL, 0); +} + +static mag_tensor_t* mag_result_constructor_routine_transposed(mag_tensor_t** inputs, const mag_op_param_t* params) { + mag_tensor_t* transposed = mag_result_constructor_routine_view(inputs, params); + mag_swap(int64_t, transposed->shape[0], transposed->shape[1]); + mag_swap(int64_t, transposed->strides[0], transposed->strides[1]); + return transposed; +} + +static mag_tensor_t* mag_result_constructor_routine_permuted(mag_tensor_t** inputs, const mag_op_param_t* params) { + mag_assert2(params != NULL); /* TODO */ + mag_tensor_t* permuted = mag_result_constructor_routine_view(inputs, params); + uint32_t axes[MAG_MAX_DIMS]; + for (uint32_t i = 0; i < MAG_MAX_DIMS; ++i) /* Unpack axes */ + axes[i] = params[i].x.u32; + for (uint32_t i = 0; i < MAG_MAX_DIMS; ++i) { /* Check that all axes are unique */ + for (uint32_t j = i+1; j < MAG_MAX_DIMS; ++j) + mag_assert(axes[i] != axes[j], "Axes must be unique: %zu != %zu", axes[i], axes[j]); + } + for (uint32_t i=0; i < MAG_MAX_DIMS; ++i) { /* Permute shape and strides */ + mag_assert2(axes[i] >= 0 && axes[i] < MAG_MAX_DIMS); + permuted->shape[axes[i]] = inputs[0]->shape[i]; + permuted->strides[axes[i]] = inputs[0]->strides[i]; + } + return permuted; +} + +static mag_tensor_t* mag_result_constructor_routine_matmul(mag_tensor_t** inputs, const mag_op_param_t* params) { /* MxR = MxN * NxR */ + (void)params; + int64_t shape[MAG_MAX_DIMS]; + shape[0] = inputs[0]->shape[0]; /* M */ + shape[1] = inputs[1]->shape[1]; /* R */ + return mag_tensor_create(inputs[0]->ctx, MAG_DTYPE_F32, shape, 2, NULL, 0); +} + +const mag_op_meta_t* mag_op_meta_of(mag_op_t type) { + static const mag_op_meta_t infos[MAG_OP__NUM] = { + [MAG_OP_NOP] = { + .mnemonic = "nop", + .argcount = 0, + .paramcount = 0, + .param_types = {MAG_OP_TPARAM_NONE}, + .inplace = false, + .r_alloc = NULL, + .validator = NULL + }, + [MAG_OP_CLONE] = { + .mnemonic = "clone", + .argcount = 1, + .paramcount = 0, + .param_types = {MAG_OP_TPARAM_NONE}, + .inplace = false, + .r_alloc = &mag_result_constructor_routine_isomorph, + .validator = &mag_validate_op_unary + }, + [MAG_OP_VIEW] = { + .mnemonic = "view", + .argcount = 1, + .paramcount = 0, + .param_types = {MAG_OP_TPARAM_NONE}, + .inplace = false, + .r_alloc = &mag_result_constructor_routine_view, + .validator = &mag_validate_op_unary + }, + [MAG_OP_TRANSPOSE] = { + .mnemonic = "transpose", + .argcount = 1, + .paramcount = 0, + .param_types = {MAG_OP_TPARAM_NONE}, + .inplace = false, + .r_alloc = &mag_result_constructor_routine_transposed, + .validator = &mag_validate_op_transpose + }, + [MAG_OP_PERMUTE] = { + .mnemonic = "permute", + .argcount = 1, + .paramcount = MAG_MAX_DIMS, + .param_types = { + MAG_OP_TPARAM_U32, + MAG_OP_TPARAM_U32, + MAG_OP_TPARAM_U32, + MAG_OP_TPARAM_U32, + MAG_OP_TPARAM_U32, + MAG_OP_TPARAM_U32, + }, + .inplace = false, + .r_alloc = &mag_result_constructor_routine_permuted, + .validator = &mag_validate_op_transpose + }, + [MAG_OP_MEAN] = { + .mnemonic = "mean", + .argcount = 1, + .paramcount = 0, + .param_types = {MAG_OP_TPARAM_NONE}, + .inplace = false, + .r_alloc = &mag_result_constructor_routine_scalar, + .validator = &mag_validate_op_scalar + }, + [MAG_OP_MIN] = { + .mnemonic = "min", + .argcount = 1, + .paramcount = 0, + .param_types = {MAG_OP_TPARAM_NONE}, + .inplace = false, + .r_alloc = &mag_result_constructor_routine_scalar, + .validator = &mag_validate_op_scalar + }, + [MAG_OP_MAX] = { + .mnemonic = "max", + .argcount = 1, + .paramcount = 0, + .param_types = {MAG_OP_TPARAM_NONE}, + .inplace = false, + .r_alloc = &mag_result_constructor_routine_scalar, + .validator = &mag_validate_op_scalar + }, + [MAG_OP_SUM] = { + .mnemonic = "sum", + .argcount = 1, + .paramcount = 0, + .param_types = {MAG_OP_TPARAM_NONE}, + .inplace = false, + .r_alloc = &mag_result_constructor_routine_scalar, + .validator = &mag_validate_op_scalar + }, + [MAG_OP_ABS] = { + .mnemonic = "abs", + .argcount = 1, + .paramcount = 0, + .param_types = {MAG_OP_TPARAM_NONE}, + .inplace = true, + .r_alloc = &mag_result_constructor_routine_isomorph, + .validator = &mag_validate_op_unary + }, + [MAG_OP_NEG] = { + .mnemonic = "neg", + .argcount = 1, + .paramcount = 0, + .param_types = {MAG_OP_TPARAM_NONE}, + .inplace = true, + .r_alloc = &mag_result_constructor_routine_isomorph, + .validator = &mag_validate_op_unary + }, + [MAG_OP_LOG] = { + .mnemonic = "log", + .argcount = 1, + .paramcount = 0, + .param_types = {MAG_OP_TPARAM_NONE}, + .inplace = true, + .r_alloc = &mag_result_constructor_routine_isomorph, + .validator = &mag_validate_op_unary + }, + [MAG_OP_SQR] = { + .mnemonic = "sqr", + .argcount = 1, + .paramcount = 0, + .param_types = {MAG_OP_TPARAM_NONE}, + .inplace = true, + .r_alloc = &mag_result_constructor_routine_isomorph, + .validator = &mag_validate_op_unary + }, + [MAG_OP_SQRT] = { + .mnemonic = "sqrt", + .argcount = 1, + .paramcount = 0, + .param_types = {MAG_OP_TPARAM_NONE}, + .inplace = true, + .r_alloc = &mag_result_constructor_routine_isomorph, + .validator = &mag_validate_op_unary + }, + [MAG_OP_SIN] = { + .mnemonic = "sin", + .argcount = 1, + .paramcount = 0, + .param_types = {MAG_OP_TPARAM_NONE}, + .inplace = true, + .r_alloc = &mag_result_constructor_routine_isomorph, + .validator = &mag_validate_op_unary + }, + [MAG_OP_COS] = { + .mnemonic = "cos", + .argcount = 1, + .paramcount = 0, + .param_types = {MAG_OP_TPARAM_NONE}, + .inplace = true, + .r_alloc = &mag_result_constructor_routine_isomorph, + .validator = &mag_validate_op_unary + }, + [MAG_OP_STEP] = { + .mnemonic = "step", + .argcount = 1, + .paramcount = 0, + .param_types = {MAG_OP_TPARAM_NONE}, + .inplace = true, + .r_alloc = &mag_result_constructor_routine_isomorph, + .validator = &mag_validate_op_unary + }, + [MAG_OP_SOFTMAX] = { + .mnemonic = "softmax", + .argcount = 1, + .paramcount = 0, + .param_types = {MAG_OP_TPARAM_NONE}, + .inplace = true, + .r_alloc = &mag_result_constructor_routine_isomorph, + .validator = &mag_validate_op_unary + }, + [MAG_OP_SOFTMAX_DV] = { + .mnemonic = "softmax_dv", + .argcount = 1, + .paramcount = 0, + .param_types = {MAG_OP_TPARAM_NONE}, + .inplace = true, + .r_alloc = &mag_result_constructor_routine_isomorph, + .validator = &mag_validate_op_unary + }, + [MAG_OP_SIGMOID] = { + .mnemonic = "sigmoid", + .argcount = 1, + .paramcount = 0, + .param_types = {MAG_OP_TPARAM_NONE}, + .inplace = true, + .r_alloc = &mag_result_constructor_routine_isomorph, + .validator = &mag_validate_op_unary + }, + [MAG_OP_SIGMOID_DV] = { + .mnemonic = "sigmoid_dv", + .argcount = 1, + .paramcount = 0, + .param_types = {MAG_OP_TPARAM_NONE}, + .inplace = true, + .r_alloc = &mag_result_constructor_routine_isomorph, + .validator = &mag_validate_op_unary + }, + [MAG_OP_HARD_SIGMOID] = { + .mnemonic = "hard_sigmoid", + .argcount = 1, + .paramcount = 0, + .param_types = {MAG_OP_TPARAM_NONE}, + .inplace = true, + .r_alloc = &mag_result_constructor_routine_isomorph, + .validator = &mag_validate_op_unary + }, + [MAG_OP_SILU] = { + .mnemonic = "silu", + .argcount = 1, + .paramcount = 0, + .param_types = {MAG_OP_TPARAM_NONE}, + .inplace = true, + .r_alloc = &mag_result_constructor_routine_isomorph, + .validator = &mag_validate_op_unary + }, + [MAG_OP_SILU_DV] = { + .mnemonic = "silu_dv", + .argcount = 1, + .paramcount = 0, + .param_types = {MAG_OP_TPARAM_NONE}, + .inplace = true, + .r_alloc = &mag_result_constructor_routine_isomorph, + .validator = &mag_validate_op_unary + }, + [MAG_OP_TANH] = { + .mnemonic = "tanh", + .argcount = 1, + .paramcount = 0, + .param_types = {MAG_OP_TPARAM_NONE}, + .inplace = true, + .r_alloc = &mag_result_constructor_routine_isomorph, + .validator = &mag_validate_op_unary + }, + [MAG_OP_TANH_DV] = { + .mnemonic = "tanh_dv", + .argcount = 1, + .paramcount = 0, + .param_types = {MAG_OP_TPARAM_NONE}, + .inplace = true, + .r_alloc = &mag_result_constructor_routine_isomorph, + .validator = &mag_validate_op_unary + }, + [MAG_OP_RELU] = { + .mnemonic = "relu", + .argcount = 1, + .paramcount = 0, + .param_types = {MAG_OP_TPARAM_NONE}, + .inplace = true, + .r_alloc = &mag_result_constructor_routine_isomorph, + .validator = &mag_validate_op_unary + }, + [MAG_OP_RELU_DV] = { + .mnemonic = "relu_dv", + .argcount = 1, + .paramcount = 0, + .param_types = {MAG_OP_TPARAM_NONE}, + .inplace = true, + .r_alloc = &mag_result_constructor_routine_isomorph, + .validator = &mag_validate_op_unary + }, + [MAG_OP_GELU] = { + .mnemonic = "gelu", + .argcount = 1, + .paramcount = 0, + .param_types = {MAG_OP_TPARAM_NONE}, + .inplace = true, + .r_alloc = &mag_result_constructor_routine_isomorph, + .validator = &mag_validate_op_unary + }, + [MAG_OP_GELU_DV] = { + .mnemonic = "gelu_dv", + .argcount = 1, + .paramcount = 0, + .param_types = {MAG_OP_TPARAM_NONE}, + .inplace = true, + .r_alloc = &mag_result_constructor_routine_isomorph, + .validator = &mag_validate_op_unary + }, + [MAG_OP_ADD] = { + .mnemonic = "add", + .argcount = 2, + .paramcount = 0, + .param_types = {MAG_OP_TPARAM_NONE}, + .inplace = true, + .r_alloc = &mag_result_constructor_routine_isomorph, + .validator = &mag_validate_op_binary + }, + [MAG_OP_SUB] = { + .mnemonic = "sub", + .argcount = 2, + .paramcount = 0, + .param_types = {MAG_OP_TPARAM_NONE}, + .inplace = true, + .r_alloc = &mag_result_constructor_routine_isomorph, + .validator = &mag_validate_op_binary + }, + [MAG_OP_MUL] = { + .mnemonic = "mul", + .argcount = 2, + .paramcount = 0, + .param_types = {MAG_OP_TPARAM_NONE}, + .inplace = true, + .r_alloc = &mag_result_constructor_routine_isomorph, + .validator = &mag_validate_op_binary + }, + [MAG_OP_DIV] = { + .mnemonic = "div", + .argcount = 2, + .paramcount = 0, + .param_types = {MAG_OP_TPARAM_NONE}, + .inplace = true, + .r_alloc = &mag_result_constructor_routine_isomorph, + .validator = &mag_validate_op_binary + }, + [MAG_OP_ADDS] = { + .mnemonic = "adds", + .argcount = 1, + .paramcount = 1, + .param_types = {MAG_OP_TPARAM_F32}, + .inplace = true, + .r_alloc = &mag_result_constructor_routine_isomorph, + .validator = &mag_validate_op_unary + }, + [MAG_OP_SUBS] = { + .mnemonic = "subs", + .argcount = 1, + .paramcount = 1, + .param_types = {MAG_OP_TPARAM_F32}, + .inplace = true, + .r_alloc = &mag_result_constructor_routine_isomorph, + .validator = &mag_validate_op_unary + }, + [MAG_OP_MULS] = { + .mnemonic = "muls", + .argcount = 1, + .paramcount = 1, + .param_types = {MAG_OP_TPARAM_F32}, + .inplace = true, + .r_alloc = &mag_result_constructor_routine_isomorph, + .validator = &mag_validate_op_unary + }, + [MAG_OP_DIVS] = { + .mnemonic = "divs", + .argcount = 1, + .paramcount = 1, + .param_types = {MAG_OP_TPARAM_F32}, + .inplace = true, + .r_alloc = &mag_result_constructor_routine_isomorph, + .validator = &mag_validate_op_unary + }, + [MAG_OP_MATMUL] = { + .mnemonic = "matmul", + .argcount = 2, + .paramcount = 0, + .param_types = {}, + .inplace = true, + .r_alloc = &mag_result_constructor_routine_matmul, + .validator = &mag_validate_op_matmul + } + }; + return infos+type; +} + +#undef mag_validate_inputs + +int64_t mag_tensor_data_size(const mag_tensor_t* t) { return t->numel*mag_dtype_meta_of(t->dtype)->size; } +int64_t mag_tensor_numel(const mag_tensor_t* t) { return t->numel; } +int64_t mag_tensor_num_rows(const mag_tensor_t* t) { + mag_static_assert(MAG_MAX_DIMS == 6); + mag_load_local_storage_group(t, d, shape); + return d1*d2*d3*d4*d5; +} +int64_t mag_tensor_num_cols(const mag_tensor_t* t) { return *t->shape; } + +#if MAG_SANITIZE_RC + static void mag_tensor_sanitize_dtor(mag_tensor_t* t) { + (void)t; + } +#endif + +static mag_tensor_t* mag_tensor_create(mag_ctx_t* ctx, mag_dtype_t type, const int64_t* dims, int64_t rank, mag_tensor_t* view, size_t view_offs) { + uintptr_t tr_id = mag_thread_id(); + mag_assert(tr_id == ctx->tr_id, "%" PRIx64 " != %" PRIx64 " Tensor must be created on the same thread as the context.", tr_id, ctx->tr_id); + mag_assert(dims != NULL && rank >= 0 && rank <= MAG_MAX_DIMS, "Rank must be within (0, %d]", MAG_MAX_DIMS); + mag_assert2(view_offs == 0); /* NYI. TODO */ + if (view) { + if (view->view_uplink) { /* Traverse view chain and accumulate offset */ + view_offs += view->view_offs; + view = view->view_uplink; + } + mag_tensor_incref(view); /* Increment view tensor strong RC */ + } + int64_t dts = mag_dtype_meta_of(type)->size; + int64_t numel = 1; + for (int64_t i=0; i < rank; ++i) /* Calculate buffer size and check for overflow. */ + mag_assert2(dims[i] > 0 && !mag_imull64_ov(dims[i], numel, &numel)); /* Overflow in buffer size. Max: INT64_MAX. Reduce dimensions. */ + int64_t numbytes = numel*dts; + mag_assert2(!view || !numbytes || numbytes + view_offs <= mag_tensor_data_size(view)); /* Slice must be within viewed tensor data range. *//* Allocate memory for tensor struct on CPU RAM. */ + mag_tensor_t* t = mag_fixed_intrusive_pool_malloc(&ctx->tensor_pool); + memset(t, 0, sizeof(*t)); + *t = (mag_tensor_t) { + .rcb = { + .rc_strong = 0, + .rc_weak = 0, + #if MAG_SANITIZE_RC + .dtor = &mag_tensor_sanitize_dtor + #endif + }, + .ctx = ctx, + .rank = rank, + .shape = {0}, + .strides = {0}, + .dtype = type, + .storage = {0}, + .numel = numel, + .flags = view ? MAG_TFLAG_VIEW : MAG_TFLAG_OWNER, + .op = MAG_OP_NOP, + .op_inputs = {0}, + .op_params = {{0}}, + .view_uplink = view, + .view_offs = view_offs, + .grad = NULL, + .pmon = {0}, + .name = "", + .ud = NULL + }; + mag_tensor_incref(t); /* First strong RC=1 */ + /* Allocate device memory */ + mag_compute_device_t* dvc = ctx->device; + void (*allocator)(mag_compute_device_t*, mag_storage_buffer_t*, size_t, size_t) = dvc->alloc_storage; + if (view) t->storage = view->storage; /* Reference memory from view */ + else (*allocator)(dvc, &t->storage, numbytes, dts); /* Allocate new device memory */ + #pragma GCC unroll 6 + for (uint32_t i=0; i < MAG_MAX_DIMS; ++i) /* Copy dimensions and set unused to identity. */ + t->shape[i] = i < rank ? dims[i] : 1; + *t->strides = 1; + #pragma GCC unroll 5 + for (uint32_t i=1; i < MAG_MAX_DIMS; ++i) /* Calculate strides and check for overflow. */ + mag_assert2(!mag_imull64_ov(t->strides[i-1], t->shape[i-1], t->strides+i)); +#if MAG_SANITIZE_RC /* If tensor RC sanitize is enabled, insert into tracking list */ + mag_tensor_node_t** head = &ctx->rc_tracked; + mag_tensor_node_t* node = (*mag_alloc)(NULL, sizeof(*node)); + *node = (mag_tensor_node_t) { + .tensor = t, + .next = *head + }; + *head = node; +#endif + return t; +} + +static void mag_tensor_destroy(mag_tensor_t* t) { + mag_ctx_t* ctx = t->ctx; +#if MAG_SANITIZE_RC /* If tensor RC sanitize is enabled, invoke destructor and erase from tracking list */ + void (*dtor)(mag_tensor_t*) = t->rcb.dtor; /* Invoke Debug destructor. */ + if (dtor) (*dtor)(t); + mag_tensor_node_t** head = &ctx->rc_tracked; + if (*head) { + mag_tensor_node_t* curr = *head, *prev = NULL; + if (curr->tensor == t) { /* Head itself holds key */ + *head = curr->next; + (*mag_alloc)(curr, 0); + } else { + while (curr && curr->tensor != t) { /* Find node */ + prev = curr; + curr = curr->next; + } + if (curr) { /* Found node */ + prev->next = curr->next; + (*mag_alloc)(curr, 0); /* Free node */ + } + } + } +#endif + if (t->flags & MAG_TFLAG_OWNER) { /* Free device memory if tensor owns it. */ + mag_compute_device_t* dvc = t->ctx->device; + void (*dtor)(mag_compute_device_t*, mag_storage_buffer_t*) = dvc->free_storage; + (*dtor)(dvc, &t->storage); + } + mag_fixed_intrusive_pool_free(&ctx->tensor_pool, t); +} + +void mag_tensor_incref(mag_tensor_t* t) { + mag_assert2(t->rcb.rc_strong < UINT32_MAX); + ++t->rcb.rc_strong; +} + +bool mag_tensor_decref(mag_tensor_t* t) { + if (t->view_uplink) { /* If tensor is a view, decrement base RC and free tensor chain */ + mag_tensor_decref(t->view_uplink); + } + if (!--t->rcb.rc_strong) { /* Strong RC reaches zero, destroy. */ + mag_tensor_destroy(t); + return true; + } + return false; +} + +mag_tensor_t* mag_tensor_create_1d(mag_ctx_t* ctx, mag_dtype_t type, int64_t d1) { + return mag_tensor_create(ctx, type, (int64_t[]) {d1}, 1, NULL, 0); +} + +mag_tensor_t* mag_tensor_create_2d(mag_ctx_t* ctx, mag_dtype_t type, int64_t d1, int64_t d2) { + return mag_tensor_create(ctx, type, (int64_t[]) {d1, d2}, 2, NULL, 0); +} + +mag_tensor_t* mag_tensor_create_3d(mag_ctx_t* ctx, mag_dtype_t type, int64_t d1, int64_t d2, int64_t d3) { + return mag_tensor_create(ctx, type, (int64_t[]) {d1, d2, d3}, 3, NULL, 0); +} + +mag_tensor_t* mag_tensor_create_4d(mag_ctx_t* ctx, mag_dtype_t type, int64_t d1, int64_t d2, int64_t d3, int64_t d4) { + return mag_tensor_create(ctx, type, (int64_t[]) {d1, d2, d3, d4}, 4, NULL, 0); +} + +mag_tensor_t* mag_tensor_create_5d(mag_ctx_t* ctx, mag_dtype_t type, int64_t d1, int64_t d2, int64_t d3, int64_t d4, int64_t d5) { + return mag_tensor_create(ctx, type, (int64_t[]) {d1, d2, d3, d4, d5}, 5, NULL, 0); +} + +mag_tensor_t* mag_tensor_create_6d(mag_ctx_t* ctx, mag_dtype_t type, int64_t d1, int64_t d2, int64_t d3, int64_t d4, int64_t d5, int64_t d6) { + return mag_tensor_create(ctx, type, (int64_t[]) {d1, d2, d3, d4, d5, d6}, 6, NULL, 0); +} + +static void MAG_HOTPROC mag_op_exec(mag_tensor_t* R, mag_compute_device_t* dvc, mag_graph_eval_order_t ord) { + mag_perf_mon_t* pmon = &R->pmon; + mag_op_perf_info_t (*pmon_ops)[MAG_OP__NUM] = &R->ctx->op_perf_mons_total; + mag_op_perf_info_t* pmon_op = (*pmon_ops)+R->op; + uint64_t start = R->ctx->profiler_enabled ? mag_hpc_clock_ns() : 0; /* Profiling monitoring */ + void (*exec)(mag_compute_device_t*, mag_tensor_t*) = ord == MAG_GRAPH_EVAL_ORDER_FORWARD ? dvc->eager_exec_fwd : dvc->eager_exec_bwd; + (*exec)(dvc, R); /* Dispatch to backend. */ + if (!R->ctx->profiler_enabled) return; /* Profiling disabled. */ + pmon->elapsed_ns = mag_hpc_clock_elapsed_ns(start); + pmon->elapsed_ns_acc += pmon->elapsed_ns; + pmon_op->elapsed_ns_acc += pmon->elapsed_ns; + ++pmon->n_execs; + ++pmon_op->n_execs; +} + +static mag_tensor_t* MAG_HOTPROC mag_tensor_operator( + mag_ctx_t* ctx, + mag_op_t op, + bool inplace, + mag_tensor_t** inputs, + uint32_t numin, + const mag_op_param_t* params, + uint32_t numparams +) { + /* Validate inputs and params first */ + mag_assert2(op != MAG_OP_NOP); + mag_assert(inputs && mag_validate_inputs(op, inputs, numin), "Invalid input tensors for operation %s.", mag_op_meta_of(op)->mnemonic); + mag_assert(mag_validate_op_params(op, params, numparams), "Invalid parameters for operation %s.", mag_op_meta_of(op)->mnemonic); + + const mag_op_meta_t* meta = mag_op_meta_of(op); + mag_graph_eval_order_t gra = MAG_GRA_FWD; /* TODO */ + mag_tensor_t* (*r_alloc)(mag_tensor_t**, const mag_op_param_t*) = meta->r_alloc; + bool (*validate_op)(mag_op_t, mag_tensor_t*, mag_tensor_t**, const mag_op_param_t*) = meta->validator; + mag_tensor_t* R = (inplace && numin && meta->inplace) /* Inplace requested? */ + ? mag_tensor_create(ctx, (*inputs)->dtype, (*inputs)->shape, (*inputs)->rank, *inputs, 0) /* View R <- X for inplace aliasing op. */ + : (*r_alloc)(inputs, params); /* Construct new result tensor. */ + if (mag_unlikely(!(*validate_op)(op, R, inputs, params))) return NULL; /* Validation failed. */ + mag_tensor_t* grad = NULL; /* ∇ᵦL = ∂L/∂B - Upper gradient tensor. */ /* TODO */ + if (gra == MAG_GRA_BWD && grad) { + R->grad = R->grad /* ∇ₐL = ∑ᵢ (∂L/∂Bᵢ) ⋅ (∂Bᵢ/∂A) - Chain rule accumulate. */ + ? mag_add(R->grad, grad) /* ∇ₐL <- ∇ₐL + ∇ᵦL */ + : mag_clone(grad); /* ∇ₐL <- ∂L/∂B */ + } + mag_assert2(R->op == MAG_OP_NOP); + R->op = op; /* Set operation for deferred execution mode. */ + for (uint32_t i=0; i < numin; ++i) { /* Set input tensors and flags. */ + R->op_inputs[i] = inputs[i]; + } + if (params) memcpy(R->op_params, params, sizeof(*params)); /* Copy operation parameters */ + if (ctx->exec_mode == MAG_EXEC_MODE_EAGER) { /* In eager execution mode, we execute immediately. */ + mag_op_exec(R, ctx->device, gra); /* Execute the operation immediately. */ + } + return R; +} + +mag_tensor_t* mag_clone(mag_tensor_t* x) { + return mag_tensor_operator(x->ctx, MAG_OP_CLONE, false, &x, 1, NULL, 0); +} + +mag_tensor_t* mag_view(mag_tensor_t* x) { + return mag_tensor_operator(x->ctx, MAG_OP_VIEW, false, &x, 1, NULL, 0); +} + +mag_tensor_t* mag_transpose(mag_tensor_t* x) { + return mag_tensor_operator(x->ctx, MAG_OP_TRANSPOSE, false, &x, 1, NULL, 0); +} + +mag_tensor_t* mag_permute(mag_tensor_t* x, uint32_t d0, uint32_t d1, uint32_t d2, uint32_t d3, uint32_t d4, uint32_t d5) { + mag_op_param_t params[MAG_MAX_OP_PARAMS] = { + {.type=MAG_OP_TPARAM_U32, .x.u32=d0}, + {.type=MAG_OP_TPARAM_U32, .x.u32=d1}, + {.type=MAG_OP_TPARAM_U32, .x.u32=d2}, + {.type=MAG_OP_TPARAM_U32, .x.u32=d3}, + {.type=MAG_OP_TPARAM_U32, .x.u32=d4}, + {.type=MAG_OP_TPARAM_U32, .x.u32=d5} + }; + return mag_tensor_operator(x->ctx, MAG_OP_PERMUTE, false, &x, 1, params, sizeof(params)/sizeof(*params)); +} + +mag_tensor_t* mag_mean(mag_tensor_t* x) { return mag_tensor_operator(x->ctx, MAG_OP_MEAN, false, &x, 1, NULL, 0); } +mag_tensor_t* mag_min(mag_tensor_t* x) { return mag_tensor_operator(x->ctx, MAG_OP_MIN, false, &x, 1, NULL, 0); } +mag_tensor_t* mag_max(mag_tensor_t* x) { return mag_tensor_operator(x->ctx, MAG_OP_MAX, false, &x, 1, NULL, 0); } +mag_tensor_t* mag_sum(mag_tensor_t* x) { return mag_tensor_operator(x->ctx, MAG_OP_SUM, false, &x, 1, NULL, 0); } +mag_tensor_t* mag_abs(mag_tensor_t* x) { return mag_tensor_operator(x->ctx, MAG_OP_ABS, false, &x, 1, NULL, 0); } +mag_tensor_t* mag_abs_(mag_tensor_t* x) { return mag_tensor_operator(x->ctx, MAG_OP_ABS, true, &x, 1, NULL, 0); } +mag_tensor_t* mag_neg(mag_tensor_t* x) { return mag_tensor_operator(x->ctx, MAG_OP_NEG, false, &x, 1, NULL, 0); } +mag_tensor_t* mag_neg_(mag_tensor_t* x) { return mag_tensor_operator(x->ctx, MAG_OP_NEG, true, &x, 1, NULL, 0); } +mag_tensor_t* mag_log(mag_tensor_t* x) { return mag_tensor_operator(x->ctx, MAG_OP_LOG, false, &x, 1, NULL, 0); } +mag_tensor_t* mag_log_(mag_tensor_t* x) { return mag_tensor_operator(x->ctx, MAG_OP_LOG, true, &x, 1, NULL, 0); } +mag_tensor_t* mag_sqr(mag_tensor_t* x) { return mag_tensor_operator(x->ctx, MAG_OP_SQR, false, &x, 1, NULL, 0); } +mag_tensor_t* mag_sqr_(mag_tensor_t* x) { return mag_tensor_operator(x->ctx, MAG_OP_SQR, true, &x, 1, NULL, 0); } +mag_tensor_t* mag_sqrt(mag_tensor_t* x) { return mag_tensor_operator(x->ctx, MAG_OP_SQRT, false, &x, 1, NULL, 0); } +mag_tensor_t* mag_sqrt_(mag_tensor_t* x) { return mag_tensor_operator(x->ctx, MAG_OP_SQRT, true, &x, 1, NULL, 0); } +mag_tensor_t* mag_sin(mag_tensor_t* x) { return mag_tensor_operator(x->ctx, MAG_OP_SIN, false, &x, 1, NULL, 0); } +mag_tensor_t* mag_sin_(mag_tensor_t* x) { return mag_tensor_operator(x->ctx, MAG_OP_SIN, true, &x, 1, NULL, 0); } +mag_tensor_t* mag_cos(mag_tensor_t* x) { return mag_tensor_operator(x->ctx, MAG_OP_COS, false, &x, 1, NULL, 0); } +mag_tensor_t* mag_cos_(mag_tensor_t* x) { return mag_tensor_operator(x->ctx, MAG_OP_COS, true, &x, 1, NULL, 0); } +mag_tensor_t* mag_step(mag_tensor_t* x) { return mag_tensor_operator(x->ctx, MAG_OP_STEP, false, &x, 1, NULL, 0); } +mag_tensor_t* mag_step_(mag_tensor_t* x) { return mag_tensor_operator(x->ctx, MAG_OP_STEP, true, &x, 1, NULL, 0); } +mag_tensor_t* mag_softmax(mag_tensor_t* x) { return mag_tensor_operator(x->ctx, MAG_OP_SOFTMAX, false, &x, 1, NULL, 0); } +mag_tensor_t* mag_softmax_(mag_tensor_t* x) { return mag_tensor_operator(x->ctx, MAG_OP_SOFTMAX, true, &x, 1, NULL, 0); } +mag_tensor_t* mag_softmax_dv(mag_tensor_t* x) { return mag_tensor_operator(x->ctx, MAG_OP_SOFTMAX_DV, false, &x, 1, NULL, 0); } +mag_tensor_t* mag_softmax_dv_(mag_tensor_t* x) { return mag_tensor_operator(x->ctx, MAG_OP_SOFTMAX_DV, true, &x, 1, NULL, 0); } +mag_tensor_t* mag_sigmoid(mag_tensor_t* x) { return mag_tensor_operator(x->ctx, MAG_OP_SIGMOID, false, &x, 1, NULL, 0); } +mag_tensor_t* mag_sigmoid_(mag_tensor_t* x) { return mag_tensor_operator(x->ctx, MAG_OP_SIGMOID, true, &x, 1, NULL, 0); } +mag_tensor_t* mag_sigmoid_dv(mag_tensor_t* x) { return mag_tensor_operator(x->ctx, MAG_OP_SIGMOID_DV, false, &x, 1, NULL, 0); } +mag_tensor_t* mag_sigmoid_dv_(mag_tensor_t* x) { return mag_tensor_operator(x->ctx, MAG_OP_SIGMOID_DV, true, &x, 1, NULL, 0); } +mag_tensor_t* mag_hard_sigmoid(mag_tensor_t* x) { return mag_tensor_operator(x->ctx, MAG_OP_HARD_SIGMOID, false, &x, 1, NULL, 0); } +mag_tensor_t* mag_hard_sigmoid_(mag_tensor_t* x) { return mag_tensor_operator(x->ctx, MAG_OP_HARD_SIGMOID, true, &x, 1, NULL, 0); } +mag_tensor_t* mag_silu(mag_tensor_t* x) { return mag_tensor_operator(x->ctx, MAG_OP_SILU, false, &x, 1, NULL, 0); } +mag_tensor_t* mag_silu_(mag_tensor_t* x) { return mag_tensor_operator(x->ctx, MAG_OP_SILU, true, &x, 1, NULL, 0); } +mag_tensor_t* mag_silu_dv(mag_tensor_t* x) { return mag_tensor_operator(x->ctx, MAG_OP_SILU_DV, false, &x, 1, NULL, 0); } +mag_tensor_t* mag_silu_dv_(mag_tensor_t* x) { return mag_tensor_operator(x->ctx, MAG_OP_SILU_DV, true, &x, 1, NULL, 0); } +mag_tensor_t* mag_tanh(mag_tensor_t* x) { return mag_tensor_operator(x->ctx, MAG_OP_TANH, false, &x, 1, NULL, 0); } +mag_tensor_t* mag_tanh_(mag_tensor_t* x) { return mag_tensor_operator(x->ctx, MAG_OP_TANH, true, &x, 1, NULL, 0); } +mag_tensor_t* mag_tanh_dv(mag_tensor_t* x) { return mag_tensor_operator(x->ctx, MAG_OP_TANH_DV, false, &x, 1, NULL, 0); } +mag_tensor_t* mag_tanh_dv_(mag_tensor_t* x) { return mag_tensor_operator(x->ctx, MAG_OP_TANH_DV, true, &x, 1, NULL, 0); } +mag_tensor_t* mag_relu(mag_tensor_t* x) { return mag_tensor_operator(x->ctx, MAG_OP_RELU, false, &x, 1, NULL, 0); } +mag_tensor_t* mag_relu_(mag_tensor_t* x) { return mag_tensor_operator(x->ctx, MAG_OP_RELU, true, &x, 1, NULL, 0); } +mag_tensor_t* mag_relu_dv(mag_tensor_t* x) { return mag_tensor_operator(x->ctx, MAG_OP_RELU_DV, false, &x, 1, NULL, 0); } +mag_tensor_t* mag_relu_dv_(mag_tensor_t* x) { return mag_tensor_operator(x->ctx, MAG_OP_RELU_DV, true, &x, 1, NULL, 0); } +mag_tensor_t* mag_gelu(mag_tensor_t* x) { return mag_tensor_operator(x->ctx, MAG_OP_GELU, false, &x, 1, NULL, 0); } +mag_tensor_t* mag_gelu_(mag_tensor_t* x) { return mag_tensor_operator(x->ctx, MAG_OP_GELU, true, &x, 1, NULL, 0); } +mag_tensor_t* mag_gelu_dv(mag_tensor_t* x) { return mag_tensor_operator(x->ctx, MAG_OP_GELU_DV, false, &x, 1, NULL, 0); } +mag_tensor_t* mag_gelu_dv_(mag_tensor_t* x) { return mag_tensor_operator(x->ctx, MAG_OP_GELU_DV, true, &x, 1, NULL, 0); } +mag_tensor_t* mag_add(mag_tensor_t* x, mag_tensor_t* y) { return mag_tensor_operator(x->ctx, MAG_OP_ADD, false, (mag_tensor_t*[]){x, y}, 2, NULL, 0); } +mag_tensor_t* mag_add_(mag_tensor_t* x, mag_tensor_t* y) { return mag_tensor_operator(x->ctx, MAG_OP_ADD, true, (mag_tensor_t*[]){x, y}, 2, NULL, 0); } +mag_tensor_t* mag_sub(mag_tensor_t* x, mag_tensor_t* y) { return mag_tensor_operator(x->ctx, MAG_OP_SUB, false, (mag_tensor_t*[]){x, y}, 2, NULL, 0); } +mag_tensor_t* mag_sub_(mag_tensor_t* x, mag_tensor_t* y) { return mag_tensor_operator(x->ctx, MAG_OP_SUB, true, (mag_tensor_t*[]){x, y}, 2, NULL, 0); } +mag_tensor_t* mag_mul(mag_tensor_t* x, mag_tensor_t* y) { return mag_tensor_operator(x->ctx, MAG_OP_MUL, false, (mag_tensor_t*[]){x, y}, 2, NULL, 0); } +mag_tensor_t* mag_mul_(mag_tensor_t* x, mag_tensor_t* y) { return mag_tensor_operator(x->ctx, MAG_OP_MUL, true, (mag_tensor_t*[]){x, y}, 2, NULL, 0); } +mag_tensor_t* mag_div(mag_tensor_t* x, mag_tensor_t* y) { return mag_tensor_operator(x->ctx, MAG_OP_DIV, false, (mag_tensor_t*[]){x, y}, 2, NULL, 0); } +mag_tensor_t* mag_div_(mag_tensor_t* x, mag_tensor_t* y) { return mag_tensor_operator(x->ctx, MAG_OP_DIV, true, (mag_tensor_t*[]){x, y}, 2, NULL, 0); } + +mag_tensor_t* mag_adds(mag_tensor_t* x, float xi) { + mag_op_param_t param = {.type=MAG_OP_TPARAM_F32, .x.f32=xi}; + return mag_tensor_operator(x->ctx, MAG_OP_ADDS, false, &x, 1, ¶m, 1); +} + +mag_tensor_t* mag_adds_(mag_tensor_t* x, float xi) { + mag_op_param_t param = {.type=MAG_OP_TPARAM_F32, .x.f32=xi}; + return mag_tensor_operator(x->ctx, MAG_OP_ADDS, true, &x, 1, ¶m, 1); +} + +mag_tensor_t* mag_subs(mag_tensor_t* x, float xi) { + mag_op_param_t param = {.type=MAG_OP_TPARAM_F32, .x.f32=xi}; + return mag_tensor_operator(x->ctx, MAG_OP_SUBS, false, &x, 1, ¶m, 1); +} + +mag_tensor_t* mag_subs_(mag_tensor_t* x, float xi) { + mag_op_param_t param = {.type=MAG_OP_TPARAM_F32, .x.f32=xi}; + return mag_tensor_operator(x->ctx, MAG_OP_SUBS, true, &x, 1, ¶m, 1); +} + +mag_tensor_t* mag_muls(mag_tensor_t* x, float xi) { + mag_op_param_t param = {.type=MAG_OP_TPARAM_F32, .x.f32=xi}; + return mag_tensor_operator(x->ctx, MAG_OP_MULS, false, &x, 1, ¶m, 1); +} + +mag_tensor_t* mag_muls_(mag_tensor_t* x, float xi) { + mag_op_param_t param = {.type=MAG_OP_TPARAM_F32, .x.f32=xi}; + return mag_tensor_operator(x->ctx, MAG_OP_MULS, true, &x, 1, ¶m, 1); +} + +mag_tensor_t* mag_divs(mag_tensor_t* x, float xi) { + mag_op_param_t param = {.type=MAG_OP_TPARAM_F32, .x.f32=xi}; + return mag_tensor_operator(x->ctx, MAG_OP_DIVS, false, &x, 1, ¶m, 1); +} + +mag_tensor_t* mag_divs_(mag_tensor_t* x, float xi) { + mag_op_param_t param = {.type=MAG_OP_TPARAM_F32, .x.f32=xi}; + return mag_tensor_operator(x->ctx, MAG_OP_DIVS, true, &x, 1, ¶m, 1); +} + +mag_tensor_t* mag_matmul(mag_tensor_t* x, mag_tensor_t* y) { + return mag_tensor_operator(x->ctx, MAG_OP_MATMUL, false, (mag_tensor_t*[]){x, y}, 2, NULL, 0); +} + +static MAG_AINLINE void mag_tensor_virtual_to_physical_index(const mag_tensor_t* t, int64_t v_idx, int64_t(*p_idx)[MAG_MAX_DIMS]) { + mag_static_assert(MAG_MAX_DIMS == 6); + mag_load_local_storage_group(t, d, shape); + (*p_idx)[5] = v_idx / (d4*d3*d2*d1*d0); + (*p_idx)[4] = (v_idx - (*p_idx)[5]*d4*d3*d2*d1*d0) / (d3*d2*d1*d0); + (*p_idx)[3] = (v_idx - (*p_idx)[5]*d4*d3*d2*d1*d0 - (*p_idx)[4]*d3*d2*d1*d0) / (d2*d1*d0); + (*p_idx)[2] = (v_idx - (*p_idx)[5]*d4*d3*d2*d1*d0 - (*p_idx)[4]*d3*d2*d1*d0 - (*p_idx)[3]*d2*d1*d0) / (d1*d0); + (*p_idx)[1] = (v_idx - (*p_idx)[5]*d4*d3*d2*d1*d0 - (*p_idx)[4]*d3*d2*d1*d0 - (*p_idx)[3]*d2*d1*d0 - (*p_idx)[2]*d1*d0) / d0; + (*p_idx)[0] = v_idx - (*p_idx)[5]*d4*d3*d2*d1*d0 - (*p_idx)[4]*d3*d2*d1*d0 - (*p_idx)[3]*d2*d1*d0 - (*p_idx)[2]*d1*d0 - (*p_idx)[1]*d0; +} + +static MAG_AINLINE int64_t mag_tensor_physical_to_virtual_index(const mag_tensor_t* t, const int64_t (*p_idx)[MAG_MAX_DIMS]) { + mag_static_assert(MAG_MAX_DIMS == 6); + mag_load_local_storage_group(t, s, strides); + mag_load_local_storage_group_arr(*p_idx, i); + return s0*i0 + s1*i1 + s2*i2 + s3*i3 + s4*i4 + s5*i5; +} + +mag_tensor_t* mag_tensor_get_arg(const mag_tensor_t* t, size_t slot) { + mag_assert(slot < MAG_MAX_INPUT_TENSORS, "Slot must be within [0, %d)", MAG_MAX_INPUT_TENSORS); + return t->op_inputs[slot]; +} + +void mag_tensor_set_arg(mag_tensor_t* t, size_t slot, mag_tensor_t* arg) { + mag_assert(slot < MAG_MAX_INPUT_TENSORS, "Slot must be within [0, %d)", MAG_MAX_INPUT_TENSORS); + mag_assert(t->op_inputs[slot] == NULL, "Argument at slot #%zu already set", slot); + t->op_inputs[slot] = arg; +} + +void mag_tensor_copy_buffer_from(mag_tensor_t* t, const void* data, size_t size) { + mag_assert(size == (size_t) mag_tensor_data_size(t), "Buffer size mismatch: %zu != %lld", size, mag_tensor_data_size(t)); + mag_storage_buffer_t* sto = &t->storage; + (*sto->cpy_host_device)(sto, 0, data, size); +} + +void mag_tensor_fill(mag_tensor_t* t, float x) { + if (x == 0.0f) { + mag_storage_buffer_t* sto = &t->storage; + (*sto->set)(sto, 0, 0); /* Zero out the buffer. */ + return; + } + mag_assert2(t->ctx->device_type == MAG_COMPUTE_DEVICE_TYPE_CPU); + + switch (t->dtype) { + case MAG_DTYPE_F32: { + int64_t n = mag_tensor_numel(t); + float* buf = (float*)t->storage.base; + for (int64_t i=0; i < n; ++i) buf[i] = x; + } break; + default: mag_panic("Unsupported DType: %d", t->dtype); + } +} + +void mag_tensor_fill_random_uniform(mag_tensor_t* t, float min, float max) { + mag_assert2(t->ctx->device_type == MAG_COMPUTE_DEVICE_TYPE_CPU); + switch (t->dtype) { + case MAG_DTYPE_F32: { + int64_t n = mag_tensor_numel(t); + float* buf = (float*)t->storage.base; + mag_prng_generate_n(t->ctx, buf, n, min, max); /* Generate uniform random numbers. */ + } break; + default: mag_panic("Unsupported DType: %d", t->dtype); + } +} + +void mag_tensor_fill_random_normal(mag_tensor_t* t, float mean, float stddev) { + mag_assert2(t->ctx->device_type == MAG_COMPUTE_DEVICE_TYPE_CPU); + switch (t->dtype) { + case MAG_DTYPE_F32: { + int64_t n = mag_tensor_numel(t); + mag_assert((n & 1) == 0, "Number of elements must be even"); + float* buf = (float*)t->storage.base; + mag_prng_generate_n(t->ctx, buf, n, 0.0f, 1.0f); /* Generate uniform random numbers. */ + for (int64_t i=0; i < n; i += 2) { /* Map uniform to normal distribution using Box-Muller transform. */ + float* u1 = buf+i; + float* u2 = buf+i+1; + float mag = stddev*sqrtf(-2.0f*logf(*u1)); + float y0 = mag*cosf((float)(2.0*M_PI)*(*u2)) + mean; + float y1 = mag*sinf((float)(2.0*M_PI)*(*u2)) + mean; + *u1 = y0; + *u2 = y1; + } + } break; + default: mag_panic("Unsupported DType: %d", t->dtype); + } +} + +uint64_t mag_tensor_get_packed_refcounts(const mag_tensor_t* t) { + return (uint64_t)t->rcb.rc_strong|((uint64_t)t->rcb.rc_weak << 32); +} + +void mag_tensor_retain(mag_tensor_t* t) { + ++t->rcb.rc_strong; +} + +size_t mag_tensor_get_memory_usage(const mag_tensor_t* t) { + return sizeof(*t) + mag_tensor_data_size(t); +} + +static void mag_print_tensor_recursive(FILE* f, const mag_tensor_t* t, int64_t (*idx)[MAG_MAX_DIMS], const int64_t (*stri)[MAG_MAX_DIMS], int64_t curr_dim, int64_t total_dims, int indent) { + mag_assert2(t->ctx->device_type == MAG_COMPUTE_DEVICE_TYPE_CPU); + int64_t dim_size = t->shape[curr_dim]; + mag_load_local_storage_group_arr(*stri, s); + if (curr_dim == total_dims - 1) { + fprintf(f, "%*s[", indent, ""); + for (int64_t i = 0; i < dim_size; ++i) { + (*idx)[curr_dim] = i; + mag_load_local_storage_group_arr(*idx, i); + float val = *((const float*)t->storage.base + i0*s0 + i1*s1 + i2*s2 + i3*s3 + i4*s4 + i5*s5); + char fmt_buf[128]; + *mag_fmt_f64(MAG_FMT_G14, (double)val, fmt_buf) = '\0'; + fprintf(f, "%s", fmt_buf); + if (i < dim_size - 1) { + fprintf(f, " "); + } + } + fprintf(f, "]"); + } else { + fprintf(f, "%*s[\n", indent, ""); + for (int64_t i = 0; i < dim_size; ++i) { + (*idx)[curr_dim] = i; + mag_print_tensor_recursive(f, t, idx, stri, curr_dim + 1, total_dims, indent + 1); + if (i < dim_size - 1) { + fprintf(f, ",\n"); + } else { + fprintf(f, "\n"); + } + } + fprintf(f, "%*s]\n", indent, ""); + } +} + + +void mag_tensor_print(const mag_tensor_t* t, bool with_header, bool with_data) { + mag_assert(t->dtype == MAG_DTYPE_F32, "Tensor must be F32"); + mag_assert2(with_header || with_data); + mag_load_local_storage_group(t, x_d, shape); + mag_load_local_storage_group(t, x_s, strides); + FILE* f = stdout; + if (with_header) { + double buf_size_cvt = 0.0; + const char* buf_size_unit = NULL; + mag_humanize_memory_size(mag_tensor_get_memory_usage(t), &buf_size_cvt, &buf_size_unit); + char shape[MAG_FMT_DIM_BUF_SIZE]; + char strides[MAG_FMT_DIM_BUF_SIZE]; + mag_fmt_dims(&shape, &t->shape, t->rank); + mag_fmt_dims(&strides, &t->strides, MAG_MAX_DIMS); + static const char* flag_abbrs = "OVGE"; + mag_assert2(strlen(flag_abbrs) == MAG_TFLAG_LEN); + char flags[MAG_TFLAG_LEN+1] = {0}; + for (uint32_t i=0, k=0; i < MAG_TFLAG_LEN; ++i) + if (t->flags & (1 << i)) + flags[k++] = flag_abbrs[i]; + flags[MAG_TFLAG_LEN] = '\0'; + fprintf(f, "Tensor '%s', DType: %s, Rank: %" PRIi64 ", Elements: %" PRIi64 ", Shape: %s, Strides: %s, Mem: %.03f %s, Flags: %s (%x)\n", + t->name, + mag_dtype_meta_of(t->dtype)->name, + t->rank, + mag_tensor_numel(t), + shape, + strides, + buf_size_cvt, + buf_size_unit, + flags, + t->flags + ); + } + if (with_data) { + int64_t strides[MAG_MAX_DIMS]; + strides[MAG_MAX_DIMS-1] = 1; + for (int32_t i = MAG_MAX_DIMS-2; i >= 0; --i) // TODO: Fix this + strides[i] = strides[i+1] * t->shape[i+1]; + int64_t idx[MAG_MAX_DIMS] = {0}; + mag_print_tensor_recursive(f, t, &idx, &strides, 0, t->rank, 0); + } +} + +void mag_tensor_set_name(mag_tensor_t* t, const char* name) { + strncpy(t->name, name, MAG_MAX_TENSOR_NAME_LEN); + t->name[MAG_MAX_TENSOR_NAME_LEN-1] = '\0'; +} + +void mag_tensor_fmt_name(mag_tensor_t* t, const char* fmt, ...) { + va_list args; + va_start(args, fmt); + vsnprintf(t->name, sizeof(t->name), fmt, args); + va_end(args); +} + +const char* mag_tensor_get_name(const mag_tensor_t* t) { + return t->name; +} + +int64_t mag_tensor_rank(const mag_tensor_t* t) { return t->rank; } +const int64_t* mag_tensor_shape(const mag_tensor_t* t) { return t->shape; } +const int64_t* mag_tensor_strides(const mag_tensor_t* t) { return t->strides; } +mag_dtype_t mag_tensor_dtype(const mag_tensor_t* t) { return t->dtype; } + +void* mag_tensor_data_ptr(const mag_tensor_t* t) { + return (void*)t->storage.base; +} + +bool mag_tensor_is_scalar(const mag_tensor_t* t) { + #pragma GCC unroll 6 + for (uint32_t i=0; i < MAG_MAX_DIMS; ++i) + if (t->shape[i] != 1) + return false; + return true; +} + +bool mag_tensor_is_vector(const mag_tensor_t* t) { + #pragma GCC unroll 5 + for (uint32_t i=1; i < MAG_MAX_DIMS; ++i) + if (t->shape[i] != 1) + return false; + return true; +} + +bool mag_tensor_is_matrix(const mag_tensor_t* t) { + #pragma GCC unroll 4 + for (uint32_t i=2; i < MAG_MAX_DIMS; ++i) + if (t->shape[i] != 1) + return false; + return true; +} + +bool mag_tensor_is_volume(const mag_tensor_t* t) { + #pragma GCC unroll 3 + for (uint32_t i=3; i < MAG_MAX_DIMS; ++i) + if (t->shape[i] != 1) + return false; + return true; +} + +bool mag_tensor_is_shape_eq(const mag_tensor_t* a, const mag_tensor_t* b) { + return memcmp(a->shape, b->shape, sizeof(a->shape)) == 0; +} + +bool mag_tensor_are_strides_eq(const mag_tensor_t* a, const mag_tensor_t* b) { + return memcmp(a->strides, b->strides, sizeof(a->strides)) == 0; +} + +bool mag_tensor_can_broadcast(const mag_tensor_t* a, const mag_tensor_t* b) { + #pragma GCC unroll 6 + for (uint32_t i=0; i < MAG_MAX_DIMS; ++i) + if ((b->shape[i] % a->shape[i]) != 0) + return false; + return true; +} + +bool mag_tensor_is_transposed(const mag_tensor_t* t) { return t->strides[0] > t->strides[1]; } + +bool mag_tensor_is_permuted(const mag_tensor_t* t) { + #pragma GCC unroll 5 + for (uint32_t i=0; i < MAG_MAX_DIMS-1; ++i) + if (t->strides[i] > t->strides[i+1]) + return true; + return false; +} + +bool mag_tensor_is_contiguous(const mag_tensor_t* t) { + return *t->strides == 1; +} + +float mag_tensor_get_scalar_physical_index(mag_tensor_t* t, int64_t d0, int64_t d1, int64_t d2, int64_t d3, int64_t d4, int64_t d5) { + mag_static_assert(MAG_MAX_DIMS == 6); + mag_load_local_storage_group(t, s, strides); + switch (t->dtype) { + case MAG_DTYPE_F32: { + float r; + mag_storage_buffer_t* sto = &t->storage; + (*sto->cpy_device_host)(sto, sizeof(r)*(d0*s0 + d1*s1 + d2*s2 + d3*s3 + d4*s4 + d5*s5), &r, sizeof(r)); + return r; + } + default: mag_panic("Unsupported data type: %s", mag_dtype_meta_of(t->dtype)->name); + } +} + +void mag_tensor_set_scalar_physical_index(mag_tensor_t* t, int64_t d0, int64_t d1, int64_t d2, int64_t d3, int64_t d4, int64_t d5, float x) { + mag_static_assert(MAG_MAX_DIMS == 6); + mag_load_local_storage_group(t, s, strides); + switch (t->dtype) { + case MAG_DTYPE_F32: { + mag_storage_buffer_t* sto = &t->storage; + (*sto->cpy_host_device)(sto, sizeof(x)*(d0*s0 + d1*s1 + d2*s2 + d3*s3 + d4*s4 + d5*s5), &x, sizeof(x)); + } break; + default: mag_panic("Unsupported data type: %s", mag_dtype_meta_of(t->dtype)->name); + } +} + +float mag_tensor_get_scalar_virtual_index(mag_tensor_t* t, int64_t v_idx) { + if (!mag_tensor_is_contiguous(t)) { + int64_t pidx[MAG_MAX_DIMS]; + mag_tensor_virtual_to_physical_index(t, v_idx, &pidx); + return mag_tensor_get_scalar_physical_index(t, pidx[0], pidx[1], pidx[2], pidx[3], pidx[4], pidx[5]); + } + switch (t->dtype) { + case MAG_DTYPE_F32: { + float r; + mag_storage_buffer_t* sto = &t->storage; + (*sto->cpy_device_host)(sto, sizeof(r)*v_idx, &r, sizeof(r)); + return r; + } + default: + mag_panic("Unsupported data type: %s", mag_dtype_meta_of(t->dtype)->name); + } +} + +void mag_tensor_set_scalar_virtual_index(mag_tensor_t* t, int64_t v_idx, float x) { + if (!mag_tensor_is_contiguous(t)) { + int64_t pidx[MAG_MAX_DIMS]; + mag_tensor_virtual_to_physical_index(t, v_idx, &pidx); + mag_tensor_set_scalar_physical_index(t, pidx[0], pidx[1], pidx[2], pidx[3], pidx[4], pidx[5], x); + return; + } + switch (t->dtype) { + case MAG_DTYPE_F32: { + mag_storage_buffer_t* sto = &t->storage; + (*sto->cpy_host_device)(sto, sizeof(x)*v_idx, &x, sizeof(x)); + } break; + default: + mag_panic("Unsupported data type: %s", mag_dtype_meta_of(t->dtype)->name); + } +} + +bool mag_tensor_eq(const mag_tensor_t* a, const mag_tensor_t* b) { + if (a->dtype != b->dtype) return false; + if (a->rank != b->rank) return false; + if (memcmp(a->shape, b->shape, sizeof(a->shape)) != 0) return false; + if (a->numel != b->numel) return false; + /*int64_t n = mag_tensor_num_elements(a); TODO + switch (a->dtype) { + case MAG_DTYPE_F32: { + const float* buf_a = (const float*)a->buf; + const float* buf_b = (const float*)b->buf; + for (int64_t i = 0; i < n; ++i) { + if (buf_a[i] != buf_b[i]) { + return false; + } + } + } break; + default: mag_panic("Unsupported data type: %s", mag_dtype_info_of(a->dtype)->name); + }*/ + return false; +} + +bool mag_tensor_is_close(const mag_tensor_t* a, const mag_tensor_t* b, float eps, double* percent_eq) { + if (a->dtype != b->dtype) return false; + if (a->rank != b->rank) return false; + if (memcmp(a->shape, b->shape, sizeof(a->shape)) != 0) return false; + if (a->numel != b->numel) return false; + /*eps = eps < 0.0f ? FLT_EPSILON : eps; TODO + int64_t n = mag_tensor_num_elements(a); + int64_t n_eq = 0; + switch (a->dtype) { + case MAG_DTYPE_F32: { + const float* buf_a = (const float*)a->buf; + const float* buf_b = (const float*)b->buf; + for (int64_t i = 0; i < n; ++i) |x - y| <= ε ∀ x, y ∈ A, B + if (fabsf(buf_a[i] - buf_b[i]) <= eps) ++n_eq; + } break; + default: mag_panic("Unsupported data type: %s", mag_dtype_info_of(a->dtype)->name); + } + if (percent_eq) *percent_eq = (double)n_eq / (double)n * 100.0; + return n_eq == n;*/ + return false; +} + +void mag_tensor_img_draw_box(mag_tensor_t* t, int32_t x1, int32_t y1, int32_t x2, int32_t y2, int32_t wi, uint32_t rgb) { + mag_assert(t->rank == 3, "Tensor must be 3D image tensor"); + mag_assert2(x2 > x1 && y2 > y1 && x1 > 0 && y1 > 0 && x2 > 0 && y2 > 0); + float* buf = mag_tensor_data_ptr(t); + int32_t w = (int32_t)mag_tensor_image_width(t); + int32_t h = (int32_t)mag_tensor_image_height(t); + int32_t c = (int32_t)mag_tensor_image_channels(t); + mag_assert2(w && h && c == 3); + float r = (float)((rgb>>16)&0xff) / 255.0f; + float g = (float)((rgb>>8)&0xff) / 255.0f; + float b = (float)(rgb&0xff) / 255.0f; + wi = mag_xmax(1, wi); + for (int32_t i=0; i < wi; ++i) { + int32_t xx1 = x1+i; + int32_t yy1 = y1+i; + int32_t xx2 = x2-i; + int32_t yy2 = y2-i; + if (mag_unlikely(xx1 >= w)) xx1 = w-1; + if (mag_unlikely(xx2 >= w)) xx2 = w-1; + if (mag_unlikely(yy1 >= h)) yy1 = h-1; + if (mag_unlikely(yy2 >= h)) yy2 = h-1; + for (int32_t j=xx1; j <= xx2; ++j) { + float* r1 = buf + j + yy1*w + 0*w*h; + float* r2 = buf + j + yy2*w + 0*w*h; + float* g1 = buf + j + yy1*w + 1*w*h; + float* g2 = buf + j + yy2*w + 1*w*h; + float* b1 = buf + j + yy1*w + 2*w*h; + float* b2 = buf + j + yy2*w + 2*w*h; + mag_bnd_chk(r1, buf, mag_tensor_data_size(t)); + mag_bnd_chk(r2, buf, mag_tensor_data_size(t)); + mag_bnd_chk(g1, buf, mag_tensor_data_size(t)); + mag_bnd_chk(g2, buf, mag_tensor_data_size(t)); + mag_bnd_chk(b1, buf, mag_tensor_data_size(t)); + mag_bnd_chk(b2, buf, mag_tensor_data_size(t)); + *r1 = *r2 = r; + *g1 = *g2 = g; + *b1 = *b2 = b; + } + for (int32_t j = yy1; j <= yy2; ++j) { + float* r1 = buf + xx1 + j*w + 0*w*h; + float* r2 = buf + xx2 + j*w + 0*w*h; + float* g1 = buf + xx1 + j*w + 1*w*h; + float* g2 = buf + xx2 + j*w + 1*w*h; + float* b1 = buf + xx1 + j*w + 2*w*h; + float* b2 = buf + xx2 + j*w + 2*w*h; + mag_bnd_chk(r1, buf, mag_tensor_data_size(t)); + mag_bnd_chk(r2, buf, mag_tensor_data_size(t)); + mag_bnd_chk(g1, buf, mag_tensor_data_size(t)); + mag_bnd_chk(g2, buf, mag_tensor_data_size(t)); + mag_bnd_chk(b1, buf, mag_tensor_data_size(t)); + mag_bnd_chk(b2, buf, mag_tensor_data_size(t)); + *r1 = *r2 = r; + *g1 = *g2 = g; + *b1 = *b2 = b; + } + } +} + +static bool mag_glyph(uint32_t c, uint32_t x, uint32_t y) { + c -= 33, --x; + if (mag_unlikely(c > 93 || x > 6 || y > 13)) return false; + uint32_t i = 98*c + 7*y + x; + return (("0@P01248@00120000P49B0000000000000:DXlW2UoDX@10008@h;IR4n@R000015A000000000000000000000000h?00" + "04@010000000000000000l0bGX@aL10000124XcX@Q25j300000000?Q248@8?000008@Pl5:DX@aL10" + "000000`CX@o24`70000`AP01N48@P0100000000l5:DX@aL12T70124XcX@Q25:40000@P0P348@P01>" + "00000240HP01248@P0a101248@T47B4940000HP01248@P01L00000000oBVS248@P000000000l48P7" + "@Pn0000048@`31248@030000000P@Q25:DP0124`1001248@P01248@P0007@P01" + "24`@P01R30000000S9S10000000"[i/6]-'0')>>(i%6))&1; +} + +void mag_tensor_img_draw_text(mag_tensor_t* t, int32_t x, int32_t y, int32_t size, uint32_t rgb, const char* txt) { /* TODO: Implement font scaling, size is ignored currently */ + mag_assert(t->rank == 3, "Tensor must be a 3D image tensor"); + mag_assert2(x >= 0 && y >= 0 && size >= 8 && txt && *txt); + mag_assert2(t->ctx->device_type == MAG_COMPUTE_DEVICE_TYPE_CPU); + float* buf = (float*)t->storage.base; + int32_t w = (int32_t)mag_tensor_image_width(t); + int32_t h = (int32_t)mag_tensor_image_height(t); + int32_t c = (int32_t)mag_tensor_image_channels(t); + mag_assert2(w && h && c == 3); + float* pr = buf; + float* pg = buf + w*h; + float* pb = buf + w*h*2; + float r = (float)((rgb>>16)&0xff) / 255.0f; + float g = (float)((rgb>>8)&0xff) / 255.0f; + float b = (float)(rgb&0xff) / 255.0f; + int32_t ly = y; + for (int32_t lx = x; *txt; lx = (*txt == '\n' ? x : lx+8), ly = (*txt == '\n' ? ly+14 : ly), txt++) { + if (mag_unlikely(!isprint(*txt))) continue; + for (int32_t yy = 0; yy < 14; ++yy) { + for (int32_t xx = 0; xx < 8; ++xx) { + if (!mag_glyph(*txt, xx, yy)) continue; + int32_t px = lx + xx; + int32_t py = ly + yy; + if (mag_unlikely(px >= w || py >= h)) continue; + int32_t ii = py*w + px; + pr[ii] = r; + pg[ii] = g; + pb[ii] = b; + } + } + } +} + +mag_ctx_t* mag_tensor_get_ctx(const mag_tensor_t* t) { return t->ctx; } +void* mag_tensor_get_user_data(const mag_tensor_t* t) { return t->ud; } +void mag_tensor_set_user_data(mag_tensor_t* t, void* ud) { t->ud = ud; } + +#ifdef __APPLE__ + static bool mag_sysctl_mib01(uint8_t (*out)[256], size_t* o_len, int mib0, int mib1) { /* Get sysctl data */ + memset(out, 0, sizeof(*out)); + *o_len = 0; + int name[2] = {mib0, mib1}; + size_t len = 0; + if (mag_unlikely(sysctl(name, sizeof(name) / sizeof(*name), NULL, &len, NULL, 0))) return false; /* Get length */ + if (mag_unlikely(len >= sizeof(*out))) return false; /* Buffer too small */ + if (mag_unlikely(sysctl(name, sizeof(name) / sizeof(*name), *out, &len, NULL, 0))) return false; /* Get data */ + *o_len = len; + return true; + } + static bool mag_sysctl_key(uint8_t (*out)[256], size_t* o_len, const char* key) { /* Get sysctl data */ + memset(out, 0, sizeof(*out)); + *o_len = 0; + size_t len = 0; + if (mag_unlikely(sysctlbyname(key, NULL, &len, NULL, 0))) return false; /* Get length */ + if (mag_unlikely(len >= sizeof(*out))) return false; /* Buffer too small */ + if (mag_unlikely(sysctlbyname(key, *out, &len, NULL, 0))) return false; /* Get data */ + *o_len = len; + return true; + } + static uint64_t mag_sysctl_unpack_int(const uint8_t (*in)[256], size_t len) { /* Unpack sysctl data */ + switch (len) { + case sizeof(uint16_t): { uint16_t r; memcpy(&r, *in, sizeof(r)); return r; } + case sizeof(uint32_t): { uint32_t r; memcpy(&r, *in, sizeof(r)); return r; } + case sizeof(uint64_t): { uint64_t r; memcpy(&r, *in, sizeof(r)); return r; } + default: return 0; + } + } +#else + static bool mag_cpuinfo_parse_value(const char* key, char (*out)[128]) { + FILE* cpuinfo = mag_fopen("/proc/cpuinfo", "rt"); + if (mag_unlikely(!cpuinfo)) return false; + size_t key_len = strlen(key); + char line[128]; + while (fgets(line, sizeof(line), cpuinfo)) { + size_t line_len = strlen(line); + if (line_len > 0 && line[line_len-1] == '\n') line[line_len-1] = '\0'; + if (strncmp(line, key, key_len) == 0 && (isspace((unsigned char)line[key_len]) || line[key_len] == ':')) { + char* colon = strchr(line, ':'); + if (!colon) continue; + char* value = colon+1; + while (isspace((unsigned char)*value)) ++value; + char* end = value + strlen(value); + for (; end > value && isspace((unsigned char)*(end-1)); --end); + *end = '\0'; + size_t value_len = llabs(end-value); + if (mag_unlikely(!value_len || value_len >= sizeof(*out))) { + fclose(cpuinfo); + return false; + } + snprintf(*out, sizeof(*out), "%s", value); + fclose(cpuinfo); + return true; + } + } + fclose(cpuinfo); + return false; + } + static uint64_t mag_parse_meminfo_value(const char* line) { + const char *p = strchr(line, ':'); + if (mag_unlikely(!p)) return 0; + ++p; + p += strspn(p, " \t"); + errno = 0; + char* end; + uint64_t value = strtoull(p, &end, 10); + if (mag_unlikely(errno != 0 || p == end)) return 0; + return value<<10; + } +#endif + +#ifdef __linux__ +static void mag_trim_quotes(char* in) { + if (in == NULL || *in == '\0') return; + size_t len = strlen(in); + if (in[len - 1] == '"') { + in[len - 1] = '\0'; + len--; + } + if (in[0] == '"') { + memmove(in, in + 1, len); + } +} +#endif + +static void MAG_COLDPROC mag_system_host_info_query_os_name(char (*out_os_name)[128]) { /* Get OS name */ + #ifdef _WIN32 + + #elif defined(__APPLE__) + size_t len; + uint8_t tmp[256]; + if (mag_likely(mag_sysctl_mib01(&tmp, &len, CTL_KERN, KERN_VERSION) && len && *tmp)) { + char* colon = strchr((const char*)tmp, ':'); + if (colon) *colon = '\0'; + snprintf(*out_os_name, sizeof(*out_os_name), "%s", (const char*)tmp); + } + #elif defined (__linux__) + FILE* f = mag_fopen("/etc/os-release", "r"); + if (!f) { + f = mag_fopen("/usr/lib/os-release", "r"); + if (!f) { + f = mag_fopen("/etc/lsb-release", "r"); + if (mag_unlikely(!f)) return; + char line[256]; + while (fgets(line, sizeof(line), f) != NULL) { + size_t len = strlen(line); + if (len > 0 && line[len-1] == '\n') line[len-1] = '\0'; + if (strncmp(line, "DISTRIB_ID", sizeof("DISTRIB_ID")-1) == 0) { + char* equals_sign = strchr(line, '='); + if (equals_sign && *(equals_sign+1)) { + strncpy(*out_os_name, equals_sign+1, sizeof(*out_os_name)-1); + (*out_os_name)[sizeof(*out_os_name)-1] = '\0'; + } + } else if (strncmp(line, "DISTRIB_DESCRIPTION", sizeof("DISTRIB_DESCRIPTION")-1) == 0) { + char* equals_sign = strchr(line, '='); + if (equals_sign && *(equals_sign+1)) { + char* start_quote = strchr(equals_sign+1, '"'); + if (start_quote) { + char* end_quote = strchr(start_quote+1, '"'); + if (end_quote) { + size_t desc_len = end_quote-start_quote-1; + if (desc_len >= sizeof(*out_os_name)) desc_len = sizeof(*out_os_name)-1; + strncpy(*out_os_name, start_quote+1, desc_len); + (*out_os_name)[desc_len] = '\0'; + } else { + strncpy(*out_os_name, start_quote+1, sizeof(*out_os_name)-1); + (*out_os_name)[sizeof(*out_os_name)-1] = '\0'; + } + } else { + strncpy(*out_os_name, equals_sign+1, sizeof(*out_os_name)-1); + (*out_os_name)[sizeof(*out_os_name)-1] = '\0'; + } + } + } + } + fclose(f); + return; + } + } + char line[256]; + while (fgets(line, sizeof(line), f) != NULL) { + size_t len = strlen(line); + if (len > 0 && line[len-1] == '\n') line[len-1] = '\0'; + if (strncmp(line, "NAME", sizeof("NAME")-1) == 0) { + char* equals_sign = strchr(line, '='); + if (equals_sign && *(equals_sign+1)) { + strncpy(*out_os_name, equals_sign + 1, sizeof(*out_os_name)-1); + (*out_os_name)[sizeof(*out_os_name)-1] = '\0'; + } + } else if (strncmp(line, "PRETTY_NAME", sizeof("PRETTY_NAME")-1) == 0) { + char* equals_sign = strchr(line, '='); + if (equals_sign && *(equals_sign+1)) { + strncpy(*out_os_name, equals_sign+1, sizeof(*out_os_name)-1); + (*out_os_name)[sizeof(*out_os_name)-1] = '\0'; + } + } + } + fclose(f); + mag_trim_quotes(*out_os_name); + #endif +} + +static void MAG_COLDPROC mag_system_host_info_query_cpu_name(char (*out_cpu_name)[128]) { /* Get CPU name */ + #ifdef _WIN32 + HKEY key; + if (mag_unlikely(RegOpenKeyExA(HKEY_LOCAL_MACHINE, "HARDWARE\\DESCRIPTION\\System\\CentralProcessor\\0", 0, KEY_READ, &key))) return; + char tmp[64+1] = {0}; + DWORD len = sizeof(tmp); + if (mag_unlikely(RegQueryValueExA(key, "ProcessorNameString", NULL, NULL, (LPBYTE)tmp, &len))) return; + if (mag_likely(strlen(tmp))) tmp[strlen(tmp)-1] = '\0'; + snprintf(*out_cpu_name, sizeof(*out_cpu_name), "%s", tmp); + #elif defined(__APPLE__) + size_t len; + uint8_t tmp[256]; + if (mag_likely(mag_sysctl_key(&tmp, &len, "machdep.cpu.brand_string") && len && *tmp)) + snprintf(*out_cpu_name, sizeof(*out_cpu_name), "%s", (const char*)tmp); + #else + char cpu_name[128]; + if (mag_likely((mag_cpuinfo_parse_value("model name", &cpu_name) && *cpu_name) || (mag_cpuinfo_parse_value("Model", &cpu_name) && *cpu_name))) + snprintf(*out_cpu_name, sizeof(*out_cpu_name), "%s", cpu_name); + #endif +} + +static void MAG_COLDPROC mag_system_host_info_query_cpu_cores(uint32_t* out_virtual, uint32_t* out_physical, uint32_t* out_sockets) { /* Get CPU virtual (logical) cores. */ + #ifdef _WIN32 + DWORD size = 0; + GetLogicalProcessorInformation(NULL, &size); + if (mag_unlikely(!size)) return; + SYSTEM_LOGICAL_PROCESSOR_INFORMATION* info = (*mag_alloc)(NULL, size); + if (mag_unlikely(!GetLogicalProcessorInformation(info, &size))) goto end; + for (DWORD i=0; i < size/sizeof(*info); ++i) { + switch (info[i].Relationship) { + default: continue; + case RelationProcessorPackage: ++*out_sockets; continue; + case RelationProcessorCore: { + ++*out_physical; + uintptr_t m = (uintptr_t)info[i].ProcessorMask; + m = m - ((m>>1) & 0x5555555555555555); + m = (m & 0x3333333333333333) + ((m>>2) & 0x3333333333333333); + *out_virtual += (((m + (m>>4)) & 0xf0f0f0f0f0f0f0f) * 0x101010101010101)>>56; + } continue; + } + } + end: (*mag_alloc)(info, 0); + #elif defined(__APPLE__) + uint8_t tmp[256]; + size_t len; + if (mag_likely(mag_sysctl_key(&tmp, &len, "machdep.cpu.thread_count") && len)) + *out_virtual = mag_sysctl_unpack_int(&tmp, len); + if (mag_likely(mag_sysctl_key(&tmp, &len, "machdep.cpu.core_count") && len)) + *out_physical = mag_sysctl_unpack_int(&tmp, len); + if (mag_likely(mag_sysctl_key(&tmp, &len, "hw.packages") && len)) + *out_sockets = mag_sysctl_unpack_int(&tmp, len); + #else + long nprocs = sysconf(_SC_NPROCESSORS_ONLN); + FILE* cpuinfo = mag_fopen("/proc/cpuinfo", "r"); + if (mag_unlikely(!cpuinfo)) return; + uint32_t physical_ids[MAG_MAX_CPUS]; + uint32_t core_ids[MAG_MAX_CPUS]; + uint32_t package_ids[MAG_MAX_CPUS]; + uint32_t cpu_count = 0; + uint32_t package_count = 0; + uint32_t current_physical_id = 0; + uint32_t current_core_id = 0; + bool got_physical_id = false; + bool got_core_id = false; + char line[256]; + while (fgets(line, sizeof(line), cpuinfo) != NULL) { + if (strncmp(line, "physical id", sizeof("physical id")-1) == 0) { + char* ptr = strchr(line, ':'); + if (ptr) { + ++ptr; + for (; *ptr && !isdigit((unsigned char)*ptr); ++ptr); + if (*ptr) { current_physical_id = (uint32_t)strtoul(ptr, NULL, 10); got_physical_id = true; } + } + } else if (strncmp(line, "core id", sizeof("core id")-1) == 0) { + char* ptr = strchr(line, ':'); + if (ptr) { + ++ptr; + for (; *ptr && !isdigit((unsigned char)*ptr); ++ptr); + if (*ptr) { current_core_id = (uint32_t)strtoul(ptr, NULL, 10); got_core_id = true; } + } + } else if (*line == '\n') { + if (got_physical_id && got_core_id) { + bool is_unique = true; + for (int32_t i = 0; i < cpu_count; ++i) if (physical_ids[i] == current_physical_id && core_ids[i] == current_core_id) { is_unique = false; break; } + if (is_unique) { + if (cpu_count < MAG_MAX_CPUS) { + physical_ids[cpu_count] = current_physical_id; + core_ids[cpu_count] = current_core_id; + ++cpu_count; + } else break; + } + is_unique = true; + for (int32_t i = 0; i < package_count; ++i) if (package_ids[i] == current_physical_id) { is_unique = false; break; } + if (is_unique) { + if (package_count < MAG_MAX_CPUS) package_ids[package_count++] = current_physical_id; + else break; + } + } + got_physical_id = false; + got_core_id = false; + } + } + fclose(cpuinfo); + *out_virtual = nprocs > 0 ? (uint32_t)nprocs : 0; + if (!cpu_count && *out_virtual) cpu_count = *out_virtual; + *out_physical = mag_xmax(1, cpu_count); + *out_virtual = nprocs > 0 ? (uint32_t)nprocs : *out_physical; + *out_sockets = mag_xmax(1, package_count); + #endif +} + +static void MAG_COLDPROC mag_system_host_info_query_memory(uint64_t* out_phys_mem_total, uint64_t* out_phys_mem_free) { /* Get physical memory */ + #ifdef _WIN32 + MEMORYSTATUSEX mem; + mem.dwLength = sizeof(mem); + if (mag_likely(GlobalMemoryStatusEx(&mem))) { + *out_phys_mem_total = mem.ullTotalPhys; + *out_phys_mem_free = mem.ullAvailPhys; + } + #elif defined(__APPLE__) + uint8_t tmp[256]; + size_t len; + if (mag_likely(mag_sysctl_mib01(&tmp, &len, CTL_HW, HW_MEMSIZE) && len)) + *out_phys_mem_total = mag_sysctl_unpack_int(&tmp, len); + struct vm_statistics64 stats; + natural_t count = HOST_VM_INFO64_COUNT; + if (mag_likely(host_statistics64(mach_host_self(), HOST_VM_INFO64, (host_info64_t)(&stats), &count) == KERN_SUCCESS)) + *out_phys_mem_free = stats.free_count * getpagesize(); + #else + FILE* meminfo = mag_fopen("/proc/meminfo", "r"); + if (mag_unlikely(!meminfo)) return; + char line[256]; + while (fgets(line, sizeof(line), meminfo)) { + if (strncmp(line, "MemTotal:", sizeof("MemTotal:")-1) == 0) + *out_phys_mem_total = mag_parse_meminfo_value(line); + else if (strncmp(line, "MemAvailable:", sizeof("MemAvailable:")-1) == 0) + *out_phys_mem_free = mag_parse_meminfo_value(line); + } + fclose(meminfo); + #endif +} + +#if defined(__x86_64__) || defined(_M_X64) + static void mag_cpuid(uint32_t leaf, int32_t sub, uint32_t* oeax, uint32_t* oebx, uint32_t* oecx, uint32_t* oedx) { + #ifdef _MSC_VER + int regs[4]; + if (sub != -1) __cpuidex(regs, leaf, sub); + else __cpuid(regs, leaf); + *oeax = regs[0], *oebx = regs[1], *oecx = regs[2], *oedx = regs[3]; + #else + uint32_t eax, ebx, ecx, edx; + if (sub != -1) __cpuid_count(leaf, sub, eax, ebx, ecx, edx); + else __cpuid(leaf, eax, ebx, ecx, edx); + *oeax = eax, *oebx = ebx, *oecx = ecx, *oedx = edx; + #endif + } + static uint64_t MAG_AINLINE mag_xgetbv(void) { /* Query extended control register value. */ + #ifdef _MSC_VER + return _xgetbv(0); + #else + uint32_t lo, hi; + __asm__ __volatile__("xgetbv\n\t" : "=a" (lo), "=d" (hi) : "c" (0)); + return (uint64_t)lo | ((uint64_t)hi << 32); + #endif + } + #define mag_cpy_regs(id) \ + (*features)[MAG_X86_64_CPUID_##id][MAG_X86_64_CPUID_EAX] = eax; \ + (*features)[MAG_X86_64_CPUID_##id][MAG_X86_64_CPUID_EBX] = ebx; \ + (*features)[MAG_X86_64_CPUID_##id][MAG_X86_64_CPUID_ECX] = ecx; \ + (*features)[MAG_X86_64_CPUID_##id][MAG_X86_64_CPUID_EDX] = edx + static void MAG_COLDPROC mag_system_info_query_x86_64_cpu_features(uint32_t (*features)[8][4]) { + uint32_t eax=0, ebx=0, ecx=0, edx=0; + uint32_t max_basic_leaf, max_extended_leaf; + mag_cpuid(0, -1, &eax, &ebx, &ecx, &edx); + mag_cpy_regs(0H); + max_basic_leaf = eax; + mag_cpuid(0x80000000u, -1, &eax, &ebx, &ecx, &edx); + max_extended_leaf = eax; + if (max_basic_leaf >= 1u) { + mag_cpuid(1, -1, &eax, &ebx, &ecx, &edx); + mag_cpy_regs(1H); + } + if (max_basic_leaf >= 2u) { + mag_cpuid(2u, -1, &eax, &ebx, &ecx, &edx); + mag_cpy_regs(2H); + } + if (max_basic_leaf >= 7u) { + mag_cpuid(7u, 0, &eax, &ebx, &ecx, &edx); + mag_cpy_regs(7H); + } + if (max_basic_leaf >= 7u) { + mag_cpuid(7u, 1, &eax, &ebx, &ecx, &edx); + mag_cpy_regs(7H_1H); + } + if (max_basic_leaf >= 0x16u) { + mag_cpuid(0x16u, -1, &eax, &ebx, &ecx, &edx); + mag_cpy_regs(16H); + } + if (max_extended_leaf >= 0x80000001u) { + mag_cpuid(0x80000001u, -1, &eax, &ebx, &ecx, &edx); + mag_cpy_regs(80000001H); + } + if (max_extended_leaf >= 0x80000007u) { + mag_cpuid(0x80000007u, -1, &eax, &ebx, &ecx, &edx); + mag_cpy_regs(80000007H); + } + bool cpu_avx_support = ((*features)[MAG_X86_64_CPUID_1H][MAG_X86_64_CPUID_ECX] & 0x10000000u) != 0; + bool cpu_osxsave_support = ((*features)[MAG_X86_64_CPUID_1H][MAG_X86_64_CPUID_ECX] & 0x8000000u) != 0; + if (cpu_avx_support && cpu_osxsave_support) { + uint64_t xcr0 = mag_xgetbv(); + if ((xcr0 & 0x6) != 0x6u) { + (*features)[MAG_X86_64_CPUID_1H][MAG_X86_64_CPUID_ECX] &= ~0x10000000u; /* Clear AVX */ + (*features)[MAG_X86_64_CPUID_7H][MAG_X86_64_CPUID_EBX] &= ~0x20u; /* Clear AVX2 */ + } + if ((xcr0 & 0xe0) != 0xe0u) { /* OS does not support AVX-512, clear AVX512 */ + (*features)[MAG_X86_64_CPUID_7H][MAG_X86_64_CPUID_EBX] &= ~0xdc230000u; + (*features)[MAG_X86_64_CPUID_7H][MAG_X86_64_CPUID_ECX] &= ~0x5842u; + (*features)[MAG_X86_64_CPUID_7H][MAG_X86_64_CPUID_EDX] &= ~0x10cu; + (*features)[MAG_X86_64_CPUID_7H_1H][MAG_X86_64_CPUID_EAX] &= ~0x20u; + } + } else { + (*features)[MAG_X86_64_CPUID_1H][MAG_X86_64_CPUID_ECX] &= ~0x10000000u; /* Clear AVX */ + (*features)[MAG_X86_64_CPUID_7H][MAG_X86_64_CPUID_EBX] &= ~0x20u; /* Clear AVX2 */ + (*features)[MAG_X86_64_CPUID_7H][MAG_X86_64_CPUID_EBX] &= ~0xdc230000u; /* Clear AVX512 */ + (*features)[MAG_X86_64_CPUID_7H][MAG_X86_64_CPUID_ECX] &= ~0x5842u; /* Clear AVX512 */ + (*features)[MAG_X86_64_CPUID_7H][MAG_X86_64_CPUID_EDX] &= ~0x10cu; /* Clear AVX512 */ + (*features)[MAG_X86_64_CPUID_7H_1H][MAG_X86_64_CPUID_EAX] &= ~0x20u; /* Clear AVX512 */ + } + } + #undef mag_cpy_regs +#endif + +static void MAG_COLDPROC mag_system_host_info_query(mag_ctx_t* ctx) { + mag_system_host_info_query_os_name(&ctx->sys.os_name); + mag_system_host_info_query_cpu_name(&ctx->sys.cpu_name); + mag_system_host_info_query_cpu_cores(&ctx->sys.cpu_virtual_cores, &ctx->sys.cpu_physical_cores, &ctx->sys.cpu_sockets); + mag_system_host_info_query_memory(&ctx->sys.phys_mem_total, &ctx->sys.phys_mem_free); + #if defined(__x86_64__) || defined(_M_X64) + mag_system_info_query_x86_64_cpu_features(&ctx->sys.x86_64_cpu_features); + #endif + if (mag_unlikely(!*ctx->sys.os_name)) snprintf(ctx->sys.os_name, sizeof(ctx->sys.os_name), "Unknown"); + if (mag_unlikely(!*ctx->sys.cpu_name)) snprintf(ctx->sys.cpu_name, sizeof(ctx->sys.cpu_name), "Unknown"); +} + +static MAG_AINLINE void mag_sto_write_u32_le(uint8_t** p, uint32_t x) { + x = mag_bswap32(x); + memcpy(*p, &x, sizeof(x)); + *p += sizeof(x); +} + +static MAG_AINLINE void mag_sto_write_u64_le(uint8_t** p, uint64_t x) { + x = mag_bswap64(x); + memcpy(*p, &x, sizeof(x)); + *p += sizeof(x); +} + +static MAG_AINLINE uint32_t mag_sto_read_u32_le(const uint8_t** p) { + uint32_t x; + memcpy(&x, *p, sizeof(x)); + x = mag_bswap32(x); + *p += sizeof(x); + return x; +} + +static MAG_AINLINE uint64_t mag_sto_read_u64_le(const uint8_t** p) { + uint64_t x; + memcpy(&x, *p, sizeof(x)); + x = mag_bswap64(x); + *p += sizeof(x); + return x; +} + +#define MAG_STO_MAGIC "magtron!" +mag_static_assert(sizeof(MAG_STO_MAGIC)-1 == sizeof(uint64_t)); +#define MAG_STO_FILE_HEADER_SIZE ((sizeof(MAG_STO_MAGIC)-1) + sizeof(uint32_t)*3) +#define MAG_STO_TENSOR_HEADER_SIZE (MAG_MAX_TENSOR_NAME_LEN + sizeof(uint32_t) + sizeof(int64_t)*MAG_MAX_DIMS) +mag_static_assert(MAG_MAX_TENSOR_NAME_LEN % 8 == 0); +mag_static_assert(MAG_DTYPE__NUM <= 0xff); +mag_static_assert(MAG_MAX_DIMS <= 0xff); +#define mag_sto_sanitize(exp, ret) do { if (mag_unlikely(!(exp))) { mag_log_error("magnetron storage sanitize error: " #exp); return (ret); } } while (0) + +static bool mag_sto_write_file_header( /* Write file header - file header must be same in every version. */ + uint8_t** p, + const uint8_t* end, + uint32_t version, + uint32_t num_tensors, + uint32_t ud +) { + mag_sto_sanitize(*p + MAG_STO_FILE_HEADER_SIZE < end, false); + const uint8_t* start = *p; + uint64_t mag_magic; + memcpy(&mag_magic, MAG_STO_MAGIC, sizeof(mag_magic)); + mag_sto_write_u64_le(p, mag_magic); + mag_sto_write_u32_le(p, version); + mag_sto_write_u32_le(p, num_tensors); + mag_sto_write_u32_le(p, ud); + return mag_likely(*p - start == MAG_STO_FILE_HEADER_SIZE); +} + +static bool mag_sto_read_file_header( /* Read file header - file header must be same in every version. */ + const uint8_t** p, + const uint8_t* end, + uint32_t* version, + uint32_t* num_tensors, + uint32_t* ud +) { + mag_sto_sanitize(*p + MAG_STO_FILE_HEADER_SIZE < end, false); + const uint8_t* start = *p; + uint64_t mag_magic = mag_sto_read_u64_le(p); + mag_sto_sanitize(memcmp(&mag_magic, MAG_STO_MAGIC, sizeof(mag_magic)) == 0, false); + *version = mag_sto_read_u32_le(p); + *num_tensors = mag_sto_read_u32_le(p); + *ud = mag_sto_read_u32_le(p); + return mag_likely(*p - start == MAG_STO_FILE_HEADER_SIZE); +} + +static bool mag_sto_write_tensor_header( + uint8_t** p, + const uint8_t* end, + uint32_t version, + const char (*name)[MAG_MAX_TENSOR_NAME_LEN], + mag_tensor_flags_t flags, + mag_dtype_t dtype, + int64_t rank, + const int64_t (*shape)[MAG_MAX_DIMS] +) { + mag_sto_sanitize(*p + MAG_STO_TENSOR_HEADER_SIZE < end, false); + const uint8_t* start = *p; + switch (version) { + case 1: { + uint64_t name_u64[sizeof(*name)/sizeof(uint64_t)]; + memcpy(name_u64, *name, sizeof(*name)); + for (size_t i=0; i < sizeof(name_u64)/sizeof(*name_u64); ++i) /* Write name as multiple u64 */ + mag_sto_write_u64_le(p, name_u64[i]); + uint32_t aux = 0; /* Pack small fields into aux field */ + aux |= (flags & 0xff) << 16; + aux |= (dtype & 0xff) << 8; + aux |= (rank & 0xff); + mag_sto_write_u32_le(p, aux); /* Write aux field */ + for (size_t i=0; i < MAG_MAX_DIMS; ++i) { /* Write shape */ + mag_sto_sanitize((*shape)[i] >= 1 && (*shape)[i] < INT64_MAX, false); + mag_sto_write_u64_le(p, (uint64_t)(*shape)[i]); + } + } break; + default: return false; + } + return mag_likely(*p - start == MAG_STO_TENSOR_HEADER_SIZE); +} + +static bool mag_sto_read_tensor_header( + const uint8_t** p, + const uint8_t* end, + uint32_t version, + char (*name)[MAG_MAX_TENSOR_NAME_LEN], + mag_tensor_flags_t* flags, + mag_dtype_t* dtype, + int64_t* rank, + int64_t (*shape)[MAG_MAX_DIMS] +) { + mag_sto_sanitize(*p + MAG_STO_TENSOR_HEADER_SIZE < end, false); + const uint8_t* start = *p; + switch (version) { + case 1: { + uint64_t name_u64[sizeof(*name)/sizeof(uint64_t)]; + for (size_t i=0; i < sizeof(name_u64)/sizeof(*name_u64); ++i) /* Read name as multiple u64 */ + name_u64[i] = mag_sto_read_u64_le(p); + memcpy(name, name_u64, sizeof(*name)); + (*name)[sizeof(*name)-1] = '\0'; + uint32_t aux = mag_sto_read_u32_le(p); /* Read aux field */ + *flags = (mag_tensor_flags_t)((aux >> 16) & 0xff); + *dtype = (mag_dtype_t)((aux >> 8) & 0xff); + *rank = (int64_t)(aux & 0xff); + mag_sto_sanitize((*flags & ~((1u<= 0 && *dtype < MAG_DTYPE__NUM, false); + mag_sto_sanitize(*rank >= 1 && *rank <= MAG_MAX_DIMS, false); + for (size_t i=0; i < MAG_MAX_DIMS; ++i) { /* Read shape */ + uint64_t u64 = mag_sto_read_u64_le(p); + mag_sto_sanitize(u64>= 1 && u64 <= (uint64_t)INT64_MAX, false); + (*shape)[i] = (int64_t)u64; + } + mag_sto_sanitize(INT64_MAX/(*shape)[1] > (*shape)[0], false); /* Check for shape overflow */ + mag_sto_sanitize(INT64_MAX/(*shape)[2] > (*shape)[0]*(*shape)[1], false); + mag_sto_sanitize(INT64_MAX/(*shape)[3] > (*shape)[0]*(*shape)[1]*(*shape)[2], false); + mag_sto_sanitize(INT64_MAX/(*shape)[4] > (*shape)[0]*(*shape)[1]*(*shape)[2]*(*shape)[3], false); + mag_sto_sanitize(INT64_MAX/(*shape)[5] > (*shape)[0]*(*shape)[1]*(*shape)[2]*(*shape)[3]*(*shape)[4], false); + } break; + default: return false; + } + return mag_likely(*p - start == MAG_STO_TENSOR_HEADER_SIZE); +} + +static bool mag_sto_write_tensor_data( + uint8_t** p, + const uint8_t* end, + uint32_t version, + mag_dtype_t dtype, + const void* data, + int64_t size +) { + mag_sto_sanitize(size > 0, false); + mag_sto_sanitize(*p + size <= end, false); + const uint8_t* start = *p; + switch (version) { + case 1: { + memcpy(*p, data, size); + *p += size; + } break; + default: return false; + } + return mag_likely(*p - start == size); +} + +static bool mag_sto_read_tensor_data( + const uint8_t** p, + const uint8_t* end, + uint32_t version, + mag_dtype_t dtype, + void* data, + int64_t size +) { + mag_sto_sanitize(size > 0, false); + mag_sto_sanitize(*p + size <= end, false); + const uint8_t* start = *p; + switch (version) { + case 1: { + memcpy(data, *p, size); /* TODO: endianess conversion */ + *p += size; + } break; + default: return false; + } + return mag_likely(*p - start == size); +} + +static size_t mag_accumulate_data_size(mag_dtype_t dtype, const int64_t (*shape)[MAG_MAX_DIMS]) { + size_t size = mag_dtype_meta_of(dtype)->size; + #pragma GCC unroll 6 + for (size_t i=0; i < MAG_MAX_DIMS; ++i) size *= (size_t)mag_xmax(1, (*shape)[i]); + return size; +} + +static size_t mag_sto_total_size(const mag_tensor_t** tensors, size_t n) { + size_t total = MAG_STO_FILE_HEADER_SIZE; + for (size_t i=0; i < n; ++i) { + total += MAG_STO_TENSOR_HEADER_SIZE; + total += mag_accumulate_data_size(tensors[i]->dtype, &tensors[i]->shape); + } + mag_assert((total & 3) == 0, "Unaligned storage size: %zu", total); + return total; +} + +static uint8_t* mag_sto_write_buffered(const mag_tensor_t** tensors, size_t n_tensors, size_t* out_size, uint32_t version) { + if (mag_unlikely(!tensors || !n_tensors || n_tensors > UINT32_MAX || !out_size || !version || version > MAG_STORAGE_VERSION)) return NULL; /* Check input */ + *out_size = mag_sto_total_size(tensors, n_tensors); + uint8_t* base = (uint8_t*)(*mag_alloc)(NULL, *out_size ); /* Allocate buffer */ + uint8_t* needle = base; + const uint8_t* end = base + *out_size ; + if (mag_unlikely(!mag_sto_write_file_header(&needle, end, version, (uint32_t)n_tensors, 0))) goto error; /* Write file header */ + for (size_t i=0; i < n_tensors; ++i) { /* Write tensor headers */ + const mag_tensor_t* t = tensors[i]; + mag_assert2(t != NULL); + if (mag_unlikely(!mag_sto_write_tensor_header( + &needle, + end, + version, + &t->name, + t->flags, + t->dtype, + t->rank, + &t->shape + ))) goto error; + } + mag_assert2(needle - base == MAG_STO_FILE_HEADER_SIZE + n_tensors*MAG_STO_TENSOR_HEADER_SIZE); /* Check written data size */ + for (size_t i=0; i < n_tensors; ++i) { /* Write tensor data */ + const mag_tensor_t* t = tensors[i]; + mag_assert2(t->ctx->device_type == MAG_COMPUTE_DEVICE_TYPE_CPU); + if (mag_unlikely(!mag_sto_write_tensor_data(&needle, end, version, t->dtype, (const void*)t->storage.base, mag_tensor_data_size(t)))) goto error; /* Write data */ + } + return base; + error: /* Error handling */ + (*mag_alloc)(base, 0); + return NULL; +} + +MAG_EXPORT mag_tensor_t** mag_sto_read_buffered(mag_ctx_t* ctx, const uint8_t* buf, size_t size, uint32_t* out_n_tensors, uint32_t* out_version) { /* Load stored tensors from buffer. Function is exported for fuzzing test. */ + if (mag_unlikely(!ctx || !buf || !out_n_tensors || !out_version || size <= MAG_STO_FILE_HEADER_SIZE + MAG_STO_TENSOR_HEADER_SIZE + 1)) return NULL; /* Check input */ + const uint8_t* needle = buf; + const uint8_t* end = buf + size; + uint32_t n_tensors; + uint32_t ud; + if (mag_unlikely(!mag_sto_read_file_header(&needle, end, out_version, &n_tensors, &ud))) return NULL; /* Read file header */ + if (mag_unlikely(!*out_version || *out_version > MAG_VERSION)) return NULL; + if (mag_unlikely(!n_tensors)) return NULL; + mag_tensor_t** tensors = (mag_tensor_t**)(*mag_alloc)(NULL, n_tensors*sizeof(*tensors)); /* Allocate return tensor array */ + for (size_t i=0; i < n_tensors; ++i) { /* Read tensor headers */ + char name[MAG_MAX_TENSOR_NAME_LEN] = {0}; + mag_tensor_flags_t flags = 0; + mag_dtype_t dtype = 0; + int64_t rank = 0; + int64_t shape[MAG_MAX_DIMS] = {0}; + if (mag_unlikely(!mag_sto_read_tensor_header(&needle, end, *out_version, &name, &flags, &dtype, &rank, &shape))) goto error; /* Read tensor header */ + mag_tensor_t* t = mag_tensor_create(ctx, dtype, shape, rank, NULL, 0); /* Create placeholder tensor */ + mag_tensor_fmt_name(t, "%s", name); + t->flags = flags; + tensors[i] = t; + } + for (size_t i=0; i < n_tensors; ++i) { /* Read tensor data */ + mag_tensor_t* t = tensors[i]; + mag_assert2(t->ctx->device_type == MAG_COMPUTE_DEVICE_TYPE_CPU); + size_t data_size = mag_accumulate_data_size(t->dtype, &t->shape); + mag_assert2(needle + data_size <= end && data_size == mag_tensor_data_size(t)); + if (mag_unlikely(!mag_sto_read_tensor_data(&needle, end, *out_version, t->dtype, (void*)t->storage.base, data_size))) goto error; /* Read data into tensor's buffer */ + } + *out_n_tensors = n_tensors; + return tensors; + error: + (*mag_alloc)(tensors, 0); + return NULL; +} + +static bool mag_sto_has_mag_ext(const char* file) { /* Check if file path has magnetron extension. */ + if (mag_unlikely(!file || strlen(file) < sizeof(MAG_STORAGE_EXT))) return false; + const char* dot = strrchr(file, '.'); + return dot && !strcmp(dot, MAG_STORAGE_EXT); +} + +void mag_tensor_save(const mag_tensor_t* t, const char* file) { + mag_assert(mag_sto_has_mag_ext(file), "Invalid file extension: %s", file); + FILE* f = mag_fopen(file, "wb"); /* Open file */ + mag_assert(f, "Failed to open file stream: %s", file); + uint32_t version = MAG_STORAGE_VERSION; + size_t n_tensors = 1; + size_t n_bytes = 0; + uint8_t* ser = mag_sto_write_buffered(&t, n_tensors, &n_bytes, version); /* Serialize tensor */ + mag_assert(ser && n_bytes, "Failed to serialize tensor to file: %s", file); /* Check serialization */ + mag_assert(fwrite(ser, 1, n_bytes, f) == n_bytes, "Failed to write %zu bytes to file: %s", n_bytes, file); /* Write to file */ + (*mag_alloc)(ser, 0); /* Free buffer */ + fflush(f); + fclose(f); + double mem; + const char* unit; + mag_humanize_memory_size(n_bytes, &mem, &unit); + mag_log_info("Saved %zu tensor%s to file: %s, %.03f %s written, storage v.%u", n_tensors, n_tensors > 1 ? "s" : "", file, mem, unit, version); +} + +mag_tensor_t* mag_tensor_load(mag_ctx_t* ctx, const char* file) { + mag_assert(mag_sto_has_mag_ext(file), "Invalid file extension: %s", file); + FILE* f = mag_fopen(file, "rb"); /* Open file */ + mag_assert(f, "Failed to open file stream: %s", file); + mag_assert2(fseek(f, 0, SEEK_END) == 0); /* Seek to end */ + long n_bytes = ftell(f); /* Get file size */ + mag_assert(n_bytes > MAG_STO_FILE_HEADER_SIZE + MAG_STO_TENSOR_HEADER_SIZE + 1, "Malformed file size"); /* Check file size */ + mag_assert2(fseek(f, 0, SEEK_SET) == 0); /* Seek to start */ + uint8_t* buf = (uint8_t*)(*mag_alloc)(NULL, n_bytes); /* Allocate buffer */ + mag_assert(fread(buf, 1, n_bytes, f) == n_bytes, "Failed to read %zu bytes from file: %s", n_bytes, file); /* Read while file into buffer */ + fclose(f), f = NULL; /* Close file */ + uint32_t n_tensors = 0, version = 0; + mag_tensor_t** tensors = mag_sto_read_buffered(ctx, buf, n_bytes, &n_tensors, &version); /* Deserialize tensors */ + mag_assert(version > 0 && version <= MAG_VERSION, "Unsupported storage version: %u", version); /* Check version */ + mag_assert(tensors && n_tensors > 0, "Failed to load tensor from file: %s", file); + mag_tensor_t* target = *tensors; + (*mag_alloc)(buf, 0); /* Free buffer */ + (*mag_alloc)(tensors, 0); /* Free tensor array */ + double mem; + const char* unit; + mag_humanize_memory_size(n_bytes, &mem, &unit); + mag_log_info("Loaded %u tensor%s from file: %s, %.03f %s read, storage v.%u", n_tensors, n_tensors > 1 ? "s" : "", file, mem, unit, version); + return target; +} + +mag_tensor_t* mag_tensor_load_image(mag_ctx_t* ctx, const char* file, mag_color_channels_t channels, uint32_t resize_w, uint32_t resize_h) { + uint8_t* (*loader)(const char*, uint32_t(*)[3], mag_color_channels_t) = ctx->image_load_fn; + void (*load_free)(uint8_t*) = ctx->image_load_free_fn; + mag_assert(loader && load_free, "Image loader not set"); + uint32_t whc[3] = {0}; + uint8_t* src = (*loader)(file, &whc, channels); + mag_assert(src, "Failed to load tensor from image: '%s'", file); + if (resize_w && resize_h) { /* Resize requested. */ + float* ori = (*mag_alloc)(NULL, whc[2]*whc[1]*whc[0]*sizeof(*ori)); + for (int64_t k=0; k < whc[2]; ++k) { /* Convert from interleaved to planar representation. */ + for (int64_t j=0; j < whc[1]; ++j) { + for (int64_t i=0; i < whc[0]; ++i) { + ori[i + whc[0]*j + whc[0]*whc[1]*k] = (float)src[k + whc[2]*i + whc[2]*whc[0]*j] / 255.0f; /* Normalize pixel values to [0, 1] */ + } + } + } + mag_tensor_t* t = mag_tensor_create_3d(ctx, MAG_DTYPE_F32, whc[2], resize_h, resize_w); + float* dst = mag_tensor_data_ptr(t); + float* part = (*mag_alloc)(NULL, whc[2] * whc[1] * resize_w * sizeof(*part)); + float ws = (float)(whc[0] - 1)/(float)(resize_w - 1); + float hs = (float)(whc[1] - 1)/(float)(resize_h - 1); + for (uint32_t k = 0; k < whc[2]; ++k){ + for (uint32_t r = 0; r < whc[1]; ++r) { + for (uint32_t c = 0; c < resize_w; ++c) { + float val = 0; + if (c == resize_w - 1 || whc[0] == 1) { + val = ori[k*(whc[0])*(whc[1]) + r*(whc[0]) + (whc[0] - 1)]; + } else { + float sx = (float)c*ws; + uint32_t ix = (uint32_t)sx; + float dx = sx - (float)ix; + val = (1-dx) * (ori[k*(whc[0])*(whc[1]) + r*(whc[0]) + ix]) + dx*(ori[k*(whc[0])*(whc[1]) + r*(whc[0]) + (ix + 1)]); + } + part[k * resize_w * (whc[1]) + r * resize_w + c] = val; + } + } + } + for (uint32_t k = 0; k < whc[2]; ++k) { + for (uint32_t r = 0; r < resize_h; ++r) { + float sy = (float)r*hs; + uint32_t iy = (uint32_t)sy; + float dy = sy - (float)iy; + for (uint32_t c = 0; c < resize_w; ++c) { + float val = (1-dy)*(part[k * resize_w * whc[1] + iy * resize_w + c]); + dst[k * resize_w * resize_h + r * resize_w + c] = val; + } + if (r == resize_h - 1 || whc[1] == 1) continue; + for (uint32_t c = 0; c < resize_w; ++c) { + float val = dy*(part[k * resize_w * (whc[1]) + (iy + 1) * resize_w + c]); + dst[k * resize_w * resize_h + r * resize_w + c] += val; + } + } + } + (*mag_alloc)(ori, 0); + (*mag_alloc)(part, 0); + mag_assert(resize_w * resize_h * whc[2] == mag_tensor_numel(t), "Buffer size mismatch: %zu != %zu", resize_w * resize_h * whc[2], (size_t)mag_tensor_numel(t)); + (*load_free)(src); + mag_log_info("Loaded and resized tensor from image: %s, %u x %u x %u", file, resize_w, resize_h, whc[2]); + return t; + } else { + mag_tensor_t* t = mag_tensor_create_3d(ctx, MAG_DTYPE_F32, whc[2], whc[1], whc[0]); + float* dst = mag_tensor_data_ptr(t); + for (int64_t k = 0; k < whc[2]; ++k) { /* Convert from interleaved to planar representation. */ + for (int64_t j = 0; j < whc[1]; ++j) { + for (int64_t i = 0; i < whc[0]; ++i) { + dst[i + whc[0]*j + whc[0]*whc[1]*k] = (float)src[k + whc[2]*i + whc[2]*whc[0]*j] / 255.0f; /* Normalize pixel values to [0, 1] */ + } + } + } + mag_assert(whc[0]*whc[1]*whc[2] == mag_tensor_numel(t), "Buffer size mismatch: %zu != %zu", whc[0]*whc[1]*whc[2], (size_t)mag_tensor_numel(t)); + (*load_free)(src); + mag_log_info("Loaded tensor from image: %s, %u x %u x %u", file, whc[0], whc[1], whc[2]); + return t; + } +} + +void mag_tensor_save_image(const mag_tensor_t* t, const char* file) { + bool (*saver)(const char*, const uint8_t*, const uint32_t(*)[3]) = t->ctx->image_save_fn; + mag_assert(saver, "Image saver not set"); + int64_t rank = mag_tensor_rank(t); + mag_assert(rank == 3, "Tensor rank must be 3, but is: %" PRIi64, (size_t)rank); + int64_t w = mag_tensor_image_width(t); + int64_t h = mag_tensor_image_height(t); + int64_t c = mag_tensor_image_channels(t); + mag_assert(c == 1 || c == 3 || c == 4, "Invalid number of channels: %zu", (size_t)c); + mag_assert(w*h*c == mag_tensor_numel(t), "Buffer size mismatch: %zu != %zu", w*h*c, (size_t)mag_tensor_numel(t)); + uint8_t* dst = (*mag_alloc)(NULL, w*h*c); /* Allocate memory for image data */ + const float* src = mag_tensor_data_ptr(t); + for (int64_t k = 0; k < c; ++k) /* Convert from planar to interleaved format. */ + for (int64_t i = 0; i < w*h; ++i) + dst[i*c + k] = (uint8_t)(src[i + k*w*h]*255.0f); + const uint32_t whc[3] = {(uint32_t)w,(uint32_t)h,(uint32_t)c}; + mag_assert((*saver)(file, dst, &whc), "Failed to save tensor to image: %s", file); + (*mag_alloc)(dst, 0); /* Free image data */ + mag_log_info("Saved tensor to image: %s, width: %d, height: %d, channels: %d", file, (int)w, (int)h, (int)c); +} + +/* Rescale factors to push the exponent of a number towards zero. */ +#define rescale_exponents(P, N) \ + P(308), P(289), P(270), P(250), P(231), P(212), P(193), P(173), P(154), \ + P(135), P(115), P(96), P(77), P(58), P(38), P(0), P(0), P(0), N(39), N(58), \ + N(77), N(96), N(116), N(135), N(154), N(174), N(193), N(212), N(231), \ + N(251), N(270), N(289) +#define one_e_p(X) 1e+0 ## X +#define one_e_n(X) 1e-0 ## X +static const int16_t mag_rescale_e[] = { rescale_exponents(-, +) }; +static const double mag_rescale_n[] = { rescale_exponents(one_e_p, one_e_n) }; +#undef one_e_n +#undef one_e_p + +/* +** For p in range -70 through 57, this table encodes pairs (m, e) such that +** 4*2^p <= (uint8_t)m*10^e, and is the smallest value for which this holds. +*/ +static const int8_t mag_four_ulp_m_e[] = { + 34, -21, 68, -21, 14, -20, 28, -20, 55, -20, 2, -19, 3, -19, 5, -19, 9, -19, + -82, -18, 35, -18, 7, -17, -117, -17, 28, -17, 56, -17, 112, -16, -33, -16, + 45, -16, 89, -16, -78, -15, 36, -15, 72, -15, -113, -14, 29, -14, 57, -14, + 114, -13, -28, -13, 46, -13, 91, -12, -74, -12, 37, -12, 73, -12, 15, -11, 3, + -11, 59, -11, 2, -10, 3, -10, 5, -10, 1, -9, -69, -9, 38, -9, 75, -9, 15, -7, + 3, -7, 6, -7, 12, -6, -17, -7, 48, -7, 96, -7, -65, -6, 39, -6, 77, -6, -103, + -5, 31, -5, 62, -5, 123, -4, -11, -4, 49, -4, 98, -4, -60, -3, 4, -2, 79, -3, + 16, -2, 32, -2, 63, -2, 2, -1, 25, 0, 5, 1, 1, 2, 2, 2, 4, 2, 8, 2, 16, 2, + 32, 2, 64, 2, -128, 2, 26, 2, 52, 2, 103, 3, -51, 3, 41, 4, 82, 4, -92, 4, + 33, 4, 66, 4, -124, 5, 27, 5, 53, 5, 105, 6, 21, 6, 42, 6, 84, 6, 17, 7, 34, + 7, 68, 7, 2, 8, 3, 8, 6, 8, 108, 9, -41, 9, 43, 10, 86, 9, -84, 10, 35, 10, + 69, 10, -118, 11, 28, 11, 55, 12, 11, 13, 22, 13, 44, 13, 88, 13, -80, 13, + 36, 13, 71, 13, -115, 14, 29, 14, 57, 14, 113, 15, -30, 15, 46, 15, 91, 15, + 19, 16, 37, 16, 73, 16, 2, 17, 3, 17, 6, 17 +}; + +/* min(2^32-1, 10^e-1) for e in range 0 through 10 */ +static const uint32_t mag_ndigits_dec_threshold[] = { + 0, 9U, 99U, 999U, 9999U, 99999U, 999999U, + 9999999U, 99999999U, 999999999U, 0xffffffffU +}; + +/* Compute the number of digits in the decimal representation of x. */ +static size_t mag_ndigits_dec(uint32_t x) { + size_t t = ((mag_fls(x | 1) * 77) >> 8) + 1; /* 2^8/77 is roughly log2(10) */ + return t + (x > mag_ndigits_dec_threshold[t]); +} + +#define wint_r(x, sh, sc) { uint32_t d = (x*(((1<>sh; x -= d*sc; *p++ = (char)('0'+d); } +static char* mag_wuint9(char* p, uint32_t u) { + uint32_t v = u / 10000, w; + u -= v * 10000; + w = v / 10000; + v -= w * 10000; + *p++ = (char)('0'+w); + wint_r(v, 23, 1000) + wint_r(v, 12, 100) + wint_r(v, 10, 10) + *p++ = (char)('0'+v); + wint_r(u, 23, 1000) + wint_r(u, 12, 100) + wint_r(u, 10, 10) + *p++ = (char)('0'+u); + return p; +} +#undef wint_r + +#define wint_r(x, sh, sc) { uint32_t d = (x*(((1<>sh; x -= d*sc; *p++ = (char)('0'+d); } +static char* mag_wint(char* p, int32_t k) { + uint32_t u = (uint32_t)k; + if (k < 0) { u = ~u+1u; *p++ = '-'; } + if (u < 10000) { + if (u < 10) goto dig1; + if (u < 100) goto dig2; + if (u < 1000) goto dig3; + } else { + uint32_t v = u / 10000; u -= v * 10000; + if (v < 10000) { + if (v < 10) goto dig5; + if (v < 100) goto dig6; + if (v < 1000) goto dig7; + } else { + uint32_t w = v / 10000; v -= w * 10000; + if (w >= 10) wint_r(w, 10, 10) + *p++ = (char)('0'+w); + } + wint_r(v, 23, 1000) + dig7: wint_r(v, 12, 100) + dig6: wint_r(v, 10, 10) + dig5: *p++ = (char)('0'+v); + } + wint_r(u, 23, 1000) + dig3: wint_r(u, 12, 100) + dig2: wint_r(u, 10, 10) + dig1: *p++ = (char)('0'+u); + return p; +} +#undef wint_r + +/* -- Extended precision arithmetic --------------------------------------- */ + +/* +** The "nd" format is a fixed-precision decimal representation for numbers. It +** consists of up to 64 uint32_t values, with each uint32_t storing a value +** in the range [0, 1e9). A number in "nd" format consists of three variables: +** +** uint32_t nd[64]; +** uint32_t ndlo; +** uint32_t ndhi; +** +** The integral part of the number is stored in nd[0 ... ndhi], the value of +** which is sum{i in [0, ndhi] | nd[i] * 10^(9*i)}. If the fractional part of +** the number is zero, ndlo is zero. Otherwise, the fractional part is stored +** in nd[ndlo ... 63], the value of which is taken to be +** sum{i in [ndlo, 63] | nd[i] * 10^(9*(i-64))}. +** +** If the array part had 128 elements rather than 64, then every double would +** have an exact representation in "nd" format. With 64 elements, all integral +** doubles have an exact representation, and all non-integral doubles have +** enough digits to make both %.99e and %.99f do the right thing. +*/ +#define MAG_ND_MUL2K_MAX_SHIFT 29 +#define MAG_ND_MUL2K_DIV1E9(val) ((uint32_t)((val) / 1000000000)) + +/* Multiply nd by 2^k and add carry_in (ndlo is assumed to be zero). */ +static uint32_t nd_mul2k(uint32_t* nd, uint32_t ndhi, uint32_t k, uint32_t carry_in, mag_format_flags sf) { + uint32_t i, ndlo = 0, start = 1; + /* Performance hacks. */ + if (k > MAG_ND_MUL2K_MAX_SHIFT*2 && MAG_FMT_FP(sf) != MAG_FMT_FP(MAG_FMT_T_FP_F)) { + start = ndhi - (MAG_FMT_PREC(sf) + 17) / 8; + } + /* Real logic. */ + while (k >= MAG_ND_MUL2K_MAX_SHIFT) { + for (i = ndlo; i <= ndhi; i++) { + uint64_t val = ((uint64_t)nd[i] << MAG_ND_MUL2K_MAX_SHIFT) | carry_in; + carry_in = MAG_ND_MUL2K_DIV1E9(val); + nd[i] = (uint32_t)val - carry_in * 1000000000; + } + if (carry_in) { + nd[++ndhi] = carry_in; carry_in = 0; + if (start++ == ndlo) ++ndlo; + } + k -= MAG_ND_MUL2K_MAX_SHIFT; + } + if (k) { + for (i = ndlo; i <= ndhi; i++) { + uint64_t val = ((uint64_t)nd[i] << k) | carry_in; + carry_in = MAG_ND_MUL2K_DIV1E9(val); + nd[i] = (uint32_t)val - carry_in * 1000000000; + } + if (carry_in) nd[++ndhi] = carry_in; + } + return ndhi; +} + +/* Divide nd by 2^k (ndlo is assumed to be zero). */ +static uint32_t nd_div2k(uint32_t* nd, uint32_t ndhi, uint32_t k, mag_format_flags sf) { + uint32_t ndlo = 0, stop1 = ~0, stop2 = ~0; + /* Performance hacks. */ + if (!ndhi) { + if (!nd[0]) { + return 0; + } else { + uint32_t s = mag_ffs(nd[0]); + if (s >= k) { nd[0] >>= k; return 0; } + nd[0] >>= s; k -= s; + } + } + if (k > 18) { + if (MAG_FMT_FP(sf) == MAG_FMT_FP(MAG_FMT_T_FP_F)) { + stop1 = 63 - (int32_t)MAG_FMT_PREC(sf) / 9; + } else { + int32_t floorlog2 = ndhi * 29 + mag_fls(nd[ndhi]) - k; + int32_t floorlog10 = (int32_t)(floorlog2 * 0.30102999566398114); + stop1 = 62 + (floorlog10 - (int32_t)MAG_FMT_PREC(sf)) / 9; + stop2 = 61 + ndhi - (int32_t)MAG_FMT_PREC(sf) / 8; + } + } + /* Real logic. */ + while (k >= 9) { + uint32_t i = ndhi, carry = 0; + for (;;) { + uint32_t val = nd[i]; + nd[i] = (val >> 9) + carry; + carry = (val & 0x1ff) * 1953125; + if (i == ndlo) break; + i = (i - 1) & 0x3f; + } + if (ndlo != stop1 && ndlo != stop2) { + if (carry) { ndlo = (ndlo - 1) & 0x3f; nd[ndlo] = carry; } + if (!nd[ndhi]) { ndhi = (ndhi - 1) & 0x3f; stop2--; } + } else if (!nd[ndhi]) { + if (ndhi != ndlo) { ndhi = (ndhi - 1) & 0x3f; stop2--; } + else return ndlo; + } + k -= 9; + } + if (k) { + uint32_t mask = (1U << k) - 1, mul = 1000000000 >> k, i = ndhi, carry = 0; + for (;;) { + uint32_t val = nd[i]; + nd[i] = (val >> k) + carry; + carry = (val & mask) * mul; + if (i == ndlo) break; + i = (i - 1) & 0x3f; + } + if (carry) { ndlo = (ndlo - 1) & 0x3f; nd[ndlo] = carry; } + } + return ndlo; +} + +/* Add m*10^e to nd (assumes ndlo <= e/9 <= ndhi and 0 <= m <= 9). */ +static uint32_t nd_add_m10e(uint32_t* nd, uint32_t ndhi, uint8_t m, int32_t e) { + uint32_t i, carry; + if (e >= 0) { + i = (uint32_t)e/9; + carry = m * (mag_ndigits_dec_threshold[e - (int32_t)i*9] + 1); + } else { + int32_t f = (e-8)/9; + i = (uint32_t)(64 + f); + carry = m * (mag_ndigits_dec_threshold[e - f*9] + 1); + } + for (;;) { + uint32_t val = nd[i] + carry; + if (mag_unlikely(val >= 1000000000)) { + val -= 1000000000; + nd[i] = val; + if (mag_unlikely(i == ndhi)) { + ndhi = (ndhi + 1) & 0x3f; + nd[ndhi] = 1; + break; + } + carry = 1; + i = (i + 1) & 0x3f; + } else { + nd[i] = val; + break; + } + } + return ndhi; +} + +static bool nd_similar(uint32_t* nd, uint32_t ndhi, uint32_t* ref, size_t hilen, size_t prec) { + char nd9[9], ref9[9]; + if (hilen <= prec) { + if (mag_unlikely(nd[ndhi] != *ref)) return 0; + prec -= hilen; ref--; ndhi = (ndhi - 1) & 0x3f; + if (prec >= 9) { + if (mag_unlikely(nd[ndhi] != *ref)) return 0; + prec -= 9; ref--; ndhi = (ndhi - 1) & 0x3f; + } + } else { + prec -= hilen - 9; + } + mag_assert(prec < 9, "bad precision %d", prec); + mag_wuint9(nd9, nd[ndhi]); + mag_wuint9(ref9, *ref); + return !memcmp(nd9, ref9, prec) && (nd9[prec] < '5') == (ref9[prec] < '5'); +} + +/* Format f64 according to format flags. */ +static char* mag_fmt_f64(mag_format_flags sf, double n, char* p) { + size_t width = MAG_FMT_WIDTH(sf), prec = MAG_FMT_PREC(sf), len; + union { + uint64_t u64; + double n; + struct { /* TODO: make endian aware */ + uint32_t lo, hi; + } u32; + } t = {.n = n}; + if (mag_unlikely((t.u32.hi << 1) >= 0xffe00000)) { + /* Handle non-finite values uniformly for %a, %e, %f, %g. */ + int32_t prefix = 0, ch = (sf & MAG_FMT_F_UPPER) ? 0x202020 : 0; + if (((t.u32.hi & 0x000fffff) | t.u32.lo) != 0) { + ch ^= ('n' << 16) | ('a' << 8) | 'n'; + if ((sf & MAG_FMT_F_SPACE)) prefix = ' '; + } else { + ch ^= ('i' << 16) | ('n' << 8) | 'f'; + if ((t.u32.hi & 0x80000000)) prefix = '-'; + else if ((sf & MAG_FMT_F_PLUS)) prefix = '+'; + else if ((sf & MAG_FMT_F_SPACE)) prefix = ' '; + } + len = 3 + (prefix != 0); + if (!(sf & MAG_FMT_F_LEFT)) while (width-- > len) *p++ = ' '; + if (prefix) *p++ = prefix; + *p++ = (char)(ch >> 16); *p++ = (char)(ch >> 8); *p++ = (char)ch; + } else if (MAG_FMT_FP(sf) == MAG_FMT_FP(MAG_FMT_T_FP_A)) { + /* %a */ + const char* hexdig = (sf & MAG_FMT_F_UPPER) ? "0123456789ABCDEFPX" : "0123456789abcdefpx"; + int32_t e = (t.u32.hi >> 20) & 0x7ff; + char prefix = 0, eprefix = '+'; + if (t.u32.hi & 0x80000000) prefix = '-'; + else if ((sf & MAG_FMT_F_PLUS)) prefix = '+'; + else if ((sf & MAG_FMT_F_SPACE)) prefix = ' '; + t.u32.hi &= 0xfffff; + if (e) { + t.u32.hi |= 0x100000; + e -= 1023; + } else if (t.u32.lo | t.u32.hi) { + /* Non-zero denormal - normalise it. */ + uint32_t shift = t.u32.hi ? 20-mag_fls(t.u32.hi) : 52-mag_fls(t.u32.lo); + e = -1022 - shift; + t.u64 <<= shift; + } + /* abs(n) == t.u64 * 2^(e - 52) */ + /* If n != 0, bit 52 of t.u64 is set, and is the highest set bit. */ + if ((int32_t)prec < 0) { + /* Default precision: use smallest precision giving exact result. */ + prec = t.u32.lo ? 13-mag_ffs(t.u32.lo)/4 : 5-mag_ffs(t.u32.hi|0x100000)/4; + } else if (prec < 13) { + /* Precision is sufficiently low as to maybe require rounding. */ + t.u64 += (((uint64_t)1) << (51 - prec*4)); + } + if (e < 0) { + eprefix = '-'; + e = -e; + } + len = 5 + mag_ndigits_dec((uint32_t)e) + prec + (prefix != 0) + + ((prec | (sf & MAG_FMT_F_ALT)) != 0); + if (!(sf & (MAG_FMT_F_LEFT | MAG_FMT_F_ZERO))) { + while (width-- > len) *p++ = ' '; + } + if (prefix) *p++ = prefix; + *p++ = '0'; + *p++ = hexdig[17]; /* x or X */ + if ((sf & (MAG_FMT_F_LEFT | MAG_FMT_F_ZERO)) == MAG_FMT_F_ZERO) { + while (width-- > len) *p++ = '0'; + } + *p++ = '0' + (t.u32.hi >> 20); /* Usually '1', sometimes '0' or '2'. */ + if ((prec | (sf & MAG_FMT_F_ALT))) { + /* Emit fractional part. */ + char* q = p + 1 + prec; + *p = '.'; + if (prec < 13) t.u64 >>= (52 - prec*4); + else while (prec > 13) p[prec--] = '0'; + while (prec) { p[prec--] = hexdig[t.u64 & 15]; t.u64 >>= 4; } + p = q; + } + *p++ = hexdig[16]; /* p or P */ + *p++ = eprefix; /* + or - */ + p = mag_wint(p, e); + } else { + /* %e or %f or %g - begin by converting n to "nd" format. */ + uint32_t nd[64]; + uint32_t ndhi = 0, ndlo, i; + int32_t e = (int32_t)(t.u32.hi >> 20) & 0x7ff, ndebias = 0; + char prefix = 0, *q; + if (t.u32.hi & 0x80000000) prefix = '-'; + else if ((sf & MAG_FMT_F_PLUS)) prefix = '+'; + else if ((sf & MAG_FMT_F_SPACE)) prefix = ' '; + prec += ((int32_t)prec >> 31) & 7; /* Default precision is 6. */ + if (MAG_FMT_FP(sf) == MAG_FMT_FP(MAG_FMT_T_FP_G)) { + /* %g - decrement precision if non-zero (to make it like %e). */ + prec--; + prec ^= (uint32_t)((int32_t)prec >> 31); + } + if ((sf & MAG_FMT_T_FP_E) && prec < 14 && n != 0) { + /* Precision is sufficiently low that rescaling will probably work. */ + if ((ndebias = mag_rescale_e[e >> 6])) { + t.n = n * mag_rescale_n[e >> 6]; + if (mag_unlikely(!e)) t.n *= 1e10, ndebias -= 10; + t.u64 -= 2; /* Convert 2ulp below (later we convert 2ulp above). */ + nd[0] = 0x100000 | (t.u32.hi & 0xfffff); + e = ((int32_t)(t.u32.hi >> 20) & 0x7ff) - 1075 - (MAG_ND_MUL2K_MAX_SHIFT < 29); + goto load_t_lo; rescale_failed: + t.n = n; + e = (int32_t)(t.u32.hi >> 20) & 0x7ff; + ndebias = 0; + ndhi = 0; + } + } + nd[0] = t.u32.hi & 0xfffff; + if (e == 0) e++; else nd[0] |= 0x100000; + e -= 1043; + if (t.u32.lo) { + e -= 32 + (MAG_ND_MUL2K_MAX_SHIFT < 29); load_t_lo: +#if MAG_ND_MUL2K_MAX_SHIFT >= 29 + nd[0] = (nd[0]<<3) | (t.u32.lo>>29); + ndhi = nd_mul2k(nd, ndhi, 29, t.u32.lo & 0x1fffffff, sf); +#elif MAG_ND_MUL2K_MAX_SHIFT >= 11 + ndhi = nd_mul2k(nd, ndhi, 11, t.u32.lo>>21, sf); + ndhi = nd_mul2k(nd, ndhi, 11, (t.u32.lo>>10) & 0x7ff, sf); + ndhi = nd_mul2k(nd, ndhi, 11, (t.u32.lo<<1) & 0x7ff, sf); +#else +#error "MAG_ND_MUL2K_MAX_SHIFT not big enough" +#endif + } + if (e >= 0) { + ndhi = nd_mul2k(nd, ndhi, (uint32_t)e, 0, sf); + ndlo = 0; + } else { + ndlo = nd_div2k(nd, ndhi, (uint32_t)-e, sf); + if (ndhi && !nd[ndhi]) ndhi--; + } + /* |n| == nd * 10^ndebias (for slightly loose interpretation of ==) */ + if ((sf & MAG_FMT_T_FP_E)) { + /* %e or %g - assume %e and start by calculating nd's exponent (nde). */ + char eprefix = '+'; + int32_t nde = -1; + size_t hilen; + if (ndlo && !nd[ndhi]) { + ndhi = 64; do {} while (!nd[--ndhi]); + nde -= 64 * 9; + } + hilen = mag_ndigits_dec(nd[ndhi]); + nde += (int32_t)(ndhi * 9 + hilen); + if (ndebias) { + /* + ** Rescaling was performed, but this introduced some error, and might + ** have pushed us across a rounding boundary. We check whether this + ** error affected the result by introducing even more error (2ulp in + ** either direction), and seeing whether a rounding boundary was + ** crossed. Having already converted the -2ulp case, we save off its + ** most significant digits, convert the +2ulp case, and compare them. + */ + int32_t eidx = e + 70 + (MAG_ND_MUL2K_MAX_SHIFT < 29) + + (t.u32.lo >= 0xfffffffe && !(~t.u32.hi << 12)); + const int8_t *m_e = mag_four_ulp_m_e + eidx * 2; + mag_assert(0 <= eidx && eidx < 128, "bad eidx %d", eidx); + nd[33] = nd[ndhi]; + nd[32] = nd[(ndhi - 1) & 0x3f]; + nd[31] = nd[(ndhi - 2) & 0x3f]; + nd_add_m10e(nd, ndhi, (uint8_t)*m_e, m_e[1]); + if (mag_unlikely(!nd_similar(nd, ndhi, nd + 33, hilen, prec + 1))) { + goto rescale_failed; + } + } + if ((int32_t)(prec - nde) < (0x3f & -(int32_t)ndlo) * 9) { + /* Precision is sufficiently low as to maybe require rounding. */ + ndhi = nd_add_m10e(nd, ndhi, 5, (int32_t)nde - prec - 1); + nde += (hilen != mag_ndigits_dec(nd[ndhi])); + } + nde += ndebias; + if ((sf & MAG_FMT_T_FP_F)) { + /* %g */ + if ((int32_t)prec >= nde && nde >= -4) { + if (nde < 0) ndhi = 0; + prec -= nde; + goto g_format_like_f; + } else if (!(sf & MAG_FMT_F_ALT) && prec && width > 5) { + /* Decrease precision in order to strip trailing zeroes. */ + char tail[9]; + uint32_t maxprec = hilen - 1 + ((ndhi - ndlo) & 0x3f) * 9; + if (prec >= maxprec) prec = maxprec; + else ndlo = (ndhi - (((int32_t)(prec - hilen) + 9) / 9)) & 0x3f; + i = prec - hilen - (((ndhi - ndlo) & 0x3f) * 9) + 10; + mag_wuint9(tail, nd[ndlo]); + while (prec && tail[--i] == '0') { + prec--; + if (!i) { + if (ndlo == ndhi) { prec = 0; break; } + ndlo = (ndlo + 1) & 0x3f; + mag_wuint9(tail, nd[ndlo]); + i = 9; + } + } + } + } + if (nde < 0) { + /* Make nde non-negative. */ + eprefix = '-'; + nde = -nde; + } + len = 3 + prec + (prefix != 0) + mag_ndigits_dec((uint32_t)nde) + (nde < 10) + + ((prec | (sf & MAG_FMT_F_ALT)) != 0); + if (!(sf & (MAG_FMT_F_LEFT | MAG_FMT_F_ZERO))) { + while (width-- > len) *p++ = ' '; + } + if (prefix) *p++ = prefix; + if ((sf & (MAG_FMT_F_LEFT | MAG_FMT_F_ZERO)) == MAG_FMT_F_ZERO) { + while (width-- > len) *p++ = '0'; + } + q = mag_wint(p + 1, nd[ndhi]); + p[0] = p[1]; /* Put leading digit in the correct place. */ + if ((prec | (sf & MAG_FMT_F_ALT))) { + /* Emit fractional part. */ + p[1] = '.'; p += 2; + prec -= (size_t)(q - p); p = q; /* Account for digits already emitted. */ + /* Then emit chunks of 9 digits (this may emit 8 digits too many). */ + for (i = ndhi; (int32_t)prec > 0 && i != ndlo; prec -= 9) { + i = (i - 1) & 0x3f; + p = mag_wuint9(p, nd[i]); + } + if ((sf & MAG_FMT_T_FP_F) && !(sf & MAG_FMT_F_ALT)) { + /* %g (and not %#g) - strip trailing zeroes. */ + p += (int32_t)prec & ((int32_t)prec >> 31); + while (p[-1] == '0') p--; + if (p[-1] == '.') p--; + } else { + /* %e (or %#g) - emit trailing zeroes. */ + while ((int32_t)prec > 0) { *p++ = '0'; prec--; } + p += (int32_t)prec; + } + } else { + p++; + } + *p++ = (sf & MAG_FMT_F_UPPER) ? 'E' : 'e'; + *p++ = eprefix; /* + or - */ + if (nde < 10) *p++ = '0'; /* Always at least two digits of exponent. */ + p = mag_wint(p, nde); + } else { + /* %f (or, shortly, %g in %f style) */ + if (prec < (size_t)(0x3f & -(int32_t)ndlo) * 9) { + /* Precision is sufficiently low as to maybe require rounding. */ + ndhi = nd_add_m10e(nd, ndhi, 5, 0 - prec - 1); + } + g_format_like_f: + if ((sf & MAG_FMT_T_FP_E) && !(sf & MAG_FMT_F_ALT) && prec && width) { + /* Decrease precision in order to strip trailing zeroes. */ + if (ndlo) { + /* nd has a fractional part; we need to look at its digits. */ + char tail[9]; + uint32_t maxprec = (64 - ndlo) * 9; + if (prec >= maxprec) prec = maxprec; + else ndlo = 64 - (prec + 8) / 9; + i = prec - ((63 - ndlo) * 9); + mag_wuint9(tail, nd[ndlo]); + while (prec && tail[--i] == '0') { + prec--; + if (!i) { + if (ndlo == 63) { prec = 0; break; } + mag_wuint9(tail, nd[++ndlo]); + i = 9; + } + } + } else { + /* nd has no fractional part, so precision goes straight to zero. */ + prec = 0; + } + } + len = ndhi * 9 + mag_ndigits_dec(nd[ndhi]) + prec + (prefix != 0) + + ((prec | (sf & MAG_FMT_F_ALT)) != 0); + if (!(sf & (MAG_FMT_F_LEFT | MAG_FMT_F_ZERO))) { + while (width-- > len) *p++ = ' '; + } + if (prefix) *p++ = prefix; + if ((sf & (MAG_FMT_F_LEFT | MAG_FMT_F_ZERO)) == MAG_FMT_F_ZERO) { + while (width-- > len) *p++ = '0'; + } + /* Emit integer part. */ + p = mag_wint(p, nd[ndhi]); + i = ndhi; + while (i) p = mag_wuint9(p, nd[--i]); + if ((prec | (sf & MAG_FMT_F_ALT))) { + /* Emit fractional part. */ + *p++ = '.'; + /* Emit chunks of 9 digits (this may emit 8 digits too many). */ + while ((int32_t)prec > 0 && i != ndlo) { + i = (i - 1) & 0x3f; + p = mag_wuint9(p, nd[i]); + prec -= 9; + } + if ((sf & MAG_FMT_T_FP_E) && !(sf & MAG_FMT_F_ALT)) { + /* %g (and not %#g) - strip trailing zeroes. */ + p += (int32_t)prec & ((int32_t)prec >> 31); + while (p[-1] == '0') p--; + if (p[-1] == '.') p--; + } else { + /* %f (or %#g) - emit trailing zeroes. */ + while ((int32_t)prec > 0) { *p++ = '0'; prec--; } + p += (int32_t)prec; + } + } + } + } + if ((sf & MAG_FMT_F_LEFT)) while (width-- > len) *p++ = ' '; + return p; +} + +static uint8_t* mag_default_image_load_impl(const char* file, uint32_t(*whc)[3], mag_color_channels_t channels) { + mag_assert2(file && *file && whc); + int w, h, c, dc; + switch (channels) { + default: dc = STBI_default; break; + case MAG_COLOR_CHANNELS_GRAY: dc = STBI_grey; break; + case MAG_COLOR_CHANNELS_GRAY_A: dc = STBI_grey_alpha; break; + case MAG_COLOR_CHANNELS_RGB: dc = STBI_rgb; break; + case MAG_COLOR_CHANNELS_RGBA: dc = STBI_rgb_alpha; break; + } + uint8_t* buf = stbi_load(file, &w, &h, &c, dc); + if (mag_unlikely(!buf || !w || !h || !c || (c != 1 && c != 3 && c != 4))) return NULL; + (*whc)[0] = (uint32_t)w; + (*whc)[1] = (uint32_t)h; + (*whc)[2] = (uint32_t)c; + return buf; +} + +static void mag_default_image_load_free_fn_impl(uint8_t* p) { + stbi_image_free(p); +} + +static bool mag_default_image_save_impl(const char* file, const uint8_t* buf, const uint32_t(*whc)[3]) { + mag_assert2(file && *file && buf && whc); + return stbi_write_jpg(file, (int)(*whc)[0], (int)(*whc)[1], (int)(*whc)[2], buf, 100) != 0; +} diff --git a/magnetron/magnetron.h b/magnetron/magnetron.h new file mode 100644 index 0000000..76c3ef7 --- /dev/null +++ b/magnetron/magnetron.h @@ -0,0 +1,445 @@ +/* (c) 2024 Mario "Neo" Sieg. */ + +#ifndef MAGNETRON_H +#define MAGNETRON_H + +/* Compile time config macros */ +#define MAG_CFG_X86_64_FAST_MATH 1 /* Use fast math for x86_64 by setting mxcsr control register. */ +#define MAG_INTRIN 1 /* Use platform and compiler specific intrinsics for performance. */ +#define MAG_BOUNDS_CHECK 0 /* Enable bounds checking for BLAS routines. */ +#define MAG_SANITIZE_RC 0 /* Enable runtime checks for debugging of reference counted tensors. */ +#define MAG_EXPORT_DLL + +#if !defined(NDEBUG) && !MAG_BOUNDS_CHECK +#undef MAG_BOUNDS_CHECK +#define MAG_BOUNDS_CHECK 1 +#endif + +#include +#include +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +#define MAG_DEFAULT_CHUNK_SIZE (1ull<<30) /* Default size of memory chunk in bytes. 1 GiB */ +#define MAG_DEFAULT_CHUNK_CAP 128 /* Default capacity of memory chunk */ +#define MAG_MAX_DIMS 6 /* Maximum number of dimensions for a tensor */ +#define MAG_MAX_TENSOR_NAME_LEN 64 /* Maximum length for tensor name */ +#define MAG_MAX_INPUT_TENSORS 2 /* Maximum number of input tensors for an operation */ +#define MAG_MAX_OP_PARAMS 6 /* Maximum number of parameters for an operation */ + +#ifndef MAG_EXPORT +# ifdef MAG_EXPORT_DLL +# ifdef _MSC_VER +# define MAG_EXPORT __declspec(dllexport) +# else +# define MAG_EXPORT __attribute__((visibility("default"))) +# endif +# else +# define MAG_EXPORT +# endif +#endif + +#define mag_version_pack(major, minor) ((uint32_t)((((major)&0xff)<<8)+((minor)&0xff))) +#define mag_version_major(version) (((version)>>8)&0xff) +#define mag_version_minor(version) ((version)&0xff) +#define MAG_VERSION mag_version_pack(0, 1) /* magnetron library version. */ +#define MAG_STORAGE_VERSION 1 /* magnetron storage format version. */ + +#define mag_assert_name2(name, line) name ## line +#define mag_assert_name(line) mag_assert_name2(_assert_, line) +#define mag_static_assert(expr) extern void mag_assert_name(__LINE__)(bool STATIC_ASSERTION_FAILED[((expr)?1:-1)]) + +typedef enum mag_compute_device_type_t { + MAG_COMPUTE_DEVICE_TYPE_CPU = 0, /* CPU compute device */ + MAG_COMPUTE_DEVICE_TYPE_GPU_CUDA = 1, /* CUDA GPU compute device */ + + MAG_COMPUTE_DEVICE_TYPE__NUM +} mag_compute_device_type_t; +extern MAG_EXPORT const char* mag_device_type_get_name(mag_compute_device_type_t op); + +typedef enum mag_exec_mode_t { + MAG_EXEC_MODE_EAGER = 0, /* Execute operations immediately. (Dynamic computation graph, like PyTorch). */ + MAG_EXEC_MODE_DEFERRED = 1, /* Build computation graph and execute later. (Static computation graph, like TensorFlow 1.0). */ + + MAG_EXEC_MODE__NUM +} mag_exec_mode_t; + +typedef enum mag_prng_algorithm_t { + MAG_PRNG_MERSENNE_TWISTER = 0, /* Mersenne Twister PRNG */ + MAG_PRNG_PCG = 1, /* Permuted Congruential Generator PRNG */ + + MAG_PRNG__NUM +} mag_prng_algorithm_t; + +typedef enum mag_color_channels_t { + MAG_COLOR_CHANNELS_AUTO, /* Automatically detect number of color channels */ + MAG_COLOR_CHANNELS_GRAY, /* Grayscale F32 */ + MAG_COLOR_CHANNELS_GRAY_A,/* Grayscale F32 + Alpha F32 */ + MAG_COLOR_CHANNELS_RGB, /* R32G32B32 */ + MAG_COLOR_CHANNELS_RGBA, /* R32G32B32A32 */ + + MAG_COLOR_CHANNELS__NUM +} mag_color_channels_t; + +extern MAG_EXPORT void* (*mag_get_alloc_fn(void))(void* blk, size_t size); /* Get global allocator. */ +extern MAG_EXPORT void mag_set_alloc_fn(void* (*alloc)(void* blk, size_t size)); /* Set global allocator. */ +extern MAG_EXPORT void mag_set_log_mode(bool enabled); /* Enable/disable logging. */ + +typedef uint32_t mag_char32_t; + +typedef struct mag_ctx_t mag_ctx_t; /* Opaque context type for managing memory pools */ + +extern MAG_EXPORT mag_ctx_t* mag_ctx_create(mag_compute_device_type_t device); /* Create context with default config, and only specificy device. */ +extern MAG_EXPORT mag_exec_mode_t mag_ctx_get_exec_mode(const mag_ctx_t* ctx); /* Get execution mode */ +extern MAG_EXPORT void mag_ctx_set_exec_mode(mag_ctx_t* ctx, mag_exec_mode_t mode); /* Set execution mode */ +extern MAG_EXPORT mag_prng_algorithm_t mag_ctx_get_prng_algorithm(const mag_ctx_t* ctx); /* Get PRNG algorithm */ +extern MAG_EXPORT void mag_ctx_set_prng_algorithm(mag_ctx_t* ctx, mag_prng_algorithm_t algorithm, uint64_t seed); /* Set PRNG algorithm */ +extern MAG_EXPORT mag_compute_device_type_t mag_ctx_get_compute_device_type(const mag_ctx_t* ctx); /* Get compute device type */ +extern MAG_EXPORT const char* mag_ctx_get_compute_device_name(const mag_ctx_t* ctx); /* Get the name of the compute device */ +extern MAG_EXPORT const char* mag_ctx_get_os_name(const mag_ctx_t* ctx); /* Get the name of the operating system */ +extern MAG_EXPORT const char* mag_ctx_get_cpu_name(const mag_ctx_t* ctx); /* Get the name of the CPU */ +extern MAG_EXPORT uint32_t mag_ctx_get_cpu_virtual_cores(const mag_ctx_t* ctx); /* Get the number of virtual cores */ +extern MAG_EXPORT uint32_t mag_ctx_get_cpu_physical_cores(const mag_ctx_t* ctx); /* Get the number of physical cores */ +extern MAG_EXPORT uint32_t mag_ctx_get_cpu_sockets(const mag_ctx_t* ctx); /* Get the number of CPU sockets */ +extern MAG_EXPORT uint64_t mag_ctx_get_physical_memory_total(const mag_ctx_t* ctx); /* Get the total physical memory in bytes */ +extern MAG_EXPORT uint64_t mag_ctx_get_physical_memory_free(const mag_ctx_t* ctx); /* Get the free physical memory in bytes */ +extern MAG_EXPORT bool mag_ctx_is_numa_system(const mag_ctx_t* ctx); /* Check if the system is NUMA */ +extern MAG_EXPORT size_t mag_ctx_get_total_tensors_created(const mag_ctx_t* ctx); /* Get total tensors created. (Including views) */ +extern MAG_EXPORT void mag_ctx_profile_start_recording(mag_ctx_t* ctx); /* Start profiling */ +extern MAG_EXPORT void mag_ctx_profile_stop_recording(mag_ctx_t* ctx, const char* export_csv_file); /* Reset profiling data */ +extern MAG_EXPORT void mag_ctx_destroy(mag_ctx_t* ctx); /* Destroy context and free memory */ + +/** + * @brief Multidimensional tensor of arbitrary rank and data type. + * The tensor is reference counted and can be shared between multiple tensors. + * Rule of Thumb for Reference Counting: + * - If you only use the reference temporarily and do not store it, no need to adjust the reference count. + * - If you store the reference (e.g., in a data structure), increase the reference count when storing and decrease it when removing. + * The rank is > 0 and <= MAG_MAX_DIMS. The shape of the tensor is an array of dimensions of size MAG_MAX_DIMS. + * Is a node in a static or dynamic computation graph, depending on the context execution mode. + */ +typedef struct mag_tensor_t mag_tensor_t; + +typedef enum mag_dtype_t { + MAG_DTYPE_F32, /* 32-bit floating-point data type */ + MAG_DTYPE__NUM /* Total number of data types */ +} mag_dtype_t; +mag_static_assert(MAG_DTYPE__NUM <= 0xff); + +typedef struct mag_dtype_meta_t { + int64_t size; /* Size of the data type in bytes */ + const char* name; /* Name of the data type */ +} mag_dtype_meta_t; +extern MAG_EXPORT const mag_dtype_meta_t* mag_dtype_meta_of(mag_dtype_t type); + +#define MAG_SEP , + +typedef enum mag_op_t { + MAG_OP_NOP, + MAG_OP_CLONE, + MAG_OP_VIEW, + MAG_OP_TRANSPOSE, + MAG_OP_PERMUTE, + MAG_OP_MEAN, + MAG_OP_MIN, + MAG_OP_MAX, + MAG_OP_SUM, + MAG_OP_ABS, + MAG_OP_NEG, + MAG_OP_LOG, + MAG_OP_SQR, + MAG_OP_SQRT, + MAG_OP_SIN, + MAG_OP_COS, + MAG_OP_STEP, + MAG_OP_SOFTMAX, + MAG_OP_SOFTMAX_DV, + MAG_OP_SIGMOID, + MAG_OP_SIGMOID_DV, + MAG_OP_HARD_SIGMOID, + MAG_OP_SILU, + MAG_OP_SILU_DV, + MAG_OP_TANH, + MAG_OP_TANH_DV, + MAG_OP_RELU, + MAG_OP_RELU_DV, + MAG_OP_GELU, + MAG_OP_GELU_DV, + MAG_OP_ADD, + MAG_OP_SUB, + MAG_OP_MUL, + MAG_OP_DIV, + MAG_OP_ADDS, + MAG_OP_SUBS, + MAG_OP_MULS, + MAG_OP_DIVS, + MAG_OP_MATMUL, + MAG_OP__NUM +} mag_op_t; +mag_static_assert(MAG_OP_NOP == 0); +mag_static_assert(MAG_OP_MATMUL+1 == MAG_OP__NUM); +mag_static_assert(MAG_OP__NUM <= 0xff); + +typedef enum mag_op_param_type_t { + MAG_OP_TPARAM_NONE = 0, + MAG_OP_TPARAM_F32 = 1, + MAG_OP_TPARAM_I32 = 2, + MAG_OP_TPARAM_U32 = 3, + + MAG_OP_TPARAM__NUM +} mag_op_param_type_t; + +typedef struct mag_op_param_t { + mag_op_param_type_t type : 8; /* Parameter type */ + union { + float f32; + int32_t i32; + uint32_t u32; + } x; +} mag_op_param_t; + +typedef struct mag_op_meta_t { + const char* mnemonic; /* Operation mnemonic */ + uint8_t argcount; /* Number of arguments */ + uint8_t paramcount; /* Number of parameters */ + mag_op_param_type_t param_types[MAG_MAX_OP_PARAMS]; /* Parameter types */ + bool inplace; /* Supports inplace execution */ + mag_tensor_t* (*r_alloc)(mag_tensor_t**, const mag_op_param_t*); + bool (*validator)(mag_op_t, mag_tensor_t*, mag_tensor_t**, const mag_op_param_t*); +} mag_op_meta_t; +extern MAG_EXPORT const mag_op_meta_t* mag_op_meta_of(mag_op_t type); + +extern MAG_EXPORT uint32_t mag_pack_color_u8(uint8_t r, uint8_t g, uint8_t b); +extern MAG_EXPORT uint32_t mag_pack_color_f32(float r, float g, float b); + +typedef enum mag_graph_eval_order_t { + MAG_GRAPH_EVAL_ORDER_FORWARD = 0, /* Evaluate graph from left to right */ + MAG_GRAPH_EVAL_ORDER_REVERSE = 1 /* Evaluate graph from right to left */ +} mag_graph_eval_order_t; + +/** + * @brief Create a new 1-dimensional tensor. + * Data is uninitialized, should be filled with values before using it. + * @param ctx Context to create the tensor in. Must not be NULL. + * @param type Data type of the tensor. Must be a valid mag_dtype_t. + * @param d1 Size of the first dimension. Must be > 0 and < INT64_MAX. + * @returns New tensor. Is never NULL. + */ +extern MAG_EXPORT mag_tensor_t* mag_tensor_create_1d(mag_ctx_t* ctx, mag_dtype_t type, int64_t d1); + +/** + * @brief Create a new 2-dimensional tensor. + * Data is uninitialized, should be filled with values before using it. + * @param ctx Context to create the tensor in. Must not be NULL. + * @param type Data type of the tensor. Must be a valid mag_dtype_t. + * @param d1 Size of the first dimension. Must be > 0 and < INT64_MAX. + * @param d2 Size of the second dimension. Must be > 0 and < INT64_MAX. + * @returns New tensor. Is never NULL. + */ +extern MAG_EXPORT mag_tensor_t* mag_tensor_create_2d(mag_ctx_t* ctx, mag_dtype_t type, int64_t d1, int64_t d2); + +/** + * @brief Create a new 3-dimensional tensor. + * Data is uninitialized, should be filled with values before using it. + * @param ctx Context to create the tensor in. Must not be NULL. + * @param type Data type of the tensor. Must be a valid mag_dtype_t. + * @param d1 Size of the first dimension. Must be > 0 and < INT64_MAX. + * @param d2 Size of the second dimension. Must be > 0 and < INT64_MAX. + * @param d3 Size of the third dimension. Must be > 0 and < INT64_MAX. + * @returns New tensor. Is never NULL. + */ +extern MAG_EXPORT mag_tensor_t* mag_tensor_create_3d(mag_ctx_t* ctx, mag_dtype_t type, int64_t d1, int64_t d2, int64_t d3); + +/** + * @brief Create a new 4-dimensional tensor. + * Data is uninitialized, should be filled with values before using it. + * @param ctx Context to create the tensor in. Must not be NULL. + * @param type Data type of the tensor. Must be a valid mag_dtype_t. + * @param d1 Size of the first dimension. Must be > 0 and < INT64_MAX. + * @param d2 Size of the second dimension. Must be > 0 and < INT64_MAX. + * @param d3 Size of the third dimension. Must be > 0 and < INT64_MAX. + * @param d4 Size of the fourth dimension. Must be > 0 and < INT64_MAX. + * @returns New tensor. Is never NULL. + */ +extern MAG_EXPORT mag_tensor_t* mag_tensor_create_4d(mag_ctx_t* ctx, mag_dtype_t type, int64_t d1, int64_t d2, int64_t d3, int64_t d4); + +/** + * @brief Create a new 5-dimensional tensor. + * Data is uninitialized, should be filled with values before using it. + * @param ctx Context to create the tensor in. Must not be NULL. + * @param type Data type of the tensor. Must be a valid mag_dtype_t. + * @param d1 Size of the first dimension. Must be > 0 and < INT64_MAX. + * @param d2 Size of the second dimension. Must be > 0 and < INT64_MAX. + * @param d3 Size of the third dimension. Must be > 0 and < INT64_MAX. + * @param d4 Size of the fourth dimension. Must be > 0 and < INT64_MAX. + * @param d5 Size of the fifth dimension. Must be > 0 and < INT64_MAX. + * @returns New tensor. Is never NULL. + */ +extern MAG_EXPORT mag_tensor_t* mag_tensor_create_5d(mag_ctx_t* ctx, mag_dtype_t type, int64_t d1, int64_t d2, int64_t d3, int64_t d4, int64_t d5); + +/** + * @brief Create a new 6-dimensional tensor. + * Data is uninitialized, should be filled with values before using it. + * @param ctx Context to create the tensor in. Must not be NULL. + * @param type Data type of the tensor. Must be a valid mag_dtype_t. + * @param d1 Size of the first dimension. Must be > 0 and < INT64_MAX. + * @param d2 Size of the second dimension. Must be > 0 and < INT64_MAX. + * @param d3 Size of the third dimension. Must be > 0 and < INT64_MAX. + * @param d4 Size of the fourth dimension. Must be > 0 and < INT64_MAX. + * @param d5 Size of the fifth dimension. Must be > 0 and < INT64_MAX. + * @param d6 Size of the sixth dimension. Must be > 0 and < INT64_MAX. + * @returns New tensor. Is never NULL. + */ +extern MAG_EXPORT mag_tensor_t* mag_tensor_create_6d(mag_ctx_t* ctx, mag_dtype_t type, int64_t d1, int64_t d2, int64_t d3, int64_t d4, int64_t d5, int64_t d6); + +/* Tensor operation functions. */ + +extern MAG_EXPORT mag_tensor_t* mag_clone(mag_tensor_t* x); +extern MAG_EXPORT mag_tensor_t* mag_view(mag_tensor_t* x); +extern MAG_EXPORT mag_tensor_t* mag_transpose(mag_tensor_t* x); +extern MAG_EXPORT mag_tensor_t* mag_permute(mag_tensor_t* x, uint32_t d0, uint32_t d1, uint32_t d2, uint32_t d3, uint32_t d4, uint32_t d5); +extern MAG_EXPORT mag_tensor_t* mag_mean(mag_tensor_t* x); +extern MAG_EXPORT mag_tensor_t* mag_min(mag_tensor_t* x); +extern MAG_EXPORT mag_tensor_t* mag_max(mag_tensor_t* x); +extern MAG_EXPORT mag_tensor_t* mag_sum(mag_tensor_t* x); + +extern MAG_EXPORT mag_tensor_t* mag_abs(mag_tensor_t* x); +extern MAG_EXPORT mag_tensor_t* mag_abs_(mag_tensor_t* x); +extern MAG_EXPORT mag_tensor_t* mag_neg(mag_tensor_t* x); +extern MAG_EXPORT mag_tensor_t* mag_neg_(mag_tensor_t* x); +extern MAG_EXPORT mag_tensor_t* mag_log(mag_tensor_t* x); +extern MAG_EXPORT mag_tensor_t* mag_log_(mag_tensor_t* x); +extern MAG_EXPORT mag_tensor_t* mag_sqr(mag_tensor_t* x); +extern MAG_EXPORT mag_tensor_t* mag_sqr_(mag_tensor_t* x); +extern MAG_EXPORT mag_tensor_t* mag_sqrt(mag_tensor_t* x); +extern MAG_EXPORT mag_tensor_t* mag_sqrt_(mag_tensor_t* x); +extern MAG_EXPORT mag_tensor_t* mag_sin(mag_tensor_t* x); +extern MAG_EXPORT mag_tensor_t* mag_sin_(mag_tensor_t* x); +extern MAG_EXPORT mag_tensor_t* mag_cos(mag_tensor_t* x); +extern MAG_EXPORT mag_tensor_t* mag_cos_(mag_tensor_t* x); +extern MAG_EXPORT mag_tensor_t* mag_step(mag_tensor_t* x); +extern MAG_EXPORT mag_tensor_t* mag_step_(mag_tensor_t* x); +extern MAG_EXPORT mag_tensor_t* mag_softmax(mag_tensor_t* x); +extern MAG_EXPORT mag_tensor_t* mag_softmax_(mag_tensor_t* x); +extern MAG_EXPORT mag_tensor_t* mag_softmax_dv(mag_tensor_t* x); +extern MAG_EXPORT mag_tensor_t* mag_softmax_dv_(mag_tensor_t* x); +extern MAG_EXPORT mag_tensor_t* mag_sigmoid(mag_tensor_t* x); +extern MAG_EXPORT mag_tensor_t* mag_sigmoid_(mag_tensor_t* x); +extern MAG_EXPORT mag_tensor_t* mag_sigmoid_dv(mag_tensor_t* x); +extern MAG_EXPORT mag_tensor_t* mag_sigmoid_dv_(mag_tensor_t* x); +extern MAG_EXPORT mag_tensor_t* mag_hard_sigmoid(mag_tensor_t* x); +extern MAG_EXPORT mag_tensor_t* mag_hard_sigmoid_(mag_tensor_t* x); +extern MAG_EXPORT mag_tensor_t* mag_silu(mag_tensor_t* x); +extern MAG_EXPORT mag_tensor_t* mag_silu_(mag_tensor_t* x); +extern MAG_EXPORT mag_tensor_t* mag_silu_dv(mag_tensor_t* x); +extern MAG_EXPORT mag_tensor_t* mag_silu_dv_(mag_tensor_t* x); +extern MAG_EXPORT mag_tensor_t* mag_tanh(mag_tensor_t* x); +extern MAG_EXPORT mag_tensor_t* mag_tanh_(mag_tensor_t* x); +extern MAG_EXPORT mag_tensor_t* mag_tanh_dv(mag_tensor_t* x); +extern MAG_EXPORT mag_tensor_t* mag_tanh_dv_(mag_tensor_t* x); +extern MAG_EXPORT mag_tensor_t* mag_relu(mag_tensor_t* x); +extern MAG_EXPORT mag_tensor_t* mag_relu_(mag_tensor_t* x); +extern MAG_EXPORT mag_tensor_t* mag_relu_dv(mag_tensor_t* x); +extern MAG_EXPORT mag_tensor_t* mag_relu_dv_(mag_tensor_t* x); +extern MAG_EXPORT mag_tensor_t* mag_gelu(mag_tensor_t* x); +extern MAG_EXPORT mag_tensor_t* mag_gelu_(mag_tensor_t* x); +extern MAG_EXPORT mag_tensor_t* mag_gelu_dv(mag_tensor_t* x); +extern MAG_EXPORT mag_tensor_t* mag_gelu_dv_(mag_tensor_t* x); + +extern MAG_EXPORT mag_tensor_t* mag_add(mag_tensor_t* x, mag_tensor_t* y); +extern MAG_EXPORT mag_tensor_t* mag_add_(mag_tensor_t* x, mag_tensor_t* y); +extern MAG_EXPORT mag_tensor_t* mag_sub( mag_tensor_t* x, mag_tensor_t* y); +extern MAG_EXPORT mag_tensor_t* mag_sub_(mag_tensor_t* x, mag_tensor_t* y); +extern MAG_EXPORT mag_tensor_t* mag_mul( mag_tensor_t* x, mag_tensor_t* y); +extern MAG_EXPORT mag_tensor_t* mag_mul_(mag_tensor_t* x, mag_tensor_t* y); +extern MAG_EXPORT mag_tensor_t* mag_div(mag_tensor_t* x, mag_tensor_t* y); +extern MAG_EXPORT mag_tensor_t* mag_div_(mag_tensor_t* x, mag_tensor_t* y); +extern MAG_EXPORT mag_tensor_t* mag_adds(mag_tensor_t* x, float xi); +extern MAG_EXPORT mag_tensor_t* mag_adds_(mag_tensor_t* x, float xi); +extern MAG_EXPORT mag_tensor_t* mag_subs( mag_tensor_t* x, float xi); +extern MAG_EXPORT mag_tensor_t* mag_subs_(mag_tensor_t* x, float xi); +extern MAG_EXPORT mag_tensor_t* mag_muls(mag_tensor_t* x, float xi); +extern MAG_EXPORT mag_tensor_t* mag_muls_(mag_tensor_t* x, float xi); +extern MAG_EXPORT mag_tensor_t* mag_divs(mag_tensor_t* x, float xi); +extern MAG_EXPORT mag_tensor_t* mag_divs_(mag_tensor_t* x, float xi); +extern MAG_EXPORT mag_tensor_t* mag_matmul(mag_tensor_t* a, mag_tensor_t* b); + +/** + * @brief Increment reference count of tensor. + * Increment the strong reference count of the tensor. The tensor is not destroyed until the strong reference count reaches zero. + * Rule of Thumb for Reference Counting: + * - If you only use the reference temporarily and do not store it, no need to adjust the reference count. + * - If you store the reference (e.g., in a data structure), increase the reference count when storing and decrease it when removing. + * @param t Tensor. Must not be NULL. + */ +extern MAG_EXPORT void mag_tensor_incref(mag_tensor_t* t); + +/** + * @brief Decrement reference count of tensor. + * Decrement the strong reference count of the tensor. The tensor is destroyed when the strong reference count reaches zero. + * Rule of Thumb for Reference Counting: + * - If you only use the reference temporarily and do not store it, no need to adjust the reference count. + * - If you store the reference (e.g., in a data structure), increase the reference count when storing and decrease it when removing. + * @param t Tensor. Must not be NULL. + * @returns True if the tensor was destroyed, false if the tensor is still alive. + */ +extern MAG_EXPORT bool mag_tensor_decref(mag_tensor_t* t); + +extern MAG_EXPORT void mag_tensor_copy_buffer_from(mag_tensor_t* t, const void* data, size_t size); /* Copy data into tensor buffer */ +extern MAG_EXPORT void mag_tensor_fill(mag_tensor_t* t, float x); /* Set all tensor elements to a specific value */ +extern MAG_EXPORT void mag_tensor_fill_random_uniform(mag_tensor_t* t, float min, float max); /* Fill tensor with random values from uniform distribution within [min, max] */ +extern MAG_EXPORT void mag_tensor_fill_random_normal(mag_tensor_t* t, float mean, float stddev); /* Fill tensor with random values from the normal distribution. */ + +extern MAG_EXPORT uint64_t mag_tensor_get_packed_refcounts(const mag_tensor_t* t); /* Return strong refcount is loword, weak refcount is hiword. */ +extern MAG_EXPORT void mag_tensor_retain(mag_tensor_t* t); /* Increment refcount */ +extern MAG_EXPORT size_t mag_tensor_get_memory_usage(const mag_tensor_t* t); /* Return memory used by this tensor in bytes. */ +extern MAG_EXPORT void mag_tensor_print(const mag_tensor_t* t, bool with_header, bool with_data); /* Print tensor info (with or without data) */ +extern MAG_EXPORT void mag_tensor_set_name(mag_tensor_t* t, const char* name); /* Set the name of the tensor */ +extern MAG_EXPORT void mag_tensor_fmt_name(mag_tensor_t* t, const char* fmt, ...); /* Format the name of the tensor */ +extern MAG_EXPORT const char* mag_tensor_get_name(const mag_tensor_t* t); /* Get the name of the tensor */ +extern MAG_EXPORT int64_t mag_tensor_rank(const mag_tensor_t* t); /* Get the rank (number of dimensions) of the tensor */ +extern MAG_EXPORT const int64_t* mag_tensor_shape(const mag_tensor_t* t); /* Get the dimensions of the tensor */ +extern MAG_EXPORT const int64_t* mag_tensor_strides(const mag_tensor_t* t); /* Get the strides of the tensor */ +extern MAG_EXPORT mag_dtype_t mag_tensor_dtype(const mag_tensor_t* t); /* Get the data type of the tensor */ +extern MAG_EXPORT void* mag_tensor_data_ptr(const mag_tensor_t* t); /* Get the tensor raw buffer pointer. Might pointer to GPU or any other device memory. */ +extern MAG_EXPORT int64_t mag_tensor_data_size(const mag_tensor_t* t); /* Get the size of the tensor buffer in bytes. */ +extern MAG_EXPORT int64_t mag_tensor_numel(const mag_tensor_t* t); /* Get the total amount of elements in the tensor. */ +extern MAG_EXPORT int64_t mag_tensor_num_rows(const mag_tensor_t* t); /* Get the number of rows (for 2D tensors) */ +extern MAG_EXPORT int64_t mag_tensor_num_cols(const mag_tensor_t* t); /* Get the number of columns (for 2D tensors) */ +extern MAG_EXPORT bool mag_tensor_is_scalar(const mag_tensor_t* t); /* Check if the tensor is a scalar */ +extern MAG_EXPORT bool mag_tensor_is_vector(const mag_tensor_t* t); /* Check if the tensor is a vector */ +extern MAG_EXPORT bool mag_tensor_is_matrix(const mag_tensor_t* t); /* Check if the tensor is a matrix */ +extern MAG_EXPORT bool mag_tensor_is_volume(const mag_tensor_t* t); /* Check if the tensor is higher-order (3D or more) */ +extern MAG_EXPORT bool mag_tensor_is_shape_eq(const mag_tensor_t* a, const mag_tensor_t* b); /* Checks if a and b have the same shape. */ +extern MAG_EXPORT bool mag_tensor_are_strides_eq(const mag_tensor_t* a, const mag_tensor_t* b); /* Checks if a and b have the same strides. */ +extern MAG_EXPORT bool mag_tensor_can_broadcast(const mag_tensor_t* a, const mag_tensor_t* b); /* Checks if b can be broadcasted into a. */ +extern MAG_EXPORT bool mag_tensor_is_transposed(const mag_tensor_t* t); /* Check if the tensor is transposed */ +extern MAG_EXPORT bool mag_tensor_is_permuted(const mag_tensor_t* t); /* Check if the tensor is permuted */ +extern MAG_EXPORT bool mag_tensor_is_contiguous(const mag_tensor_t* t); /* Check if the tensor memory is contiguous */ +extern MAG_EXPORT float mag_tensor_get_scalar_physical_index(mag_tensor_t* t, int64_t d0, int64_t d1, int64_t d2, int64_t d3, int64_t d4, int64_t d5); /* Get scalar value at physical index */ +extern MAG_EXPORT void mag_tensor_set_scalar_physical_index(mag_tensor_t* t, int64_t d0, int64_t d1, int64_t d2, int64_t d3, int64_t d4, int64_t d5, float x); /* Set scalar value at physical index */ +extern MAG_EXPORT float mag_tensor_get_scalar_virtual_index(mag_tensor_t* t, int64_t v_idx); /* Get scalar value at virtual index */ +extern MAG_EXPORT void mag_tensor_set_scalar_virtual_index(mag_tensor_t* t, int64_t v_idx, float x); /* Set scalar value at virtual index */ +extern MAG_EXPORT bool mag_tensor_eq(const mag_tensor_t* a, const mag_tensor_t* b); /* Check if two tensors are equal without epsilon. */ +extern MAG_EXPORT bool mag_tensor_is_close(const mag_tensor_t* a, const mag_tensor_t* b, float eps, double* percent_eq); /* Check if two tensors are equal with epsilon and percentage in equality. Set eps to < 0 to use machine epsilon. */ +extern MAG_EXPORT void mag_tensor_img_draw_box(mag_tensor_t* t, int32_t x1, int32_t y1, int32_t x2, int32_t y2, int32_t wi, uint32_t rgb); +extern MAG_EXPORT void mag_tensor_img_draw_text(mag_tensor_t* t, int32_t x, int32_t y, int32_t size, uint32_t rgb, const char* txt); /* Draw text on image tensor */ +extern MAG_EXPORT mag_ctx_t* mag_tensor_get_ctx(const mag_tensor_t* t); /* Get the context of the tensor */ +extern MAG_EXPORT void* mag_tensor_get_user_data(const mag_tensor_t* t); /* Get the user data of the tensor */ +extern MAG_EXPORT void mag_tensor_set_user_data(mag_tensor_t* t, void* ud); /* Set the user data of the tensor */ +extern MAG_EXPORT void mag_tensor_save(const mag_tensor_t* t, const char* file); /* Save tensor to magnetron binary file. */ +extern MAG_EXPORT mag_tensor_t* mag_tensor_load(mag_ctx_t* ctx, const char* file); /* Load tensor from magnetron binary file. */ +extern MAG_EXPORT mag_tensor_t* mag_tensor_load_image(mag_ctx_t* ctx, const char* file, mag_color_channels_t channels, uint32_t resize_w, uint32_t resize_h); /* Create a tensor from an image file. */ +extern MAG_EXPORT void mag_tensor_save_image(const mag_tensor_t* t, const char* file); /* Save tensor data as an image */ +#define mag_tensor_image_width(tensor) (mag_tensor_shape(tensor)[2]) /* Get image width from tensor */ +#define mag_tensor_image_height(tensor) (mag_tensor_shape(tensor)[1]) /* Get image height from tensor */ +#define mag_tensor_image_channels(tensor) (mag_tensor_shape(tensor)[0]) /* Get image channels from tensor */ + +#ifdef __cplusplus +} +#endif +#endif diff --git a/magnetron/magnetron_cpu.c b/magnetron/magnetron_cpu.c new file mode 100644 index 0000000..8ab0266 --- /dev/null +++ b/magnetron/magnetron_cpu.c @@ -0,0 +1,2344 @@ +/* (c) 2024 Mario "Neo" Sieg. */ + +#include "magnetron_internal.h" + +#include +#include + +#if defined(__APPLE__) && defined(MAG_ACCELERATE) +#include +#endif + +typedef struct mag_cpu_thread_local_ctx_t { + mag_ctx_t* ctx; + int64_t thread_num; + int64_t thread_idx; +} mag_cpu_thread_local_ctx_t; + +#define mag_f32p(t) ((const float*)(t)->storage.base) +#define mag_f32p_mut(t) ((float*)(t)->storage.base) + +#if MAG_APPROXMATH && (defined(__aarch64__) && defined(__ARM_NEON)) || defined(_M_ARM64) + +static float32x4_t mag_simd_expf(float32x4_t x) { /* exp(x) : ℝ -> (0, ∞), x |-> e^x. Error = 1.45358 + 0.5 ulps. x > 88.38 -> INF, x < -103.97 -> 0 */ + float32x4_t r = vdupq_n_f32(0x1.8p23f); + float32x4_t z = vfmaq_f32(r, x, vdupq_n_f32(0x1.715476p+0f)); + float32x4_t n = vsubq_f32(z, r); + float32x4_t b = vfmsq_f32(vfmsq_f32(x, n, vdupq_n_f32(0x1.62e4p-1f)), n, vdupq_n_f32(0x1.7f7d1cp-20f)); + uint32x4_t e = vshlq_n_u32(vreinterpretq_u32_f32(z), 23); + float32x4_t k = vreinterpretq_f32_u32(vaddq_u32(e, vreinterpretq_u32_f32(vdupq_n_f32(1)))); + uint32x4_t c = vcagtq_f32(n, vdupq_n_f32(126)); + float32x4_t u = vmulq_f32(b, b); + float32x4_t j = vfmaq_f32( + vmulq_f32(vdupq_n_f32(0x1.ffffecp-1f), b), + vfmaq_f32(vfmaq_f32(vdupq_n_f32(0x1.fffdb6p-2f), vdupq_n_f32(0x1.555e66p-3f), b), + vfmaq_f32(vdupq_n_f32(0x1.573e2ep-5f), vdupq_n_f32(0x1.0e4020p-7f), b), u), u); + if (!vpaddd_u64(vreinterpretq_u64_u32(c))) return vfmaq_f32(k, j, k); + uint32x4_t d = vandq_u32(vclezq_f32(n), vdupq_n_u32(0x82000000)); + float32x4_t s1 = vreinterpretq_f32_u32(vaddq_u32(d, vdupq_n_u32(0x7f000000))); + float32x4_t s2 = vreinterpretq_f32_u32(vsubq_u32(e, d)); + return vbslq_f32(vcagtq_f32(n, vdupq_n_f32(192)), vmulq_f32(s1, s1), + vbslq_f32(c, vmulq_f32(vfmaq_f32(s2, s2, j), s1), vfmaq_f32(k, k, j))); +} + +static float32x4_t mag_simd_tanh(float32x4_t x) { /* tanh' : ℝ -> (-1, 1), x |-> 1 / ((cosh x)^2) */ + float32x4_t one = vdupq_n_f32(1.0f); + float32x4_t neg_one = vdupq_n_f32(-1.0f); + float32x4_t two = vdupq_n_f32(2.0f); + float32x4_t neg_two = vdupq_n_f32(-2.0f); + float32x4_t a = vmulq_f32(neg_two, x); + float32x4_t b = mag_simd_expf(a); + float32x4_t c = vaddq_f32(one, b); + float32x4_t inv = vrecpeq_f32(c); + inv = vmulq_f32(vrecpsq_f32(c, inv), inv); /* Newton–Raphson method */ + inv = vmulq_f32(vrecpsq_f32(c, inv), inv); /* Newton–Raphson method */ + return vaddq_f32(neg_one, vmulq_f32(two, inv)); +} + +static void mag_simd_sincos(float32x4_t x, float32x4_t *osin, float32x4_t *ocos) { + uint32x4_t sign_mask_sin = vcltq_f32(x, vdupq_n_f32(0)); + x = vabsq_f32(x); + float32x4_t y = vmulq_f32(x, vdupq_n_f32(1.27323954473516f)); + uint32x4_t emm2 = vcvtq_u32_f32(y); + emm2 = vaddq_u32(emm2, vdupq_n_u32(1)); + emm2 = vandq_u32(emm2, vdupq_n_u32(~1)); + y = vcvtq_f32_u32(emm2); + uint32x4_t poly_mask = vtstq_u32(emm2, vdupq_n_u32(2)); + x = vmlaq_f32(x, y, vdupq_n_f32(-0.78515625f)); + x = vmlaq_f32(x, y, vdupq_n_f32(-2.4187564849853515625e-4f)); + x = vmlaq_f32(x, y, vdupq_n_f32(-3.77489497744594108e-8f)); + sign_mask_sin = veorq_u32(sign_mask_sin, vtstq_u32(emm2, vdupq_n_u32(4))); + uint32x4_t sign_mask_cos = vtstq_u32(vsubq_u32(emm2, vdupq_n_u32(2)), vdupq_n_u32(4)); + float32x4_t z = vmulq_f32(x, x); + float32x4_t y1, y2; + y1 = vmlaq_f32(vdupq_n_f32(-1.388731625493765e-003f), z, vdupq_n_f32(2.443315711809948e-005f)); + y2 = vmlaq_f32(vdupq_n_f32(8.3321608736e-3f), z, vdupq_n_f32(-1.9515295891e-4f)); + y1 = vmlaq_f32(vdupq_n_f32(4.166664568298827e-002f), y1, z); + y2 = vmlaq_f32(vdupq_n_f32(-1.6666654611e-1f), y2, z); + y1 = vmulq_f32(y1, z); + y2 = vmulq_f32(y2, z); + y1 = vmulq_f32(y1, z); + y1 = vmlsq_f32(y1, z, vdupq_n_f32(0.5f)); + y2 = vmlaq_f32(x, y2, x); + y1 = vaddq_f32(y1, vdupq_n_f32(1)); + float32x4_t ys = vbslq_f32(poly_mask, y1, y2); + float32x4_t yc = vbslq_f32(poly_mask, y2, y1); + *osin = vbslq_f32(sign_mask_sin, vnegq_f32(ys), ys); + *ocos = vbslq_f32(sign_mask_cos, yc, vnegq_f32(yc)); +} + +#elif MAG_APPROXMATH && defined(__AVX512F__) && defined(__AVX512DQ__) + +static __m512 mag_simd_expf(const __m512 x) { /* exp(x) : ℝ -> (0, ∞), x |-> e^x. Error = 1.45358 + 0.5 ulps. x > 88.38 -> INF, x < -103.97 -> 0 */ + __m512 r = _mm512_set1_ps(0x1.8p23f); + __m512 z = _mm512_fmadd_ps(x, _mm512_set1_ps(0x1.715476p+0f), r); + __m512 n = _mm512_sub_ps(z, r); + __m512 b = _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.7f7d1cp-20f), _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.62e4p-1f), x)); + __mmask16 d = _mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(192), _CMP_GT_OQ); + __m512 u = _mm512_mul_ps(b, b); + __m512 j = _mm512_fmadd_ps( + _mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_set1_ps(0x1.0e4020p-7f), b, _mm512_set1_ps(0x1.573e2ep-5f)), u, + _mm512_fmadd_ps(_mm512_set1_ps(0x1.555e66p-3f), b, _mm512_set1_ps(0x1.fffdb6p-2f))), u, _mm512_fmadd_ps(_mm512_set1_ps(0x1.ffffecp-1f), b, _mm512_set1_ps(1.0F)) + ); + __m512 res = _mm512_scalef_ps(j, n); + if (_mm512_kortestz(d, d)) return res; + __m512 zero = _mm512_setzero_ps(); + __m512 alt = _mm512_mask_blend_ps(_mm512_cmp_ps_mask(n, zero, _CMP_LE_OQ), _mm512_set1_ps(INFINITY), zero); + return _mm512_mask_blend_ps(d, res, alt); +} + +static __m512 mag_simd_tanh(__m512 x) { /* tanh' : ℝ -> (-1, 1), x |-> 1 / ((cosh x)^2) */ + __m512 one = _mm512_set1_ps(1.0f); + __m512 neg_one = _mm512_set1_ps(-1.0f); + __m512 two = _mm512_set1_ps(2.0f); + __m512 neg_two = _mm512_set1_ps(-2.0f); + __m512 a = _mm512_mul_ps(neg_two, x); + __m512 b = mag_simd_expf(a); + __m512 c = _mm512_add_ps(one, b); + __m512 inv = _mm512_rcp14_ps(c); + inv = _mm512_mul_ps(_mm512_rcp14_ps(_mm512_mul_ps(c, inv)), inv); /* Newton–Raphson method */ + inv = _mm512_mul_ps(_mm512_rcp14_ps(_mm512_mul_ps(c, inv)), inv); /* Newton–Raphson method */ + return _mm512_fmadd_ps(two, inv, neg_one); +} + +#elif MAG_APPROXMATH && defined(__AVX2__) && defined(__FMA__) + +static __m256 mag_simd_expf(const __m256 x) { /* exp(x) : ℝ -> (0, ∞), x |-> e^x. Error = 1.45358 + 0.5 ulps. x > 88.38 -> INF, x < -103.97 -> 0 */ + __m256 r = _mm256_set1_ps(0x1.8p23f); + __m256 z = _mm256_fmadd_ps(x, _mm256_set1_ps(0x1.715476p+0f), r); + __m256 n = _mm256_sub_ps(z, r); + __m256 b = _mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.7f7d1cp-20f),_mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.62e4p-1f), x)); + __m256i e = _mm256_slli_epi32(_mm256_castps_si256(z), 23); + __m256 k = _mm256_castsi256_ps(_mm256_add_epi32(e, _mm256_castps_si256(_mm256_set1_ps(1)))); + __m256i c = _mm256_castps_si256(_mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n), _mm256_set1_ps(126), _CMP_GT_OQ)); + __m256 u = _mm256_mul_ps(b, b); + __m256 j = _mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_set1_ps(0x1.0e4020p-7f), b,_mm256_set1_ps(0x1.573e2ep-5f)), u,_mm256_fmadd_ps(_mm256_set1_ps(0x1.555e66p-3f), b,_mm256_set1_ps(0x1.fffdb6p-2f))),u, _mm256_mul_ps(_mm256_set1_ps(0x1.ffffecp-1f), b)); + if (!_mm256_movemask_ps(_mm256_castsi256_ps(c))) return _mm256_fmadd_ps(j, k, k); + __m256i g = _mm256_and_si256(_mm256_castps_si256(_mm256_cmp_ps(n, _mm256_setzero_ps(), _CMP_LE_OQ)),_mm256_set1_epi32(0x82000000u)); + __m256 s1 = _mm256_castsi256_ps(_mm256_add_epi32(g, _mm256_set1_epi32(0x7f000000u))); + __m256 s2 = _mm256_castsi256_ps(_mm256_sub_epi32(e, g)); + __m256i d = _mm256_castps_si256(_mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n), _mm256_set1_ps(192), _CMP_GT_OQ)); + return _mm256_or_ps( + _mm256_and_ps(_mm256_castsi256_ps(d), _mm256_mul_ps(s1, s1)), + _mm256_andnot_ps( + _mm256_castsi256_ps(d), + _mm256_or_ps( + _mm256_and_ps(_mm256_castsi256_ps(c), + _mm256_mul_ps(_mm256_fmadd_ps(s2, j, s2), s1)), + _mm256_andnot_ps(_mm256_castsi256_ps(c), _mm256_fmadd_ps(k, j, k)))) + ); +} + +static __m256 mag_simd_tanh(__m256 x) { /* tanh' : ℝ -> (-1, 1), x |-> 1 / ((cosh x)^2) */ + __m256 one = _mm256_set1_ps(1.0f); + __m256 neg_one = _mm256_set1_ps(-1.0f); + __m256 two = _mm256_set1_ps(2.0f); + __m256 neg_two = _mm256_set1_ps(-2.0f); + __m256 a = _mm256_mul_ps(neg_two, x); + __m256 b = mag_simd_expf(a); + __m256 c = _mm256_add_ps(one, b); + __m256 inv = _mm256_rcp_ps(c); + inv = _mm256_mul_ps(_mm256_rcp_ps(_mm256_mul_ps(c, inv)), inv); /* Newton–Raphson method */ + inv = _mm256_mul_ps(_mm256_rcp_ps(_mm256_mul_ps(c, inv)), inv); /* Newton–Raphson method */ + return _mm256_fmadd_ps(two, inv, neg_one); +} + +#elif MAG_APPROXMATH && defined(__SSE2__) +static __m128 mag_simd_expf(const __m128 x) { /* exp(x) : ℝ -> (0, ∞), x |-> e^x. Error = 1.45358 + 0.5 ulps. x > 88.38 -> INF, x < -103.97 -> 0 */ + __m128 r = _mm_set1_ps(0x1.8p23f); + __m128 z = _mm_add_ps(_mm_mul_ps(x, _mm_set1_ps(0x1.715476p+0f)), r); + __m128 n = _mm_sub_ps(z, r); + __m128 b = _mm_sub_ps(_mm_sub_ps(x, _mm_mul_ps(n, _mm_set1_ps(0x1.62e4p-1f))), _mm_mul_ps(n, _mm_set1_ps(0x1.7f7d1cp-20f))); + __m128i e = _mm_slli_epi32(_mm_castps_si128(z), 23); + __m128 k = _mm_castsi128_ps(_mm_add_epi32(e, _mm_castps_si128(_mm_set1_ps(1)))); + __m128i c = _mm_castps_si128(_mm_cmpgt_ps(_mm_andnot_ps(_mm_set1_ps(-0.f), n), _mm_set1_ps(126))); + __m128 u = _mm_mul_ps(b, b); + __m128 j = _mm_add_ps(_mm_mul_ps(_mm_add_ps(_mm_mul_ps(_mm_add_ps(_mm_mul_ps(_mm_set1_ps(0x1.0e4020p-7f), b), _mm_set1_ps(0x1.573e2ep-5f)),u), + _mm_add_ps(_mm_mul_ps(_mm_set1_ps(0x1.555e66p-3f), b), _mm_set1_ps(0x1.fffdb6p-2f))), u), + _mm_mul_ps(_mm_set1_ps(0x1.ffffecp-1f), b)); + if (!_mm_movemask_epi8(c)) return _mm_add_ps(_mm_mul_ps(j, k), k); + __m128i g = _mm_and_si128(_mm_castps_si128(_mm_cmple_ps(n, _mm_setzero_ps())),_mm_set1_epi32(0x82000000u)); + __m128 s1 = _mm_castsi128_ps(_mm_add_epi32(g, _mm_set1_epi32(0x7f000000u))); + __m128 s2 = _mm_castsi128_ps(_mm_sub_epi32(e, g)); + __m128i d = _mm_castps_si128(_mm_cmpgt_ps(_mm_andnot_ps(_mm_set1_ps(-0.f), n), _mm_set1_ps(192))); + return _mm_or_ps( + _mm_and_ps(_mm_castsi128_ps(d), _mm_mul_ps(s1, s1)), + _mm_andnot_ps(_mm_castsi128_ps(d), + _mm_or_ps(_mm_and_ps(_mm_castsi128_ps(c), _mm_mul_ps(_mm_add_ps(_mm_mul_ps(s2, j), s2), s1)), + _mm_andnot_ps(_mm_castsi128_ps(c), _mm_add_ps(_mm_mul_ps(k, j), k)))) + ); +} + +static __m128 mag_simd_tanh(__m128 x) { /* tanh' : ℝ -> (-1, 1), x |-> 1 / ((cosh x)^2) */ + __m128 one = _mm_set1_ps(1.0f); + __m128 neg_one = _mm_set1_ps(-1.0f); + __m128 two = _mm_set1_ps(2.0f); + __m128 neg_two = _mm_set1_ps(-2.0f); + __m128 a = _mm_mul_ps(neg_two, x); + __m128 b = mag_simd_expf(a); + __m128 c = _mm_add_ps(one, b); + __m128 inv = _mm_rcp_ps(c); + inv = _mm_mul_ps(_mm_rcp_ps(_mm_mul_ps(c, inv)), inv); /* Newton–Raphson method */ + inv = _mm_mul_ps(_mm_rcp_ps(_mm_mul_ps(c, inv)), inv); /* Newton–Raphson method */ + return _mm_add_ps(neg_one, _mm_mul_ps(two, inv)); +} + +static void mag_simd_sincos(__m128 x, __m128 *osin, __m128 *ocos) { + __m128 sign_mask_sin_ps = _mm_cmplt_ps(x, _mm_set1_ps(0.0f)); + __m128i sign_mask_sin = _mm_castps_si128(sign_mask_sin_ps); + x = _mm_and_ps(x, _mm_castsi128_ps(_mm_set1_epi32(0x7fffffff))); + __m128 y = _mm_mul_ps(x, _mm_set1_ps(1.27323954473516f)); + __m128i emm2 = _mm_cvtps_epi32(y); + emm2 = _mm_add_epi32(emm2, _mm_set1_epi32(1)); + emm2 = _mm_and_si128(emm2, _mm_set1_epi32(~1)); + y = _mm_cvtepi32_ps(emm2); + __m128i poly_mask = _mm_cmpeq_epi32(emm2, _mm_set1_epi32(2)); + x = _mm_add_ps(x, _mm_mul_ps(y, _mm_set1_ps(-0.78515625f))); + x = _mm_add_ps(x, _mm_mul_ps(y, _mm_set1_ps(-2.4187564849853515625e-4f))); + x = _mm_add_ps(x, _mm_mul_ps(y, _mm_set1_ps(-3.77489497744594108e-8f))); + __m128i tmp = _mm_cmpeq_epi32(emm2, _mm_set1_epi32(4)); + sign_mask_sin = _mm_xor_si128(sign_mask_sin, tmp); + __m128i sign_mask_cos = _mm_cmpeq_epi32(_mm_sub_epi32(emm2, _mm_set1_epi32(2)), _mm_set1_epi32(4)); + __m128 z = _mm_mul_ps(x, x); + __m128 y1 = _mm_add_ps(_mm_set1_ps(-1.388731625493765e-003f), _mm_mul_ps(z, _mm_set1_ps(2.443315711809948e-005f))); + __m128 y2 = _mm_add_ps(_mm_set1_ps(8.3321608736e-3f), _mm_mul_ps(z, _mm_set1_ps(-1.9515295891e-4f))); + y1 = _mm_add_ps(_mm_set1_ps(4.166664568298827e-002f), _mm_mul_ps(y1, z)); + y2 = _mm_add_ps(_mm_set1_ps(-1.6666654611e-1f), _mm_mul_ps(y2, z)); + y1 = _mm_mul_ps(y1, z); + y2 = _mm_mul_ps(y2, z); + y1 = _mm_mul_ps(y1, z); + y1 = _mm_sub_ps(y1, _mm_mul_ps(z, _mm_set1_ps(0.5f))); + y2 = _mm_add_ps(x, _mm_mul_ps(y2, x)); + y1 = _mm_add_ps(y1, _mm_set1_ps(1.0f)); + __m128 poly_mask_ps = _mm_castsi128_ps(poly_mask); + __m128 ys = _mm_or_ps(_mm_and_ps(poly_mask_ps, y1), _mm_andnot_ps(poly_mask_ps, y2)); + __m128 yc = _mm_or_ps(_mm_and_ps(poly_mask_ps, y2), _mm_andnot_ps(poly_mask_ps, y1)); + __m128 sign_mask_sin_ps2 = _mm_castsi128_ps(sign_mask_sin); + __m128 neg_ys = _mm_sub_ps(_mm_setzero_ps(), ys); + __m128 osin_ps = _mm_or_ps(_mm_and_ps(sign_mask_sin_ps2, neg_ys), _mm_andnot_ps(sign_mask_sin_ps2, ys)); + __m128 sign_mask_cos_ps = _mm_castsi128_ps(sign_mask_cos); + __m128 neg_yc = _mm_sub_ps(_mm_setzero_ps(), yc); + __m128 ocos_ps = _mm_or_ps(_mm_and_ps(sign_mask_cos_ps, yc), _mm_andnot_ps(sign_mask_cos_ps, neg_yc)); + *osin = osin_ps; + *ocos = ocos_ps; +} + +#endif + +static void MAG_HOTPROC mag_vadd_f32( + const int64_t n, + float* const o, + const float* const x, + const float* const y +) { +#ifdef MAG_ACCELERATE + vDSP_vadd(y, 1, x, 1, o, 1, n); +#else + for (int64_t i=0; i < n; ++i) { + o[i] = x[i] + y[i]; + } +#endif +} + +static void MAG_HOTPROC mag_vsub_f32( + const int64_t n, + float* const o, + const float* const x, + const float* const y +) { +#ifdef MAG_ACCELERATE + vDSP_vsub(y, 1, x, 1, o, 1, n); +#else + for (int64_t i=0; i < n; ++i) { + o[i] = x[i] - y[i]; + } +#endif +} + +static void MAG_HOTPROC mag_vmul_f32( + const int64_t n, + float* const o, + const float* const x, + const float* const y +) { +#ifdef MAG_ACCELERATE + vDSP_vmul(y, 1, x, 1, o, 1, n); +#else + for (int64_t i=0; i < n; ++i) { + o[i] = x[i] * y[i]; + } +#endif +} + +static void MAG_HOTPROC mag_vdiv_f32( + const int64_t n, + float* const o, + const float* const x, + const float* const y +) { +#ifdef MAG_ACCELERATE + vDSP_vdiv(y, 1, x, 1, o, 1, n); +#else + for (int64_t i=0; i < n; ++i) { + o[i] = x[i] / y[i]; + } +#endif +} + +static void MAG_HOTPROC mag_vadds_f32( + const int64_t n, + float* const o, + const float* const x, + const float y +) { + for (int64_t i=0; i < n; ++i) { + o[i] = x[i] + y; + } +} + +static void MAG_HOTPROC mag_vsubs_f32( + const int64_t n, + float* const o, + const float* const x, + const float y +) { + for (int64_t i=0; i < n; ++i) { + o[i] = x[i] - y; + } +} + +static void MAG_HOTPROC mag_vmuls_f32( + const int64_t n, + float* const o, + const float* const x, + const float y +) { + for (int64_t i=0; i < n; ++i) { + o[i] = x[i] * y; + } +} + +static void MAG_HOTPROC mag_vdivs_f32( + const int64_t n, + float* const o, + const float* const x, + const float y +) { + for (int64_t i=0; i < n; ++i) { + o[i] = x[i] / y; + } +} + +static float MAG_UNUSED MAG_HOTPROC mag_vdot_f32( + const int64_t n, + const float* const x, + const float* const y +) { +#if (defined(__aarch64__) && defined(__ARM_NEON)) || defined(_M_ARM64) + const int64_t k = n & -16; + float32x4_t acc[4] = {vdupq_n_f32(0)}; + float32x4_t vx[4]; + float32x4_t vy[4]; + for (int64_t i=0; i < k; i += 16) { /* Process STEP elements at a time */ + vx[0] = vld1q_f32(x+i+(0<<2)); + vy[0] = vld1q_f32(y+i+(0<<2)); + acc[0] = vfmaq_f32(acc[0], vx[0], vy[0]); + vx[1] = vld1q_f32(x+i+(1<<2)); + vy[1] = vld1q_f32(y+i+(1<<2)); + acc[1] = vfmaq_f32(acc[1], vx[1], vy[1]); + vx[2] = vld1q_f32(x+i+(2<<2)); + vy[2] = vld1q_f32(y+i+(2<<2)); + acc[2] = vfmaq_f32(acc[2], vx[2], vy[2]); + vx[3] = vld1q_f32(x+i+(3<<2)); + vy[3] = vld1q_f32(y+i+(3<<2)); + acc[3] = vfmaq_f32(acc[3], vx[3], vy[3]); + } + acc[1] = vaddq_f32(acc[1], acc[3]); /* Fold acc[1] += acc[3] */ + *acc = vaddq_f32(*acc, acc[2]); /* Fold acc[0] += acc[2] */ + *acc = vaddq_f32(*acc, acc[1]); /* Fold acc[0] += acc[1] */ + float sum = vaddvq_f32(*acc); /* Reduce to scalar with horizontal sum. */ + for (int64_t i=k; i < n; ++i) sum += x[i]*y[i]; /* Process leftovers scalar-wise */ + return sum; +#elif defined(__AVX512F__) && defined(__FMA__) + const int64_t k = n & -64; + __m512 acc[4] = {_mm512_setzero_ps()}; + __m512 vx[4]; + __m512 vy[4]; + for (int64_t i=0; i < k; i += 64) { + vx[0] = _mm512_loadu_ps(x+i+(0<<4)); + vy[0] = _mm512_loadu_ps(y+i+(0<<4)); + acc[0] = _mm512_fmadd_ps(vx[0], vy[0], acc[0]); + vx[1] = _mm512_loadu_ps(x+i+(1<<4)); + vy[1] = _mm512_loadu_ps(y+i+(1<<4)); + acc[1] = _mm512_fmadd_ps(vx[1], vy[1], acc[1]); + vx[2] = _mm512_loadu_ps(x+i+(2<<4)); + vy[2] = _mm512_loadu_ps(y+i+(2<<4)); + acc[2] = _mm512_fmadd_ps(vx[2], vy[2], acc[2]); + vx[3] = _mm512_loadu_ps(x+i+(3<<4)); + vy[3] = _mm512_loadu_ps(y+i+(3<<4)); + acc[3] = _mm512_fmadd_ps(vx[3], vy[3], acc[3]); + } + acc[1] = _mm512_add_ps(acc[1], acc[3]); + *acc = _mm512_add_ps(*acc, acc[2]); + *acc = _mm512_add_ps(*acc, acc[1]); + float sum = _mm512_reduce_add_ps(*acc); + for (int64_t i=k; i < n; ++i) sum += x[i]*y[i]; /* Process leftovers scalar-wise */ + return sum; +#elif defined(__AVX__) && defined(__FMA__) + const int64_t k = n & -32; + __m256 acc[4] = {_mm256_setzero_ps()}; + __m256 vx[4]; + __m256 vy[4]; + for (int64_t i=0; i < k; i += 32) { + vx[0] = _mm256_loadu_ps(x+i+(0<<3)); + vy[0] = _mm256_loadu_ps(y+i+(0<<3)); + acc[0] = _mm256_fmadd_ps(vx[0], vy[0], acc[0]); + vx[1] = _mm256_loadu_ps(x+i+(1<<3)); + vy[1] = _mm256_loadu_ps(y+i+(1<<3)); + acc[1] = _mm256_fmadd_ps(vx[1], vy[1], acc[1]); + vx[2] = _mm256_loadu_ps(x+i+(2<<3)); + vy[2] = _mm256_loadu_ps(y+i+(2<<3)); + acc[2] = _mm256_fmadd_ps(vx[2], vy[2], acc[2]); + vx[3] = _mm256_loadu_ps(x+i+(3<<3)); + vy[3] = _mm256_loadu_ps(y+i+(3<<3)); + acc[3] = _mm256_fmadd_ps(vx[3], vy[3], acc[3]); + } + acc[1] = _mm256_add_ps(acc[1], acc[3]); + *acc = _mm256_add_ps(*acc, acc[2]); + *acc = _mm256_add_ps(*acc, acc[1]); + __m128 v0 = _mm_add_ps(_mm256_castps256_ps128(*acc), _mm256_extractf128_ps(*acc, 1)); + v0 = _mm_hadd_ps(v0, v0); + v0 = _mm_hadd_ps(v0, v0); + float sum = _mm_cvtss_f32(v0); + for (int64_t i=k; i < n; ++i) sum += x[i]*y[i]; /* Process leftovers scalar-wise */ + return sum; +#elif defined(__SSE2__) + const int64_t k = n & -16; + __m128 acc[4] = {_mm_setzero_ps()}; + __m128 vx[4]; + __m128 vy[4]; + for (int64_t i=0; i < k; i += 16) { + vx[0] = _mm_loadu_ps(x+i+(0<<2)); + vy[0] = _mm_loadu_ps(y+i+(0<<2)); + acc[0] = _mm_add_ps(acc[0], _mm_mul_ps(vx[0], vy[0])); + vx[1] = _mm_loadu_ps(x+i+(1<<2)); + vy[1] = _mm_loadu_ps(y+i+(1<<2)); + acc[1] = _mm_add_ps(acc[1], _mm_mul_ps(vx[1], vy[1])); + vx[2] = _mm_loadu_ps(x+i+(2<<2)); + vy[2] = _mm_loadu_ps(y+i+(2<<2)); + acc[2] = _mm_add_ps(acc[2], _mm_mul_ps(vx[2], vy[2])); + vx[3] = _mm_loadu_ps(x+i+(3<<2)); + vy[3] = _mm_loadu_ps(y+i+(3<<2)); + acc[3] = _mm_add_ps(acc[3], _mm_mul_ps(vx[3], vy[3])); + } + #ifdef __SSE3__ + acc[1] = _mm_add_ps(acc[1], acc[3]); + *acc = _mm_add_ps(*acc, acc[2]); + *acc = _mm_add_ps(*acc, acc[1]); + *acc = _mm_hadd_ps(*acc, *acc); + *acc = _mm_hadd_ps(*acc, *acc); + float sum = _mm_cvtss_f32(*acc); + #else + __m128 shuf = _mm_shuffle_ps(*acc, *acc, _MM_SHUFFLE(2, 3, 0, 1)); + __m128 sums = _mm_add_ps(*acc, shuf); + shuf = _mm_movehl_ps(shuf, sums); + sums = _mm_add_ss(sums, shuf); + float sum = _mm_cvtss_f32(sums); + #endif + for (int64_t i=k; i < n; ++i) sum += x[i]*y[i]; /* Process leftovers scalar-wise */ + return sum; +#else + double r = 0.0; + for (int64_t i=0; i < n; ++i) r += (double)x[i] * (double)y[i]; + return (float)r; +#endif +} + +static double MAG_HOTPROC mag_vsum_f64_f32( /* Σx. */ + const int64_t n, + const float* const x +) { +#ifdef MAG_ACCELERATE + float sum; + vDSP_sve(x, 1, &sum, n); + return (double)sum; +#else + double sum = 0.0; + for (int64_t i=0; i < n; ++i) { + sum += (double)x[i]; + } + return sum; +#endif +} + +static float MAG_HOTPROC mag_vmin_f32( /* min x */ + const int64_t n, + const float* const x +) { + float min = INFINITY; + for (int64_t i=0; i < n; ++i) { + min = fminf(min, x[i]); + } + return min; +} + +static float MAG_HOTPROC mag_vmax_f32( /* max x */ + const int64_t n, + const float* const x +) { + float min = -INFINITY; + for (int64_t i=0; i < n; ++i) { + min = fmaxf(min, x[i]); + } + return min; +} + +static void MAG_HOTPROC mag_vabs_f32( /* o = |x| */ + const int64_t n, + float* const o, + const float* const x +) { + for (int64_t i=0; i < n; ++i) { + o[i] = fabsf(x[i]); + } +} + +static void MAG_HOTPROC mag_vneg_f32( /* o = -x */ + const int64_t n, + float* const o, + const float* const x +) { + for (int64_t i=0; i < n; ++i) { + o[i] = -x[i]; + } +} + +static void MAG_HOTPROC mag_vlog_f32( /* o = log x */ + const int64_t n, + float* const o, + const float* const x +) { + int64_t i=0; +#if MAG_APPROXMATH && (defined(__aarch64__) && defined(__ARM_NEON)) || defined(_M_ARM64) + const float32x4_t one = vdupq_n_f32(1); + for (; i+3 < n; i += 4) { + float32x4_t xi = vld1q_f32(x+i); + xi = vmaxq_f32(xi, vdupq_n_f32(0)); + uint32x4_t invalid_mask = vcleq_f32(xi, vdupq_n_f32(0)); + int32x4_t ux = vreinterpretq_s32_f32(xi); + int32x4_t emm0 = vshrq_n_s32(ux, 23); + ux = vandq_s32(ux, vdupq_n_s32(~0x7f800000u)); + ux = vorrq_s32(ux, vreinterpretq_s32_f32(vdupq_n_f32(0.5f))); + xi = vreinterpretq_f32_s32(ux); + emm0 = vsubq_s32(emm0, vdupq_n_s32(0x7f)); + float32x4_t e = vcvtq_f32_s32(emm0); + e = vaddq_f32(e, one); + uint32x4_t mask = vcltq_f32(xi, vdupq_n_f32(0.707106781186547524f)); + float32x4_t tmp = vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(xi), mask)); + xi = vsubq_f32(xi, one); + e = vsubq_f32(e, vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(one), mask))); + xi = vaddq_f32(xi, tmp); + float32x4_t z = vmulq_f32(xi, xi); + float32x4_t y = vdupq_n_f32(7.0376836292e-2f); + y = vmlaq_f32(vdupq_n_f32(-1.1514610310e-1f), y, xi); + y = vmlaq_f32(vdupq_n_f32(1.1676998740e-1f), y, xi); + y = vmlaq_f32(vdupq_n_f32(-1.2420140846e-1f), y, xi); + y = vmlaq_f32(vdupq_n_f32(1.4249322787e-1f), y, xi); + y = vmlaq_f32(vdupq_n_f32(-1.6668057665e-1f), y, xi); + y = vmlaq_f32(vdupq_n_f32(2.0000714765e-1f), y, xi); + y = vmlaq_f32(vdupq_n_f32(-2.4999993993e-1f), y, xi); + y = vmlaq_f32(vdupq_n_f32(3.3333331174e-1f), y, xi); + y = vmulq_f32(y, xi); + y = vmulq_f32(y, z); + y = vmlaq_f32(y, e, vdupq_n_f32(-2.12194440e-4f)); + y = vmlsq_f32(y, z, vdupq_n_f32(0.5f)); + xi = vaddq_f32(xi, y); + xi = vmlaq_f32(xi, e, vdupq_n_f32(0.693359375f)); + xi = vreinterpretq_f32_u32(vorrq_u32(vreinterpretq_u32_f32(xi), invalid_mask)); + vst1q_f32(o+i, xi); + } +#elif MAG_APPROXMATH && defined(__AVX512F__) && defined(__AVX512DQ__) +#error TODO +#elif MAG_APPROXMATH && defined(__AVX2__) && defined(__FMA__) +#error TODO +#elif MAG_APPROXMATH && defined(__SSE2__) + const __m128 one = _mm_set1_ps(1.0f); + for (; i+3 < n; i += 4) { + __m128 xi = _mm_loadu_ps(x+i); + xi = _mm_max_ps(xi, _mm_set1_ps(0.0f)); + __m128 invalid_mask = _mm_cmple_ps(xi, _mm_set1_ps(0.0f)); + __m128i ux = _mm_castps_si128(xi); + __m128i emm0 = _mm_srli_epi32(ux, 23); + ux = _mm_and_si128(ux, _mm_set1_epi32(~0x7f800000u)); + ux = _mm_or_si128(ux, _mm_castps_si128(_mm_set1_ps(0.5f))); + xi = _mm_castsi128_ps(ux); + emm0 = _mm_sub_epi32(emm0, _mm_set1_epi32(0x7f)); + __m128 e = _mm_cvtepi32_ps(emm0); + e = _mm_add_ps(e, one); + __m128 mask = _mm_cmplt_ps(xi, _mm_set1_ps(0.707106781186547524f)); + __m128 tmp = _mm_and_ps(xi, mask); + xi = _mm_sub_ps(xi, one); + e = _mm_sub_ps(e, _mm_and_ps(one, mask)); + xi = _mm_add_ps(xi, tmp); + __m128 z = _mm_mul_ps(xi, xi); + __m128 y = _mm_set1_ps(7.0376836292e-2f); + y = _mm_add_ps(_mm_mul_ps(y, xi), _mm_set1_ps(-1.1514610310e-1f)); + y = _mm_add_ps(_mm_mul_ps(y, xi), _mm_set1_ps(1.1676998740e-1f)); + y = _mm_add_ps(_mm_mul_ps(y, xi), _mm_set1_ps(-1.2420140846e-1f)); + y = _mm_add_ps(_mm_mul_ps(y, xi), _mm_set1_ps(1.4249322787e-1f)); + y = _mm_add_ps(_mm_mul_ps(y, xi), _mm_set1_ps(-1.6668057665e-1f)); + y = _mm_add_ps(_mm_mul_ps(y, xi), _mm_set1_ps(2.0000714765e-1f)); + y = _mm_add_ps(_mm_mul_ps(y, xi), _mm_set1_ps(-2.4999993993e-1f)); + y = _mm_add_ps(_mm_mul_ps(y, xi), _mm_set1_ps(3.3333331174e-1f)); + y = _mm_mul_ps(y, xi); + y = _mm_mul_ps(y, z); + y = _mm_add_ps(_mm_mul_ps(e, _mm_set1_ps(-2.12194440e-4f)), y); + y = _mm_sub_ps(y, _mm_mul_ps(z, _mm_set1_ps(0.5f))); + xi = _mm_add_ps(xi, y); + xi = _mm_add_ps(_mm_mul_ps(e, _mm_set1_ps(0.693359375f)), xi); + xi = _mm_or_ps(xi, invalid_mask); + _mm_storeu_ps(o+i, xi); + } +#endif + for (; i < n; ++i) { /* Process leftovers scalar-wise */ + o[i] = logf(x[i]); + } +} + +static void MAG_HOTPROC mag_vsqr_f32( /* o = x² */ + const int64_t n, + float* const o, + const float* const x +) { + for (int64_t i=0; i < n; ++i) + o[i] = x[i]*x[i]; +} + +static void MAG_HOTPROC mag_vsqrt_f32( /* o = √x */ + const int64_t n, + float* const o, + const float* const x +) { + for (int64_t i=0; i < n; ++i) { + o[i] = sqrtf(x[i]); + } +} + +static void MAG_HOTPROC mag_vsin_f32( /* o = sin x */ + const int64_t n, + float* const o, + const float* const x +) { + int64_t i=0; +#if MAG_APPROXMATH && (defined(__aarch64__) && defined(__ARM_NEON)) || defined(_M_ARM64) + for (; i+3 < n; i += 4) { + float32x4_t xi = vld1q_f32(x+i); + float32x4_t ocos; + mag_simd_sincos(xi, &xi, &ocos); + vst1q_f32(o+i, xi); + } +#elif MAG_APPROXMATH && defined(__AVX512F__) && defined(__AVX512DQ__) +#error TODO +#elif MAG_APPROXMATH && defined(__AVX2__) && defined(__FMA__) +#error TODO +#elif MAG_APPROXMATH && defined(__SSE2__) + for (; i+3 < n; i += 4) { + __m128 xi = _mm_loadu_ps(x+i); + __m128 ocos; + mag_simd_sincos(xi, &xi, &ocos); + _mm_storeu_ps(o+i, xi); + } +#endif + for (; i < n; ++i) { /* Process leftovers scalar-wise */ + o[i] = sinf(x[i]); + } +} + +static void MAG_HOTPROC mag_vcos_f32( /* o = cos x */ + const int64_t n, + float* const o, + const float* const x +) { + int64_t i=0; +#if MAG_APPROXMATH && (defined(__aarch64__) && defined(__ARM_NEON)) || defined(_M_ARM64) + for (; i+3 < n; i += 4) { + float32x4_t xi = vld1q_f32(x+i); + float32x4_t osin; + mag_simd_sincos(xi, &osin, &xi); + vst1q_f32(o+i, xi); + } +#elif MAG_APPROXMATH && defined(__AVX512F__) && defined(__AVX512DQ__) +#error TODO +#elif MAG_APPROXMATH && defined(__AVX2__) && defined(__FMA__) +#error TODO +#elif MAG_APPROXMATH && defined(__SSE2__) + for (; i+3 < n; i += 4) { + __m128 xi = _mm_loadu_ps(x+i); + __m128 osin; + mag_simd_sincos(xi, &osin, &xi); + _mm_storeu_ps(o+i, xi); + } +#endif + for (; i < n; ++i) { /* Process leftovers scalar-wise */ + o[i] = cosf(x[i]); + } +} + +static void MAG_HOTPROC mag_vstep_f32( /* Heaviside step function. */ + const int64_t n, + float* const o, + const float* const x +) { + for (int64_t i=0; i < n; ++i) { + o[i] = x[i] >= 0.0f ? 1.0f : 0.0f; + } +} + +static void MAG_HOTPROC mag_vsoftmax_f32( /* softmax : ℝ -> (0, ∞), x |-> e^x */ + const int64_t n, + float* const o, + const float* const x +) { + int64_t i=0; +#if MAG_APPROXMATH && (defined(__aarch64__) && defined(__ARM_NEON)) || defined(_M_ARM64) + for (; i+3 < n; i += 4) { + vst1q_f32(o+i, mag_simd_expf(vld1q_f32(x+i))); + } +#elif MAG_APPROXMATH && defined(__AVX512F__) && defined(__AVX512DQ__) + for (; i+15 < n; i += 16) { + _mm512_storeu_ps(o+i, mag_simd_expf(_mm512_loadu_ps(x+i))); + } +#elif MAG_APPROXMATH && defined(__AVX2__) && defined(__FMA__) + for (; i+7 < n; i += 8) { + _mm256_storeu_ps(o+i, mag_simd_expf(_mm256_loadu_ps(x+i))); + } +#elif MAG_APPROXMATH && defined(__SSE2__) + for (; i+3 < n; i += 4) { + _mm_storeu_ps(o+i, mag_simd_expf(_mm_loadu_ps(x+i))); + } +#endif + for (; i < n; ++i) o[i] = expf(x[i]); /* Process leftovers scalar-wise */ +} + +static void MAG_HOTPROC mag_vsoftmax_dv_f32( /* softmax' = softmax : ℝ -> (0, ∞), x |-> e^x */ + const int64_t n, + float* const o, + const float* const x +) { + mag_vsoftmax_f32(n, o, x); +} + +static void MAG_HOTPROC mag_vsigmoid_f32( /* σ : ℝ -> (0, 1), x |-> 1/(1 + e^(-x)) */ + const int64_t n, + float* const o, + const float* const x +) { + int64_t i=0; +#if MAG_APPROXMATH && (defined(__aarch64__) && defined(__ARM_NEON)) || defined(_M_ARM64) + const float32x4_t one = vdupq_n_f32(1.0f); + const float32x4_t zero = vdupq_n_f32(0.0f); + for (; i+3 < n; i += 4) { + float32x4_t xx = vld1q_f32(x+i); + float32x4_t neg_x = vsubq_f32(zero, xx); + float32x4_t exp_neg_x = mag_simd_expf(neg_x); + float32x4_t one_plus_exp_neg_x = vaddq_f32(one, exp_neg_x); + vst1q_f32(o+i, vdivq_f32(one, one_plus_exp_neg_x)); + } +#elif MAG_APPROXMATH && defined(__AVX512F__) && defined(__AVX512DQ__) + __m512 one = _mm512_set1_ps(1.0f); + __m512 zero = _mm512_setzero_ps(); + for (; i+15 < n; i += 16) { + __m512 xx = _mm512_loadu_ps(x+i); + __m512 neg_x = _mm512_sub_ps(zero, xx); + __m512 exp_neg_x = mag_simd_expf(neg_x); + __m512 one_plus_exp_neg_x = _mm512_add_ps(one, exp_neg_x); + _mm512_storeu_ps(o+i, _mm512_div_ps(one, one_plus_exp_neg_x)); + } +#elif MAG_APPROXMATH && defined(__AVX2__) && defined(__FMA__) + __m256 one = _mm256_set1_ps(1.0f); + __m256 zero = _mm256_setzero_ps(); + for (; i+7 < n; i += 8) { + __m256 xx = _mm256_loadu_ps(x+i); + __m256 neg_x = _mm256_sub_ps(zero, xx); + __m256 exp_neg_x = mag_simd_expf(neg_x); + __m256 one_plus_exp_neg_x = _mm256_add_ps(one, exp_neg_x); + _mm256_storeu_ps(o+i, _mm256_div_ps(one, one_plus_exp_neg_x)); + } +#elif MAG_APPROXMATH && defined(__SSE2__) + __m128 one = _mm_set1_ps(1.0f); + __m128 zero = _mm_setzero_ps(); + for (; i+3 < n; i += 4) { + __m128 xx = _mm_loadu_ps(x+i); + __m128 neg_x = _mm_sub_ps(zero, xx); + __m128 exp_neg_x = mag_simd_expf(neg_x); + __m128 one_plus_exp_neg_x = _mm_add_ps(one, exp_neg_x); + _mm_storeu_ps(o+i, _mm_div_ps(one, one_plus_exp_neg_x)); + } +#endif + for (; i < n; ++i) o[i] = 1.0f / (1.0f + expf(-x[i])); /* Process leftovers scalar-wise */ +} + +static void MAG_HOTPROC mag_vsigmoid_dv_f32( /* σ' : ℝ -> (0, 1), x |-> x * (1-x) */ + const int64_t n, + float* const o, + const float* const x +) { + for (int64_t i=0; i < n; ++i) { + o[i] = x[i] * (1.0f - x[i]); + } +} + +static void MAG_HOTPROC mag_vhard_sigmoid_f32( /* σ^ : ℝ -> (0, 1), x |-> min(1, max(0, (x + 3)/6)) */ + const int64_t n, + float* const o, + const float* const x +) { + for (int64_t i=0; i < n; ++i) { + o[i] = fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); + } +} + +static void MAG_HOTPROC mag_vsilu_f32( /* silu : ℝ -> ℝ, x |-> x/(1 + e^(-x)) */ + const int64_t n, + float* const o, + const float* const x +) { + int64_t i=0; +#if MAG_APPROXMATH && (defined(__aarch64__) && defined(__ARM_NEON)) || defined(_M_ARM64) + float32x4_t one = vdupq_n_f32(1.0f); + float32x4_t zero = vdupq_n_f32(0.0f); + for (; i+3 < n; i += 4) { + float32x4_t xx = vld1q_f32(x+i); + float32x4_t neg_x = vsubq_f32(zero, xx); + float32x4_t exp_neg_x = mag_simd_expf(neg_x); + float32x4_t one_plus_exp_neg_x = vaddq_f32(one, exp_neg_x); + vst1q_f32(o+i, vdivq_f32(xx, one_plus_exp_neg_x)); + } +#elif MAG_APPROXMATH && defined(__AVX512F__) && defined(__AVX512DQ__) + __m512 one = _mm512_set1_ps(1); + __m512 zero = _mm512_setzero_ps(); + for (; i+15 < n; i += 16) { + __m512 xx = _mm512_loadu_ps(x+i); + __m512 neg_x = _mm512_sub_ps(zero, xx); + __m512 exp_neg_x = mag_simd_expf(neg_x); + __m512 one_plus_exp_neg_x = _mm512_add_ps(one, exp_neg_x); + _mm512_storeu_ps(o+i, _mm512_div_ps(xx, one_plus_exp_neg_x)); + } +#elif MAG_APPROXMATH && defined(__AVX2__) && defined(__FMA__) + __m256 one = _mm256_set1_ps(1); + __m256 zero = _mm256_setzero_ps(); + for (; i+7 < n; i += 8) { + const __m256 xx = _mm256_loadu_ps(x+i); + __m256 neg_x = _mm256_sub_ps(zero, xx); + __m256 exp_neg_x = mag_simd_expf(neg_x); + __m256 one_plus_exp_neg_x = _mm256_add_ps(one, exp_neg_x); + _mm256_storeu_ps(o+i, _mm256_div_ps(xx, one_plus_exp_neg_x)); + } +#elif MAG_APPROXMATH && defined(__SSE2__) + __m128 one = _mm_set1_ps(1); + __m128 zero = _mm_setzero_ps(); + for (; i+3 < n; i += 4) { + __m128 xx = _mm_loadu_ps(x+i); + __m128 neg_x = _mm_sub_ps(zero, xx); + __m128 exp_neg_x = mag_simd_expf(neg_x); + __m128 one_plus_exp_neg_x = _mm_add_ps(one, exp_neg_x); + _mm_storeu_ps(o+i, _mm_div_ps(xx, one_plus_exp_neg_x)); + } +#endif + for (; i < n; ++i) { + o[i] = x[i] / (1.0f + expf(-x[i])); + } +} + +static void MAG_HOTPROC mag_vsilu_dv_f32( /* silu' : ℝ -> TODO */ + const int64_t n, + float* const o, + const float* const x +) { + for (int64_t i=0; i < n; ++i) { + mag_panic("NYI!"); + } +} + +static void MAG_HOTPROC mag_vtanh_f32( /* tanh : ℝ -> (-1, 1), x |-> tanh x */ + const int64_t n, + float* const o, + const float* const x +) { + int64_t i=0; +#if MAG_APPROXMATH && (defined(__aarch64__) && defined(__ARM_NEON)) || defined(_M_ARM64) + for (; i+3 < n; i += 4) { + vst1q_f32(o+i, mag_simd_tanh(vld1q_f32(x+i))); + } +#elif MAG_APPROXMATH && defined(__AVX512F__) && defined(__AVX512DQ__) + for (; i+15 < n; i += 16) { + _mm512_storeu_ps(o+i, mag_simd_tanh(_mm512_loadu_ps(x+i))); + } +#elif MAG_APPROXMATH && defined(__AVX2__) && defined(__FMA__) + for (; i+7 < n; i += 8) { + _mm256_storeu_ps(o+i, mag_simd_tanh(_mm256_loadu_ps(x+i))); + } +#elif MAG_APPROXMATH && defined(__SSE2__) + for (; i+3 < n; i += 4) { + _mm_storeu_ps(o+i, mag_simd_tanh(_mm_loadu_ps(x+i))); + } +#endif + for (; i < n; ++i) { + o[i] = tanhf(x[i]); + } +} + +static void MAG_HOTPROC mag_vtanh_dv_f32( /* tanh' : ℝ -> (-1, 1), x |-> 1 / ((cosh x)^2) */ + const int64_t n, + float* const o, + const float* const x +) { + for (int64_t i=0; i < n; ++i) { + const float cx = coshf(x[i]); + o[i] = 1.0f / (cx*cx); + } +} + +static void MAG_HOTPROC mag_vrelu_f32( /* relu : ℝ -> ℝ^+, x |-> max {x, 0} */ + const int64_t n, + float* const o, + const float* const x +) { + for (int64_t i=0; i < n; ++i) { + o[i] = mag_xmax(x[i], 0.0f); + } +} + +static void MAG_HOTPROC mag_vrelu_dv_f32( /* relu' : ℝ -> ℝ^+, x |-> { 0 if x < 0, UB if x = 0, else 1 */ + const int64_t n, + float* const o, + const float* const x +) { + for (int64_t i=0; i < n; ++i) { + o[i] = x[i] <= 0.0f ? 0.0f : 1.0f; /* relu' is mathematically undefined for x = 0, but we return 0 in this case. */ + } +} + +static void MAG_HOTPROC mag_vgelu_f32( /* gelu : ℝ -> ℝ, x |-> TODO */ + const int64_t n, + float* const o, + const float* const x +) { + int64_t i=0; +#if MAG_APPROXMATH && (defined(__aarch64__) && defined(__ARM_NEON)) || defined(_M_ARM64) + float32x4_t half = vdupq_n_f32(0.5f); + float32x4_t one = vdupq_n_f32(1.0f); + float32x4_t coeff1 = vdupq_n_f32(0.79788456080286535587989211986876f); + float32x4_t coeff2 = vdupq_n_f32(MAG_GELU_COEFF); + for (; i+3 < n; i += 4) { + float32x4_t xx = vld1q_f32(x+i); + float32x4_t a = vaddq_f32(one, vmulq_f32(coeff2, vmulq_f32(xx, xx))); + float32x4_t b = vaddq_f32(one, mag_simd_tanh(vmulq_f32(coeff1, vmulq_f32(xx, a)))); + float32x4_t c = vmulq_f32(half, vmulq_f32(xx, b)); + vst1q_f32(o+i, c); + } +#elif MAG_APPROXMATH && defined(__AVX512F__) && defined(__AVX512DQ__) + __m512 half = _mm512_set1_ps(0.5f); + __m512 one = _mm512_set1_ps(1.0f); + __m512 coeff1 = _mm512_set1_ps(0.79788456080286535587989211986876f); + __m512 coeff2 = _mm512_set1_ps(MAG_GELU_COEFF); + for (; i+15 < n; i += 16) { + __m512 xx = _mm512_loadu_ps(x+i); + __m512 a = _mm512_fmadd_ps(coeff2, _mm512_mul_ps(xx, xx), one); + __m512 b = _mm512_add_ps(one, mag_simd_tanh(_mm512_mul_ps(coeff1, _mm512_mul_ps(xx, a)))); + __m512 c = _mm512_mul_ps(half, _mm512_mul_ps(xx, b)); + _mm512_storeu_ps(o+i, c); + } +#elif MAG_APPROXMATH && defined(__AVX2__) && defined(__FMA__) + __m256 half = _mm256_set1_ps(0.5f); + __m256 one = _mm256_set1_ps(1.0f); + __m256 coeff1 = _mm256_set1_ps(0.79788456080286535587989211986876f); + __m256 coeff2 = _mm256_set1_ps(MAG_GELU_COEFF); + for (; i+7 < n; i += 8) { + __m256 xx = _mm256_loadu_ps(x+i); + __m256 a = _mm256_fmadd_ps(coeff2, _mm256_mul_ps(xx, xx), one); + __m256 b = _mm256_add_ps(one, mag_simd_tanh(_mm256_mul_ps(coeff1, _mm256_mul_ps(xx, a)))); + __m256 c = _mm256_mul_ps(half, _mm256_mul_ps(xx, b)); + _mm256_storeu_ps(o+i, c); + } +#elif MAG_APPROXMATH && defined(__SSE2__) + __m128 half = _mm_set1_ps(0.5f); + __m128 one = _mm_set1_ps(1.0f); + __m128 coeff1 = _mm_set1_ps(0.79788456080286535587989211986876f); + __m128 coeff2 = _mm_set1_ps(MAG_GELU_COEFF); + for (; i+3 < n; i += 4) { + __m128 xx = _mm_loadu_ps(x+i); + __m128 a = _mm_add_ps(one, _mm_mul_ps(coeff2, _mm_mul_ps(xx, xx))); + __m128 b = _mm_add_ps(one, mag_simd_tanh(_mm_mul_ps(coeff1, _mm_mul_ps(xx, a)))); + __m128 c = _mm_mul_ps(half, _mm_mul_ps(xx, b)); + _mm_storeu_ps(o+i, c); + } +#endif + for (; i < n; ++i) { + o[i] = 0.5f*x[i]*(1.0f + tanhf(0.79788456080286535587989211986876f*x[i]*(1.0f + MAG_GELU_COEFF*x[i]*x[i]))); + } +} + +static void MAG_HOTPROC mag_vgelu_dv_f32( /* gelu' : ℝ -> ℝ, x |-> TODO */ + const int64_t n, + float* const o, + const float* const x +) { + for (int64_t i=0; i < n; ++i) { + mag_panic("NYI"); /* TODO */ + } +} + +static void mag_blas_nop( + const mag_cpu_thread_local_ctx_t* const bci, + mag_tensor_t* const r, + const mag_tensor_t** const inputs +) { + (void)bci; + (void)r; + (void)inputs; +} + +static void mag_blas_clone( + const mag_cpu_thread_local_ctx_t* const bci, + mag_tensor_t* const r, + const mag_tensor_t** const inputs +) { + const mag_tensor_t* const x = inputs[0]; + mag_assert2(mag_tensor_is_shape_eq(x, r)); + float* const b_r = mag_f32p_mut(r); + const float* const b_x = mag_f32p(x); + memcpy(b_r, b_x, mag_tensor_data_size(r)); +} + +static void MAG_HOTPROC mag_blas_mean_f32( /* Σx/n Arithmetic mean */ + const mag_cpu_thread_local_ctx_t* const bci, + mag_tensor_t* const r, + const mag_tensor_t** const inputs /* Assumes correct inputs for op, all != NULL! */ +) { + (void)bci; + const mag_tensor_t* const x = inputs[0]; + float* const b_r = mag_f32p_mut(r); + const float* const b_x = mag_f32p(x); + mag_load_local_storage_group(r, r_s, strides); + mag_load_local_storage_group(x, x_d, shape); + mag_load_local_storage_group(x, x_s, strides); + double sum = 0.0; + for (int64_t i5=0; i5 < x_d5; ++i5) { + for (int64_t i4=0; i4 < x_d4; ++i4) { + for (int64_t i3=0; i3 < x_d3; ++i3) { + for (int64_t i2=0; i2 < x_d2; ++i2) { + for (int64_t i1=0; i1 < x_d1; ++i1) { + const float* const p_x = b_x + i1*x_s1 + i2*x_s2 + i3*x_s3 + i4*x_s4 + i5*x_s5; + mag_bnd_chk(p_x, b_x, mag_tensor_data_size(x)); + sum += mag_vsum_f64_f32( + x_d0, + p_x + ); + } + } + } + } + } + sum /= (double)x->numel; + *b_r = (float)sum; +} + +static void MAG_HOTPROC mag_blas_min_f32( /* min x */ + const mag_cpu_thread_local_ctx_t* const bci, + mag_tensor_t* const r, + const mag_tensor_t** const inputs /* Assumes correct inputs for op, all != NULL! */ +) { + (void)bci; + const mag_tensor_t* const x = inputs[0]; + float* const b_r = mag_f32p_mut(r); + const float* const b_x = mag_f32p(x); + mag_load_local_storage_group(r, r_s, strides); + mag_load_local_storage_group(x, x_d, shape); + mag_load_local_storage_group(x, x_s, strides); + float min = INFINITY; + for (int64_t i5=0; i5 < x_d5; ++i5) { + for (int64_t i4=0; i4 < x_d4; ++i4) { + for (int64_t i3=0; i3 < x_d3; ++i3) { + for (int64_t i2=0; i2 < x_d2; ++i2) { + for (int64_t i1=0; i1 < x_d1; ++i1) { + const float* const p_x = b_x + i1*x_s1 + i2*x_s2 + i3*x_s3 + i4*x_s4 + i5*x_s5; + mag_bnd_chk(p_x, b_x, mag_tensor_data_size(x)); + min = fminf(mag_vmin_f32(x_d0, p_x), min); + } + } + } + } + } + *b_r = min; +} + +static void MAG_HOTPROC mag_blas_max_f32( /* max x */ + const mag_cpu_thread_local_ctx_t* const bci, + mag_tensor_t* const r, + const mag_tensor_t** const inputs /* Assumes correct inputs for op, all != NULL! */ +) { + (void)bci; + const mag_tensor_t* const x = inputs[0]; + float* const b_r = mag_f32p_mut(r); + const float* const b_x = mag_f32p(x); + mag_load_local_storage_group(r, r_s, strides); + mag_load_local_storage_group(x, x_d, shape); + mag_load_local_storage_group(x, x_s, strides); + float max = -INFINITY; + for (int64_t i5=0; i5 < x_d5; ++i5) { + for (int64_t i4=0; i4 < x_d4; ++i4) { + for (int64_t i3=0; i3 < x_d3; ++i3) { + for (int64_t i2=0; i2 < x_d2; ++i2) { + for (int64_t i1=0; i1 < x_d1; ++i1) { + const float* const p_x = b_x + i1*x_s1 + i2*x_s2 + i3*x_s3 + i4*x_s4 + i5*x_s5; + mag_bnd_chk(p_x, b_x, mag_tensor_data_size(x)); + max = fmaxf(mag_vmax_f32(x_d0, p_x), max); + } + } + } + } + } + *b_r = max; +} + +static void MAG_HOTPROC mag_blas_sum_f32( /* Σx/n Arithmetic mean */ + const mag_cpu_thread_local_ctx_t* const bci, + mag_tensor_t* const r, + const mag_tensor_t** const inputs /* Assumes correct inputs for op, all != NULL! */ +) { + (void)bci; + const mag_tensor_t* const x = inputs[0]; + float* const b_r = mag_f32p_mut(r); + const float* const b_x = mag_f32p(x); + mag_load_local_storage_group(r, r_s, strides); + mag_load_local_storage_group(x, x_d, shape); + mag_load_local_storage_group(x, x_s, strides); + double sum = 0.0; + for (int64_t i5=0; i5 < x_d5; ++i5) { + for (int64_t i4=0; i4 < x_d4; ++i4) { + for (int64_t i3=0; i3 < x_d3; ++i3) { + for (int64_t i2=0; i2 < x_d2; ++i2) { + for (int64_t i1=0; i1 < x_d1; ++i1) { + const float* const p_x = b_x + i1*x_s1 + i2*x_s2 + i3*x_s3 + i4*x_s4 + i5*x_s5; + mag_bnd_chk(p_x, b_x, mag_tensor_data_size(x)); + sum += mag_vsum_f64_f32(x_d0, p_x); + } + } + } + } + } + *b_r = (float)sum; +} + +static void MAG_HOTPROC mag_blas_abs_f32( + const mag_cpu_thread_local_ctx_t* const bci, + mag_tensor_t* const r, + const mag_tensor_t** const inputs /* Assumes correct inputs for op, all != NULL! */ +) { + (void)bci; + const mag_tensor_t* const x = inputs[0]; + float* const b_r = mag_f32p_mut(r); + const float* const b_x = mag_f32p(x); + mag_load_local_storage_group(r, r_s, strides); + mag_load_local_storage_group(x, x_s, strides); + const int64_t rc = mag_tensor_num_rows(x); + const int64_t cc = mag_tensor_num_cols(x); + for (int64_t ri=0; ri < rc; ++ri) { + float* const p_r = b_r + ri*r_s1; + const float* const p_x = b_x + ri*x_s1; + mag_bnd_chk(p_r, b_r, mag_tensor_data_size(r)); + mag_bnd_chk(p_x, b_x, mag_tensor_data_size(x)); + mag_vabs_f32(cc, p_r, p_x); + } +} + +static void MAG_HOTPROC mag_blas_neg_f32( + const mag_cpu_thread_local_ctx_t* const bci, + mag_tensor_t* const r, + const mag_tensor_t** const inputs /* Assumes correct inputs for op, all != NULL! */ +) { + (void)bci; + const mag_tensor_t* const x = inputs[0]; + float* const b_r = mag_f32p_mut(r); + const float* const b_x = mag_f32p(x); + mag_load_local_storage_group(r, r_s, strides); + mag_load_local_storage_group(x, x_s, strides); + const int64_t rc = mag_tensor_num_rows(x); + const int64_t cc = mag_tensor_num_cols(x); + for (int64_t ri=0; ri < rc; ++ri) { + float* const p_r = b_r + ri*r_s1; + const float* const p_x = b_x + ri*x_s1; + mag_bnd_chk(p_r, b_r, mag_tensor_data_size(r)); + mag_bnd_chk(p_x, b_x, mag_tensor_data_size(x)); + mag_vneg_f32(cc, p_r, p_x); + } +} + +static void MAG_HOTPROC mag_blas_log_f32( + const mag_cpu_thread_local_ctx_t* const bci, + mag_tensor_t* const r, + const mag_tensor_t** const inputs /* Assumes correct inputs for op, all != NULL! */ +) { + (void)bci; + const mag_tensor_t* const x = inputs[0]; + float* const b_r = mag_f32p_mut(r); + const float* const b_x = mag_f32p(x); + mag_load_local_storage_group(r, r_s, strides); + mag_load_local_storage_group(x, x_s, strides); + const int64_t rc = mag_tensor_num_rows(x); + const int64_t cc = mag_tensor_num_cols(x); + for (int64_t ri=0; ri < rc; ++ri) { + float* const p_r = b_r + ri*r_s1; + const float* const p_x = b_x + ri*x_s1; + mag_bnd_chk(p_r, b_r, mag_tensor_data_size(r)); + mag_bnd_chk(p_x, b_x, mag_tensor_data_size(x)); + mag_vlog_f32(cc, p_r, p_x); + } +} + +static void MAG_HOTPROC mag_blas_sqr_f32( + const mag_cpu_thread_local_ctx_t* const bci, + mag_tensor_t* const r, + const mag_tensor_t** const inputs /* Assumes correct inputs for op, all != NULL! */ +) { + (void)bci; + const mag_tensor_t* const x = inputs[0]; + float* const b_r = mag_f32p_mut(r); + const float* const b_x = mag_f32p(x); + mag_load_local_storage_group(r, r_s, strides); + mag_load_local_storage_group(x, x_s, strides); + const int64_t rc = mag_tensor_num_rows(x); + const int64_t cc = mag_tensor_num_cols(x); + for (int64_t ri=0; ri < rc; ++ri) { + float* const p_r = b_r + ri*r_s1; + const float* const p_x = b_x + ri*x_s1; + mag_bnd_chk(p_r, b_r, mag_tensor_data_size(r)); + mag_bnd_chk(p_x, b_x, mag_tensor_data_size(x)); + mag_vsqr_f32(cc, p_r, p_x); + } +} + +static void MAG_HOTPROC mag_blas_sqrt_f32( + const mag_cpu_thread_local_ctx_t* const bci, + mag_tensor_t* const r, + const mag_tensor_t** const inputs /* Assumes correct inputs for op, all != NULL! */ +) { + (void)bci; + const mag_tensor_t* const x = inputs[0]; + float* const b_r = mag_f32p_mut(r); + const float* const b_x = mag_f32p(x); + mag_load_local_storage_group(r, r_s, strides); + mag_load_local_storage_group(x, x_s, strides); + const int64_t rc = mag_tensor_num_rows(x); + const int64_t cc = mag_tensor_num_cols(x); + for (int64_t ri=0; ri < rc; ++ri) { + float* const p_r = b_r + ri*r_s1; + const float* const p_x = b_x + ri*x_s1; + mag_bnd_chk(p_r, b_r, mag_tensor_data_size(r)); + mag_bnd_chk(p_x, b_x, mag_tensor_data_size(x)); + mag_vsqrt_f32(cc, p_r, p_x); + } +} + +static void MAG_HOTPROC mag_blas_sin_f32( + const mag_cpu_thread_local_ctx_t* const bci, + mag_tensor_t* const r, + const mag_tensor_t** const inputs /* Assumes correct inputs for op, all != NULL! */ +) { + (void)bci; + const mag_tensor_t* const x = inputs[0]; + float* const b_r = mag_f32p_mut(r); + const float* const b_x = mag_f32p(x); + mag_load_local_storage_group(r, r_s, strides); + mag_load_local_storage_group(x, x_s, strides); + const int64_t rc = mag_tensor_num_rows(x); + const int64_t cc = mag_tensor_num_cols(x); + for (int64_t ri=0; ri < rc; ++ri) { + float* const p_r = b_r + ri*r_s1; + const float* const p_x = b_x + ri*x_s1; + mag_bnd_chk(p_r, b_r, mag_tensor_data_size(r)); + mag_bnd_chk(p_x, b_x, mag_tensor_data_size(x)); + mag_vsin_f32(cc, p_r, p_x); + } +} + +static void MAG_HOTPROC mag_blas_cos_f32( + const mag_cpu_thread_local_ctx_t* const bci, + mag_tensor_t* const r, + const mag_tensor_t** const inputs /* Assumes correct inputs for op, all != NULL! */ +) { + (void)bci; + const mag_tensor_t* const x = inputs[0]; + float* const b_r = mag_f32p_mut(r); + const float* const b_x = mag_f32p(x); + mag_load_local_storage_group(r, r_s, strides); + mag_load_local_storage_group(x, x_s, strides); + const int64_t rc = mag_tensor_num_rows(x); + const int64_t cc = mag_tensor_num_cols(x); + for (int64_t ri=0; ri < rc; ++ri) { + float* const p_r = b_r + ri*r_s1; + const float* const p_x = b_x + ri*x_s1; + mag_bnd_chk(p_r, b_r, mag_tensor_data_size(r)); + mag_bnd_chk(p_x, b_x, mag_tensor_data_size(x)); + mag_vcos_f32(cc, p_r, p_x); + } +} + +static void MAG_HOTPROC mag_blas_step_f32( + const mag_cpu_thread_local_ctx_t* const bci, + mag_tensor_t* const r, + const mag_tensor_t** const inputs /* Assumes correct inputs for op, all != NULL! */ +) { + (void)bci; + const mag_tensor_t* const x = inputs[0]; + float* const b_r = mag_f32p_mut(r); + const float* const b_x = mag_f32p(x); + mag_load_local_storage_group(r, r_s, strides); + mag_load_local_storage_group(x, x_s, strides); + const int64_t rc = mag_tensor_num_rows(x); + const int64_t cc = mag_tensor_num_cols(x); + for (int64_t ri=0; ri < rc; ++ri) { + float* const p_r = b_r + ri*r_s1; + const float* const p_x = b_x + ri*x_s1; + mag_bnd_chk(p_r, b_r, mag_tensor_data_size(r)); + mag_bnd_chk(p_x, b_x, mag_tensor_data_size(x)); + mag_vstep_f32(cc, p_r, p_x); + } +} + +static void MAG_HOTPROC mag_blas_softmax_f32( + const mag_cpu_thread_local_ctx_t* const bci, + mag_tensor_t* const r, + const mag_tensor_t** const inputs /* Assumes correct inputs for op, all != NULL! */ +) { + (void)bci; + const mag_tensor_t* const x = inputs[0]; + float* const b_r = mag_f32p_mut(r); + const float* const b_x = mag_f32p(x); + mag_load_local_storage_group(r, r_s, strides); + mag_load_local_storage_group(x, x_s, strides); + const int64_t rc = mag_tensor_num_rows(x); + const int64_t cc = mag_tensor_num_cols(x); + for (int64_t ri=0; ri < rc; ++ri) { + float* const p_r = b_r + ri*r_s1; + const float* const p_x = b_x + ri*x_s1; + mag_bnd_chk(p_r, b_r, mag_tensor_data_size(r)); + mag_bnd_chk(p_x, b_x, mag_tensor_data_size(x)); + mag_vsoftmax_f32(cc, p_r, p_x); + } +} + +static void MAG_HOTPROC mag_blas_softmax_dv_f32( + const mag_cpu_thread_local_ctx_t* const bci, + mag_tensor_t* const r, + const mag_tensor_t** const inputs /* Assumes correct inputs for op, all != NULL! */ +) { + (void)bci; + const mag_tensor_t* const x = inputs[0]; + float* const b_r = mag_f32p_mut(r); + const float* const b_x = mag_f32p(x); + mag_load_local_storage_group(r, r_s, strides); + mag_load_local_storage_group(x, x_s, strides); + const int64_t rc = mag_tensor_num_rows(x); + const int64_t cc = mag_tensor_num_cols(x); + for (int64_t ri=0; ri < rc; ++ri) { + float* const p_r = b_r + ri*r_s1; + const float* const p_x = b_x + ri*x_s1; + mag_bnd_chk(p_r, b_r, mag_tensor_data_size(r)); + mag_bnd_chk(p_x, b_x, mag_tensor_data_size(x)); + mag_vsoftmax_dv_f32(cc, p_r, p_x); + } +} + +static void MAG_HOTPROC mag_blas_sigmoid_f32( + const mag_cpu_thread_local_ctx_t* const bci, + mag_tensor_t* const r, + const mag_tensor_t** const inputs /* Assumes correct inputs for op, all != NULL! */ +) { + (void)bci; + const mag_tensor_t* const x = inputs[0]; + float* const b_r = mag_f32p_mut(r); + const float* const b_x = mag_f32p(x); + mag_load_local_storage_group(r, r_s, strides); + mag_load_local_storage_group(x, x_s, strides); + const int64_t rc = mag_tensor_num_rows(x); + const int64_t cc = mag_tensor_num_cols(x); + for (int64_t ri=0; ri < rc; ++ri) { + float* const p_r = b_r + ri*r_s1; + const float* const p_x = b_x + ri*x_s1; + mag_bnd_chk(p_r, b_r, mag_tensor_data_size(r)); + mag_bnd_chk(p_x, b_x, mag_tensor_data_size(x)); + mag_vsigmoid_f32(cc, p_r, p_x); + } +} + +static void MAG_HOTPROC mag_blas_sigmoid_dv_f32( + const mag_cpu_thread_local_ctx_t* const bci, + mag_tensor_t* const r, + const mag_tensor_t** const inputs /* Assumes correct inputs for op, all != NULL! */ +) { + (void)bci; + const mag_tensor_t* const x = inputs[0]; + float* const b_r = mag_f32p_mut(r); + const float* const b_x = mag_f32p(x); + mag_load_local_storage_group(r, r_s, strides); + mag_load_local_storage_group(x, x_s, strides); + const int64_t rc = mag_tensor_num_rows(x); + const int64_t cc = mag_tensor_num_cols(x); + for (int64_t ri=0; ri < rc; ++ri) { + float* const p_r = b_r + ri*r_s1; + const float* const p_x = b_x + ri*x_s1; + mag_bnd_chk(p_r, b_r, mag_tensor_data_size(r)); + mag_bnd_chk(p_x, b_x, mag_tensor_data_size(x)); + mag_vsigmoid_dv_f32(cc, p_r, p_x); + } +} + +static void MAG_HOTPROC mag_blas_hard_sigmoid_f32( + const mag_cpu_thread_local_ctx_t* const bci, + mag_tensor_t* const r, + const mag_tensor_t** const inputs /* Assumes correct inputs for op, all != NULL! */ +) { + (void)bci; + const mag_tensor_t* const x = inputs[0]; + float* const b_r = mag_f32p_mut(r); + const float* const b_x = mag_f32p(x); + mag_load_local_storage_group(r, r_s, strides); + mag_load_local_storage_group(x, x_s, strides); + const int64_t rc = mag_tensor_num_rows(x); + const int64_t cc = mag_tensor_num_cols(x); + for (int64_t ri=0; ri < rc; ++ri) { + float* const p_r = b_r + ri*r_s1; + const float* const p_x = b_x + ri*x_s1; + mag_bnd_chk(p_r, b_r, mag_tensor_data_size(r)); + mag_bnd_chk(p_x, b_x, mag_tensor_data_size(x)); + mag_vhard_sigmoid_f32(cc, p_r, p_x); + } +} + +static void MAG_HOTPROC mag_blas_silu_f32( + const mag_cpu_thread_local_ctx_t* const bci, + mag_tensor_t* const r, + const mag_tensor_t** const inputs /* Assumes correct inputs for op, all != NULL! */ +) { + (void)bci; + const mag_tensor_t* const x = inputs[0]; + float* const b_r = mag_f32p_mut(r); + const float* const b_x = mag_f32p(x); + mag_load_local_storage_group(r, r_s, strides); + mag_load_local_storage_group(x, x_s, strides); + const int64_t rc = mag_tensor_num_rows(x); + const int64_t cc = mag_tensor_num_cols(x); + for (int64_t ri=0; ri < rc; ++ri) { + float* const p_r = b_r + ri*r_s1; + const float* const p_x = b_x + ri*x_s1; + mag_bnd_chk(p_r, b_r, mag_tensor_data_size(r)); + mag_bnd_chk(p_x, b_x, mag_tensor_data_size(x)); + mag_vsilu_f32(cc, p_r, p_x); + } +} + +static void MAG_HOTPROC mag_blas_silu_dv_f32( + const mag_cpu_thread_local_ctx_t* const bci, + mag_tensor_t* const r, + const mag_tensor_t** const inputs /* Assumes correct inputs for op, all != NULL! */ +) { + (void)bci; + const mag_tensor_t* const x = inputs[0]; + float* const b_r = mag_f32p_mut(r); + const float* const b_x = mag_f32p(x); + mag_load_local_storage_group(r, r_s, strides); + mag_load_local_storage_group(x, x_s, strides); + const int64_t rc = mag_tensor_num_rows(x); + const int64_t cc = mag_tensor_num_cols(x); + for (int64_t ri=0; ri < rc; ++ri) { + float* const p_r = b_r + ri*r_s1; + const float* const p_x = b_x + ri*x_s1; + mag_bnd_chk(p_r, b_r, mag_tensor_data_size(r)); + mag_bnd_chk(p_x, b_x, mag_tensor_data_size(x)); + mag_vsilu_dv_f32(cc, p_r, p_x); + } +} + +static void MAG_HOTPROC mag_blas_tanh_f32( + const mag_cpu_thread_local_ctx_t* const bci, + mag_tensor_t* const r, + const mag_tensor_t** const inputs /* Assumes correct inputs for op, all != NULL! */ +) { + (void)bci; + const mag_tensor_t* const x = inputs[0]; + float* const b_r = mag_f32p_mut(r); + const float* const b_x = mag_f32p(x); + mag_load_local_storage_group(r, r_s, strides); + mag_load_local_storage_group(x, x_s, strides); + const int64_t rc = mag_tensor_num_rows(x); + const int64_t cc = mag_tensor_num_cols(x); + for (int64_t ri=0; ri < rc; ++ri) { + float* const p_r = b_r + ri*r_s1; + const float* const p_x = b_x + ri*x_s1; + mag_bnd_chk(p_r, b_r, mag_tensor_data_size(r)); + mag_bnd_chk(p_x, b_x, mag_tensor_data_size(x)); + mag_vtanh_f32(cc, p_r, p_x); + } +} + +static void MAG_HOTPROC mag_blas_tanh_dv_f32( + const mag_cpu_thread_local_ctx_t* const bci, + mag_tensor_t* const r, + const mag_tensor_t** const inputs /* Assumes correct inputs for op, all != NULL! */ +) { + (void)bci; + const mag_tensor_t* const x = inputs[0]; + float* const b_r = mag_f32p_mut(r); + const float* const b_x = mag_f32p(x); + mag_load_local_storage_group(r, r_s, strides); + mag_load_local_storage_group(x, x_s, strides); + const int64_t rc = mag_tensor_num_rows(x); + const int64_t cc = mag_tensor_num_cols(x); + for (int64_t ri=0; ri < rc; ++ri) { + float* const p_r = b_r + ri*r_s1; + const float* const p_x = b_x + ri*x_s1; + mag_bnd_chk(p_r, b_r, mag_tensor_data_size(r)); + mag_bnd_chk(p_x, b_x, mag_tensor_data_size(x)); + mag_vtanh_dv_f32(cc, p_r, p_x); + } +} + +static void MAG_HOTPROC mag_blas_relu_f32( + const mag_cpu_thread_local_ctx_t* const bci, + mag_tensor_t* const r, + const mag_tensor_t** const inputs /* Assumes correct inputs for op, all != NULL! */ +) { + (void)bci; + const mag_tensor_t* const x = inputs[0]; + float* const b_r = mag_f32p_mut(r); + const float* const b_x = mag_f32p(x); + mag_load_local_storage_group(r, r_s, strides); + mag_load_local_storage_group(x, x_s, strides); + const int64_t rc = mag_tensor_num_rows(x); + const int64_t cc = mag_tensor_num_cols(x); + for (int64_t ri=0; ri < rc; ++ri) { + float* const p_r = b_r + ri*r_s1; + const float* const p_x = b_x + ri*x_s1; + mag_bnd_chk(p_r, b_r, mag_tensor_data_size(r)); + mag_bnd_chk(p_x, b_x, mag_tensor_data_size(x)); + mag_vrelu_f32(cc, p_r, p_x); + } +} + +static void MAG_HOTPROC mag_blas_relu_dv_f32( + const mag_cpu_thread_local_ctx_t* const bci, + mag_tensor_t* const r, + const mag_tensor_t** const inputs /* Assumes correct inputs for op, all != NULL! */ +) { + (void)bci; + const mag_tensor_t* const x = inputs[0]; + float* const b_r = mag_f32p_mut(r); + const float* const b_x = mag_f32p(x); + mag_load_local_storage_group(r, r_s, strides); + mag_load_local_storage_group(x, x_s, strides); + const int64_t rc = mag_tensor_num_rows(x); + const int64_t cc = mag_tensor_num_cols(x); + for (int64_t ri=0; ri < rc; ++ri) { + float* const p_r = b_r + ri*r_s1; + const float* const p_x = b_x + ri*x_s1; + mag_bnd_chk(p_r, b_r, mag_tensor_data_size(r)); + mag_bnd_chk(p_x, b_x, mag_tensor_data_size(x)); + mag_vrelu_dv_f32(cc, p_r, p_x); + } +} + +static void MAG_HOTPROC mag_blas_gelu_f32( + const mag_cpu_thread_local_ctx_t* const bci, + mag_tensor_t* const r, + const mag_tensor_t** const inputs /* Assumes correct inputs for op, all != NULL! */ +) { + (void)bci; + const mag_tensor_t* const x = inputs[0]; + float* const b_r = mag_f32p_mut(r); + const float* const b_x = mag_f32p(x); + mag_load_local_storage_group(r, r_s, strides); + mag_load_local_storage_group(x, x_s, strides); + const int64_t rc = mag_tensor_num_rows(x); + const int64_t cc = mag_tensor_num_cols(x); + for (int64_t ri=0; ri < rc; ++ri) { + float* const p_r = b_r + ri*r_s1; + const float* const p_x = b_x + ri*x_s1; + mag_bnd_chk(p_r, b_r, mag_tensor_data_size(r)); + mag_bnd_chk(p_x, b_x, mag_tensor_data_size(x)); + mag_vgelu_f32(cc, p_r, p_x); + } +} + +static void MAG_HOTPROC mag_blas_gelu_dv_f32( + const mag_cpu_thread_local_ctx_t* const bci, + mag_tensor_t* const r, + const mag_tensor_t** const inputs /* Assumes correct inputs for op, all != NULL! */ +) { + (void)bci; + const mag_tensor_t* const x = inputs[0]; + float* const b_r = mag_f32p_mut(r); + const float* const b_x = mag_f32p(x); + mag_load_local_storage_group(r, r_s, strides); + mag_load_local_storage_group(x, x_s, strides); + const int64_t rc = mag_tensor_num_rows(x); + const int64_t cc = mag_tensor_num_cols(x); + for (int64_t ri=0; ri < rc; ++ri) { + float* const p_r = b_r + ri*r_s1; + const float* const p_x = b_x + ri*x_s1; + mag_bnd_chk(p_r, b_r, mag_tensor_data_size(r)); + mag_bnd_chk(p_x, b_x, mag_tensor_data_size(x)); + mag_vgelu_dv_f32(cc, p_r, p_x); + } +} + +static void MAG_HOTPROC mag_blas_add_f32( + const mag_cpu_thread_local_ctx_t* const bci, + mag_tensor_t* r, + const mag_tensor_t** const inputs /* Assumes correct inputs for op, all != NULL! */ +) { + const mag_tensor_t* const x = inputs[0]; + const mag_tensor_t* const y = inputs[1]; + float* const b_r = mag_f32p_mut(r); + const float* const b_x = mag_f32p(x); + const float* const b_y = mag_f32p(y); + mag_load_local_storage_group(r, r_d, shape); + mag_load_local_storage_group(r, r_s, strides); + mag_load_local_storage_group(x, x_d, shape); + mag_load_local_storage_group(x, x_s, strides); + mag_load_local_storage_group(y, y_d, shape); + mag_load_local_storage_group(y, y_s, strides); + const int64_t rc = mag_tensor_num_rows(x); + int64_t x_i1 = 0; + int64_t x_i2 = 0; + int64_t x_i3 = 0; + int64_t x_i4 = 0; + int64_t x_i5 = 0; + if (y_s0 == 1) { /* Fast path for contiguous input tensors. */ + for (int64_t ri=0; ri < rc; ++ri) { /* For each row */ + /* Compute broadcasted 5D indices for y and the resulting buffer ptrs for r,x,y. */ + const int64_t y_i5 = x_i5 % y_d5; + const int64_t y_i4 = x_i4 % y_d4; + const int64_t y_i3 = x_i3 % y_d3; + const int64_t y_i2 = x_i2 % y_d2; + const int64_t y_i1 = x_i1 % y_d1; + float* const p_r = b_r + x_i1*r_s1 + x_i2*r_s2 + x_i3*r_s3 + x_i4*r_s4 + x_i5*r_s5; + const float* const p_x = b_x + x_i1*x_s1 + x_i2*x_s2 + x_i3*x_s3 + x_i4*x_s4 + x_i5*x_s5; + const float* const p_y = b_y + y_i1*y_s1 + y_i2*y_s2 + y_i3*y_s3 + y_i4*y_s4 + y_i5*y_s5; + mag_bnd_chk(p_y, b_y, mag_tensor_data_size(y)); + const int64_t pa = x_d0 / y_d0; + for (int64_t i=0; i < pa; ++i) { /* For each element in row */ + float* const pp_r = p_r + i*y_d0; /* Compute result ptr. */ + const float* const pp_x = p_x + i*y_d0; /* Compute x ptr. */ + mag_bnd_chk(pp_r, b_r, mag_tensor_data_size(r)); + mag_bnd_chk(pp_x, b_x, mag_tensor_data_size(x)); + mag_vadd_f32(y_d0, pp_r, pp_x, p_y); /* Apply micro kernel vector op */ + } + /* Incremental index computation. */ + ++x_i1; + if (x_i1 >= x_d1) { + x_i1 = 0; + ++x_i2; + if (x_i2 >= x_d2) { + x_i2 = 0; + ++x_i3; + if (x_i3 >= x_d3) { + x_i3 = 0; + ++x_i4; + if (x_i4 >= x_d4) { + x_i4 = 0; + ++x_i5; + } + } + } + } + } + } else { /* Slow path for non-contiguous input tensors. */ + for (int64_t ri=0; ri < rc; ++ri) { /* For each row */ + /* Compute broadcasted 5D indices for y and the resulting buffer ptrs for r,x,y. */ + const int64_t y_i5 = x_i5 % y_d5; + const int64_t y_i4 = x_i4 % y_d4; + const int64_t y_i3 = x_i3 % y_d3; + const int64_t y_i2 = x_i2 % y_d2; + const int64_t y_i1 = x_i1 % y_d1; + float* const p_r = b_r + x_i1*r_s1 + x_i2*r_s2 + x_i3*r_s3 + x_i4*r_s4 + x_i5*r_s5; + const float* const p_x = b_x + x_i1*x_s1 + x_i2*x_s2 + x_i3*x_s3 + x_i4*x_s4 + x_i5*x_s5; + for (int64_t i=0; i < r_d0; ++i) { /* For each element in row */ + const float* const p_y = b_y + i%y_d0*y_s0 + y_i1*y_s1 + y_i2*y_s2 + y_i3*y_s3 + y_i4*y_s4 + y_i5*y_s5; /* Compute result ptr. */ + mag_bnd_chk(p_r+i, b_r, mag_tensor_data_size(r)); + mag_bnd_chk(p_x+i, b_x, mag_tensor_data_size(x)); + mag_bnd_chk(p_y, b_y, mag_tensor_data_size(y)); + p_r[i] = p_x[i] + *p_y; /* Apply scalar op. */ + } + /* Incremental index computation. */ + ++x_i1; + if (x_i1 >= x_d1) { + x_i1 = 0; + ++x_i2; + if (x_i2 >= x_d2) { + x_i2 = 0; + ++x_i3; + if (x_i3 >= x_d3) { + x_i3 = 0; + ++x_i4; + if (x_i4 >= x_d4) { + x_i4 = 0; + ++x_i5; + } + } + } + } + } + } +} + +static void MAG_HOTPROC mag_blas_sub_f32( + const mag_cpu_thread_local_ctx_t* const bci, + mag_tensor_t* const r, + const mag_tensor_t** const inputs /* Assumes correct inputs for op, all != NULL! */ +) { + const mag_tensor_t* const x = inputs[0]; + const mag_tensor_t* const y = inputs[1]; + float* const b_r = mag_f32p_mut(r); + const float* const b_x = mag_f32p(x); + const float* const b_y = mag_f32p(y); + mag_load_local_storage_group(r, r_d, shape); + mag_load_local_storage_group(r, r_s, strides); + mag_load_local_storage_group(x, x_d, shape); + mag_load_local_storage_group(x, x_s, strides); + mag_load_local_storage_group(y, y_d, shape); + mag_load_local_storage_group(y, y_s, strides); + const int64_t rc = mag_tensor_num_rows(x); + int64_t x_i1 = 0; + int64_t x_i2 = 0; + int64_t x_i3 = 0; + int64_t x_i4 = 0; + int64_t x_i5 = 0; + if (y_s0 == 1) { /* Fast path for contiguous input tensors. */ + for (int64_t ri=0; ri < rc; ++ri) { /* For each row */ + /* Compute broadcasted 5D indices for y and the resulting buffer ptrs for r,x,y. */ + const int64_t y_i5 = x_i5 % y_d5; + const int64_t y_i4 = x_i4 % y_d4; + const int64_t y_i3 = x_i3 % y_d3; + const int64_t y_i2 = x_i2 % y_d2; + const int64_t y_i1 = x_i1 % y_d1; + float* const p_r = b_r + x_i1*r_s1 + x_i2*r_s2 + x_i3*r_s3 + x_i4*r_s4 + x_i5*r_s5; + const float* const p_x = b_x + x_i1*x_s1 + x_i2*x_s2 + x_i3*x_s3 + x_i4*x_s4 + x_i5*x_s5; + const float* const p_y = b_y + y_i1*y_s1 + y_i2*y_s2 + y_i3*y_s3 + y_i4*y_s4 + y_i5*y_s5; + mag_bnd_chk(p_y, b_y, mag_tensor_data_size(y)); + const int64_t pa = x_d0 / y_d0; + for (int64_t i=0; i < pa; ++i) { /* For each element in row */ + float* const pp_r = p_r + i*y_d0; /* Compute result ptr. */ + const float* const pp_x = p_x + i*y_d0; /* Compute x ptr. */ + mag_bnd_chk(pp_r, b_r, mag_tensor_data_size(r)); + mag_bnd_chk(pp_x, b_x, mag_tensor_data_size(x)); + mag_vsub_f32(y_d0, pp_r, pp_x, p_y); /* Apply micro kernel vector op */ + } + /* Incremental index computation. */ + ++x_i1; + if (x_i1 >= x_d1) { + x_i1 = 0; + ++x_i2; + if (x_i2 >= x_d2) { + x_i2 = 0; + ++x_i3; + if (x_i3 >= x_d3) { + x_i3 = 0; + ++x_i4; + if (x_i4 >= x_d4) { + x_i4 = 0; + ++x_i5; + } + } + } + } + } + } else { /* Slow path for non-contiguous input tensors. */ + for (int64_t ri=0; ri < rc; ++ri) { /* For each row */ + /* Compute broadcasted 5D indices for y and the resulting buffer ptrs for r,x,y. */ + const int64_t y_i5 = x_i5 % y_d5; + const int64_t y_i4 = x_i4 % y_d4; + const int64_t y_i3 = x_i3 % y_d3; + const int64_t y_i2 = x_i2 % y_d2; + const int64_t y_i1 = x_i1 % y_d1; + float* const p_r = b_r + x_i1*r_s1 + x_i2*r_s2 + x_i3*r_s3 + x_i4*r_s4 + x_i5*r_s5; + const float* const p_x = b_x + x_i1*x_s1 + x_i2*x_s2 + x_i3*x_s3 + x_i4*x_s4 + x_i5*x_s5; + for (int64_t i=0; i < r_d0; ++i) { /* For each element in row */ + const float* const p_y = b_y + i%y_d0*y_s0 + y_i1*y_s1 + y_i2*y_s2 + y_i3*y_s3 + y_i4*y_s4 + y_i5*y_s5; /* Compute result ptr. */ + mag_bnd_chk(p_r+i, b_r, mag_tensor_data_size(r)); + mag_bnd_chk(p_x+i, b_x, mag_tensor_data_size(x)); + mag_bnd_chk(p_y, b_y, mag_tensor_data_size(y)); + p_r[i] = p_x[i] - *p_y; /* Apply scalar op. */ + } + /* Incremental index computation. */ + ++x_i1; + if (x_i1 >= x_d1) { + x_i1 = 0; + ++x_i2; + if (x_i2 >= x_d2) { + x_i2 = 0; + ++x_i3; + if (x_i3 >= x_d3) { + x_i3 = 0; + ++x_i4; + if (x_i4 >= x_d4) { + x_i4 = 0; + ++x_i5; + } + } + } + } + } + } +} + +static void MAG_HOTPROC mag_blas_mul_f32( + const mag_cpu_thread_local_ctx_t* const bci, + mag_tensor_t* const r, + const mag_tensor_t** const inputs /* Assumes correct inputs for op, all != NULL! */ +) { + const mag_tensor_t* const x = inputs[0]; + const mag_tensor_t* const y = inputs[1]; + float* const b_r = mag_f32p_mut(r); + const float* const b_x = mag_f32p(x); + const float* const b_y = mag_f32p(y); + mag_load_local_storage_group(r, r_d, shape); + mag_load_local_storage_group(r, r_s, strides); + mag_load_local_storage_group(x, x_d, shape); + mag_load_local_storage_group(x, x_s, strides); + mag_load_local_storage_group(y, y_d, shape); + mag_load_local_storage_group(y, y_s, strides); + const int64_t rc = mag_tensor_num_rows(x); + int64_t x_i1 = 0; + int64_t x_i2 = 0; + int64_t x_i3 = 0; + int64_t x_i4 = 0; + int64_t x_i5 = 0; + if (y_s0 == 1) { /* Fast path for contiguous input tensors. */ + for (int64_t ri=0; ri < rc; ++ri) { /* For each row */ + /* Compute broadcasted 5D indices for y and the resulting buffer ptrs for r,x,y. */ + const int64_t y_i5 = x_i5 % y_d5; + const int64_t y_i4 = x_i4 % y_d4; + const int64_t y_i3 = x_i3 % y_d3; + const int64_t y_i2 = x_i2 % y_d2; + const int64_t y_i1 = x_i1 % y_d1; + float* const p_r = b_r + x_i1*r_s1 + x_i2*r_s2 + x_i3*r_s3 + x_i4*r_s4 + x_i5*r_s5; + const float* const p_x = b_x + x_i1*x_s1 + x_i2*x_s2 + x_i3*x_s3 + x_i4*x_s4 + x_i5*x_s5; + const float* const p_y = b_y + y_i1*y_s1 + y_i2*y_s2 + y_i3*y_s3 + y_i4*y_s4 + y_i5*y_s5; + mag_bnd_chk(p_y, b_y, mag_tensor_data_size(y)); + const int64_t pa = x_d0 / y_d0; + for (int64_t i=0; i < pa; ++i) { /* For each element in row */ + float* const pp_r = p_r + i*y_d0; /* Compute result ptr. */ + const float* const pp_x = p_x + i*y_d0; /* Compute x ptr. */ + mag_bnd_chk(pp_r, b_r, mag_tensor_data_size(r)); + mag_bnd_chk(pp_x, b_x, mag_tensor_data_size(x)); + mag_vmul_f32(y_d0, pp_r, pp_x, p_y); /* Apply micro kernel vector op */ + } + /* Incremental index computation. */ + ++x_i1; + if (x_i1 >= x_d1) { + x_i1 = 0; + ++x_i2; + if (x_i2 >= x_d2) { + x_i2 = 0; + ++x_i3; + if (x_i3 >= x_d3) { + x_i3 = 0; + ++x_i4; + if (x_i4 >= x_d4) { + x_i4 = 0; + ++x_i5; + } + } + } + } + } + } else { /* Slow path for non-contiguous input tensors. */ + for (int64_t ri=0; ri < rc; ++ri) { /* For each row */ + /* Compute broadcasted 5D indices for y and the resulting buffer ptrs for r,x,y. */ + const int64_t y_i5 = x_i5 % y_d5; + const int64_t y_i4 = x_i4 % y_d4; + const int64_t y_i3 = x_i3 % y_d3; + const int64_t y_i2 = x_i2 % y_d2; + const int64_t y_i1 = x_i1 % y_d1; + float* const p_r = b_r + x_i1*r_s1 + x_i2*r_s2 + x_i3*r_s3 + x_i4*r_s4 + x_i5*r_s5; + const float* const p_x = b_x + x_i1*x_s1 + x_i2*x_s2 + x_i3*x_s3 + x_i4*x_s4 + x_i5*x_s5; + for (int64_t i=0; i < r_d0; ++i) { /* For each element in row */ + const float* const p_y = b_y + i%y_d0*y_s0 + y_i1*y_s1 + y_i2*y_s2 + y_i3*y_s3 + y_i4*y_s4 + y_i5*y_s5; /* Compute result ptr. */ + mag_bnd_chk(p_r+i, b_r, mag_tensor_data_size(r)); + mag_bnd_chk(p_x+i, b_x, mag_tensor_data_size(x)); + mag_bnd_chk(p_y, b_y, mag_tensor_data_size(y)); + p_r[i] = p_x[i] * *p_y; /* Apply scalar op. */ + } + /* Incremental index computation. */ + ++x_i1; + if (x_i1 >= x_d1) { + x_i1 = 0; + ++x_i2; + if (x_i2 >= x_d2) { + x_i2 = 0; + ++x_i3; + if (x_i3 >= x_d3) { + x_i3 = 0; + ++x_i4; + if (x_i4 >= x_d4) { + x_i4 = 0; + ++x_i5; + } + } + } + } + } + } +} + +static void MAG_HOTPROC mag_blas_div_f32( + const mag_cpu_thread_local_ctx_t* const bci, + mag_tensor_t* const r, + const mag_tensor_t** const inputs /* Assumes correct inputs for op, all != NULL! */ +) { + const mag_tensor_t* const x = inputs[0]; + const mag_tensor_t* const y = inputs[1]; + float* const b_r = mag_f32p_mut(r); + const float* const b_x = mag_f32p(x); + const float* const b_y = mag_f32p(y); + mag_load_local_storage_group(r, r_d, shape); + mag_load_local_storage_group(r, r_s, strides); + mag_load_local_storage_group(x, x_d, shape); + mag_load_local_storage_group(x, x_s, strides); + mag_load_local_storage_group(y, y_d, shape); + mag_load_local_storage_group(y, y_s, strides); + const int64_t rc = mag_tensor_num_rows(x); + int64_t x_i1 = 0; + int64_t x_i2 = 0; + int64_t x_i3 = 0; + int64_t x_i4 = 0; + int64_t x_i5 = 0; + if (y_s0 == 1) { /* Fast path for contiguous input tensors. */ + for (int64_t ri=0; ri < rc; ++ri) { /* For each row */ + /* Compute broadcasted 5D indices for y and the resulting buffer ptrs for r,x,y. */ + const int64_t y_i5 = x_i5 % y_d5; + const int64_t y_i4 = x_i4 % y_d4; + const int64_t y_i3 = x_i3 % y_d3; + const int64_t y_i2 = x_i2 % y_d2; + const int64_t y_i1 = x_i1 % y_d1; + float* const p_r = b_r + x_i1*r_s1 + x_i2*r_s2 + x_i3*r_s3 + x_i4*r_s4 + x_i5*r_s5; + const float* const p_x = b_x + x_i1*x_s1 + x_i2*x_s2 + x_i3*x_s3 + x_i4*x_s4 + x_i5*x_s5; + const float* const p_y = b_y + y_i1*y_s1 + y_i2*y_s2 + y_i3*y_s3 + y_i4*y_s4 + y_i5*y_s5; + mag_bnd_chk(p_y, b_y, mag_tensor_data_size(y)); + const int64_t pa = x_d0 / y_d0; + for (int64_t i=0; i < pa; ++i) { /* For each element in row */ + float* const pp_r = p_r + i*y_d0; /* Compute result ptr. */ + const float* const pp_x = p_x + i*y_d0; /* Compute x ptr. */ + mag_bnd_chk(pp_r, b_r, mag_tensor_data_size(r)); + mag_bnd_chk(pp_x, b_x, mag_tensor_data_size(x)); + mag_vdiv_f32(y_d0, pp_r, pp_x, p_y); /* Apply micro kernel vector op */ + } + /* Incremental index computation. */ + ++x_i1; + if (x_i1 >= x_d1) { + x_i1 = 0; + ++x_i2; + if (x_i2 >= x_d2) { + x_i2 = 0; + ++x_i3; + if (x_i3 >= x_d3) { + x_i3 = 0; + ++x_i4; + if (x_i4 >= x_d4) { + x_i4 = 0; + ++x_i5; + } + } + } + } + } + } else { /* Slow path for non-contiguous input tensors. */ + for (int64_t ri=0; ri < rc; ++ri) { /* For each row */ + /* Compute broadcasted 5D indices for y and the resulting buffer ptrs for r,x,y. */ + const int64_t y_i5 = x_i5 % y_d5; + const int64_t y_i4 = x_i4 % y_d4; + const int64_t y_i3 = x_i3 % y_d3; + const int64_t y_i2 = x_i2 % y_d2; + const int64_t y_i1 = x_i1 % y_d1; + float* const p_r = b_r + x_i1*r_s1 + x_i2*r_s2 + x_i3*r_s3 + x_i4*r_s4 + x_i5*r_s5; + const float* const p_x = b_x + x_i1*x_s1 + x_i2*x_s2 + x_i3*x_s3 + x_i4*x_s4 + x_i5*x_s5; + for (int64_t i=0; i < r_d0; ++i) { /* For each element in row */ + const float* const p_y = b_y + i%y_d0*y_s0 + y_i1*y_s1 + y_i2*y_s2 + y_i3*y_s3 + y_i4*y_s4 + y_i5*y_s5; /* Compute result ptr. */ + mag_bnd_chk(p_r+i, b_r, mag_tensor_data_size(r)); + mag_bnd_chk(p_x+i, b_x, mag_tensor_data_size(x)); + mag_bnd_chk(p_y, b_y, mag_tensor_data_size(y)); + p_r[i] = p_x[i] / *p_y; /* Apply scalar op. */ + } + /* Incremental index computation. */ + ++x_i1; + if (x_i1 >= x_d1) { + x_i1 = 0; + ++x_i2; + if (x_i2 >= x_d2) { + x_i2 = 0; + ++x_i3; + if (x_i3 >= x_d3) { + x_i3 = 0; + ++x_i4; + if (x_i4 >= x_d4) { + x_i4 = 0; + ++x_i5; + } + } + } + } + } + } +} + +static void MAG_HOTPROC mag_blas_adds_f32( + const mag_cpu_thread_local_ctx_t* const bci, + mag_tensor_t* const r, + const mag_tensor_t** const inputs /* Assumes correct inputs for op, all != NULL! */ +) { + (void)bci; + const mag_tensor_t* const x = inputs[0]; + const float xi = r->op_params->x.f32; + float* const b_r = mag_f32p_mut(r); + const float* const b_x = mag_f32p(x); + mag_load_local_storage_group(r, r_s, strides); + mag_load_local_storage_group(x, x_s, strides); + const int64_t rc = mag_tensor_num_rows(x); + const int64_t cc = mag_tensor_num_cols(x); + for (int64_t ri=0; ri < rc; ++ri) { + float* const p_r = b_r + ri*r_s1; + const float* const p_x = b_x + ri*x_s1; + mag_bnd_chk(p_r, b_r, mag_tensor_data_size(r)); + mag_bnd_chk(p_x, b_x, mag_tensor_data_size(x)); + mag_vadds_f32(cc, p_r, p_x, xi); + } +} + +static void MAG_HOTPROC mag_blas_subs_f32( + const mag_cpu_thread_local_ctx_t* const bci, + mag_tensor_t* const r, + const mag_tensor_t** const inputs /* Assumes correct inputs for op, all != NULL! */ +) { + (void)bci; + const mag_tensor_t* const x = inputs[0]; + const float xi = r->op_params->x.f32; + float* const b_r = mag_f32p_mut(r); + const float* const b_x = mag_f32p(x); + mag_load_local_storage_group(r, r_s, strides); + mag_load_local_storage_group(x, x_s, strides); + const int64_t rc = mag_tensor_num_rows(x); + const int64_t cc = mag_tensor_num_cols(x); + for (int64_t ri=0; ri < rc; ++ri) { + float* const p_r = b_r + ri*r_s1; + const float* const p_x = b_x + ri*x_s1; + mag_bnd_chk(p_r, b_r, mag_tensor_data_size(r)); + mag_bnd_chk(p_x, b_x, mag_tensor_data_size(x)); + mag_vsubs_f32(cc, p_r, p_x, xi); + } +} + +static void MAG_HOTPROC mag_blas_muls_f32( + const mag_cpu_thread_local_ctx_t* const bci, + mag_tensor_t* const r, + const mag_tensor_t** const inputs /* Assumes correct inputs for op, all != NULL! */ +) { + (void)bci; + const mag_tensor_t* const x = inputs[0]; + const float xi = r->op_params->x.f32; + float* const b_r = mag_f32p_mut(r); + const float* const b_x = mag_f32p(x); + mag_load_local_storage_group(r, r_s, strides); + mag_load_local_storage_group(x, x_s, strides); + const int64_t rc = mag_tensor_num_rows(x); + const int64_t cc = mag_tensor_num_cols(x); + for (int64_t ri=0; ri < rc; ++ri) { + float* const p_r = b_r + ri*r_s1; + const float* const p_x = b_x + ri*x_s1; + mag_bnd_chk(p_r, b_r, mag_tensor_data_size(r)); + mag_bnd_chk(p_x, b_x, mag_tensor_data_size(x)); + mag_vmuls_f32(cc, p_r, p_x, xi); + } +} + +static void MAG_HOTPROC mag_blas_divs_f32( + const mag_cpu_thread_local_ctx_t* const bci, + mag_tensor_t* const r, + const mag_tensor_t** const inputs /* Assumes correct inputs for op, all != NULL! */ +) { + (void)bci; + const mag_tensor_t* const x = inputs[0]; + const float xi = r->op_params->x.f32; + float* const b_r = mag_f32p_mut(r); + const float* const b_x = mag_f32p(x); + mag_load_local_storage_group(r, r_s, strides); + mag_load_local_storage_group(x, x_s, strides); + const int64_t rc = mag_tensor_num_rows(x); + const int64_t cc = mag_tensor_num_cols(x); + for (int64_t ri=0; ri < rc; ++ri) { + float* const p_r = b_r + ri*r_s1; + const float* const p_x = b_x + ri*x_s1; + mag_bnd_chk(p_r, b_r, mag_tensor_data_size(r)); + mag_bnd_chk(p_x, b_x, mag_tensor_data_size(x)); + mag_vdivs_f32(cc, p_r, p_x, xi); + } +} + +#if 0 /* Naive matrix multiplication, but no broadcasting support. */ +static void MAG_HOTPROC mag_blas_matmul_f32( + const mag_cpu_thread_local_ctx_t* const bci, + mag_tensor_t* const r, + const mag_tensor_t** const inputs /* Assumes correct inputs for op, all != NULL! */ +) { + const mag_tensor_t* const x = inputs[0]; + const mag_tensor_t* const y = inputs[1]; + float* const b_r = mag_f32p_mut(r); + const float* const b_x = mag_f32p(x); + const uint8_t* const b_y = (const uint8_t*)y->buf; + mag_load_local_storage_group(r, r_d, shape); + mag_load_local_storage_group(r, r_s, strides); + mag_load_local_storage_group(x, x_d, shape); + mag_load_local_storage_group(x, x_s, strides); + mag_load_local_storage_group(y, y_d, shape); + mag_load_local_storage_group(y, y_s, strides); + for (int64_t i3=0; i3 < r_d3; ++i3) { + for (int64_t i2=0; i2 < r_d2; ++i2) { + for (int64_t i1=0; i1 < r_d1; ++i1) { + for (int64_t i0=0; i0 < r_d1; ++i0) { + double sum = 0.0; + for (int64_t k=0; k < x_d0; ++k) { + const float* const p_x = (const float*)(b_x + k*x_s0 + i0*x_s1 + i2*x_s2 + i3*x_s3); + const float* const p_y = (const float*)(b_y + i1*y_s0 + k*y_s1 + i2*y_s2 + i3*y_s3); + mag_bnd_chk(p_x, b_x, x->buf_size); + mag_bnd_chk(p_y, b_y, y->buf_size); + sum += (double)(*p_x**p_y); + } + float* const p_r = (float*)(b_r + i1*r_s0 + i0*r_s1 + i2*r_s2 + i3*r_s3); + mag_bnd_chk(p_r, b_r, r->buf_size); + *p_r = (float)sum; + } + } + } + } +} +#endif + +/* +** Matrix multiplication. +** R = A x B +*/ +static void MAG_HOTPROC mag_blas_matmul_f32( + const mag_cpu_thread_local_ctx_t* const bci, + mag_tensor_t* const r, + const mag_tensor_t** const inputs /* Assumes correct inputs for op, all != NULL! */ +) { + const mag_tensor_t* const x = inputs[0]; + const mag_tensor_t* const y = inputs[1]; + float* const b_r = mag_f32p_mut(r); + const float* const b_x = mag_f32p(x); + const float* const b_y = mag_f32p(y); + mag_load_local_storage_group(r, r_d, shape); + mag_load_local_storage_group(r, r_s, strides); + mag_load_local_storage_group(x, x_d, shape); + mag_load_local_storage_group(x, x_s, strides); + mag_load_local_storage_group(y, y_d, shape); + mag_load_local_storage_group(y, y_s, strides); + mag_assert2(x_d2 == 1 && x_d3 == 1); + mag_assert2(y_d2 == 1 && y_d3 == 1); + const int64_t rows = x_d0; + const int64_t cols = x_d1; + const int64_t inners = y_d1; + for (int64_t i=0; i < rows; ++i) { + for (int64_t k=0; k < cols; ++k) { + const float* const p_x = b_x + x_d1*i + k; + mag_bnd_chk(p_x, b_x, mag_tensor_data_size(x)); + for (int64_t j = 0; j < inners; ++j) { + float* const p_r = b_r + r_d1*i + j; + const float* const p_y = b_y + y_d1*k + j; + mag_bnd_chk(p_r, b_r, mag_tensor_data_size(r)); + mag_bnd_chk(p_y, b_y, mag_tensor_data_size(y)); + *p_r += *p_x * *p_y; + } + } + } +} + +static void (*mag_blas_dispatch_table_forward[MAG_OP__NUM])(const mag_cpu_thread_local_ctx_t*, mag_tensor_t*, const mag_tensor_t**) = { + [MAG_OP_NOP] = &mag_blas_nop, /* No operation */ + [MAG_OP_CLONE] = &mag_blas_clone, + [MAG_OP_VIEW] = &mag_blas_nop, /* View is a no-op */ + [MAG_OP_TRANSPOSE] = &mag_blas_nop, /* Transpose is a runtime no-op */ + [MAG_OP_PERMUTE] = &mag_blas_nop, /* Transpose is a runtime no-op */ + [MAG_OP_MEAN] = &mag_blas_mean_f32, + [MAG_OP_MIN] = &mag_blas_min_f32, + [MAG_OP_MAX] = &mag_blas_max_f32, + [MAG_OP_SUM] = &mag_blas_sum_f32, + [MAG_OP_ABS] = &mag_blas_abs_f32, + [MAG_OP_NEG] = &mag_blas_neg_f32, + [MAG_OP_LOG] = &mag_blas_log_f32, + [MAG_OP_SQR] = &mag_blas_sqr_f32, + [MAG_OP_SQRT] = &mag_blas_sqrt_f32, + [MAG_OP_SIN] = &mag_blas_sin_f32, + [MAG_OP_COS] = &mag_blas_cos_f32, + [MAG_OP_STEP] = &mag_blas_step_f32, + [MAG_OP_SOFTMAX] = &mag_blas_softmax_f32, + [MAG_OP_SOFTMAX_DV] = &mag_blas_softmax_dv_f32, + [MAG_OP_SIGMOID] = &mag_blas_sigmoid_f32, + [MAG_OP_SIGMOID_DV] = &mag_blas_sigmoid_dv_f32, + [MAG_OP_HARD_SIGMOID] = &mag_blas_hard_sigmoid_f32, + [MAG_OP_SILU] = &mag_blas_silu_f32, + [MAG_OP_SILU_DV] = &mag_blas_silu_dv_f32, + [MAG_OP_TANH] = &mag_blas_tanh_f32, + [MAG_OP_TANH_DV] = &mag_blas_tanh_dv_f32, + [MAG_OP_RELU] = &mag_blas_relu_f32, + [MAG_OP_RELU_DV] = &mag_blas_relu_dv_f32, + [MAG_OP_GELU] = &mag_blas_gelu_f32, + [MAG_OP_GELU_DV] = &mag_blas_gelu_dv_f32, + [MAG_OP_ADD] = &mag_blas_add_f32, + [MAG_OP_SUB] = &mag_blas_sub_f32, + [MAG_OP_MUL] = &mag_blas_mul_f32, + [MAG_OP_DIV] = &mag_blas_div_f32, + [MAG_OP_ADDS] = &mag_blas_adds_f32, + [MAG_OP_SUBS] = &mag_blas_subs_f32, + [MAG_OP_MULS] = &mag_blas_muls_f32, + [MAG_OP_DIVS] = &mag_blas_divs_f32, + [MAG_OP_MATMUL] = &mag_blas_matmul_f32 +}; + +static void (*mag_blas_dispatch_table_backward[MAG_OP__NUM])(const mag_cpu_thread_local_ctx_t*, mag_tensor_t*, const mag_tensor_t**) = { + [MAG_OP_NOP] = &mag_blas_nop, /* No operation */ + [MAG_OP_CLONE] = &mag_blas_clone, + [MAG_OP_VIEW] = &mag_blas_nop, /* View is a no-op */ + [MAG_OP_TRANSPOSE] = &mag_blas_nop, /* Transpose is a runtime no-op */ + [MAG_OP_PERMUTE] = &mag_blas_nop, /* Transpose is a runtime no-op */ + [MAG_OP_MEAN] = &mag_blas_mean_f32, + [MAG_OP_MIN] = &mag_blas_min_f32, + [MAG_OP_MAX] = &mag_blas_max_f32, + [MAG_OP_SUM] = &mag_blas_sum_f32, + [MAG_OP_ABS] = &mag_blas_abs_f32, + [MAG_OP_NEG] = &mag_blas_neg_f32, + [MAG_OP_LOG] = &mag_blas_log_f32, + [MAG_OP_SQR] = &mag_blas_sqr_f32, + [MAG_OP_SQRT] = &mag_blas_sqrt_f32, + [MAG_OP_SIN] = &mag_blas_sin_f32, + [MAG_OP_COS] = &mag_blas_cos_f32, + [MAG_OP_STEP] = &mag_blas_step_f32, + [MAG_OP_SOFTMAX] = &mag_blas_softmax_f32, + [MAG_OP_SOFTMAX_DV] = &mag_blas_softmax_dv_f32, + [MAG_OP_SIGMOID] = &mag_blas_sigmoid_f32, + [MAG_OP_SIGMOID_DV] = &mag_blas_sigmoid_dv_f32, + [MAG_OP_HARD_SIGMOID] = &mag_blas_hard_sigmoid_f32, + [MAG_OP_SILU] = &mag_blas_silu_f32, + [MAG_OP_SILU_DV] = &mag_blas_silu_dv_f32, + [MAG_OP_TANH] = &mag_blas_tanh_f32, + [MAG_OP_TANH_DV] = &mag_blas_tanh_dv_f32, + [MAG_OP_RELU] = &mag_blas_relu_f32, + [MAG_OP_RELU_DV] = &mag_blas_relu_dv_f32, + [MAG_OP_GELU] = &mag_blas_gelu_f32, + [MAG_OP_GELU_DV] = &mag_blas_gelu_dv_f32, + [MAG_OP_ADD] = &mag_blas_add_f32, + [MAG_OP_SUB] = &mag_blas_sub_f32, + [MAG_OP_MUL] = &mag_blas_mul_f32, + [MAG_OP_DIV] = &mag_blas_div_f32, + [MAG_OP_ADDS] = &mag_blas_adds_f32, + [MAG_OP_SUBS] = &mag_blas_subs_f32, + [MAG_OP_MULS] = &mag_blas_muls_f32, + [MAG_OP_DIVS] = &mag_blas_divs_f32, + [MAG_OP_MATMUL] = &mag_blas_matmul_f32 +}; + +static MAG_HOTPROC void mag_cpu_exec_fwd(mag_compute_device_t* dvc, mag_tensor_t* root) { + mag_cpu_thread_local_ctx_t* bci = (mag_cpu_thread_local_ctx_t*)(dvc+1); // todo + mag_blas_dispatch_table_forward[root->op](bci, root, (const mag_tensor_t**)root->op_inputs); +} + +static MAG_HOTPROC void mag_cpu_exec_bwd(mag_compute_device_t* dvc, mag_tensor_t* root) { + mag_cpu_thread_local_ctx_t* bci = (mag_cpu_thread_local_ctx_t*)(dvc+1); // todo + mag_blas_dispatch_table_backward[root->op](bci, root, (const mag_tensor_t**)root->op_inputs); +} + +static void mag_cpu_buf_set(mag_storage_buffer_t* sto, size_t offs, uint8_t x) { + mag_assert2(sto->base+offs <= sto->base+sto->size); + memset((void*)(sto->base+offs), x, sto->size-offs); /* On CPU just plain old memset with offset. */ +} + +static void mag_cpu_buf_cpy_host_device(mag_storage_buffer_t* sto, size_t offs, const void* src, size_t n) { + mag_assert2(sto->base+offs+n <= sto->base+sto->size); + memcpy((void*)(sto->base+offs), src, n); /* On CPU just plain old memcpy with offset. */ +} + +static void mag_cpu_buf_cpy_device_host(mag_storage_buffer_t* sto, size_t offs, void* dst, size_t n) { + mag_assert2(sto->base+offs+n <= sto->base+sto->size); + memcpy(dst, (void*)(sto->base+offs), n); /* On CPU just plain old memcpy with offset. */ +} + +static void mag_cpu_alloc_storage(mag_compute_device_t* host, mag_storage_buffer_t* out, size_t size, size_t align) { + mag_assert2(size); + void* block = mag_alloc_aligned(size, align); + memset(block, 0, size); + *out = (mag_storage_buffer_t){ /* Set up storage buffer. */ + .base = (uintptr_t)block, + .size = size, + .alignment = align, + .host = host, + .set = &mag_cpu_buf_set, + .cpy_host_device = &mag_cpu_buf_cpy_host_device, + .cpy_device_host = &mag_cpu_buf_cpy_device_host + }; +} + +static void mag_cpu_free_storage(mag_compute_device_t* dvc, mag_storage_buffer_t* buf) { + mag_free_aligned((void*)buf->base); + memset(buf, 0, sizeof(*buf)); /* Set to zero. */ +} + +static mag_compute_device_t* mag_cpu_init_interface(mag_ctx_t* ctx) { + mag_compute_device_t* dvc = (*mag_alloc)(NULL, sizeof(*dvc)); + *dvc = (mag_compute_device_t){ + .name = "CPU", + .impl = NULL, + .is_async = false, + .type = MAG_COMPUTE_DEVICE_TYPE_CPU, + .eager_exec_fwd = &mag_cpu_exec_fwd, + .eager_exec_bwd = &mag_cpu_exec_bwd, + .alloc_storage = &mag_cpu_alloc_storage, + .free_storage = &mag_cpu_free_storage + }; + snprintf(dvc->name, sizeof(dvc->name), "%s - %s", mag_device_type_get_name(dvc->type), ctx->sys.cpu_name); + return dvc; +} + +static void mag_cpu_release_interface(mag_compute_device_t* ctx) { + (*mag_alloc)(ctx, 0); +} + +mag_compute_device_t* mag_init_device_cpu(mag_ctx_t* ctx, uint32_t num_threads) { + num_threads = num_threads ? num_threads : mag_xmax(1, ctx->sys.cpu_virtual_cores); + mag_compute_device_t* dvc = mag_cpu_init_interface(ctx); + return dvc; +} + +void mag_destroy_device_cpu(mag_compute_device_t* dvc) { + mag_cpu_release_interface(dvc); +} diff --git a/magnetron/magnetron_cuda.cu b/magnetron/magnetron_cuda.cu new file mode 100644 index 0000000..f9fb72f --- /dev/null +++ b/magnetron/magnetron_cuda.cu @@ -0,0 +1,158 @@ +/* (c) 2024 Mario "Neo" Sieg. */ + +#include "magnetron_cuda.cuh" + +#include +#include +#include +#include + +namespace mag::cuda { + /* Driver result check. */ + #define mag_cu_chk_rdv(expr) \ + do { \ + if (auto rrr {(expr)}; rrr != CUDA_SUCCESS) [[unlikely]] { \ + const char* err_str = "?"; \ + cuGetErrorString(rrr, &err_str); \ + mag_panic(#expr, __func__, __FILE__, __LINE__, err_str); \ + } \ + } while (0) + + /* Runtime result check. */ + #define mag_cu_chk_rt(expr) \ + do { \ + if (auto rrr {(expr)}; rrr != cudaSuccess) [[unlikely]] { \ + mag_panic(#expr, __func__, __FILE__, __LINE__, cudaGetErrorString(rrr)); \ + } \ + } while (0) + + #define mag_cu_assert(expr, msg, ...) \ + if (!(expr)) [[unlikely]] { \ + mag_panic("%s:%d Assertion failed: " #expr " <- " msg, __FILE__, __LINE__, ## __VA_ARGS__);\ + } + #define mag_cu_assert2(expr) mag_cu_assert(expr, "") + + static constinit bool s_is_init {}; + static constinit std::int32_t active_device_id {}; + static std::vector s_devices {}; + + auto mag_init_device_cuda([[maybe_unused]] mag_ctx_t* ctx) -> mag_compute_device_t* { + std::int32_t active_device_id {0}; // TODO: Implement device selection. + const physical_device* device {cuda_init(active_device_id)}; + if (!device) [[unlikely]] { /* No devices available or initialization failed, let runtime fallback to other compute device. */ + mag_log_error("No CUDA devices with id %d available, using CPU processing", active_device_id); + return nullptr; + } + auto* dvc {static_cast((*mag_alloc)(nullptr, sizeof(mag_compute_device_t)))}; + new (dvc) mag_compute_device_t { + .name = "GPU", + .impl = nullptr, + .is_async = true, + .type = MAG_COMPUTE_DEVICE_TYPE_GPU_CUDA, + .eager_exec_fwd = nullptr, + .eager_exec_bwd = nullptr, + .alloc_storage = nullptr, + .free_storage = nullptr + }; + double vram; + const char* unit; + mag_humanize_memory_size(device->vram, &vram, &unit); + std::snprintf(dvc->name, sizeof(dvc->name), "%s - %s - %.03f %s VRAM", mag_device_type_get_name(dvc->type), device->name.data(), vram, unit); + return dvc; + } + + void mag_destroy_device_cuda(mag_compute_device_t* dvc) { + s_is_init = false; + s_devices.clear(); + dvc->~mag_compute_device_t(); + (*mag_alloc)(dvc, 0); + } + + /* + ** Initialize CUDA runtime. Returns empty span if initialization failed or not devices are available. + ** Normally, we panic when some CUDA runtime function fails, but in this case we just return an empty span, + ** to allow the runtime to fall back to CPU processing, if no CUDA devices are available. + */ + auto cuda_init(std::int32_t use_device) -> const physical_device* { + if (s_is_init) return &s_devices[active_device_id]; + std::int32_t numgpus {}; + if (cudaGetDeviceCount(&numgpus) != cudaSuccess) [[unlikely]] return nullptr; + if (numgpus < 1 || numgpus > max_devices) [[unlikely]] return nullptr; + s_devices.reserve(numgpus); + for (std::int32_t id {}; id < numgpus; ++id) { /* Iterate over devices */ + physical_device& dvc {s_devices.emplace_back()}; + CUdevice cu_dvc {}; + std::int32_t vmm_support {}; + if (cuDeviceGet(&cu_dvc, id) != CUDA_SUCCESS) [[unlikely]] continue; /* Get device handle */ + if (cuDeviceGetAttribute(&vmm_support, CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED, cu_dvc) != CUDA_SUCCESS) [[unlikely]] continue; /* Check VMM support */ + if (vmm_support) { /* Virtual memory management supported */ + CUmemAllocationProp alloc_props {}; + alloc_props.type = CU_MEM_ALLOCATION_TYPE_PINNED; + alloc_props.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + alloc_props.location.id = id; + if (cuMemGetAllocationGranularity(&dvc.vmm_granularity, &alloc_props, CU_MEM_ALLOC_GRANULARITY_RECOMMENDED) != CUDA_SUCCESS) [[unlikely]] continue; /* Get VMM granularity */ + } + dvc.has_vmm = !!vmm_support; + cudaDeviceProp props {}; + if (cudaGetDeviceProperties(&props, id) != cudaSuccess) [[unlikely]] continue; /* Get device properties */ + dvc.id = id; + dvc.name = std::bit_cast(props.name); + dvc.nsm = props.multiProcessorCount; + dvc.smpb = props.sharedMemPerBlock; + dvc.smpb_opt = props.sharedMemPerBlockOptin; + dvc.cl = 100*props.major + 10*props.minor; + dvc.ntpb = props.maxThreadsPerBlock; + dvc.vram = props.totalGlobalMem; + } + active_device_id = std::clamp(use_device, 0, numgpus-1); + s_is_init = true; + return &s_devices[active_device_id]; + } + + vm_pool::vm_pool(const physical_device& dvc) { + m_granularity = dvc.vmm_granularity; + m_dvc_id = dvc.id; + } + + vm_pool::~vm_pool() { + if (m_dvc_address) { + mag_cu_chk_rdv(cuMemUnmap(m_dvc_address, m_cap)); + mag_cu_chk_rdv(cuMemAddressFree(m_dvc_address, max_size)); + } + } + + auto vm_pool::alloc(std::size_t sz, std::size_t align, std::size_t& out_sz) -> void* { + sz = (sz+align-1)&-align; /* Overallocate to ensure alignment. */ + if (std::size_t free {m_cap-m_needle}; sz > free) { + std::size_t reserve {sz-free}; + reserve = (reserve+m_granularity-1)&-m_granularity; /* Align to granularity. */ + mag_cu_assert2(m_cap+reserve <= max_size); + CUmemAllocationProp prop {}; + prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; + prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + prop.location.id = m_dvc_id; + CUmemGenericAllocationHandle handle {}; + mag_cu_chk_rdv(cuMemCreate(&handle, reserve, &prop, 0)); + if (!m_dvc_address) /* Reserve virtual memory */ + mag_cu_chk_rdv(cuMemAddressReserve(&m_dvc_address, max_size, 0, 0, 0)); + mag_cu_chk_rdv(cuMemMap(m_dvc_address+m_cap, reserve, 0, handle, 0)); + mag_cu_chk_rdv(cuMemRelease(handle)); /* Release handle */ + CUmemAccessDesc access {}; + access.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + access.location.id = m_dvc_id; + access.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; + mag_cu_chk_rdv(cuMemSetAccess(m_dvc_address+m_cap, reserve, &access, 1)); /* Set access */ + m_cap += reserve; + } + mag_cu_assert2(m_dvc_address); + auto* p {reinterpret_cast(m_dvc_address+m_needle)}; + out_sz = sz; + m_needle += sz; + return p; + } + + auto vm_pool::free(void* ptr, std::size_t sz) -> void { + m_needle -= sz; + mag_cu_assert2(ptr == reinterpret_cast(m_dvc_address + m_needle)); + } +} diff --git a/magnetron/magnetron_cuda.cuh b/magnetron/magnetron_cuda.cuh new file mode 100644 index 0000000..716e965 --- /dev/null +++ b/magnetron/magnetron_cuda.cuh @@ -0,0 +1,71 @@ +/* (c) 2024 Mario "Neo" Sieg. */ + +#pragma once + +#include "magnetron_internal.h" + +#include +#include +#include + +#include + +namespace mag::cuda { + extern "C" auto mag_init_device_cuda(mag_ctx_t* ctx) -> mag_compute_device_t*; /* Initialize GPU compute device. (Invoked from magnetron C core) */ + extern "C" auto mag_destroy_device_cuda(mag_compute_device_t* dvc) -> void; /* Destroy GPU compute device. (Invoked from magnetron C core) */ + + constexpr std::size_t max_devices = 32; + constexpr std::uint32_t warp_size = 32; + constexpr std::uint32_t max_streams = 8; + + struct physical_device final { + std::int32_t id {}; /* Device ID */ + std::array name {}; /* Device name */ + std::size_t vram {}; /* Video memory in bytes */ + std::uint32_t cl {}; /* Compute capability */ + std::uint32_t nsm {}; /* Number of SMs */ + std::uint32_t ntpb {}; /* Number of threads per block */ + std::size_t smpb {}; /* Shared memory per block */ + std::size_t smpb_opt {}; /* Shared memory per block opt-in */ + bool has_vmm {}; /* Has virtual memory management */ + std::size_t vmm_granularity {}; /* Virtual memory management granularity */ + }; + + /* Initialize CUDA runtime. Returns active device info if init successfully, else nullptr. */ + [[nodiscard]] extern auto cuda_init(std::int32_t use_device) -> const physical_device*; + + /* Memory pool interface. */ + class pool { + public: + pool(const pool&) = delete; + pool(pool&&) = delete; + auto operator=(const pool&) -> pool& = delete; + auto operator=(pool&&) -> pool& = delete; + virtual ~pool() = default; + + virtual auto alloc(std::size_t sz, std::size_t align, std::size_t& out_sz) -> void* = 0; + virtual auto free(void* ptr, std::size_t sz) -> void = 0; + + protected: + pool() = default; + }; + + /* Memory pool allocating virtual memory pages. */ + class vm_pool final : public pool { + public: + static constexpr std::size_t max_size = 64ull<<30; /* 64 GiB */ + + explicit vm_pool(const physical_device& dvc); + ~vm_pool() override; + + auto alloc(std::size_t sz, std::size_t align, std::size_t& out_sz) -> void* override; + auto free(void* ptr, std::size_t sz) -> void override; + + private: + CUdeviceptr m_dvc_address {}; + std::int32_t m_dvc_id {}; + std::size_t m_needle {}; + std::size_t m_cap {}; + std::size_t m_granularity {}; + }; +} diff --git a/magnetron/magnetron_device_registry.c b/magnetron/magnetron_device_registry.c new file mode 100644 index 0000000..e2d8016 --- /dev/null +++ b/magnetron/magnetron_device_registry.c @@ -0,0 +1,57 @@ +/* (c) 2024 Mario "Neo" Sieg. */ + +#include "magnetron_internal.h" + +extern mag_compute_device_t* mag_init_device_cpu(mag_ctx_t* ctx); /* Initialize CPU compute device. Implemented in magnetron_cpu.c */ +extern void mag_destroy_device_cpu(mag_compute_device_t* dvc); /* Destroy CPU compute device. Implemented in magnetron_cpu.c */ + +#ifdef MAG_ENABLE_CUDA +extern mag_compute_device_t* mag_init_device_cuda(mag_ctx_t* ctx); /* Initialize GPU compute device. Implemented in magnetron_cpu.c */ +extern void mag_destroy_device_cuda(mag_compute_device_t* dvc); /* Destroy GPU compute device. Implemented in magnetron_cpu.c */ +#endif + +#define MAG_DEVICE_FALLBACK MAG_COMPUTE_DEVICE_TYPE_CPU + +static const mag_device_factory_t* const mag_device_factories[MAG_COMPUTE_DEVICE_TYPE__NUM] = { + [MAG_COMPUTE_DEVICE_TYPE_CPU] = &(mag_device_factory_t){ + .init = &mag_init_device_cpu, + .destroy = &mag_destroy_device_cpu, + }, + +#ifdef MAG_ENABLE_CUDA + [MAG_COMPUTE_DEVICE_TYPE_GPU_CUDA] = &(mag_device_factory_t){ + .init = &mag_init_device_cuda, + .destroy = &mag_destroy_device_cuda, + }, +#else + [MAG_COMPUTE_DEVICE_TYPE_GPU_CUDA] = NULL, /* CUDA not enabled. */ +#endif +}; + +mag_compute_device_t* mag_init_dynamic_device(mag_ctx_t* ctx, mag_compute_device_type_t* type) { + mag_assert2(ctx && *type < MAG_COMPUTE_DEVICE_TYPE__NUM); + mag_assert2(mag_device_factories[MAG_DEVICE_FALLBACK]); /* Fallback factory must be present. */ + const mag_device_factory_t* factory = mag_device_factories[*type]; + if (mag_unlikely(!factory)) { + mag_log_error("No device factory for type '%s', falling back to CPU.", mag_device_type_get_name(*type)); + goto fallback; + } + mag_compute_device_t* dvc = (*factory->init)(ctx); + if (mag_unlikely(!dvc)) { + mag_log_error("Failed to initialize device of type '%s', falling back to CPU.", mag_device_type_get_name(*type)); + goto fallback; + } + return dvc; + fallback: + *type = MAG_DEVICE_FALLBACK; + factory = mag_device_factories[*type]; + mag_assert2(factory); + dvc = (*factory->init)(ctx); + mag_assert2(dvc); /* Ensure fallback device is created. */ + return dvc; +} + +void mag_destroy_dynamic_device(mag_compute_device_t* dvc) { + if (!dvc) return; + (*mag_device_factories[dvc->type]->destroy)(dvc); +} diff --git a/magnetron/magnetron_internal.h b/magnetron/magnetron_internal.h new file mode 100644 index 0000000..4de9c3a --- /dev/null +++ b/magnetron/magnetron_internal.h @@ -0,0 +1,520 @@ +/* (c) 2024 Mario "Neo" Sieg. */ + +#ifndef MAGNETRON_INTERNAL_H +#define MAGNETRON_INTERNAL_H + +#include "magnetron.h" + +#include +#include + +#ifdef _MSC_VER +#include +#else +#ifdef __aarch64__ +#include +#include +#elif defined(__x86_64__) || defined(_M_X64) +#include +#include +#endif +#endif + +#if defined(__GLIBC__) || defined(__GNU_LIBRARY__) || defined(__ANDROID__) +#include +#elif defined(__APPLE__) && defined(__MACH__) +#include +#elif defined(BSD) || defined(_SYSTYPE_BSD) +#if defined(__OpenBSD__) +#include +#else +#include +#endif +#endif + +#ifdef __cplusplus +extern "C" { +#endif + +#define MAG_GELU_COEFF 0.044715f +#define MAG_GRA_FWD MAG_GRAPH_EVAL_ORDER_FORWARD +#define MAG_GRA_BWD MAG_GRAPH_EVAL_ORDER_REVERSE +#define MAG_GRA_LEN 2 +#define MAG_MAX_CPUS 8192 +#define MAG_MAX_NUMA_NODES 64 +#define MAG_STORAGE_EXT ".magnetron" + +#if defined(__GNUC__) || defined(__clang__) || defined(__INTEL_COMPILER) + +#define MAG_NORET __attribute__((noreturn)) +#define MAG_ALIGN(x) __attribute__((aligned(x))) +#define MAG_AINLINE inline __attribute__((always_inline)) +#define MAG_NOINLINE __attribute__((noinline)) +#define MAG_HOTPROC __attribute__((hot)) +#define MAG_COLDPROC __attribute__((cold)) +#define MAG_PACKED __attribute__((packed)) +#define MAG_FALLTHROUGH __attribute__((fallthrough)) +#define MAG_UNUSED __attribute__((unused)) +#define mag_likely(x) __builtin_expect(!!(x), 1) +#define mag_unlikely(x) __builtin_expect(!!(x), 0) +#define mag_ffs(x) ((uint32_t)__builtin_ctz(x)) +#define mag_fls(x) ((uint32_t)(__builtin_clz(x)^31)) +#define mag_ffs64(x) ((uint32_t)__builtin_ctzll(x)) +#define mag_fls64(x) ((uint32_t)(__builtin_clzll(x)^63)) + +typedef int32_t mag_atomic_t; /* Atomic integer type */ +typedef enum mag_mo_t { /* Atomic memory order */ + MAG_MO_RELAXED = __ATOMIC_RELAXED, + MAG_MO_CONSUME = __ATOMIC_CONSUME, + MAG_MO_ACQUIRE = __ATOMIC_ACQUIRE, + MAG_MO_RELEASE = __ATOMIC_RELEASE, + MAG_MO_ACQ_REL = __ATOMIC_ACQ_REL, + MAG_MO_SEQ_CST = __ATOMIC_SEQ_CST +} mag_mo_t; + +static MAG_AINLINE void mag_atomic_store(volatile mag_atomic_t* o, mag_atomic_t x, mag_mo_t order) { + __atomic_store_n(o, x, order); +} +static MAG_AINLINE mag_atomic_t mag_atomic_load(volatile mag_atomic_t* o, mag_mo_t order) { + return __atomic_load_n(o, order); +} +static MAG_AINLINE mag_atomic_t mag_atomic_fetch_add(volatile mag_atomic_t* o, mag_atomic_t x, mag_mo_t order) { + return __atomic_fetch_add(o, x, order); +} +static MAG_AINLINE mag_atomic_t mag_atomic_fetch_sub(volatile mag_atomic_t* o, mag_atomic_t x, mag_mo_t order) { + return __atomic_fetch_sub(o, x, order); +} +static MAG_AINLINE mag_atomic_t mag_atomic_fetch_and(volatile mag_atomic_t* o, mag_atomic_t x, mag_mo_t order) { + return __atomic_fetch_and(o, x, order); +} +static MAG_AINLINE mag_atomic_t mag_atomic_fetch_or(volatile mag_atomic_t* o, mag_atomic_t x, mag_mo_t order) { + return __atomic_fetch_or(o, x, order); +} +static MAG_AINLINE mag_atomic_t mag_atomic_fetch_xor(volatile mag_atomic_t* o, mag_atomic_t x, mag_mo_t order) { + return __atomic_fetch_xor(o, x, order); +} +static MAG_AINLINE mag_atomic_t mag_atomic_exchange(volatile mag_atomic_t* o, mag_atomic_t x, mag_mo_t order) { + return __atomic_exchange_n(o, x, order); +} +static MAG_AINLINE bool mag_atomic_compare_exchange_weak(volatile mag_atomic_t* o, mag_atomic_t *exp, mag_atomic_t *des, mag_mo_t order_succ, mag_mo_t order_fail) { + return __atomic_compare_exchange(o, exp, des, true, order_succ, order_fail); +} +static MAG_AINLINE bool mag_atomic_compare_exchange_strong(volatile mag_atomic_t* o, mag_atomic_t *exp, mag_atomic_t *des, mag_mo_t order_succ, mag_mo_t order_fail) { + return __atomic_compare_exchange(o, exp, des, false, order_succ, order_fail); +} + +#else + +unsigned char _BitScanForward64(unsigned long*, unsigned __int64); +unsigned char _BitScanReverse64(unsigned long*, unsigned __int64); +#pragma intrinsic(_BitScanForward64) +#pragma intrinsic(_BitScanReverse64) +#define MAG_NORET __declspec(noreturn) +#define MAG_ALIGN(x) __declspec(align(x)) +#define MAG_AINLINE inline __forceinline +#define MAG_NOINLINE __declspec(noinline) +#define MAG_HOTPROC +#define MAG_COLDPROC +#define MAG_PACKED __declspec(align(1)) +#define MAG_FALLTHROUGH +#define MAG_UNUSED +#define mag_likely(x) (x) +#define mag_unlikely(x) (x) +static MAG_AINLINE uint32_t mag_ffs(const uint32_t x) { + unsigned long r; _BitScanForward(&r, x); return (uint32_t)r; +} +static MAG_AINLINE uint32_t mag_fls(const uint32_t x) { + unsigned long r; _BitScanReverse(&r, x); return (uint32_t)r; +} +static MAG_AINLINE uint32_t mag_ffs64(const uint64_t x) { + unsigned long r; _BitScanForward64(&r, x); return (uint32_t)r; +} +static MAG_AINLINE uint32_t mag_fls64(const uint64_t x) { + unsigned long r; _BitScanReverse64(&r, x); return (uint32_t)r; +} +#define __alignof__ __alignof + +typedef long mag_atomic_t; /* Atomic integer type */ +typedef enum mag_mo_t { /* Atomic memory order */ + MAG_MO_RELAXED, + MAG_MO_CONSUME, + MAG_MO_ACQUIRE, + MAG_MO_RELEASE, + MAG_MO_ACQ_REL, + MAG_MO_SEQ_CST +} mag_mo_t; + +static MAG_AINLINE void mag_atomic_store(volatile mag_atomic_t* o, mag_atomic_t x, mag_mo_t order) { + (void)order; _InterlockedExchange(o, x); +} +static MAG_AINLINE mag_atomic_t mag_atomic_load(volatile mag_atomic_t* o, mag_mo_t order) { + (void)order; + mag_atomic_t r; + _InterlockedExchange(&r, *o); + return r; +} +static MAG_AINLINE mag_atomic_t mag_atomic_fetch_add(volatile mag_atomic_t* o, mag_atomic_t x, mag_mo_t order) { + (void)order; + return _InterlockedExchangeAdd(o, x); +} +static MAG_AINLINE mag_atomic_t mag_atomic_fetch_sub(volatile mag_atomic_t* o, mag_atomic_t x, mag_mo_t order) { + (void)order; + return _InterlockedExchangeAdd(o, -x); +} +static MAG_AINLINE mag_atomic_t mag_atomic_fetch_and(volatile mag_atomic_t* o, mag_atomic_t x, mag_mo_t order) { + (void)order; + return _InterlockedAnd(o, x); +} +static MAG_AINLINE mag_atomic_t mag_atomic_fetch_or(volatile mag_atomic_t* o, mag_atomic_t x, mag_mo_t order) { + (void)order; + return _InterlockedOr(o, x); +} +static MAG_AINLINE mag_atomic_t mag_atomic_fetch_xor(volatile mag_atomic_t* o, mag_atomic_t x, mag_mo_t order) { + (void)order; + return _InterlockedXor(o, x); +} +static MAG_AINLINE mag_atomic_t mag_atomic_exchange(volatile mag_atomic_t* o, mag_atomic_t x, mag_mo_t order) { + (void)order; + return _InterlockedExchange(o, x); +} +static MAG_AINLINE bool mag_atomic_compare_exchange_weak(volatile mag_atomic_t* o, mag_atomic_t *exp, mag_atomic_t *des, mag_mo_t order_succ, mag_mo_t order_fail) { + (void)order_succ; (void)order_fail; + return _InterlockedCompareExchange(o, *des, *exp) == *exp; +} +static MAG_AINLINE bool mag_atomic_compare_exchange_strong(volatile mag_atomic_t* o, mag_atomic_t *exp, mag_atomic_t *des, mag_mo_t order_succ, mag_mo_t order_fail) { + (void)order_succ; (void)order_fail; + return _InterlockedCompareExchange(o, *des, *exp) == *exp; +} + +#endif + +mag_static_assert(sizeof(0u) == 4); +mag_static_assert(sizeof(0ull) == 8); + +#ifdef __BYTE_ORDER +#if defined(__BIG_ENDIAN) && (__BYTE_ORDER == __BIG_ENDIAN) +#define MAG_BE +#elif defined(__LITTLE_ENDIAN) && (__BYTE_ORDER == __LITTLE_ENDIAN) +#define MAG_LE +#endif +#elif defined(_BYTE_ORDER) +#if defined(_BIG_ENDIAN) && (_BYTE_ORDER == _BIG_ENDIAN) +#define MAG_BE +#elif defined(_LITTLE_ENDIAN) && (_BYTE_ORDER == _LITTLE_ENDIAN) +#define MAG_LE +#endif +#elif defined(__BIG_ENDIAN__) +#define MAG_BE +#elif defined(__LITTLE_ENDIAN__) +#define MAG_LE +#else +#if defined(__ARMEL__) || defined(__THUMBEL__) || defined(__AARCH64EL__) || \ +defined(_MIPSEL) || defined(__MIPSEL) || defined(__MIPSEL__) || \ +defined(__ia64__) || defined(_IA64) || defined(__IA64__) || defined(__ia64) || \ +defined(_M_IA64) || defined(__itanium__) || defined(i386) || defined(__i386__) || \ +defined(__i486__) || defined(__i586__) || defined(__i686__) || defined(__i386) || \ +defined(_M_IX86) || defined(_X86_) || defined(__THW_INTEL__) || defined(__I86__) || \ +defined(__INTEL__) || defined(__x86_64) || defined(__x86_64__) || \ +defined(__amd64__) || defined(__amd64) || defined(_M_X64) || \ +defined(__bfin__) || defined(__BFIN__) || defined(bfin) || defined(BFIN) +#define MAG_LE +#elif defined(__m68k__) || defined(M68000) || defined(__hppa__) || defined(__hppa) || defined(__HPPA__) || \ +defined(__sparc__) || defined(__sparc) || defined(__370__) || defined(__THW_370__) || \ +defined(__s390__) || defined(__s390x__) || defined(__SYSC_ZARCH__) +#define MAG_BE +#elif defined(__arm__) || defined(__arm64) || defined(__thumb__) || \ +defined(__TARGET_ARCH_ARM) || defined(__TARGET_ARCH_THUMB) || defined(__ARM_ARCH) || \ +defined(_M_ARM) || defined(_M_ARM64) +#if defined(_WIN32) || defined(_WIN64) || \ +defined(__WIN32__) || defined(__TOS_WIN__) || defined(__WINDOWS__) +#define MAG_LE +#else +#error "Unknown endianness" +#endif +#endif +#endif + +static uint32_t MAG_AINLINE mag_bswap32(uint32_t x) { /* Swap bytes for endianess switch. Should be optimized to a (bswap/rev) instruction on modern compilers. */ + #ifdef MAG_BE + #if (defined(__GNUC__) && ((__GNUC__ > 4) || (__GNUC__ == 4 && __GNUC_MINOR__ >= 3))) || defined(__clang__) + x = (uint32_t)__builtin_bswap32((int32_t)x); + #else + x = (x & 0xff000000) >> 24 | + (x & 0xff0000) >> 8 | + (x & 0xff00) << 8 | + (x & 0xff) << 24; + #endif + #endif + return x; +} + +static uint64_t MAG_AINLINE mag_bswap64(uint64_t x) { /* Swap bytes for endianess switch. Should be optimized to a (bswap/rev) instruction on modern compilers. */ + #ifdef MAG_BE + #if (defined(__GNUC__) && ((__GNUC__ > 4) || (__GNUC__ == 4 && __GNUC_MINOR__ >= 3))) || defined(__clang__) + x = (uint64_t)__builtin_bswap64((int64_t)x); + #else + x = (x & 0xff00000000000000) >> 56 | + (x & 0xff000000000000) >> 40 | + (x & 0xff0000000000) >> 24 | + (x & 0xff00000000) >> 8 | + (x & 0xff000000) << 8 | + (x & 0xff0000) << 24 | + (x & 0xff00) << 40 | + (x & 0xff) << 56; + #endif + #endif + return x; +} + +extern MAG_NORET MAG_COLDPROC MAG_EXPORT void mag_panic(const char* msg, ...); +extern MAG_EXPORT bool mag_log_enabled; +extern MAG_EXPORT void* (*mag_alloc)(void* blk, size_t size); +extern MAG_EXPORT void* mag_alloc_aligned(size_t size, size_t align); +extern MAG_EXPORT void mag_free_aligned(void* blk); + +extern MAG_EXPORT void mag_humanize_memory_size(size_t n, double* out, const char** unit); + +#define mag_swap(T, a, b) do { T tmp = (a); (a) = (b); (b) = tmp; } while (0) +#define mag_xmax(x, y) (((x) > (y)) ? (x) : (y)) +#define mag_xmin(x, y) (((x) < (y)) ? (x) : (y)) +#define MAG_CC_RED "\x1b[31m" +#define MAG_CC_GREEN "\x1b[32m" +#define MAG_CC_YELLOW "\x1b[33m" +#define MAG_CC_BLUE "\x1b[34m" +#define MAG_CC_MAGENTA "\x1b[35m" +#define MAG_CC_CYAN "\x1b[36m" +#define MAG_CC_RESET "\x1b[0m" +#define MAG_STRINGIZE2(x) #x +#define MAG_STRINGIZE(x) MAG_STRINGIZE2(x) +#ifdef __FILE_NAME__ +# define MAG_SRC_NAME __FILE_NAME__ ":" MAG_STRINGIZE(__LINE__) +#else +# define MAG_SRC_NAME __FILE__ ":" MAG_STRINGIZE(__LINE__) +#endif +#define mag_log_info(msg, ...) do { if (mag_unlikely(mag_log_enabled)) fprintf(stdout, MAG_CC_CYAN "[magnetron] " MAG_CC_RESET msg "\n", ## __VA_ARGS__); } while (0) +#define mag_log_info_force(msg, ...) do { fprintf(stdout, MAG_CC_CYAN "[magnetron] " MAG_CC_RESET msg "\n", ## __VA_ARGS__); } while (0) +#define mag_log_warn(msg, ...) do { fprintf(stdout, MAG_CC_CYAN "[magnetron] " MAG_CC_RESET MAG_SRC_NAME " " MAG_CC_YELLOW msg MAG_CC_RESET "\n", ## __VA_ARGS__); fflush(stdout); } while (0) +#define mag_log_error(msg, ...) do { fprintf(stdout, MAG_CC_CYAN "[magnetron] " MAG_CC_RESET MAG_SRC_NAME " " MAG_CC_RED msg MAG_CC_RESET "\n", ## __VA_ARGS__); fflush(stdout); } while (0) + +#define mag_assert(expr, msg, ...) \ + if (mag_unlikely(!(expr))) { \ + mag_panic("%s:%d Assertion failed: " #expr " <- " msg, __FILE__, __LINE__, ## __VA_ARGS__);\ + } +#define mag_assert2(expr) mag_assert(expr, "") + +#if MAG_BOUNDS_CHECK +#define mag_bnd_chk(ptr, base, n) \ + mag_assert((uintptr_t)(ptr) >= (uintptr_t)(base) && (uintptr_t)(ptr) < (uintptr_t)(base) + (n), \ + "\nBound check failed: %p not in [%p, %p), base+%zu, end+%zu", \ + (void*)(ptr), \ + (void*)(base), \ + (void*)((uintptr_t)(base)+(n)), \ + (size_t)llabs((long long)((int64_t)(ptr)-(int64_t)(base))), \ + (size_t)llabs((long long)(((int64_t)(base)+(n))-(int64_t)(ptr))) \ + ) +#else +#define mag_bnd_chk(ptr, base, n) +#endif + +/* Increment pointer or size with correct type alignment. */ +static MAG_AINLINE void* mag_pincr(void** p, size_t sz, size_t align) { + void* pp = (void*)(((uintptr_t)*p+align-1)&-align); + *p = (void*)((uint8_t*)pp+sz); + return pp; +} + +typedef struct mag_intrusive_chunk mag_intrusive_chunk; +struct mag_intrusive_chunk { + uint8_t* bot; /* Bottom (base) of chunk */ + uint8_t* top; /* Top of chunk, grows downwards towards bottom */ + mag_intrusive_chunk* next; /* Link to next chunk */ +}; + +/* Fast memory allocator for memory blocks of same size. Obtains a memory pool and freelist for fast de/allocation. */ +typedef struct mag_fixed_intrusive_pool { + size_t block_size; /* Size of each allocated block */ + size_t block_align; /* Alignment requirements of each block. */ + size_t blocks_per_chunk; /* How many blocks fit in each chunk */ + mag_intrusive_chunk* chunks; /* Linked list of all chunks */ + mag_intrusive_chunk* chunk_head; /* Last chunk */ + void* free_list; /* Intrusive single linked list of free chunks */ + uint64_t num_freelist_hits; /* Number of cache (free-list) hits */ + uint64_t num_pool_hits; /* Number of cache (pool) hits */ + uint64_t num_chunks; /* Number of used chunks */ + uint64_t num_allocs; /* Number of total allocations */ +} mag_fixed_intrusive_pool; + +extern MAG_EXPORT void mag_fixed_intrusive_pool_init(mag_fixed_intrusive_pool* pool, size_t block_size, size_t block_align, size_t blocks_per_chunk); +extern MAG_EXPORT void* mag_fixed_intrusive_pool_malloc(mag_fixed_intrusive_pool* pool); +extern MAG_EXPORT void mag_fixed_intrusive_pool_free(mag_fixed_intrusive_pool* pool, void* blk); +extern MAG_EXPORT void mag_fixed_intrusive_pool_destroy(mag_fixed_intrusive_pool* pool); +extern MAG_EXPORT void mag_fixed_intrusive_pool_print_info(mag_fixed_intrusive_pool* pool, const char* name); + +/* Device interface to any compute backend device (CPU, GPU, TPU etc..) */ +typedef struct mag_compute_device_t mag_compute_device_t; + +/* Buffer interface on a compute device */ +typedef struct mag_storage_buffer_t mag_storage_buffer_t; +struct mag_storage_buffer_t { + uintptr_t base; /* Pointer to buffer on device. Might point to GPU or any other device memory. */ + size_t size; /* Size of buffer in bytes. */ + size_t alignment; /* Alignment of buffer. */ + mag_compute_device_t* host; /* Host device. */ + void (*set)(mag_storage_buffer_t* sto, size_t offs, uint8_t x); /* Memset buffer. */ + void (*cpy_host_device)(mag_storage_buffer_t* sto, size_t offs, const void* src, size_t n); /* Copy data from host to device. */ + void (*cpy_device_host)(mag_storage_buffer_t* sto, size_t offs, void* dst, size_t n); /* Copy data from device to host. */ +}; + +/* Device interface to any compute backend device (CPU, GPU, TPU etc..) */ +struct mag_compute_device_t { + char name[128]; /* Device name. */ + void* impl; /* Device specific implementation, if applicable. */ + bool is_async; /* If device is async. */ + mag_compute_device_type_t type; /* Device type enum. */ + void (*eager_exec_fwd)(mag_compute_device_t* dvc, mag_tensor_t* root); /* Execute a single op forward. */ + void (*eager_exec_bwd)(mag_compute_device_t* dvc, mag_tensor_t* root); /* Execute a single op backwards. */ + void (*alloc_storage)(mag_compute_device_t* dvc, mag_storage_buffer_t* out, size_t size, size_t align); + void (*free_storage)(mag_compute_device_t* dvc, mag_storage_buffer_t* buf); +}; + +/* Device creation and destruction. */ +typedef struct mag_device_factory_t { + mag_compute_device_t* (*init)(mag_ctx_t* ctx); /* Initialize device. */ + void (*destroy)(mag_compute_device_t* dvc); /* Destroy device. */ +} mag_device_factory_t; + +/* Global device factories. Implemented in magnetron_device_registry.c */ +extern mag_compute_device_t* mag_init_dynamic_device(mag_ctx_t* ctx, mag_compute_device_type_t* type); +extern void mag_destroy_dynamic_device(mag_compute_device_t* dvc); + +/* Profiling performance monitor per op. */ +typedef struct mag_perf_mon_t { + uint64_t elapsed_ns; + uint64_t elapsed_ns_acc; + uint64_t n_execs; +} mag_perf_mon_t; + +/* Performance monitor for profiler session. */ +typedef struct mag_op_perf_info_t { + uint64_t elapsed_ns_acc; + uint64_t n_execs; +} mag_op_perf_info_t; + +#if MAG_SANITIZE_RC +typedef struct mag_tensor_node_t mag_tensor_node_t; +struct mag_tensor_node_t { + mag_tensor_t* tensor; + mag_tensor_node_t* next; +}; +#endif + +/* +** Context contains all isolated state and data. +** Lifetimes of tensors and compute graphs are bound to the context - the context is the owner. +** Context itself is not thread-safe, use a thread-local context or synchronize access. (Multiple contexts can be used.) +*/ +struct mag_ctx_t { + struct { + char os_name[128]; /* OS name. */ + char cpu_name[128]; /* CPU name. */ + uint32_t cpu_virtual_cores; /* Virtual CPUs. */ + uint32_t cpu_physical_cores; /* Physical CPU cores. */ + uint32_t cpu_sockets; /* CPU sockets. */ + uint64_t phys_mem_total; /* Total physical memory in bytes. */ + uint64_t phys_mem_free; /* Free physical memory in bytes. */ +#if defined(__x86_64__) || defined(_M_X64) + uint32_t x86_64_cpu_features[8][4]; /* x86-64 CPU features. */ +#endif + } sys; +#if MAG_SANITIZE_RC + mag_tensor_node_t* rc_tracked; /* Linked list of RC tensors for sanitize. */ +#endif + mag_fixed_intrusive_pool tensor_pool; /* Fixed-size memory pool for tensors. */ + mag_exec_mode_t exec_mode; + bool profiler_enabled; + mag_op_perf_info_t op_perf_mons_total[MAG_OP__NUM]; + union { + struct { + uint64_t state; + uint64_t inc; + } pcg; + struct { + uint32_t remaining; + uint32_t next; + uint32_t state[624]; + } mersenne; + } prng; + mag_prng_algorithm_t prng_algorithm; /* PRNG algorithm. */ + uintptr_t tr_id; /* Host thread ID. */ + size_t sh_len; /* Number of shutdown hooks. */ + size_t sh_cap; /* Maximum number of shutdown hooks. */ + mag_compute_device_type_t device_type; /* Active compute device. */ + mag_compute_device_t* device; /* Active compute device. */ + uint8_t* (*image_load_fn)(const char*, uint32_t(*)[3], mag_color_channels_t); /* Image loader. stb_image by default, you can plug-in your own. */ + void (*image_load_free_fn)(uint8_t*); /* Image loader free function. stb_image by default, you can plug-in your own. */ + bool (*image_save_fn)(const char*, const uint8_t*, const uint32_t(*)[3]); /* Image saver. stb_image by default, you can plug-in your own. */ + void* ud; /* User data. */ +}; + +typedef enum mag_tensor_flags_t { + MAG_TFLAG_NONE = 0, + MAG_TFLAG_OWNER = 1<<0, /* Tensor is the owner of the buffer. */ + MAG_TFLAG_VIEW = 1<<1, /* Tensor is a view. */ + MAG_FLAG_GRAD = 1<<2, /* Tensor is a gradient. */ + MAG_TFLAG_EXEC_EAGER = 1<<3, /* Tensor is executed eagerly. */ + + MAG_TFLAG_LEN = 4 +} mag_tensor_flags_t; +mag_static_assert(MAG_TFLAG_LEN <= 0xff); + +/* +** Tensor with up to 6 Dimensions. +*/ +struct mag_tensor_t { + struct { + uint32_t rc_strong; /* Strong reference count. */ + uint32_t rc_weak; /* Weak reference count. */ +#if MAG_SANITIZE_RC + void (*dtor)(mag_tensor_t*); /* Debug destructor. */ +#endif + } rcb; /* Reference count control block. */ + mag_ctx_t* ctx; /* Host context. */ + int64_t rank; /* Number of active dimensions. [1, MAX_DIMS] */ + int64_t shape[MAG_MAX_DIMS]; /* Shape of the tensor. */ + int64_t strides[MAG_MAX_DIMS]; /* Strides of the tensor. We store the strides in element counts and NOT in bytes. */ + mag_dtype_t dtype; /* Data type of the tensor. */ + mag_storage_buffer_t storage; /* Storage buffer. */ + int64_t numel; /* Number of elements in the tensor. */ + mag_tensor_flags_t flags; /* Tensor flags. */ + mag_op_t op; /* Opcode for operators. */ + mag_tensor_t* op_inputs[MAG_MAX_INPUT_TENSORS]; /* Input tensors for operators. */ + mag_op_param_t op_params[MAG_MAX_OP_PARAMS]; /* Operator parameters. */ + mag_tensor_t* view_uplink; /* View base tensor. */ + size_t view_offs; /* Offset in view tensor. */ + mag_tensor_t* grad; /* ∇f - Gradient tensor. */ + mag_perf_mon_t pmon; /* Performance monitor. */ + char name[MAG_MAX_TENSOR_NAME_LEN]; /* Tensor debug name. */ + void* ud; /* User data. */ +}; + +#define mag_load_local_storage_group_arr(arr, prefix) \ + const int64_t prefix##0 = (arr)[0]; \ + const int64_t prefix##1 = (arr)[1]; \ + const int64_t prefix##2 = (arr)[2]; \ + const int64_t prefix##3 = (arr)[3]; \ + const int64_t prefix##4 = (arr)[4]; \ + const int64_t prefix##5 = (arr)[5]; \ + (void)prefix##0; \ + (void)prefix##1; \ + (void)prefix##2; \ + (void)prefix##3; \ + (void)prefix##4; \ + (void)prefix##5 + +#define mag_load_local_storage_group(xk, prefix, var) mag_load_local_storage_group_arr((xk)->var, prefix) + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/media/magnetron-logo-noborder.png b/media/magnetron-logo-noborder.png new file mode 100644 index 0000000..e7a6b0e Binary files /dev/null and b/media/magnetron-logo-noborder.png differ diff --git a/media/magnetron-logo.svg b/media/magnetron-logo.svg new file mode 100644 index 0000000..8dea70d --- /dev/null +++ b/media/magnetron-logo.svg @@ -0,0 +1,27 @@ + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/media/magnetron-torus.png b/media/magnetron-torus.png new file mode 100644 index 0000000..b61e253 Binary files /dev/null and b/media/magnetron-torus.png differ diff --git a/media/xor.png b/media/xor.png new file mode 100644 index 0000000..1975d98 Binary files /dev/null and b/media/xor.png differ diff --git a/python/benchmarks/tune.py b/python/benchmarks/tune.py new file mode 100644 index 0000000..e69de29 diff --git a/python/chat/requirements.txt b/python/chat/requirements.txt new file mode 100644 index 0000000..87531d5 --- /dev/null +++ b/python/chat/requirements.txt @@ -0,0 +1,2 @@ +python-dotenv +flask \ No newline at end of file diff --git a/python/chat/server.py b/python/chat/server.py new file mode 100644 index 0000000..bcae468 --- /dev/null +++ b/python/chat/server.py @@ -0,0 +1,68 @@ +from math import ceil + +from dotenv import load_dotenv +from flask import Flask, render_template, request +import magnetron as mag +from magnetron.models import * + +load_dotenv() + +EPOCHS: int = 10000 +LEARNING_RATE: float = 0.8 + +# Inputs +inputs = [ + mag.Tensor.const([0.0, 0.0]), + mag.Tensor.const([0.0, 1.0]), + mag.Tensor.const([1.0, 0.0]), + mag.Tensor.const([1.0, 1.0]) +] + +# Targets +targets = [ + mag.Tensor.const([0.0]), + mag.Tensor.const([1.0]), + mag.Tensor.const([1.0]), + mag.Tensor.const([0.0]) +] + +mlp = SequentialModel([ + DenseLayer(2, 4), + DenseLayer(4, 1) +]) + +# Train model XOR +print('Training XOR model...') +losses = mlp.train(inputs, targets, EPOCHS, LEARNING_RATE) + +print('Launching Flask server...') +app = Flask(__name__) + +@app.route('/') +def index(): + return render_template('index.html') + + +@app.route('/api/v1/bot_response') +def get_response(): + try: + message = request.args.get('message') + splits = message.split(' ') + a: float = float(splits[0]) + b: float = float(splits[1]) + inp = mag.Tensor.const([a, b]) + result: float = mlp.forward(inp).scalar() + result_rounded: float = mlp.forward(inp, activation=mag.Operator.HARD_SIGMOID).scalar() + return f'{a} ^ {b} = {result} ≈ {result_rounded} => {int(result_rounded) == 1}' + except: + return 'Please enter a valid input. Enter two numbers (between 0 and 1) seperated by spaces. For example: 1 1 or 1 0 or 0 0.' + + +@app.route('/api/v1/system_info') +def get_system_info(): + ctx = mag.Context.active + return f'{ctx.os_name} | {ctx.cpu_name} ({ctx.cpu_virtual_cores}) | {ctx.physical_memory_total / (1 << 30)} GiB RAM | {(ctx.total_allocated_pool_memory / (1 << 20)):.2f} MiB POOL' + + +if __name__ == '__main__': + app.run(host='0.0.0.0', use_reloader=False) # host='0.0.0.0' diff --git a/python/chat/static/css/style.css b/python/chat/static/css/style.css new file mode 100644 index 0000000..737437e --- /dev/null +++ b/python/chat/static/css/style.css @@ -0,0 +1,99 @@ +* { + margin: 0; + padding: 0; + box-sizing: border-box; +} +body { + font-family: Arial, sans-serif; + display: flex; + justify-content: center; + align-items: center; + min-height: 100vh; + background-color: #f9f9f9; + color: #333; +} +.main-container { + width: 100%; + max-width: 1200px; + display: flex; + flex-direction: column; + align-items: center; + padding: 1rem; +} +.logo-container img { + width: 100px; + margin-bottom: 1rem; +} +h1 { + font-size: 1.5rem; + text-align: center; + margin-bottom: 1rem; + color: #444; +} +.chat-container { + width: 100%; + height: 70vh; + background-color: #fff; + border-radius: 8px; + box-shadow: 0px 4px 12px rgba(0, 0, 0, 0.1); + padding: 1rem; + display: flex; + flex-direction: column; + justify-content: space-between; +} +.chatbox { + flex-grow: 1; + overflow-y: auto; + border: 1px solid #ddd; + border-radius: 8px; + padding: 1rem; + background-color: #fafafa; +} +.message { + margin-bottom: 1rem; +} +.user { + color: #007bff; +} +.bot { + color: #28a745; +} +.input-container { + display: flex; + gap: 0.5rem; + margin-top: 1rem; +} +input[type="text"] { + flex-grow: 1; + padding: 0.75rem; + border: 1px solid #ddd; + border-radius: 8px; + font-size: 1rem; + transition: border-color 0.2s; +} +input[type="text"]:focus { + border-color: #007bff; + outline: none; +} +button { + padding: 0.75rem 1rem; + border: none; + background-color: #007bff; + color: #fff; + font-size: 1rem; + border-radius: 8px; + cursor: pointer; + transition: background-color 0.2s; +} +button:hover { + background-color: #0056b3; +} +footer { + margin-top: 1rem; + text-align: center; + font-size: 0.875rem; + color: #666; +} +.info-text { + margin-bottom: 0.25rem; +} diff --git a/python/chat/static/images/wavelet-logo.png b/python/chat/static/images/wavelet-logo.png new file mode 100644 index 0000000..d05670b Binary files /dev/null and b/python/chat/static/images/wavelet-logo.png differ diff --git a/python/chat/static/js/script.js b/python/chat/static/js/script.js new file mode 100644 index 0000000..10bc17e --- /dev/null +++ b/python/chat/static/js/script.js @@ -0,0 +1,57 @@ +const chatbox = document.getElementById('chatbox'); +const message = document.getElementById('message'); +const send = document.getElementById('send'); +const infoText = document.getElementById('info-text'); + +function botSay(botMessage) { + const botDiv = document.createElement('div'); + botDiv.className = 'message bot'; + botDiv.innerHTML = `magnetron: ${botMessage}`; + chatbox.appendChild(botDiv); + chatbox.scrollTop = chatbox.scrollHeight; +} + +send.addEventListener('click', function() { + const userMessage = message.value.trim(); + if (userMessage) { + const userDiv = document.createElement('div'); + userDiv.className = 'message user'; + userDiv.innerHTML = `You: ${userMessage}`; + chatbox.appendChild(userDiv); + + message.value = ""; + + fetch(`/api/v1/bot_response?message=${encodeURIComponent(userMessage)}`) + .then(response => response.text()) + .then(botMessage => { + botSay(botMessage); + }) + .catch(error => { + console.error('Error fetching response:', error); + }); + + fetch(`/api/v1/system_info`) + .then(response => response.text()) + .then(infoString => { + infoText.innerHTML = infoString; + }) + .catch(error => { + console.error('Error fetching response:', error); + }); + } +}); + +fetch(`/api/v1/system_info`) + .then(response => response.text()) + .then(infoString => { + infoText.innerHTML = infoString; + }) + .catch(error => { + console.error('Error fetching response:', error); + }); + +message.addEventListener('keypress', function(event) { + if (event.key === 'Enter') send.click(); +}); + +botSay('Hello! I\'m an AI trained to solve the XOR function. Enter me two space separated numbers in {0, 1} and press enter. Like 1 0. Or 1 1.') diff --git a/python/chat/templates/index.html b/python/chat/templates/index.html new file mode 100644 index 0000000..306b17c --- /dev/null +++ b/python/chat/templates/index.html @@ -0,0 +1,36 @@ + + + + + + Chat with magnetron + + + + +
+ +
+ magnetron Logo +
+ + +
+

Chat with magnetron

+
+
+ + +
+
+ + +
+ +

+
+
+ + + + \ No newline at end of file diff --git a/python/examples/research/approx.py b/python/examples/research/approx.py new file mode 100644 index 0000000..94dd163 --- /dev/null +++ b/python/examples/research/approx.py @@ -0,0 +1,39 @@ +# (c) 2024 Mario 'Neo' Sieg. + +import magnetron as mag +import matplotlib +import matplotlib.pyplot as plt + +matplotlib.rcParams["axes.formatter.limits"] = (-99, 99) # Disable scientific notation to show all digits + + +def plot_approximation_error(name: str, exact_func: callable, approx_op: mag.Operator, domain: (float, float), + step: float = 0.0001): + x_values = [i * step for i in range(int(domain[0] / step), int(domain[1] / step))] + exact = [exact_func(x) for x in x_values] + approx = mag.Tensor.operator(approx_op, False, None, mag.Tensor.const(x_values)).data_as_f32() + errors = [abs(exact[i] - approx[i]) for i in range(len(exact))] + assert len(exact) == len(approx) == len(errors) == len(x_values) + + plt.figure(figsize=(10, 5)) + plt.plot(x_values, exact, label=f'Exact {name}', color='blue') + plt.plot(x_values, approx, label=f'Approx {name}', color='orange', linestyle='--') + plt.title(f'Exact vs Approx {name}') + plt.xlabel('x') + plt.ylabel(f'{name}(x)') + plt.legend() + plt.grid(True) + + plt.figure(figsize=(10, 5)) + plt.plot(x_values, errors, label='Absolute Error', color='red') + plt.title(f'Error in {name} Approximation') + plt.xlabel('x') + plt.ylabel('Absolute Error') + plt.plot([], [], ' ', label=f'Samples {len(errors)}') + plt.plot([], [], ' ', label=f'Mean Error: {sum(errors) / len(errors):.20f}') + plt.plot([], [], ' ', label=f'Min Error: {min(errors):.20f}') + plt.plot([], [], ' ', label=f'Max Error: {max(errors):.20f}') + plt.legend() + plt.grid(True) + + plt.show() diff --git a/python/examples/research/approx_gelu.py b/python/examples/research/approx_gelu.py new file mode 100644 index 0000000..d4d6475 --- /dev/null +++ b/python/examples/research/approx_gelu.py @@ -0,0 +1,11 @@ +# (c) 2024 Mario 'Neo' Sieg. + +import approx +import math + +import magnetron as mag + +def gelu(x: float) -> float: + return 0.5 * x * (1.0 + math.tanh(0.79788456080286535587989211986876 * x * (1.0 + 0.044715 * x * x))) + +approx.plot_approximation_error('gelu', gelu, mag.Operator.GELU, domain=(-2, 2)) diff --git a/python/examples/research/approx_silu.py b/python/examples/research/approx_silu.py new file mode 100644 index 0000000..48d1263 --- /dev/null +++ b/python/examples/research/approx_silu.py @@ -0,0 +1,13 @@ +# (c) 2024 Mario 'Neo' Sieg. + +import approx +import math + +import magnetron as mag + + +def silu(x: float) -> float: + return x / (1.0 + math.exp(-x)) + + +approx.plot_approximation_error('silu', silu, mag.Operator.SILU, domain=(-2, 2)) diff --git a/python/examples/research/approx_softmax.py b/python/examples/research/approx_softmax.py new file mode 100644 index 0000000..d6b9cf9 --- /dev/null +++ b/python/examples/research/approx_softmax.py @@ -0,0 +1,8 @@ +# (c) 2024 Mario 'Neo' Sieg. + +import approx +import math + +import magnetron as mag + +approx.plot_approximation_error('softmax', math.exp, mag.Operator.SOFTMAX, domain=(-10, 10)) diff --git a/python/examples/research/approx_tanh.py b/python/examples/research/approx_tanh.py new file mode 100644 index 0000000..d8c6d39 --- /dev/null +++ b/python/examples/research/approx_tanh.py @@ -0,0 +1,8 @@ +# (c) 2024 Mario 'Neo' Sieg. + +import approx +import math + +import magnetron as mag + +approx.plot_approximation_error('tanh', math.tanh, mag.Operator.TANH, domain=(-2, 2)) diff --git a/python/examples/research/matmul_numpy.py b/python/examples/research/matmul_numpy.py new file mode 100644 index 0000000..8a510c0 --- /dev/null +++ b/python/examples/research/matmul_numpy.py @@ -0,0 +1,21 @@ +# (c) 2024 Mario "Neo" Sieg. + +import numpy as np +import time + +N = 1024 +A = np.random.randn(N, N).astype(dtype=np.float32) +B = np.random.randn(N, N).astype(dtype=np.float32) + +flop = 2*N**3 +avg = 0 +I = 10 +for _ in range(I): + st = time.monotonic() + C = A @ B + et = time.monotonic() + s = et - st + print(f'{flop/s * 1e-12} TFLOP/s') + avg += flop/s + +print(f'Average: {avg/I * 1e-12} TFLOP/s') \ No newline at end of file diff --git a/python/examples/research/matmul_ref.py b/python/examples/research/matmul_ref.py new file mode 100644 index 0000000..484eb0e --- /dev/null +++ b/python/examples/research/matmul_ref.py @@ -0,0 +1,13 @@ +# (c) 2024 Mario "Neo" Sieg. + +import numpy as np + +A = np.array([ + [1.0, 2.0], + [3.0, 4.0], + [5.0, 6.0] +]) +B = np.array([ + 0.5, -1.0 +]) +print(np.matmul(A, B)) diff --git a/python/examples/research/matmul_wl.py b/python/examples/research/matmul_wl.py new file mode 100644 index 0000000..d76aad6 --- /dev/null +++ b/python/examples/research/matmul_wl.py @@ -0,0 +1,23 @@ +# (c) 2024 Mario "Neo" Sieg. + +import magnetron as mag +import time + + +N = 1024 +A = mag.Tensor.uniform((N, N), name='A') +B = mag.Tensor.uniform((N, N), name='B') +print(A.shape) + +flop = 2*N**3 +avg = 0 +I = 10 +for _ in range(I): + st = time.monotonic() + C = A @ B + et = time.monotonic() + s = et - st + print(f'{flop/s * 1e-12} TFLOP/s') + avg += flop/s + +print(f'Average: {avg/I * 1e-12} TFLOP/s') \ No newline at end of file diff --git a/python/examples/simple/cuda.py b/python/examples/simple/cuda.py new file mode 100644 index 0000000..0df9189 --- /dev/null +++ b/python/examples/simple/cuda.py @@ -0,0 +1,10 @@ +# (c) 2024 Mario "Neo" Sieg. + +import magnetron as mag + +mag.GlobalConfig.verbose = True +mag.GlobalConfig.compute_device = mag.ComputeDevice.CUDA + +test = mag.Tensor.uniform(shape=(4, 4)) +r = test + test +print(r) diff --git a/python/examples/simple/mnist.py b/python/examples/simple/mnist.py new file mode 100644 index 0000000..1d99dbd --- /dev/null +++ b/python/examples/simple/mnist.py @@ -0,0 +1,14 @@ +# (c) 2024 Mario 'Neo' Sieg. + +from magnetron.models import SequentialModel, DenseLayer + +EPOCHS: int = 10 +LEARNING_RATE: float = 1e-3 + +network = SequentialModel([ + DenseLayer(784, 250), + DenseLayer(250, 100), + DenseLayer(100, 10) +]) + +network.summary() diff --git a/python/examples/simple/mnist_dataset.py b/python/examples/simple/mnist_dataset.py new file mode 100644 index 0000000..70316b9 --- /dev/null +++ b/python/examples/simple/mnist_dataset.py @@ -0,0 +1,35 @@ +# (c) 2024 Mario 'Neo' Sieg. + +import gzip +import struct +import numpy as np +import magnetron as mag + + +def ubyte_load_data(src: str, num_samples: int) -> np.ndarray: + with gzip.open(src) as gz: + n = struct.unpack("I", gz.read(4)) + assert n[0] == 0x3080000 + n = struct.unpack(">I", gz.read(4))[0] + assert n == num_samples + crow = struct.unpack(">I", gz.read(4))[0] + ccol = struct.unpack(">I", gz.read(4))[0] + assert crow == ccol == 28 + res = np.frombuffer(gz.read(num_samples * crow * ccol), dtype=np.uint8) + return res.astype(dtype=np.float32).reshape((num_samples, crow, ccol)) / 256.0 + + +def ubyte_load_labels(src: str, num_samples: int) -> np.ndarray: + with gzip.open(src) as gz: + n = struct.unpack("I", gz.read(4)) + assert n[0] == 0x1080000 + n = struct.unpack(">I", gz.read(4)) + assert n[0] == num_samples + res = np.frombuffer(gz.read(num_samples), dtype=np.uint8) + return res.reshape(num_samples) + + +data = ubyte_load_data('../../../datasets/mnist/images-idx3-ubyte.gz', 60000) +mnist = mag.Tensor.const(data=data.tolist()) +mnist.print(True, False) +mnist.save('mnist_images.magnetron') diff --git a/python/examples/simple/perceptron.py b/python/examples/simple/perceptron.py new file mode 100644 index 0000000..548120e --- /dev/null +++ b/python/examples/simple/perceptron.py @@ -0,0 +1,23 @@ +import magnetron as mag + +# Perceptron function (McCulloch–Pitts neuron) +def perceptron(x: mag.Tensor, w: mag.Tensor, b: mag.Tensor) -> mag.Tensor: + return (w @ x + b).heaviside_step() + +# Negating perceptron +def logical_not(input: float) -> float: + x = mag.Tensor.const([input]) + w = mag.Tensor.const([-1]) + b = mag.Tensor.const([0.5]) + r = perceptron(x, w, b) + return r.scalar() + +truth_table = [ + 0.0, + 1.0, + 0.0, + 1.0, +] + +for bit in truth_table: + print(f'¬{bit} = {logical_not(bit)}') diff --git a/python/examples/simple/requirements.txt b/python/examples/simple/requirements.txt new file mode 100644 index 0000000..68ae195 --- /dev/null +++ b/python/examples/simple/requirements.txt @@ -0,0 +1,4 @@ +cffi +matplotlib +numpy +opencv-python diff --git a/python/examples/simple/xor.py b/python/examples/simple/xor.py new file mode 100644 index 0000000..6b39793 --- /dev/null +++ b/python/examples/simple/xor.py @@ -0,0 +1,47 @@ +# (c) 2024 Mario 'Neo' Sieg. + +import magnetron as mag +from magnetron.models import SequentialModel, DenseLayer +import matplotlib.pyplot as plt + +mag.enable_log(True) + +EPOCHS: int = 10000 +LEARNING_RATE: float = 0.8 + +# Inputs +inputs = [ + mag.Tensor.const([0.0, 0.0]), + mag.Tensor.const([0.0, 1.0]), + mag.Tensor.const([1.0, 0.0]), + mag.Tensor.const([1.0, 1.0]) +] + +# Targets +targets = [ + mag.Tensor.const([0.0]), + mag.Tensor.const([1.0]), + mag.Tensor.const([1.0]), + mag.Tensor.const([0.0]) +] + +mlp = SequentialModel([ + DenseLayer(2, 4), + DenseLayer(4, 1) +]) + +# Train model +losses = mlp.train(inputs, targets, EPOCHS, LEARNING_RATE) + +# Inference +for input_tensor in inputs: + input_data = input_tensor.to_list() + output: float = mlp.forward(input_tensor).scalar() + print(f'{input_data[0]} ^ {input_data[1]} = {output}') + +# Plot MSE loss +plt.plot(list(range(0, EPOCHS - 1)), losses) +plt.xlabel('Epochs') +plt.ylabel('MSE Loss') +plt.title('XOR Problem') +plt.show() diff --git a/python/examples/utils.py b/python/examples/utils.py new file mode 100644 index 0000000..aa93444 --- /dev/null +++ b/python/examples/utils.py @@ -0,0 +1,56 @@ +# (c) 2024 Mario 'Neo' Sieg. +# Implements utility functions for magnetron. Requires matplotlib and numpy. + +from magnetron.core import Tensor +from matplotlib import pyplot as plt +import numpy as np + + +def to_numpy(tensor: Tensor, deinterleave: bool = False) -> np.array: + if deinterleave: + w: int = tensor.width + h: int = tensor.height + c: int = tensor.channels + data = np.array(tensor.data_as_f32()).reshape(c, h, w) + data = np.transpose(data, (1, 2, 0)) + return (data * 255.0).astype(np.uint8) + else: + return np.array(tensor.data_as_f32(), dtype=np.float32).reshape(tensor.shape) + + +def plot_tensor_image(tensor: Tensor, title: str | None = None) -> None: + if title is not None: + plt.title(title) + plt.imshow(to_numpy(tensor, deinterleave=True)) + plt.show() + + +def plot_tensor_scatter(tensor: Tensor, title: str | None = None, deinterleave: bool = False) -> None: + data = to_numpy(tensor, deinterleave) + match tensor.rank: + case 2: + x, y = np.where(data > 0) + plt.scatter(x, y, marker='o') + if title is not None: + plt.title(title) + plt.show() + case 3: + x, y, z = np.where(data > 0) + fig = plt.figure() + if title is not None: + plt.title(title) + ax = fig.add_subplot(111, projection='3d') + ax.scatter(x, y, z, marker='o') + plt.show() + case 4: + x, y, z, w = np.where(data > 0) + fig = plt.figure() + if title is not None: + plt.title(title) + ax = fig.add_subplot(111, projection='3d') + scatter = ax.scatter(x, y, z, c=w, cmap='viridis', marker='o') + plt.colorbar(scatter, label='4th Dimension') + plt.show() + case _: + raise ValueError('Only 2D, 3D and 4D tensors are supported.') + diff --git a/python/magnetron_framework/bing_gen.py b/python/magnetron_framework/bing_gen.py new file mode 100644 index 0000000..0ebd7f1 --- /dev/null +++ b/python/magnetron_framework/bing_gen.py @@ -0,0 +1,74 @@ +# (c) 2024 Mario "Neo" Sieg. + +import datetime +import re + +C_HDR_FILE: str = '../../magnetron/magnetron.h' +OUTPUT_FILE: str = 'magnetron/_ffi_cdecl_generated.py' + +print(f'Generating {OUTPUT_FILE} from {C_HDR_FILE}...') + + +def comment_replacer(match): + s = match.group(0) + if s.startswith('/'): + return ' ' + else: + return s + + +macro_substitutions: dict[str, str] = { + 'MAG_EXPORT': ' ', + 'MAG_MAX_DIMS': str(6), # SYNC with magnetron.h + 'MAG_MAX_OP_PARAMS': str(6) # SYNC with magnetron.h +} + +enums_names: list[str] = [] +struct_names: list[str] = [] + + +def keep_line(line: str) -> bool: + if line == '' or line.startswith('#'): + return False + if line.startswith('extern') and not line.startswith('extern "C"'): + return True + if line.startswith('typedef enum'): + enums_names.append(line.split()[2]) + return False + if line.startswith('typedef struct'): + struct_names.append(line.split()[2]) + return False + if line.startswith('typedef'): + return True + return False + + +c_input: list[str] = [] +with open(C_HDR_FILE, 'rt') as f: + full_src: str = f.read() + pattern = re.compile( + r'//.*?$|/\*.*?\*/|\'(?:\\.|[^\\\'])*\'|"(?:\\.|[^\\"])*"', + re.DOTALL | re.MULTILINE + ) + full_src = re.sub(pattern, comment_replacer, full_src) # remove comments + for macro, replacement in macro_substitutions.items(): + full_src = full_src.replace(macro, replacement) + c_input = [line.strip() for line in full_src.splitlines()] # remove empty lines + c_input = [line for line in c_input if keep_line(line)] # remove empty lines + +out = f'# Autogenered by {__file__} {datetime.datetime.now()}, do NOT edit!\n\n' +out += "__MAG_CDECLS: str = '''\n\n" +for struct in struct_names: + out += f'typedef struct {struct} {struct};\n' +out += '\n' +for enum in enums_names: + out += f'typedef int {enum};\n' +out += '\n' +for line in c_input: + out += f'{line}\n' +out += "'''\n\n" + +with open(OUTPUT_FILE, 'wt') as f: + f.write(out) + +print(f'Generated {OUTPUT_FILE} with {len(c_input)} lines.') diff --git a/python/magnetron_framework/build_wheel.bat b/python/magnetron_framework/build_wheel.bat new file mode 100644 index 0000000..27aca60 --- /dev/null +++ b/python/magnetron_framework/build_wheel.bat @@ -0,0 +1,11 @@ +@echo off +if exist build ( + rmdir /s /q build +) +if exist dist ( + rmdir /s /q dist +) +if exist magnetron.egg-info ( + rmdir /s /q magnetron.egg-info +) +pip3 wheel --verbose -w dist . diff --git a/python/magnetron_framework/build_wheel.sh b/python/magnetron_framework/build_wheel.sh new file mode 100644 index 0000000..5baeb27 --- /dev/null +++ b/python/magnetron_framework/build_wheel.sh @@ -0,0 +1,6 @@ +#!/usr/bin/env bash + +rm -rf ./build +rm -rf ./dist +rm -rf ./magnetron.egg-info +pip3 wheel --verbose -w dist . diff --git a/python/magnetron_framework/install_wheel_local.bat b/python/magnetron_framework/install_wheel_local.bat new file mode 100644 index 0000000..e2d2d52 --- /dev/null +++ b/python/magnetron_framework/install_wheel_local.bat @@ -0,0 +1,5 @@ +@echo off +call build_wheel.bat +for %%f in (dist\*.whl) do ( + pip3 install "%%f" --force-reinstall +) \ No newline at end of file diff --git a/python/magnetron_framework/install_wheel_local.sh b/python/magnetron_framework/install_wheel_local.sh new file mode 100644 index 0000000..b820571 --- /dev/null +++ b/python/magnetron_framework/install_wheel_local.sh @@ -0,0 +1,4 @@ +#!/usr/bin/env bash + +bash ./build_wheel.sh +pip3 install ./dist/*.whl --force-reinstall diff --git a/python/magnetron_framework/magnetron/__init__.py b/python/magnetron_framework/magnetron/__init__.py new file mode 100644 index 0000000..65c77a7 --- /dev/null +++ b/python/magnetron_framework/magnetron/__init__.py @@ -0,0 +1,10 @@ +# (c) 2024 Mario "Neo" Sieg. + +__version__ = "0.1.0" +__author__ = "Mario Sieg" +__email__ = "mario.sieg.64@gmail.com" +__author_email__ = "mario.sieg.64@gmail.com" +__license__ = "Apache 2.0" +__url__ = "https://github.com/MarioSieg/magnetron" + +from .core import * diff --git a/python/magnetron_framework/magnetron/_ffi_cdecl_generated.py b/python/magnetron_framework/magnetron/_ffi_cdecl_generated.py new file mode 100644 index 0000000..57fb4b9 --- /dev/null +++ b/python/magnetron_framework/magnetron/_ffi_cdecl_generated.py @@ -0,0 +1,169 @@ +# Autogenered by /Users/mariosieg/Documents/projects/magnetron/python/magnetron_framework/bing_gen.py 2024-12-29 15:56:14.554105, do NOT edit! + +__MAG_CDECLS: str = ''' + +typedef struct mag_ctx_t mag_ctx_t; +typedef struct mag_tensor_t mag_tensor_t; +typedef struct mag_dtype_meta_t mag_dtype_meta_t; +typedef struct mag_op_param_t mag_op_param_t; +typedef struct mag_op_meta_t mag_op_meta_t; + +typedef int mag_compute_device_type_t; +typedef int mag_exec_mode_t; +typedef int mag_prng_algorithm_t; +typedef int mag_color_channels_t; +typedef int mag_dtype_t; +typedef int mag_op_t; +typedef int mag_op_param_type_t; +typedef int mag_graph_eval_order_t; + +extern const char* mag_device_type_get_name(mag_compute_device_type_t op); +extern void* (*mag_get_alloc_fn(void))(void* blk, size_t size); +extern void mag_set_alloc_fn(void* (*alloc)(void* blk, size_t size)); +extern void mag_set_log_mode(bool enabled); +typedef uint32_t mag_char32_t; +extern mag_ctx_t* mag_ctx_create(mag_compute_device_type_t device); +extern mag_exec_mode_t mag_ctx_get_exec_mode(const mag_ctx_t* ctx); +extern void mag_ctx_set_exec_mode(mag_ctx_t* ctx, mag_exec_mode_t mode); +extern mag_prng_algorithm_t mag_ctx_get_prng_algorithm(const mag_ctx_t* ctx); +extern void mag_ctx_set_prng_algorithm(mag_ctx_t* ctx, mag_prng_algorithm_t algorithm, uint64_t seed); +extern mag_compute_device_type_t mag_ctx_get_compute_device_type(const mag_ctx_t* ctx); +extern const char* mag_ctx_get_compute_device_name(const mag_ctx_t* ctx); +extern const char* mag_ctx_get_os_name(const mag_ctx_t* ctx); +extern const char* mag_ctx_get_cpu_name(const mag_ctx_t* ctx); +extern uint32_t mag_ctx_get_cpu_virtual_cores(const mag_ctx_t* ctx); +extern uint32_t mag_ctx_get_cpu_physical_cores(const mag_ctx_t* ctx); +extern uint32_t mag_ctx_get_cpu_sockets(const mag_ctx_t* ctx); +extern uint64_t mag_ctx_get_physical_memory_total(const mag_ctx_t* ctx); +extern uint64_t mag_ctx_get_physical_memory_free(const mag_ctx_t* ctx); +extern bool mag_ctx_is_numa_system(const mag_ctx_t* ctx); +extern size_t mag_ctx_get_total_tensors_created(const mag_ctx_t* ctx); +extern void mag_ctx_profile_start_recording(mag_ctx_t* ctx); +extern void mag_ctx_profile_stop_recording(mag_ctx_t* ctx, const char* export_csv_file); +extern void mag_ctx_destroy(mag_ctx_t* ctx); +extern const mag_dtype_meta_t* mag_dtype_meta_of(mag_dtype_t type); +extern const mag_op_meta_t* mag_op_meta_of(mag_op_t type); +extern uint32_t mag_pack_color_u8(uint8_t r, uint8_t g, uint8_t b); +extern uint32_t mag_pack_color_f32(float r, float g, float b); +extern mag_tensor_t* mag_tensor_create_1d(mag_ctx_t* ctx, mag_dtype_t type, int64_t d1); +extern mag_tensor_t* mag_tensor_create_2d(mag_ctx_t* ctx, mag_dtype_t type, int64_t d1, int64_t d2); +extern mag_tensor_t* mag_tensor_create_3d(mag_ctx_t* ctx, mag_dtype_t type, int64_t d1, int64_t d2, int64_t d3); +extern mag_tensor_t* mag_tensor_create_4d(mag_ctx_t* ctx, mag_dtype_t type, int64_t d1, int64_t d2, int64_t d3, int64_t d4); +extern mag_tensor_t* mag_tensor_create_5d(mag_ctx_t* ctx, mag_dtype_t type, int64_t d1, int64_t d2, int64_t d3, int64_t d4, int64_t d5); +extern mag_tensor_t* mag_tensor_create_6d(mag_ctx_t* ctx, mag_dtype_t type, int64_t d1, int64_t d2, int64_t d3, int64_t d4, int64_t d5, int64_t d6); +extern mag_tensor_t* mag_clone(mag_tensor_t* x); +extern mag_tensor_t* mag_view(mag_tensor_t* x); +extern mag_tensor_t* mag_transpose(mag_tensor_t* x); +extern mag_tensor_t* mag_permute(mag_tensor_t* x, uint32_t d0, uint32_t d1, uint32_t d2, uint32_t d3, uint32_t d4, uint32_t d5); +extern mag_tensor_t* mag_mean(mag_tensor_t* x); +extern mag_tensor_t* mag_min(mag_tensor_t* x); +extern mag_tensor_t* mag_max(mag_tensor_t* x); +extern mag_tensor_t* mag_sum(mag_tensor_t* x); +extern mag_tensor_t* mag_abs(mag_tensor_t* x); +extern mag_tensor_t* mag_abs_(mag_tensor_t* x); +extern mag_tensor_t* mag_neg(mag_tensor_t* x); +extern mag_tensor_t* mag_neg_(mag_tensor_t* x); +extern mag_tensor_t* mag_log(mag_tensor_t* x); +extern mag_tensor_t* mag_log_(mag_tensor_t* x); +extern mag_tensor_t* mag_sqr(mag_tensor_t* x); +extern mag_tensor_t* mag_sqr_(mag_tensor_t* x); +extern mag_tensor_t* mag_sqrt(mag_tensor_t* x); +extern mag_tensor_t* mag_sqrt_(mag_tensor_t* x); +extern mag_tensor_t* mag_sin(mag_tensor_t* x); +extern mag_tensor_t* mag_sin_(mag_tensor_t* x); +extern mag_tensor_t* mag_cos(mag_tensor_t* x); +extern mag_tensor_t* mag_cos_(mag_tensor_t* x); +extern mag_tensor_t* mag_step(mag_tensor_t* x); +extern mag_tensor_t* mag_step_(mag_tensor_t* x); +extern mag_tensor_t* mag_softmax(mag_tensor_t* x); +extern mag_tensor_t* mag_softmax_(mag_tensor_t* x); +extern mag_tensor_t* mag_softmax_dv(mag_tensor_t* x); +extern mag_tensor_t* mag_softmax_dv_(mag_tensor_t* x); +extern mag_tensor_t* mag_sigmoid(mag_tensor_t* x); +extern mag_tensor_t* mag_sigmoid_(mag_tensor_t* x); +extern mag_tensor_t* mag_sigmoid_dv(mag_tensor_t* x); +extern mag_tensor_t* mag_sigmoid_dv_(mag_tensor_t* x); +extern mag_tensor_t* mag_hard_sigmoid(mag_tensor_t* x); +extern mag_tensor_t* mag_hard_sigmoid_(mag_tensor_t* x); +extern mag_tensor_t* mag_silu(mag_tensor_t* x); +extern mag_tensor_t* mag_silu_(mag_tensor_t* x); +extern mag_tensor_t* mag_silu_dv(mag_tensor_t* x); +extern mag_tensor_t* mag_silu_dv_(mag_tensor_t* x); +extern mag_tensor_t* mag_tanh(mag_tensor_t* x); +extern mag_tensor_t* mag_tanh_(mag_tensor_t* x); +extern mag_tensor_t* mag_tanh_dv(mag_tensor_t* x); +extern mag_tensor_t* mag_tanh_dv_(mag_tensor_t* x); +extern mag_tensor_t* mag_relu(mag_tensor_t* x); +extern mag_tensor_t* mag_relu_(mag_tensor_t* x); +extern mag_tensor_t* mag_relu_dv(mag_tensor_t* x); +extern mag_tensor_t* mag_relu_dv_(mag_tensor_t* x); +extern mag_tensor_t* mag_gelu(mag_tensor_t* x); +extern mag_tensor_t* mag_gelu_(mag_tensor_t* x); +extern mag_tensor_t* mag_gelu_dv(mag_tensor_t* x); +extern mag_tensor_t* mag_gelu_dv_(mag_tensor_t* x); +extern mag_tensor_t* mag_add(mag_tensor_t* x, mag_tensor_t* y); +extern mag_tensor_t* mag_add_(mag_tensor_t* x, mag_tensor_t* y); +extern mag_tensor_t* mag_sub( mag_tensor_t* x, mag_tensor_t* y); +extern mag_tensor_t* mag_sub_(mag_tensor_t* x, mag_tensor_t* y); +extern mag_tensor_t* mag_mul( mag_tensor_t* x, mag_tensor_t* y); +extern mag_tensor_t* mag_mul_(mag_tensor_t* x, mag_tensor_t* y); +extern mag_tensor_t* mag_div(mag_tensor_t* x, mag_tensor_t* y); +extern mag_tensor_t* mag_div_(mag_tensor_t* x, mag_tensor_t* y); +extern mag_tensor_t* mag_adds(mag_tensor_t* x, float xi); +extern mag_tensor_t* mag_adds_(mag_tensor_t* x, float xi); +extern mag_tensor_t* mag_subs( mag_tensor_t* x, float xi); +extern mag_tensor_t* mag_subs_(mag_tensor_t* x, float xi); +extern mag_tensor_t* mag_muls(mag_tensor_t* x, float xi); +extern mag_tensor_t* mag_muls_(mag_tensor_t* x, float xi); +extern mag_tensor_t* mag_divs(mag_tensor_t* x, float xi); +extern mag_tensor_t* mag_divs_(mag_tensor_t* x, float xi); +extern mag_tensor_t* mag_matmul(mag_tensor_t* a, mag_tensor_t* b); +extern void mag_tensor_incref(mag_tensor_t* t); +extern bool mag_tensor_decref(mag_tensor_t* t); +extern void mag_tensor_copy_buffer_from(mag_tensor_t* t, const void* data, size_t size); +extern void mag_tensor_fill(mag_tensor_t* t, float x); +extern void mag_tensor_fill_random_uniform(mag_tensor_t* t, float min, float max); +extern void mag_tensor_fill_random_normal(mag_tensor_t* t, float mean, float stddev); +extern uint64_t mag_tensor_get_packed_refcounts(const mag_tensor_t* t); +extern void mag_tensor_retain(mag_tensor_t* t); +extern size_t mag_tensor_get_memory_usage(const mag_tensor_t* t); +extern void mag_tensor_print(const mag_tensor_t* t, bool with_header, bool with_data); +extern void mag_tensor_set_name(mag_tensor_t* t, const char* name); +extern void mag_tensor_fmt_name(mag_tensor_t* t, const char* fmt, ...); +extern const char* mag_tensor_get_name(const mag_tensor_t* t); +extern int64_t mag_tensor_rank(const mag_tensor_t* t); +extern const int64_t* mag_tensor_shape(const mag_tensor_t* t); +extern const int64_t* mag_tensor_strides(const mag_tensor_t* t); +extern mag_dtype_t mag_tensor_dtype(const mag_tensor_t* t); +extern void* mag_tensor_data_ptr(const mag_tensor_t* t); +extern int64_t mag_tensor_data_size(const mag_tensor_t* t); +extern int64_t mag_tensor_numel(const mag_tensor_t* t); +extern int64_t mag_tensor_num_rows(const mag_tensor_t* t); +extern int64_t mag_tensor_num_cols(const mag_tensor_t* t); +extern bool mag_tensor_is_scalar(const mag_tensor_t* t); +extern bool mag_tensor_is_vector(const mag_tensor_t* t); +extern bool mag_tensor_is_matrix(const mag_tensor_t* t); +extern bool mag_tensor_is_volume(const mag_tensor_t* t); +extern bool mag_tensor_is_shape_eq(const mag_tensor_t* a, const mag_tensor_t* b); +extern bool mag_tensor_are_strides_eq(const mag_tensor_t* a, const mag_tensor_t* b); +extern bool mag_tensor_can_broadcast(const mag_tensor_t* a, const mag_tensor_t* b); +extern bool mag_tensor_is_transposed(const mag_tensor_t* t); +extern bool mag_tensor_is_permuted(const mag_tensor_t* t); +extern bool mag_tensor_is_contiguous(const mag_tensor_t* t); +extern float mag_tensor_get_scalar_physical_index(mag_tensor_t* t, int64_t d0, int64_t d1, int64_t d2, int64_t d3, int64_t d4, int64_t d5); +extern void mag_tensor_set_scalar_physical_index(mag_tensor_t* t, int64_t d0, int64_t d1, int64_t d2, int64_t d3, int64_t d4, int64_t d5, float x); +extern float mag_tensor_get_scalar_virtual_index(mag_tensor_t* t, int64_t v_idx); +extern void mag_tensor_set_scalar_virtual_index(mag_tensor_t* t, int64_t v_idx, float x); +extern bool mag_tensor_eq(const mag_tensor_t* a, const mag_tensor_t* b); +extern bool mag_tensor_is_close(const mag_tensor_t* a, const mag_tensor_t* b, float eps, double* percent_eq); +extern void mag_tensor_img_draw_box(mag_tensor_t* t, int32_t x1, int32_t y1, int32_t x2, int32_t y2, int32_t wi, uint32_t rgb); +extern void mag_tensor_img_draw_text(mag_tensor_t* t, int32_t x, int32_t y, int32_t size, uint32_t rgb, const char* txt); +extern mag_ctx_t* mag_tensor_get_ctx(const mag_tensor_t* t); +extern void* mag_tensor_get_user_data(const mag_tensor_t* t); +extern void mag_tensor_set_user_data(mag_tensor_t* t, void* ud); +extern void mag_tensor_save(const mag_tensor_t* t, const char* file); +extern mag_tensor_t* mag_tensor_load(mag_ctx_t* ctx, const char* file); +extern mag_tensor_t* mag_tensor_load_image(mag_ctx_t* ctx, const char* file, mag_color_channels_t channels, uint32_t resize_w, uint32_t resize_h); +extern void mag_tensor_save_image(const mag_tensor_t* t, const char* file); +''' + diff --git a/python/magnetron_framework/magnetron/_lib_loader.py b/python/magnetron_framework/magnetron/_lib_loader.py new file mode 100644 index 0000000..a4f5102 --- /dev/null +++ b/python/magnetron_framework/magnetron/_lib_loader.py @@ -0,0 +1,26 @@ +from pathlib import Path +from magnetron._ffi_cdecl_generated import __MAG_CDECLS +import sys + +MAG_LIBS = [ + ('win32', 'magnetron.dll'), + ('linux', 'libmagnetron.so'), + ('darwin', 'libmagnetron.dylib'), +] + +def load_native_module(): + platform = sys.platform + lib_name = next((lib for os, lib in MAG_LIBS if platform.startswith(os)), None) + assert lib_name, f"Unsupported platform: {platform}" + + # Locate the library in the package directory + pkg_path = Path(__file__).parent + lib_path = pkg_path / lib_name + assert lib_path.exists(), f"magnetron shared library not found: {lib_path}" + + # Load the library using cffi + from cffi import FFI + ffi = FFI() + ffi.cdef(__MAG_CDECLS) # Define the C declarations + lib = ffi.dlopen(str(lib_path)) # Load the shared library + return ffi, lib \ No newline at end of file diff --git a/python/magnetron_framework/magnetron/core.py b/python/magnetron_framework/magnetron/core.py new file mode 100644 index 0000000..dda022b --- /dev/null +++ b/python/magnetron_framework/magnetron/core.py @@ -0,0 +1,1537 @@ +# (c) 2024 Mario "Neo" Sieg. + +# Implements core functionality: Context, Tensors and Operations. + +# To debug Python to C FFI calls: +# $ cp examples/perceptron.py tmp.py && gdb -ex r --args python3 tmp.py +# See also https://wiki.python.org/moin/DebuggingWithGdb + +import faulthandler +import weakref +from dataclasses import dataclass +from os import getenv +from os.path import isfile +from magnetron._lib_loader import load_native_module +from enum import Enum, auto + +# Enable faulthandler for debugging +faulthandler.enable() + +ffi, C = load_native_module() # Load the native magnetron shared library + +# Common constants +MAX_DIMS: int = 6 +MAX_ARG_TENSORS: int = 2 +MAG_MAX_OP_PARAMS: int = 6 +DIM_MAX: int = ((1 << 64) - 1) >> 1 + +def enable_log(enable: bool) -> None: + """ + Set logging mode for the magnetron backend. + + Parameters + ---------- + enable : bool + If True, enables logging of operations and internal states. + """ + C.mag_set_log_mode(enable) + +def pack_color(r: int, g: int, b: int) -> int: + """ + Packs three 8-bit color components (R, G, B) into a single 24-bit integer. + + Parameters + ---------- + r : int + Red component [0-255]. + g : int + Green component [0-255]. + b : int + Blue component [0-255]. + + Returns + ------- + int + A single integer representing the packed RGB color. + """ + return C.mag_pack_color_u8(r, g, b) + +class ComputeDevice(Enum): + """ + Compute devices available for parallel computations. + """ + CPU = 0 + CUDA = auto() + +class PRNGAlgorithm(Enum): + """ + Pseudorandom number generator algorithms. + """ + MERSENNE_TWISTER = 0 # Default - Mersenne Twister + PCG = auto() # Permuted Congruential Generator + +class DType(Enum): + """ + Supported data types for tensors. + """ + F32 = 0 + +class ColorChannels(Enum): + """ + Desired color channels when loading images. + """ + AUTO = 0 # Automatically determine the number of color channels + GRAY = auto() # Grayscale F32 + GRAY_A = auto() # Grayscale F32 + Alpha + RGB = auto() # R32G32B32 + RGBA = auto() # R32G32B32A32 + + @property + def name(self) -> str: + """Returns the operator name as a string.""" + assert self.value < self._COUNT.value + return ffi.string(C.mag_op_get_name(self.value)).decode('utf-8') + + @property + def mnemonic(self) -> str: + """Returns a short mnemonic name for the operator.""" + assert self.value < self._COUNT.value + return ffi.string(C.mag_op_get_mnemonic(self.value)).decode('utf-8') + + @property + def argument_count(self) -> int: + """Returns the number of tensor arguments required by this operator.""" + assert self.value < self._COUNT.value + return C.mag_op_get_argcount(self.value) + + @property + def supports_inplace(self) -> bool: + """Checks if the operator supports in-place modifications.""" + return C.mag_op_supports_inplace(self.value) + + @property + def is_unary(self) -> bool: + """Checks if the operator is a unary operator.""" + return self.argument_count == 1 + + @property + def is_binary(self) -> bool: + """Checks if the operator is a binary operator.""" + return self.argument_count == 2 + +class OpParam: + """ + Represents an operation parameter to be passed to a tensor operator. + """ + + def __init__(self, value: int) -> None: + """ + Internal constructor. + + Parameters + ---------- + value : int + Internal integer representation of the parameter. + """ + self.value = value + + @staticmethod + def new_int(x: int) -> 'OpParam': + """ + Creates an integer operation parameter. + + Parameters + ---------- + x : int + The integer value to store. + + Returns + ------- + OpParam + A new integer operation parameter. + """ + return OpParam(C.mag_op_param_int(x)) + + @staticmethod + def new_float(x: float) -> 'OpParam': + """ + Creates a float operation parameter. + + Parameters + ---------- + x : float + The float value to store. + + Returns + ------- + OpParam + A new float operation parameter. + """ + return OpParam(C.mag_op_param_float(x)) + + @property + def is_int(self) -> bool: + """Checks if the parameter is an integer type.""" + return C.mag_op_param_is_int(self.value) + + @property + def is_float(self) -> bool: + """Checks if the parameter is a float type.""" + return C.mag_op_param_is_float(self.value) + + @property + def unpack_int(self) -> int: + """ + Returns the stored integer value. + + Returns + ------- + int + The integer value of the parameter. + """ + assert self.is_int + return C.mag_op_param_unpack_int(self.value) + + @property + def unpack_float(self) -> float: + """ + Returns the stored float value. + + Returns + ------- + float + The float value of the parameter. + """ + assert self.is_float + return C.mag_op_param_unpack_float(self.value) + +class GraphEvalOrder(Enum): + """ + Order in which the computation graph should be evaluated. + Applies to deferred execution mode only. + """ + FORWARD = 0 # Evaluate graph left-to-right + REVERSE = 1 # Evaluate graph right-to-left + +class ExecutionMode(Enum): + """ + Execution modes for the magnetron context. + """ + EAGER = 0 # Execute operations immediately (dynamic graph) + DEFERRED = 1 # Build computation graph and execute later (static graph) + +@dataclass +class GlobalConfig: + verbose: bool = (getenv('MAG_VERBOSE', '0') == '1') + compute_device: ComputeDevice = ComputeDevice.CUDA if getenv('MAG_COMPUTE_DEVICE') == 'CUDA' else ComputeDevice.CPU + +class Context: + """ + Manages the magnetron context and tensor lifecycles, including device selection, + memory allocation, and execution mode. + + A global active context is created automatically when needed. + """ + + _active: 'Context' = None # Global context + + @staticmethod + def active() -> 'Context': + """ + Returns the active global context, creating one if it does not exist. + + Returns + ------- + Context + The currently active context. + """ + if Context._active is None: + enable_log(GlobalConfig.verbose) + Context._active = Context(GlobalConfig.compute_device) + return Context._active + + def __init__(self, device: ComputeDevice, *, execution_mode: ExecutionMode = ExecutionMode.EAGER): + """ + Initializes a new magnetron context. + + Parameters + ---------- + device : ComputeDevice + The compute device (CPU or CUDA). + execution_mode : ExecutionMode, optional + The execution mode (eager or deferred), by default EAGER. + """ + self.ctx = C.mag_ctx_create(device.value) + self.execution_mode = execution_mode + + @property + def compute_device(self) -> ComputeDevice: + """ + Returns the compute device used by this context. + + Returns + ------- + ComputeDevice + The compute device (CPU or CUDA). + """ + return ComputeDevice(C.mag_ctx_get_compute_device_type(self.ctx)) + + @property + def compute_device_name(self) -> str: + """ + Returns the name of the active compute device. + + Returns + ------- + str + Name of the device, e.g. "CPU" or "NVIDIA GPU". + """ + return ffi.string(C.mag_ctx_get_compute_device_name(self.ctx)).decode('utf-8') + + @property + def execution_mode(self) -> ExecutionMode: + """ + Returns the execution mode of the context. + + Returns + ------- + ExecutionMode + The current execution mode (EAGER or DEFERRED). + """ + return ExecutionMode(C.mag_ctx_get_exec_mode(self.ctx)) + + @execution_mode.setter + def execution_mode(self, mode: ExecutionMode): + """ + Sets the execution mode of the context. + + Parameters + ---------- + mode : ExecutionMode + Desired mode (EAGER or DEFERRED). + """ + C.mag_ctx_set_exec_mode(self.ctx, mode.value) + + @property + def prng_algorithm(self) -> PRNGAlgorithm: + """ + Returns the PRNG algorithm currently used. + + Returns + ------- + PRNGAlgorithm + The PRNG algorithm in use. + """ + return PRNGAlgorithm(C.mag_ctx_get_prng_algorithm(self.ctx)) + + @prng_algorithm.setter + def prng_algorithm(self, algorithm: PRNGAlgorithm): + """ + Sets the PRNG algorithm and re-seeds it. + + Parameters + ---------- + algorithm : PRNGAlgorithm + The desired PRNG algorithm. + """ + C.mag_ctx_set_prng_algorithm(self.ctx, algorithm.value, 0) + + @property + def os_name(self) -> str: + """ + Returns the operating system name. + + Returns + ------- + str + The OS name (e.g., "Linux"). + """ + return ffi.string(C.mag_ctx_get_os_name(self.ctx)).decode('utf-8') + + @property + def cpu_name(self) -> str: + """ + Returns the CPU name/brand string. + + Returns + ------- + str + The CPU name. + """ + return ffi.string(C.mag_ctx_get_cpu_name(self.ctx)).decode('utf-8') + + @property + def cpu_virtual_cores(self) -> int: + """ + Returns the number of virtual CPU cores (logical processors). + + Returns + ------- + int + Count of virtual CPU cores. + """ + return C.mag_ctx_get_cpu_virtual_cores(self.ctx) + + @property + def cpu_physical_cores(self) -> int: + """ + Returns the number of physical CPU cores. + + Returns + ------- + int + Count of physical CPU cores. + """ + return C.mag_ctx_get_cpu_physical_cores(self.ctx) + + @property + def cpu_sockets(self) -> int: + """ + Returns the number of CPU sockets present on the system. + + Returns + ------- + int + CPU socket count. + """ + return C.mag_ctx_get_cpu_sockets(self.ctx) + + @property + def physical_memory_total(self) -> int: + """ + Returns the total physical memory (RAM) in bytes. + + Returns + ------- + int + Total system memory in bytes. + """ + return C.mag_ctx_get_physical_memory_total(self.ctx) + + @property + def physical_memory_free(self) -> int: + """ + Returns the amount of free physical memory (RAM) in bytes. + + Returns + ------- + int + Free memory in bytes. + """ + return C.mag_ctx_get_physical_memory_free(self.ctx) + + @property + def physical_memory_used(self) -> int: + """ + Returns the amount of used physical memory (RAM) in bytes. + + Returns + ------- + int + Used memory in bytes. + """ + return abs(self.physical_memory_total - self.physical_memory_free) + + @property + def is_numa_system(self) -> bool: + """ + Checks if the system uses Non-Uniform Memory Access (NUMA). + + Returns + ------- + bool + True if NUMA is supported, False otherwise. + """ + return C.mag_ctx_is_numa_system(self.ctx) + + @property + def total_allocated_pool_memory(self) -> int: + """ + Returns the total allocated memory (pool) by this context. + + Returns + ------- + int + Total allocated memory in bytes. + """ + return C.mag_ctx_total_allocated_pool_memory(self.ctx) + + @property + def total_tensors_created(self) -> int: + """ + Returns the total number of tensors created (including views). + + Returns + ------- + int + Count of all tensors created. + """ + return C.mag_ctx_get_total_tensors_created(self.ctx) + + @property + def total_tensors_allocated(self) -> int: + """ + Returns the total number of allocated tensors (not including views). + + Returns + ------- + int + Count of allocated tensors. + """ + return C.mag_ctx_get_total_tensors_allocated(self.ctx) + + def start_profiler(self) -> None: + """ + Starts recording profiling information for operations. + Profiling must be stopped to produce a report. Slightly reduces performance. + """ + C.mag_ctx_profile_start_recording(self.ctx) + + def stop_profiler(self, export_csv_file: str | None = None) -> None: + """ + Stops recording profiling information and generates a report. + + Parameters + ---------- + export_csv_file : str, optional + Path to export a CSV profiling report. If None, only an internal report is generated. + """ + csv_file = ffi.NULL if export_csv_file is None else bytes(export_csv_file, 'utf-8') + C.mag_ctx_profile_stop_recording(self.ctx, csv_file) + + def __del__(self): + """ + Destructor that releases context resources. + """ + C.mag_ctx_destroy(self.ctx) + self.ctx = ffi.NULL + +class Tensor: + """ + Represents a tensor in the magnetron library. Supports various operations and transformations. + """ + + def __init__(self, internal_instance: ffi.CData | None = None) -> None: + """ + Internal constructor. Do not call directly. Use static methods or operators. + + Parameters + ---------- + internal_instance : ffi.CData or None + Internal tensor instance (C data pointer). + """ + self.context_ref = None + self.tensor = internal_instance + + def __del__(self) -> None: + """Releases tensor resources upon object destruction.""" + C.mag_tensor_decref(self.tensor) + self.tensor = ffi.NULL + + _DISPATCH = { + 1: C.mag_tensor_create_1d, + 2: C.mag_tensor_create_2d, + 3: C.mag_tensor_create_3d, + 4: C.mag_tensor_create_4d, + 5: C.mag_tensor_create_5d, + 6: C.mag_tensor_create_6d + } + assert len(_DISPATCH) == MAX_DIMS + + def _new(self, ctx: Context, *, shape: tuple[int, ...], dtype: DType = DType.F32, + name: str | None = None) -> None: + """ + Internal helper to create a new tensor instance. + + Parameters + ---------- + ctx : Context + The magnetron context where this tensor belongs. + shape : tuple[int, ...] + Dimensions of the tensor. + dtype : DType, optional + Data type of the tensor, by default DType.F32. + name : str or None, optional + A friendly name for the tensor, by default None. + """ + assert 0 < len(shape) <= MAX_DIMS, f'Invalid number of dimensions: {len(shape)}' + assert all(0 < dim <= DIM_MAX for dim in shape), 'Invalid dimension size' + self.context_ref = weakref.ref(ctx) + self.tensor = self._DISPATCH[len(shape)](ctx.ctx, dtype.value, *shape) + self.name = f'Tensor {self.shape}' if name is None else name + + @classmethod + def empty(cls, shape: tuple[int, ...], *, dtype: DType = DType.F32, name: str | None = None) -> 'Tensor': + """ + Creates an empty tensor with uninitialized data. + + Parameters + ---------- + shape : tuple[int, ...] + The shape of the tensor. + dtype : DType, optional + Data type of the tensor, by default F32. + name : str or None, optional + A friendly name for the tensor, by default None. + + Returns + ------- + Tensor + The newly created empty tensor. + """ + tensor = cls(None) + tensor._new(Context.active(), shape=shape, dtype=dtype, name=name) + return tensor + + @classmethod + def full(cls, shape: tuple[int, ...], *, fill_value: float, dtype: DType = DType.F32, + name: str | None = None) -> 'Tensor': + """ + Creates a tensor filled with a given constant value. + + Parameters + ---------- + shape : tuple[int, ...] + The shape of the tensor. + fill_value : float + The constant value to fill. + dtype : DType, optional + Data type, by default F32. + name : str or None, optional + A friendly name for the tensor, by default None. + + Returns + ------- + Tensor + The filled tensor. + """ + tensor = cls(None) + tensor._new(Context.active(), shape=shape, dtype=dtype, name=name) + C.mag_tensor_fill(tensor.tensor, fill_value) + return tensor + + @classmethod + def const(cls, data, *, dtype: DType = DType.F32, + name: str | None = None) -> 'Tensor': + """ + Creates a tensor from a nested Python list or a single scalar. + + Parameters + ---------- + data : scalar or nested list + Data to create the tensor from. + dtype : DType, optional + Data type, by default F32. + name : str or None, optional + A friendly name for the tensor, by default None. + + Returns + ------- + Tensor + The constructed tensor. + """ + def flatten_nested_lists(nested): + if not isinstance(nested, list): + return (), [nested] + elif len(nested) == 0: + return (0,), [] + else: + shapes = [] + flattened = [] + for item in nested: + shape_lst, flat = flatten_nested_lists(item) + shapes.append(shape_lst) + flattened.extend(flat) + first_shape = shapes[0] + for s in shapes: + assert s == first_shape, 'All sub-lists must have the same shape' + return (len(nested),) + first_shape, flattened + + shape, flattened_data = flatten_nested_lists(data) + tensor = cls(None) + tensor._new(Context.active(), shape=tuple(shape), dtype=dtype, name=name) + size: int = len(flattened_data) * ffi.sizeof('float') + C.mag_tensor_copy_buffer_from(tensor.tensor, ffi.new(f'float[{len(flattened_data)}]', flattened_data), size) + return tensor + + @classmethod + def zeros(cls, shape: tuple[int, ...], *, dtype: DType = DType.F32, + name: str | None = None) -> 'Tensor': + """ + Creates a tensor filled with zeros. + + Parameters + ---------- + shape : tuple[int, ...] + The shape of the tensor. + dtype : DType, optional + Data type, by default F32. + name : str or None, optional + A friendly name for the tensor, by default None. + + Returns + ------- + Tensor + The zero-filled tensor. + """ + return cls.full(shape, fill_value=0.0, dtype=dtype, name=name) + + @classmethod + def uniform(cls, shape: tuple[int, ...], *, interval: (float, float) = (-1.0, 1.0), dtype: DType = DType.F32, + name: str | None = None) -> 'Tensor': + """ + Creates a tensor filled with random uniform values within a given interval. + + Parameters + ---------- + shape : tuple[int, ...] + The shape of the tensor. + interval : (float, float), optional + Min and max values for uniform distribution, by default (-1.0, 1.0). + dtype : DType, optional + Data type, by default F32. + name : str or None, optional + A friendly name for the tensor, by default None. + + Returns + ------- + Tensor + The tensor filled with random values. + """ + tensor = cls(None) + tensor._new(Context.active(), shape=shape, dtype=dtype, name=name) + if interval[1] < interval[0]: + interval = (interval[1], interval[0]) + C.mag_tensor_fill_random_uniform(tensor.tensor, interval[0], interval[1]) + return tensor + + @classmethod + def normal(cls, shape: tuple[int, ...], *, mean: float, stddev: float) -> 'Tensor': + """ + Creates a tensor filled with random values from a normal distribution. + + Parameters + ---------- + shape : tuple[int, ...] + The shape of the tensor. + mean : float + Mean of the normal distribution. + stddev : float + Standard deviation of the normal distribution. + + Returns + ------- + Tensor + The tensor filled with normally distributed values. + """ + tensor = cls(None) + tensor._new(Context.active(), shape=shape, dtype=DType.F32) + C.mag_tensor_fill_random_normal(tensor.tensor, mean, stddev) + return tensor + + @classmethod + def load(cls, file_path: str) -> 'Tensor': + """ + Loads a tensor from a binary magnetron file. + + Parameters + ---------- + file_path : str + Path to the .magnetron file. + + Returns + ------- + Tensor + The loaded tensor. + """ + assert file_path.endswith('.magnetron'), 'File must be a magnetron file' + instance = C.mag_tensor_load(Context.active().ctx, bytes(file_path, 'utf-8')) + return cls(internal_instance=instance) + + @classmethod + def load_image(cls, file_path: str, *, + name: str | None = None, + channels=ColorChannels.AUTO, + resize_to: (int, int) = (0, 0)) -> 'Tensor': + """ + Loads an image from a file and creates a tensor. + + Parameters + ---------- + file_path : str + Path to the image file. + name : str or None, optional + A friendly name for the tensor, by default None. + channels : ColorChannels, optional + Desired color channels to load, by default AUTO. + resize_to : (int, int), optional + If not (0,0), resize the image to (width, height), by default (0,0). + + Returns + ------- + Tensor + The created image tensor. + """ + assert isfile(file_path), f'File not found: {file_path}' + instance = C.mag_tensor_load_image(Context.active().ctx, bytes(file_path, 'utf-8'), channels.value, + resize_to[0], resize_to[1]) + tensor = cls(instance) + if name is not None: + tensor.name = name + return tensor + + def print(self, print_header: bool = False, print_data: bool = True) -> None: + """ + Prints the tensor metadata and optionally its data. + + Parameters + ---------- + print_header : bool, optional + If True, prints a header line, by default False. + print_data : bool, optional + If True, prints the tensor data, by default True. + """ + C.mag_tensor_print(self.tensor, print_header, print_data) + + @property + def name(self) -> str: + """ + Returns the name of the tensor. + + Returns + ------- + str + The tensor's name. + """ + return ffi.string(C.mag_tensor_get_name(self.tensor)).decode('utf-8') + + @name.setter + def name(self, name: str) -> None: + """ + Sets a name for the tensor. + + Parameters + ---------- + name : str + The new name for the tensor. + """ + C.mag_tensor_set_name(self.tensor, bytes(name, 'utf-8')) + + @property + def rank(self) -> int: + """ + Returns the rank (number of dimensions) of the tensor. + + Returns + ------- + int + Number of dimensions. + """ + return C.mag_tensor_rank(self.tensor) + + @property + def shape(self) -> tuple[int, ...]: + """ + Returns the shape (dimensions) of the tensor. + + Returns + ------- + tuple[int, ...] + The dimensions of the tensor. + """ + return tuple(ffi.unpack(C.mag_tensor_shape(self.tensor), self.rank)) + + @property + def strides(self) -> tuple[int, ...]: + """ + Returns the strides of the tensor. + + Returns + ------- + tuple[int, ...] + The strides for each dimension. + """ + return tuple(ffi.unpack(C.mag_tensor_strides(self.tensor), self.rank)) + + @property + def dtype(self) -> DType: + """ + Returns the data type of the tensor. + + Returns + ------- + DType + The data type, e.g. DType.F32. + """ + return DType(C.mag_tensor_dtype(self.tensor)) + + @property + def data_ptr(self) -> int: + """ + Returns the pointer to the tensor's data buffer. + + Returns + ------- + int + Memory address of the tensor data. + """ + return int(ffi.cast('uintptr_t', C.mag_tensor_data_ptr(self.tensor))) + + def to_list(self) -> list[float]: + """ + Returns the tensor data as a Python list of floats. + + Returns + ------- + list[float] + A flat list containing all tensor elements. + """ + assert self.dtype == DType.F32, 'Invalid data type' + return ffi.unpack(ffi.cast('float*', C.mag_tensor_data_ptr(self.tensor)), self.numel) + + def scalar(self) -> float: + """ + Returns the scalar value of a 0D tensor. + + Returns + ------- + float + The single scalar value. + """ + assert self.dtype == DType.F32, 'Invalid data type' + return ffi.unpack(ffi.cast('float*', C.mag_tensor_data_ptr(self.tensor)), 1)[0] + + @property + def data_size(self) -> int: + """ + Returns the size of the tensor buffer in bytes. + + Returns + ------- + int + Data size in bytes. + """ + return C.mag_tensor_data_size(self.tensor) + + @property + def numel(self) -> int: + """ + Returns the number of elements in the tensor. + + Returns + ------- + int + Number of elements. + """ + return C.mag_tensor_numel(self.tensor) + + @property + def num_rows(self) -> int: + """ + Returns the number of rows in a 2D tensor (matrix). + + Returns + ------- + int + Number of rows. + """ + return C.mag_tensor_num_rows(self.tensor) + + @property + def num_cols(self) -> int: + """ + Returns the number of columns in a 2D tensor (matrix). + + Returns + ------- + int + Number of columns. + """ + return C.mag_tensor_num_cols(self.tensor) + + @property + def is_scalar(self) -> bool: + """ + Checks if the tensor is a scalar (0D). + + Returns + ------- + bool + True if scalar, otherwise False. + """ + return C.mag_tensor_is_scalar(self.tensor) + + @property + def is_vector(self) -> bool: + """ + Checks if the tensor is a vector (1D). + + Returns + ------- + bool + True if vector, otherwise False. + """ + return C.mag_tensor_is_vector(self.tensor) + + @property + def is_matrix(self) -> bool: + """ + Checks if the tensor is a matrix (2D). + + Returns + ------- + bool + True if matrix, otherwise False. + """ + return C.mag_tensor_is_matrix(self.tensor) + + @property + def is_volume(self) -> bool: + """ + Checks if the tensor is a volume (3D or higher). + + Returns + ------- + bool + True if volume, otherwise False. + """ + return C.mag_tensor_is_volume(self.tensor) + + @property + def is_transposed(self) -> bool: + """ + Checks if the tensor is transposed. + + Returns + ------- + bool + True if transposed, otherwise False. + """ + return C.mag_tensor_is_transposed(self.tensor) + + @property + def is_permuted(self) -> bool: + """ + Checks if the tensor is permuted. + + Returns + ------- + bool + True if permuted, otherwise False. + """ + return C.mag_tensor_is_permuted(self.tensor) + + def is_shape_eq(self, other: 'Tensor') -> bool: + """ + Checks if two tensors have the same shape. + + Parameters + ---------- + other : Tensor + Another tensor to compare shape. + + Returns + ------- + bool + True if shapes match, otherwise False. + """ + return C.mag_tensor_is_shape_eq(self.tensor, other.tensor) + + def are_strides_eq(self, other: 'Tensor') -> bool: + """ + Checks if two tensors have identical strides. + + Parameters + ---------- + other : Tensor + Another tensor to compare strides. + + Returns + ------- + bool + True if strides match, otherwise False. + """ + return C.mag_tensor_are_strides_eq(self.tensor, other.tensor) + + def can_broadcast(self, other: 'Tensor') -> bool: + """ + Checks if `other` tensor can be broadcasted to the shape of this tensor. + + Parameters + ---------- + other : Tensor + Another tensor. + + Returns + ------- + bool + True if broadcastable, otherwise False. + """ + return C.mag_tensor_can_broadcast(self.tensor, other.tensor) + + @property + def width(self) -> int: + """ + Returns the width of an image tensor (assumes layout: CxHxW). + + Returns + ------- + int + Width dimension. + """ + return self.shape[2] + + @property + def height(self) -> int: + """ + Returns the height of an image tensor (assumes layout: CxHxW). + + Returns + ------- + int + Height dimension. + """ + return self.shape[1] + + @property + def channels(self) -> int: + """ + Returns the number of channels in an image tensor (assumes layout: CxHxW). + + Returns + ------- + int + Number of channels. + """ + return self.shape[0] + + @property + def is_contiguous(self) -> bool: + """ + Checks if the tensor is contiguous in memory. + + Returns + ------- + bool + True if contiguous, otherwise False. + """ + return C.mag_tensor_is_contiguous(self.tensor) + + def is_close(self, other: 'Tensor', eps: float = -1.0, print_eq_percent: bool = False) -> (bool, float): + """ + Checks if the tensor is close to another tensor within a given epsilon. + + Parameters + ---------- + other : Tensor + Another tensor to compare to. + eps : float, optional + Epsilon tolerance. If -1.0, uses a default internal value. + print_eq_percent : bool, optional + If True, prints the percentage of equal elements. + + Returns + ------- + (bool, float) + A tuple (close_status, eq_percent). + close_status is True if the tensors are close. + eq_percent is the percentage of approximately equal elements. + """ + percent_eq = ffi.new('double[1]') + is_eq = C.mag_tensor_is_close(self.tensor, other.tensor, eps, percent_eq) + if print_eq_percent: + print(f'Tensors are close: {is_eq}, Percent equal: {percent_eq[0]:.2f}%') + return is_eq, percent_eq[0] + + def draw_box(self, p1: (int, int), p2: (int, int), width: int = 2, rgb: int = 0xffffff): + """ + Draws a rectangular box on an image tensor. + + Parameters + ---------- + p1 : (int, int) + Top-left corner coordinates (x, y). + p2 : (int, int) + Bottom-right corner coordinates (x, y). + width : int, optional + Line width of the box, by default 2. + rgb : int, optional + 24-bit RGB color (e.g. 0xFFFFFF for white), by default 0xffffff. + """ + assert p2[0] > p1[0] and p2[1] > p1[1] and width > 0 + C.mag_tensor_img_draw_box(self.tensor, p1[0], p1[1], p2[0], p2[1], width, rgb & 0xffffff) + + def draw_text(self, p: (int, int), size: int, txt: str, rgb: int = 0xffffff): + """ + Draws text on an image tensor. + + Parameters + ---------- + p : (int, int) + Position to draw text (x, y). + size : int + Font size. + txt : str + The text to draw. + rgb : int, optional + 24-bit RGB color, by default white (0xffffff). + """ + C.mag_tensor_img_draw_text(self.tensor, p[0], p[1], size, rgb & 0xffffff, bytes(txt, 'utf-8')) + + def save(self, file_path: str) -> None: + """ + Saves the tensor to a binary magnetron file. + + Parameters + ---------- + file_path : str + File path to save the tensor. Appends '.magnetron' if not present. + """ + if not file_path.endswith('.magnetron'): + file_path += '.magnetron' + C.mag_tensor_save(self.tensor, bytes(file_path, 'utf-8')) + + def save_image(self, file_path: str) -> None: + """ + Saves a 3D image tensor as a JPG image file. + + Parameters + ---------- + file_path : str + File path to save the image. + + Raises + ------ + AssertionError + If tensor is not 3D or channel count is not supported. + """ + assert self.rank == 3, 'Tensor must be a 3D image tensor' + assert self.channels in (1, 3, 4), 'Invalid number of color channels' + C.mag_tensor_save_image(self.tensor, bytes(file_path, 'utf-8')) + + def clone(self) -> 'Tensor': + """ + Creates a new tensor with the same data as this one (deep copy). + + Returns + ------- + Tensor + A cloned tensor. + """ + return Tensor(C.mag_clone(self.tensor)) + + def view(self) -> 'Tensor': + """ + Creates a view of the tensor that shares underlying data (shallow copy). + + Returns + ------- + Tensor + A view tensor. + """ + return Tensor(C.mag_view(self.tensor)) + + def transpose(self) -> 'Tensor': + """ + Transposes the tensor (swaps the last two dimensions). + + Returns + ------- + Tensor + A transposed tensor. + """ + return Tensor(C.mag_transpose(self.tensor)) + + def permute(self, axes: tuple[int, ...]) -> 'Tensor': + """ + Permutes the dimensions of the tensor. + + Parameters + ---------- + axes : tuple[int, ...] + A tuple specifying the permutation of axes. + + Returns + ------- + Tensor + A tensor with permuted dimensions. + """ + assert len(axes) == MAX_DIMS, f'Invalid number of axes: {axes}' + for i in range(MAX_DIMS): + assert 0 <= axes[i] < MAX_DIMS + for j in range(i + 1, MAX_DIMS): + assert axes[i] != axes[j], f'Duplicate axis: {axes[i]}' + return Tensor(C.mag_permute(self.tensor, *axes)) + + def mean(self) -> 'Tensor': + """Computes the mean of all elements in the tensor.""" + return Tensor(C.mag_mean(self.tensor)) + + def min(self) -> 'Tensor': + """Computes the minimum value in the tensor.""" + return Tensor(C.mag_min(self.tensor)) + + def max(self) -> 'Tensor': + """Computes the maximum value in the tensor.""" + return Tensor(C.mag_max(self.tensor)) + + def sum(self) -> 'Tensor': + """Computes the sum of all elements in the tensor.""" + return Tensor(C.mag_sum(self.tensor)) + + def abs(self) -> 'Tensor': + """Computes element-wise absolute value.""" + return Tensor(C.mag_abs(self.tensor)) + + def abs_(self) -> 'Tensor': + """In-place element-wise absolute value.""" + return Tensor(C.mag_abs_(self.tensor)) + + def neg(self) -> 'Tensor': + """Computes element-wise negation.""" + return Tensor(C.mag_neg(self.tensor)) + + def neg_(self) -> 'Tensor': + """In-place element-wise negation.""" + return Tensor(C.mag_neg_(self.tensor)) + + def __neg__(self) -> 'Tensor': + """Overloads unary negation: -X.""" + return self.neg() + + def log(self) -> 'Tensor': + """Computes element-wise natural logarithm.""" + return Tensor(C.mag_log(self.tensor)) + + def log_(self) -> 'Tensor': + """In-place element-wise natural logarithm.""" + return Tensor(C.mag_log_(self.tensor)) + + def sqr(self) -> 'Tensor': + """Computes element-wise square of values.""" + return Tensor(C.mag_sqr(self.tensor)) + + def sqr_(self) -> 'Tensor': + """In-place element-wise square of values.""" + return Tensor(C.mag_sqr_(self.tensor)) + + def sqrt(self) -> 'Tensor': + """Computes element-wise square root.""" + return Tensor(C.mag_sqrt(self.tensor)) + + def sqrt_(self) -> 'Tensor': + """In-place element-wise square root.""" + return Tensor(C.mag_sqrt_(self.tensor)) + + def sin(self) -> 'Tensor': + """Computes element-wise sine.""" + return Tensor(C.mag_sin(self.tensor)) + + def sin_(self) -> 'Tensor': + """In-place element-wise sine.""" + return Tensor(C.mag_sin_(self.tensor)) + + def cos(self) -> 'Tensor': + """Computes element-wise cosine.""" + return Tensor(C.mag_cos(self.tensor)) + + def cos_(self) -> 'Tensor': + """In-place element-wise cosine.""" + return Tensor(C.mag_cos_(self.tensor)) + + def heaviside_step(self) -> 'Tensor': + """Computes element-wise Heaviside step function.""" + return Tensor(C.mag_step(self.tensor)) + + def heaviside_step_(self) -> 'Tensor': + """In-place element-wise Heaviside step function.""" + return Tensor(C.mag_step_(self.tensor)) + + def softmax(self, derivative: bool = False) -> 'Tensor': + """ + Applies softmax or its derivative on the tensor. + + Parameters + ---------- + derivative : bool, optional + If True, computes softmax derivative instead of softmax, by default False. + + Returns + ------- + Tensor + The transformed tensor. + """ + return Tensor(C.mag_softmax_dv(self.tensor) if derivative else C.mag_softmax(self.tensor)) + + def softmax_(self, derivative: bool = False) -> 'Tensor': + """In-place softmax or softmax derivative.""" + return Tensor(C.mag_softmax_dv_(self.tensor) if derivative else C.mag_softmax_(self.tensor)) + + def sigmoid(self, derivative: bool = False) -> 'Tensor': + """ + Applies sigmoid or its derivative on the tensor. + + Parameters + ---------- + derivative : bool, optional + If True, computes sigmoid derivative, by default False. + + Returns + ------- + Tensor + The transformed tensor. + """ + return Tensor(C.mag_sigmoid_dv(self.tensor) if derivative else C.mag_sigmoid(self.tensor)) + + def sigmoid_(self, derivative: bool = False) -> 'Tensor': + """In-place sigmoid or sigmoid derivative.""" + return Tensor(C.mag_sigmoid_dv_(self.tensor) if derivative else C.mag_sigmoid_(self.tensor)) + + def hard_sigmoid(self) -> 'Tensor': + """Applies hard sigmoid to the tensor.""" + return Tensor(C.mag_hard_sigmoid(self.tensor)) + + def hard_sigmoid_(self) -> 'Tensor': + """In-place hard sigmoid.""" + return Tensor(C.mag_hard_sigmoid_(self.tensor)) + + def silu(self, derivative: bool = False) -> 'Tensor': + """ + Applies SiLU (Sigmoid-Weighted Linear Unit) or its derivative. + + Parameters + ---------- + derivative : bool, optional + If True, applies SiLU derivative, by default False. + + Returns + ------- + Tensor + The transformed tensor. + """ + return C.mag_silu_dv(self.tensor) if derivative else C.mag_silu(self.tensor) + + def silu_(self, derivative: bool = False) -> 'Tensor': + """In-place SiLU or SiLU derivative.""" + return Tensor(C.mag_silu_dv_(self.tensor) if derivative else C.mag_silu_(self.tensor)) + + def tanh(self, derivative: bool = False) -> 'Tensor': + """ + Applies hyperbolic tangent or its derivative. + + Parameters + ---------- + derivative : bool, optional + If True, computes tanh derivative, by default False. + + Returns + ------- + Tensor + The transformed tensor. + """ + return Tensor(C.mag_tanh_dv(self.tensor) if derivative else C.mag_tanh(self.tensor)) + + def tanh_(self, derivative: bool = False) -> 'Tensor': + """In-place tanh or tanh derivative.""" + return Tensor(C.mag_tanh_dv_(self.tensor) if derivative else C.mag_tanh_(self.tensor)) + + def relu(self, derivative: bool = False) -> 'Tensor': + """ + Applies ReLU (Rectified Linear Unit) or its derivative. + + Parameters + ---------- + derivative : bool, optional + If True, computes ReLU derivative, by default False. + + Returns + ------- + Tensor + The transformed tensor. + """ + return Tensor(C.mag_relu_dv(self.tensor) if derivative else C.mag_relu(self.tensor)) + + def relu_(self, derivative: bool = False) -> 'Tensor': + """In-place ReLU or ReLU derivative.""" + return Tensor(C.mag_relu_dv_(self.tensor) if derivative else C.mag_relu_(self.tensor)) + + def gelu(self, derivative: bool = False) -> 'Tensor': + """ + Applies GELU (Gaussian Error Linear Unit) or its derivative. + + Parameters + ---------- + derivative : bool, optional + If True, computes GELU derivative, by default False. + + Returns + ------- + Tensor + The transformed tensor. + """ + return Tensor(C.mag_gelu_dv(self.tensor) if derivative else C.mag_gelu(self.tensor)) + + def gelu_(self, derivative: bool = False) -> 'Tensor': + """In-place GELU or GELU derivative.""" + return Tensor(C.mag_gelu_dv_(self.tensor) if derivative else C.mag_gelu_(self.tensor)) + + def __add__(self, other: object | int | float) -> 'Tensor': + """Element-wise addition with another tensor or scalar.""" + return Tensor(C.mag_add(self.tensor, other.tensor) if isinstance(other, Tensor) else C.mag_adds(self.tensor, float(other))) + + def __iadd__(self, other: object | int | float) -> 'Tensor': + """In-place element-wise addition.""" + return Tensor(C.mag_add_(self.tensor, other.tensor) if isinstance(other, Tensor) else C.mag_adds_(self.tensor, float(other))) + + def __sub__(self, other: object | int | float) -> 'Tensor': + """Element-wise subtraction with another tensor or scalar.""" + return Tensor(C.mag_sub(self.tensor, other.tensor) if isinstance(other, Tensor) else C.mag_subs(self.tensor, float(other))) + + def __isub__(self, other: object | int | float) -> 'Tensor': + """In-place element-wise subtraction.""" + return Tensor(C.mag_sub_(self.tensor, other.tensor) if isinstance(other, Tensor) else C.mag_subs_(self.tensor, float(other))) + + def __mul__(self, other: object | int | float) -> 'Tensor': + """Element-wise multiplication with another tensor or scalar.""" + return Tensor(C.mag_mul(self.tensor, other.tensor) if isinstance(other, Tensor) else C.mag_muls(self.tensor, float(other))) + + def __imul__(self, other: object | int | float) -> 'Tensor': + """In-place element-wise multiplication.""" + return Tensor(C.mag_mul_(self.tensor, other.tensor) if isinstance(other, Tensor) else C.mag_muls_(self.tensor, float(other))) + + def __truediv__(self, other: object | int | float) -> 'Tensor': + """Element-wise division with another tensor or scalar.""" + return Tensor(C.mag_div(self.tensor, other.tensor) if isinstance(other, Tensor) else C.mag_divs(self.tensor, float(other))) + + def __itruediv__(self, other: object | int | float) -> 'Tensor': + """In-place element-wise division.""" + return Tensor(C.mag_div_(self.tensor, other.tensor) if isinstance(other, Tensor) else C.mag_divs_(self.tensor, float(other))) + + def __matmul__(self, other: 'Tensor') -> 'Tensor': + """Matrix multiplication with another tensor: A @ B.""" + return Tensor(C.mag_matmul(self.tensor, other.tensor)) + + def __imatmul__(self, other: 'Tensor') -> 'Tensor': + """In-place matrix multiplication: A @= B.""" + return Tensor(C.mag_matmul_(self.tensor, other.tensor)) + + def __eq__(self, other: 'Tensor') -> bool: + """ + Checks if two tensors have identical data and shape. + + Parameters + ---------- + other : Tensor + Another tensor to compare. + + Returns + ------- + bool + True if equal, otherwise False. + """ + return C.mag_tensor_eq(self.tensor, other.tensor) + + def __str__(self) -> str: + """String representation prints the tensor header and data.""" + self.print(True, True) + return '' diff --git a/python/magnetron_framework/magnetron/models.py b/python/magnetron_framework/magnetron/models.py new file mode 100644 index 0000000..0a9a0ef --- /dev/null +++ b/python/magnetron_framework/magnetron/models.py @@ -0,0 +1,110 @@ +# (c) 2024 Mario "Neo" Sieg. +# Implements high level model classes for neural networks based on the magnetron.core module. + +import time +from abc import ABC + +import magnetron as mag + + +class Layer(ABC): + def forward(self, inputs: mag.Tensor) -> mag.Tensor: + pass + + def backward(self, is_in: bool, cache: mag.Tensor, delta: mag.Tensor, rate: float) -> mag.Tensor: + pass + + +class Model(ABC): + def forward(self, inputs: mag.Tensor) -> mag.Tensor: + pass + + def backward(self, outputs: mag.Tensor, targets: mag.Tensor, rate: float): + pass + + def train(self, inputs: list[mag.Tensor], targets: list[mag.Tensor], epochs: int, learning_rate: float): + pass + + def summary(self): + pass + + +class Optim: + @staticmethod + def mse(y: mag.Tensor, y_hat: mag.Tensor) -> float: + """Mean Squared Error""" + return (y - y_hat).sqr_().mean().scalar() + + @staticmethod + def cross_entropy(y: mag.Tensor, y_hat: mag.Tensor) -> float: + """Cross Entropy Loss""" + return -(y * y_hat.log_()).sum().scalar() + + +class DenseLayer(Layer): + def __init__(self, in_features: int, out_features: int): + self.weight = mag.Tensor.uniform(shape=(out_features, in_features)) + self.bias = mag.Tensor.uniform(shape=(out_features, 1)) + self.cache = None + + def forward(self, prev: mag.Tensor) -> mag.Tensor: + return (self.weight @ prev + self.bias).sigmoid() + + def backward(self, is_in: bool, cache: mag.Tensor, delta: mag.Tensor, rate: float) -> mag.Tensor: + self.weight -= (delta @ cache.transpose().clone()) * rate + self.bias -= delta * rate + if is_in: + return (self.weight.transpose() @ delta) * cache.sigmoid(derivative=True) + else: + return delta + + +class SequentialModel(Model): + def __init__(self, layers: list[DenseLayer]): + super().__init__() + assert len(layers) > 0 + self.layers = layers + self.cache = [] + self.loss_epoch_step = 1000 + + def forward(self, inputs: mag.Tensor) -> mag.Tensor: + prev = inputs + self.cache.clear() + self.cache.append(prev) + for i in range(0, len(self.layers)): + prev = self.layers[i].forward(prev) + self.cache.append(prev) + return prev + + def backward(self, outputs: mag.Tensor, targets: mag.Tensor, rate: float): + delta = (outputs - targets) * outputs.sigmoid(derivative=True) + for i in reversed(range(0, len(self.layers))): + delta = self.layers[i].backward(i > 0, self.cache[i], delta, rate) + + def train(self, inputs: list[mag.Tensor], targets: list[mag.Tensor], epochs: int, learning_rate: float): + assert len(inputs) == len(targets) + print(f'Training started {epochs} epochs with learning rate {learning_rate}') + now = time.time_ns() + losses = [] + for epoch in range(0, epochs - 1): + total_loss: float = 0 + for i in range(0, len(inputs)): + pred: mag.Tensor = self.forward(inputs[i]) + self.backward(pred, targets[i], learning_rate) + total_loss += Optim.mse(pred, targets[i]) + mean_loss = total_loss / len(inputs) + losses.append(mean_loss) + if epoch % self.loss_epoch_step == 0: + print(f'Epoch: {epoch}, Loss: {mean_loss}') + print(f'Training finished in {(time.time_ns() - now) / 1e9} seconds') + return losses + + def summary(self): + trainable_params = 0 + for layer in self.layers: + trainable_params += layer.weight.numel + layer.bias.numel + layers: int = len(self.layers) + print('---- Model Summary ----') + print(f'Trainable Parameters: {trainable_params}') + print(f'Layers: {layers}') + print('-----------------------') diff --git a/python/magnetron_framework/requirements.txt b/python/magnetron_framework/requirements.txt new file mode 100644 index 0000000..2dc0a16 --- /dev/null +++ b/python/magnetron_framework/requirements.txt @@ -0,0 +1 @@ +cffi \ No newline at end of file diff --git a/python/magnetron_framework/setup.py b/python/magnetron_framework/setup.py new file mode 100644 index 0000000..ed3806b --- /dev/null +++ b/python/magnetron_framework/setup.py @@ -0,0 +1,79 @@ +# (c) 2024 Mario 'Neo' Sieg. + +import os +import subprocess +import multiprocessing + +from setuptools import setup, Extension +from setuptools.command.build_ext import build_ext + +CMAKE_ROOT: str = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')) # Root directory of the CMake project +NUM_JOBS: int = max(multiprocessing.cpu_count() - 1, 1) # Use all but one core + +class BuildException(Exception): + def __init__(self, message: str): + self.message = message + super().__init__(self.message) + +class CMakeBuildExtension(Extension): + def __init__(self, name, root_dir: str=''): + super().__init__(name, sources=[]) + self.root_dir = os.path.abspath(root_dir) + +class CMakeBuildExecutor(build_ext): + def initialize_options(self): + super().initialize_options() + + def run(self): + try: + print(subprocess.check_output(['cmake', '--version'])) + except OSError: + raise BuildException('CMake must be installed to build the magnetron binaries from source. Please install CMake and try again.') + super().run() + for ext in self.extensions: + self.build_extension(ext) + + def build_extension(self, ext): + if not os.path.exists(self.build_temp): + os.makedirs(self.build_temp) + cmake_args = [ + '-DMAGNETRON_ENABLE_CUDA=OFF', # TODO: Fix cuda compilation + f'-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={os.path.abspath(os.path.join(self.build_lib, "magnetron"))}', + '-DCMAKE_BUILD_TYPE=Release', + ] + build_args = [ + '--target magnetron', # Only build the magnetron library + f'-j{NUM_JOBS}', + '-v' + ] + print(subprocess.check_call(['cmake', ext.root_dir] + cmake_args, cwd=self.build_temp)) + print(subprocess.check_call(['cmake', '--build', '.'] + build_args, cwd=self.build_temp)) + +# Setup dependencies from requirements.txt +lib_folder = os.path.dirname(os.path.realpath(__file__)) +requirement_path = f'{lib_folder}/requirements.txt' +install_requires = [] +if os.path.isfile(requirement_path): + with open(requirement_path) as f: + install_requires = f.read().splitlines() + +# Setup magnetron package +setup( + name='magnetron', + version='0.1.0', + author='Mario Sieg', + author_email='mario.sieg.64@gmail.com', + description='A lightweight machine learning library with GPU support.', + long_description='A lightweight machine learning library with GPU support.', + packages=['magnetron'], + package_data={ + 'magnetron': ['*.dylib', '*.so', '*.dll'], + }, + include_package_data=True, + ext_modules=[CMakeBuildExtension('magnetron', root_dir=CMAKE_ROOT)], + cmdclass={ + 'build_ext': CMakeBuildExecutor, + }, + zip_safe=False, + install_requires=install_requires +) diff --git a/python/magnetron_framework/wavelet.egg-info/PKG-INFO b/python/magnetron_framework/wavelet.egg-info/PKG-INFO new file mode 100644 index 0000000..44ffd12 --- /dev/null +++ b/python/magnetron_framework/wavelet.egg-info/PKG-INFO @@ -0,0 +1,9 @@ +Metadata-Version: 2.1 +Name: magnetron +Version: 0.1.0 +Summary: A lightweight machine learning library with GPU support. +Author: Mario Sieg +Author-email: mario.sieg.64@gmail.com +Requires-Dist: cffi + +A lightweight machine learning library with GPU support. diff --git a/python/magnetron_framework/wavelet.egg-info/SOURCES.txt b/python/magnetron_framework/wavelet.egg-info/SOURCES.txt new file mode 100644 index 0000000..6b679ef --- /dev/null +++ b/python/magnetron_framework/wavelet.egg-info/SOURCES.txt @@ -0,0 +1,12 @@ +setup.py +magnetron/__init__.py +magnetron/_ffi_cdecl_generated.py +magnetron/_lib_loader.py +magnetron/core.py +magnetron/models.py +magnetron.egg-info/PKG-INFO +magnetron.egg-info/SOURCES.txt +magnetron.egg-info/dependency_links.txt +magnetron.egg-info/not-zip-safe +magnetron.egg-info/requires.txt +magnetron.egg-info/top_level.txt \ No newline at end of file diff --git a/python/magnetron_framework/wavelet.egg-info/dependency_links.txt b/python/magnetron_framework/wavelet.egg-info/dependency_links.txt new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/python/magnetron_framework/wavelet.egg-info/dependency_links.txt @@ -0,0 +1 @@ + diff --git a/python/magnetron_framework/wavelet.egg-info/not-zip-safe b/python/magnetron_framework/wavelet.egg-info/not-zip-safe new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/python/magnetron_framework/wavelet.egg-info/not-zip-safe @@ -0,0 +1 @@ + diff --git a/python/magnetron_framework/wavelet.egg-info/requires.txt b/python/magnetron_framework/wavelet.egg-info/requires.txt new file mode 100644 index 0000000..6a88e4b --- /dev/null +++ b/python/magnetron_framework/wavelet.egg-info/requires.txt @@ -0,0 +1 @@ +cffi diff --git a/python/magnetron_framework/wavelet.egg-info/top_level.txt b/python/magnetron_framework/wavelet.egg-info/top_level.txt new file mode 100644 index 0000000..520649a --- /dev/null +++ b/python/magnetron_framework/wavelet.egg-info/top_level.txt @@ -0,0 +1 @@ +magnetron diff --git a/python/magnetron_viewer/icons/folder.png b/python/magnetron_viewer/icons/folder.png new file mode 100644 index 0000000..2ce1ce7 Binary files /dev/null and b/python/magnetron_viewer/icons/folder.png differ diff --git a/python/magnetron_viewer/icons/icon.png b/python/magnetron_viewer/icons/icon.png new file mode 100644 index 0000000..b5c9fee Binary files /dev/null and b/python/magnetron_viewer/icons/icon.png differ diff --git a/python/magnetron_viewer/icons/icons_credits.txt b/python/magnetron_viewer/icons/icons_credits.txt new file mode 100644 index 0000000..9e89525 --- /dev/null +++ b/python/magnetron_viewer/icons/icons_credits.txt @@ -0,0 +1 @@ +Icons designed by Smashicons from Flaticon diff --git a/python/magnetron_viewer/icons/metadata.png b/python/magnetron_viewer/icons/metadata.png new file mode 100644 index 0000000..5d780a7 Binary files /dev/null and b/python/magnetron_viewer/icons/metadata.png differ diff --git a/python/magnetron_viewer/icons/tensor.png b/python/magnetron_viewer/icons/tensor.png new file mode 100644 index 0000000..7049dcc Binary files /dev/null and b/python/magnetron_viewer/icons/tensor.png differ diff --git a/python/magnetron_viewer/logo.png b/python/magnetron_viewer/logo.png new file mode 100644 index 0000000..d05670b Binary files /dev/null and b/python/magnetron_viewer/logo.png differ diff --git a/python/magnetron_viewer/main.py b/python/magnetron_viewer/main.py new file mode 100644 index 0000000..69ceffc --- /dev/null +++ b/python/magnetron_viewer/main.py @@ -0,0 +1,158 @@ +# (c) 2024 Mario 'Neo' Sieg. + +import time +import sys +import os +from magnetron.core import Tensor +from PyQt5.QtWidgets import * +from PyQt5.QtGui import * +from PyQt5.QtCore import * + +FONT_SIZE: int = 14 + + +def process_events_idle(): + for _ in range(0, 5): # Sleep for a bit to allow the loading box to show up + QApplication.processEvents() + time.sleep(0.1) + + +class MAGNETRONViewer(QMainWindow): + def __init__(self): + super().__init__() + self.window_icon = QIcon('logo.png') + self.tensor_icon = QIcon('icons/tensor.png') + self.folder_icon = QIcon('icons/folder.png') + self.metadata_icon = QIcon('icons/metadata.png') + self.setWindowTitle('magnetron File Viewer') + self.setWindowIcon(self.window_icon) + self.resize(1920, 1080) + + main_widget = QWidget() + main_layout = QVBoxLayout(main_widget) + splitter = QSplitter(Qt.Horizontal) + + self.tensor_tree = QTreeWidget() + self.tensor_tree.setHeaderHidden(True) + self.tensor_tree.setStyleSheet(f'font-size: {FONT_SIZE}px;') + self.tensor_tree.itemClicked.connect(self.show_tensor_data) + + self.data_view_container = QWidget() + data_view_layout = QVBoxLayout(self.data_view_container) + + self.data_view = QPlainTextEdit() + self.data_view.setReadOnly(True) + self.data_view.setStyleSheet(f'font-size: {FONT_SIZE}px;') + + data_view_layout.addWidget(self.data_view) + data_view_layout.setContentsMargins(0, 0, 0, 0) + + self.info_panel = QTextEdit() + self.info_panel.setReadOnly(True) + self.info_panel.setStyleSheet(f'font-size: {FONT_SIZE}px;') + + splitter.addWidget(self.tensor_tree) + splitter.addWidget(self.data_view_container) + splitter.addWidget(self.info_panel) + splitter.setStretchFactor(0, 1) + splitter.setStretchFactor(1, 8) + splitter.setStretchFactor(2, 1) + + main_layout.addWidget(splitter) + self.setCentralWidget(main_widget) + + menu_bar = self.menuBar() + file_menu = menu_bar.addMenu('File') + open_action = QAction('Open', self) + open_action.triggered.connect(self.open_file) + file_menu.addAction(open_action) + + self.tensors = {} + self.metadata = {} + + def open_file(self): + file_name, _ = QFileDialog.getOpenFileName(self, 'Open magnetron File', os.getcwd(), + 'magnetron Files (*.magnetron);;All Files (*)') + if not file_name: + print('No file selected') + return + + self.setWindowTitle(f'magnetron File Viewer - {os.path.basename(file_name)}') + + self.tensor_tree.clear() + self.tensors.clear() + self.metadata.clear() + + tensor_file = Tensor.load(file_name) + tensors_item = QTreeWidgetItem(self.tensor_tree) + tensors_item.setText(0, 'Tensors') + tensors_item.setIcon(0, self.folder_icon) + for tensor in [tensor_file]: + tensor.name = tensor.name or f'Tensor {len(self.tensors) + 1}' + self.tensors[tensor.name] = tensor + tensor_item = QTreeWidgetItem(tensors_item) + tensor_item.setText(0, tensor.name) + tensor_item.setIcon(0, self.tensor_icon) + + metadata_item = QTreeWidgetItem(self.tensor_tree) + metadata_item.setText(0, 'Metadata') + metadata_item.setIcon(0, self.folder_icon) + metadata = { + 'File Name': os.path.basename(file_name), + 'File Size': os.path.getsize(file_name), + 'File Path': file_name + } + for key, value in metadata.items(): + self.metadata[key] = value + metadata_entry = QTreeWidgetItem(metadata_item) + metadata_entry.setText(0, f'{key}: {value}') + metadata_entry.setIcon(0, self.metadata_icon) + + self.tensor_tree.expandAll() + + def show_tensor_data(self, item): + parent = item.parent() + if parent is None or parent.text(0) != 'Tensors': + return + + tensor_name = item.text(0) + if tensor_name not in self.tensors: + return + tensor = self.tensors[tensor_name] + tensor_data = tensor.to_list() + + rows = [] + elements_per_row = 16 + for i in range(0, len(tensor_data), elements_per_row): + row = ' '.join(f'{value:10.5f}' for value in tensor_data[i:i + elements_per_row]) + rows.append(row) + data_str = '\n'.join(rows) + self.data_view.setPlainText(data_str) + + extra_info = [ + f'Name: {tensor.name}', + f'Dimensions: {tensor.shape}', + f'Strides: {tensor.strides}', + f'DType: {tensor.dtype}', + f'Rank: {tensor.rank}', + f'Total Elements: {tensor.numel}', + f'Total Bytes: {tensor.data_size}', + f'Min: {min(tensor_data)}', + f'Max: {max(tensor_data)}', + f'Mean: {sum(tensor_data) / len(tensor_data)}', + f'Transposed: {tensor.is_transposed}', + f'Permuted: {tensor.is_permuted}', + f'Contiguous: {tensor.is_contiguous}', + ] + self.info_panel.setText('\n'.join(extra_info)) + + +def main(): + app = QApplication(sys.argv) + viewer = MAGNETRONViewer() + viewer.show() + sys.exit(app.exec_()) + + +if __name__ == '__main__': + main() diff --git a/python/requirements.txt b/python/requirements.txt new file mode 100644 index 0000000..fe5b858 --- /dev/null +++ b/python/requirements.txt @@ -0,0 +1,7 @@ +cffi +pycparser +setuptools +pytest +pyqt5 +numpy +matplotlib diff --git a/python/tests/__init__.py b/python/tests/__init__.py new file mode 100644 index 0000000..d6ecd02 --- /dev/null +++ b/python/tests/__init__.py @@ -0,0 +1 @@ +# (c) 2024 Mario "Neo" Sieg. diff --git a/python/tests/tests.py b/python/tests/tests.py new file mode 100644 index 0000000..f5908bc --- /dev/null +++ b/python/tests/tests.py @@ -0,0 +1,91 @@ +# (c) 2024 Mario 'Neo' Sieg. + +def test_import_magnetron(): + import magnetron + assert magnetron.__version__ is not None + + +def test_simple_exec(): + import magnetron as mag + a = mag.Tensor.const([1, 4, 1]) + assert a.max().scalar() == 4 + +from magnetron import * + +def test_context_creation(): + # Test that a context can be created and defaults are correct. + ctx = Context.active() + assert ctx.compute_device in (ComputeDevice.CPU, ComputeDevice.CUDA) + assert ctx.execution_mode.name in ('EAGER', 'DEFERRED') + assert isinstance(ctx.os_name, str) + assert isinstance(ctx.cpu_name, str) + assert ctx.cpu_virtual_cores >= 1 + assert ctx.cpu_physical_cores >= 1 + +def test_tensor_creation(): + # Test creating a simple tensor + t = Tensor.empty((3, 3)) + assert t.shape == (3, 3) + assert t.numel == 9 + assert t.dtype == DType.F32 + assert t.is_matrix is True + +def test_tensor_fill(): + # Test creating a tensor filled with a constant value + val = 42.0 + t = Tensor.full((2, 2), fill_value=val) + data = t.to_list() + assert all(x == val for x in data) + assert t.shape == (2, 2) + +def test_tensor_zeros(): + # Test zero tensor + t = Tensor.zeros((2, 3)) + data = t.to_list() + assert all(x == 0.0 for x in data) + +def test_tensor_ops(): + # Test basic arithmetic operations + a = Tensor.const([[1, 2], [3, 4]]) + b = Tensor.const([[4, 3], [2, 1]]) + + c = a + b + assert c.rank == 2 + assert c.shape == (2, 2) + assert c.to_list() == [5.0, 5.0, 5.0, 5.0] + + d = a - b + assert d.to_list() == [-3.0, -1.0, 1.0, 3.0] + + e = a * 2.0 + assert e.to_list() == [2.0, 4.0, 6.0, 8.0] + +def test_tensor_inplace_ops(): + # Test in-place operations + t = Tensor.const([1, 2, 3]) + t += 1 + assert t.to_list() == [2.0, 3.0, 4.0] + + t *= 2 + assert t.to_list() == [4.0, 6.0, 8.0] + +def test_tensor_unary_ops(): + # Test unary operations like neg, abs, sqrt + t = Tensor.const([-1, -4, 9]) + neg_t = t.neg() + assert neg_t.to_list() == [1.0, 4.0, -9.0] + + abs_t = t.abs() + assert abs_t.to_list() == [1.0, 4.0, 9.0] + + sqrt_t = abs_t.sqrt() + # sqrt(1)=1, sqrt(4)=2, sqrt(9)=3 + assert sqrt_t.to_list() == [1.0, 2.0, 3.0] + +def test_save_and_load(tmp_path): + t = Tensor.const([[1, 2], [3, 4]]) + file_path = tmp_path / 'test_tensor.magnetron' + t.save(str(file_path)) + loaded = Tensor.load(str(file_path)) + assert loaded.shape == t.shape + assert loaded.to_list() == t.to_list() diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt new file mode 100644 index 0000000..f0e8356 --- /dev/null +++ b/test/CMakeLists.txt @@ -0,0 +1,21 @@ +# (c) 2024 Mario "Neo" Sieg. + +enable_language(CXX) +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +add_subdirectory(googletest) + +file(GLOB_RECURSE TEST_SOURCES *.cpp) +add_executable(magnetron_test ${TEST_SOURCES}) +target_link_libraries(magnetron_test magnetron) +target_include_directories(magnetron_test PRIVATE ../magnetron) +target_link_libraries( + magnetron_test + GTest::gtest_main +) +add_test(NAME magnetron_test COMMAND magnetron_test) + +if (${MAGNETRON_ENABLE_CUDA}) + target_compile_definitions(magnetron_test PRIVATE MAG_ENABLE_CUDA) +endif() diff --git a/test/allocators.cpp b/test/allocators.cpp new file mode 100644 index 0000000..47a2e0e --- /dev/null +++ b/test/allocators.cpp @@ -0,0 +1,72 @@ +// (c) 2024 Mario "Neo" Sieg. + +#include "prelude.hpp" + +TEST(allocators, fixed_intrusive_pool_alloc_free) { + mag_fixed_intrusive_pool pool {}; + mag_fixed_intrusive_pool_init(&pool, sizeof(int), alignof(int), 8); + ASSERT_EQ(pool.num_allocs, 0); + ASSERT_EQ(pool.num_chunks, 1); + ASSERT_EQ(pool.num_freelist_hits, 0); + ASSERT_EQ(pool.num_pool_hits, 0); + for (int i = 0; i < 8192; ++i) { + int* x = static_cast(mag_fixed_intrusive_pool_malloc(&pool)); + ASSERT_EQ(pool.num_allocs, i+1); + if (i >= 1) { + ASSERT_EQ(pool.num_freelist_hits, i); + } + *x = -1; + ASSERT_EQ(*x, -1); + mag_fixed_intrusive_pool_free(&pool, x); + } + ASSERT_EQ(pool.num_chunks, 1); + ASSERT_EQ(pool.num_pool_hits, 1); + mag_fixed_intrusive_pool_destroy(&pool); +} + +TEST(allocators, fixed_intrusive_pool_exhaust_pool) { + mag_fixed_intrusive_pool pool {}; + mag_fixed_intrusive_pool_init(&pool, sizeof(int), alignof(int), 8); + ASSERT_EQ(pool.num_allocs, 0); + ASSERT_EQ(pool.num_chunks, 1); + ASSERT_EQ(pool.num_freelist_hits, 0); + ASSERT_EQ(pool.num_pool_hits, 0); + for (int i = 0; i < 8192; ++i) { + [[maybe_unused]] int* volatile x = static_cast(mag_fixed_intrusive_pool_malloc(&pool)); + } + ASSERT_EQ(pool.num_chunks, 8192/8); + //ASSERT_EQ(pool.num_pool_hits, 1); + ASSERT_EQ(pool.num_freelist_hits, 0); + mag_fixed_intrusive_pool_destroy(&pool); +} + +#ifndef _MSC_VER // MSVC fucks around with linking a __declspex(dllexport) ed function ptr. TODO: fix + +TEST(allocators, alloc_small) { + int* i = static_cast((*mag_alloc)(nullptr, sizeof(int))); + ASSERT_NE(i, nullptr); + *i = -1; + ASSERT_EQ(*i, -1); + (*mag_alloc)(i, 0); +} + +TEST(allocators, alloc_large) { + void* i = (*mag_alloc)(nullptr, 1ull<<30); + ASSERT_NE(i, nullptr); + std::memset(i, 5, 1ull<<30); + (*mag_alloc)(i, 0); +} + +#endif + +TEST(allocators, alloc_aligned) { + std::size_t align = alignof(int); + for (; align <= 8192; align <<= 1) { + int* i = static_cast(mag_alloc_aligned(sizeof(int), align)); + ASSERT_NE(i, nullptr); + *i = -1; + ASSERT_EQ(*i, -1); + ASSERT_EQ(reinterpret_cast(i) % align, 0); + mag_free_aligned(i); + } +} diff --git a/test/compute_cpu.cpp b/test/compute_cpu.cpp new file mode 100644 index 0000000..4730a38 --- /dev/null +++ b/test/compute_cpu.cpp @@ -0,0 +1,536 @@ +// (c) 2024 Mario "Neo" Sieg. + +// For some tests, we use a larger absolute error than machine epsilon, +// because the BLAS uses SIMD for certain functions which have higher accuracy than the scalar high-precision lambdas. + +#include "prelude.hpp" +#include + +static constexpr std::int64_t k_lim_same_shape = 4; +static constexpr std::int64_t k_lim_broadcast = 2; + +#define impl_test_unary_op(name, eps, op, scalar_op) \ + TEST(compute_cpu, name##_same_shape) { \ + mag_ctx_t* ctx = mag_ctx_create(MAG_COMPUTE_DEVICE_TYPE_CPU); \ + \ + for (std::int64_t i0=1; i0 <= k_lim_same_shape; ++i0) \ + for (std::int64_t i1=1; i1 <= k_lim_same_shape; ++i1) \ + for (std::int64_t i2=1; i2 <= k_lim_same_shape; ++i2) \ + for (std::int64_t i3=1; i3 <= k_lim_same_shape; ++i3) \ + for (std::int64_t i4=1; i4 <= k_lim_same_shape; ++i4) \ + for (std::int64_t i5=1; i5 <= k_lim_same_shape; ++i5) { \ + mag_tensor_t* x = mag_tensor_create_6d(ctx, MAG_DTYPE_F32, i0, i1, i2, i3, i4, i5); \ + mag_tensor_fill_random_uniform(x, 0.0f, 1.0f); \ + \ + mag_tensor_t* r = mag_##op(x); \ + \ + const auto* b_x = static_cast(mag_tensor_data_ptr(x)); \ + const auto* b_r = static_cast(mag_tensor_data_ptr(r)); \ + ASSERT_EQ(mag_tensor_numel(x), mag_tensor_numel(r)); \ + ASSERT_NE(mag_tensor_data_ptr(r), mag_tensor_data_ptr(x)); \ + for (std::int64_t i=0; i < mag_tensor_numel(x); ++i) { \ + ASSERT_NEAR(b_r[i], scalar_op(b_x[i]), (eps)); \ + } \ + mag_tensor_decref(r); \ + mag_tensor_decref(x); \ + } \ + \ + mag_ctx_destroy(ctx); \ + } \ + TEST(compute_cpu, name##_same_shape_inplace) { \ + mag_ctx_t* ctx = mag_ctx_create(MAG_COMPUTE_DEVICE_TYPE_CPU); \ + \ + for (std::int64_t i0=1; i0 <= k_lim_same_shape; ++i0) \ + for (std::int64_t i1=1; i1 <= k_lim_same_shape; ++i1) \ + for (std::int64_t i2=1; i2 <= k_lim_same_shape; ++i2) \ + for (std::int64_t i3=1; i3 <= k_lim_same_shape; ++i3) \ + for (std::int64_t i4=1; i4 <= k_lim_same_shape; ++i4) \ + for (std::int64_t i5=1; i5 <= k_lim_same_shape; ++i5) { \ + mag_tensor_t* x = mag_tensor_create_6d(ctx, MAG_DTYPE_F32, i0, i1, i2, i3, i4, i5); \ + mag_tensor_fill_random_uniform(x, 0.0f, 1.0f); \ + std::vector x_origin {}; \ + mag_tensor_buf_f32_to_vec(x, x_origin); \ + \ + mag_tensor_t* r = mag_##op##_(x); \ + mag_tensor_set_name(r, "result"); \ + \ + const auto* b_x = x_origin.data(); \ + mag_tensor_set_name(x, "X"); \ + const auto* b_r = static_cast(mag_tensor_data_ptr(r)); \ + ASSERT_EQ(mag_tensor_numel(x), mag_tensor_numel(r)); \ + ASSERT_EQ(mag_tensor_data_ptr(x), mag_tensor_data_ptr(r)); \ + for (std::int64_t i=0; i < mag_tensor_numel(x); ++i) { \ + ASSERT_NEAR(b_r[i], scalar_op(b_x[i]), (eps)); \ + } \ + mag_tensor_decref(r); \ + mag_tensor_decref(x); \ + } \ + \ + mag_ctx_destroy(ctx); \ + } + +impl_test_unary_op(abs, 1e-6, abs, [](float x) -> float { + return std::abs(x); +}) + +TEST(compute_cpu, neg_same_shape) { + mag_ctx_t* ctx = mag_ctx_create(MAG_COMPUTE_DEVICE_TYPE_CPU); + for (std::int64_t i0=1; i0 <= k_lim_same_shape; ++i0) + for (std::int64_t i1=1; i1 <= k_lim_same_shape; ++i1) + for (std::int64_t i2=1; i2 <= k_lim_same_shape; ++i2) + for (std::int64_t i3=1; i3 <= k_lim_same_shape; ++i3) + for (std::int64_t i4=1; i4 <= k_lim_same_shape; ++i4) + for (std::int64_t i5=1; i5 <= k_lim_same_shape; ++i5) { + mag_tensor_t* x = mag_tensor_create_6d(ctx, MAG_DTYPE_F32, i0, i1, i2, i3, i4, i5); + mag_tensor_fill_random_uniform(x, 0.0f, 1.0f); + mag_tensor_t* r = mag_neg(x); + const auto* b_x = static_cast(mag_tensor_data_ptr(x)); + const auto* b_r = static_cast(mag_tensor_data_ptr(r)); + ASSERT_EQ(mag_tensor_numel(x), mag_tensor_numel(r)); + for (std::int64_t i=0; i < mag_tensor_numel(x); ++i) { + ASSERT_EQ(b_r[i], -b_x[i]); + } + mag_tensor_decref(x); + mag_tensor_decref(r); + } + mag_ctx_destroy(ctx); +} + +impl_test_unary_op(log, 1e-6, log, [](float x) -> float { + return std::log(x); +}) + +impl_test_unary_op(sqr, 1e-9, sqr, [](float x) -> float { + return x*x; +}) + +impl_test_unary_op(sqrt, 1e-9, sqrt, [](float x) -> float { + return std::sqrt(x); +}) + +impl_test_unary_op(sin, 1e-6, sin, [](float x) -> float { + return std::sin(x); +}) + +#if 0 // TODO fix +impl_test_unary_op(cos, 1e-6, cos, [](float x) -> float { + return std::cos(x); +}) +#endif + +impl_test_unary_op(step, 1e-9, step, [](float x) -> float { + return x >= 0.0f ? 1.0f : 0.0f; +}) + +impl_test_unary_op(softmax, 1e-6, softmax, [](float x) -> float { + return std::exp(x); +}) +impl_test_unary_op(softmax_dv, 1e-6, softmax_dv, [](float x) -> float { + return std::exp(x); +}) + +impl_test_unary_op(sigmoid, 1e-6, sigmoid, [](float x) -> float { + return 1.0f / (1.0f + std::exp(-x)); +}) +impl_test_unary_op(sigmoid_dv, 1e-6, sigmoid_dv, [](float x) -> float { + return x * (1.0f - x); +}) + +impl_test_unary_op(hard_sigmoid, 1e-6, hard_sigmoid, [](float x) -> float { + return std::min(1.0f, std::max(0.0f, (x + 3.0f) / 6.0f)); +}) +//impl_test_unary_op(hard_sigmoid_dv, HARD_SIGMOID_DV, [](float x) -> float { +// return -(std::exp(x) / ((std::exp(x)+1.0f)*(std::exp(x)+1.0f))); +//}) + +impl_test_unary_op(silu, 1e-6, silu, [](float x) -> float { + return x / (1.0f + std::exp(-x)); +}) +//impl_test_unary_op(silu_dv, SILU_DV, [](float x) -> float { +// return -(std::exp(x) / ((std::exp(x)+1.0f)*(std::exp(x)+1.0f))); +//}) + +impl_test_unary_op(tanh, 1e-3, tanh, [](float x) -> float { + return std::tanh(x); +}) +impl_test_unary_op(tanh_dv, 1e-9, tanh_dv, [](float x) -> float { + return 1.0f / (std::cosh(x)*std::cosh(x)); +}) + +impl_test_unary_op(relu, 1e-9, relu, [](float x) -> float { + return std::max(x, 0.0f); +}) +impl_test_unary_op(relu_dv, 1e-9, relu_dv, [](float x) -> float { + return x <= 0.0f ? 0.0f : 1.0f; +}) + +impl_test_unary_op(gelu, 1e-3, gelu, [](float x) -> float { + return 0.5f*x*(1.0f + std::tanh(0.79788456080286535587989211986876f*x*(1.0f + 0.044715f*x*x))); +}) +//impl_test_unary_op(gelu_dv, GELU_DV, [](float x) -> float { +// return x <= 0.0f ? 0.0f : 1.0f; +//}) + +#undef impl_test_unary_op + +#define impl_test_binary_op(name, op, scalar_op) \ + TEST(compute_cpu, name##_same_shape) { \ + mag_ctx_t* ctx = mag_ctx_create(MAG_COMPUTE_DEVICE_TYPE_CPU); \ + for (std::int64_t i0=1; i0 <= k_lim_same_shape; ++i0) \ + for (std::int64_t i1=1; i1 <= k_lim_same_shape; ++i1) \ + for (std::int64_t i2=1; i2 <= k_lim_same_shape; ++i2) \ + for (std::int64_t i3=1; i3 <= k_lim_same_shape; ++i3) \ + for (std::int64_t i4=1; i4 <= k_lim_same_shape; ++i4) \ + for (std::int64_t i5=1; i5 <= k_lim_same_shape; ++i5) { \ + mag_tensor_t* x = mag_tensor_create_6d(ctx, MAG_DTYPE_F32, i0, i1, i2, i3, i4, i5); \ + mag_tensor_t* y = mag_clone(x); \ + mag_tensor_fill_random_uniform(x, 0.0f, 1.0f); \ + mag_tensor_fill_random_uniform(y, -5.0f, 5.0f); \ + \ + mag_tensor_t* r = mag_##op(x, y); \ + \ + \ + const auto* b_x = static_cast(mag_tensor_data_ptr(x)); \ + const auto* b_y = static_cast(mag_tensor_data_ptr(y)); \ + const auto* b_r = static_cast(mag_tensor_data_ptr(r)); \ + ASSERT_NE(mag_tensor_data_ptr(r), mag_tensor_data_ptr(x)); \ + ASSERT_EQ(mag_tensor_numel(x), mag_tensor_numel(y)); \ + ASSERT_EQ(mag_tensor_numel(r), mag_tensor_numel(y)); \ + for (std::int64_t i=0; i < mag_tensor_numel(x); ++i) { \ + ASSERT_FLOAT_EQ(b_r[i], b_x[i] scalar_op b_y[i]); \ + } \ + mag_tensor_decref(r); \ + mag_tensor_decref(x); \ + mag_tensor_decref(y); \ + } \ + \ + mag_ctx_destroy(ctx); \ + } \ + \ + TEST(compute_cpu, name##_scalar_broadcast) { \ + mag_ctx_t* ctx = mag_ctx_create(MAG_COMPUTE_DEVICE_TYPE_CPU); \ + for (std::int64_t factor=2; factor <= 4; ++factor) \ + for (std::int64_t i0=1; i0 <= k_lim_broadcast; ++i0) \ + for (std::int64_t i1=1; i1 <= k_lim_broadcast; ++i1) \ + for (std::int64_t i2=1; i2 <= k_lim_broadcast; ++i2) \ + for (std::int64_t i3=1; i3 <= k_lim_broadcast; ++i3) \ + for (std::int64_t i4=1; i4 <= k_lim_broadcast; ++i4) \ + for (std::int64_t i5=1; i5 <= k_lim_broadcast; ++i5) { \ + mag_tensor_t* x = mag_tensor_create_6d(ctx, MAG_DTYPE_F32, i0*factor, i1*factor, i2*factor, i3*factor, i4*factor, i5*factor); \ + mag_tensor_t* y = mag_tensor_create_6d(ctx, MAG_DTYPE_F32, i0, i1, i2, i3, i4, i5); \ + mag_tensor_fill_random_uniform(x, 0.0f, 1.0f); \ + mag_tensor_fill(y, 2.2f); \ + \ + mag_tensor_t* r = mag_##op(x, y); \ + \ + const auto* b_x = static_cast(mag_tensor_data_ptr(x)); \ + const auto* b_r = static_cast(mag_tensor_data_ptr(r)); \ + ASSERT_EQ(mag_tensor_numel(r), mag_tensor_numel(x)); \ + ASSERT_NE(mag_tensor_numel(x), mag_tensor_numel(y)); \ + ASSERT_NE(mag_tensor_data_ptr(r), mag_tensor_data_ptr(x)); \ + for (std::int64_t i=0; i < mag_tensor_numel(x); ++i) { \ + ASSERT_FLOAT_EQ(b_r[i], b_x[i] scalar_op 2.2f); \ + } \ + mag_tensor_decref(r); \ + mag_tensor_decref(x); \ + mag_tensor_decref(y); \ + } \ + \ + mag_ctx_destroy(ctx); \ + } \ + TEST(compute_cpu, name##_same_shape_inplace) { \ + mag_ctx_t* ctx = mag_ctx_create(MAG_COMPUTE_DEVICE_TYPE_CPU); \ + for (std::int64_t i0=1; i0 <= k_lim_same_shape; ++i0) \ + for (std::int64_t i1=1; i1 <= k_lim_same_shape; ++i1) \ + for (std::int64_t i2=1; i2 <= k_lim_same_shape; ++i2) \ + for (std::int64_t i3=1; i3 <= k_lim_same_shape; ++i3) \ + for (std::int64_t i4=1; i4 <= k_lim_same_shape; ++i4) \ + for (std::int64_t i5=1; i5 <= k_lim_same_shape; ++i5) { \ + mag_tensor_t* x = mag_tensor_create_6d(ctx, MAG_DTYPE_F32, i0, i1, i2, i3, i4, i5); \ + mag_tensor_t* y = mag_clone(x); \ + mag_tensor_fill_random_uniform(x, 0.0f, 1.0f); \ + mag_tensor_fill_random_uniform(y, -5.0f, 5.0f); \ + std::vector x_origin {}; \ + mag_tensor_buf_f32_to_vec(x, x_origin); \ + \ + mag_tensor_t* r = mag_##op##_(x, y); \ + \ + \ + const auto* b_x = x_origin.data(); \ + const auto* b_xx = static_cast(mag_tensor_data_ptr(x)); \ + const auto* b_y = static_cast(mag_tensor_data_ptr(y)); \ + const auto* b_r = static_cast(mag_tensor_data_ptr(r)); \ + ASSERT_EQ(mag_tensor_numel(x), mag_tensor_numel(y)); \ + ASSERT_EQ(mag_tensor_numel(r), mag_tensor_numel(y)); \ + ASSERT_EQ(mag_tensor_data_ptr(r), mag_tensor_data_ptr(x)); \ + for (std::int64_t i=0; i < mag_tensor_numel(x); ++i) { \ + ASSERT_FLOAT_EQ(b_r[i], b_xx[i]); \ + ASSERT_FLOAT_EQ(b_r[i], b_x[i] scalar_op b_y[i]); \ + } \ + mag_tensor_decref(r); \ + mag_tensor_decref(x); \ + mag_tensor_decref(y); \ + } \ + \ + mag_ctx_destroy(ctx); \ + } \ + \ + TEST(compute_cpu, name##_scalar_broadcast_inplace) { \ + mag_ctx_t* ctx = mag_ctx_create(MAG_COMPUTE_DEVICE_TYPE_CPU); \ + for (std::int64_t factor=2; factor <= 4; ++factor) \ + for (std::int64_t i0=1; i0 <= k_lim_broadcast; ++i0) \ + for (std::int64_t i1=1; i1 <= k_lim_broadcast; ++i1) \ + for (std::int64_t i2=1; i2 <= k_lim_broadcast; ++i2) \ + for (std::int64_t i3=1; i3 <= k_lim_broadcast; ++i3) \ + for (std::int64_t i4=1; i4 <= k_lim_broadcast; ++i4) \ + for (std::int64_t i5=1; i5 <= k_lim_broadcast; ++i5) { \ + mag_tensor_t* x = mag_tensor_create_6d(ctx, MAG_DTYPE_F32, i0*factor, i1*factor, i2*factor, i3*factor, i4*factor, i5*factor); \ + mag_tensor_t* y = mag_tensor_create_6d(ctx, MAG_DTYPE_F32, i0, i1, i2, i3, i4, i5); \ + mag_tensor_fill_random_uniform(x, 0.0f, 1.0f); \ + mag_tensor_fill(y, 2.2f); \ + \ + std::vector x_origin {}; \ + mag_tensor_buf_f32_to_vec(x, x_origin); \ + \ + mag_tensor_t* r = mag_##op##_(x, y); \ + \ + const auto* b_x = x_origin.data(); \ + const auto* b_xx = static_cast(mag_tensor_data_ptr(x)); \ + const auto* b_r = static_cast(mag_tensor_data_ptr(r)); \ + ASSERT_EQ(mag_tensor_numel(r), mag_tensor_numel(x)); \ + ASSERT_NE(mag_tensor_numel(x), mag_tensor_numel(y)); \ + ASSERT_EQ(mag_tensor_data_ptr(r), mag_tensor_data_ptr(x)); \ + for (std::int64_t i=0; i < mag_tensor_numel(x); ++i) { \ + ASSERT_FLOAT_EQ(b_r[i], b_xx[i]); \ + ASSERT_FLOAT_EQ(b_r[i], b_x[i] scalar_op 2.2f); \ + } \ + mag_tensor_decref(r); \ + mag_tensor_decref(x); \ + mag_tensor_decref(y); \ + } \ + \ + mag_ctx_destroy(ctx); \ + } \ + TEST(compute_cpu, name##_scalar) { \ + mag_ctx_t* ctx = mag_ctx_create(MAG_COMPUTE_DEVICE_TYPE_CPU); \ + for (std::int64_t i0=1; i0 <= k_lim_same_shape; ++i0) \ + for (std::int64_t i1=1; i1 <= k_lim_same_shape; ++i1) \ + for (std::int64_t i2=1; i2 <= k_lim_same_shape; ++i2) \ + for (std::int64_t i3=1; i3 <= k_lim_same_shape; ++i3) \ + for (std::int64_t i4=1; i4 <= k_lim_same_shape; ++i4) \ + for (std::int64_t i5=1; i5 <= k_lim_same_shape; ++i5) { \ + mag_tensor_t* x = mag_tensor_create_6d(ctx, MAG_DTYPE_F32, i0, i1, i2, i3, i4, i5); \ + mag_tensor_fill_random_uniform(x, 0.0f, 1.0f); \ + \ + mag_tensor_t* r = mag_##op##s(x, static_cast(i0+i1+i2+i3+i4+i5)*0.221f); \ + \ + \ + const auto* b_x = static_cast(mag_tensor_data_ptr(x)); \ + const auto* b_r = static_cast(mag_tensor_data_ptr(r)); \ + ASSERT_NE(mag_tensor_data_ptr(r), mag_tensor_data_ptr(x)); \ + for (std::int64_t i=0; i < mag_tensor_numel(x); ++i) { \ + ASSERT_FLOAT_EQ(b_r[i], b_x[i] scalar_op (static_cast(i0+i1+i2+i3+i4+i5)*0.221f)); \ + } \ + mag_tensor_decref(r); \ + mag_tensor_decref(x); \ + } \ + \ + mag_ctx_destroy(ctx); \ + } + +impl_test_binary_op(add_f32, add, +) +impl_test_binary_op(sub_f32, sub, -) +impl_test_binary_op(mul_f32, mul, *) +impl_test_binary_op(div_f32, div, /) + +#undef impl_test_binary_op + +static void mag__inner_matmul_naive( + const float* A, + const float* B, + float* C, + const int64_t M, + const int64_t N, + const int64_t K +) { + for (int64_t i = 0; i < M * N; ++i) C[i] = 0.0f; + for (int64_t i = 0; i < M; ++i) { // Rows of A and C + for (int64_t k = 0; k < K; ++k) { // Columns of A, Rows of B + float a_ik = A[i * K + k]; // Access A[i][k] + for (int64_t j = 0; j < N; ++j) { // Columns of B and C + C[i * N + j] += a_ik * B[k * N + j]; // C[i][j] += A[i][k] * B[k][j] + } + } + } +} + +TEST(compute_cpu, matmul_inner_naive) { + static constexpr float A[6] = { + 1.0f, 2.0f, + 3.0f, 4.0f, + 5.0f, 6.0f + }; + static constexpr float B[2] = {0.5f, -1.0f}; + float C[3]; + mag__inner_matmul_naive(A, B, C, 3, 1, 2); + ASSERT_FLOAT_EQ(C[0], -1.5f); + ASSERT_FLOAT_EQ(C[1], -2.5f); + ASSERT_FLOAT_EQ(C[2], -3.5f); +} + +TEST(compute_cpu, matmul_f32_same_shape_2x2) { + mag_ctx_t* ctx = mag_ctx_create(MAG_COMPUTE_DEVICE_TYPE_CPU); + + static constexpr float A_values[2][2] = { + {1.6354027, -1.3607267}, + {1.8556793, 1.1689897} + }; + static constexpr float B_values[2][2] = { + {-0.6105532, 0.10695228}, + {-1.0069681, -0.40955952} + }; + + // Manually set known values for A and B + mag_tensor_t* A = mag_tensor_create_2d(ctx, MAG_DTYPE_F32, 2, 2); + mag_tensor_copy_buffer_from(A, A_values, sizeof(A_values)); + + mag_tensor_t* B = mag_tensor_create_2d(ctx, MAG_DTYPE_F32, 2, 2); + mag_tensor_copy_buffer_from(B, B_values, sizeof(B_values)); + + // Create result tensor R for matrix multiplication + mag_tensor_t* R = mag_matmul(A, B); + mag_tensor_print(R, true, true); + auto* buf = static_cast(mag_tensor_data_ptr(R)); + + static constexpr float expected[2*2] = { + 0.3717081, 0.7322086, + -2.3101263, -0.28030172 + }; + + for (int i = 0; i < 2*2; ++i) { + ASSERT_FLOAT_EQ(buf[i], expected[i]); + } + + mag_tensor_decref(A); + mag_tensor_decref(B); + mag_tensor_decref(R); + + mag_ctx_destroy(ctx); +} + +TEST(compute_cpu, matmul_f32_different_shape_2x2) { + mag_ctx_t* ctx = mag_ctx_create(MAG_COMPUTE_DEVICE_TYPE_CPU); + + static constexpr float AV[3*2] = { + 1.0f, 2.0f, + 3.0f, 4.0f, + 5.0f, 6.0f + }; + static constexpr float BV[2] = {0.5f, -1.0f}; + + // Manually set known values for A and B + mag_tensor_t* A = mag_tensor_create_2d(ctx, MAG_DTYPE_F32, 3, 2); + mag_tensor_copy_buffer_from(A, AV, sizeof(AV)); + + mag_tensor_t* B = mag_tensor_create_1d(ctx, MAG_DTYPE_F32, 2); + mag_tensor_copy_buffer_from(B, BV, sizeof(BV)); + + // Create result tensor R for matrix multiplication + mag_tensor_t* R = mag_matmul(A, B); + //ASSERT_EQ(mag_tensor_rank(R), 1); + ASSERT_EQ(mag_tensor_shape(R)[0], 3); + const auto* C = static_cast(mag_tensor_data_ptr(R)); + //mag__inner_matmul_naive(A, B, C, 3, 1, 2); + ASSERT_FLOAT_EQ(C[0], -1.5f); + ASSERT_FLOAT_EQ(C[1], -2.5f); + ASSERT_FLOAT_EQ(C[2], -3.5f); + + mag_tensor_decref(A); + mag_tensor_decref(B); + mag_tensor_decref(R); + + mag_ctx_destroy(ctx); +} + +TEST(compute_cpu, arithmetic_mean) { + mag_ctx_t* ctx = mag_ctx_create(MAG_COMPUTE_DEVICE_TYPE_CPU); + mag_tensor_t* A = mag_tensor_create_4d(ctx, MAG_DTYPE_F32, 4, 1, 3, 2); + mag_tensor_fill_random_uniform(A, -1.0f, 1.0f); + mag_tensor_t* R = mag_mean(A); + ASSERT_NE(R, nullptr); + double a_mean = 0.0; + for (std::int64_t i=0; i < mag_tensor_numel(A); ++i) + a_mean += static_cast(static_cast(mag_tensor_data_ptr(A))[i]); + a_mean /= static_cast(mag_tensor_numel(A)); + ASSERT_NEAR(static_cast(a_mean), *static_cast(mag_tensor_data_ptr(R)), 1e-9); + mag_tensor_decref(A); + mag_tensor_decref(R); + mag_ctx_destroy(ctx); +} + +TEST(compute_cpu, min) { + mag_ctx_t* ctx = mag_ctx_create(MAG_COMPUTE_DEVICE_TYPE_CPU); + mag_tensor_t* A = mag_tensor_create_4d(ctx, MAG_DTYPE_F32, 4, 1, 3, 2); + mag_tensor_fill_random_uniform(A, -1.0f, 1.0f); + mag_tensor_t* R = mag_min(A); + ASSERT_NE(R, nullptr); + float a_min = *std::min_element(static_cast(mag_tensor_data_ptr(A)), static_cast(mag_tensor_data_ptr(A)) + mag_tensor_numel(A)); + ASSERT_FLOAT_EQ(a_min, *static_cast(mag_tensor_data_ptr(R))); + mag_tensor_decref(A); + mag_tensor_decref(R); + mag_ctx_destroy(ctx); +} + +TEST(compute_cpu, max) { + mag_ctx_t* ctx = mag_ctx_create(MAG_COMPUTE_DEVICE_TYPE_CPU); + mag_tensor_t* A = mag_tensor_create_4d(ctx, MAG_DTYPE_F32, 4, 1, 3, 2); + mag_tensor_fill_random_uniform(A, -1.0f, 1.0f); + mag_tensor_t* R = mag_max(A); + ASSERT_NE(R, nullptr); + float a_min = *std::max_element(static_cast(mag_tensor_data_ptr(A)), static_cast(mag_tensor_data_ptr(A)) + mag_tensor_numel(A)); + ASSERT_FLOAT_EQ(a_min, *static_cast(mag_tensor_data_ptr(R))); + mag_tensor_decref(A); + mag_tensor_decref(R); + mag_ctx_destroy(ctx); +} + +TEST(compute_cpu, hsum) { + mag_ctx_t* ctx = mag_ctx_create(MAG_COMPUTE_DEVICE_TYPE_CPU); + mag_tensor_t* A = mag_tensor_create_4d(ctx, MAG_DTYPE_F32, 4, 1, 3, 2); + mag_tensor_fill_random_uniform(A, -1.0f, 1.0f); + mag_tensor_t* R = mag_sum(A); + ASSERT_NE(R, nullptr); + double a_sum = 0.0; + for (std::int64_t i=0; i < mag_tensor_numel(A); ++i) + a_sum += static_cast(static_cast(mag_tensor_data_ptr(A))[i]); + ASSERT_NEAR(static_cast(a_sum), *static_cast(mag_tensor_data_ptr(R)), 1e-9); + mag_tensor_decref(A); + mag_tensor_decref(R); + mag_ctx_destroy(ctx); +} + +TEST(compute_cpu, heavy_compute_single_op) { + mag_ctx_t* ctx = mag_ctx_create(MAG_COMPUTE_DEVICE_TYPE_CPU); + mag_tensor_t* A = mag_tensor_create_3d(ctx, MAG_DTYPE_F32, 8192, 8192, 3); + mag_tensor_t* B = mag_clone(A); + mag_tensor_fill(B, 3.0); + mag_tensor_t* R = mag_add(A, B); + ASSERT_NE(R, nullptr); + mag_tensor_decref(A); + mag_tensor_decref(B); + mag_tensor_decref(R); + mag_ctx_destroy(ctx); +} + +TEST(compute_cpu, heavy_compute_single_op_scalar) { + mag_ctx_t* ctx = mag_ctx_create(MAG_COMPUTE_DEVICE_TYPE_CPU); + mag_tensor_t* A = mag_tensor_create_1d(ctx, MAG_DTYPE_F32, 1); + mag_tensor_t* B = mag_clone(A); + mag_tensor_fill(B, 3.0); + mag_tensor_t* R = mag_add(A, B); + ASSERT_NE(R, nullptr); + mag_tensor_decref(A); + mag_tensor_decref(B); + mag_tensor_decref(R); + mag_ctx_destroy(ctx); +} diff --git a/test/core.cpp b/test/core.cpp new file mode 100644 index 0000000..c3baf39 --- /dev/null +++ b/test/core.cpp @@ -0,0 +1,62 @@ +// (c) 2024 Mario "Neo" Sieg. + +#include "prelude.hpp" + +#include +#include + +TEST(core, profiler_small_dims) { + mag_ctx_t* ctx = mag_ctx_create(MAG_COMPUTE_DEVICE_TYPE_CPU); + mag_ctx_profile_start_recording(ctx); + for (int i=0; i < 1000000; ++i) { + mag_tensor_t* A = mag_tensor_create_2d(ctx, MAG_DTYPE_F32, 3, 3); + mag_tensor_t* B = mag_sin(A); + mag_tensor_t* C = mag_cos(B); + [[maybe_unused]] + mag_tensor_t* D = mag_tanh(C); + mag_tensor_decref(A); + mag_tensor_decref(B); + mag_tensor_decref(C); + mag_tensor_decref(D); + } + mag_ctx_profile_stop_recording(ctx, "perf.csv"); + ASSERT_TRUE(std::filesystem::exists("perf.csv")); + std::filesystem::remove("perf.csv"); + mag_ctx_destroy(ctx); +} + +TEST(core, profiler_big_dims) { + mag_ctx_t* ctx = mag_ctx_create(MAG_COMPUTE_DEVICE_TYPE_CPU); + mag_ctx_profile_start_recording(ctx); + mag_tensor_t* A = mag_tensor_create_6d(ctx, MAG_DTYPE_F32, 32, 32, 4, 4, 4, 4); + mag_tensor_t* B = mag_sin(A); + mag_tensor_t* C = mag_cos(B); + [[maybe_unused]] + mag_tensor_t* D = mag_tanh(C); + mag_ctx_profile_stop_recording(ctx, nullptr); + mag_tensor_decref(A); + mag_tensor_decref(B); + mag_tensor_decref(C); + mag_tensor_decref(D); + mag_ctx_destroy(ctx); +} + +#if 0 +TEST(core, crc32) { + ASSERT_EQ(mag__crc32c("Hello, World!", std::strlen("Hello, World!")), 1297420392); + uint8_t y = 0x3f; + ASSERT_EQ(mag__crc32c(& y, sizeof(y)), 1015883460); + ASSERT_EQ(mag__crc32c(nullptr, 0), 0); + ASSERT_EQ(mag__crc32c("AB", std::strlen("AB")), 3180610794); + ASSERT_EQ(mag__crc32c( + "Ich liebe Berliner Kebap, der ist einfach ultra schmackofatz, gerade um 4 Uhr Morgens nach einer langen Clubnacht.", + std::strlen( + "Ich liebe Berliner Kebap, der ist einfach ultra schmackofatz, gerade um 4 Uhr Morgens nach einer langen Clubnacht.")), 60440201); + std::vector huge {}; + huge.resize(0xffff); + for (std::size_t i = 0; i < huge.size(); ++i) { + huge[i] = i % 0xff; + } + ASSERT_EQ(mag__crc32c(huge.data(), huge.size()), 2008503331); +} +#endif \ No newline at end of file diff --git a/test/ctx.cpp b/test/ctx.cpp new file mode 100644 index 0000000..bc0dcd5 --- /dev/null +++ b/test/ctx.cpp @@ -0,0 +1,21 @@ +// (c) 2024 Mario "Neo" Sieg. + +#include "prelude.hpp" + +TEST(ctx, create_destroy_cpu) { + mag_set_log_mode(true); + mag_ctx_t* ctx = mag_ctx_create(MAG_COMPUTE_DEVICE_TYPE_CPU); + ASSERT_NE(ctx, nullptr); + mag_ctx_destroy(ctx); + mag_set_log_mode(false); +} + +#ifdef MAG_ENABLE_CUDA +TEST(ctx, create_destroy_cuda) { + mag_set_log_mode(true); + mag_ctx_t* ctx = mag_ctx_create(MAG_COMPUTE_DEVICE_TYPE_GPU_CUDA); + ASSERT_NE(ctx, nullptr); + mag_ctx_destroy(ctx); + mag_set_log_mode(false); +} +#endif diff --git a/test/cuda.cpp b/test/cuda.cpp new file mode 100644 index 0000000..44791e7 --- /dev/null +++ b/test/cuda.cpp @@ -0,0 +1,38 @@ +// (c) 2024 Mario "Neo" Sieg. + +#ifdef MAG_ENABLE_CUDA + +#include "prelude.hpp" +#include "magnetron_cuda.cuh" + +using namespace mag::cuda; + +TEST(cuda, init) { + const auto* device = cuda_init(0); + ASSERT_NE(device, nullptr); + const auto& dev = *device; + std::cout << "Device ID: " << dev.id << std::endl; + std::cout << "Device Name: " << dev.name.data() << std::endl; + std::cout << "Video Memory: " << static_cast(dev.vram) / 1e9 << " GB" << std::endl; + std::cout << "Compute Capability: " << dev.cl << std::endl; + std::cout << "Number of SMs: " << dev.nsm << std::endl; + std::cout << "Number of Threads per Block: " << dev.ntpb << std::endl; + std::cout << "Shared Memory per Block: " << static_cast(dev.smpb) / 1e3 << " KB" << std::endl; + std::cout << "Shared Memory per Block Opt-In: " << static_cast(dev.smpb_opt) / 1e3 << " KB" << std::endl; + std::cout << "Has Virtual Memory Management: " << dev.has_vmm << std::endl; + std::cout << "Virtual Memory Management Granularity: " << static_cast(dev.vmm_granularity) / 1e3 << " KB" << std::endl; +} + +TEST(cuda, vm_pool) { + const auto* device = cuda_init(0); + ASSERT_NE(device, nullptr); + vm_pool pool {*device}; + std::size_t actual_sz = 0; + auto* ptr = pool.alloc(1024, 32, actual_sz); + ASSERT_NE(ptr, nullptr); + ASSERT_EQ(reinterpret_cast(ptr) % 32, 0); + std::cout << "Allocated: " << actual_sz << " bytes" << std::endl; + pool.free(ptr, actual_sz); +} + +#endif diff --git a/test/googletest/.clang-format b/test/googletest/.clang-format new file mode 100644 index 0000000..5b9bfe6 --- /dev/null +++ b/test/googletest/.clang-format @@ -0,0 +1,4 @@ +# Run manually to reformat a file: +# clang-format -i --style=file +Language: Cpp +BasedOnStyle: Google diff --git a/test/googletest/.github/ISSUE_TEMPLATE/00-bug_report.yml b/test/googletest/.github/ISSUE_TEMPLATE/00-bug_report.yml new file mode 100644 index 0000000..586779a --- /dev/null +++ b/test/googletest/.github/ISSUE_TEMPLATE/00-bug_report.yml @@ -0,0 +1,53 @@ +name: Bug Report +description: Let us know that something does not work as expected. +title: "[Bug]: Please title this bug report" +body: + - type: textarea + id: what-happened + attributes: + label: Describe the issue + description: What happened, and what did you expect to happen? + validations: + required: true + - type: textarea + id: steps + attributes: + label: Steps to reproduce the problem + description: It is important that we are able to reproduce the problem that you are experiencing. Please provide all code and relevant steps to reproduce the problem, including your `BUILD`/`CMakeLists.txt` file and build commands. Links to a GitHub branch or [godbolt.org](https://godbolt.org/) that demonstrate the problem are also helpful. + validations: + required: true + - type: textarea + id: version + attributes: + label: What version of GoogleTest are you using? + description: Please include the output of `git rev-parse HEAD` or the GoogleTest release version number that you are using. + validations: + required: true + - type: textarea + id: os + attributes: + label: What operating system and version are you using? + description: If you are using a Linux distribution please include the name and version of the distribution as well. + validations: + required: true + - type: textarea + id: compiler + attributes: + label: What compiler and version are you using? + description: Please include the output of `gcc -v` or `clang -v`, or the equivalent for your compiler. + validations: + required: true + - type: textarea + id: buildsystem + attributes: + label: What build system are you using? + description: Please include the output of `bazel --version` or `cmake --version`, or the equivalent for your build system. + validations: + required: true + - type: textarea + id: additional + attributes: + label: Additional context + description: Add any other context about the problem here. + validations: + required: false diff --git a/test/googletest/.github/ISSUE_TEMPLATE/10-feature_request.yml b/test/googletest/.github/ISSUE_TEMPLATE/10-feature_request.yml new file mode 100644 index 0000000..f3bbc09 --- /dev/null +++ b/test/googletest/.github/ISSUE_TEMPLATE/10-feature_request.yml @@ -0,0 +1,33 @@ +name: Feature request +description: Propose a new feature. +title: "[FR]: Please title this feature request" +labels: "enhancement" +body: + - type: textarea + id: version + attributes: + label: Does the feature exist in the most recent commit? + description: We recommend using the latest commit from GitHub in your projects. + validations: + required: true + - type: textarea + id: why + attributes: + label: Why do we need this feature? + description: Ideally, explain why a combination of existing features cannot be used instead. + validations: + required: true + - type: textarea + id: proposal + attributes: + label: Describe the proposal. + description: Include a detailed description of the feature, with usage examples. + validations: + required: true + - type: textarea + id: platform + attributes: + label: Is the feature specific to an operating system, compiler, or build system version? + description: If it is, please specify which versions. + validations: + required: true diff --git a/test/googletest/.github/ISSUE_TEMPLATE/config.yml b/test/googletest/.github/ISSUE_TEMPLATE/config.yml new file mode 100644 index 0000000..65170d1 --- /dev/null +++ b/test/googletest/.github/ISSUE_TEMPLATE/config.yml @@ -0,0 +1,5 @@ +blank_issues_enabled: false +contact_links: + - name: Get Help + url: https://github.com/google/googletest/discussions + about: Please ask and answer questions here. diff --git a/test/googletest/.gitignore b/test/googletest/.gitignore new file mode 100644 index 0000000..f0df39d --- /dev/null +++ b/test/googletest/.gitignore @@ -0,0 +1,89 @@ +# Ignore CI build directory +build/ +xcuserdata +cmake-build-debug/ +.idea/ +bazel-bin +bazel-genfiles +bazel-googletest +bazel-out +bazel-testlogs +MODULE.bazel.lock +# python +*.pyc + +# Visual Studio files +.vs +*.sdf +*.opensdf +*.VC.opendb +*.suo +*.user +_ReSharper.Caches/ +Win32-Debug/ +Win32-Release/ +x64-Debug/ +x64-Release/ + +# VSCode files +.cache/ +cmake-variants.yaml + +# Ignore autoconf / automake files +Makefile.in +aclocal.m4 +configure +build-aux/ +autom4te.cache/ +googletest/m4/libtool.m4 +googletest/m4/ltoptions.m4 +googletest/m4/ltsugar.m4 +googletest/m4/ltversion.m4 +googletest/m4/lt~obsolete.m4 +googlemock/m4 + +# Ignore generated directories. +googlemock/fused-src/ +googletest/fused-src/ + +# macOS files +.DS_Store +googletest/.DS_Store +googletest/xcode/.DS_Store + +# Ignore cmake generated directories and files. +CMakeFiles +CTestTestfile.cmake +Makefile +cmake_install.cmake +googlemock/CMakeFiles +googlemock/CTestTestfile.cmake +googlemock/Makefile +googlemock/cmake_install.cmake +googlemock/gtest +/bin +/googlemock/gmock.dir +/googlemock/gmock_main.dir +/googlemock/RUN_TESTS.vcxproj.filters +/googlemock/RUN_TESTS.vcxproj +/googlemock/INSTALL.vcxproj.filters +/googlemock/INSTALL.vcxproj +/googlemock/gmock_main.vcxproj.filters +/googlemock/gmock_main.vcxproj +/googlemock/gmock.vcxproj.filters +/googlemock/gmock.vcxproj +/googlemock/gmock.sln +/googlemock/ALL_BUILD.vcxproj.filters +/googlemock/ALL_BUILD.vcxproj +/lib +/Win32 +/ZERO_CHECK.vcxproj.filters +/ZERO_CHECK.vcxproj +/RUN_TESTS.vcxproj.filters +/RUN_TESTS.vcxproj +/INSTALL.vcxproj.filters +/INSTALL.vcxproj +/googletest-distribution.sln +/CMakeCache.txt +/ALL_BUILD.vcxproj.filters +/ALL_BUILD.vcxproj diff --git a/test/googletest/BUILD.bazel b/test/googletest/BUILD.bazel new file mode 100644 index 0000000..0306468 --- /dev/null +++ b/test/googletest/BUILD.bazel @@ -0,0 +1,236 @@ +# Copyright 2017 Google Inc. +# All Rights Reserved. +# +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# Bazel Build for Google C++ Testing Framework(Google Test) + +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) + +exports_files(["LICENSE"]) + +config_setting( + name = "qnx", + constraint_values = ["@platforms//os:qnx"], +) + +config_setting( + name = "windows", + constraint_values = ["@platforms//os:windows"], +) + +config_setting( + name = "freebsd", + constraint_values = ["@platforms//os:freebsd"], +) + +config_setting( + name = "openbsd", + constraint_values = ["@platforms//os:openbsd"], +) + +# NOTE: Fuchsia is not an officially supported platform. +config_setting( + name = "fuchsia", + constraint_values = ["@platforms//os:fuchsia"], +) + +config_setting( + name = "msvc_compiler", + flag_values = { + "@bazel_tools//tools/cpp:compiler": "msvc-cl", + }, + visibility = [":__subpackages__"], +) + +config_setting( + name = "has_absl", + values = {"define": "absl=1"}, +) + +# Library that defines the FRIEND_TEST macro. +cc_library( + name = "gtest_prod", + hdrs = ["googletest/include/gtest/gtest_prod.h"], + includes = ["googletest/include"], +) + +# Google Test including Google Mock +cc_library( + name = "gtest", + srcs = glob( + include = [ + "googletest/src/*.cc", + "googletest/src/*.h", + "googletest/include/gtest/**/*.h", + "googlemock/src/*.cc", + "googlemock/include/gmock/**/*.h", + ], + exclude = [ + "googletest/src/gtest-all.cc", + "googletest/src/gtest_main.cc", + "googlemock/src/gmock-all.cc", + "googlemock/src/gmock_main.cc", + ], + ), + hdrs = glob([ + "googletest/include/gtest/*.h", + "googlemock/include/gmock/*.h", + ]), + copts = select({ + ":qnx": [], + ":windows": [], + "//conditions:default": ["-pthread"], + }), + defines = select({ + ":has_absl": ["GTEST_HAS_ABSL=1"], + "//conditions:default": [], + }), + features = select({ + ":windows": ["windows_export_all_symbols"], + "//conditions:default": [], + }), + includes = [ + "googlemock", + "googlemock/include", + "googletest", + "googletest/include", + ], + linkopts = select({ + ":qnx": ["-lregex"], + ":windows": [], + ":freebsd": [ + "-lm", + "-pthread", + ], + ":openbsd": [ + "-lm", + "-pthread", + ], + "//conditions:default": ["-pthread"], + }), + deps = select({ + ":has_absl": [ + "@abseil-cpp//absl/container:flat_hash_set", + "@abseil-cpp//absl/debugging:failure_signal_handler", + "@abseil-cpp//absl/debugging:stacktrace", + "@abseil-cpp//absl/debugging:symbolize", + "@abseil-cpp//absl/flags:flag", + "@abseil-cpp//absl/flags:parse", + "@abseil-cpp//absl/flags:reflection", + "@abseil-cpp//absl/flags:usage", + "@abseil-cpp//absl/strings", + "@abseil-cpp//absl/types:any", + "@abseil-cpp//absl/types:optional", + "@abseil-cpp//absl/types:variant", + "@re2//:re2", + ], + "//conditions:default": [], + }) + select({ + # `gtest-death-test.cc` has `EXPECT_DEATH` that spawns a process, + # expects it to crash and inspects its logs with the given matcher, + # so that's why these libraries are needed. + # Otherwise, builds targeting Fuchsia would fail to compile. + ":fuchsia": [ + "@fuchsia_sdk//pkg/fdio", + "@fuchsia_sdk//pkg/syslog", + "@fuchsia_sdk//pkg/zx", + ], + "//conditions:default": [], + }), +) + +cc_library( + name = "gtest_main", + srcs = ["googlemock/src/gmock_main.cc"], + features = select({ + ":windows": ["windows_export_all_symbols"], + "//conditions:default": [], + }), + deps = [":gtest"], +) + +# The following rules build samples of how to use gTest. +cc_library( + name = "gtest_sample_lib", + srcs = [ + "googletest/samples/sample1.cc", + "googletest/samples/sample2.cc", + "googletest/samples/sample4.cc", + ], + hdrs = [ + "googletest/samples/prime_tables.h", + "googletest/samples/sample1.h", + "googletest/samples/sample2.h", + "googletest/samples/sample3-inl.h", + "googletest/samples/sample4.h", + ], + features = select({ + ":windows": ["windows_export_all_symbols"], + "//conditions:default": [], + }), +) + +cc_test( + name = "gtest_samples", + size = "small", + # All Samples except: + # sample9 (main) + # sample10 (main and takes a command line option and needs to be separate) + srcs = [ + "googletest/samples/sample1_unittest.cc", + "googletest/samples/sample2_unittest.cc", + "googletest/samples/sample3_unittest.cc", + "googletest/samples/sample4_unittest.cc", + "googletest/samples/sample5_unittest.cc", + "googletest/samples/sample6_unittest.cc", + "googletest/samples/sample7_unittest.cc", + "googletest/samples/sample8_unittest.cc", + ], + linkstatic = 0, + deps = [ + "gtest_sample_lib", + ":gtest_main", + ], +) + +cc_test( + name = "sample9_unittest", + size = "small", + srcs = ["googletest/samples/sample9_unittest.cc"], + deps = [":gtest"], +) + +cc_test( + name = "sample10_unittest", + size = "small", + srcs = ["googletest/samples/sample10_unittest.cc"], + deps = [":gtest"], +) diff --git a/test/googletest/CMakeLists.txt b/test/googletest/CMakeLists.txt new file mode 100644 index 0000000..512e5c3 --- /dev/null +++ b/test/googletest/CMakeLists.txt @@ -0,0 +1,36 @@ +# Note: CMake support is community-based. The maintainers do not use CMake +# internally. + +cmake_minimum_required(VERSION 3.13) + +project(googletest-distribution) +set(GOOGLETEST_VERSION 1.15.2) + +if(NOT CYGWIN AND NOT MSYS AND NOT ${CMAKE_SYSTEM_NAME} STREQUAL QNX) + set(CMAKE_CXX_EXTENSIONS OFF) +endif() + +enable_testing() + +include(CMakeDependentOption) +include(GNUInstallDirs) + +# Note that googlemock target already builds googletest. +option(BUILD_GMOCK "Builds the googlemock subproject" ON) +option(INSTALL_GTEST "Enable installation of googletest. (Projects embedding googletest may want to turn this OFF.)" ON) +option(GTEST_HAS_ABSL "Use Abseil and RE2. Requires Abseil and RE2 to be separately added to the build." OFF) + +if(GTEST_HAS_ABSL) + if(NOT TARGET absl::base) + find_package(absl REQUIRED) + endif() + if(NOT TARGET re2::re2) + find_package(re2 REQUIRED) + endif() +endif() + +if(BUILD_GMOCK) + add_subdirectory( googlemock ) +else() + add_subdirectory( googletest ) +endif() diff --git a/test/googletest/CONTRIBUTING.md b/test/googletest/CONTRIBUTING.md new file mode 100644 index 0000000..ab5a47b --- /dev/null +++ b/test/googletest/CONTRIBUTING.md @@ -0,0 +1,141 @@ +# How to become a contributor and submit your own code + +## Contributor License Agreements + +We'd love to accept your patches! Before we can take them, we have to jump a +couple of legal hurdles. + +Please fill out either the individual or corporate Contributor License Agreement +(CLA). + +* If you are an individual writing original source code and you're sure you + own the intellectual property, then you'll need to sign an + [individual CLA](https://developers.google.com/open-source/cla/individual). +* If you work for a company that wants to allow you to contribute your work, + then you'll need to sign a + [corporate CLA](https://developers.google.com/open-source/cla/corporate). + +Follow either of the two links above to access the appropriate CLA and +instructions for how to sign and return it. Once we receive it, we'll be able to +accept your pull requests. + +## Are you a Googler? + +If you are a Googler, please make an attempt to submit an internal contribution +rather than a GitHub Pull Request. If you are not able to submit internally, a +PR is acceptable as an alternative. + +## Contributing A Patch + +1. Submit an issue describing your proposed change to the + [issue tracker](https://github.com/google/googletest/issues). +2. Please don't mix more than one logical change per submittal, because it + makes the history hard to follow. If you want to make a change that doesn't + have a corresponding issue in the issue tracker, please create one. +3. Also, coordinate with team members that are listed on the issue in question. + This ensures that work isn't being duplicated and communicating your plan + early also generally leads to better patches. +4. If your proposed change is accepted, and you haven't already done so, sign a + Contributor License Agreement + ([see details above](#contributor-license-agreements)). +5. Fork the desired repo, develop and test your code changes. +6. Ensure that your code adheres to the existing style in the sample to which + you are contributing. +7. Ensure that your code has an appropriate set of unit tests which all pass. +8. Submit a pull request. + +## The Google Test and Google Mock Communities + +The Google Test community exists primarily through the +[discussion group](https://groups.google.com/group/googletestframework) and the +GitHub repository. Likewise, the Google Mock community exists primarily through +their own [discussion group](https://groups.google.com/group/googlemock). You +are definitely encouraged to contribute to the discussion and you can also help +us to keep the effectiveness of the group high by following and promoting the +guidelines listed here. + +### Please Be Friendly + +Showing courtesy and respect to others is a vital part of the Google culture, +and we strongly encourage everyone participating in Google Test development to +join us in accepting nothing less. Of course, being courteous is not the same as +failing to constructively disagree with each other, but it does mean that we +should be respectful of each other when enumerating the 42 technical reasons +that a particular proposal may not be the best choice. There's never a reason to +be antagonistic or dismissive toward anyone who is sincerely trying to +contribute to a discussion. + +Sure, C++ testing is serious business and all that, but it's also a lot of fun. +Let's keep it that way. Let's strive to be one of the friendliest communities in +all of open source. + +As always, discuss Google Test in the official GoogleTest discussion group. You +don't have to actually submit code in order to sign up. Your participation +itself is a valuable contribution. + +## Style + +To keep the source consistent, readable, diffable and easy to merge, we use a +fairly rigid coding style, as defined by the +[google-styleguide](https://github.com/google/styleguide) project. All patches +will be expected to conform to the style outlined +[here](https://google.github.io/styleguide/cppguide.html). Use +[.clang-format](https://github.com/google/googletest/blob/main/.clang-format) to +check your formatting. + +## Requirements for Contributors + +If you plan to contribute a patch, you need to build Google Test, Google Mock, +and their own tests from a git checkout, which has further requirements: + +* [Python](https://www.python.org/) v3.6 or newer (for running some of the + tests and re-generating certain source files from templates) +* [CMake](https://cmake.org/) v2.8.12 or newer + +## Developing Google Test and Google Mock + +This section discusses how to make your own changes to the Google Test project. + +### Testing Google Test and Google Mock Themselves + +To make sure your changes work as intended and don't break existing +functionality, you'll want to compile and run Google Test and GoogleMock's own +tests. For that you can use CMake: + +``` +mkdir mybuild +cd mybuild +cmake -Dgtest_build_tests=ON -Dgmock_build_tests=ON ${GTEST_REPO_DIR} +``` + +To choose between building only Google Test or Google Mock, you may modify your +cmake command to be one of each + +``` +cmake -Dgtest_build_tests=ON ${GTEST_DIR} # sets up Google Test tests +cmake -Dgmock_build_tests=ON ${GMOCK_DIR} # sets up Google Mock tests +``` + +Make sure you have Python installed, as some of Google Test's tests are written +in Python. If the cmake command complains about not being able to find Python +(`Could NOT find PythonInterp (missing: PYTHON_EXECUTABLE)`), try telling it +explicitly where your Python executable can be found: + +``` +cmake -DPYTHON_EXECUTABLE=path/to/python ... +``` + +Next, you can build Google Test and / or Google Mock and all desired tests. On +\*nix, this is usually done by + +``` +make +``` + +To run the tests, do + +``` +make test +``` + +All tests should pass. diff --git a/test/googletest/CONTRIBUTORS b/test/googletest/CONTRIBUTORS new file mode 100644 index 0000000..ccea41e --- /dev/null +++ b/test/googletest/CONTRIBUTORS @@ -0,0 +1,66 @@ +# This file contains a list of people who've made non-trivial +# contribution to the Google C++ Testing Framework project. People +# who commit code to the project are encouraged to add their names +# here. Please keep the list sorted by first names. + +Ajay Joshi +Balázs Dán +Benoit Sigoure +Bharat Mediratta +Bogdan Piloca +Chandler Carruth +Chris Prince +Chris Taylor +Dan Egnor +Dave MacLachlan +David Anderson +Dean Sturtevant +Eric Roman +Gene Volovich +Hady Zalek +Hal Burch +Jeffrey Yasskin +Jim Keller +Joe Walnes +Jon Wray +Jói Sigurðsson +Keir Mierle +Keith Ray +Kenton Varda +Kostya Serebryany +Krystian Kuzniarek +Lev Makhlis +Manuel Klimek +Mario Tanev +Mark Paskin +Markus Heule +Martijn Vels +Matthew Simmons +Mika Raento +Mike Bland +Miklós Fazekas +Neal Norwitz +Nermin Ozkiranartli +Owen Carlsen +Paneendra Ba +Pasi Valminen +Patrick Hanna +Patrick Riley +Paul Menage +Peter Kaminski +Piotr Kaminski +Preston Jackson +Rainer Klaffenboeck +Russ Cox +Russ Rufer +Sean Mcafee +Sigurður Ásgeirsson +Soyeon Kim +Sverre Sundsdal +Szymon Sobik +Takeshi Yoshino +Tracy Bialik +Vadim Berman +Vlad Losev +Wolfgang Klier +Zhanyong Wan diff --git a/test/googletest/LICENSE b/test/googletest/LICENSE new file mode 100644 index 0000000..1941a11 --- /dev/null +++ b/test/googletest/LICENSE @@ -0,0 +1,28 @@ +Copyright 2008, Google Inc. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/test/googletest/MODULE.bazel b/test/googletest/MODULE.bazel new file mode 100644 index 0000000..c9a52e0 --- /dev/null +++ b/test/googletest/MODULE.bazel @@ -0,0 +1,67 @@ +# Copyright 2024 Google Inc. +# All Rights Reserved. +# +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# https://bazel.build/external/overview#bzlmod + +module( + name = "googletest", + version = "head", + compatibility_level = 1, +) + +# Only direct dependencies need to be listed below. +# Please keep the versions in sync with the versions in the WORKSPACE file. + +bazel_dep(name = "abseil-cpp", + version = "20240116.2") + +bazel_dep(name = "platforms", + version = "0.0.10") + +bazel_dep(name = "re2", + version = "2024-07-02") + +bazel_dep(name = "rules_python", + version = "0.34.0", + dev_dependency = True) + +# https://rules-python.readthedocs.io/en/stable/toolchains.html#library-modules-with-dev-only-python-usage +python = use_extension( + "@rules_python//python/extensions:python.bzl", + "python", + dev_dependency = True +) + +python.toolchain(python_version = "3.12", + is_default = True, + ignore_root_user_error = True) + +fake_fuchsia_sdk = use_repo_rule("//:fake_fuchsia_sdk.bzl", "fake_fuchsia_sdk") +fake_fuchsia_sdk(name = "fuchsia_sdk") diff --git a/test/googletest/README.md b/test/googletest/README.md new file mode 100644 index 0000000..03c70a1 --- /dev/null +++ b/test/googletest/README.md @@ -0,0 +1,142 @@ +# GoogleTest + +### Announcements + +#### Live at Head + +GoogleTest now follows the +[Abseil Live at Head philosophy](https://abseil.io/about/philosophy#upgrade-support). +We recommend +[updating to the latest commit in the `main` branch as often as possible](https://github.com/abseil/abseil-cpp/blob/master/FAQ.md#what-is-live-at-head-and-how-do-i-do-it). +We do publish occasional semantic versions, tagged with +`v${major}.${minor}.${patch}` (e.g. `v1.15.2`). + +#### Documentation Updates + +Our documentation is now live on GitHub Pages at +https://google.github.io/googletest/. We recommend browsing the documentation on +GitHub Pages rather than directly in the repository. + +#### Release 1.15.2 + +[Release 1.15.2](https://github.com/google/googletest/releases/tag/v1.15.2) is +now available. + +The 1.15.x branch requires at least C++14. + +#### Continuous Integration + +We use Google's internal systems for continuous integration. + +#### Coming Soon + +* We are planning to take a dependency on + [Abseil](https://github.com/abseil/abseil-cpp). + +## Welcome to **GoogleTest**, Google's C++ test framework! + +This repository is a merger of the formerly separate GoogleTest and GoogleMock +projects. These were so closely related that it makes sense to maintain and +release them together. + +### Getting Started + +See the [GoogleTest User's Guide](https://google.github.io/googletest/) for +documentation. We recommend starting with the +[GoogleTest Primer](https://google.github.io/googletest/primer.html). + +More information about building GoogleTest can be found at +[googletest/README.md](googletest/README.md). + +## Features + +* xUnit test framework: \ + Googletest is based on the [xUnit](https://en.wikipedia.org/wiki/XUnit) + testing framework, a popular architecture for unit testing +* Test discovery: \ + Googletest automatically discovers and runs your tests, eliminating the need + to manually register your tests +* Rich set of assertions: \ + Googletest provides a variety of assertions, such as equality, inequality, + exceptions, and more, making it easy to test your code +* User-defined assertions: \ + You can define your own assertions with Googletest, making it simple to + write tests that are specific to your code +* Death tests: \ + Googletest supports death tests, which verify that your code exits in a + certain way, making it useful for testing error-handling code +* Fatal and non-fatal failures: \ + You can specify whether a test failure should be treated as fatal or + non-fatal with Googletest, allowing tests to continue running even if a + failure occurs +* Value-parameterized tests: \ + Googletest supports value-parameterized tests, which run multiple times with + different input values, making it useful for testing functions that take + different inputs +* Type-parameterized tests: \ + Googletest also supports type-parameterized tests, which run with different + data types, making it useful for testing functions that work with different + data types +* Various options for running tests: \ + Googletest provides many options for running tests including running + individual tests, running tests in a specific order and running tests in + parallel + +## Supported Platforms + +GoogleTest follows Google's +[Foundational C++ Support Policy](https://opensource.google/documentation/policies/cplusplus-support). +See +[this table](https://github.com/google/oss-policies-info/blob/main/foundational-cxx-support-matrix.md) +for a list of currently supported versions of compilers, platforms, and build +tools. + +## Who Is Using GoogleTest? + +In addition to many internal projects at Google, GoogleTest is also used by the +following notable projects: + +* The [Chromium projects](https://www.chromium.org/) (behind the Chrome + browser and Chrome OS). +* The [LLVM](https://llvm.org/) compiler. +* [Protocol Buffers](https://github.com/google/protobuf), Google's data + interchange format. +* The [OpenCV](https://opencv.org/) computer vision library. + +## Related Open Source Projects + +[GTest Runner](https://github.com/nholthaus/gtest-runner) is a Qt5 based +automated test-runner and Graphical User Interface with powerful features for +Windows and Linux platforms. + +[GoogleTest UI](https://github.com/ospector/gtest-gbar) is a test runner that +runs your test binary, allows you to track its progress via a progress bar, and +displays a list of test failures. Clicking on one shows failure text. GoogleTest +UI is written in C#. + +[GTest TAP Listener](https://github.com/kinow/gtest-tap-listener) is an event +listener for GoogleTest that implements the +[TAP protocol](https://en.wikipedia.org/wiki/Test_Anything_Protocol) for test +result output. If your test runner understands TAP, you may find it useful. + +[gtest-parallel](https://github.com/google/gtest-parallel) is a test runner that +runs tests from your binary in parallel to provide significant speed-up. + +[GoogleTest Adapter](https://marketplace.visualstudio.com/items?itemName=DavidSchuldenfrei.gtest-adapter) +is a VS Code extension allowing to view GoogleTest in a tree view and run/debug +your tests. + +[C++ TestMate](https://github.com/matepek/vscode-catch2-test-adapter) is a VS +Code extension allowing to view GoogleTest in a tree view and run/debug your +tests. + +[Cornichon](https://pypi.org/project/cornichon/) is a small Gherkin DSL parser +that generates stub code for GoogleTest. + +## Contributing Changes + +Please read +[`CONTRIBUTING.md`](https://github.com/google/googletest/blob/main/CONTRIBUTING.md) +for details on how to contribute to this project. + +Happy testing! diff --git a/test/googletest/WORKSPACE b/test/googletest/WORKSPACE new file mode 100644 index 0000000..63f7681 --- /dev/null +++ b/test/googletest/WORKSPACE @@ -0,0 +1,62 @@ +# Copyright 2024 Google Inc. +# All Rights Reserved. +# +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +workspace(name = "googletest") + +load("//:googletest_deps.bzl", "googletest_deps") +googletest_deps() + +load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") + +http_archive( + name = "rules_python", + sha256 = "d71d2c67e0bce986e1c5a7731b4693226867c45bfe0b7c5e0067228a536fc580", + strip_prefix = "rules_python-0.29.0", + urls = ["https://github.com/bazelbuild/rules_python/releases/download/0.29.0/rules_python-0.29.0.tar.gz"], +) + +# https://github.com/bazelbuild/rules_python/releases/tag/0.29.0 +load("@rules_python//python:repositories.bzl", "py_repositories") +py_repositories() + +http_archive( + name = "bazel_skylib", + sha256 = "cd55a062e763b9349921f0f5db8c3933288dc8ba4f76dd9416aac68acee3cb94", + urls = ["https://github.com/bazelbuild/bazel-skylib/releases/download/1.5.0/bazel-skylib-1.5.0.tar.gz"], +) + +http_archive( + name = "platforms", + urls = [ + "https://mirror.bazel.build/github.com/bazelbuild/platforms/releases/download/0.0.10/platforms-0.0.10.tar.gz", + "https://github.com/bazelbuild/platforms/releases/download/0.0.10/platforms-0.0.10.tar.gz", + ], + sha256 = "218efe8ee736d26a3572663b374a253c012b716d8af0c07e842e82f238a0a7ee", +) diff --git a/test/googletest/WORKSPACE.bzlmod b/test/googletest/WORKSPACE.bzlmod new file mode 100644 index 0000000..381432c --- /dev/null +++ b/test/googletest/WORKSPACE.bzlmod @@ -0,0 +1,35 @@ +# Copyright 2024 Google Inc. +# All Rights Reserved. +# +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# https://bazel.build/external/migration#workspace.bzlmod +# +# This file is intentionally empty. When bzlmod is enabled and this +# file exists, the content of WORKSPACE is ignored. This prevents +# bzlmod builds from unintentionally depending on the WORKSPACE file. diff --git a/test/googletest/ci/linux-presubmit.sh b/test/googletest/ci/linux-presubmit.sh new file mode 100644 index 0000000..6d2b3fb --- /dev/null +++ b/test/googletest/ci/linux-presubmit.sh @@ -0,0 +1,139 @@ +#!/bin/bash +# +# Copyright 2020, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +set -euox pipefail + +readonly LINUX_LATEST_CONTAINER="gcr.io/google.com/absl-177019/linux_hybrid-latest:20240523" +readonly LINUX_GCC_FLOOR_CONTAINER="gcr.io/google.com/absl-177019/linux_gcc-floor:20230120" + +if [[ -z ${GTEST_ROOT:-} ]]; then + GTEST_ROOT="$(realpath $(dirname ${0})/..)" +fi + +if [[ -z ${STD:-} ]]; then + STD="c++14 c++17 c++20" +fi + +# Test the CMake build +for cc in /usr/local/bin/gcc /opt/llvm/clang/bin/clang; do + for cmake_off_on in OFF ON; do + time docker run \ + --volume="${GTEST_ROOT}:/src:ro" \ + --tmpfs="/build:exec" \ + --workdir="/build" \ + --rm \ + --env="CC=${cc}" \ + --env=CXXFLAGS="-Werror -Wdeprecated" \ + ${LINUX_LATEST_CONTAINER} \ + /bin/bash -c " + cmake /src \ + -DCMAKE_CXX_STANDARD=14 \ + -Dgtest_build_samples=ON \ + -Dgtest_build_tests=ON \ + -Dgmock_build_tests=ON \ + -Dcxx_no_exception=${cmake_off_on} \ + -Dcxx_no_rtti=${cmake_off_on} && \ + make -j$(nproc) && \ + ctest -j$(nproc) --output-on-failure" + done +done + +# Do one test with an older version of GCC +# TODO(googletest-team): This currently uses Bazel 5. When upgrading to a +# version of Bazel that supports Bzlmod, add --enable_bzlmod=false to keep test +# coverage for the old WORKSPACE dependency management. +time docker run \ + --volume="${GTEST_ROOT}:/src:ro" \ + --workdir="/src" \ + --rm \ + --env="CC=/usr/local/bin/gcc" \ + --env="BAZEL_CXXOPTS=-std=c++14" \ + ${LINUX_GCC_FLOOR_CONTAINER} \ + /usr/local/bin/bazel test ... \ + --copt="-Wall" \ + --copt="-Werror" \ + --copt="-Wuninitialized" \ + --copt="-Wundef" \ + --copt="-Wno-error=pragmas" \ + --features=external_include_paths \ + --keep_going \ + --show_timestamps \ + --test_output=errors + +# Test GCC +for std in ${STD}; do + for absl in 0 1; do + time docker run \ + --volume="${GTEST_ROOT}:/src:ro" \ + --workdir="/src" \ + --rm \ + --env="CC=/usr/local/bin/gcc" \ + --env="BAZEL_CXXOPTS=-std=${std}" \ + ${LINUX_LATEST_CONTAINER} \ + /usr/local/bin/bazel test ... \ + --copt="-Wall" \ + --copt="-Werror" \ + --copt="-Wuninitialized" \ + --copt="-Wundef" \ + --define="absl=${absl}" \ + --enable_bzlmod=true \ + --features=external_include_paths \ + --keep_going \ + --show_timestamps \ + --test_output=errors + done +done + +# Test Clang +for std in ${STD}; do + for absl in 0 1; do + time docker run \ + --volume="${GTEST_ROOT}:/src:ro" \ + --workdir="/src" \ + --rm \ + --env="CC=/opt/llvm/clang/bin/clang" \ + --env="BAZEL_CXXOPTS=-std=${std}" \ + ${LINUX_LATEST_CONTAINER} \ + /usr/local/bin/bazel test ... \ + --copt="--gcc-toolchain=/usr/local" \ + --copt="-Wall" \ + --copt="-Werror" \ + --copt="-Wuninitialized" \ + --copt="-Wundef" \ + --define="absl=${absl}" \ + --enable_bzlmod=true \ + --features=external_include_paths \ + --keep_going \ + --linkopt="--gcc-toolchain=/usr/local" \ + --show_timestamps \ + --test_output=errors + done +done diff --git a/test/googletest/ci/macos-presubmit.sh b/test/googletest/ci/macos-presubmit.sh new file mode 100644 index 0000000..70eaa74 --- /dev/null +++ b/test/googletest/ci/macos-presubmit.sh @@ -0,0 +1,77 @@ +#!/bin/bash +# +# Copyright 2020, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +set -euox pipefail + +if [[ -z ${GTEST_ROOT:-} ]]; then + GTEST_ROOT="$(realpath $(dirname ${0})/..)" +fi + +# Test the CMake build +for cmake_off_on in OFF ON; do + BUILD_DIR=$(mktemp -d build_dir.XXXXXXXX) + cd ${BUILD_DIR} + time cmake ${GTEST_ROOT} \ + -DCMAKE_CXX_STANDARD=14 \ + -Dgtest_build_samples=ON \ + -Dgtest_build_tests=ON \ + -Dgmock_build_tests=ON \ + -Dcxx_no_exception=${cmake_off_on} \ + -Dcxx_no_rtti=${cmake_off_on} + time make + time ctest -j$(nproc) --output-on-failure +done + +# Test the Bazel build + +# If we are running on Kokoro, check for a versioned Bazel binary. +KOKORO_GFILE_BAZEL_BIN="bazel-7.0.0-darwin-x86_64" +if [[ ${KOKORO_GFILE_DIR:-} ]] && [[ -f ${KOKORO_GFILE_DIR}/${KOKORO_GFILE_BAZEL_BIN} ]]; then + BAZEL_BIN="${KOKORO_GFILE_DIR}/${KOKORO_GFILE_BAZEL_BIN}" + chmod +x ${BAZEL_BIN} +else + BAZEL_BIN="bazel" +fi + +cd ${GTEST_ROOT} +for absl in 0 1; do + ${BAZEL_BIN} test ... \ + --copt="-Wall" \ + --copt="-Werror" \ + --copt="-Wundef" \ + --cxxopt="-std=c++14" \ + --define="absl=${absl}" \ + --enable_bzlmod=true \ + --features=external_include_paths \ + --keep_going \ + --show_timestamps \ + --test_output=errors +done diff --git a/test/googletest/ci/windows-presubmit.bat b/test/googletest/ci/windows-presubmit.bat new file mode 100644 index 0000000..1adc1a1 --- /dev/null +++ b/test/googletest/ci/windows-presubmit.bat @@ -0,0 +1,63 @@ +SETLOCAL ENABLEDELAYEDEXPANSION + +SET BAZEL_EXE=%KOKORO_GFILE_DIR%\bazel-7.0.0-windows-x86_64.exe + +SET PATH=C:\Python34;%PATH% +SET BAZEL_PYTHON=C:\python34\python.exe +SET BAZEL_SH=C:\tools\msys64\usr\bin\bash.exe +SET CMAKE_BIN="cmake.exe" +SET CTEST_BIN="ctest.exe" +SET CTEST_OUTPUT_ON_FAILURE=1 +SET CMAKE_BUILD_PARALLEL_LEVEL=16 +SET CTEST_PARALLEL_LEVEL=16 + +IF EXIST git\googletest ( + CD git\googletest +) ELSE IF EXIST github\googletest ( + CD github\googletest +) + +IF %errorlevel% neq 0 EXIT /B 1 + +:: ---------------------------------------------------------------------------- +:: CMake +MKDIR cmake_msvc2022 +CD cmake_msvc2022 + +%CMAKE_BIN% .. ^ + -G "Visual Studio 17 2022" ^ + -DPYTHON_EXECUTABLE:FILEPATH=c:\python37\python.exe ^ + -DPYTHON_INCLUDE_DIR:PATH=c:\python37\include ^ + -DPYTHON_LIBRARY:FILEPATH=c:\python37\lib\site-packages\pip ^ + -Dgtest_build_samples=ON ^ + -Dgtest_build_tests=ON ^ + -Dgmock_build_tests=ON +IF %errorlevel% neq 0 EXIT /B 1 + +%CMAKE_BIN% --build . --target ALL_BUILD --config Debug -- -maxcpucount +IF %errorlevel% neq 0 EXIT /B 1 + +%CTEST_BIN% -C Debug --timeout 600 +IF %errorlevel% neq 0 EXIT /B 1 + +CD .. +RMDIR /S /Q cmake_msvc2022 + +:: ---------------------------------------------------------------------------- +:: Bazel + +:: The default home directory on Kokoro is a long path which causes errors +:: because of Windows limitations on path length. +:: --output_user_root=C:\tmp causes Bazel to use a shorter path. +SET BAZEL_VS=C:\Program Files\Microsoft Visual Studio\2022\Community +%BAZEL_EXE% ^ + --output_user_root=C:\tmp ^ + test ... ^ + --compilation_mode=dbg ^ + --copt=/std:c++14 ^ + --copt=/WX ^ + --enable_bzlmod=true ^ + --keep_going ^ + --test_output=errors ^ + --test_tag_filters=-no_test_msvc2017 +IF %errorlevel% neq 0 EXIT /B 1 diff --git a/test/googletest/docs/_config.yml b/test/googletest/docs/_config.yml new file mode 100644 index 0000000..d12867e --- /dev/null +++ b/test/googletest/docs/_config.yml @@ -0,0 +1 @@ +title: GoogleTest diff --git a/test/googletest/docs/_data/navigation.yml b/test/googletest/docs/_data/navigation.yml new file mode 100644 index 0000000..9f33327 --- /dev/null +++ b/test/googletest/docs/_data/navigation.yml @@ -0,0 +1,43 @@ +nav: +- section: "Get Started" + items: + - title: "Supported Platforms" + url: "/platforms.html" + - title: "Quickstart: Bazel" + url: "/quickstart-bazel.html" + - title: "Quickstart: CMake" + url: "/quickstart-cmake.html" +- section: "Guides" + items: + - title: "GoogleTest Primer" + url: "/primer.html" + - title: "Advanced Topics" + url: "/advanced.html" + - title: "Mocking for Dummies" + url: "/gmock_for_dummies.html" + - title: "Mocking Cookbook" + url: "/gmock_cook_book.html" + - title: "Mocking Cheat Sheet" + url: "/gmock_cheat_sheet.html" +- section: "References" + items: + - title: "Testing Reference" + url: "/reference/testing.html" + - title: "Mocking Reference" + url: "/reference/mocking.html" + - title: "Assertions" + url: "/reference/assertions.html" + - title: "Matchers" + url: "/reference/matchers.html" + - title: "Actions" + url: "/reference/actions.html" + - title: "Testing FAQ" + url: "/faq.html" + - title: "Mocking FAQ" + url: "/gmock_faq.html" + - title: "Code Samples" + url: "/samples.html" + - title: "Using pkg-config" + url: "/pkgconfig.html" + - title: "Community Documentation" + url: "/community_created_documentation.html" diff --git a/test/googletest/docs/_layouts/default.html b/test/googletest/docs/_layouts/default.html new file mode 100644 index 0000000..c7f331b --- /dev/null +++ b/test/googletest/docs/_layouts/default.html @@ -0,0 +1,58 @@ + + + + + + + +{% seo %} + + + + + + +
+
+ {{ content }} +
+ +
+ + + + diff --git a/test/googletest/docs/_sass/main.scss b/test/googletest/docs/_sass/main.scss new file mode 100644 index 0000000..92edc87 --- /dev/null +++ b/test/googletest/docs/_sass/main.scss @@ -0,0 +1,200 @@ +// Styles for GoogleTest docs website on GitHub Pages. +// Color variables are defined in +// https://github.com/pages-themes/primer/tree/master/_sass/primer-support/lib/variables + +$sidebar-width: 260px; + +body { + display: flex; + margin: 0; +} + +.sidebar { + background: $black; + color: $text-white; + flex-shrink: 0; + height: 100vh; + overflow: auto; + position: sticky; + top: 0; + width: $sidebar-width; +} + +.sidebar h1 { + font-size: 1.5em; +} + +.sidebar h2 { + color: $gray-light; + font-size: 0.8em; + font-weight: normal; + margin-bottom: 0.8em; + padding-left: 2.5em; + text-transform: uppercase; +} + +.sidebar .header { + background: $black; + padding: 2em; + position: sticky; + top: 0; + width: 100%; +} + +.sidebar .header a { + color: $text-white; + text-decoration: none; +} + +.sidebar .nav-toggle { + display: none; +} + +.sidebar .expander { + cursor: pointer; + display: none; + height: 3em; + position: absolute; + right: 1em; + top: 1.5em; + width: 3em; +} + +.sidebar .expander .arrow { + border: solid $white; + border-width: 0 3px 3px 0; + display: block; + height: 0.7em; + margin: 1em auto; + transform: rotate(45deg); + transition: transform 0.5s; + width: 0.7em; +} + +.sidebar nav { + width: 100%; +} + +.sidebar nav ul { + list-style-type: none; + margin-bottom: 1em; + padding: 0; + + &:last-child { + margin-bottom: 2em; + } + + a { + text-decoration: none; + } + + li { + color: $text-white; + padding-left: 2em; + text-decoration: none; + } + + li.active { + background: $border-gray-darker; + font-weight: bold; + } + + li:hover { + background: $border-gray-darker; + } +} + +.main { + background-color: $bg-gray; + width: calc(100% - #{$sidebar-width}); +} + +.main .main-inner { + background-color: $white; + padding: 2em; +} + +.main .footer { + margin: 0; + padding: 2em; +} + +.main table th { + text-align: left; +} + +.main .callout { + border-left: 0.25em solid $white; + padding: 1em; + + a { + text-decoration: underline; + } + + &.important { + background-color: $bg-yellow-light; + border-color: $bg-yellow; + color: $black; + } + + &.note { + background-color: $bg-blue-light; + border-color: $text-blue; + color: $text-blue; + } + + &.tip { + background-color: $green-000; + border-color: $green-700; + color: $green-700; + } + + &.warning { + background-color: $red-000; + border-color: $text-red; + color: $text-red; + } +} + +.main .good pre { + background-color: $bg-green-light; +} + +.main .bad pre { + background-color: $red-000; +} + +@media all and (max-width: 768px) { + body { + flex-direction: column; + } + + .sidebar { + height: auto; + position: relative; + width: 100%; + } + + .sidebar .expander { + display: block; + } + + .sidebar nav { + height: 0; + overflow: hidden; + } + + .sidebar .nav-toggle:checked { + & ~ nav { + height: auto; + } + + & + .expander .arrow { + transform: rotate(-135deg); + } + } + + .main { + width: 100%; + } +} diff --git a/test/googletest/docs/advanced.md b/test/googletest/docs/advanced.md new file mode 100644 index 0000000..240588a --- /dev/null +++ b/test/googletest/docs/advanced.md @@ -0,0 +1,2446 @@ +# Advanced GoogleTest Topics + +## Introduction + +Now that you have read the [GoogleTest Primer](primer.md) and learned how to +write tests using GoogleTest, it's time to learn some new tricks. This document +will show you more assertions as well as how to construct complex failure +messages, propagate fatal failures, reuse and speed up your test fixtures, and +use various flags with your tests. + +## More Assertions + +This section covers some less frequently used, but still significant, +assertions. + +### Explicit Success and Failure + +See [Explicit Success and Failure](reference/assertions.md#success-failure) in +the Assertions Reference. + +### Exception Assertions + +See [Exception Assertions](reference/assertions.md#exceptions) in the Assertions +Reference. + +### Predicate Assertions for Better Error Messages + +Even though GoogleTest has a rich set of assertions, they can never be complete, +as it's impossible (nor a good idea) to anticipate all scenarios a user might +run into. Therefore, sometimes a user has to use `EXPECT_TRUE()` to check a +complex expression, for lack of a better macro. This has the problem of not +showing you the values of the parts of the expression, making it hard to +understand what went wrong. As a workaround, some users choose to construct the +failure message by themselves, streaming it into `EXPECT_TRUE()`. However, this +is awkward especially when the expression has side-effects or is expensive to +evaluate. + +GoogleTest gives you three different options to solve this problem: + +#### Using an Existing Boolean Function + +If you already have a function or functor that returns `bool` (or a type that +can be implicitly converted to `bool`), you can use it in a *predicate +assertion* to get the function arguments printed for free. See +[`EXPECT_PRED*`](reference/assertions.md#EXPECT_PRED) in the Assertions +Reference for details. + +#### Using a Function That Returns an AssertionResult + +While `EXPECT_PRED*()` and friends are handy for a quick job, the syntax is not +satisfactory: you have to use different macros for different arities, and it +feels more like Lisp than C++. The `::testing::AssertionResult` class solves +this problem. + +An `AssertionResult` object represents the result of an assertion (whether it's +a success or a failure, and an associated message). You can create an +`AssertionResult` using one of these factory functions: + +```c++ +namespace testing { + +// Returns an AssertionResult object to indicate that an assertion has +// succeeded. +AssertionResult AssertionSuccess(); + +// Returns an AssertionResult object to indicate that an assertion has +// failed. +AssertionResult AssertionFailure(); + +} +``` + +You can then use the `<<` operator to stream messages to the `AssertionResult` +object. + +To provide more readable messages in Boolean assertions (e.g. `EXPECT_TRUE()`), +write a predicate function that returns `AssertionResult` instead of `bool`. For +example, if you define `IsEven()` as: + +```c++ +testing::AssertionResult IsEven(int n) { + if ((n % 2) == 0) + return testing::AssertionSuccess(); + else + return testing::AssertionFailure() << n << " is odd"; +} +``` + +instead of: + +```c++ +bool IsEven(int n) { + return (n % 2) == 0; +} +``` + +the failed assertion `EXPECT_TRUE(IsEven(Fib(4)))` will print: + +```none +Value of: IsEven(Fib(4)) + Actual: false (3 is odd) +Expected: true +``` + +instead of a more opaque + +```none +Value of: IsEven(Fib(4)) + Actual: false +Expected: true +``` + +If you want informative messages in `EXPECT_FALSE` and `ASSERT_FALSE` as well +(one third of Boolean assertions in the Google code base are negative ones), and +are fine with making the predicate slower in the success case, you can supply a +success message: + +```c++ +testing::AssertionResult IsEven(int n) { + if ((n % 2) == 0) + return testing::AssertionSuccess() << n << " is even"; + else + return testing::AssertionFailure() << n << " is odd"; +} +``` + +Then the statement `EXPECT_FALSE(IsEven(Fib(6)))` will print + +```none + Value of: IsEven(Fib(6)) + Actual: true (8 is even) + Expected: false +``` + +#### Using a Predicate-Formatter + +If you find the default message generated by +[`EXPECT_PRED*`](reference/assertions.md#EXPECT_PRED) and +[`EXPECT_TRUE`](reference/assertions.md#EXPECT_TRUE) unsatisfactory, or some +arguments to your predicate do not support streaming to `ostream`, you can +instead use *predicate-formatter assertions* to *fully* customize how the +message is formatted. See +[`EXPECT_PRED_FORMAT*`](reference/assertions.md#EXPECT_PRED_FORMAT) in the +Assertions Reference for details. + +### Floating-Point Comparison + +See [Floating-Point Comparison](reference/assertions.md#floating-point) in the +Assertions Reference. + +#### Floating-Point Predicate-Format Functions + +Some floating-point operations are useful, but not that often used. In order to +avoid an explosion of new macros, we provide them as predicate-format functions +that can be used in the predicate assertion macro +[`EXPECT_PRED_FORMAT2`](reference/assertions.md#EXPECT_PRED_FORMAT), for +example: + +```c++ +using ::testing::FloatLE; +using ::testing::DoubleLE; +... +EXPECT_PRED_FORMAT2(FloatLE, val1, val2); +EXPECT_PRED_FORMAT2(DoubleLE, val1, val2); +``` + +The above code verifies that `val1` is less than, or approximately equal to, +`val2`. + +### Asserting Using gMock Matchers + +See [`EXPECT_THAT`](reference/assertions.md#EXPECT_THAT) in the Assertions +Reference. + +### More String Assertions + +(Please read the [previous](#asserting-using-gmock-matchers) section first if +you haven't.) + +You can use the gMock [string matchers](reference/matchers.md#string-matchers) +with [`EXPECT_THAT`](reference/assertions.md#EXPECT_THAT) to do more string +comparison tricks (sub-string, prefix, suffix, regular expression, and etc). For +example, + +```c++ +using ::testing::HasSubstr; +using ::testing::MatchesRegex; +... + ASSERT_THAT(foo_string, HasSubstr("needle")); + EXPECT_THAT(bar_string, MatchesRegex("\\w*\\d+")); +``` + +### Windows HRESULT assertions + +See [Windows HRESULT Assertions](reference/assertions.md#HRESULT) in the +Assertions Reference. + +### Type Assertions + +You can call the function + +```c++ +::testing::StaticAssertTypeEq(); +``` + +to assert that types `T1` and `T2` are the same. The function does nothing if +the assertion is satisfied. If the types are different, the function call will +fail to compile, the compiler error message will say that `T1 and T2 are not the +same type` and most likely (depending on the compiler) show you the actual +values of `T1` and `T2`. This is mainly useful inside template code. + +**Caveat**: When used inside a member function of a class template or a function +template, `StaticAssertTypeEq()` is effective only if the function is +instantiated. For example, given: + +```c++ +template class Foo { + public: + void Bar() { testing::StaticAssertTypeEq(); } +}; +``` + +the code: + +```c++ +void Test1() { Foo foo; } +``` + +will not generate a compiler error, as `Foo::Bar()` is never actually +instantiated. Instead, you need: + +```c++ +void Test2() { Foo foo; foo.Bar(); } +``` + +to cause a compiler error. + +### Assertion Placement + +You can use assertions in any C++ function. In particular, it doesn't have to be +a method of the test fixture class. The one constraint is that assertions that +generate a fatal failure (`FAIL*` and `ASSERT_*`) can only be used in +void-returning functions. This is a consequence of Google's not using +exceptions. By placing it in a non-void function you'll get a confusing compile +error like `"error: void value not ignored as it ought to be"` or `"cannot +initialize return object of type 'bool' with an rvalue of type 'void'"` or +`"error: no viable conversion from 'void' to 'string'"`. + +If you need to use fatal assertions in a function that returns non-void, one +option is to make the function return the value in an out parameter instead. For +example, you can rewrite `T2 Foo(T1 x)` to `void Foo(T1 x, T2* result)`. You +need to make sure that `*result` contains some sensible value even when the +function returns prematurely. As the function now returns `void`, you can use +any assertion inside of it. + +If changing the function's type is not an option, you should just use assertions +that generate non-fatal failures, such as `ADD_FAILURE*` and `EXPECT_*`. + +{: .callout .note} +NOTE: Constructors and destructors are not considered void-returning functions, +according to the C++ language specification, and so you may not use fatal +assertions in them; you'll get a compilation error if you try. Instead, either +call `abort` and crash the entire test executable, or put the fatal assertion in +a `SetUp`/`TearDown` function; see +[constructor/destructor vs. `SetUp`/`TearDown`](faq.md#CtorVsSetUp) + +{: .callout .warning} +WARNING: A fatal assertion in a helper function (private void-returning method) +called from a constructor or destructor does not terminate the current test, as +your intuition might suggest: it merely returns from the constructor or +destructor early, possibly leaving your object in a partially-constructed or +partially-destructed state! You almost certainly want to `abort` or use +`SetUp`/`TearDown` instead. + +## Skipping test execution + +Related to the assertions `SUCCEED()` and `FAIL()`, you can prevent further test +execution at runtime with the `GTEST_SKIP()` macro. This is useful when you need +to check for preconditions of the system under test during runtime and skip +tests in a meaningful way. + +`GTEST_SKIP()` can be used in individual test cases or in the `SetUp()` methods +of classes derived from either `::testing::Environment` or `::testing::Test`. +For example: + +```c++ +TEST(SkipTest, DoesSkip) { + GTEST_SKIP() << "Skipping single test"; + EXPECT_EQ(0, 1); // Won't fail; it won't be executed +} + +class SkipFixture : public ::testing::Test { + protected: + void SetUp() override { + GTEST_SKIP() << "Skipping all tests for this fixture"; + } +}; + +// Tests for SkipFixture won't be executed. +TEST_F(SkipFixture, SkipsOneTest) { + EXPECT_EQ(5, 7); // Won't fail +} +``` + +As with assertion macros, you can stream a custom message into `GTEST_SKIP()`. + +## Teaching GoogleTest How to Print Your Values + +When a test assertion such as `EXPECT_EQ` fails, GoogleTest prints the argument +values to help you debug. It does this using a user-extensible value printer. + +This printer knows how to print built-in C++ types, native arrays, STL +containers, and any type that supports the `<<` operator. For other types, it +prints the raw bytes in the value and hopes that you the user can figure it out. + +As mentioned earlier, the printer is *extensible*. That means you can teach it +to do a better job at printing your particular type than to dump the bytes. To +do that, define an `AbslStringify()` overload as a `friend` function template +for your type: + +```cpp +namespace foo { + +class Point { // We want GoogleTest to be able to print instances of this. + ... + // Provide a friend overload. + template + friend void AbslStringify(Sink& sink, const Point& point) { + absl::Format(&sink, "(%d, %d)", point.x, point.y); + } + + int x; + int y; +}; + +// If you can't declare the function in the class it's important that the +// AbslStringify overload is defined in the SAME namespace that defines Point. +// C++'s look-up rules rely on that. +enum class EnumWithStringify { kMany = 0, kChoices = 1 }; + +template +void AbslStringify(Sink& sink, EnumWithStringify e) { + absl::Format(&sink, "%s", e == EnumWithStringify::kMany ? "Many" : "Choices"); +} + +} // namespace foo +``` + +{: .callout .note} +Note: `AbslStringify()` utilizes a generic "sink" buffer to construct its +string. For more information about supported operations on `AbslStringify()`'s +sink, see go/abslstringify. + +`AbslStringify()` can also use `absl::StrFormat`'s catch-all `%v` type specifier +within its own format strings to perform type deduction. `Point` above could be +formatted as `"(%v, %v)"` for example, and deduce the `int` values as `%d`. + +Sometimes, `AbslStringify()` might not be an option: your team may wish to print +types with extra debugging information for testing purposes only. If so, you can +instead define a `PrintTo()` function like this: + +```c++ +#include + +namespace foo { + +class Point { + ... + friend void PrintTo(const Point& point, std::ostream* os) { + *os << "(" << point.x << "," << point.y << ")"; + } + + int x; + int y; +}; + +// If you can't declare the function in the class it's important that PrintTo() +// is defined in the SAME namespace that defines Point. C++'s look-up rules +// rely on that. +void PrintTo(const Point& point, std::ostream* os) { + *os << "(" << point.x << "," << point.y << ")"; +} + +} // namespace foo +``` + +If you have defined both `AbslStringify()` and `PrintTo()`, the latter will be +used by GoogleTest. This allows you to customize how the value appears in +GoogleTest's output without affecting code that relies on the behavior of +`AbslStringify()`. + +If you have an existing `<<` operator and would like to define an +`AbslStringify()`, the latter will be used for GoogleTest printing. + +If you want to print a value `x` using GoogleTest's value printer yourself, just +call `::testing::PrintToString(x)`, which returns an `std::string`: + +```c++ +vector > point_ints = GetPointIntVector(); + +EXPECT_TRUE(IsCorrectPointIntVector(point_ints)) + << "point_ints = " << testing::PrintToString(point_ints); +``` + +For more details regarding `AbslStringify()` and its integration with other +libraries, see go/abslstringify. + +## Death Tests + +In many applications, there are assertions that can cause application failure if +a condition is not met. These consistency checks, which ensure that the program +is in a known good state, are there to fail at the earliest possible time after +some program state is corrupted. If the assertion checks the wrong condition, +then the program may proceed in an erroneous state, which could lead to memory +corruption, security holes, or worse. Hence it is vitally important to test that +such assertion statements work as expected. + +Since these precondition checks cause the processes to die, we call such tests +_death tests_. More generally, any test that checks that a program terminates +(except by throwing an exception) in an expected fashion is also a death test. + +Note that if a piece of code throws an exception, we don't consider it "death" +for the purpose of death tests, as the caller of the code could catch the +exception and avoid the crash. If you want to verify exceptions thrown by your +code, see [Exception Assertions](#ExceptionAssertions). + +If you want to test `EXPECT_*()/ASSERT_*()` failures in your test code, see +["Catching" Failures](#catching-failures). + +### How to Write a Death Test + +GoogleTest provides assertion macros to support death tests. See +[Death Assertions](reference/assertions.md#death) in the Assertions Reference +for details. + +To write a death test, simply use one of the macros inside your test function. +For example, + +```c++ +TEST(MyDeathTest, Foo) { + // This death test uses a compound statement. + ASSERT_DEATH({ + int n = 5; + Foo(&n); + }, "Error on line .* of Foo()"); +} + +TEST(MyDeathTest, NormalExit) { + EXPECT_EXIT(NormalExit(), testing::ExitedWithCode(0), "Success"); +} + +TEST(MyDeathTest, KillProcess) { + EXPECT_EXIT(KillProcess(), testing::KilledBySignal(SIGKILL), + "Sending myself unblockable signal"); +} +``` + +verifies that: + +* calling `Foo(5)` causes the process to die with the given error message, +* calling `NormalExit()` causes the process to print `"Success"` to stderr and + exit with exit code 0, and +* calling `KillProcess()` kills the process with signal `SIGKILL`. + +The test function body may contain other assertions and statements as well, if +necessary. + +Note that a death test only cares about three things: + +1. does `statement` abort or exit the process? +2. (in the case of `ASSERT_EXIT` and `EXPECT_EXIT`) does the exit status + satisfy `predicate`? Or (in the case of `ASSERT_DEATH` and `EXPECT_DEATH`) + is the exit status non-zero? And +3. does the stderr output match `matcher`? + +In particular, if `statement` generates an `ASSERT_*` or `EXPECT_*` failure, it +will **not** cause the death test to fail, as GoogleTest assertions don't abort +the process. + +### Death Test Naming + +{: .callout .important} +IMPORTANT: We strongly recommend you to follow the convention of naming your +**test suite** (not test) `*DeathTest` when it contains a death test, as +demonstrated in the above example. The +[Death Tests And Threads](#death-tests-and-threads) section below explains why. + +If a test fixture class is shared by normal tests and death tests, you can use +`using` or `typedef` to introduce an alias for the fixture class and avoid +duplicating its code: + +```c++ +class FooTest : public testing::Test { ... }; + +using FooDeathTest = FooTest; + +TEST_F(FooTest, DoesThis) { + // normal test +} + +TEST_F(FooDeathTest, DoesThat) { + // death test +} +``` + +### Regular Expression Syntax + +When built with Bazel and using Abseil, GoogleTest uses the +[RE2](https://github.com/google/re2/wiki/Syntax) syntax. Otherwise, for POSIX +systems (Linux, Cygwin, Mac), GoogleTest uses the +[POSIX extended regular expression](https://www.opengroup.org/onlinepubs/009695399/basedefs/xbd_chap09.html#tag_09_04) +syntax. To learn about POSIX syntax, you may want to read this +[Wikipedia entry](https://en.wikipedia.org/wiki/Regular_expression#POSIX_extended). + +On Windows, GoogleTest uses its own simple regular expression implementation. It +lacks many features. For example, we don't support union (`"x|y"`), grouping +(`"(xy)"`), brackets (`"[xy]"`), and repetition count (`"x{5,7}"`), among +others. Below is what we do support (`A` denotes a literal character, period +(`.`), or a single `\\ ` escape sequence; `x` and `y` denote regular +expressions.): + +Expression | Meaning +---------- | -------------------------------------------------------------- +`c` | matches any literal character `c` +`\\d` | matches any decimal digit +`\\D` | matches any character that's not a decimal digit +`\\f` | matches `\f` +`\\n` | matches `\n` +`\\r` | matches `\r` +`\\s` | matches any ASCII whitespace, including `\n` +`\\S` | matches any character that's not a whitespace +`\\t` | matches `\t` +`\\v` | matches `\v` +`\\w` | matches any letter, `_`, or decimal digit +`\\W` | matches any character that `\\w` doesn't match +`\\c` | matches any literal character `c`, which must be a punctuation +`.` | matches any single character except `\n` +`A?` | matches 0 or 1 occurrences of `A` +`A*` | matches 0 or many occurrences of `A` +`A+` | matches 1 or many occurrences of `A` +`^` | matches the beginning of a string (not that of each line) +`$` | matches the end of a string (not that of each line) +`xy` | matches `x` followed by `y` + +To help you determine which capability is available on your system, GoogleTest +defines macros to govern which regular expression it is using. The macros are: +`GTEST_USES_SIMPLE_RE=1` or `GTEST_USES_POSIX_RE=1`. If you want your death +tests to work in all cases, you can either `#if` on these macros or use the more +limited syntax only. + +### How It Works + +See [Death Assertions](reference/assertions.md#death) in the Assertions +Reference. + +### Death Tests And Threads + +The reason for the two death test styles has to do with thread safety. Due to +well-known problems with forking in the presence of threads, death tests should +be run in a single-threaded context. Sometimes, however, it isn't feasible to +arrange that kind of environment. For example, statically-initialized modules +may start threads before main is ever reached. Once threads have been created, +it may be difficult or impossible to clean them up. + +GoogleTest has three features intended to raise awareness of threading issues. + +1. A warning is emitted if multiple threads are running when a death test is + encountered. +2. Test suites with a name ending in "DeathTest" are run before all other + tests. +3. It uses `clone()` instead of `fork()` to spawn the child process on Linux + (`clone()` is not available on Cygwin and Mac), as `fork()` is more likely + to cause the child to hang when the parent process has multiple threads. + +It's perfectly fine to create threads inside a death test statement; they are +executed in a separate process and cannot affect the parent. + +### Death Test Styles + +The "threadsafe" death test style was introduced in order to help mitigate the +risks of testing in a possibly multithreaded environment. It trades increased +test execution time (potentially dramatically so) for improved thread safety. + +The automated testing framework does not set the style flag. You can choose a +particular style of death tests by setting the flag programmatically: + +```c++ +GTEST_FLAG_SET(death_test_style, "threadsafe"); +``` + +You can do this in `main()` to set the style for all death tests in the binary, +or in individual tests. Recall that flags are saved before running each test and +restored afterwards, so you need not do that yourself. For example: + +```c++ +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + GTEST_FLAG_SET(death_test_style, "fast"); + return RUN_ALL_TESTS(); +} + +TEST(MyDeathTest, TestOne) { + GTEST_FLAG_SET(death_test_style, "threadsafe"); + // This test is run in the "threadsafe" style: + ASSERT_DEATH(ThisShouldDie(), ""); +} + +TEST(MyDeathTest, TestTwo) { + // This test is run in the "fast" style: + ASSERT_DEATH(ThisShouldDie(), ""); +} +``` + +### Caveats + +The `statement` argument of `ASSERT_EXIT()` can be any valid C++ statement. If +it leaves the current function via a `return` statement or by throwing an +exception, the death test is considered to have failed. Some GoogleTest macros +may return from the current function (e.g. `ASSERT_TRUE()`), so be sure to avoid +them in `statement`. + +Since `statement` runs in the child process, any in-memory side effect (e.g. +modifying a variable, releasing memory, etc) it causes will *not* be observable +in the parent process. In particular, if you release memory in a death test, +your program will fail the heap check as the parent process will never see the +memory reclaimed. To solve this problem, you can + +1. try not to free memory in a death test; +2. free the memory again in the parent process; or +3. do not use the heap checker in your program. + +Due to an implementation detail, you cannot place multiple death test assertions +on the same line; otherwise, compilation will fail with an unobvious error +message. + +Despite the improved thread safety afforded by the "threadsafe" style of death +test, thread problems such as deadlock are still possible in the presence of +handlers registered with `pthread_atfork(3)`. + +## Using Assertions in Sub-routines + +{: .callout .note} +Note: If you want to put a series of test assertions in a subroutine to check +for a complex condition, consider using +[a custom GMock matcher](gmock_cook_book.md#NewMatchers) instead. This lets you +provide a more readable error message in case of failure and avoid all of the +issues described below. + +### Adding Traces to Assertions + +If a test sub-routine is called from several places, when an assertion inside it +fails, it can be hard to tell which invocation of the sub-routine the failure is +from. You can alleviate this problem using extra logging or custom failure +messages, but that usually clutters up your tests. A better solution is to use +the `SCOPED_TRACE` macro or the `ScopedTrace` utility: + +```c++ +SCOPED_TRACE(message); +``` + +```c++ +ScopedTrace trace("file_path", line_number, message); +``` + +where `message` can be anything streamable to `std::ostream`. `SCOPED_TRACE` +macro will cause the current file name, line number, and the given message to be +added in every failure message. `ScopedTrace` accepts explicit file name and +line number in arguments, which is useful for writing test helpers. The effect +will be undone when the control leaves the current lexical scope. + +For example, + +```c++ +10: void Sub1(int n) { +11: EXPECT_EQ(Bar(n), 1); +12: EXPECT_EQ(Bar(n + 1), 2); +13: } +14: +15: TEST(FooTest, Bar) { +16: { +17: SCOPED_TRACE("A"); // This trace point will be included in +18: // every failure in this scope. +19: Sub1(1); +20: } +21: // Now it won't. +22: Sub1(9); +23: } +``` + +could result in messages like these: + +```none +path/to/foo_test.cc:11: Failure +Value of: Bar(n) +Expected: 1 + Actual: 2 +Google Test trace: +path/to/foo_test.cc:17: A + +path/to/foo_test.cc:12: Failure +Value of: Bar(n + 1) +Expected: 2 + Actual: 3 +``` + +Without the trace, it would've been difficult to know which invocation of +`Sub1()` the two failures come from respectively. (You could add an extra +message to each assertion in `Sub1()` to indicate the value of `n`, but that's +tedious.) + +Some tips on using `SCOPED_TRACE`: + +1. With a suitable message, it's often enough to use `SCOPED_TRACE` at the + beginning of a sub-routine, instead of at each call site. +2. When calling sub-routines inside a loop, make the loop iterator part of the + message in `SCOPED_TRACE` such that you can know which iteration the failure + is from. +3. Sometimes the line number of the trace point is enough for identifying the + particular invocation of a sub-routine. In this case, you don't have to + choose a unique message for `SCOPED_TRACE`. You can simply use `""`. +4. You can use `SCOPED_TRACE` in an inner scope when there is one in the outer + scope. In this case, all active trace points will be included in the failure + messages, in reverse order they are encountered. +5. The trace dump is clickable in Emacs - hit `return` on a line number and + you'll be taken to that line in the source file! + +### Propagating Fatal Failures + +A common pitfall when using `ASSERT_*` and `FAIL*` is not understanding that +when they fail they only abort the _current function_, not the entire test. For +example, the following test will segfault: + +```c++ +void Subroutine() { + // Generates a fatal failure and aborts the current function. + ASSERT_EQ(1, 2); + + // The following won't be executed. + ... +} + +TEST(FooTest, Bar) { + Subroutine(); // The intended behavior is for the fatal failure + // in Subroutine() to abort the entire test. + + // The actual behavior: the function goes on after Subroutine() returns. + int* p = nullptr; + *p = 3; // Segfault! +} +``` + +To alleviate this, GoogleTest provides three different solutions. You could use +either exceptions, the `(ASSERT|EXPECT)_NO_FATAL_FAILURE` assertions or the +`HasFatalFailure()` function. They are described in the following two +subsections. + +#### Asserting on Subroutines with an exception + +The following code can turn ASSERT-failure into an exception: + +```c++ +class ThrowListener : public testing::EmptyTestEventListener { + void OnTestPartResult(const testing::TestPartResult& result) override { + if (result.type() == testing::TestPartResult::kFatalFailure) { + throw testing::AssertionException(result); + } + } +}; +int main(int argc, char** argv) { + ... + testing::UnitTest::GetInstance()->listeners().Append(new ThrowListener); + return RUN_ALL_TESTS(); +} +``` + +This listener should be added after other listeners if you have any, otherwise +they won't see failed `OnTestPartResult`. + +#### Asserting on Subroutines + +As shown above, if your test calls a subroutine that has an `ASSERT_*` failure +in it, the test will continue after the subroutine returns. This may not be what +you want. + +Often people want fatal failures to propagate like exceptions. For that +GoogleTest offers the following macros: + +Fatal assertion | Nonfatal assertion | Verifies +------------------------------------- | ------------------------------------- | -------- +`ASSERT_NO_FATAL_FAILURE(statement);` | `EXPECT_NO_FATAL_FAILURE(statement);` | `statement` doesn't generate any new fatal failures in the current thread. + +Only failures in the thread that executes the assertion are checked to determine +the result of this type of assertions. If `statement` creates new threads, +failures in these threads are ignored. + +Examples: + +```c++ +ASSERT_NO_FATAL_FAILURE(Foo()); + +int i; +EXPECT_NO_FATAL_FAILURE({ + i = Bar(); +}); +``` + +Assertions from multiple threads are currently not supported on Windows. + +#### Checking for Failures in the Current Test + +`HasFatalFailure()` in the `::testing::Test` class returns `true` if an +assertion in the current test has suffered a fatal failure. This allows +functions to catch fatal failures in a sub-routine and return early. + +```c++ +class Test { + public: + ... + static bool HasFatalFailure(); +}; +``` + +The typical usage, which basically simulates the behavior of a thrown exception, +is: + +```c++ +TEST(FooTest, Bar) { + Subroutine(); + // Aborts if Subroutine() had a fatal failure. + if (HasFatalFailure()) return; + + // The following won't be executed. + ... +} +``` + +If `HasFatalFailure()` is used outside of `TEST()` , `TEST_F()` , or a test +fixture, you must add the `::testing::Test::` prefix, as in: + +```c++ +if (testing::Test::HasFatalFailure()) return; +``` + +Similarly, `HasNonfatalFailure()` returns `true` if the current test has at +least one non-fatal failure, and `HasFailure()` returns `true` if the current +test has at least one failure of either kind. + +## Logging Additional Information + +In your test code, you can call `RecordProperty("key", value)` to log additional +information, where `value` can be either a string or an `int`. The *last* value +recorded for a key will be emitted to the +[XML output](#generating-an-xml-report) if you specify one. For example, the +test + +```c++ +TEST_F(WidgetUsageTest, MinAndMaxWidgets) { + RecordProperty("MaximumWidgets", ComputeMaxUsage()); + RecordProperty("MinimumWidgets", ComputeMinUsage()); +} +``` + +will output XML like this: + +```xml + ... + + ... +``` + +{: .callout .note} +> NOTE: +> +> * `RecordProperty()` is a static member of the `Test` class. Therefore it +> needs to be prefixed with `::testing::Test::` if used outside of the +> `TEST` body and the test fixture class. +> * *`key`* must be a valid XML attribute name, and cannot conflict with the +> ones already used by GoogleTest (`name`, `status`, `time`, `classname`, +> `type_param`, and `value_param`). +> * Calling `RecordProperty()` outside of the lifespan of a test is allowed. +> If it's called outside of a test but between a test suite's +> `SetUpTestSuite()` and `TearDownTestSuite()` methods, it will be +> attributed to the XML element for the test suite. If it's called outside +> of all test suites (e.g. in a test environment), it will be attributed to +> the top-level XML element. + +## Sharing Resources Between Tests in the Same Test Suite + +GoogleTest creates a new test fixture object for each test in order to make +tests independent and easier to debug. However, sometimes tests use resources +that are expensive to set up, making the one-copy-per-test model prohibitively +expensive. + +If the tests don't change the resource, there's no harm in their sharing a +single resource copy. So, in addition to per-test set-up/tear-down, GoogleTest +also supports per-test-suite set-up/tear-down. To use it: + +1. In your test fixture class (say `FooTest` ), declare as `static` some member + variables to hold the shared resources. +2. Outside your test fixture class (typically just below it), define those + member variables, optionally giving them initial values. +3. In the same test fixture class, define a public member function `static void + SetUpTestSuite()` (remember not to spell it as **`SetupTestSuite`** with a + small `u`!) to set up the shared resources and a `static void + TearDownTestSuite()` function to tear them down. + +That's it! GoogleTest automatically calls `SetUpTestSuite()` before running the +*first test* in the `FooTest` test suite (i.e. before creating the first +`FooTest` object), and calls `TearDownTestSuite()` after running the *last test* +in it (i.e. after deleting the last `FooTest` object). In between, the tests can +use the shared resources. + +Remember that the test order is undefined, so your code can't depend on a test +preceding or following another. Also, the tests must either not modify the state +of any shared resource, or, if they do modify the state, they must restore the +state to its original value before passing control to the next test. + +Note that `SetUpTestSuite()` may be called multiple times for a test fixture +class that has derived classes, so you should not expect code in the function +body to be run only once. Also, derived classes still have access to shared +resources defined as static members, so careful consideration is needed when +managing shared resources to avoid memory leaks if shared resources are not +properly cleaned up in `TearDownTestSuite()`. + +Here's an example of per-test-suite set-up and tear-down: + +```c++ +class FooTest : public testing::Test { + protected: + // Per-test-suite set-up. + // Called before the first test in this test suite. + // Can be omitted if not needed. + static void SetUpTestSuite() { + shared_resource_ = new ...; + + // If `shared_resource_` is **not deleted** in `TearDownTestSuite()`, + // reallocation should be prevented because `SetUpTestSuite()` may be called + // in subclasses of FooTest and lead to memory leak. + // + // if (shared_resource_ == nullptr) { + // shared_resource_ = new ...; + // } + } + + // Per-test-suite tear-down. + // Called after the last test in this test suite. + // Can be omitted if not needed. + static void TearDownTestSuite() { + delete shared_resource_; + shared_resource_ = nullptr; + } + + // You can define per-test set-up logic as usual. + void SetUp() override { ... } + + // You can define per-test tear-down logic as usual. + void TearDown() override { ... } + + // Some expensive resource shared by all tests. + static T* shared_resource_; +}; + +T* FooTest::shared_resource_ = nullptr; + +TEST_F(FooTest, Test1) { + ... you can refer to shared_resource_ here ... +} + +TEST_F(FooTest, Test2) { + ... you can refer to shared_resource_ here ... +} +``` + +{: .callout .note} +NOTE: Though the above code declares `SetUpTestSuite()` protected, it may +sometimes be necessary to declare it public, such as when using it with +`TEST_P`. + +## Global Set-Up and Tear-Down + +Just as you can do set-up and tear-down at the test level and the test suite +level, you can also do it at the test program level. Here's how. + +First, you subclass the `::testing::Environment` class to define a test +environment, which knows how to set-up and tear-down: + +```c++ +class Environment : public ::testing::Environment { + public: + ~Environment() override {} + + // Override this to define how to set up the environment. + void SetUp() override {} + + // Override this to define how to tear down the environment. + void TearDown() override {} +}; +``` + +Then, you register an instance of your environment class with GoogleTest by +calling the `::testing::AddGlobalTestEnvironment()` function: + +```c++ +Environment* AddGlobalTestEnvironment(Environment* env); +``` + +Now, when `RUN_ALL_TESTS()` is invoked, it first calls the `SetUp()` method. The +tests are then executed, provided that none of the environments have reported +fatal failures and `GTEST_SKIP()` has not been invoked. Finally, `TearDown()` is +called. + +Note that `SetUp()` and `TearDown()` are only invoked if there is at least one +test to be performed. Importantly, `TearDown()` is executed even if the test is +not run due to a fatal failure or `GTEST_SKIP()`. + +Calling `SetUp()` and `TearDown()` for each iteration depends on the flag +`gtest_recreate_environments_when_repeating`. `SetUp()` and `TearDown()` are +called for each environment object when the object is recreated for each +iteration. However, if test environments are not recreated for each iteration, +`SetUp()` is called only on the first iteration, and `TearDown()` is called only +on the last iteration. + +It's OK to register multiple environment objects. In this suite, their `SetUp()` +will be called in the order they are registered, and their `TearDown()` will be +called in the reverse order. + +Note that GoogleTest takes ownership of the registered environment objects. +Therefore **do not delete them** by yourself. + +You should call `AddGlobalTestEnvironment()` before `RUN_ALL_TESTS()` is called, +probably in `main()`. If you use `gtest_main`, you need to call this before +`main()` starts for it to take effect. One way to do this is to define a global +variable like this: + +```c++ +testing::Environment* const foo_env = + testing::AddGlobalTestEnvironment(new FooEnvironment); +``` + +However, we strongly recommend you to write your own `main()` and call +`AddGlobalTestEnvironment()` there, as relying on initialization of global +variables makes the code harder to read and may cause problems when you register +multiple environments from different translation units and the environments have +dependencies among them (remember that the compiler doesn't guarantee the order +in which global variables from different translation units are initialized). + +## Value-Parameterized Tests + +*Value-parameterized tests* allow you to test your code with different +parameters without writing multiple copies of the same test. This is useful in a +number of situations, for example: + +* You have a piece of code whose behavior is affected by one or more + command-line flags. You want to make sure your code performs correctly for + various values of those flags. +* You want to test different implementations of an OO interface. +* You want to test your code over various inputs (a.k.a. data-driven testing). + This feature is easy to abuse, so please exercise your good sense when doing + it! + +### How to Write Value-Parameterized Tests + +To write value-parameterized tests, first you should define a fixture class. It +must be derived from both `testing::Test` and `testing::WithParamInterface` +(the latter is a pure interface), where `T` is the type of your parameter +values. For convenience, you can just derive the fixture class from +`testing::TestWithParam`, which itself is derived from both `testing::Test` +and `testing::WithParamInterface`. `T` can be any copyable type. If it's a +raw pointer, you are responsible for managing the lifespan of the pointed +values. + +{: .callout .note} +NOTE: If your test fixture defines `SetUpTestSuite()` or `TearDownTestSuite()` +they must be declared **public** rather than **protected** in order to use +`TEST_P`. + +```c++ +class FooTest : + public testing::TestWithParam { + // You can implement all the usual fixture class members here. + // To access the test parameter, call GetParam() from class + // TestWithParam. +}; + +// Or, when you want to add parameters to a pre-existing fixture class: +class BaseTest : public testing::Test { + ... +}; +class BarTest : public BaseTest, + public testing::WithParamInterface { + ... +}; +``` + +Then, use the `TEST_P` macro to define as many test patterns using this fixture +as you want. The `_P` suffix is for "parameterized" or "pattern", whichever you +prefer to think. + +```c++ +TEST_P(FooTest, DoesBlah) { + // Inside a test, access the test parameter with the GetParam() method + // of the TestWithParam class: + EXPECT_TRUE(foo.Blah(GetParam())); + ... +} + +TEST_P(FooTest, HasBlahBlah) { + ... +} +``` + +Finally, you can use the `INSTANTIATE_TEST_SUITE_P` macro to instantiate the +test suite with any set of parameters you want. GoogleTest defines a number of +functions for generating test parameters—see details at +[`INSTANTIATE_TEST_SUITE_P`](reference/testing.md#INSTANTIATE_TEST_SUITE_P) in +the Testing Reference. + +For example, the following statement will instantiate tests from the `FooTest` +test suite each with parameter values `"meeny"`, `"miny"`, and `"moe"` using the +[`Values`](reference/testing.md#param-generators) parameter generator: + +```c++ +INSTANTIATE_TEST_SUITE_P(MeenyMinyMoe, + FooTest, + testing::Values("meeny", "miny", "moe")); +``` + +{: .callout .note} +NOTE: The code above must be placed at global or namespace scope, not at +function scope. + +The first argument to `INSTANTIATE_TEST_SUITE_P` is a unique name for the +instantiation of the test suite. The next argument is the name of the test +pattern, and the last is the +[parameter generator](reference/testing.md#param-generators). + +The parameter generator expression is not evaluated until GoogleTest is +initialized (via `InitGoogleTest()`). Any prior initialization done in the +`main` function will be accessible from the parameter generator, for example, +the results of flag parsing. + +You can instantiate a test pattern more than once, so to distinguish different +instances of the pattern, the instantiation name is added as a prefix to the +actual test suite name. Remember to pick unique prefixes for different +instantiations. The tests from the instantiation above will have these names: + +* `MeenyMinyMoe/FooTest.DoesBlah/0` for `"meeny"` +* `MeenyMinyMoe/FooTest.DoesBlah/1` for `"miny"` +* `MeenyMinyMoe/FooTest.DoesBlah/2` for `"moe"` +* `MeenyMinyMoe/FooTest.HasBlahBlah/0` for `"meeny"` +* `MeenyMinyMoe/FooTest.HasBlahBlah/1` for `"miny"` +* `MeenyMinyMoe/FooTest.HasBlahBlah/2` for `"moe"` + +You can use these names in [`--gtest_filter`](#running-a-subset-of-the-tests). + +The following statement will instantiate all tests from `FooTest` again, each +with parameter values `"cat"` and `"dog"` using the +[`ValuesIn`](reference/testing.md#param-generators) parameter generator: + +```c++ +constexpr absl::string_view kPets[] = {"cat", "dog"}; +INSTANTIATE_TEST_SUITE_P(Pets, FooTest, testing::ValuesIn(kPets)); +``` + +The tests from the instantiation above will have these names: + +* `Pets/FooTest.DoesBlah/0` for `"cat"` +* `Pets/FooTest.DoesBlah/1` for `"dog"` +* `Pets/FooTest.HasBlahBlah/0` for `"cat"` +* `Pets/FooTest.HasBlahBlah/1` for `"dog"` + +Please note that `INSTANTIATE_TEST_SUITE_P` will instantiate *all* tests in the +given test suite, whether their definitions come before or *after* the +`INSTANTIATE_TEST_SUITE_P` statement. + +Additionally, by default, every `TEST_P` without a corresponding +`INSTANTIATE_TEST_SUITE_P` causes a failing test in test suite +`GoogleTestVerification`. If you have a test suite where that omission is not an +error, for example it is in a library that may be linked in for other reasons or +where the list of test cases is dynamic and may be empty, then this check can be +suppressed by tagging the test suite: + +```c++ +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(FooTest); +``` + +You can see [sample7_unittest.cc] and [sample8_unittest.cc] for more examples. + +[sample7_unittest.cc]: https://github.com/google/googletest/blob/main/googletest/samples/sample7_unittest.cc "Parameterized Test example" +[sample8_unittest.cc]: https://github.com/google/googletest/blob/main/googletest/samples/sample8_unittest.cc "Parameterized Test example with multiple parameters" + +### Creating Value-Parameterized Abstract Tests + +In the above, we define and instantiate `FooTest` in the *same* source file. +Sometimes you may want to define value-parameterized tests in a library and let +other people instantiate them later. This pattern is known as *abstract tests*. +As an example of its application, when you are designing an interface you can +write a standard suite of abstract tests (perhaps using a factory function as +the test parameter) that all implementations of the interface are expected to +pass. When someone implements the interface, they can instantiate your suite to +get all the interface-conformance tests for free. + +To define abstract tests, you should organize your code like this: + +1. Put the definition of the parameterized test fixture class (e.g. `FooTest`) + in a header file, say `foo_param_test.h`. Think of this as *declaring* your + abstract tests. +2. Put the `TEST_P` definitions in `foo_param_test.cc`, which includes + `foo_param_test.h`. Think of this as *implementing* your abstract tests. + +Once they are defined, you can instantiate them by including `foo_param_test.h`, +invoking `INSTANTIATE_TEST_SUITE_P()`, and depending on the library target that +contains `foo_param_test.cc`. You can instantiate the same abstract test suite +multiple times, possibly in different source files. + +### Specifying Names for Value-Parameterized Test Parameters + +The optional last argument to `INSTANTIATE_TEST_SUITE_P()` allows the user to +specify a function or functor that generates custom test name suffixes based on +the test parameters. The function should accept one argument of type +`testing::TestParamInfo`, and return `std::string`. + +`testing::PrintToStringParamName` is a builtin test suffix generator that +returns the value of `testing::PrintToString(GetParam())`. It does not work for +`std::string` or C strings. + +{: .callout .note} +NOTE: test names must be non-empty, unique, and may only contain ASCII +alphanumeric characters. In particular, they +[should not contain underscores](faq.md#why-should-test-suite-names-and-test-names-not-contain-underscore) + +```c++ +class MyTestSuite : public testing::TestWithParam {}; + +TEST_P(MyTestSuite, MyTest) +{ + std::cout << "Example Test Param: " << GetParam() << std::endl; +} + +INSTANTIATE_TEST_SUITE_P(MyGroup, MyTestSuite, testing::Range(0, 10), + testing::PrintToStringParamName()); +``` + +Providing a custom functor allows for more control over test parameter name +generation, especially for types where the automatic conversion does not +generate helpful parameter names (e.g. strings as demonstrated above). The +following example illustrates this for multiple parameters, an enumeration type +and a string, and also demonstrates how to combine generators. It uses a lambda +for conciseness: + +```c++ +enum class MyType { MY_FOO = 0, MY_BAR = 1 }; + +class MyTestSuite : public testing::TestWithParam> { +}; + +INSTANTIATE_TEST_SUITE_P( + MyGroup, MyTestSuite, + testing::Combine( + testing::Values(MyType::MY_FOO, MyType::MY_BAR), + testing::Values("A", "B")), + [](const testing::TestParamInfo& info) { + std::string name = absl::StrCat( + std::get<0>(info.param) == MyType::MY_FOO ? "Foo" : "Bar", + std::get<1>(info.param)); + absl::c_replace_if(name, [](char c) { return !std::isalnum(c); }, '_'); + return name; + }); +``` + +## Typed Tests + +Suppose you have multiple implementations of the same interface and want to make +sure that all of them satisfy some common requirements. Or, you may have defined +several types that are supposed to conform to the same "concept" and you want to +verify it. In both cases, you want the same test logic repeated for different +types. + +While you can write one `TEST` or `TEST_F` for each type you want to test (and +you may even factor the test logic into a function template that you invoke from +the `TEST`), it's tedious and doesn't scale: if you want `m` tests over `n` +types, you'll end up writing `m*n` `TEST`s. + +*Typed tests* allow you to repeat the same test logic over a list of types. You +only need to write the test logic once, although you must know the type list +when writing typed tests. Here's how you do it: + +First, define a fixture class template. It should be parameterized by a type. +Remember to derive it from `::testing::Test`: + +```c++ +template +class FooTest : public testing::Test { + public: + ... + using List = std::list; + static T shared_; + T value_; +}; +``` + +Next, associate a list of types with the test suite, which will be repeated for +each type in the list: + +```c++ +using MyTypes = ::testing::Types; +TYPED_TEST_SUITE(FooTest, MyTypes); +``` + +The type alias (`using` or `typedef`) is necessary for the `TYPED_TEST_SUITE` +macro to parse correctly. Otherwise the compiler will think that each comma in +the type list introduces a new macro argument. + +Then, use `TYPED_TEST()` instead of `TEST_F()` to define a typed test for this +test suite. You can repeat this as many times as you want: + +```c++ +TYPED_TEST(FooTest, DoesBlah) { + // Inside a test, refer to the special name TypeParam to get the type + // parameter. Since we are inside a derived class template, C++ requires + // us to visit the members of FooTest via 'this'. + TypeParam n = this->value_; + + // To visit static members of the fixture, add the 'TestFixture::' + // prefix. + n += TestFixture::shared_; + + // To refer to typedefs in the fixture, add the 'typename TestFixture::' + // prefix. The 'typename' is required to satisfy the compiler. + typename TestFixture::List values; + + values.push_back(n); + ... +} + +TYPED_TEST(FooTest, HasPropertyA) { ... } +``` + +You can see [sample6_unittest.cc] for a complete example. + +[sample6_unittest.cc]: https://github.com/google/googletest/blob/main/googletest/samples/sample6_unittest.cc "Typed Test example" + +## Type-Parameterized Tests + +*Type-parameterized tests* are like typed tests, except that they don't require +you to know the list of types ahead of time. Instead, you can define the test +logic first and instantiate it with different type lists later. You can even +instantiate it more than once in the same program. + +If you are designing an interface or concept, you can define a suite of +type-parameterized tests to verify properties that any valid implementation of +the interface/concept should have. Then, the author of each implementation can +just instantiate the test suite with their type to verify that it conforms to +the requirements, without having to write similar tests repeatedly. Here's an +example: + +First, define a fixture class template, as we did with typed tests: + +```c++ +template +class FooTest : public testing::Test { + void DoSomethingInteresting(); + ... +}; +``` + +Next, declare that you will define a type-parameterized test suite: + +```c++ +TYPED_TEST_SUITE_P(FooTest); +``` + +Then, use `TYPED_TEST_P()` to define a type-parameterized test. You can repeat +this as many times as you want: + +```c++ +TYPED_TEST_P(FooTest, DoesBlah) { + // Inside a test, refer to TypeParam to get the type parameter. + TypeParam n = 0; + + // You will need to use `this` explicitly to refer to fixture members. + this->DoSomethingInteresting() + ... +} + +TYPED_TEST_P(FooTest, HasPropertyA) { ... } +``` + +Now the tricky part: you need to register all test patterns using the +`REGISTER_TYPED_TEST_SUITE_P` macro before you can instantiate them. The first +argument of the macro is the test suite name; the rest are the names of the +tests in this test suite: + +```c++ +REGISTER_TYPED_TEST_SUITE_P(FooTest, + DoesBlah, HasPropertyA); +``` + +Finally, you are free to instantiate the pattern with the types you want. If you +put the above code in a header file, you can `#include` it in multiple C++ +source files and instantiate it multiple times. + +```c++ +using MyTypes = ::testing::Types; +INSTANTIATE_TYPED_TEST_SUITE_P(My, FooTest, MyTypes); +``` + +To distinguish different instances of the pattern, the first argument to the +`INSTANTIATE_TYPED_TEST_SUITE_P` macro is a prefix that will be added to the +actual test suite name. Remember to pick unique prefixes for different +instances. + +In the special case where the type list contains only one type, you can write +that type directly without `::testing::Types<...>`, like this: + +```c++ +INSTANTIATE_TYPED_TEST_SUITE_P(My, FooTest, int); +``` + +You can see [sample6_unittest.cc] for a complete example. + +## Testing Private Code + +If you change your software's internal implementation, your tests should not +break as long as the change is not observable by users. Therefore, **per the +black-box testing principle, most of the time you should test your code through +its public interfaces.** + +**If you still find yourself needing to test internal implementation code, +consider if there's a better design.** The desire to test internal +implementation is often a sign that the class is doing too much. Consider +extracting an implementation class, and testing it. Then use that implementation +class in the original class. + +If you absolutely have to test non-public interface code though, you can. There +are two cases to consider: + +* Static functions ( *not* the same as static member functions!) or unnamed + namespaces, and +* Private or protected class members + +To test them, we use the following special techniques: + +* Both static functions and definitions/declarations in an unnamed namespace + are only visible within the same translation unit. To test them, you can + `#include` the entire `.cc` file being tested in your `*_test.cc` file. + (#including `.cc` files is not a good way to reuse code - you should not do + this in production code!) + + However, a better approach is to move the private code into the + `foo::internal` namespace, where `foo` is the namespace your project + normally uses, and put the private declarations in a `*-internal.h` file. + Your production `.cc` files and your tests are allowed to include this + internal header, but your clients are not. This way, you can fully test your + internal implementation without leaking it to your clients. + +* Private class members are only accessible from within the class or by + friends. To access a class' private members, you can declare your test + fixture as a friend to the class and define accessors in your fixture. Tests + using the fixture can then access the private members of your production + class via the accessors in the fixture. Note that even though your fixture + is a friend to your production class, your tests are not automatically + friends to it, as they are technically defined in sub-classes of the + fixture. + + Another way to test private members is to refactor them into an + implementation class, which is then declared in a `*-internal.h` file. Your + clients aren't allowed to include this header but your tests can. Such is + called the + [Pimpl](https://www.gamedev.net/articles/programming/general-and-gameplay-programming/the-c-pimpl-r1794/) + (Private Implementation) idiom. + + Or, you can declare an individual test as a friend of your class by adding + this line in the class body: + + ```c++ + FRIEND_TEST(TestSuiteName, TestName); + ``` + + For example, + + ```c++ + // foo.h + class Foo { + ... + private: + FRIEND_TEST(FooTest, BarReturnsZeroOnNull); + + int Bar(void* x); + }; + + // foo_test.cc + ... + TEST(FooTest, BarReturnsZeroOnNull) { + Foo foo; + EXPECT_EQ(foo.Bar(NULL), 0); // Uses Foo's private member Bar(). + } + ``` + + Pay special attention when your class is defined in a namespace. If you want + your test fixtures and tests to be friends of your class, then they must be + defined in the exact same namespace (no anonymous or inline namespaces). + + For example, if the code to be tested looks like: + + ```c++ + namespace my_namespace { + + class Foo { + friend class FooTest; + FRIEND_TEST(FooTest, Bar); + FRIEND_TEST(FooTest, Baz); + ... definition of the class Foo ... + }; + + } // namespace my_namespace + ``` + + Your test code should be something like: + + ```c++ + namespace my_namespace { + + class FooTest : public testing::Test { + protected: + ... + }; + + TEST_F(FooTest, Bar) { ... } + TEST_F(FooTest, Baz) { ... } + + } // namespace my_namespace + ``` + +## "Catching" Failures + +If you are building a testing utility on top of GoogleTest, you'll want to test +your utility. What framework would you use to test it? GoogleTest, of course. + +The challenge is to verify that your testing utility reports failures correctly. +In frameworks that report a failure by throwing an exception, you could catch +the exception and assert on it. But GoogleTest doesn't use exceptions, so how do +we test that a piece of code generates an expected failure? + +`"gtest/gtest-spi.h"` contains some constructs to do this. +After #including this header, you can use + +```c++ + EXPECT_FATAL_FAILURE(statement, substring); +``` + +to assert that `statement` generates a fatal (e.g. `ASSERT_*`) failure in the +current thread whose message contains the given `substring`, or use + +```c++ + EXPECT_NONFATAL_FAILURE(statement, substring); +``` + +if you are expecting a non-fatal (e.g. `EXPECT_*`) failure. + +Only failures in the current thread are checked to determine the result of this +type of expectations. If `statement` creates new threads, failures in these +threads are also ignored. If you want to catch failures in other threads as +well, use one of the following macros instead: + +```c++ + EXPECT_FATAL_FAILURE_ON_ALL_THREADS(statement, substring); + EXPECT_NONFATAL_FAILURE_ON_ALL_THREADS(statement, substring); +``` + +{: .callout .note} +NOTE: Assertions from multiple threads are currently not supported on Windows. + +For technical reasons, there are some caveats: + +1. You cannot stream a failure message to either macro. + +2. `statement` in `EXPECT_FATAL_FAILURE{_ON_ALL_THREADS}()` cannot reference + local non-static variables or non-static members of `this` object. + +3. `statement` in `EXPECT_FATAL_FAILURE{_ON_ALL_THREADS}()` cannot return a + value. + +## Registering tests programmatically + +The `TEST` macros handle the vast majority of all use cases, but there are few +where runtime registration logic is required. For those cases, the framework +provides the `::testing::RegisterTest` that allows callers to register arbitrary +tests dynamically. + +This is an advanced API only to be used when the `TEST` macros are insufficient. +The macros should be preferred when possible, as they avoid most of the +complexity of calling this function. + +It provides the following signature: + +```c++ +template +TestInfo* RegisterTest(const char* test_suite_name, const char* test_name, + const char* type_param, const char* value_param, + const char* file, int line, Factory factory); +``` + +The `factory` argument is a factory callable (move-constructible) object or +function pointer that creates a new instance of the Test object. It handles +ownership to the caller. The signature of the callable is `Fixture*()`, where +`Fixture` is the test fixture class for the test. All tests registered with the +same `test_suite_name` must return the same fixture type. This is checked at +runtime. + +The framework will infer the fixture class from the factory and will call the +`SetUpTestSuite` and `TearDownTestSuite` for it. + +Must be called before `RUN_ALL_TESTS()` is invoked, otherwise behavior is +undefined. + +Use case example: + +```c++ +class MyFixture : public testing::Test { + public: + // All of these optional, just like in regular macro usage. + static void SetUpTestSuite() { ... } + static void TearDownTestSuite() { ... } + void SetUp() override { ... } + void TearDown() override { ... } +}; + +class MyTest : public MyFixture { + public: + explicit MyTest(int data) : data_(data) {} + void TestBody() override { ... } + + private: + int data_; +}; + +void RegisterMyTests(const std::vector& values) { + for (int v : values) { + testing::RegisterTest( + "MyFixture", ("Test" + std::to_string(v)).c_str(), nullptr, + std::to_string(v).c_str(), + __FILE__, __LINE__, + // Important to use the fixture type as the return type here. + [=]() -> MyFixture* { return new MyTest(v); }); + } +} +... +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + std::vector values_to_test = LoadValuesFromConfig(); + RegisterMyTests(values_to_test); + ... + return RUN_ALL_TESTS(); +} +``` + +## Getting the Current Test's Name + +Sometimes a function may need to know the name of the currently running test. +For example, you may be using the `SetUp()` method of your test fixture to set +the golden file name based on which test is running. The +[`TestInfo`](reference/testing.md#TestInfo) class has this information. + +To obtain a `TestInfo` object for the currently running test, call +`current_test_info()` on the [`UnitTest`](reference/testing.md#UnitTest) +singleton object: + +```c++ + // Gets information about the currently running test. + // Do NOT delete the returned object - it's managed by the UnitTest class. + const testing::TestInfo* const test_info = + testing::UnitTest::GetInstance()->current_test_info(); + + printf("We are in test %s of test suite %s.\n", + test_info->name(), + test_info->test_suite_name()); +``` + +`current_test_info()` returns a null pointer if no test is running. In +particular, you cannot find the test suite name in `SetUpTestSuite()`, +`TearDownTestSuite()` (where you know the test suite name implicitly), or +functions called from them. + +## Extending GoogleTest by Handling Test Events + +GoogleTest provides an **event listener API** to let you receive notifications +about the progress of a test program and test failures. The events you can +listen to include the start and end of the test program, a test suite, or a test +method, among others. You may use this API to augment or replace the standard +console output, replace the XML output, or provide a completely different form +of output, such as a GUI or a database. You can also use test events as +checkpoints to implement a resource leak checker, for example. + +### Defining Event Listeners + +To define a event listener, you subclass either +[`testing::TestEventListener`](reference/testing.md#TestEventListener) or +[`testing::EmptyTestEventListener`](reference/testing.md#EmptyTestEventListener) +The former is an (abstract) interface, where *each pure virtual method can be +overridden to handle a test event* (For example, when a test starts, the +`OnTestStart()` method will be called.). The latter provides an empty +implementation of all methods in the interface, such that a subclass only needs +to override the methods it cares about. + +When an event is fired, its context is passed to the handler function as an +argument. The following argument types are used: + +* UnitTest reflects the state of the entire test program, +* TestSuite has information about a test suite, which can contain one or more + tests, +* TestInfo contains the state of a test, and +* TestPartResult represents the result of a test assertion. + +An event handler function can examine the argument it receives to find out +interesting information about the event and the test program's state. + +Here's an example: + +```c++ + class MinimalistPrinter : public testing::EmptyTestEventListener { + // Called before a test starts. + void OnTestStart(const testing::TestInfo& test_info) override { + printf("*** Test %s.%s starting.\n", + test_info.test_suite_name(), test_info.name()); + } + + // Called after a failed assertion or a SUCCESS(). + void OnTestPartResult(const testing::TestPartResult& test_part_result) override { + printf("%s in %s:%d\n%s\n", + test_part_result.failed() ? "*** Failure" : "Success", + test_part_result.file_name(), + test_part_result.line_number(), + test_part_result.summary()); + } + + // Called after a test ends. + void OnTestEnd(const testing::TestInfo& test_info) override { + printf("*** Test %s.%s ending.\n", + test_info.test_suite_name(), test_info.name()); + } + }; +``` + +### Using Event Listeners + +To use the event listener you have defined, add an instance of it to the +GoogleTest event listener list (represented by class +[`TestEventListeners`](reference/testing.md#TestEventListeners) - note the "s" +at the end of the name) in your `main()` function, before calling +`RUN_ALL_TESTS()`: + +```c++ +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + // Gets hold of the event listener list. + testing::TestEventListeners& listeners = + testing::UnitTest::GetInstance()->listeners(); + // Adds a listener to the end. GoogleTest takes the ownership. + listeners.Append(new MinimalistPrinter); + return RUN_ALL_TESTS(); +} +``` + +There's only one problem: the default test result printer is still in effect, so +its output will mingle with the output from your minimalist printer. To suppress +the default printer, just release it from the event listener list and delete it. +You can do so by adding one line: + +```c++ + ... + delete listeners.Release(listeners.default_result_printer()); + listeners.Append(new MinimalistPrinter); + return RUN_ALL_TESTS(); +``` + +Now, sit back and enjoy a completely different output from your tests. For more +details, see [sample9_unittest.cc]. + +[sample9_unittest.cc]: https://github.com/google/googletest/blob/main/googletest/samples/sample9_unittest.cc "Event listener example" + +You may append more than one listener to the list. When an `On*Start()` or +`OnTestPartResult()` event is fired, the listeners will receive it in the order +they appear in the list (since new listeners are added to the end of the list, +the default text printer and the default XML generator will receive the event +first). An `On*End()` event will be received by the listeners in the *reverse* +order. This allows output by listeners added later to be framed by output from +listeners added earlier. + +### Generating Failures in Listeners + +You may use failure-raising macros (`EXPECT_*()`, `ASSERT_*()`, `FAIL()`, etc) +when processing an event. There are some restrictions: + +1. You cannot generate any failure in `OnTestPartResult()` (otherwise it will + cause `OnTestPartResult()` to be called recursively). +2. A listener that handles `OnTestPartResult()` is not allowed to generate any + failure. + +When you add listeners to the listener list, you should put listeners that +handle `OnTestPartResult()` *before* listeners that can generate failures. This +ensures that failures generated by the latter are attributed to the right test +by the former. + +See [sample10_unittest.cc] for an example of a failure-raising listener. + +[sample10_unittest.cc]: https://github.com/google/googletest/blob/main/googletest/samples/sample10_unittest.cc "Failure-raising listener example" + +## Running Test Programs: Advanced Options + +GoogleTest test programs are ordinary executables. Once built, you can run them +directly and affect their behavior via the following environment variables +and/or command line flags. For the flags to work, your programs must call +`::testing::InitGoogleTest()` before calling `RUN_ALL_TESTS()`. + +To see a list of supported flags and their usage, please run your test program +with the `--help` flag. + +If an option is specified both by an environment variable and by a flag, the +latter takes precedence. + +### Selecting Tests + +#### Listing Test Names + +Sometimes it is necessary to list the available tests in a program before +running them so that a filter may be applied if needed. Including the flag +`--gtest_list_tests` overrides all other flags and lists tests in the following +format: + +```none +TestSuite1. + TestName1 + TestName2 +TestSuite2. + TestName +``` + +None of the tests listed are actually run if the flag is provided. There is no +corresponding environment variable for this flag. + +#### Running a Subset of the Tests + +By default, a GoogleTest program runs all tests the user has defined. Sometimes, +you want to run only a subset of the tests (e.g. for debugging or quickly +verifying a change). If you set the `GTEST_FILTER` environment variable or the +`--gtest_filter` flag to a filter string, GoogleTest will only run the tests +whose full names (in the form of `TestSuiteName.TestName`) match the filter. + +The format of a filter is a '`:`'-separated list of wildcard patterns (called +the *positive patterns*) optionally followed by a '`-`' and another +'`:`'-separated pattern list (called the *negative patterns*). A test matches +the filter if and only if it matches any of the positive patterns but does not +match any of the negative patterns. + +A pattern may contain `'*'` (matches any string) or `'?'` (matches any single +character). For convenience, the filter `'*-NegativePatterns'` can be also +written as `'-NegativePatterns'`. + +For example: + +* `./foo_test` Has no flag, and thus runs all its tests. +* `./foo_test --gtest_filter=*` Also runs everything, due to the single + match-everything `*` value. +* `./foo_test --gtest_filter=FooTest.*` Runs everything in test suite + `FooTest` . +* `./foo_test --gtest_filter=*Null*:*Constructor*` Runs any test whose full + name contains either `"Null"` or `"Constructor"` . +* `./foo_test --gtest_filter=-*DeathTest.*` Runs all non-death tests. +* `./foo_test --gtest_filter=FooTest.*-FooTest.Bar` Runs everything in test + suite `FooTest` except `FooTest.Bar`. +* `./foo_test --gtest_filter=FooTest.*:BarTest.*-FooTest.Bar:BarTest.Foo` Runs + everything in test suite `FooTest` except `FooTest.Bar` and everything in + test suite `BarTest` except `BarTest.Foo`. + +#### Stop test execution upon first failure + +By default, a GoogleTest program runs all tests the user has defined. In some +cases (e.g. iterative test development & execution) it may be desirable stop +test execution upon first failure (trading improved latency for completeness). +If `GTEST_FAIL_FAST` environment variable or `--gtest_fail_fast` flag is set, +the test runner will stop execution as soon as the first test failure is found. + +#### Temporarily Disabling Tests + +If you have a broken test that you cannot fix right away, you can add the +`DISABLED_` prefix to its name. This will exclude it from execution. This is +better than commenting out the code or using `#if 0`, as disabled tests are +still compiled (and thus won't rot). + +If you need to disable all tests in a test suite, you can either add `DISABLED_` +to the front of the name of each test, or alternatively add it to the front of +the test suite name. + +For example, the following tests won't be run by GoogleTest, even though they +will still be compiled: + +```c++ +// Tests that Foo does Abc. +TEST(FooTest, DISABLED_DoesAbc) { ... } + +class DISABLED_BarTest : public testing::Test { ... }; + +// Tests that Bar does Xyz. +TEST_F(DISABLED_BarTest, DoesXyz) { ... } +``` + +{: .callout .note} +NOTE: This feature should only be used for temporary pain-relief. You still have +to fix the disabled tests at a later date. As a reminder, GoogleTest will print +a banner warning you if a test program contains any disabled tests. + +{: .callout .tip} +TIP: You can easily count the number of disabled tests you have using +`grep`. This number can be used as a metric for +improving your test quality. + +#### Temporarily Enabling Disabled Tests + +To include disabled tests in test execution, just invoke the test program with +the `--gtest_also_run_disabled_tests` flag or set the +`GTEST_ALSO_RUN_DISABLED_TESTS` environment variable to a value other than `0`. +You can combine this with the `--gtest_filter` flag to further select which +disabled tests to run. + +### Repeating the Tests + +Once in a while you'll run into a test whose result is hit-or-miss. Perhaps it +will fail only 1% of the time, making it rather hard to reproduce the bug under +a debugger. This can be a major source of frustration. + +The `--gtest_repeat` flag allows you to repeat all (or selected) test methods in +a program many times. Hopefully, a flaky test will eventually fail and give you +a chance to debug. Here's how to use it: + +```none +$ foo_test --gtest_repeat=1000 +Repeat foo_test 1000 times and don't stop at failures. + +$ foo_test --gtest_repeat=-1 +A negative count means repeating forever. + +$ foo_test --gtest_repeat=1000 --gtest_break_on_failure +Repeat foo_test 1000 times, stopping at the first failure. This +is especially useful when running under a debugger: when the test +fails, it will drop into the debugger and you can then inspect +variables and stacks. + +$ foo_test --gtest_repeat=1000 --gtest_filter=FooBar.* +Repeat the tests whose name matches the filter 1000 times. +``` + +If your test program contains +[global set-up/tear-down](#global-set-up-and-tear-down) code, it will be +repeated in each iteration as well, as the flakiness may be in it. To avoid +repeating global set-up/tear-down, specify +`--gtest_recreate_environments_when_repeating=false`{.nowrap}. + +You can also specify the repeat count by setting the `GTEST_REPEAT` environment +variable. + +### Shuffling the Tests + +You can specify the `--gtest_shuffle` flag (or set the `GTEST_SHUFFLE` +environment variable to `1`) to run the tests in a program in a random order. +This helps to reveal bad dependencies between tests. + +By default, GoogleTest uses a random seed calculated from the current time. +Therefore you'll get a different order every time. The console output includes +the random seed value, such that you can reproduce an order-related test failure +later. To specify the random seed explicitly, use the `--gtest_random_seed=SEED` +flag (or set the `GTEST_RANDOM_SEED` environment variable), where `SEED` is an +integer in the range [0, 99999]. The seed value 0 is special: it tells +GoogleTest to do the default behavior of calculating the seed from the current +time. + +If you combine this with `--gtest_repeat=N`, GoogleTest will pick a different +random seed and re-shuffle the tests in each iteration. + +### Distributing Test Functions to Multiple Machines + +If you have more than one machine you can use to run a test program, you might +want to run the test functions in parallel and get the result faster. We call +this technique *sharding*, where each machine is called a *shard*. + +GoogleTest is compatible with test sharding. To take advantage of this feature, +your test runner (not part of GoogleTest) needs to do the following: + +1. Allocate a number of machines (shards) to run the tests. +1. On each shard, set the `GTEST_TOTAL_SHARDS` environment variable to the total + number of shards. It must be the same for all shards. +1. On each shard, set the `GTEST_SHARD_INDEX` environment variable to the index + of the shard. Different shards must be assigned different indices, which + must be in the range `[0, GTEST_TOTAL_SHARDS - 1]`. +1. Run the same test program on all shards. When GoogleTest sees the above two + environment variables, it will select a subset of the test functions to run. + Across all shards, each test function in the program will be run exactly + once. +1. Wait for all shards to finish, then collect and report the results. + +Your project may have tests that were written without GoogleTest and thus don't +understand this protocol. In order for your test runner to figure out which test +supports sharding, it can set the environment variable `GTEST_SHARD_STATUS_FILE` +to a non-existent file path. If a test program supports sharding, it will create +this file to acknowledge that fact; otherwise it will not create it. The actual +contents of the file are not important at this time, although we may put some +useful information in it in the future. + +Here's an example to make it clear. Suppose you have a test program `foo_test` +that contains the following 5 test functions: + +``` +TEST(A, V) +TEST(A, W) +TEST(B, X) +TEST(B, Y) +TEST(B, Z) +``` + +Suppose you have 3 machines at your disposal. To run the test functions in +parallel, you would set `GTEST_TOTAL_SHARDS` to 3 on all machines, and set +`GTEST_SHARD_INDEX` to 0, 1, and 2 on the machines respectively. Then you would +run the same `foo_test` on each machine. + +GoogleTest reserves the right to change how the work is distributed across the +shards, but here's one possible scenario: + +* Machine #0 runs `A.V` and `B.X`. +* Machine #1 runs `A.W` and `B.Y`. +* Machine #2 runs `B.Z`. + +### Controlling Test Output + +#### Colored Terminal Output + +GoogleTest can use colors in its terminal output to make it easier to spot the +important information: + +
...
+[----------] 1 test from FooTest
+[ RUN      ] FooTest.DoesAbc
+[       OK ] FooTest.DoesAbc
+[----------] 2 tests from BarTest
+[ RUN      ] BarTest.HasXyzProperty
+[       OK ] BarTest.HasXyzProperty
+[ RUN      ] BarTest.ReturnsTrueOnSuccess
+... some error messages ...
+[   FAILED ] BarTest.ReturnsTrueOnSuccess
+...
+[==========] 30 tests from 14 test suites ran.
+[   PASSED ] 28 tests.
+[   FAILED ] 2 tests, listed below:
+[   FAILED ] BarTest.ReturnsTrueOnSuccess
+[   FAILED ] AnotherTest.DoesXyz
+
+ 2 FAILED TESTS
+
+ +You can set the `GTEST_COLOR` environment variable or the `--gtest_color` +command line flag to `yes`, `no`, or `auto` (the default) to enable colors, +disable colors, or let GoogleTest decide. When the value is `auto`, GoogleTest +will use colors if and only if the output goes to a terminal and (on non-Windows +platforms) the `TERM` environment variable is set to `xterm` or `xterm-color`. + +#### Suppressing test passes + +By default, GoogleTest prints 1 line of output for each test, indicating if it +passed or failed. To show only test failures, run the test program with +`--gtest_brief=1`, or set the GTEST_BRIEF environment variable to `1`. + +#### Suppressing the Elapsed Time + +By default, GoogleTest prints the time it takes to run each test. To disable +that, run the test program with the `--gtest_print_time=0` command line flag, or +set the GTEST_PRINT_TIME environment variable to `0`. + +#### Suppressing UTF-8 Text Output + +In case of assertion failures, GoogleTest prints expected and actual values of +type `string` both as hex-encoded strings as well as in readable UTF-8 text if +they contain valid non-ASCII UTF-8 characters. If you want to suppress the UTF-8 +text because, for example, you don't have an UTF-8 compatible output medium, run +the test program with `--gtest_print_utf8=0` or set the `GTEST_PRINT_UTF8` +environment variable to `0`. + +#### Generating an XML Report + +GoogleTest can emit a detailed XML report to a file in addition to its normal +textual output. The report contains the duration of each test, and thus can help +you identify slow tests. + +To generate the XML report, set the `GTEST_OUTPUT` environment variable or the +`--gtest_output` flag to the string `"xml:path_to_output_file"`, which will +create the file at the given location. You can also just use the string `"xml"`, +in which case the output can be found in the `test_detail.xml` file in the +current directory. + +If you specify a directory (for example, `"xml:output/directory/"` on Linux or +`"xml:output\directory\"` on Windows), GoogleTest will create the XML file in +that directory, named after the test executable (e.g. `foo_test.xml` for test +program `foo_test` or `foo_test.exe`). If the file already exists (perhaps left +over from a previous run), GoogleTest will pick a different name (e.g. +`foo_test_1.xml`) to avoid overwriting it. + +The report is based on the `junitreport` Ant task. Since that format was +originally intended for Java, a little interpretation is required to make it +apply to GoogleTest tests, as shown here: + +```xml + + + + + + + + + +``` + +* The root `` element corresponds to the entire test program. +* `` elements correspond to GoogleTest test suites. +* `` elements correspond to GoogleTest test functions. + +For instance, the following program + +```c++ +TEST(MathTest, Addition) { ... } +TEST(MathTest, Subtraction) { ... } +TEST(LogicTest, NonContradiction) { ... } +``` + +could generate this report: + +```xml + + + + + ... + ... + + + + + + + + + +``` + +Things to note: + +* The `tests` attribute of a `` or `` element tells how + many test functions the GoogleTest program or test suite contains, while the + `failures` attribute tells how many of them failed. + +* The `time` attribute expresses the duration of the test, test suite, or + entire test program in seconds. + +* The `timestamp` attribute records the local date and time of the test + execution. + +* The `file` and `line` attributes record the source file location, where the + test was defined. + +* Each `` element corresponds to a single failed GoogleTest + assertion. + +#### Generating a JSON Report + +GoogleTest can also emit a JSON report as an alternative format to XML. To +generate the JSON report, set the `GTEST_OUTPUT` environment variable or the +`--gtest_output` flag to the string `"json:path_to_output_file"`, which will +create the file at the given location. You can also just use the string +`"json"`, in which case the output can be found in the `test_detail.json` file +in the current directory. + +The report format conforms to the following JSON Schema: + +```json +{ + "$schema": "https://json-schema.org/schema#", + "type": "object", + "definitions": { + "TestCase": { + "type": "object", + "properties": { + "name": { "type": "string" }, + "tests": { "type": "integer" }, + "failures": { "type": "integer" }, + "disabled": { "type": "integer" }, + "time": { "type": "string" }, + "testsuite": { + "type": "array", + "items": { + "$ref": "#/definitions/TestInfo" + } + } + } + }, + "TestInfo": { + "type": "object", + "properties": { + "name": { "type": "string" }, + "file": { "type": "string" }, + "line": { "type": "integer" }, + "status": { + "type": "string", + "enum": ["RUN", "NOTRUN"] + }, + "time": { "type": "string" }, + "classname": { "type": "string" }, + "failures": { + "type": "array", + "items": { + "$ref": "#/definitions/Failure" + } + } + } + }, + "Failure": { + "type": "object", + "properties": { + "failures": { "type": "string" }, + "type": { "type": "string" } + } + } + }, + "properties": { + "tests": { "type": "integer" }, + "failures": { "type": "integer" }, + "disabled": { "type": "integer" }, + "errors": { "type": "integer" }, + "timestamp": { + "type": "string", + "format": "date-time" + }, + "time": { "type": "string" }, + "name": { "type": "string" }, + "testsuites": { + "type": "array", + "items": { + "$ref": "#/definitions/TestCase" + } + } + } +} +``` + +The report uses the format that conforms to the following Proto3 using the +[JSON encoding](https://developers.google.com/protocol-buffers/docs/proto3#json): + +```proto +syntax = "proto3"; + +package googletest; + +import "google/protobuf/timestamp.proto"; +import "google/protobuf/duration.proto"; + +message UnitTest { + int32 tests = 1; + int32 failures = 2; + int32 disabled = 3; + int32 errors = 4; + google.protobuf.Timestamp timestamp = 5; + google.protobuf.Duration time = 6; + string name = 7; + repeated TestCase testsuites = 8; +} + +message TestCase { + string name = 1; + int32 tests = 2; + int32 failures = 3; + int32 disabled = 4; + int32 errors = 5; + google.protobuf.Duration time = 6; + repeated TestInfo testsuite = 7; +} + +message TestInfo { + string name = 1; + string file = 6; + int32 line = 7; + enum Status { + RUN = 0; + NOTRUN = 1; + } + Status status = 2; + google.protobuf.Duration time = 3; + string classname = 4; + message Failure { + string failures = 1; + string type = 2; + } + repeated Failure failures = 5; +} +``` + +For instance, the following program + +```c++ +TEST(MathTest, Addition) { ... } +TEST(MathTest, Subtraction) { ... } +TEST(LogicTest, NonContradiction) { ... } +``` + +could generate this report: + +```json +{ + "tests": 3, + "failures": 1, + "errors": 0, + "time": "0.035s", + "timestamp": "2011-10-31T18:52:42Z", + "name": "AllTests", + "testsuites": [ + { + "name": "MathTest", + "tests": 2, + "failures": 1, + "errors": 0, + "time": "0.015s", + "testsuite": [ + { + "name": "Addition", + "file": "test.cpp", + "line": 1, + "status": "RUN", + "time": "0.007s", + "classname": "", + "failures": [ + { + "message": "Value of: add(1, 1)\n Actual: 3\nExpected: 2", + "type": "" + }, + { + "message": "Value of: add(1, -1)\n Actual: 1\nExpected: 0", + "type": "" + } + ] + }, + { + "name": "Subtraction", + "file": "test.cpp", + "line": 2, + "status": "RUN", + "time": "0.005s", + "classname": "" + } + ] + }, + { + "name": "LogicTest", + "tests": 1, + "failures": 0, + "errors": 0, + "time": "0.005s", + "testsuite": [ + { + "name": "NonContradiction", + "file": "test.cpp", + "line": 3, + "status": "RUN", + "time": "0.005s", + "classname": "" + } + ] + } + ] +} +``` + +{: .callout .important} +IMPORTANT: The exact format of the JSON document is subject to change. + +### Controlling How Failures Are Reported + +#### Detecting Test Premature Exit + +Google Test implements the _premature-exit-file_ protocol for test runners to +catch any kind of unexpected exits of test programs. Upon start, Google Test +creates the file which will be automatically deleted after all work has been +finished. Then, the test runner can check if this file exists. In case the file +remains undeleted, the inspected test has exited prematurely. + +This feature is enabled only if the `TEST_PREMATURE_EXIT_FILE` environment +variable has been set. + +#### Turning Assertion Failures into Break-Points + +When running test programs under a debugger, it's very convenient if the +debugger can catch an assertion failure and automatically drop into interactive +mode. GoogleTest's *break-on-failure* mode supports this behavior. + +To enable it, set the `GTEST_BREAK_ON_FAILURE` environment variable to a value +other than `0`. Alternatively, you can use the `--gtest_break_on_failure` +command line flag. + +#### Disabling Catching Test-Thrown Exceptions + +GoogleTest can be used either with or without exceptions enabled. If a test +throws a C++ exception or (on Windows) a structured exception (SEH), by default +GoogleTest catches it, reports it as a test failure, and continues with the next +test method. This maximizes the coverage of a test run. Also, on Windows an +uncaught exception will cause a pop-up window, so catching the exceptions allows +you to run the tests automatically. + +When debugging the test failures, however, you may instead want the exceptions +to be handled by the debugger, such that you can examine the call stack when an +exception is thrown. To achieve that, set the `GTEST_CATCH_EXCEPTIONS` +environment variable to `0`, or use the `--gtest_catch_exceptions=0` flag when +running the tests. + +### Sanitizer Integration + +The +[Undefined Behavior Sanitizer](https://clang.llvm.org/docs/UndefinedBehaviorSanitizer.html), +[Address Sanitizer](https://github.com/google/sanitizers/wiki/AddressSanitizer), +and +[Thread Sanitizer](https://github.com/google/sanitizers/wiki/ThreadSanitizerCppManual) +all provide weak functions that you can override to trigger explicit failures +when they detect sanitizer errors, such as creating a reference from `nullptr`. +To override these functions, place definitions for them in a source file that +you compile as part of your main binary: + +``` +extern "C" { +void __ubsan_on_report() { + FAIL() << "Encountered an undefined behavior sanitizer error"; +} +void __asan_on_error() { + FAIL() << "Encountered an address sanitizer error"; +} +void __tsan_on_report() { + FAIL() << "Encountered a thread sanitizer error"; +} +} // extern "C" +``` + +After compiling your project with one of the sanitizers enabled, if a particular +test triggers a sanitizer error, GoogleTest will report that it failed. diff --git a/test/googletest/docs/assets/css/style.scss b/test/googletest/docs/assets/css/style.scss new file mode 100644 index 0000000..bb30f41 --- /dev/null +++ b/test/googletest/docs/assets/css/style.scss @@ -0,0 +1,5 @@ +--- +--- + +@import "jekyll-theme-primer"; +@import "main"; diff --git a/test/googletest/docs/community_created_documentation.md b/test/googletest/docs/community_created_documentation.md new file mode 100644 index 0000000..4569075 --- /dev/null +++ b/test/googletest/docs/community_created_documentation.md @@ -0,0 +1,7 @@ +# Community-Created Documentation + +The following is a list, in no particular order, of links to documentation +created by the Googletest community. + +* [Googlemock Insights](https://github.com/ElectricRCAircraftGuy/eRCaGuy_dotfiles/blob/master/googletest/insights.md), + by [ElectricRCAircraftGuy](https://github.com/ElectricRCAircraftGuy) diff --git a/test/googletest/docs/faq.md b/test/googletest/docs/faq.md new file mode 100644 index 0000000..c7d10b5 --- /dev/null +++ b/test/googletest/docs/faq.md @@ -0,0 +1,671 @@ +# GoogleTest FAQ + +## Why should test suite names and test names not contain underscore? + +{: .callout .note} +Note: GoogleTest reserves underscore (`_`) for special-purpose keywords, such as +[the `DISABLED_` prefix](advanced.md#temporarily-disabling-tests), in addition +to the following rationale. + +Underscore (`_`) is special, as C++ reserves the following to be used by the +compiler and the standard library: + +1. any identifier that starts with an `_` followed by an upper-case letter, and +2. any identifier that contains two consecutive underscores (i.e. `__`) + *anywhere* in its name. + +User code is *prohibited* from using such identifiers. + +Now let's look at what this means for `TEST` and `TEST_F`. + +Currently `TEST(TestSuiteName, TestName)` generates a class named +`TestSuiteName_TestName_Test`. What happens if `TestSuiteName` or `TestName` +contains `_`? + +1. If `TestSuiteName` starts with an `_` followed by an upper-case letter (say, + `_Foo`), we end up with `_Foo_TestName_Test`, which is reserved and thus + invalid. +2. If `TestSuiteName` ends with an `_` (say, `Foo_`), we get + `Foo__TestName_Test`, which is invalid. +3. If `TestName` starts with an `_` (say, `_Bar`), we get + `TestSuiteName__Bar_Test`, which is invalid. +4. If `TestName` ends with an `_` (say, `Bar_`), we get + `TestSuiteName_Bar__Test`, which is invalid. + +So clearly `TestSuiteName` and `TestName` cannot start or end with `_` +(Actually, `TestSuiteName` can start with `_`—as long as the `_` isn't followed +by an upper-case letter. But that's getting complicated. So for simplicity we +just say that it cannot start with `_`.). + +It may seem fine for `TestSuiteName` and `TestName` to contain `_` in the +middle. However, consider this: + +```c++ +TEST(Time, Flies_Like_An_Arrow) { ... } +TEST(Time_Flies, Like_An_Arrow) { ... } +``` + +Now, the two `TEST`s will both generate the same class +(`Time_Flies_Like_An_Arrow_Test`). That's not good. + +So for simplicity, we just ask the users to avoid `_` in `TestSuiteName` and +`TestName`. The rule is more constraining than necessary, but it's simple and +easy to remember. It also gives GoogleTest some wiggle room in case its +implementation needs to change in the future. + +If you violate the rule, there may not be immediate consequences, but your test +may (just may) break with a new compiler (or a new version of the compiler you +are using) or with a new version of GoogleTest. Therefore it's best to follow +the rule. + +## Why does GoogleTest support `EXPECT_EQ(NULL, ptr)` and `ASSERT_EQ(NULL, ptr)` but not `EXPECT_NE(NULL, ptr)` and `ASSERT_NE(NULL, ptr)`? + +First of all, you can use `nullptr` with each of these macros, e.g. +`EXPECT_EQ(ptr, nullptr)`, `EXPECT_NE(ptr, nullptr)`, `ASSERT_EQ(ptr, nullptr)`, +`ASSERT_NE(ptr, nullptr)`. This is the preferred syntax in the style guide +because `nullptr` does not have the type problems that `NULL` does. + +Due to some peculiarity of C++, it requires some non-trivial template meta +programming tricks to support using `NULL` as an argument of the `EXPECT_XX()` +and `ASSERT_XX()` macros. Therefore we only do it where it's most needed +(otherwise we make the implementation of GoogleTest harder to maintain and more +error-prone than necessary). + +Historically, the `EXPECT_EQ()` macro took the *expected* value as its first +argument and the *actual* value as the second, though this argument order is now +discouraged. It was reasonable that someone wanted +to write `EXPECT_EQ(NULL, some_expression)`, and this indeed was requested +several times. Therefore we implemented it. + +The need for `EXPECT_NE(NULL, ptr)` wasn't nearly as strong. When the assertion +fails, you already know that `ptr` must be `NULL`, so it doesn't add any +information to print `ptr` in this case. That means `EXPECT_TRUE(ptr != NULL)` +works just as well. + +If we were to support `EXPECT_NE(NULL, ptr)`, for consistency we'd have to +support `EXPECT_NE(ptr, NULL)` as well. This means using the template meta +programming tricks twice in the implementation, making it even harder to +understand and maintain. We believe the benefit doesn't justify the cost. + +Finally, with the growth of the gMock matcher library, we are encouraging people +to use the unified `EXPECT_THAT(value, matcher)` syntax more often in tests. One +significant advantage of the matcher approach is that matchers can be easily +combined to form new matchers, while the `EXPECT_NE`, etc, macros cannot be +easily combined. Therefore we want to invest more in the matchers than in the +`EXPECT_XX()` macros. + +## I need to test that different implementations of an interface satisfy some common requirements. Should I use typed tests or value-parameterized tests? + +For testing various implementations of the same interface, either typed tests or +value-parameterized tests can get it done. It's really up to you the user to +decide which is more convenient for you, depending on your particular case. Some +rough guidelines: + +* Typed tests can be easier to write if instances of the different + implementations can be created the same way, modulo the type. For example, + if all these implementations have a public default constructor (such that + you can write `new TypeParam`), or if their factory functions have the same + form (e.g. `CreateInstance()`). +* Value-parameterized tests can be easier to write if you need different code + patterns to create different implementations' instances, e.g. `new Foo` vs + `new Bar(5)`. To accommodate for the differences, you can write factory + function wrappers and pass these function pointers to the tests as their + parameters. +* When a typed test fails, the default output includes the name of the type, + which can help you quickly identify which implementation is wrong. + Value-parameterized tests only show the number of the failed iteration by + default. You will need to define a function that returns the iteration name + and pass it as the third parameter to INSTANTIATE_TEST_SUITE_P to have more + useful output. +* When using typed tests, you need to make sure you are testing against the + interface type, not the concrete types (in other words, you want to make + sure `implicit_cast(my_concrete_impl)` works, not just that + `my_concrete_impl` works). It's less likely to make mistakes in this area + when using value-parameterized tests. + +I hope I didn't confuse you more. :-) If you don't mind, I'd suggest you to give +both approaches a try. Practice is a much better way to grasp the subtle +differences between the two tools. Once you have some concrete experience, you +can much more easily decide which one to use the next time. + +## My death test modifies some state, but the change seems lost after the death test finishes. Why? + +Death tests (`EXPECT_DEATH`, etc.) are executed in a sub-process s.t. the +expected crash won't kill the test program (i.e. the parent process). As a +result, any in-memory side effects they incur are observable in their respective +sub-processes, but not in the parent process. You can think of them as running +in a parallel universe, more or less. + +In particular, if you use mocking and the death test statement invokes some mock +methods, the parent process will think the calls have never occurred. Therefore, +you may want to move your `EXPECT_CALL` statements inside the `EXPECT_DEATH` +macro. + +## EXPECT_EQ(htonl(blah), blah_blah) generates weird compiler errors in opt mode. Is this a GoogleTest bug? + +Actually, the bug is in `htonl()`. + +According to `'man htonl'`, `htonl()` is a *function*, which means it's valid to +use `htonl` as a function pointer. However, in opt mode `htonl()` is defined as +a *macro*, which breaks this usage. + +Worse, the macro definition of `htonl()` uses a `gcc` extension and is *not* +standard C++. That hacky implementation has some ad hoc limitations. In +particular, it prevents you from writing `Foo()`, where `Foo` +is a template that has an integral argument. + +The implementation of `EXPECT_EQ(a, b)` uses `sizeof(... a ...)` inside a +template argument, and thus doesn't compile in opt mode when `a` contains a call +to `htonl()`. It is difficult to make `EXPECT_EQ` bypass the `htonl()` bug, as +the solution must work with different compilers on various platforms. + +## The compiler complains about "undefined references" to some static const member variables, but I did define them in the class body. What's wrong? + +If your class has a static data member: + +```c++ +// foo.h +class Foo { + ... + static const int kBar = 100; +}; +``` + +you also need to define it *outside* of the class body in `foo.cc`: + +```c++ +const int Foo::kBar; // No initializer here. +``` + +Otherwise your code is **invalid C++**, and may break in unexpected ways. In +particular, using it in GoogleTest comparison assertions (`EXPECT_EQ`, etc.) +will generate an "undefined reference" linker error. The fact that "it used to +work" doesn't mean it's valid. It just means that you were lucky. :-) + +If the declaration of the static data member is `constexpr` then it is +implicitly an `inline` definition, and a separate definition in `foo.cc` is not +needed: + +```c++ +// foo.h +class Foo { + ... + static constexpr int kBar = 100; // Defines kBar, no need to do it in foo.cc. +}; +``` + +## Can I derive a test fixture from another? + +Yes. + +Each test fixture has a corresponding and same named test suite. This means only +one test suite can use a particular fixture. Sometimes, however, multiple test +cases may want to use the same or slightly different fixtures. For example, you +may want to make sure that all of a GUI library's test suites don't leak +important system resources like fonts and brushes. + +In GoogleTest, you share a fixture among test suites by putting the shared logic +in a base test fixture, then deriving from that base a separate fixture for each +test suite that wants to use this common logic. You then use `TEST_F()` to write +tests using each derived fixture. + +Typically, your code looks like this: + +```c++ +// Defines a base test fixture. +class BaseTest : public ::testing::Test { + protected: + ... +}; + +// Derives a fixture FooTest from BaseTest. +class FooTest : public BaseTest { + protected: + void SetUp() override { + BaseTest::SetUp(); // Sets up the base fixture first. + ... additional set-up work ... + } + + void TearDown() override { + ... clean-up work for FooTest ... + BaseTest::TearDown(); // Remember to tear down the base fixture + // after cleaning up FooTest! + } + + ... functions and variables for FooTest ... +}; + +// Tests that use the fixture FooTest. +TEST_F(FooTest, Bar) { ... } +TEST_F(FooTest, Baz) { ... } + +... additional fixtures derived from BaseTest ... +``` + +If necessary, you can continue to derive test fixtures from a derived fixture. +GoogleTest has no limit on how deep the hierarchy can be. + +For a complete example using derived test fixtures, see +[sample5_unittest.cc](https://github.com/google/googletest/blob/main/googletest/samples/sample5_unittest.cc). + +## My compiler complains "void value not ignored as it ought to be." What does this mean? + +You're probably using an `ASSERT_*()` in a function that doesn't return `void`. +`ASSERT_*()` can only be used in `void` functions, due to exceptions being +disabled by our build system. Please see more details +[here](advanced.md#assertion-placement). + +## My death test hangs (or seg-faults). How do I fix it? + +In GoogleTest, death tests are run in a child process and the way they work is +delicate. To write death tests you really need to understand how they work—see +the details at [Death Assertions](reference/assertions.md#death) in the +Assertions Reference. + +In particular, death tests don't like having multiple threads in the parent +process. So the first thing you can try is to eliminate creating threads outside +of `EXPECT_DEATH()`. For example, you may want to use mocks or fake objects +instead of real ones in your tests. + +Sometimes this is impossible as some library you must use may be creating +threads before `main()` is even reached. In this case, you can try to minimize +the chance of conflicts by either moving as many activities as possible inside +`EXPECT_DEATH()` (in the extreme case, you want to move everything inside), or +leaving as few things as possible in it. Also, you can try to set the death test +style to `"threadsafe"`, which is safer but slower, and see if it helps. + +If you go with thread-safe death tests, remember that they rerun the test +program from the beginning in the child process. Therefore make sure your +program can run side-by-side with itself and is deterministic. + +In the end, this boils down to good concurrent programming. You have to make +sure that there are no race conditions or deadlocks in your program. No silver +bullet - sorry! + +## Should I use the constructor/destructor of the test fixture or SetUp()/TearDown()? {#CtorVsSetUp} + +The first thing to remember is that GoogleTest does **not** reuse the same test +fixture object across multiple tests. For each `TEST_F`, GoogleTest will create +a **fresh** test fixture object, immediately call `SetUp()`, run the test body, +call `TearDown()`, and then delete the test fixture object. + +When you need to write per-test set-up and tear-down logic, you have the choice +between using the test fixture constructor/destructor or `SetUp()`/`TearDown()`. +The former is usually preferred, as it has the following benefits: + +* By initializing a member variable in the constructor, we have the option to + make it `const`, which helps prevent accidental changes to its value and + makes the tests more obviously correct. +* In case we need to subclass the test fixture class, the subclass' + constructor is guaranteed to call the base class' constructor *first*, and + the subclass' destructor is guaranteed to call the base class' destructor + *afterward*. With `SetUp()/TearDown()`, a subclass may make the mistake of + forgetting to call the base class' `SetUp()/TearDown()` or call them at the + wrong time. + +You may still want to use `SetUp()/TearDown()` in the following cases: + +* C++ does not allow virtual function calls in constructors and destructors. + You can call a method declared as virtual, but it will not use dynamic + dispatch. It will use the definition from the class the constructor of which + is currently executing. This is because calling a virtual method before the + derived class constructor has a chance to run is very dangerous - the + virtual method might operate on uninitialized data. Therefore, if you need + to call a method that will be overridden in a derived class, you have to use + `SetUp()/TearDown()`. +* In the body of a constructor (or destructor), it's not possible to use the + `ASSERT_xx` macros. Therefore, if the set-up operation could cause a fatal + test failure that should prevent the test from running, it's necessary to + use `abort` and abort the whole test + executable, or to use `SetUp()` instead of a constructor. +* If the tear-down operation could throw an exception, you must use + `TearDown()` as opposed to the destructor, as throwing in a destructor leads + to undefined behavior and usually will kill your program right away. Note + that many standard libraries (like STL) may throw when exceptions are + enabled in the compiler. Therefore you should prefer `TearDown()` if you + want to write portable tests that work with or without exceptions. +* The GoogleTest team is considering making the assertion macros throw on + platforms where exceptions are enabled (e.g. Windows, Mac OS, and Linux + client-side), which will eliminate the need for the user to propagate + failures from a subroutine to its caller. Therefore, you shouldn't use + GoogleTest assertions in a destructor if your code could run on such a + platform. + +## The compiler complains "no matching function to call" when I use `ASSERT_PRED*`. How do I fix it? + +See details for [`EXPECT_PRED*`](reference/assertions.md#EXPECT_PRED) in the +Assertions Reference. + +## My compiler complains about "ignoring return value" when I call RUN_ALL_TESTS(). Why? + +Some people had been ignoring the return value of `RUN_ALL_TESTS()`. That is, +instead of + +```c++ + return RUN_ALL_TESTS(); +``` + +they write + +```c++ + RUN_ALL_TESTS(); +``` + +This is **wrong and dangerous**. The testing services needs to see the return +value of `RUN_ALL_TESTS()` in order to determine if a test has passed. If your +`main()` function ignores it, your test will be considered successful even if it +has a GoogleTest assertion failure. Very bad. + +We have decided to fix this (thanks to Michael Chastain for the idea). Now, your +code will no longer be able to ignore `RUN_ALL_TESTS()` when compiled with +`gcc`. If you do so, you'll get a compiler error. + +If you see the compiler complaining about you ignoring the return value of +`RUN_ALL_TESTS()`, the fix is simple: just make sure its value is used as the +return value of `main()`. + +But how could we introduce a change that breaks existing tests? Well, in this +case, the code was already broken in the first place, so we didn't break it. :-) + +## My compiler complains that a constructor (or destructor) cannot return a value. What's going on? + +Due to a peculiarity of C++, in order to support the syntax for streaming +messages to an `ASSERT_*`, e.g. + +```c++ + ASSERT_EQ(1, Foo()) << "blah blah" << foo; +``` + +we had to give up using `ASSERT*` and `FAIL*` (but not `EXPECT*` and +`ADD_FAILURE*`) in constructors and destructors. The workaround is to move the +content of your constructor/destructor to a private void member function, or +switch to `EXPECT_*()` if that works. This +[section](advanced.md#assertion-placement) in the user's guide explains it. + +## My SetUp() function is not called. Why? + +C++ is case-sensitive. Did you spell it as `Setup()`? + +Similarly, sometimes people spell `SetUpTestSuite()` as `SetupTestSuite()` and +wonder why it's never called. + +## I have several test suites which share the same test fixture logic; do I have to define a new test fixture class for each of them? This seems pretty tedious. + +You don't have to. Instead of + +```c++ +class FooTest : public BaseTest {}; + +TEST_F(FooTest, Abc) { ... } +TEST_F(FooTest, Def) { ... } + +class BarTest : public BaseTest {}; + +TEST_F(BarTest, Abc) { ... } +TEST_F(BarTest, Def) { ... } +``` + +you can simply `typedef` the test fixtures: + +```c++ +typedef BaseTest FooTest; + +TEST_F(FooTest, Abc) { ... } +TEST_F(FooTest, Def) { ... } + +typedef BaseTest BarTest; + +TEST_F(BarTest, Abc) { ... } +TEST_F(BarTest, Def) { ... } +``` + +## GoogleTest output is buried in a whole bunch of LOG messages. What do I do? + +The GoogleTest output is meant to be a concise and human-friendly report. If +your test generates textual output itself, it will mix with the GoogleTest +output, making it hard to read. However, there is an easy solution to this +problem. + +Since `LOG` messages go to stderr, we decided to let GoogleTest output go to +stdout. This way, you can easily separate the two using redirection. For +example: + +```shell +$ ./my_test > gtest_output.txt +``` + +## Why should I prefer test fixtures over global variables? + +There are several good reasons: + +1. It's likely your test needs to change the states of its global variables. + This makes it difficult to keep side effects from escaping one test and + contaminating others, making debugging difficult. By using fixtures, each + test has a fresh set of variables that's different (but with the same + names). Thus, tests are kept independent of each other. +2. Global variables pollute the global namespace. +3. Test fixtures can be reused via subclassing, which cannot be done easily + with global variables. This is useful if many test suites have something in + common. + +## What can the statement argument in ASSERT_DEATH() be? + +`ASSERT_DEATH(statement, matcher)` (or any death assertion macro) can be used +wherever *`statement`* is valid. So basically *`statement`* can be any C++ +statement that makes sense in the current context. In particular, it can +reference global and/or local variables, and can be: + +* a simple function call (often the case), +* a complex expression, or +* a compound statement. + +Some examples are shown here: + +```c++ +// A death test can be a simple function call. +TEST(MyDeathTest, FunctionCall) { + ASSERT_DEATH(Xyz(5), "Xyz failed"); +} + +// Or a complex expression that references variables and functions. +TEST(MyDeathTest, ComplexExpression) { + const bool c = Condition(); + ASSERT_DEATH((c ? Func1(0) : object2.Method("test")), + "(Func1|Method) failed"); +} + +// Death assertions can be used anywhere in a function. In +// particular, they can be inside a loop. +TEST(MyDeathTest, InsideLoop) { + // Verifies that Foo(0), Foo(1), ..., and Foo(4) all die. + for (int i = 0; i < 5; i++) { + EXPECT_DEATH_M(Foo(i), "Foo has \\d+ errors", + ::testing::Message() << "where i is " << i); + } +} + +// A death assertion can contain a compound statement. +TEST(MyDeathTest, CompoundStatement) { + // Verifies that at lease one of Bar(0), Bar(1), ..., and + // Bar(4) dies. + ASSERT_DEATH({ + for (int i = 0; i < 5; i++) { + Bar(i); + } + }, + "Bar has \\d+ errors"); +} +``` + +## I have a fixture class `FooTest`, but `TEST_F(FooTest, Bar)` gives me error ``"no matching function for call to `FooTest::FooTest()'"``. Why? + +GoogleTest needs to be able to create objects of your test fixture class, so it +must have a default constructor. Normally the compiler will define one for you. +However, there are cases where you have to define your own: + +* If you explicitly declare a non-default constructor for class `FooTest` + (`DISALLOW_EVIL_CONSTRUCTORS()` does this), then you need to define a + default constructor, even if it would be empty. +* If `FooTest` has a const non-static data member, then you have to define the + default constructor *and* initialize the const member in the initializer + list of the constructor. (Early versions of `gcc` doesn't force you to + initialize the const member. It's a bug that has been fixed in `gcc 4`.) + +## Why does ASSERT_DEATH complain about previous threads that were already joined? + +With the Linux pthread library, there is no turning back once you cross the line +from a single thread to multiple threads. The first time you create a thread, a +manager thread is created in addition, so you get 3, not 2, threads. Later when +the thread you create joins the main thread, the thread count decrements by 1, +but the manager thread will never be killed, so you still have 2 threads, which +means you cannot safely run a death test. + +The new NPTL thread library doesn't suffer from this problem, as it doesn't +create a manager thread. However, if you don't control which machine your test +runs on, you shouldn't depend on this. + +## Why does GoogleTest require the entire test suite, instead of individual tests, to be named `*DeathTest` when it uses `ASSERT_DEATH`? + +GoogleTest does not interleave tests from different test suites. That is, it +runs all tests in one test suite first, and then runs all tests in the next test +suite, and so on. GoogleTest does this because it needs to set up a test suite +before the first test in it is run, and tear it down afterwards. Splitting up +the test case would require multiple set-up and tear-down processes, which is +inefficient and makes the semantics unclean. + +If we were to determine the order of tests based on test name instead of test +case name, then we would have a problem with the following situation: + +```c++ +TEST_F(FooTest, AbcDeathTest) { ... } +TEST_F(FooTest, Uvw) { ... } + +TEST_F(BarTest, DefDeathTest) { ... } +TEST_F(BarTest, Xyz) { ... } +``` + +Since `FooTest.AbcDeathTest` needs to run before `BarTest.Xyz`, and we don't +interleave tests from different test suites, we need to run all tests in the +`FooTest` case before running any test in the `BarTest` case. This contradicts +with the requirement to run `BarTest.DefDeathTest` before `FooTest.Uvw`. + +## But I don't like calling my entire test suite `*DeathTest` when it contains both death tests and non-death tests. What do I do? + +You don't have to, but if you like, you may split up the test suite into +`FooTest` and `FooDeathTest`, where the names make it clear that they are +related: + +```c++ +class FooTest : public ::testing::Test { ... }; + +TEST_F(FooTest, Abc) { ... } +TEST_F(FooTest, Def) { ... } + +using FooDeathTest = FooTest; + +TEST_F(FooDeathTest, Uvw) { ... EXPECT_DEATH(...) ... } +TEST_F(FooDeathTest, Xyz) { ... ASSERT_DEATH(...) ... } +``` + +## GoogleTest prints the LOG messages in a death test's child process only when the test fails. How can I see the LOG messages when the death test succeeds? + +Printing the LOG messages generated by the statement inside `EXPECT_DEATH()` +makes it harder to search for real problems in the parent's log. Therefore, +GoogleTest only prints them when the death test has failed. + +If you really need to see such LOG messages, a workaround is to temporarily +break the death test (e.g. by changing the regex pattern it is expected to +match). Admittedly, this is a hack. We'll consider a more permanent solution +after the fork-and-exec-style death tests are implemented. + +## The compiler complains about `no match for 'operator<<'` when I use an assertion. What gives? + +If you use a user-defined type `FooType` in an assertion, you must make sure +there is an `std::ostream& operator<<(std::ostream&, const FooType&)` function +defined such that we can print a value of `FooType`. + +In addition, if `FooType` is declared in a name space, the `<<` operator also +needs to be defined in the *same* name space. See +[Tip of the Week #49](https://abseil.io/tips/49) for details. + +## How do I suppress the memory leak messages on Windows? + +Since the statically initialized GoogleTest singleton requires allocations on +the heap, the Visual C++ memory leak detector will report memory leaks at the +end of the program run. The easiest way to avoid this is to use the +`_CrtMemCheckpoint` and `_CrtMemDumpAllObjectsSince` calls to not report any +statically initialized heap objects. See MSDN for more details and additional +heap check/debug routines. + +## How can my code detect if it is running in a test? + +If you write code that sniffs whether it's running in a test and does different +things accordingly, you are leaking test-only logic into production code and +there is no easy way to ensure that the test-only code paths aren't run by +mistake in production. Such cleverness also leads to +[Heisenbugs](https://en.wikipedia.org/wiki/Heisenbug). Therefore we strongly +advise against the practice, and GoogleTest doesn't provide a way to do it. + +In general, the recommended way to cause the code to behave differently under +test is [Dependency Injection](https://en.wikipedia.org/wiki/Dependency_injection). You can inject +different functionality from the test and from the production code. Since your +production code doesn't link in the for-test logic at all (the +[`testonly`](https://docs.bazel.build/versions/master/be/common-definitions.html#common.testonly) attribute for BUILD targets helps to ensure +that), there is no danger in accidentally running it. + +However, if you *really*, *really*, *really* have no choice, and if you follow +the rule of ending your test program names with `_test`, you can use the +*horrible* hack of sniffing your executable name (`argv[0]` in `main()`) to know +whether the code is under test. + +## How do I temporarily disable a test? + +If you have a broken test that you cannot fix right away, you can add the +`DISABLED_` prefix to its name. This will exclude it from execution. This is +better than commenting out the code or using `#if 0`, as disabled tests are +still compiled (and thus won't rot). + +To include disabled tests in test execution, just invoke the test program with +the `--gtest_also_run_disabled_tests` flag. + +## Is it OK if I have two separate `TEST(Foo, Bar)` test methods defined in different namespaces? + +Yes. + +The rule is **all test methods in the same test suite must use the same fixture +class**. This means that the following is **allowed** because both tests use the +same fixture class (`::testing::Test`). + +```c++ +namespace foo { +TEST(CoolTest, DoSomething) { + SUCCEED(); +} +} // namespace foo + +namespace bar { +TEST(CoolTest, DoSomething) { + SUCCEED(); +} +} // namespace bar +``` + +However, the following code is **not allowed** and will produce a runtime error +from GoogleTest because the test methods are using different test fixture +classes with the same test suite name. + +```c++ +namespace foo { +class CoolTest : public ::testing::Test {}; // Fixture foo::CoolTest +TEST_F(CoolTest, DoSomething) { + SUCCEED(); +} +} // namespace foo + +namespace bar { +class CoolTest : public ::testing::Test {}; // Fixture: bar::CoolTest +TEST_F(CoolTest, DoSomething) { + SUCCEED(); +} +} // namespace bar +``` diff --git a/test/googletest/docs/gmock_cheat_sheet.md b/test/googletest/docs/gmock_cheat_sheet.md new file mode 100644 index 0000000..ddafaaa --- /dev/null +++ b/test/googletest/docs/gmock_cheat_sheet.md @@ -0,0 +1,241 @@ +# gMock Cheat Sheet + +## Defining a Mock Class + +### Mocking a Normal Class {#MockClass} + +Given + +```cpp +class Foo { + public: + virtual ~Foo(); + virtual int GetSize() const = 0; + virtual string Describe(const char* name) = 0; + virtual string Describe(int type) = 0; + virtual bool Process(Bar elem, int count) = 0; +}; +``` + +(note that `~Foo()` **must** be virtual) we can define its mock as + +```cpp +#include + +class MockFoo : public Foo { + public: + MOCK_METHOD(int, GetSize, (), (const, override)); + MOCK_METHOD(string, Describe, (const char* name), (override)); + MOCK_METHOD(string, Describe, (int type), (override)); + MOCK_METHOD(bool, Process, (Bar elem, int count), (override)); +}; +``` + +To create a "nice" mock, which ignores all uninteresting calls, a "naggy" mock, +which warns on all uninteresting calls, or a "strict" mock, which treats them as +failures: + +```cpp +using ::testing::NiceMock; +using ::testing::NaggyMock; +using ::testing::StrictMock; + +NiceMock nice_foo; // The type is a subclass of MockFoo. +NaggyMock naggy_foo; // The type is a subclass of MockFoo. +StrictMock strict_foo; // The type is a subclass of MockFoo. +``` + +{: .callout .note} +**Note:** A mock object is currently naggy by default. We may make it nice by +default in the future. + +### Mocking a Class Template {#MockTemplate} + +Class templates can be mocked just like any class. + +To mock + +```cpp +template +class StackInterface { + public: + virtual ~StackInterface(); + virtual int GetSize() const = 0; + virtual void Push(const Elem& x) = 0; +}; +``` + +(note that all member functions that are mocked, including `~StackInterface()` +**must** be virtual). + +```cpp +template +class MockStack : public StackInterface { + public: + MOCK_METHOD(int, GetSize, (), (const, override)); + MOCK_METHOD(void, Push, (const Elem& x), (override)); +}; +``` + +### Specifying Calling Conventions for Mock Functions + +If your mock function doesn't use the default calling convention, you can +specify it by adding `Calltype(convention)` to `MOCK_METHOD`'s 4th parameter. +For example, + +```cpp + MOCK_METHOD(bool, Foo, (int n), (Calltype(STDMETHODCALLTYPE))); + MOCK_METHOD(int, Bar, (double x, double y), + (const, Calltype(STDMETHODCALLTYPE))); +``` + +where `STDMETHODCALLTYPE` is defined by `` on Windows. + +## Using Mocks in Tests {#UsingMocks} + +The typical work flow is: + +1. Import the gMock names you need to use. All gMock symbols are in the + `testing` namespace unless they are macros or otherwise noted. +2. Create the mock objects. +3. Optionally, set the default actions of the mock objects. +4. Set your expectations on the mock objects (How will they be called? What + will they do?). +5. Exercise code that uses the mock objects; if necessary, check the result + using googletest assertions. +6. When a mock object is destructed, gMock automatically verifies that all + expectations on it have been satisfied. + +Here's an example: + +```cpp +using ::testing::Return; // #1 + +TEST(BarTest, DoesThis) { + MockFoo foo; // #2 + + ON_CALL(foo, GetSize()) // #3 + .WillByDefault(Return(1)); + // ... other default actions ... + + EXPECT_CALL(foo, Describe(5)) // #4 + .Times(3) + .WillRepeatedly(Return("Category 5")); + // ... other expectations ... + + EXPECT_EQ(MyProductionFunction(&foo), "good"); // #5 +} // #6 +``` + +## Setting Default Actions {#OnCall} + +gMock has a **built-in default action** for any function that returns `void`, +`bool`, a numeric value, or a pointer. In C++11, it will additionally returns +the default-constructed value, if one exists for the given type. + +To customize the default action for functions with return type `T`, use +[`DefaultValue`](reference/mocking.md#DefaultValue). For example: + +```cpp + // Sets the default action for return type std::unique_ptr to + // creating a new Buzz every time. + DefaultValue>::SetFactory( + [] { return std::make_unique(AccessLevel::kInternal); }); + + // When this fires, the default action of MakeBuzz() will run, which + // will return a new Buzz object. + EXPECT_CALL(mock_buzzer_, MakeBuzz("hello")).Times(AnyNumber()); + + auto buzz1 = mock_buzzer_.MakeBuzz("hello"); + auto buzz2 = mock_buzzer_.MakeBuzz("hello"); + EXPECT_NE(buzz1, nullptr); + EXPECT_NE(buzz2, nullptr); + EXPECT_NE(buzz1, buzz2); + + // Resets the default action for return type std::unique_ptr, + // to avoid interfere with other tests. + DefaultValue>::Clear(); +``` + +To customize the default action for a particular method of a specific mock +object, use [`ON_CALL`](reference/mocking.md#ON_CALL). `ON_CALL` has a similar +syntax to `EXPECT_CALL`, but it is used for setting default behaviors when you +do not require that the mock method is called. See +[Knowing When to Expect](gmock_cook_book.md#UseOnCall) for a more detailed +discussion. + +## Setting Expectations {#ExpectCall} + +See [`EXPECT_CALL`](reference/mocking.md#EXPECT_CALL) in the Mocking Reference. + +## Matchers {#MatcherList} + +See the [Matchers Reference](reference/matchers.md). + +## Actions {#ActionList} + +See the [Actions Reference](reference/actions.md). + +## Cardinalities {#CardinalityList} + +See the [`Times` clause](reference/mocking.md#EXPECT_CALL.Times) of +`EXPECT_CALL` in the Mocking Reference. + +## Expectation Order + +By default, expectations can be matched in *any* order. If some or all +expectations must be matched in a given order, you can use the +[`After` clause](reference/mocking.md#EXPECT_CALL.After) or +[`InSequence` clause](reference/mocking.md#EXPECT_CALL.InSequence) of +`EXPECT_CALL`, or use an [`InSequence` object](reference/mocking.md#InSequence). + +## Verifying and Resetting a Mock + +gMock will verify the expectations on a mock object when it is destructed, or +you can do it earlier: + +```cpp +using ::testing::Mock; +... +// Verifies and removes the expectations on mock_obj; +// returns true if and only if successful. +Mock::VerifyAndClearExpectations(&mock_obj); +... +// Verifies and removes the expectations on mock_obj; +// also removes the default actions set by ON_CALL(); +// returns true if and only if successful. +Mock::VerifyAndClear(&mock_obj); +``` + +Do not set new expectations after verifying and clearing a mock after its use. +Setting expectations after code that exercises the mock has undefined behavior. +See [Using Mocks in Tests](gmock_for_dummies.md#using-mocks-in-tests) for more +information. + +You can also tell gMock that a mock object can be leaked and doesn't need to be +verified: + +```cpp +Mock::AllowLeak(&mock_obj); +``` + +## Mock Classes + +gMock defines a convenient mock class template + +```cpp +class MockFunction { + public: + MOCK_METHOD(R, Call, (A1, ..., An)); +}; +``` + +See this [recipe](gmock_cook_book.md#UsingCheckPoints) for one application of +it. + +## Flags + +| Flag | Description | +| :----------------------------- | :---------------------------------------- | +| `--gmock_catch_leaked_mocks=0` | Don't report leaked mock objects as failures. | +| `--gmock_verbose=LEVEL` | Sets the default verbosity level (`info`, `warning`, or `error`) of Google Mock messages. | diff --git a/test/googletest/docs/gmock_cook_book.md b/test/googletest/docs/gmock_cook_book.md new file mode 100644 index 0000000..f1b10b4 --- /dev/null +++ b/test/googletest/docs/gmock_cook_book.md @@ -0,0 +1,4379 @@ +# gMock Cookbook + +You can find recipes for using gMock here. If you haven't yet, please read +[the dummy guide](gmock_for_dummies.md) first to make sure you understand the +basics. + +{: .callout .note} +**Note:** gMock lives in the `testing` name space. For readability, it is +recommended to write `using ::testing::Foo;` once in your file before using the +name `Foo` defined by gMock. We omit such `using` statements in this section for +brevity, but you should do it in your own code. + +## Creating Mock Classes + +Mock classes are defined as normal classes, using the `MOCK_METHOD` macro to +generate mocked methods. The macro gets 3 or 4 parameters: + +```cpp +class MyMock { + public: + MOCK_METHOD(ReturnType, MethodName, (Args...)); + MOCK_METHOD(ReturnType, MethodName, (Args...), (Specs...)); +}; +``` + +The first 3 parameters are simply the method declaration, split into 3 parts. +The 4th parameter accepts a closed list of qualifiers, which affect the +generated method: + +* **`const`** - Makes the mocked method a `const` method. Required if + overriding a `const` method. +* **`override`** - Marks the method with `override`. Recommended if overriding + a `virtual` method. +* **`noexcept`** - Marks the method with `noexcept`. Required if overriding a + `noexcept` method. +* **`Calltype(...)`** - Sets the call type for the method (e.g. to + `STDMETHODCALLTYPE`), useful in Windows. +* **`ref(...)`** - Marks the method with the reference qualification + specified. Required if overriding a method that has reference + qualifications. Eg `ref(&)` or `ref(&&)`. + +### Dealing with unprotected commas + +Unprotected commas, i.e. commas which are not surrounded by parentheses, prevent +`MOCK_METHOD` from parsing its arguments correctly: + +{: .bad} +```cpp +class MockFoo { + public: + MOCK_METHOD(std::pair, GetPair, ()); // Won't compile! + MOCK_METHOD(bool, CheckMap, (std::map, bool)); // Won't compile! +}; +``` + +Solution 1 - wrap with parentheses: + +{: .good} +```cpp +class MockFoo { + public: + MOCK_METHOD((std::pair), GetPair, ()); + MOCK_METHOD(bool, CheckMap, ((std::map), bool)); +}; +``` + +Note that wrapping a return or argument type with parentheses is, in general, +invalid C++. `MOCK_METHOD` removes the parentheses. + +Solution 2 - define an alias: + +{: .good} +```cpp +class MockFoo { + public: + using BoolAndInt = std::pair; + MOCK_METHOD(BoolAndInt, GetPair, ()); + using MapIntDouble = std::map; + MOCK_METHOD(bool, CheckMap, (MapIntDouble, bool)); +}; +``` + +### Mocking Private or Protected Methods + +You must always put a mock method definition (`MOCK_METHOD`) in a `public:` +section of the mock class, regardless of the method being mocked being `public`, +`protected`, or `private` in the base class. This allows `ON_CALL` and +`EXPECT_CALL` to reference the mock function from outside of the mock class. +(Yes, C++ allows a subclass to change the access level of a virtual function in +the base class.) Example: + +```cpp +class Foo { + public: + ... + virtual bool Transform(Gadget* g) = 0; + + protected: + virtual void Resume(); + + private: + virtual int GetTimeOut(); +}; + +class MockFoo : public Foo { + public: + ... + MOCK_METHOD(bool, Transform, (Gadget* g), (override)); + + // The following must be in the public section, even though the + // methods are protected or private in the base class. + MOCK_METHOD(void, Resume, (), (override)); + MOCK_METHOD(int, GetTimeOut, (), (override)); +}; +``` + +### Mocking Overloaded Methods + +You can mock overloaded functions as usual. No special attention is required: + +```cpp +class Foo { + ... + + // Must be virtual as we'll inherit from Foo. + virtual ~Foo(); + + // Overloaded on the types and/or numbers of arguments. + virtual int Add(Element x); + virtual int Add(int times, Element x); + + // Overloaded on the const-ness of this object. + virtual Bar& GetBar(); + virtual const Bar& GetBar() const; +}; + +class MockFoo : public Foo { + ... + MOCK_METHOD(int, Add, (Element x), (override)); + MOCK_METHOD(int, Add, (int times, Element x), (override)); + + MOCK_METHOD(Bar&, GetBar, (), (override)); + MOCK_METHOD(const Bar&, GetBar, (), (const, override)); +}; +``` + +{: .callout .note} +**Note:** if you don't mock all versions of the overloaded method, the compiler +will give you a warning about some methods in the base class being hidden. To +fix that, use `using` to bring them in scope: + +```cpp +class MockFoo : public Foo { + ... + using Foo::Add; + MOCK_METHOD(int, Add, (Element x), (override)); + // We don't want to mock int Add(int times, Element x); + ... +}; +``` + +### Mocking Class Templates + +You can mock class templates just like any class. + +```cpp +template +class StackInterface { + ... + // Must be virtual as we'll inherit from StackInterface. + virtual ~StackInterface(); + + virtual int GetSize() const = 0; + virtual void Push(const Elem& x) = 0; +}; + +template +class MockStack : public StackInterface { + ... + MOCK_METHOD(int, GetSize, (), (override)); + MOCK_METHOD(void, Push, (const Elem& x), (override)); +}; +``` + +### Mocking Non-virtual Methods {#MockingNonVirtualMethods} + +gMock can mock non-virtual functions to be used in Hi-perf dependency injection. + +In this case, instead of sharing a common base class with the real class, your +mock class will be *unrelated* to the real class, but contain methods with the +same signatures. The syntax for mocking non-virtual methods is the *same* as +mocking virtual methods (just don't add `override`): + +```cpp +// A simple packet stream class. None of its members is virtual. +class ConcretePacketStream { + public: + void AppendPacket(Packet* new_packet); + const Packet* GetPacket(size_t packet_number) const; + size_t NumberOfPackets() const; + ... +}; + +// A mock packet stream class. It inherits from no other, but defines +// GetPacket() and NumberOfPackets(). +class MockPacketStream { + public: + MOCK_METHOD(const Packet*, GetPacket, (size_t packet_number), (const)); + MOCK_METHOD(size_t, NumberOfPackets, (), (const)); + ... +}; +``` + +Note that the mock class doesn't define `AppendPacket()`, unlike the real class. +That's fine as long as the test doesn't need to call it. + +Next, you need a way to say that you want to use `ConcretePacketStream` in +production code, and use `MockPacketStream` in tests. Since the functions are +not virtual and the two classes are unrelated, you must specify your choice at +*compile time* (as opposed to run time). + +One way to do it is to templatize your code that needs to use a packet stream. +More specifically, you will give your code a template type argument for the type +of the packet stream. In production, you will instantiate your template with +`ConcretePacketStream` as the type argument. In tests, you will instantiate the +same template with `MockPacketStream`. For example, you may write: + +```cpp +template +void CreateConnection(PacketStream* stream) { ... } + +template +class PacketReader { + public: + void ReadPackets(PacketStream* stream, size_t packet_num); +}; +``` + +Then you can use `CreateConnection()` and +`PacketReader` in production code, and use +`CreateConnection()` and `PacketReader` in +tests. + +```cpp + MockPacketStream mock_stream; + EXPECT_CALL(mock_stream, ...)...; + .. set more expectations on mock_stream ... + PacketReader reader(&mock_stream); + ... exercise reader ... +``` + +### Mocking Free Functions + +It is not possible to directly mock a free function (i.e. a C-style function or +a static method). If you need to, you can rewrite your code to use an interface +(abstract class). + +Instead of calling a free function (say, `OpenFile`) directly, introduce an +interface for it and have a concrete subclass that calls the free function: + +```cpp +class FileInterface { + public: + ... + virtual bool Open(const char* path, const char* mode) = 0; +}; + +class File : public FileInterface { + public: + ... + bool Open(const char* path, const char* mode) override { + return OpenFile(path, mode); + } +}; +``` + +Your code should talk to `FileInterface` to open a file. Now it's easy to mock +out the function. + +This may seem like a lot of hassle, but in practice you often have multiple +related functions that you can put in the same interface, so the per-function +syntactic overhead will be much lower. + +If you are concerned about the performance overhead incurred by virtual +functions, and profiling confirms your concern, you can combine this with the +recipe for [mocking non-virtual methods](#MockingNonVirtualMethods). + +Alternatively, instead of introducing a new interface, you can rewrite your code +to accept a std::function instead of the free function, and then use +[MockFunction](#MockFunction) to mock the std::function. + +### Old-Style `MOCK_METHODn` Macros + +Before the generic `MOCK_METHOD` macro +[was introduced in 2018](https://github.com/google/googletest/commit/c5f08bf91944ce1b19bcf414fa1760e69d20afc2), +mocks where created using a family of macros collectively called `MOCK_METHODn`. +These macros are still supported, though migration to the new `MOCK_METHOD` is +recommended. + +The macros in the `MOCK_METHODn` family differ from `MOCK_METHOD`: + +* The general structure is `MOCK_METHODn(MethodName, ReturnType(Args))`, + instead of `MOCK_METHOD(ReturnType, MethodName, (Args))`. +* The number `n` must equal the number of arguments. +* When mocking a const method, one must use `MOCK_CONST_METHODn`. +* When mocking a class template, the macro name must be suffixed with `_T`. +* In order to specify the call type, the macro name must be suffixed with + `_WITH_CALLTYPE`, and the call type is the first macro argument. + +Old macros and their new equivalents: + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Simple
OldMOCK_METHOD1(Foo, bool(int))
NewMOCK_METHOD(bool, Foo, (int))
Const Method
OldMOCK_CONST_METHOD1(Foo, bool(int))
NewMOCK_METHOD(bool, Foo, (int), (const))
Method in a Class Template
OldMOCK_METHOD1_T(Foo, bool(int))
NewMOCK_METHOD(bool, Foo, (int))
Const Method in a Class Template
OldMOCK_CONST_METHOD1_T(Foo, bool(int))
NewMOCK_METHOD(bool, Foo, (int), (const))
Method with Call Type
OldMOCK_METHOD1_WITH_CALLTYPE(STDMETHODCALLTYPE, Foo, bool(int))
NewMOCK_METHOD(bool, Foo, (int), (Calltype(STDMETHODCALLTYPE)))
Const Method with Call Type
OldMOCK_CONST_METHOD1_WITH_CALLTYPE(STDMETHODCALLTYPE, Foo, bool(int))
NewMOCK_METHOD(bool, Foo, (int), (const, Calltype(STDMETHODCALLTYPE)))
Method with Call Type in a Class Template
OldMOCK_METHOD1_T_WITH_CALLTYPE(STDMETHODCALLTYPE, Foo, bool(int))
NewMOCK_METHOD(bool, Foo, (int), (Calltype(STDMETHODCALLTYPE)))
Const Method with Call Type in a Class Template
OldMOCK_CONST_METHOD1_T_WITH_CALLTYPE(STDMETHODCALLTYPE, Foo, bool(int))
NewMOCK_METHOD(bool, Foo, (int), (const, Calltype(STDMETHODCALLTYPE)))
+ +### The Nice, the Strict, and the Naggy {#NiceStrictNaggy} + +If a mock method has no `EXPECT_CALL` spec but is called, we say that it's an +"uninteresting call", and the default action (which can be specified using +`ON_CALL()`) of the method will be taken. Currently, an uninteresting call will +also by default cause gMock to print a warning. + +However, sometimes you may want to ignore these uninteresting calls, and +sometimes you may want to treat them as errors. gMock lets you make the decision +on a per-mock-object basis. + +Suppose your test uses a mock class `MockFoo`: + +```cpp +TEST(...) { + MockFoo mock_foo; + EXPECT_CALL(mock_foo, DoThis()); + ... code that uses mock_foo ... +} +``` + +If a method of `mock_foo` other than `DoThis()` is called, you will get a +warning. However, if you rewrite your test to use `NiceMock` instead, +you can suppress the warning: + +```cpp +using ::testing::NiceMock; + +TEST(...) { + NiceMock mock_foo; + EXPECT_CALL(mock_foo, DoThis()); + ... code that uses mock_foo ... +} +``` + +`NiceMock` is a subclass of `MockFoo`, so it can be used wherever +`MockFoo` is accepted. + +It also works if `MockFoo`'s constructor takes some arguments, as +`NiceMock` "inherits" `MockFoo`'s constructors: + +```cpp +using ::testing::NiceMock; + +TEST(...) { + NiceMock mock_foo(5, "hi"); // Calls MockFoo(5, "hi"). + EXPECT_CALL(mock_foo, DoThis()); + ... code that uses mock_foo ... +} +``` + +The usage of `StrictMock` is similar, except that it makes all uninteresting +calls failures: + +```cpp +using ::testing::StrictMock; + +TEST(...) { + StrictMock mock_foo; + EXPECT_CALL(mock_foo, DoThis()); + ... code that uses mock_foo ... + + // The test will fail if a method of mock_foo other than DoThis() + // is called. +} +``` + +{: .callout .note} +NOTE: `NiceMock` and `StrictMock` only affects *uninteresting* calls (calls of +*methods* with no expectations); they do not affect *unexpected* calls (calls of +methods with expectations, but they don't match). See +[Understanding Uninteresting vs Unexpected Calls](#uninteresting-vs-unexpected). + +There are some caveats though (sadly they are side effects of C++'s +limitations): + +1. `NiceMock` and `StrictMock` only work for mock methods + defined using the `MOCK_METHOD` macro **directly** in the `MockFoo` class. + If a mock method is defined in a **base class** of `MockFoo`, the "nice" or + "strict" modifier may not affect it, depending on the compiler. In + particular, nesting `NiceMock` and `StrictMock` (e.g. + `NiceMock >`) is **not** supported. +2. `NiceMock` and `StrictMock` may not work correctly if the + destructor of `MockFoo` is not virtual. We would like to fix this, but it + requires cleaning up existing tests. + +Finally, you should be **very cautious** about when to use naggy or strict +mocks, as they tend to make tests more brittle and harder to maintain. When you +refactor your code without changing its externally visible behavior, ideally you +shouldn't need to update any tests. If your code interacts with a naggy mock, +however, you may start to get spammed with warnings as the result of your +change. Worse, if your code interacts with a strict mock, your tests may start +to fail and you'll be forced to fix them. Our general recommendation is to use +nice mocks (not yet the default) most of the time, use naggy mocks (the current +default) when developing or debugging tests, and use strict mocks only as the +last resort. + +### Simplifying the Interface without Breaking Existing Code {#SimplerInterfaces} + +Sometimes a method has a long list of arguments that is mostly uninteresting. +For example: + +```cpp +class LogSink { + public: + ... + virtual void send(LogSeverity severity, const char* full_filename, + const char* base_filename, int line, + const struct tm* tm_time, + const char* message, size_t message_len) = 0; +}; +``` + +This method's argument list is lengthy and hard to work with (the `message` +argument is not even 0-terminated). If we mock it as is, using the mock will be +awkward. If, however, we try to simplify this interface, we'll need to fix all +clients depending on it, which is often infeasible. + +The trick is to redispatch the method in the mock class: + +```cpp +class ScopedMockLog : public LogSink { + public: + ... + void send(LogSeverity severity, const char* full_filename, + const char* base_filename, int line, const tm* tm_time, + const char* message, size_t message_len) override { + // We are only interested in the log severity, full file name, and + // log message. + Log(severity, full_filename, std::string(message, message_len)); + } + + // Implements the mock method: + // + // void Log(LogSeverity severity, + // const string& file_path, + // const string& message); + MOCK_METHOD(void, Log, + (LogSeverity severity, const string& file_path, + const string& message)); +}; +``` + +By defining a new mock method with a trimmed argument list, we make the mock +class more user-friendly. + +This technique may also be applied to make overloaded methods more amenable to +mocking. For example, when overloads have been used to implement default +arguments: + +```cpp +class MockTurtleFactory : public TurtleFactory { + public: + Turtle* MakeTurtle(int length, int weight) override { ... } + Turtle* MakeTurtle(int length, int weight, int speed) override { ... } + + // the above methods delegate to this one: + MOCK_METHOD(Turtle*, DoMakeTurtle, ()); +}; +``` + +This allows tests that don't care which overload was invoked to avoid specifying +argument matchers: + +```cpp +ON_CALL(factory, DoMakeTurtle) + .WillByDefault(Return(MakeMockTurtle())); +``` + +### Alternative to Mocking Concrete Classes + +Often you may find yourself using classes that don't implement interfaces. In +order to test your code that uses such a class (let's call it `Concrete`), you +may be tempted to make the methods of `Concrete` virtual and then mock it. + +Try not to do that. + +Making a non-virtual function virtual is a big decision. It creates an extension +point where subclasses can tweak your class' behavior. This weakens your control +on the class because now it's harder to maintain the class invariants. You +should make a function virtual only when there is a valid reason for a subclass +to override it. + +Mocking concrete classes directly is problematic as it creates a tight coupling +between the class and the tests - any small change in the class may invalidate +your tests and make test maintenance a pain. + +To avoid such problems, many programmers have been practicing "coding to +interfaces": instead of talking to the `Concrete` class, your code would define +an interface and talk to it. Then you implement that interface as an adaptor on +top of `Concrete`. In tests, you can easily mock that interface to observe how +your code is doing. + +This technique incurs some overhead: + +* You pay the cost of virtual function calls (usually not a problem). +* There is more abstraction for the programmers to learn. + +However, it can also bring significant benefits in addition to better +testability: + +* `Concrete`'s API may not fit your problem domain very well, as you may not + be the only client it tries to serve. By designing your own interface, you + have a chance to tailor it to your need - you may add higher-level + functionalities, rename stuff, etc instead of just trimming the class. This + allows you to write your code (user of the interface) in a more natural way, + which means it will be more readable, more maintainable, and you'll be more + productive. +* If `Concrete`'s implementation ever has to change, you don't have to rewrite + everywhere it is used. Instead, you can absorb the change in your + implementation of the interface, and your other code and tests will be + insulated from this change. + +Some people worry that if everyone is practicing this technique, they will end +up writing lots of redundant code. This concern is totally understandable. +However, there are two reasons why it may not be the case: + +* Different projects may need to use `Concrete` in different ways, so the best + interfaces for them will be different. Therefore, each of them will have its + own domain-specific interface on top of `Concrete`, and they will not be the + same code. +* If enough projects want to use the same interface, they can always share it, + just like they have been sharing `Concrete`. You can check in the interface + and the adaptor somewhere near `Concrete` (perhaps in a `contrib` + sub-directory) and let many projects use it. + +You need to weigh the pros and cons carefully for your particular problem, but +I'd like to assure you that the Java community has been practicing this for a +long time and it's a proven effective technique applicable in a wide variety of +situations. :-) + +### Delegating Calls to a Fake {#DelegatingToFake} + +Some times you have a non-trivial fake implementation of an interface. For +example: + +```cpp +class Foo { + public: + virtual ~Foo() {} + virtual char DoThis(int n) = 0; + virtual void DoThat(const char* s, int* p) = 0; +}; + +class FakeFoo : public Foo { + public: + char DoThis(int n) override { + return (n > 0) ? '+' : + (n < 0) ? '-' : '0'; + } + + void DoThat(const char* s, int* p) override { + *p = strlen(s); + } +}; +``` + +Now you want to mock this interface such that you can set expectations on it. +However, you also want to use `FakeFoo` for the default behavior, as duplicating +it in the mock object is, well, a lot of work. + +When you define the mock class using gMock, you can have it delegate its default +action to a fake class you already have, using this pattern: + +```cpp +class MockFoo : public Foo { + public: + // Normal mock method definitions using gMock. + MOCK_METHOD(char, DoThis, (int n), (override)); + MOCK_METHOD(void, DoThat, (const char* s, int* p), (override)); + + // Delegates the default actions of the methods to a FakeFoo object. + // This must be called *before* the custom ON_CALL() statements. + void DelegateToFake() { + ON_CALL(*this, DoThis).WillByDefault([this](int n) { + return fake_.DoThis(n); + }); + ON_CALL(*this, DoThat).WillByDefault([this](const char* s, int* p) { + fake_.DoThat(s, p); + }); + } + + private: + FakeFoo fake_; // Keeps an instance of the fake in the mock. +}; +``` + +With that, you can use `MockFoo` in your tests as usual. Just remember that if +you don't explicitly set an action in an `ON_CALL()` or `EXPECT_CALL()`, the +fake will be called upon to do it.: + +```cpp +using ::testing::_; + +TEST(AbcTest, Xyz) { + MockFoo foo; + + foo.DelegateToFake(); // Enables the fake for delegation. + + // Put your ON_CALL(foo, ...)s here, if any. + + // No action specified, meaning to use the default action. + EXPECT_CALL(foo, DoThis(5)); + EXPECT_CALL(foo, DoThat(_, _)); + + int n = 0; + EXPECT_EQ(foo.DoThis(5), '+'); // FakeFoo::DoThis() is invoked. + foo.DoThat("Hi", &n); // FakeFoo::DoThat() is invoked. + EXPECT_EQ(n, 2); +} +``` + +**Some tips:** + +* If you want, you can still override the default action by providing your own + `ON_CALL()` or using `.WillOnce()` / `.WillRepeatedly()` in `EXPECT_CALL()`. +* In `DelegateToFake()`, you only need to delegate the methods whose fake + implementation you intend to use. + +* The general technique discussed here works for overloaded methods, but + you'll need to tell the compiler which version you mean. To disambiguate a + mock function (the one you specify inside the parentheses of `ON_CALL()`), + use [this technique](#SelectOverload); to disambiguate a fake function (the + one you place inside `Invoke()`), use a `static_cast` to specify the + function's type. For instance, if class `Foo` has methods `char DoThis(int + n)` and `bool DoThis(double x) const`, and you want to invoke the latter, + you need to write `Invoke(&fake_, static_cast(&FakeFoo::DoThis))` instead of `Invoke(&fake_, &FakeFoo::DoThis)` + (The strange-looking thing inside the angled brackets of `static_cast` is + the type of a function pointer to the second `DoThis()` method.). + +* Having to mix a mock and a fake is often a sign of something gone wrong. + Perhaps you haven't got used to the interaction-based way of testing yet. Or + perhaps your interface is taking on too many roles and should be split up. + Therefore, **don't abuse this**. We would only recommend to do it as an + intermediate step when you are refactoring your code. + +Regarding the tip on mixing a mock and a fake, here's an example on why it may +be a bad sign: Suppose you have a class `System` for low-level system +operations. In particular, it does file and I/O operations. And suppose you want +to test how your code uses `System` to do I/O, and you just want the file +operations to work normally. If you mock out the entire `System` class, you'll +have to provide a fake implementation for the file operation part, which +suggests that `System` is taking on too many roles. + +Instead, you can define a `FileOps` interface and an `IOOps` interface and split +`System`'s functionalities into the two. Then you can mock `IOOps` without +mocking `FileOps`. + +### Delegating Calls to a Real Object + +When using testing doubles (mocks, fakes, stubs, and etc), sometimes their +behaviors will differ from those of the real objects. This difference could be +either intentional (as in simulating an error such that you can test the error +handling code) or unintentional. If your mocks have different behaviors than the +real objects by mistake, you could end up with code that passes the tests but +fails in production. + +You can use the *delegating-to-real* technique to ensure that your mock has the +same behavior as the real object while retaining the ability to validate calls. +This technique is very similar to the [delegating-to-fake](#DelegatingToFake) +technique, the difference being that we use a real object instead of a fake. +Here's an example: + +```cpp +using ::testing::AtLeast; + +class MockFoo : public Foo { + public: + MockFoo() { + // By default, all calls are delegated to the real object. + ON_CALL(*this, DoThis).WillByDefault([this](int n) { + return real_.DoThis(n); + }); + ON_CALL(*this, DoThat).WillByDefault([this](const char* s, int* p) { + real_.DoThat(s, p); + }); + ... + } + MOCK_METHOD(char, DoThis, ...); + MOCK_METHOD(void, DoThat, ...); + ... + private: + Foo real_; +}; + +... + MockFoo mock; + EXPECT_CALL(mock, DoThis()) + .Times(3); + EXPECT_CALL(mock, DoThat("Hi")) + .Times(AtLeast(1)); + ... use mock in test ... +``` + +With this, gMock will verify that your code made the right calls (with the right +arguments, in the right order, called the right number of times, etc), and a +real object will answer the calls (so the behavior will be the same as in +production). This gives you the best of both worlds. + +### Delegating Calls to a Parent Class + +Ideally, you should code to interfaces, whose methods are all pure virtual. In +reality, sometimes you do need to mock a virtual method that is not pure (i.e, +it already has an implementation). For example: + +```cpp +class Foo { + public: + virtual ~Foo(); + + virtual void Pure(int n) = 0; + virtual int Concrete(const char* str) { ... } +}; + +class MockFoo : public Foo { + public: + // Mocking a pure method. + MOCK_METHOD(void, Pure, (int n), (override)); + // Mocking a concrete method. Foo::Concrete() is shadowed. + MOCK_METHOD(int, Concrete, (const char* str), (override)); +}; +``` + +Sometimes you may want to call `Foo::Concrete()` instead of +`MockFoo::Concrete()`. Perhaps you want to do it as part of a stub action, or +perhaps your test doesn't need to mock `Concrete()` at all (but it would be +oh-so painful to have to define a new mock class whenever you don't need to mock +one of its methods). + +You can call `Foo::Concrete()` inside an action by: + +```cpp +... + EXPECT_CALL(foo, Concrete).WillOnce([&foo](const char* str) { + return foo.Foo::Concrete(str); + }); +``` + +or tell the mock object that you don't want to mock `Concrete()`: + +```cpp +... + ON_CALL(foo, Concrete).WillByDefault([&foo](const char* str) { + return foo.Foo::Concrete(str); + }); +``` + +(Why don't we just write `{ return foo.Concrete(str); }`? If you do that, +`MockFoo::Concrete()` will be called (and cause an infinite recursion) since +`Foo::Concrete()` is virtual. That's just how C++ works.) + +## Using Matchers + +### Matching Argument Values Exactly + +You can specify exactly which arguments a mock method is expecting: + +```cpp +using ::testing::Return; +... + EXPECT_CALL(foo, DoThis(5)) + .WillOnce(Return('a')); + EXPECT_CALL(foo, DoThat("Hello", bar)); +``` + +### Using Simple Matchers + +You can use matchers to match arguments that have a certain property: + +```cpp +using ::testing::NotNull; +using ::testing::Return; +... + EXPECT_CALL(foo, DoThis(Ge(5))) // The argument must be >= 5. + .WillOnce(Return('a')); + EXPECT_CALL(foo, DoThat("Hello", NotNull())); + // The second argument must not be NULL. +``` + +A frequently used matcher is `_`, which matches anything: + +```cpp + EXPECT_CALL(foo, DoThat(_, NotNull())); +``` + +### Combining Matchers {#CombiningMatchers} + +You can build complex matchers from existing ones using `AllOf()`, +`AllOfArray()`, `AnyOf()`, `AnyOfArray()` and `Not()`: + +```cpp +using ::testing::AllOf; +using ::testing::Gt; +using ::testing::HasSubstr; +using ::testing::Ne; +using ::testing::Not; +... + // The argument must be > 5 and != 10. + EXPECT_CALL(foo, DoThis(AllOf(Gt(5), + Ne(10)))); + + // The first argument must not contain sub-string "blah". + EXPECT_CALL(foo, DoThat(Not(HasSubstr("blah")), + NULL)); +``` + +Matchers are function objects, and parametrized matchers can be composed just +like any other function. However because their types can be long and rarely +provide meaningful information, it can be easier to express them with C++14 +generic lambdas to avoid specifying types. For example, + +```cpp +using ::testing::Contains; +using ::testing::Property; + +inline constexpr auto HasFoo = [](const auto& f) { + return Property("foo", &MyClass::foo, Contains(f)); +}; +... + EXPECT_THAT(x, HasFoo("blah")); +``` + +### Casting Matchers {#SafeMatcherCast} + +gMock matchers are statically typed, meaning that the compiler can catch your +mistake if you use a matcher of the wrong type (for example, if you use `Eq(5)` +to match a `string` argument). Good for you! + +Sometimes, however, you know what you're doing and want the compiler to give you +some slack. One example is that you have a matcher for `long` and the argument +you want to match is `int`. While the two types aren't exactly the same, there +is nothing really wrong with using a `Matcher` to match an `int` - after +all, we can first convert the `int` argument to a `long` losslessly before +giving it to the matcher. + +To support this need, gMock gives you the `SafeMatcherCast(m)` function. It +casts a matcher `m` to type `Matcher`. To ensure safety, gMock checks that +(let `U` be the type `m` accepts : + +1. Type `T` can be *implicitly* cast to type `U`; +2. When both `T` and `U` are built-in arithmetic types (`bool`, integers, and + floating-point numbers), the conversion from `T` to `U` is not lossy (in + other words, any value representable by `T` can also be represented by `U`); + and +3. When `U` is a reference, `T` must also be a reference (as the underlying + matcher may be interested in the address of the `U` value). + +The code won't compile if any of these conditions isn't met. + +Here's one example: + +```cpp +using ::testing::SafeMatcherCast; + +// A base class and a child class. +class Base { ... }; +class Derived : public Base { ... }; + +class MockFoo : public Foo { + public: + MOCK_METHOD(void, DoThis, (Derived* derived), (override)); +}; + +... + MockFoo foo; + // m is a Matcher we got from somewhere. + EXPECT_CALL(foo, DoThis(SafeMatcherCast(m))); +``` + +If you find `SafeMatcherCast(m)` too limiting, you can use a similar function +`MatcherCast(m)`. The difference is that `MatcherCast` works as long as you +can `static_cast` type `T` to type `U`. + +`MatcherCast` essentially lets you bypass C++'s type system (`static_cast` isn't +always safe as it could throw away information, for example), so be careful not +to misuse/abuse it. + +### Selecting Between Overloaded Functions {#SelectOverload} + +If you expect an overloaded function to be called, the compiler may need some +help on which overloaded version it is. + +To disambiguate functions overloaded on the const-ness of this object, use the +`Const()` argument wrapper. + +```cpp +using ::testing::ReturnRef; + +class MockFoo : public Foo { + ... + MOCK_METHOD(Bar&, GetBar, (), (override)); + MOCK_METHOD(const Bar&, GetBar, (), (const, override)); +}; + +... + MockFoo foo; + Bar bar1, bar2; + EXPECT_CALL(foo, GetBar()) // The non-const GetBar(). + .WillOnce(ReturnRef(bar1)); + EXPECT_CALL(Const(foo), GetBar()) // The const GetBar(). + .WillOnce(ReturnRef(bar2)); +``` + +(`Const()` is defined by gMock and returns a `const` reference to its argument.) + +To disambiguate overloaded functions with the same number of arguments but +different argument types, you may need to specify the exact type of a matcher, +either by wrapping your matcher in `Matcher()`, or using a matcher whose +type is fixed (`TypedEq`, `An()`, etc): + +```cpp +using ::testing::An; +using ::testing::Matcher; +using ::testing::TypedEq; + +class MockPrinter : public Printer { + public: + MOCK_METHOD(void, Print, (int n), (override)); + MOCK_METHOD(void, Print, (char c), (override)); +}; + +TEST(PrinterTest, Print) { + MockPrinter printer; + + EXPECT_CALL(printer, Print(An())); // void Print(int); + EXPECT_CALL(printer, Print(Matcher(Lt(5)))); // void Print(int); + EXPECT_CALL(printer, Print(TypedEq('a'))); // void Print(char); + + printer.Print(3); + printer.Print(6); + printer.Print('a'); +} +``` + +### Performing Different Actions Based on the Arguments + +When a mock method is called, the *last* matching expectation that's still +active will be selected (think "newer overrides older"). So, you can make a +method do different things depending on its argument values like this: + +```cpp +using ::testing::_; +using ::testing::Lt; +using ::testing::Return; +... + // The default case. + EXPECT_CALL(foo, DoThis(_)) + .WillRepeatedly(Return('b')); + // The more specific case. + EXPECT_CALL(foo, DoThis(Lt(5))) + .WillRepeatedly(Return('a')); +``` + +Now, if `foo.DoThis()` is called with a value less than 5, `'a'` will be +returned; otherwise `'b'` will be returned. + +### Matching Multiple Arguments as a Whole + +Sometimes it's not enough to match the arguments individually. For example, we +may want to say that the first argument must be less than the second argument. +The `With()` clause allows us to match all arguments of a mock function as a +whole. For example, + +```cpp +using ::testing::_; +using ::testing::Ne; +using ::testing::Lt; +... + EXPECT_CALL(foo, InRange(Ne(0), _)) + .With(Lt()); +``` + +says that the first argument of `InRange()` must not be 0, and must be less than +the second argument. + +The expression inside `With()` must be a matcher of type `Matcher>`, where `A1`, ..., `An` are the types of the function arguments. + +You can also write `AllArgs(m)` instead of `m` inside `.With()`. The two forms +are equivalent, but `.With(AllArgs(Lt()))` is more readable than `.With(Lt())`. + +You can use `Args(m)` to match the `n` selected arguments (as a +tuple) against `m`. For example, + +```cpp +using ::testing::_; +using ::testing::AllOf; +using ::testing::Args; +using ::testing::Lt; +... + EXPECT_CALL(foo, Blah) + .With(AllOf(Args<0, 1>(Lt()), Args<1, 2>(Lt()))); +``` + +says that `Blah` will be called with arguments `x`, `y`, and `z` where `x < y < +z`. Note that in this example, it wasn't necessary to specify the positional +matchers. + +As a convenience and example, gMock provides some matchers for 2-tuples, +including the `Lt()` matcher above. See +[Multi-argument Matchers](reference/matchers.md#MultiArgMatchers) for the +complete list. + +Note that if you want to pass the arguments to a predicate of your own (e.g. +`.With(Args<0, 1>(Truly(&MyPredicate)))`), that predicate MUST be written to +take a `std::tuple` as its argument; gMock will pass the `n` selected arguments +as *one* single tuple to the predicate. + +### Using Matchers as Predicates + +Have you noticed that a matcher is just a fancy predicate that also knows how to +describe itself? Many existing algorithms take predicates as arguments (e.g. +those defined in STL's `` header), and it would be a shame if gMock +matchers were not allowed to participate. + +Luckily, you can use a matcher where a unary predicate functor is expected by +wrapping it inside the `Matches()` function. For example, + +```cpp +#include +#include + +using ::testing::Matches; +using ::testing::Ge; + +vector v; +... +// How many elements in v are >= 10? +const int count = count_if(v.begin(), v.end(), Matches(Ge(10))); +``` + +Since you can build complex matchers from simpler ones easily using gMock, this +gives you a way to conveniently construct composite predicates (doing the same +using STL's `` header is just painful). For example, here's a +predicate that's satisfied by any number that is >= 0, <= 100, and != 50: + +```cpp +using ::testing::AllOf; +using ::testing::Ge; +using ::testing::Le; +using ::testing::Matches; +using ::testing::Ne; +... +Matches(AllOf(Ge(0), Le(100), Ne(50))) +``` + +### Using Matchers in googletest Assertions + +See [`EXPECT_THAT`](reference/assertions.md#EXPECT_THAT) in the Assertions +Reference. + +### Using Predicates as Matchers + +gMock provides a set of built-in matchers for matching arguments with expected +values—see the [Matchers Reference](reference/matchers.md) for more information. +In case you find the built-in set lacking, you can use an arbitrary unary +predicate function or functor as a matcher - as long as the predicate accepts a +value of the type you want. You do this by wrapping the predicate inside the +`Truly()` function, for example: + +```cpp +using ::testing::Truly; + +int IsEven(int n) { return (n % 2) == 0 ? 1 : 0; } +... + // Bar() must be called with an even number. + EXPECT_CALL(foo, Bar(Truly(IsEven))); +``` + +Note that the predicate function / functor doesn't have to return `bool`. It +works as long as the return value can be used as the condition in the statement +`if (condition) ...`. + +### Matching Arguments that Are Not Copyable + +When you do an `EXPECT_CALL(mock_obj, Foo(bar))`, gMock saves away a copy of +`bar`. When `Foo()` is called later, gMock compares the argument to `Foo()` with +the saved copy of `bar`. This way, you don't need to worry about `bar` being +modified or destroyed after the `EXPECT_CALL()` is executed. The same is true +when you use matchers like `Eq(bar)`, `Le(bar)`, and so on. + +But what if `bar` cannot be copied (i.e. has no copy constructor)? You could +define your own matcher function or callback and use it with `Truly()`, as the +previous couple of recipes have shown. Or, you may be able to get away from it +if you can guarantee that `bar` won't be changed after the `EXPECT_CALL()` is +executed. Just tell gMock that it should save a reference to `bar`, instead of a +copy of it. Here's how: + +```cpp +using ::testing::Eq; +using ::testing::Lt; +... + // Expects that Foo()'s argument == bar. + EXPECT_CALL(mock_obj, Foo(Eq(std::ref(bar)))); + + // Expects that Foo()'s argument < bar. + EXPECT_CALL(mock_obj, Foo(Lt(std::ref(bar)))); +``` + +Remember: if you do this, don't change `bar` after the `EXPECT_CALL()`, or the +result is undefined. + +### Validating a Member of an Object + +Often a mock function takes a reference to object as an argument. When matching +the argument, you may not want to compare the entire object against a fixed +object, as that may be over-specification. Instead, you may need to validate a +certain member variable or the result of a certain getter method of the object. +You can do this with `Field()` and `Property()`. More specifically, + +```cpp +Field(&Foo::bar, m) +``` + +is a matcher that matches a `Foo` object whose `bar` member variable satisfies +matcher `m`. + +```cpp +Property(&Foo::baz, m) +``` + +is a matcher that matches a `Foo` object whose `baz()` method returns a value +that satisfies matcher `m`. + +For example: + +| Expression | Description | +| :--------------------------- | :--------------------------------------- | +| `Field(&Foo::number, Ge(3))` | Matches `x` where `x.number >= 3`. | +| `Property(&Foo::name, StartsWith("John "))` | Matches `x` where `x.name()` starts with `"John "`. | + +Note that in `Property(&Foo::baz, ...)`, method `baz()` must take no argument +and be declared as `const`. Don't use `Property()` against member functions that +you do not own, because taking addresses of functions is fragile and generally +not part of the contract of the function. + +`Field()` and `Property()` can also match plain pointers to objects. For +instance, + +```cpp +using ::testing::Field; +using ::testing::Ge; +... +Field(&Foo::number, Ge(3)) +``` + +matches a plain pointer `p` where `p->number >= 3`. If `p` is `NULL`, the match +will always fail regardless of the inner matcher. + +What if you want to validate more than one members at the same time? Remember +that there are [`AllOf()` and `AllOfArray()`](#CombiningMatchers). + +Finally `Field()` and `Property()` provide overloads that take the field or +property names as the first argument to include it in the error message. This +can be useful when creating combined matchers. + +```cpp +using ::testing::AllOf; +using ::testing::Field; +using ::testing::Matcher; +using ::testing::SafeMatcherCast; + +Matcher IsFoo(const Foo& foo) { + return AllOf(Field("some_field", &Foo::some_field, foo.some_field), + Field("other_field", &Foo::other_field, foo.other_field), + Field("last_field", &Foo::last_field, foo.last_field)); +} +``` + +### Validating the Value Pointed to by a Pointer Argument + +C++ functions often take pointers as arguments. You can use matchers like +`IsNull()`, `NotNull()`, and other comparison matchers to match a pointer, but +what if you want to make sure the value *pointed to* by the pointer, instead of +the pointer itself, has a certain property? Well, you can use the `Pointee(m)` +matcher. + +`Pointee(m)` matches a pointer if and only if `m` matches the value the pointer +points to. For example: + +```cpp +using ::testing::Ge; +using ::testing::Pointee; +... + EXPECT_CALL(foo, Bar(Pointee(Ge(3)))); +``` + +expects `foo.Bar()` to be called with a pointer that points to a value greater +than or equal to 3. + +One nice thing about `Pointee()` is that it treats a `NULL` pointer as a match +failure, so you can write `Pointee(m)` instead of + +```cpp +using ::testing::AllOf; +using ::testing::NotNull; +using ::testing::Pointee; +... + AllOf(NotNull(), Pointee(m)) +``` + +without worrying that a `NULL` pointer will crash your test. + +Also, did we tell you that `Pointee()` works with both raw pointers **and** +smart pointers (`std::unique_ptr`, `std::shared_ptr`, etc)? + +What if you have a pointer to pointer? You guessed it - you can use nested +`Pointee()` to probe deeper inside the value. For example, +`Pointee(Pointee(Lt(3)))` matches a pointer that points to a pointer that points +to a number less than 3 (what a mouthful...). + +### Defining a Custom Matcher Class {#CustomMatcherClass} + +Most matchers can be simply defined using [the MATCHER* macros](#NewMatchers), +which are terse and flexible, and produce good error messages. However, these +macros are not very explicit about the interfaces they create and are not always +suitable, especially for matchers that will be widely reused. + +For more advanced cases, you may need to define your own matcher class. A custom +matcher allows you to test a specific invariant property of that object. Let's +take a look at how to do so. + +Imagine you have a mock function that takes an object of type `Foo`, which has +an `int bar()` method and an `int baz()` method. You want to constrain that the +argument's `bar()` value plus its `baz()` value is a given number. (This is an +invariant.) Here's how we can write and use a matcher class to do so: + +```cpp +class BarPlusBazEqMatcher { + public: + using is_gtest_matcher = void; + + explicit BarPlusBazEqMatcher(int expected_sum) + : expected_sum_(expected_sum) {} + + bool MatchAndExplain(const Foo& foo, + std::ostream* /* listener */) const { + return (foo.bar() + foo.baz()) == expected_sum_; + } + + void DescribeTo(std::ostream* os) const { + *os << "bar() + baz() equals " << expected_sum_; + } + + void DescribeNegationTo(std::ostream* os) const { + *os << "bar() + baz() does not equal " << expected_sum_; + } + private: + const int expected_sum_; +}; + +::testing::Matcher BarPlusBazEq(int expected_sum) { + return BarPlusBazEqMatcher(expected_sum); +} + +... + Foo foo; + EXPECT_THAT(foo, BarPlusBazEq(5))...; +``` + +### Matching Containers + +Sometimes an STL container (e.g. list, vector, map, ...) is passed to a mock +function and you may want to validate it. Since most STL containers support the +`==` operator, you can write `Eq(expected_container)` or simply +`expected_container` to match a container exactly. + +Sometimes, though, you may want to be more flexible (for example, the first +element must be an exact match, but the second element can be any positive +number, and so on). Also, containers used in tests often have a small number of +elements, and having to define the expected container out-of-line is a bit of a +hassle. + +You can use the `ElementsAre()` or `UnorderedElementsAre()` matcher in such +cases: + +```cpp +using ::testing::_; +using ::testing::ElementsAre; +using ::testing::Gt; +... + MOCK_METHOD(void, Foo, (const vector& numbers), (override)); +... + EXPECT_CALL(mock, Foo(ElementsAre(1, Gt(0), _, 5))); +``` + +The above matcher says that the container must have 4 elements, which must be 1, +greater than 0, anything, and 5 respectively. + +If you instead write: + +```cpp +using ::testing::_; +using ::testing::Gt; +using ::testing::UnorderedElementsAre; +... + MOCK_METHOD(void, Foo, (const vector& numbers), (override)); +... + EXPECT_CALL(mock, Foo(UnorderedElementsAre(1, Gt(0), _, 5))); +``` + +It means that the container must have 4 elements, which (under some permutation) +must be 1, greater than 0, anything, and 5 respectively. + +As an alternative you can place the arguments in a C-style array and use +`ElementsAreArray()` or `UnorderedElementsAreArray()` instead: + +```cpp +using ::testing::ElementsAreArray; +... + // ElementsAreArray accepts an array of element values. + const int expected_vector1[] = {1, 5, 2, 4, ...}; + EXPECT_CALL(mock, Foo(ElementsAreArray(expected_vector1))); + + // Or, an array of element matchers. + Matcher expected_vector2[] = {1, Gt(2), _, 3, ...}; + EXPECT_CALL(mock, Foo(ElementsAreArray(expected_vector2))); +``` + +In case the array needs to be dynamically created (and therefore the array size +cannot be inferred by the compiler), you can give `ElementsAreArray()` an +additional argument to specify the array size: + +```cpp +using ::testing::ElementsAreArray; +... + int* const expected_vector3 = new int[count]; + ... fill expected_vector3 with values ... + EXPECT_CALL(mock, Foo(ElementsAreArray(expected_vector3, count))); +``` + +Use `Pair` when comparing maps or other associative containers. + +{% raw %} + +```cpp +using ::testing::UnorderedElementsAre; +using ::testing::Pair; +... + absl::flat_hash_map m = {{"a", 1}, {"b", 2}, {"c", 3}}; + EXPECT_THAT(m, UnorderedElementsAre( + Pair("a", 1), Pair("b", 2), Pair("c", 3))); +``` + +{% endraw %} + +**Tips:** + +* `ElementsAre*()` can be used to match *any* container that implements the + STL iterator pattern (i.e. it has a `const_iterator` type and supports + `begin()/end()`), not just the ones defined in STL. It will even work with + container types yet to be written - as long as they follows the above + pattern. +* You can use nested `ElementsAre*()` to match nested (multi-dimensional) + containers. +* If the container is passed by pointer instead of by reference, just write + `Pointee(ElementsAre*(...))`. +* The order of elements *matters* for `ElementsAre*()`. If you are using it + with containers whose element order are undefined (such as a + `std::unordered_map`) you should use `UnorderedElementsAre`. + +### Sharing Matchers + +Under the hood, a gMock matcher object consists of a pointer to a ref-counted +implementation object. Copying matchers is allowed and very efficient, as only +the pointer is copied. When the last matcher that references the implementation +object dies, the implementation object will be deleted. + +Therefore, if you have some complex matcher that you want to use again and +again, there is no need to build it every time. Just assign it to a matcher +variable and use that variable repeatedly! For example, + +```cpp +using ::testing::AllOf; +using ::testing::Gt; +using ::testing::Le; +using ::testing::Matcher; +... + Matcher in_range = AllOf(Gt(5), Le(10)); + ... use in_range as a matcher in multiple EXPECT_CALLs ... +``` + +### Matchers must have no side-effects {#PureMatchers} + +{: .callout .warning} +WARNING: gMock does not guarantee when or how many times a matcher will be +invoked. Therefore, all matchers must be *purely functional*: they cannot have +any side effects, and the match result must not depend on anything other than +the matcher's parameters and the value being matched. + +This requirement must be satisfied no matter how a matcher is defined (e.g., if +it is one of the standard matchers, or a custom matcher). In particular, a +matcher can never call a mock function, as that will affect the state of the +mock object and gMock. + +## Setting Expectations + +### Knowing When to Expect {#UseOnCall} + +**`ON_CALL`** is likely the *single most under-utilized construct* in gMock. + +There are basically two constructs for defining the behavior of a mock object: +`ON_CALL` and `EXPECT_CALL`. The difference? `ON_CALL` defines what happens when +a mock method is called, but doesn't imply any expectation on the method +being called. `EXPECT_CALL` not only defines the behavior, but also sets an +expectation that the method will be called with the given arguments, for the +given number of times (and *in the given order* when you specify the order +too). + +Since `EXPECT_CALL` does more, isn't it better than `ON_CALL`? Not really. Every +`EXPECT_CALL` adds a constraint on the behavior of the code under test. Having +more constraints than necessary is *baaad* - even worse than not having enough +constraints. + +This may be counter-intuitive. How could tests that verify more be worse than +tests that verify less? Isn't verification the whole point of tests? + +The answer lies in *what* a test should verify. **A good test verifies the +contract of the code.** If a test over-specifies, it doesn't leave enough +freedom to the implementation. As a result, changing the implementation without +breaking the contract (e.g. refactoring and optimization), which should be +perfectly fine to do, can break such tests. Then you have to spend time fixing +them, only to see them broken again the next time the implementation is changed. + +Keep in mind that one doesn't have to verify more than one property in one test. +In fact, **it's a good style to verify only one thing in one test.** If you do +that, a bug will likely break only one or two tests instead of dozens (which +case would you rather debug?). If you are also in the habit of giving tests +descriptive names that tell what they verify, you can often easily guess what's +wrong just from the test log itself. + +So use `ON_CALL` by default, and only use `EXPECT_CALL` when you actually intend +to verify that the call is made. For example, you may have a bunch of `ON_CALL`s +in your test fixture to set the common mock behavior shared by all tests in the +same group, and write (scarcely) different `EXPECT_CALL`s in different `TEST_F`s +to verify different aspects of the code's behavior. Compared with the style +where each `TEST` has many `EXPECT_CALL`s, this leads to tests that are more +resilient to implementational changes (and thus less likely to require +maintenance) and makes the intent of the tests more obvious (so they are easier +to maintain when you do need to maintain them). + +If you are bothered by the "Uninteresting mock function call" message printed +when a mock method without an `EXPECT_CALL` is called, you may use a `NiceMock` +instead to suppress all such messages for the mock object, or suppress the +message for specific methods by adding `EXPECT_CALL(...).Times(AnyNumber())`. DO +NOT suppress it by blindly adding an `EXPECT_CALL(...)`, or you'll have a test +that's a pain to maintain. + +### Ignoring Uninteresting Calls + +If you are not interested in how a mock method is called, just don't say +anything about it. In this case, if the method is ever called, gMock will +perform its default action to allow the test program to continue. If you are not +happy with the default action taken by gMock, you can override it using +`DefaultValue::Set()` (described [here](#DefaultValue)) or `ON_CALL()`. + +Please note that once you expressed interest in a particular mock method (via +`EXPECT_CALL()`), all invocations to it must match some expectation. If this +function is called but the arguments don't match any `EXPECT_CALL()` statement, +it will be an error. + +### Disallowing Unexpected Calls + +If a mock method shouldn't be called at all, explicitly say so: + +```cpp +using ::testing::_; +... + EXPECT_CALL(foo, Bar(_)) + .Times(0); +``` + +If some calls to the method are allowed, but the rest are not, just list all the +expected calls: + +```cpp +using ::testing::AnyNumber; +using ::testing::Gt; +... + EXPECT_CALL(foo, Bar(5)); + EXPECT_CALL(foo, Bar(Gt(10))) + .Times(AnyNumber()); +``` + +A call to `foo.Bar()` that doesn't match any of the `EXPECT_CALL()` statements +will be an error. + +### Understanding Uninteresting vs Unexpected Calls {#uninteresting-vs-unexpected} + +*Uninteresting* calls and *unexpected* calls are different concepts in gMock. +*Very* different. + +A call `x.Y(...)` is **uninteresting** if there's *not even a single* +`EXPECT_CALL(x, Y(...))` set. In other words, the test isn't interested in the +`x.Y()` method at all, as evident in that the test doesn't care to say anything +about it. + +A call `x.Y(...)` is **unexpected** if there are *some* `EXPECT_CALL(x, +Y(...))`s set, but none of them matches the call. Put another way, the test is +interested in the `x.Y()` method (therefore it explicitly sets some +`EXPECT_CALL` to verify how it's called); however, the verification fails as the +test doesn't expect this particular call to happen. + +**An unexpected call is always an error,** as the code under test doesn't behave +the way the test expects it to behave. + +**By default, an uninteresting call is not an error,** as it violates no +constraint specified by the test. (gMock's philosophy is that saying nothing +means there is no constraint.) However, it leads to a warning, as it *might* +indicate a problem (e.g. the test author might have forgotten to specify a +constraint). + +In gMock, `NiceMock` and `StrictMock` can be used to make a mock class "nice" or +"strict". How does this affect uninteresting calls and unexpected calls? + +A **nice mock** suppresses uninteresting call *warnings*. It is less chatty than +the default mock, but otherwise is the same. If a test fails with a default +mock, it will also fail using a nice mock instead. And vice versa. Don't expect +making a mock nice to change the test's result. + +A **strict mock** turns uninteresting call warnings into errors. So making a +mock strict may change the test's result. + +Let's look at an example: + +```cpp +TEST(...) { + NiceMock mock_registry; + EXPECT_CALL(mock_registry, GetDomainOwner("google.com")) + .WillRepeatedly(Return("Larry Page")); + + // Use mock_registry in code under test. + ... &mock_registry ... +} +``` + +The sole `EXPECT_CALL` here says that all calls to `GetDomainOwner()` must have +`"google.com"` as the argument. If `GetDomainOwner("yahoo.com")` is called, it +will be an unexpected call, and thus an error. *Having a nice mock doesn't +change the severity of an unexpected call.* + +So how do we tell gMock that `GetDomainOwner()` can be called with some other +arguments as well? The standard technique is to add a "catch all" `EXPECT_CALL`: + +```cpp + EXPECT_CALL(mock_registry, GetDomainOwner(_)) + .Times(AnyNumber()); // catches all other calls to this method. + EXPECT_CALL(mock_registry, GetDomainOwner("google.com")) + .WillRepeatedly(Return("Larry Page")); +``` + +Remember that `_` is the wildcard matcher that matches anything. With this, if +`GetDomainOwner("google.com")` is called, it will do what the second +`EXPECT_CALL` says; if it is called with a different argument, it will do what +the first `EXPECT_CALL` says. + +Note that the order of the two `EXPECT_CALL`s is important, as a newer +`EXPECT_CALL` takes precedence over an older one. + +For more on uninteresting calls, nice mocks, and strict mocks, read +["The Nice, the Strict, and the Naggy"](#NiceStrictNaggy). + +### Ignoring Uninteresting Arguments {#ParameterlessExpectations} + +If your test doesn't care about the parameters (it only cares about the number +or order of calls), you can often simply omit the parameter list: + +```cpp + // Expect foo.Bar( ... ) twice with any arguments. + EXPECT_CALL(foo, Bar).Times(2); + + // Delegate to the given method whenever the factory is invoked. + ON_CALL(foo_factory, MakeFoo) + .WillByDefault(&BuildFooForTest); +``` + +This functionality is only available when a method is not overloaded; to prevent +unexpected behavior it is a compilation error to try to set an expectation on a +method where the specific overload is ambiguous. You can work around this by +supplying a [simpler mock interface](#SimplerInterfaces) than the mocked class +provides. + +This pattern is also useful when the arguments are interesting, but match logic +is substantially complex. You can leave the argument list unspecified and use +SaveArg actions to [save the values for later verification](#SaveArgVerify). If +you do that, you can easily differentiate calling the method the wrong number of +times from calling it with the wrong arguments. + +### Expecting Ordered Calls {#OrderedCalls} + +Although an `EXPECT_CALL()` statement defined later takes precedence when gMock +tries to match a function call with an expectation, by default calls don't have +to happen in the order `EXPECT_CALL()` statements are written. For example, if +the arguments match the matchers in the second `EXPECT_CALL()`, but not those in +the first and third, then the second expectation will be used. + +If you would rather have all calls occur in the order of the expectations, put +the `EXPECT_CALL()` statements in a block where you define a variable of type +`InSequence`: + +```cpp +using ::testing::_; +using ::testing::InSequence; + + { + InSequence s; + + EXPECT_CALL(foo, DoThis(5)); + EXPECT_CALL(bar, DoThat(_)) + .Times(2); + EXPECT_CALL(foo, DoThis(6)); + } +``` + +In this example, we expect a call to `foo.DoThis(5)`, followed by two calls to +`bar.DoThat()` where the argument can be anything, which are in turn followed by +a call to `foo.DoThis(6)`. If a call occurred out-of-order, gMock will report an +error. + +### Expecting Partially Ordered Calls {#PartialOrder} + +Sometimes requiring everything to occur in a predetermined order can lead to +brittle tests. For example, we may care about `A` occurring before both `B` and +`C`, but aren't interested in the relative order of `B` and `C`. In this case, +the test should reflect our real intent, instead of being overly constraining. + +gMock allows you to impose an arbitrary DAG (directed acyclic graph) on the +calls. One way to express the DAG is to use the +[`After` clause](reference/mocking.md#EXPECT_CALL.After) of `EXPECT_CALL`. + +Another way is via the `InSequence()` clause (not the same as the `InSequence` +class), which we borrowed from jMock 2. It's less flexible than `After()`, but +more convenient when you have long chains of sequential calls, as it doesn't +require you to come up with different names for the expectations in the chains. +Here's how it works: + +If we view `EXPECT_CALL()` statements as nodes in a graph, and add an edge from +node A to node B wherever A must occur before B, we can get a DAG. We use the +term "sequence" to mean a directed path in this DAG. Now, if we decompose the +DAG into sequences, we just need to know which sequences each `EXPECT_CALL()` +belongs to in order to be able to reconstruct the original DAG. + +So, to specify the partial order on the expectations we need to do two things: +first to define some `Sequence` objects, and then for each `EXPECT_CALL()` say +which `Sequence` objects it is part of. + +Expectations in the same sequence must occur in the order they are written. For +example, + +```cpp +using ::testing::Sequence; +... + Sequence s1, s2; + + EXPECT_CALL(foo, A()) + .InSequence(s1, s2); + EXPECT_CALL(bar, B()) + .InSequence(s1); + EXPECT_CALL(bar, C()) + .InSequence(s2); + EXPECT_CALL(foo, D()) + .InSequence(s2); +``` + +specifies the following DAG (where `s1` is `A -> B`, and `s2` is `A -> C -> D`): + +```text + +---> B + | + A ---| + | + +---> C ---> D +``` + +This means that A must occur before B and C, and C must occur before D. There's +no restriction about the order other than these. + +### Controlling When an Expectation Retires + +When a mock method is called, gMock only considers expectations that are still +active. An expectation is active when created, and becomes inactive (aka +*retires*) when a call that has to occur later has occurred. For example, in + +```cpp +using ::testing::_; +using ::testing::Sequence; +... + Sequence s1, s2; + + EXPECT_CALL(log, Log(WARNING, _, "File too large.")) // #1 + .Times(AnyNumber()) + .InSequence(s1, s2); + EXPECT_CALL(log, Log(WARNING, _, "Data set is empty.")) // #2 + .InSequence(s1); + EXPECT_CALL(log, Log(WARNING, _, "User not found.")) // #3 + .InSequence(s2); +``` + +as soon as either #2 or #3 is matched, #1 will retire. If a warning `"File too +large."` is logged after this, it will be an error. + +Note that an expectation doesn't retire automatically when it's saturated. For +example, + +```cpp +using ::testing::_; +... + EXPECT_CALL(log, Log(WARNING, _, _)); // #1 + EXPECT_CALL(log, Log(WARNING, _, "File too large.")); // #2 +``` + +says that there will be exactly one warning with the message `"File too +large."`. If the second warning contains this message too, #2 will match again +and result in an upper-bound-violated error. + +If this is not what you want, you can ask an expectation to retire as soon as it +becomes saturated: + +```cpp +using ::testing::_; +... + EXPECT_CALL(log, Log(WARNING, _, _)); // #1 + EXPECT_CALL(log, Log(WARNING, _, "File too large.")) // #2 + .RetiresOnSaturation(); +``` + +Here #2 can be used only once, so if you have two warnings with the message +`"File too large."`, the first will match #2 and the second will match #1 - +there will be no error. + +## Using Actions + +### Returning References from Mock Methods + +If a mock function's return type is a reference, you need to use `ReturnRef()` +instead of `Return()` to return a result: + +```cpp +using ::testing::ReturnRef; + +class MockFoo : public Foo { + public: + MOCK_METHOD(Bar&, GetBar, (), (override)); +}; +... + MockFoo foo; + Bar bar; + EXPECT_CALL(foo, GetBar()) + .WillOnce(ReturnRef(bar)); +... +``` + +### Returning Live Values from Mock Methods + +The `Return(x)` action saves a copy of `x` when the action is created, and +always returns the same value whenever it's executed. Sometimes you may want to +instead return the *live* value of `x` (i.e. its value at the time when the +action is *executed*.). Use either `ReturnRef()` or `ReturnPointee()` for this +purpose. + +If the mock function's return type is a reference, you can do it using +`ReturnRef(x)`, as shown in the previous recipe ("Returning References from Mock +Methods"). However, gMock doesn't let you use `ReturnRef()` in a mock function +whose return type is not a reference, as doing that usually indicates a user +error. So, what shall you do? + +Though you may be tempted, DO NOT use `std::ref()`: + +```cpp +using ::testing::Return; + +class MockFoo : public Foo { + public: + MOCK_METHOD(int, GetValue, (), (override)); +}; +... + int x = 0; + MockFoo foo; + EXPECT_CALL(foo, GetValue()) + .WillRepeatedly(Return(std::ref(x))); // Wrong! + x = 42; + EXPECT_EQ(foo.GetValue(), 42); +``` + +Unfortunately, it doesn't work here. The above code will fail with error: + +```text +Value of: foo.GetValue() + Actual: 0 +Expected: 42 +``` + +The reason is that `Return(*value*)` converts `value` to the actual return type +of the mock function at the time when the action is *created*, not when it is +*executed*. (This behavior was chosen for the action to be safe when `value` is +a proxy object that references some temporary objects.) As a result, +`std::ref(x)` is converted to an `int` value (instead of a `const int&`) when +the expectation is set, and `Return(std::ref(x))` will always return 0. + +`ReturnPointee(pointer)` was provided to solve this problem specifically. It +returns the value pointed to by `pointer` at the time the action is *executed*: + +```cpp +using ::testing::ReturnPointee; +... + int x = 0; + MockFoo foo; + EXPECT_CALL(foo, GetValue()) + .WillRepeatedly(ReturnPointee(&x)); // Note the & here. + x = 42; + EXPECT_EQ(foo.GetValue(), 42); // This will succeed now. +``` + +### Combining Actions + +Want to do more than one thing when a function is called? That's fine. `DoAll()` +allows you to do a sequence of actions every time. Only the return value of the +last action in the sequence will be used. + +```cpp +using ::testing::_; +using ::testing::DoAll; + +class MockFoo : public Foo { + public: + MOCK_METHOD(bool, Bar, (int n), (override)); +}; +... + EXPECT_CALL(foo, Bar(_)) + .WillOnce(DoAll(action_1, + action_2, + ... + action_n)); +``` + +The return value of the last action **must** match the return type of the mocked +method. In the example above, `action_n` could be `Return(true)`, or a lambda +that returns a `bool`, but not `SaveArg`, which returns `void`. Otherwise the +signature of `DoAll` would not match the signature expected by `WillOnce`, which +is the signature of the mocked method, and it wouldn't compile. + +### Verifying Complex Arguments {#SaveArgVerify} + +If you want to verify that a method is called with a particular argument but the +match criteria is complex, it can be difficult to distinguish between +cardinality failures (calling the method the wrong number of times) and argument +match failures. Similarly, if you are matching multiple parameters, it may not +be easy to distinguishing which argument failed to match. For example: + +```cpp + // Not ideal: this could fail because of a problem with arg1 or arg2, or maybe + // just the method wasn't called. + EXPECT_CALL(foo, SendValues(_, ElementsAre(1, 4, 4, 7), EqualsProto( ... ))); +``` + +You can instead save the arguments and test them individually: + +```cpp + EXPECT_CALL(foo, SendValues) + .WillOnce(DoAll(SaveArg<1>(&actual_array), SaveArg<2>(&actual_proto))); + ... run the test + EXPECT_THAT(actual_array, ElementsAre(1, 4, 4, 7)); + EXPECT_THAT(actual_proto, EqualsProto( ... )); +``` + +### Mocking Side Effects {#MockingSideEffects} + +Sometimes a method exhibits its effect not via returning a value but via side +effects. For example, it may change some global state or modify an output +argument. To mock side effects, in general you can define your own action by +implementing `::testing::ActionInterface`. + +If all you need to do is to change an output argument, the built-in +`SetArgPointee()` action is convenient: + +```cpp +using ::testing::_; +using ::testing::SetArgPointee; + +class MockMutator : public Mutator { + public: + MOCK_METHOD(void, Mutate, (bool mutate, int* value), (override)); + ... +} +... + MockMutator mutator; + EXPECT_CALL(mutator, Mutate(true, _)) + .WillOnce(SetArgPointee<1>(5)); +``` + +In this example, when `mutator.Mutate()` is called, we will assign 5 to the +`int` variable pointed to by argument #1 (0-based). + +`SetArgPointee()` conveniently makes an internal copy of the value you pass to +it, removing the need to keep the value in scope and alive. The implication +however is that the value must have a copy constructor and assignment operator. + +If the mock method also needs to return a value as well, you can chain +`SetArgPointee()` with `Return()` using `DoAll()`, remembering to put the +`Return()` statement last: + +```cpp +using ::testing::_; +using ::testing::DoAll; +using ::testing::Return; +using ::testing::SetArgPointee; + +class MockMutator : public Mutator { + public: + ... + MOCK_METHOD(bool, MutateInt, (int* value), (override)); +} +... + MockMutator mutator; + EXPECT_CALL(mutator, MutateInt(_)) + .WillOnce(DoAll(SetArgPointee<0>(5), + Return(true))); +``` + +Note, however, that if you use the `ReturnOKWith()` method, it will override the +values provided by `SetArgPointee()` in the response parameters of your function +call. + +If the output argument is an array, use the `SetArrayArgument(first, last)` +action instead. It copies the elements in source range `[first, last)` to the +array pointed to by the `N`-th (0-based) argument: + +```cpp +using ::testing::NotNull; +using ::testing::SetArrayArgument; + +class MockArrayMutator : public ArrayMutator { + public: + MOCK_METHOD(void, Mutate, (int* values, int num_values), (override)); + ... +} +... + MockArrayMutator mutator; + int values[5] = {1, 2, 3, 4, 5}; + EXPECT_CALL(mutator, Mutate(NotNull(), 5)) + .WillOnce(SetArrayArgument<0>(values, values + 5)); +``` + +This also works when the argument is an output iterator: + +```cpp +using ::testing::_; +using ::testing::SetArrayArgument; + +class MockRolodex : public Rolodex { + public: + MOCK_METHOD(void, GetNames, (std::back_insert_iterator>), + (override)); + ... +} +... + MockRolodex rolodex; + vector names = {"George", "John", "Thomas"}; + EXPECT_CALL(rolodex, GetNames(_)) + .WillOnce(SetArrayArgument<0>(names.begin(), names.end())); +``` + +### Changing a Mock Object's Behavior Based on the State + +If you expect a call to change the behavior of a mock object, you can use +`::testing::InSequence` to specify different behaviors before and after the +call: + +```cpp +using ::testing::InSequence; +using ::testing::Return; + +... + { + InSequence seq; + EXPECT_CALL(my_mock, IsDirty()) + .WillRepeatedly(Return(true)); + EXPECT_CALL(my_mock, Flush()); + EXPECT_CALL(my_mock, IsDirty()) + .WillRepeatedly(Return(false)); + } + my_mock.FlushIfDirty(); +``` + +This makes `my_mock.IsDirty()` return `true` before `my_mock.Flush()` is called +and return `false` afterwards. + +If the behavior change is more complex, you can store the effects in a variable +and make a mock method get its return value from that variable: + +```cpp +using ::testing::_; +using ::testing::SaveArg; +using ::testing::Return; + +ACTION_P(ReturnPointee, p) { return *p; } +... + int previous_value = 0; + EXPECT_CALL(my_mock, GetPrevValue) + .WillRepeatedly(ReturnPointee(&previous_value)); + EXPECT_CALL(my_mock, UpdateValue) + .WillRepeatedly(SaveArg<0>(&previous_value)); + my_mock.DoSomethingToUpdateValue(); +``` + +Here `my_mock.GetPrevValue()` will always return the argument of the last +`UpdateValue()` call. + +### Setting the Default Value for a Return Type {#DefaultValue} + +If a mock method's return type is a built-in C++ type or pointer, by default it +will return 0 when invoked. Also, in C++ 11 and above, a mock method whose +return type has a default constructor will return a default-constructed value by +default. You only need to specify an action if this default value doesn't work +for you. + +Sometimes, you may want to change this default value, or you may want to specify +a default value for types gMock doesn't know about. You can do this using the +`::testing::DefaultValue` class template: + +```cpp +using ::testing::DefaultValue; + +class MockFoo : public Foo { + public: + MOCK_METHOD(Bar, CalculateBar, (), (override)); +}; + + +... + Bar default_bar; + // Sets the default return value for type Bar. + DefaultValue::Set(default_bar); + + MockFoo foo; + + // We don't need to specify an action here, as the default + // return value works for us. + EXPECT_CALL(foo, CalculateBar()); + + foo.CalculateBar(); // This should return default_bar. + + // Unsets the default return value. + DefaultValue::Clear(); +``` + +Please note that changing the default value for a type can make your tests hard +to understand. We recommend you to use this feature judiciously. For example, +you may want to make sure the `Set()` and `Clear()` calls are right next to the +code that uses your mock. + +### Setting the Default Actions for a Mock Method + +You've learned how to change the default value of a given type. However, this +may be too coarse for your purpose: perhaps you have two mock methods with the +same return type and you want them to have different behaviors. The `ON_CALL()` +macro allows you to customize your mock's behavior at the method level: + +```cpp +using ::testing::_; +using ::testing::AnyNumber; +using ::testing::Gt; +using ::testing::Return; +... + ON_CALL(foo, Sign(_)) + .WillByDefault(Return(-1)); + ON_CALL(foo, Sign(0)) + .WillByDefault(Return(0)); + ON_CALL(foo, Sign(Gt(0))) + .WillByDefault(Return(1)); + + EXPECT_CALL(foo, Sign(_)) + .Times(AnyNumber()); + + foo.Sign(5); // This should return 1. + foo.Sign(-9); // This should return -1. + foo.Sign(0); // This should return 0. +``` + +As you may have guessed, when there are more than one `ON_CALL()` statements, +the newer ones in the order take precedence over the older ones. In other words, +the **last** one that matches the function arguments will be used. This matching +order allows you to set up the common behavior in a mock object's constructor or +the test fixture's set-up phase and specialize the mock's behavior later. + +Note that both `ON_CALL` and `EXPECT_CALL` have the same "later statements take +precedence" rule, but they don't interact. That is, `EXPECT_CALL`s have their +own precedence order distinct from the `ON_CALL` precedence order. + +### Using Functions/Methods/Functors/Lambdas as Actions {#FunctionsAsActions} + +If the built-in actions don't suit you, you can use an existing callable +(function, `std::function`, method, functor, lambda) as an action. + +```cpp +using ::testing::_; using ::testing::Invoke; + +class MockFoo : public Foo { + public: + MOCK_METHOD(int, Sum, (int x, int y), (override)); + MOCK_METHOD(bool, ComplexJob, (int x), (override)); +}; + +int CalculateSum(int x, int y) { return x + y; } +int Sum3(int x, int y, int z) { return x + y + z; } + +class Helper { + public: + bool ComplexJob(int x); +}; + +... + MockFoo foo; + Helper helper; + EXPECT_CALL(foo, Sum(_, _)) + .WillOnce(&CalculateSum) + .WillRepeatedly(Invoke(NewPermanentCallback(Sum3, 1))); + EXPECT_CALL(foo, ComplexJob(_)) + .WillOnce(Invoke(&helper, &Helper::ComplexJob)) + .WillOnce([] { return true; }) + .WillRepeatedly([](int x) { return x > 0; }); + + foo.Sum(5, 6); // Invokes CalculateSum(5, 6). + foo.Sum(2, 3); // Invokes Sum3(1, 2, 3). + foo.ComplexJob(10); // Invokes helper.ComplexJob(10). + foo.ComplexJob(-1); // Invokes the inline lambda. +``` + +The only requirement is that the type of the function, etc must be *compatible* +with the signature of the mock function, meaning that the latter's arguments (if +it takes any) can be implicitly converted to the corresponding arguments of the +former, and the former's return type can be implicitly converted to that of the +latter. So, you can invoke something whose type is *not* exactly the same as the +mock function, as long as it's safe to do so - nice, huh? + +Note that: + +* The action takes ownership of the callback and will delete it when the + action itself is destructed. +* If the type of a callback is derived from a base callback type `C`, you need + to implicitly cast it to `C` to resolve the overloading, e.g. + + ```cpp + using ::testing::Invoke; + ... + ResultCallback* is_ok = ...; + ... Invoke(is_ok) ...; // This works. + + BlockingClosure* done = new BlockingClosure; + ... Invoke(implicit_cast(done)) ...; // The cast is necessary. + ``` + +### Using Functions with Extra Info as Actions + +The function or functor you call using `Invoke()` must have the same number of +arguments as the mock function you use it for. Sometimes you may have a function +that takes more arguments, and you are willing to pass in the extra arguments +yourself to fill the gap. You can do this in gMock using callbacks with +pre-bound arguments. Here's an example: + +```cpp +using ::testing::Invoke; + +class MockFoo : public Foo { + public: + MOCK_METHOD(char, DoThis, (int n), (override)); +}; + +char SignOfSum(int x, int y) { + const int sum = x + y; + return (sum > 0) ? '+' : (sum < 0) ? '-' : '0'; +} + +TEST_F(FooTest, Test) { + MockFoo foo; + + EXPECT_CALL(foo, DoThis(2)) + .WillOnce(Invoke(NewPermanentCallback(SignOfSum, 5))); + EXPECT_EQ(foo.DoThis(2), '+'); // Invokes SignOfSum(5, 2). +} +``` + +### Invoking a Function/Method/Functor/Lambda/Callback Without Arguments + +`Invoke()` passes the mock function's arguments to the function, etc being +invoked such that the callee has the full context of the call to work with. If +the invoked function is not interested in some or all of the arguments, it can +simply ignore them. + +Yet, a common pattern is that a test author wants to invoke a function without +the arguments of the mock function. She could do that using a wrapper function +that throws away the arguments before invoking an underlining nullary function. +Needless to say, this can be tedious and obscures the intent of the test. + +There are two solutions to this problem. First, you can pass any callable of +zero args as an action. Alternatively, use `InvokeWithoutArgs()`, which is like +`Invoke()` except that it doesn't pass the mock function's arguments to the +callee. Here's an example of each: + +```cpp +using ::testing::_; +using ::testing::InvokeWithoutArgs; + +class MockFoo : public Foo { + public: + MOCK_METHOD(bool, ComplexJob, (int n), (override)); +}; + +bool Job1() { ... } +bool Job2(int n, char c) { ... } + +... + MockFoo foo; + EXPECT_CALL(foo, ComplexJob(_)) + .WillOnce([] { Job1(); }); + .WillOnce(InvokeWithoutArgs(NewPermanentCallback(Job2, 5, 'a'))); + + foo.ComplexJob(10); // Invokes Job1(). + foo.ComplexJob(20); // Invokes Job2(5, 'a'). +``` + +Note that: + +* The action takes ownership of the callback and will delete it when the + action itself is destructed. +* If the type of a callback is derived from a base callback type `C`, you need + to implicitly cast it to `C` to resolve the overloading, e.g. + + ```cpp + using ::testing::InvokeWithoutArgs; + ... + ResultCallback* is_ok = ...; + ... InvokeWithoutArgs(is_ok) ...; // This works. + + BlockingClosure* done = ...; + ... InvokeWithoutArgs(implicit_cast(done)) ...; + // The cast is necessary. + ``` + +### Invoking an Argument of the Mock Function + +Sometimes a mock function will receive a function pointer, a functor (in other +words, a "callable") as an argument, e.g. + +```cpp +class MockFoo : public Foo { + public: + MOCK_METHOD(bool, DoThis, (int n, (ResultCallback1* callback)), + (override)); +}; +``` + +and you may want to invoke this callable argument: + +```cpp +using ::testing::_; +... + MockFoo foo; + EXPECT_CALL(foo, DoThis(_, _)) + .WillOnce(...); + // Will execute callback->Run(5), where callback is the + // second argument DoThis() receives. +``` + +{: .callout .note} +NOTE: The section below is legacy documentation from before C++ had lambdas: + +Arghh, you need to refer to a mock function argument but C++ has no lambda +(yet), so you have to define your own action. :-( Or do you really? + +Well, gMock has an action to solve *exactly* this problem: + +```cpp +InvokeArgument(arg_1, arg_2, ..., arg_m) +``` + +will invoke the `N`-th (0-based) argument the mock function receives, with +`arg_1`, `arg_2`, ..., and `arg_m`. No matter if the argument is a function +pointer, a functor, or a callback. gMock handles them all. + +With that, you could write: + +```cpp +using ::testing::_; +using ::testing::InvokeArgument; +... + EXPECT_CALL(foo, DoThis(_, _)) + .WillOnce(InvokeArgument<1>(5)); + // Will execute callback->Run(5), where callback is the + // second argument DoThis() receives. +``` + +What if the callable takes an argument by reference? No problem - just wrap it +inside `std::ref()`: + +```cpp + ... + MOCK_METHOD(bool, Bar, + ((ResultCallback2* callback)), + (override)); + ... + using ::testing::_; + using ::testing::InvokeArgument; + ... + MockFoo foo; + Helper helper; + ... + EXPECT_CALL(foo, Bar(_)) + .WillOnce(InvokeArgument<0>(5, std::ref(helper))); + // std::ref(helper) guarantees that a reference to helper, not a copy of + // it, will be passed to the callback. +``` + +What if the callable takes an argument by reference and we do **not** wrap the +argument in `std::ref()`? Then `InvokeArgument()` will *make a copy* of the +argument, and pass a *reference to the copy*, instead of a reference to the +original value, to the callable. This is especially handy when the argument is a +temporary value: + +```cpp + ... + MOCK_METHOD(bool, DoThat, (bool (*f)(const double& x, const string& s)), + (override)); + ... + using ::testing::_; + using ::testing::InvokeArgument; + ... + MockFoo foo; + ... + EXPECT_CALL(foo, DoThat(_)) + .WillOnce(InvokeArgument<0>(5.0, string("Hi"))); + // Will execute (*f)(5.0, string("Hi")), where f is the function pointer + // DoThat() receives. Note that the values 5.0 and string("Hi") are + // temporary and dead once the EXPECT_CALL() statement finishes. Yet + // it's fine to perform this action later, since a copy of the values + // are kept inside the InvokeArgument action. +``` + +### Ignoring an Action's Result + +Sometimes you have an action that returns *something*, but you need an action +that returns `void` (perhaps you want to use it in a mock function that returns +`void`, or perhaps it needs to be used in `DoAll()` and it's not the last in the +list). `IgnoreResult()` lets you do that. For example: + +```cpp +using ::testing::_; +using ::testing::DoAll; +using ::testing::IgnoreResult; +using ::testing::Return; + +int Process(const MyData& data); +string DoSomething(); + +class MockFoo : public Foo { + public: + MOCK_METHOD(void, Abc, (const MyData& data), (override)); + MOCK_METHOD(bool, Xyz, (), (override)); +}; + + ... + MockFoo foo; + EXPECT_CALL(foo, Abc(_)) + // .WillOnce(Invoke(Process)); + // The above line won't compile as Process() returns int but Abc() needs + // to return void. + .WillOnce(IgnoreResult(Process)); + EXPECT_CALL(foo, Xyz()) + .WillOnce(DoAll(IgnoreResult(DoSomething), + // Ignores the string DoSomething() returns. + Return(true))); +``` + +Note that you **cannot** use `IgnoreResult()` on an action that already returns +`void`. Doing so will lead to ugly compiler errors. + +### Selecting an Action's Arguments {#SelectingArgs} + +Say you have a mock function `Foo()` that takes seven arguments, and you have a +custom action that you want to invoke when `Foo()` is called. Trouble is, the +custom action only wants three arguments: + +```cpp +using ::testing::_; +using ::testing::Invoke; +... + MOCK_METHOD(bool, Foo, + (bool visible, const string& name, int x, int y, + (const map>), double& weight, double min_weight, + double max_wight)); +... +bool IsVisibleInQuadrant1(bool visible, int x, int y) { + return visible && x >= 0 && y >= 0; +} +... + EXPECT_CALL(mock, Foo) + .WillOnce(Invoke(IsVisibleInQuadrant1)); // Uh, won't compile. :-( +``` + +To please the compiler God, you need to define an "adaptor" that has the same +signature as `Foo()` and calls the custom action with the right arguments: + +```cpp +using ::testing::_; +using ::testing::Invoke; +... +bool MyIsVisibleInQuadrant1(bool visible, const string& name, int x, int y, + const map, double>& weight, + double min_weight, double max_wight) { + return IsVisibleInQuadrant1(visible, x, y); +} +... + EXPECT_CALL(mock, Foo) + .WillOnce(Invoke(MyIsVisibleInQuadrant1)); // Now it works. +``` + +But isn't this awkward? + +gMock provides a generic *action adaptor*, so you can spend your time minding +more important business than writing your own adaptors. Here's the syntax: + +```cpp +WithArgs(action) +``` + +creates an action that passes the arguments of the mock function at the given +indices (0-based) to the inner `action` and performs it. Using `WithArgs`, our +original example can be written as: + +```cpp +using ::testing::_; +using ::testing::Invoke; +using ::testing::WithArgs; +... + EXPECT_CALL(mock, Foo) + .WillOnce(WithArgs<0, 2, 3>(Invoke(IsVisibleInQuadrant1))); // No need to define your own adaptor. +``` + +For better readability, gMock also gives you: + +* `WithoutArgs(action)` when the inner `action` takes *no* argument, and +* `WithArg(action)` (no `s` after `Arg`) when the inner `action` takes + *one* argument. + +As you may have realized, `InvokeWithoutArgs(...)` is just syntactic sugar for +`WithoutArgs(Invoke(...))`. + +Here are more tips: + +* The inner action used in `WithArgs` and friends does not have to be + `Invoke()` -- it can be anything. +* You can repeat an argument in the argument list if necessary, e.g. + `WithArgs<2, 3, 3, 5>(...)`. +* You can change the order of the arguments, e.g. `WithArgs<3, 2, 1>(...)`. +* The types of the selected arguments do *not* have to match the signature of + the inner action exactly. It works as long as they can be implicitly + converted to the corresponding arguments of the inner action. For example, + if the 4-th argument of the mock function is an `int` and `my_action` takes + a `double`, `WithArg<4>(my_action)` will work. + +### Ignoring Arguments in Action Functions + +The [selecting-an-action's-arguments](#SelectingArgs) recipe showed us one way +to make a mock function and an action with incompatible argument lists fit +together. The downside is that wrapping the action in `WithArgs<...>()` can get +tedious for people writing the tests. + +If you are defining a function (or method, functor, lambda, callback) to be used +with `Invoke*()`, and you are not interested in some of its arguments, an +alternative to `WithArgs` is to declare the uninteresting arguments as `Unused`. +This makes the definition less cluttered and less fragile in case the types of +the uninteresting arguments change. It could also increase the chance the action +function can be reused. For example, given + +```cpp + public: + MOCK_METHOD(double, Foo, double(const string& label, double x, double y), + (override)); + MOCK_METHOD(double, Bar, (int index, double x, double y), (override)); +``` + +instead of + +```cpp +using ::testing::_; +using ::testing::Invoke; + +double DistanceToOriginWithLabel(const string& label, double x, double y) { + return sqrt(x*x + y*y); +} +double DistanceToOriginWithIndex(int index, double x, double y) { + return sqrt(x*x + y*y); +} +... + EXPECT_CALL(mock, Foo("abc", _, _)) + .WillOnce(Invoke(DistanceToOriginWithLabel)); + EXPECT_CALL(mock, Bar(5, _, _)) + .WillOnce(Invoke(DistanceToOriginWithIndex)); +``` + +you could write + +```cpp +using ::testing::_; +using ::testing::Invoke; +using ::testing::Unused; + +double DistanceToOrigin(Unused, double x, double y) { + return sqrt(x*x + y*y); +} +... + EXPECT_CALL(mock, Foo("abc", _, _)) + .WillOnce(Invoke(DistanceToOrigin)); + EXPECT_CALL(mock, Bar(5, _, _)) + .WillOnce(Invoke(DistanceToOrigin)); +``` + +### Sharing Actions + +Just like matchers, a gMock action object consists of a pointer to a ref-counted +implementation object. Therefore copying actions is also allowed and very +efficient. When the last action that references the implementation object dies, +the implementation object will be deleted. + +If you have some complex action that you want to use again and again, you may +not have to build it from scratch every time. If the action doesn't have an +internal state (i.e. if it always does the same thing no matter how many times +it has been called), you can assign it to an action variable and use that +variable repeatedly. For example: + +```cpp +using ::testing::Action; +using ::testing::DoAll; +using ::testing::Return; +using ::testing::SetArgPointee; +... + Action set_flag = DoAll(SetArgPointee<0>(5), + Return(true)); + ... use set_flag in .WillOnce() and .WillRepeatedly() ... +``` + +However, if the action has its own state, you may be surprised if you share the +action object. Suppose you have an action factory `IncrementCounter(init)` which +creates an action that increments and returns a counter whose initial value is +`init`, using two actions created from the same expression and using a shared +action will exhibit different behaviors. Example: + +```cpp + EXPECT_CALL(foo, DoThis()) + .WillRepeatedly(IncrementCounter(0)); + EXPECT_CALL(foo, DoThat()) + .WillRepeatedly(IncrementCounter(0)); + foo.DoThis(); // Returns 1. + foo.DoThis(); // Returns 2. + foo.DoThat(); // Returns 1 - DoThat() uses a different + // counter than DoThis()'s. +``` + +versus + +```cpp +using ::testing::Action; +... + Action increment = IncrementCounter(0); + EXPECT_CALL(foo, DoThis()) + .WillRepeatedly(increment); + EXPECT_CALL(foo, DoThat()) + .WillRepeatedly(increment); + foo.DoThis(); // Returns 1. + foo.DoThis(); // Returns 2. + foo.DoThat(); // Returns 3 - the counter is shared. +``` + +### Testing Asynchronous Behavior + +One oft-encountered problem with gMock is that it can be hard to test +asynchronous behavior. Suppose you had a `EventQueue` class that you wanted to +test, and you created a separate `EventDispatcher` interface so that you could +easily mock it out. However, the implementation of the class fired all the +events on a background thread, which made test timings difficult. You could just +insert `sleep()` statements and hope for the best, but that makes your test +behavior nondeterministic. A better way is to use gMock actions and +`Notification` objects to force your asynchronous test to behave synchronously. + +```cpp +class MockEventDispatcher : public EventDispatcher { + MOCK_METHOD(bool, DispatchEvent, (int32), (override)); +}; + +TEST(EventQueueTest, EnqueueEventTest) { + MockEventDispatcher mock_event_dispatcher; + EventQueue event_queue(&mock_event_dispatcher); + + const int32 kEventId = 321; + absl::Notification done; + EXPECT_CALL(mock_event_dispatcher, DispatchEvent(kEventId)) + .WillOnce([&done] { done.Notify(); }); + + event_queue.EnqueueEvent(kEventId); + done.WaitForNotification(); +} +``` + +In the example above, we set our normal gMock expectations, but then add an +additional action to notify the `Notification` object. Now we can just call +`Notification::WaitForNotification()` in the main thread to wait for the +asynchronous call to finish. After that, our test suite is complete and we can +safely exit. + +{: .callout .note} +Note: this example has a downside: namely, if the expectation is not satisfied, +our test will run forever. It will eventually time-out and fail, but it will +take longer and be slightly harder to debug. To alleviate this problem, you can +use `WaitForNotificationWithTimeout(ms)` instead of `WaitForNotification()`. + +## Misc Recipes on Using gMock + +### Mocking Methods That Use Move-Only Types + +C++11 introduced *move-only types*. A move-only-typed value can be moved from +one object to another, but cannot be copied. `std::unique_ptr` is probably +the most commonly used move-only type. + +Mocking a method that takes and/or returns move-only types presents some +challenges, but nothing insurmountable. This recipe shows you how you can do it. +Note that the support for move-only method arguments was only introduced to +gMock in April 2017; in older code, you may find more complex +[workarounds](#LegacyMoveOnly) for lack of this feature. + +Let’s say we are working on a fictional project that lets one post and share +snippets called “buzzes”. Your code uses these types: + +```cpp +enum class AccessLevel { kInternal, kPublic }; + +class Buzz { + public: + explicit Buzz(AccessLevel access) { ... } + ... +}; + +class Buzzer { + public: + virtual ~Buzzer() {} + virtual std::unique_ptr MakeBuzz(StringPiece text) = 0; + virtual bool ShareBuzz(std::unique_ptr buzz, int64_t timestamp) = 0; + ... +}; +``` + +A `Buzz` object represents a snippet being posted. A class that implements the +`Buzzer` interface is capable of creating and sharing `Buzz`es. Methods in +`Buzzer` may return a `unique_ptr` or take a `unique_ptr`. Now we +need to mock `Buzzer` in our tests. + +To mock a method that accepts or returns move-only types, you just use the +familiar `MOCK_METHOD` syntax as usual: + +```cpp +class MockBuzzer : public Buzzer { + public: + MOCK_METHOD(std::unique_ptr, MakeBuzz, (StringPiece text), (override)); + MOCK_METHOD(bool, ShareBuzz, (std::unique_ptr buzz, int64_t timestamp), + (override)); +}; +``` + +Now that we have the mock class defined, we can use it in tests. In the +following code examples, we assume that we have defined a `MockBuzzer` object +named `mock_buzzer_`: + +```cpp + MockBuzzer mock_buzzer_; +``` + +First let’s see how we can set expectations on the `MakeBuzz()` method, which +returns a `unique_ptr`. + +As usual, if you set an expectation without an action (i.e. the `.WillOnce()` or +`.WillRepeatedly()` clause), when that expectation fires, the default action for +that method will be taken. Since `unique_ptr<>` has a default constructor that +returns a null `unique_ptr`, that’s what you’ll get if you don’t specify an +action: + +```cpp +using ::testing::IsNull; +... + // Use the default action. + EXPECT_CALL(mock_buzzer_, MakeBuzz("hello")); + + // Triggers the previous EXPECT_CALL. + EXPECT_THAT(mock_buzzer_.MakeBuzz("hello"), IsNull()); +``` + +If you are not happy with the default action, you can tweak it as usual; see +[Setting Default Actions](#OnCall). + +If you just need to return a move-only value, you can use it in combination with +`WillOnce`. For example: + +```cpp + EXPECT_CALL(mock_buzzer_, MakeBuzz("hello")) + .WillOnce(Return(std::make_unique(AccessLevel::kInternal))); + EXPECT_NE(nullptr, mock_buzzer_.MakeBuzz("hello")); +``` + +Quiz time! What do you think will happen if a `Return` action is performed more +than once (e.g. you write `... .WillRepeatedly(Return(std::move(...)));`)? Come +think of it, after the first time the action runs, the source value will be +consumed (since it’s a move-only value), so the next time around, there’s no +value to move from -- you’ll get a run-time error that `Return(std::move(...))` +can only be run once. + +If you need your mock method to do more than just moving a pre-defined value, +remember that you can always use a lambda or a callable object, which can do +pretty much anything you want: + +```cpp + EXPECT_CALL(mock_buzzer_, MakeBuzz("x")) + .WillRepeatedly([](StringPiece text) { + return std::make_unique(AccessLevel::kInternal); + }); + + EXPECT_NE(nullptr, mock_buzzer_.MakeBuzz("x")); + EXPECT_NE(nullptr, mock_buzzer_.MakeBuzz("x")); +``` + +Every time this `EXPECT_CALL` fires, a new `unique_ptr` will be created +and returned. You cannot do this with `Return(std::make_unique<...>(...))`. + +That covers returning move-only values; but how do we work with methods +accepting move-only arguments? The answer is that they work normally, although +some actions will not compile when any of method's arguments are move-only. You +can always use `Return`, or a [lambda or functor](#FunctionsAsActions): + +```cpp + using ::testing::Unused; + + EXPECT_CALL(mock_buzzer_, ShareBuzz(NotNull(), _)).WillOnce(Return(true)); + EXPECT_TRUE(mock_buzzer_.ShareBuzz(std::make_unique(AccessLevel::kInternal)), + 0); + + EXPECT_CALL(mock_buzzer_, ShareBuzz(_, _)).WillOnce( + [](std::unique_ptr buzz, Unused) { return buzz != nullptr; }); + EXPECT_FALSE(mock_buzzer_.ShareBuzz(nullptr, 0)); +``` + +Many built-in actions (`WithArgs`, `WithoutArgs`,`DeleteArg`, `SaveArg`, ...) +could in principle support move-only arguments, but the support for this is not +implemented yet. If this is blocking you, please file a bug. + +A few actions (e.g. `DoAll`) copy their arguments internally, so they can never +work with non-copyable objects; you'll have to use functors instead. + +#### Legacy workarounds for move-only types {#LegacyMoveOnly} + +Support for move-only function arguments was only introduced to gMock in April +of 2017. In older code, you may encounter the following workaround for the lack +of this feature (it is no longer necessary - we're including it just for +reference): + +```cpp +class MockBuzzer : public Buzzer { + public: + MOCK_METHOD(bool, DoShareBuzz, (Buzz* buzz, Time timestamp)); + bool ShareBuzz(std::unique_ptr buzz, Time timestamp) override { + return DoShareBuzz(buzz.get(), timestamp); + } +}; +``` + +The trick is to delegate the `ShareBuzz()` method to a mock method (let’s call +it `DoShareBuzz()`) that does not take move-only parameters. Then, instead of +setting expectations on `ShareBuzz()`, you set them on the `DoShareBuzz()` mock +method: + +```cpp + MockBuzzer mock_buzzer_; + EXPECT_CALL(mock_buzzer_, DoShareBuzz(NotNull(), _)); + + // When one calls ShareBuzz() on the MockBuzzer like this, the call is + // forwarded to DoShareBuzz(), which is mocked. Therefore this statement + // will trigger the above EXPECT_CALL. + mock_buzzer_.ShareBuzz(std::make_unique(AccessLevel::kInternal), 0); +``` + +### Making the Compilation Faster + +Believe it or not, the *vast majority* of the time spent on compiling a mock +class is in generating its constructor and destructor, as they perform +non-trivial tasks (e.g. verification of the expectations). What's more, mock +methods with different signatures have different types and thus their +constructors/destructors need to be generated by the compiler separately. As a +result, if you mock many different types of methods, compiling your mock class +can get really slow. + +If you are experiencing slow compilation, you can move the definition of your +mock class' constructor and destructor out of the class body and into a `.cc` +file. This way, even if you `#include` your mock class in N files, the compiler +only needs to generate its constructor and destructor once, resulting in a much +faster compilation. + +Let's illustrate the idea using an example. Here's the definition of a mock +class before applying this recipe: + +```cpp +// File mock_foo.h. +... +class MockFoo : public Foo { + public: + // Since we don't declare the constructor or the destructor, + // the compiler will generate them in every translation unit + // where this mock class is used. + + MOCK_METHOD(int, DoThis, (), (override)); + MOCK_METHOD(bool, DoThat, (const char* str), (override)); + ... more mock methods ... +}; +``` + +After the change, it would look like: + +```cpp +// File mock_foo.h. +... +class MockFoo : public Foo { + public: + // The constructor and destructor are declared, but not defined, here. + MockFoo(); + virtual ~MockFoo(); + + MOCK_METHOD(int, DoThis, (), (override)); + MOCK_METHOD(bool, DoThat, (const char* str), (override)); + ... more mock methods ... +}; +``` + +and + +```cpp +// File mock_foo.cc. +#include "path/to/mock_foo.h" + +// The definitions may appear trivial, but the functions actually do a +// lot of things through the constructors/destructors of the member +// variables used to implement the mock methods. +MockFoo::MockFoo() {} +MockFoo::~MockFoo() {} +``` + +### Forcing a Verification + +When it's being destroyed, your friendly mock object will automatically verify +that all expectations on it have been satisfied, and will generate googletest +failures if not. This is convenient as it leaves you with one less thing to +worry about. That is, unless you are not sure if your mock object will be +destroyed. + +How could it be that your mock object won't eventually be destroyed? Well, it +might be created on the heap and owned by the code you are testing. Suppose +there's a bug in that code and it doesn't delete the mock object properly - you +could end up with a passing test when there's actually a bug. + +Using a heap checker is a good idea and can alleviate the concern, but its +implementation is not 100% reliable. So, sometimes you do want to *force* gMock +to verify a mock object before it is (hopefully) destructed. You can do this +with `Mock::VerifyAndClearExpectations(&mock_object)`: + +```cpp +TEST(MyServerTest, ProcessesRequest) { + using ::testing::Mock; + + MockFoo* const foo = new MockFoo; + EXPECT_CALL(*foo, ...)...; + // ... other expectations ... + + // server now owns foo. + MyServer server(foo); + server.ProcessRequest(...); + + // In case that server's destructor will forget to delete foo, + // this will verify the expectations anyway. + Mock::VerifyAndClearExpectations(foo); +} // server is destroyed when it goes out of scope here. +``` + +{: .callout .tip} +**Tip:** The `Mock::VerifyAndClearExpectations()` function returns a `bool` to +indicate whether the verification was successful (`true` for yes), so you can +wrap that function call inside a `ASSERT_TRUE()` if there is no point going +further when the verification has failed. + +Do not set new expectations after verifying and clearing a mock after its use. +Setting expectations after code that exercises the mock has undefined behavior. +See [Using Mocks in Tests](gmock_for_dummies.md#using-mocks-in-tests) for more +information. + +### Using Checkpoints {#UsingCheckPoints} + +Sometimes you might want to test a mock object's behavior in phases whose sizes +are each manageable, or you might want to set more detailed expectations about +which API calls invoke which mock functions. + +A technique you can use is to put the expectations in a sequence and insert +calls to a dummy "checkpoint" function at specific places. Then you can verify +that the mock function calls do happen at the right time. For example, if you +are exercising the code: + +```cpp + Foo(1); + Foo(2); + Foo(3); +``` + +and want to verify that `Foo(1)` and `Foo(3)` both invoke `mock.Bar("a")`, but +`Foo(2)` doesn't invoke anything, you can write: + +```cpp +using ::testing::MockFunction; + +TEST(FooTest, InvokesBarCorrectly) { + MyMock mock; + // Class MockFunction has exactly one mock method. It is named + // Call() and has type F. + MockFunction check; + { + InSequence s; + + EXPECT_CALL(mock, Bar("a")); + EXPECT_CALL(check, Call("1")); + EXPECT_CALL(check, Call("2")); + EXPECT_CALL(mock, Bar("a")); + } + Foo(1); + check.Call("1"); + Foo(2); + check.Call("2"); + Foo(3); +} +``` + +The expectation spec says that the first `Bar("a")` call must happen before +checkpoint "1", the second `Bar("a")` call must happen after checkpoint "2", and +nothing should happen between the two checkpoints. The explicit checkpoints make +it clear which `Bar("a")` is called by which call to `Foo()`. + +### Mocking Destructors + +Sometimes you want to make sure a mock object is destructed at the right time, +e.g. after `bar->A()` is called but before `bar->B()` is called. We already know +that you can specify constraints on the [order](#OrderedCalls) of mock function +calls, so all we need to do is to mock the destructor of the mock function. + +This sounds simple, except for one problem: a destructor is a special function +with special syntax and special semantics, and the `MOCK_METHOD` macro doesn't +work for it: + +```cpp +MOCK_METHOD(void, ~MockFoo, ()); // Won't compile! +``` + +The good news is that you can use a simple pattern to achieve the same effect. +First, add a mock function `Die()` to your mock class and call it in the +destructor, like this: + +```cpp +class MockFoo : public Foo { + ... + // Add the following two lines to the mock class. + MOCK_METHOD(void, Die, ()); + ~MockFoo() override { Die(); } +}; +``` + +(If the name `Die()` clashes with an existing symbol, choose another name.) Now, +we have translated the problem of testing when a `MockFoo` object dies to +testing when its `Die()` method is called: + +```cpp + MockFoo* foo = new MockFoo; + MockBar* bar = new MockBar; + ... + { + InSequence s; + + // Expects *foo to die after bar->A() and before bar->B(). + EXPECT_CALL(*bar, A()); + EXPECT_CALL(*foo, Die()); + EXPECT_CALL(*bar, B()); + } +``` + +And that's that. + +### Using gMock and Threads {#UsingThreads} + +In a **unit** test, it's best if you could isolate and test a piece of code in a +single-threaded context. That avoids race conditions and dead locks, and makes +debugging your test much easier. + +Yet most programs are multi-threaded, and sometimes to test something we need to +pound on it from more than one thread. gMock works for this purpose too. + +Remember the steps for using a mock: + +1. Create a mock object `foo`. +2. Set its default actions and expectations using `ON_CALL()` and + `EXPECT_CALL()`. +3. The code under test calls methods of `foo`. +4. Optionally, verify and reset the mock. +5. Destroy the mock yourself, or let the code under test destroy it. The + destructor will automatically verify it. + +If you follow the following simple rules, your mocks and threads can live +happily together: + +* Execute your *test code* (as opposed to the code being tested) in *one* + thread. This makes your test easy to follow. +* Obviously, you can do step #1 without locking. +* When doing step #2 and #5, make sure no other thread is accessing `foo`. + Obvious too, huh? +* #3 and #4 can be done either in one thread or in multiple threads - anyway + you want. gMock takes care of the locking, so you don't have to do any - + unless required by your test logic. + +If you violate the rules (for example, if you set expectations on a mock while +another thread is calling its methods), you get undefined behavior. That's not +fun, so don't do it. + +gMock guarantees that the action for a mock function is done in the same thread +that called the mock function. For example, in + +```cpp + EXPECT_CALL(mock, Foo(1)) + .WillOnce(action1); + EXPECT_CALL(mock, Foo(2)) + .WillOnce(action2); +``` + +if `Foo(1)` is called in thread 1 and `Foo(2)` is called in thread 2, gMock will +execute `action1` in thread 1 and `action2` in thread 2. + +gMock does *not* impose a sequence on actions performed in different threads +(doing so may create deadlocks as the actions may need to cooperate). This means +that the execution of `action1` and `action2` in the above example *may* +interleave. If this is a problem, you should add proper synchronization logic to +`action1` and `action2` to make the test thread-safe. + +Also, remember that `DefaultValue` is a global resource that potentially +affects *all* living mock objects in your program. Naturally, you won't want to +mess with it from multiple threads or when there still are mocks in action. + +### Controlling How Much Information gMock Prints + +When gMock sees something that has the potential of being an error (e.g. a mock +function with no expectation is called, a.k.a. an uninteresting call, which is +allowed but perhaps you forgot to explicitly ban the call), it prints some +warning messages, including the arguments of the function, the return value, and +the stack trace. Hopefully this will remind you to take a look and see if there +is indeed a problem. + +Sometimes you are confident that your tests are correct and may not appreciate +such friendly messages. Some other times, you are debugging your tests or +learning about the behavior of the code you are testing, and wish you could +observe every mock call that happens (including argument values, the return +value, and the stack trace). Clearly, one size doesn't fit all. + +You can control how much gMock tells you using the `--gmock_verbose=LEVEL` +command-line flag, where `LEVEL` is a string with three possible values: + +* `info`: gMock will print all informational messages, warnings, and errors + (most verbose). At this setting, gMock will also log any calls to the + `ON_CALL/EXPECT_CALL` macros. It will include a stack trace in + "uninteresting call" warnings. +* `warning`: gMock will print both warnings and errors (less verbose); it will + omit the stack traces in "uninteresting call" warnings. This is the default. +* `error`: gMock will print errors only (least verbose). + +Alternatively, you can adjust the value of that flag from within your tests like +so: + +```cpp + ::testing::FLAGS_gmock_verbose = "error"; +``` + +If you find gMock printing too many stack frames with its informational or +warning messages, remember that you can control their amount with the +`--gtest_stack_trace_depth=max_depth` flag. + +Now, judiciously use the right flag to enable gMock serve you better! + +### Gaining Super Vision into Mock Calls + +You have a test using gMock. It fails: gMock tells you some expectations aren't +satisfied. However, you aren't sure why: Is there a typo somewhere in the +matchers? Did you mess up the order of the `EXPECT_CALL`s? Or is the code under +test doing something wrong? How can you find out the cause? + +Won't it be nice if you have X-ray vision and can actually see the trace of all +`EXPECT_CALL`s and mock method calls as they are made? For each call, would you +like to see its actual argument values and which `EXPECT_CALL` gMock thinks it +matches? If you still need some help to figure out who made these calls, how +about being able to see the complete stack trace at each mock call? + +You can unlock this power by running your test with the `--gmock_verbose=info` +flag. For example, given the test program: + +```cpp +#include + +using ::testing::_; +using ::testing::HasSubstr; +using ::testing::Return; + +class MockFoo { + public: + MOCK_METHOD(void, F, (const string& x, const string& y)); +}; + +TEST(Foo, Bar) { + MockFoo mock; + EXPECT_CALL(mock, F(_, _)).WillRepeatedly(Return()); + EXPECT_CALL(mock, F("a", "b")); + EXPECT_CALL(mock, F("c", HasSubstr("d"))); + + mock.F("a", "good"); + mock.F("a", "b"); +} +``` + +if you run it with `--gmock_verbose=info`, you will see this output: + +```shell +[ RUN ] Foo.Bar + +foo_test.cc:14: EXPECT_CALL(mock, F(_, _)) invoked +Stack trace: ... + +foo_test.cc:15: EXPECT_CALL(mock, F("a", "b")) invoked +Stack trace: ... + +foo_test.cc:16: EXPECT_CALL(mock, F("c", HasSubstr("d"))) invoked +Stack trace: ... + +foo_test.cc:14: Mock function call matches EXPECT_CALL(mock, F(_, _))... + Function call: F(@0x7fff7c8dad40"a",@0x7fff7c8dad10"good") +Stack trace: ... + +foo_test.cc:15: Mock function call matches EXPECT_CALL(mock, F("a", "b"))... + Function call: F(@0x7fff7c8dada0"a",@0x7fff7c8dad70"b") +Stack trace: ... + +foo_test.cc:16: Failure +Actual function call count doesn't match EXPECT_CALL(mock, F("c", HasSubstr("d")))... + Expected: to be called once + Actual: never called - unsatisfied and active +[ FAILED ] Foo.Bar +``` + +Suppose the bug is that the `"c"` in the third `EXPECT_CALL` is a typo and +should actually be `"a"`. With the above message, you should see that the actual +`F("a", "good")` call is matched by the first `EXPECT_CALL`, not the third as +you thought. From that it should be obvious that the third `EXPECT_CALL` is +written wrong. Case solved. + +If you are interested in the mock call trace but not the stack traces, you can +combine `--gmock_verbose=info` with `--gtest_stack_trace_depth=0` on the test +command line. + +### Running Tests in Emacs + +If you build and run your tests in Emacs using the `M-x google-compile` command +(as many googletest users do), the source file locations of gMock and googletest +errors will be highlighted. Just press `` on one of them and you'll be +taken to the offending line. Or, you can just type `C-x`` to jump to the next +error. + +To make it even easier, you can add the following lines to your `~/.emacs` file: + +```text +(global-set-key "\M-m" 'google-compile) ; m is for make +(global-set-key [M-down] 'next-error) +(global-set-key [M-up] '(lambda () (interactive) (next-error -1))) +``` + +Then you can type `M-m` to start a build (if you want to run the test as well, +just make sure `foo_test.run` or `runtests` is in the build command you supply +after typing `M-m`), or `M-up`/`M-down` to move back and forth between errors. + +## Extending gMock + +### Writing New Matchers Quickly {#NewMatchers} + +{: .callout .warning} +WARNING: gMock does not guarantee when or how many times a matcher will be +invoked. Therefore, all matchers must be functionally pure. See +[this section](#PureMatchers) for more details. + +The `MATCHER*` family of macros can be used to define custom matchers easily. +The syntax: + +```cpp +MATCHER(name, description_string_expression) { statements; } +``` + +will define a matcher with the given name that executes the statements, which +must return a `bool` to indicate if the match succeeds. Inside the statements, +you can refer to the value being matched by `arg`, and refer to its type by +`arg_type`. + +The *description string* is a `string`-typed expression that documents what the +matcher does, and is used to generate the failure message when the match fails. +It can (and should) reference the special `bool` variable `negation`, and should +evaluate to the description of the matcher when `negation` is `false`, or that +of the matcher's negation when `negation` is `true`. + +For convenience, we allow the description string to be empty (`""`), in which +case gMock will use the sequence of words in the matcher name as the +description. + +#### Basic Example + +```cpp +MATCHER(IsDivisibleBy7, "") { return (arg % 7) == 0; } +``` + +allows you to write + +```cpp + // Expects mock_foo.Bar(n) to be called where n is divisible by 7. + EXPECT_CALL(mock_foo, Bar(IsDivisibleBy7())); +``` + +or, + +```cpp + using ::testing::Not; + ... + // Verifies that a value is divisible by 7 and the other is not. + EXPECT_THAT(some_expression, IsDivisibleBy7()); + EXPECT_THAT(some_other_expression, Not(IsDivisibleBy7())); +``` + +If the above assertions fail, they will print something like: + +```shell + Value of: some_expression + Expected: is divisible by 7 + Actual: 27 + ... + Value of: some_other_expression + Expected: not (is divisible by 7) + Actual: 21 +``` + +where the descriptions `"is divisible by 7"` and `"not (is divisible by 7)"` are +automatically calculated from the matcher name `IsDivisibleBy7`. + +#### Adding Custom Failure Messages + +As you may have noticed, the auto-generated descriptions (especially those for +the negation) may not be so great. You can always override them with a `string` +expression of your own: + +```cpp +MATCHER(IsDivisibleBy7, + absl::StrCat(negation ? "isn't" : "is", " divisible by 7")) { + return (arg % 7) == 0; +} +``` + +Optionally, you can stream additional information to a hidden argument named +`result_listener` to explain the match result. For example, a better definition +of `IsDivisibleBy7` is: + +```cpp +MATCHER(IsDivisibleBy7, "") { + if ((arg % 7) == 0) + return true; + + *result_listener << "the remainder is " << (arg % 7); + return false; +} +``` + +With this definition, the above assertion will give a better message: + +```shell + Value of: some_expression + Expected: is divisible by 7 + Actual: 27 (the remainder is 6) +``` + +#### Using EXPECT_ Statements in Matchers + +You can also use `EXPECT_...` (and `ASSERT_...`) statements inside custom +matcher definitions. In many cases, this allows you to write your matcher more +concisely while still providing an informative error message. For example: + +```cpp +MATCHER(IsDivisibleBy7, "") { + const auto remainder = arg % 7; + EXPECT_EQ(remainder, 0); + return true; +} +``` + +If you write a test that includes the line `EXPECT_THAT(27, IsDivisibleBy7());`, +you will get an error something like the following: + +```shell +Expected equality of these values: + remainder + Which is: 6 + 0 +``` + +#### `MatchAndExplain` + +You should let `MatchAndExplain()` print *any additional information* that can +help a user understand the match result. Note that it should explain why the +match succeeds in case of a success (unless it's obvious) - this is useful when +the matcher is used inside `Not()`. There is no need to print the argument value +itself, as gMock already prints it for you. + +#### Argument Types + +The type of the value being matched (`arg_type`) is determined by the +context in which you use the matcher and is supplied to you by the compiler, so +you don't need to worry about declaring it (nor can you). This allows the +matcher to be polymorphic. For example, `IsDivisibleBy7()` can be used to match +any type where the value of `(arg % 7) == 0` can be implicitly converted to a +`bool`. In the `Bar(IsDivisibleBy7())` example above, if method `Bar()` takes an +`int`, `arg_type` will be `int`; if it takes an `unsigned long`, `arg_type` will +be `unsigned long`; and so on. + +### Writing New Parameterized Matchers Quickly + +Sometimes you'll want to define a matcher that has parameters. For that you can +use the macro: + +```cpp +MATCHER_P(name, param_name, description_string) { statements; } +``` + +where the description string can be either `""` or a `string` expression that +references `negation` and `param_name`. + +For example: + +```cpp +MATCHER_P(HasAbsoluteValue, value, "") { return abs(arg) == value; } +``` + +will allow you to write: + +```cpp + EXPECT_THAT(Blah("a"), HasAbsoluteValue(n)); +``` + +which may lead to this message (assuming `n` is 10): + +```shell + Value of: Blah("a") + Expected: has absolute value 10 + Actual: -9 +``` + +Note that both the matcher description and its parameter are printed, making the +message human-friendly. + +In the matcher definition body, you can write `foo_type` to reference the type +of a parameter named `foo`. For example, in the body of +`MATCHER_P(HasAbsoluteValue, value)` above, you can write `value_type` to refer +to the type of `value`. + +gMock also provides `MATCHER_P2`, `MATCHER_P3`, ..., up to `MATCHER_P10` to +support multi-parameter matchers: + +```cpp +MATCHER_Pk(name, param_1, ..., param_k, description_string) { statements; } +``` + +Please note that the custom description string is for a particular *instance* of +the matcher, where the parameters have been bound to actual values. Therefore +usually you'll want the parameter values to be part of the description. gMock +lets you do that by referencing the matcher parameters in the description string +expression. + +For example, + +```cpp +using ::testing::PrintToString; +MATCHER_P2(InClosedRange, low, hi, + absl::StrFormat("%s in range [%s, %s]", negation ? "isn't" : "is", + PrintToString(low), PrintToString(hi))) { + return low <= arg && arg <= hi; +} +... +EXPECT_THAT(3, InClosedRange(4, 6)); +``` + +would generate a failure that contains the message: + +```shell + Expected: is in range [4, 6] +``` + +If you specify `""` as the description, the failure message will contain the +sequence of words in the matcher name followed by the parameter values printed +as a tuple. For example, + +```cpp + MATCHER_P2(InClosedRange, low, hi, "") { ... } + ... + EXPECT_THAT(3, InClosedRange(4, 6)); +``` + +would generate a failure that contains the text: + +```shell + Expected: in closed range (4, 6) +``` + +For the purpose of typing, you can view + +```cpp +MATCHER_Pk(Foo, p1, ..., pk, description_string) { ... } +``` + +as shorthand for + +```cpp +template +FooMatcherPk +Foo(p1_type p1, ..., pk_type pk) { ... } +``` + +When you write `Foo(v1, ..., vk)`, the compiler infers the types of the +parameters `v1`, ..., and `vk` for you. If you are not happy with the result of +the type inference, you can specify the types by explicitly instantiating the +template, as in `Foo(5, false)`. As said earlier, you don't get to +(or need to) specify `arg_type` as that's determined by the context in which the +matcher is used. + +You can assign the result of expression `Foo(p1, ..., pk)` to a variable of type +`FooMatcherPk`. This can be useful when composing +matchers. Matchers that don't have a parameter or have only one parameter have +special types: you can assign `Foo()` to a `FooMatcher`-typed variable, and +assign `Foo(p)` to a `FooMatcherP`-typed variable. + +While you can instantiate a matcher template with reference types, passing the +parameters by pointer usually makes your code more readable. If, however, you +still want to pass a parameter by reference, be aware that in the failure +message generated by the matcher you will see the value of the referenced object +but not its address. + +You can overload matchers with different numbers of parameters: + +```cpp +MATCHER_P(Blah, a, description_string_1) { ... } +MATCHER_P2(Blah, a, b, description_string_2) { ... } +``` + +While it's tempting to always use the `MATCHER*` macros when defining a new +matcher, you should also consider implementing the matcher interface directly +instead (see the recipes that follow), especially if you need to use the matcher +a lot. While these approaches require more work, they give you more control on +the types of the value being matched and the matcher parameters, which in +general leads to better compiler error messages that pay off in the long run. +They also allow overloading matchers based on parameter types (as opposed to +just based on the number of parameters). + +### Writing New Monomorphic Matchers + +A matcher of argument type `T` implements the matcher interface for `T` and does +two things: it tests whether a value of type `T` matches the matcher, and can +describe what kind of values it matches. The latter ability is used for +generating readable error messages when expectations are violated. + +A matcher of `T` must declare a typedef like: + +```cpp +using is_gtest_matcher = void; +``` + +and supports the following operations: + +```cpp +// Match a value and optionally explain into an ostream. +bool matched = matcher.MatchAndExplain(value, maybe_os); +// where `value` is of type `T` and +// `maybe_os` is of type `std::ostream*`, where it can be null if the caller +// is not interested in there textual explanation. + +matcher.DescribeTo(os); +matcher.DescribeNegationTo(os); +// where `os` is of type `std::ostream*`. +``` + +If you need a custom matcher but `Truly()` is not a good option (for example, +you may not be happy with the way `Truly(predicate)` describes itself, or you +may want your matcher to be polymorphic as `Eq(value)` is), you can define a +matcher to do whatever you want in two steps: first implement the matcher +interface, and then define a factory function to create a matcher instance. The +second step is not strictly needed but it makes the syntax of using the matcher +nicer. + +For example, you can define a matcher to test whether an `int` is divisible by 7 +and then use it like this: + +```cpp +using ::testing::Matcher; + +class DivisibleBy7Matcher { + public: + using is_gtest_matcher = void; + + bool MatchAndExplain(int n, std::ostream*) const { + return (n % 7) == 0; + } + + void DescribeTo(std::ostream* os) const { + *os << "is divisible by 7"; + } + + void DescribeNegationTo(std::ostream* os) const { + *os << "is not divisible by 7"; + } +}; + +Matcher DivisibleBy7() { + return DivisibleBy7Matcher(); +} + +... + EXPECT_CALL(foo, Bar(DivisibleBy7())); +``` + +You may improve the matcher message by streaming additional information to the +`os` argument in `MatchAndExplain()`: + +```cpp +class DivisibleBy7Matcher { + public: + bool MatchAndExplain(int n, std::ostream* os) const { + const int remainder = n % 7; + if (remainder != 0 && os != nullptr) { + *os << "the remainder is " << remainder; + } + return remainder == 0; + } + ... +}; +``` + +Then, `EXPECT_THAT(x, DivisibleBy7());` may generate a message like this: + +```shell +Value of: x +Expected: is divisible by 7 + Actual: 23 (the remainder is 2) +``` + +{: .callout .tip} +Tip: for convenience, `MatchAndExplain()` can take a `MatchResultListener*` +instead of `std::ostream*`. + +### Writing New Polymorphic Matchers + +Expanding what we learned above to *polymorphic* matchers is now just as simple +as adding templates in the right place. + +```cpp + +class NotNullMatcher { + public: + using is_gtest_matcher = void; + + // To implement a polymorphic matcher, we just need to make MatchAndExplain a + // template on its first argument. + + // In this example, we want to use NotNull() with any pointer, so + // MatchAndExplain() accepts a pointer of any type as its first argument. + // In general, you can define MatchAndExplain() as an ordinary method or + // a method template, or even overload it. + template + bool MatchAndExplain(T* p, std::ostream*) const { + return p != nullptr; + } + + // Describes the property of a value matching this matcher. + void DescribeTo(std::ostream* os) const { *os << "is not NULL"; } + + // Describes the property of a value NOT matching this matcher. + void DescribeNegationTo(std::ostream* os) const { *os << "is NULL"; } +}; + +NotNullMatcher NotNull() { + return NotNullMatcher(); +} + +... + + EXPECT_CALL(foo, Bar(NotNull())); // The argument must be a non-NULL pointer. +``` + +### Legacy Matcher Implementation + +Defining matchers used to be somewhat more complicated, in which it required +several supporting classes and virtual functions. To implement a matcher for +type `T` using the legacy API you have to derive from `MatcherInterface` and +call `MakeMatcher` to construct the object. + +The interface looks like this: + +```cpp +class MatchResultListener { + public: + ... + // Streams x to the underlying ostream; does nothing if the ostream + // is NULL. + template + MatchResultListener& operator<<(const T& x); + + // Returns the underlying ostream. + std::ostream* stream(); +}; + +template +class MatcherInterface { + public: + virtual ~MatcherInterface(); + + // Returns true if and only if the matcher matches x; also explains the match + // result to 'listener'. + virtual bool MatchAndExplain(T x, MatchResultListener* listener) const = 0; + + // Describes this matcher to an ostream. + virtual void DescribeTo(std::ostream* os) const = 0; + + // Describes the negation of this matcher to an ostream. + virtual void DescribeNegationTo(std::ostream* os) const; +}; +``` + +Fortunately, most of the time you can define a polymorphic matcher easily with +the help of `MakePolymorphicMatcher()`. Here's how you can define `NotNull()` as +an example: + +```cpp +using ::testing::MakePolymorphicMatcher; +using ::testing::MatchResultListener; +using ::testing::PolymorphicMatcher; + +class NotNullMatcher { + public: + // To implement a polymorphic matcher, first define a COPYABLE class + // that has three members MatchAndExplain(), DescribeTo(), and + // DescribeNegationTo(), like the following. + + // In this example, we want to use NotNull() with any pointer, so + // MatchAndExplain() accepts a pointer of any type as its first argument. + // In general, you can define MatchAndExplain() as an ordinary method or + // a method template, or even overload it. + template + bool MatchAndExplain(T* p, + MatchResultListener* /* listener */) const { + return p != NULL; + } + + // Describes the property of a value matching this matcher. + void DescribeTo(std::ostream* os) const { *os << "is not NULL"; } + + // Describes the property of a value NOT matching this matcher. + void DescribeNegationTo(std::ostream* os) const { *os << "is NULL"; } +}; + +// To construct a polymorphic matcher, pass an instance of the class +// to MakePolymorphicMatcher(). Note the return type. +PolymorphicMatcher NotNull() { + return MakePolymorphicMatcher(NotNullMatcher()); +} + +... + + EXPECT_CALL(foo, Bar(NotNull())); // The argument must be a non-NULL pointer. +``` + +{: .callout .note} +**Note:** Your polymorphic matcher class does **not** need to inherit from +`MatcherInterface` or any other class, and its methods do **not** need to be +virtual. + +Like in a monomorphic matcher, you may explain the match result by streaming +additional information to the `listener` argument in `MatchAndExplain()`. + +### Writing New Cardinalities + +A cardinality is used in `Times()` to tell gMock how many times you expect a +call to occur. It doesn't have to be exact. For example, you can say +`AtLeast(5)` or `Between(2, 4)`. + +If the [built-in set](gmock_cheat_sheet.md#CardinalityList) of cardinalities +doesn't suit you, you are free to define your own by implementing the following +interface (in namespace `testing`): + +```cpp +class CardinalityInterface { + public: + virtual ~CardinalityInterface(); + + // Returns true if and only if call_count calls will satisfy this cardinality. + virtual bool IsSatisfiedByCallCount(int call_count) const = 0; + + // Returns true if and only if call_count calls will saturate this + // cardinality. + virtual bool IsSaturatedByCallCount(int call_count) const = 0; + + // Describes self to an ostream. + virtual void DescribeTo(std::ostream* os) const = 0; +}; +``` + +For example, to specify that a call must occur even number of times, you can +write + +```cpp +using ::testing::Cardinality; +using ::testing::CardinalityInterface; +using ::testing::MakeCardinality; + +class EvenNumberCardinality : public CardinalityInterface { + public: + bool IsSatisfiedByCallCount(int call_count) const override { + return (call_count % 2) == 0; + } + + bool IsSaturatedByCallCount(int call_count) const override { + return false; + } + + void DescribeTo(std::ostream* os) const { + *os << "called even number of times"; + } +}; + +Cardinality EvenNumber() { + return MakeCardinality(new EvenNumberCardinality); +} + +... + EXPECT_CALL(foo, Bar(3)) + .Times(EvenNumber()); +``` + +### Writing New Actions {#QuickNewActions} + +If the built-in actions don't work for you, you can easily define your own one. +All you need is a call operator with a signature compatible with the mocked +function. So you can use a lambda: + +```cpp +MockFunction mock; +EXPECT_CALL(mock, Call).WillOnce([](const int input) { return input * 7; }); +EXPECT_EQ(mock.AsStdFunction()(2), 14); +``` + +Or a struct with a call operator (even a templated one): + +```cpp +struct MultiplyBy { + template + T operator()(T arg) { return arg * multiplier; } + + int multiplier; +}; + +// Then use: +// EXPECT_CALL(...).WillOnce(MultiplyBy{7}); +``` + +It's also fine for the callable to take no arguments, ignoring the arguments +supplied to the mock function: + +```cpp +MockFunction mock; +EXPECT_CALL(mock, Call).WillOnce([] { return 17; }); +EXPECT_EQ(mock.AsStdFunction()(0), 17); +``` + +When used with `WillOnce`, the callable can assume it will be called at most +once and is allowed to be a move-only type: + +```cpp +// An action that contains move-only types and has an &&-qualified operator, +// demanding in the type system that it be called at most once. This can be +// used with WillOnce, but the compiler will reject it if handed to +// WillRepeatedly. +struct MoveOnlyAction { + std::unique_ptr move_only_state; + std::unique_ptr operator()() && { return std::move(move_only_state); } +}; + +MockFunction()> mock; +EXPECT_CALL(mock, Call).WillOnce(MoveOnlyAction{std::make_unique(17)}); +EXPECT_THAT(mock.AsStdFunction()(), Pointee(Eq(17))); +``` + +More generally, to use with a mock function whose signature is `R(Args...)` the +object can be anything convertible to `OnceAction` or +`Action. The difference between the two is that `OnceAction` has +weaker requirements (`Action` requires a copy-constructible input that can be +called repeatedly whereas `OnceAction` requires only move-constructible and +supports `&&`-qualified call operators), but can be used only with `WillOnce`. +`OnceAction` is typically relevant only when supporting move-only types or +actions that want a type-system guarantee that they will be called at most once. + +Typically the `OnceAction` and `Action` templates need not be referenced +directly in your actions: a struct or class with a call operator is sufficient, +as in the examples above. But fancier polymorphic actions that need to know the +specific return type of the mock function can define templated conversion +operators to make that possible. See `gmock-actions.h` for examples. + +#### Legacy macro-based Actions + +Before C++11, the functor-based actions were not supported; the old way of +writing actions was through a set of `ACTION*` macros. We suggest to avoid them +in new code; they hide a lot of logic behind the macro, potentially leading to +harder-to-understand compiler errors. Nevertheless, we cover them here for +completeness. + +By writing + +```cpp +ACTION(name) { statements; } +``` + +in a namespace scope (i.e. not inside a class or function), you will define an +action with the given name that executes the statements. The value returned by +`statements` will be used as the return value of the action. Inside the +statements, you can refer to the K-th (0-based) argument of the mock function as +`argK`. For example: + +```cpp +ACTION(IncrementArg1) { return ++(*arg1); } +``` + +allows you to write + +```cpp +... WillOnce(IncrementArg1()); +``` + +Note that you don't need to specify the types of the mock function arguments. +Rest assured that your code is type-safe though: you'll get a compiler error if +`*arg1` doesn't support the `++` operator, or if the type of `++(*arg1)` isn't +compatible with the mock function's return type. + +Another example: + +```cpp +ACTION(Foo) { + (*arg2)(5); + Blah(); + *arg1 = 0; + return arg0; +} +``` + +defines an action `Foo()` that invokes argument #2 (a function pointer) with 5, +calls function `Blah()`, sets the value pointed to by argument #1 to 0, and +returns argument #0. + +For more convenience and flexibility, you can also use the following pre-defined +symbols in the body of `ACTION`: + +`argK_type` | The type of the K-th (0-based) argument of the mock function +:-------------- | :----------------------------------------------------------- +`args` | All arguments of the mock function as a tuple +`args_type` | The type of all arguments of the mock function as a tuple +`return_type` | The return type of the mock function +`function_type` | The type of the mock function + +For example, when using an `ACTION` as a stub action for mock function: + +```cpp +int DoSomething(bool flag, int* ptr); +``` + +we have: + +Pre-defined Symbol | Is Bound To +------------------ | --------------------------------- +`arg0` | the value of `flag` +`arg0_type` | the type `bool` +`arg1` | the value of `ptr` +`arg1_type` | the type `int*` +`args` | the tuple `(flag, ptr)` +`args_type` | the type `std::tuple` +`return_type` | the type `int` +`function_type` | the type `int(bool, int*)` + +#### Legacy macro-based parameterized Actions + +Sometimes you'll want to parameterize an action you define. For that we have +another macro + +```cpp +ACTION_P(name, param) { statements; } +``` + +For example, + +```cpp +ACTION_P(Add, n) { return arg0 + n; } +``` + +will allow you to write + +```cpp +// Returns argument #0 + 5. +... WillOnce(Add(5)); +``` + +For convenience, we use the term *arguments* for the values used to invoke the +mock function, and the term *parameters* for the values used to instantiate an +action. + +Note that you don't need to provide the type of the parameter either. Suppose +the parameter is named `param`, you can also use the gMock-defined symbol +`param_type` to refer to the type of the parameter as inferred by the compiler. +For example, in the body of `ACTION_P(Add, n)` above, you can write `n_type` for +the type of `n`. + +gMock also provides `ACTION_P2`, `ACTION_P3`, and etc to support multi-parameter +actions. For example, + +```cpp +ACTION_P2(ReturnDistanceTo, x, y) { + double dx = arg0 - x; + double dy = arg1 - y; + return sqrt(dx*dx + dy*dy); +} +``` + +lets you write + +```cpp +... WillOnce(ReturnDistanceTo(5.0, 26.5)); +``` + +You can view `ACTION` as a degenerated parameterized action where the number of +parameters is 0. + +You can also easily define actions overloaded on the number of parameters: + +```cpp +ACTION_P(Plus, a) { ... } +ACTION_P2(Plus, a, b) { ... } +``` + +### Restricting the Type of an Argument or Parameter in an ACTION + +For maximum brevity and reusability, the `ACTION*` macros don't ask you to +provide the types of the mock function arguments and the action parameters. +Instead, we let the compiler infer the types for us. + +Sometimes, however, we may want to be more explicit about the types. There are +several tricks to do that. For example: + +```cpp +ACTION(Foo) { + // Makes sure arg0 can be converted to int. + int n = arg0; + ... use n instead of arg0 here ... +} + +ACTION_P(Bar, param) { + // Makes sure the type of arg1 is const char*. + ::testing::StaticAssertTypeEq(); + + // Makes sure param can be converted to bool. + bool flag = param; +} +``` + +where `StaticAssertTypeEq` is a compile-time assertion in googletest that +verifies two types are the same. + +### Writing New Action Templates Quickly + +Sometimes you want to give an action explicit template parameters that cannot be +inferred from its value parameters. `ACTION_TEMPLATE()` supports that and can be +viewed as an extension to `ACTION()` and `ACTION_P*()`. + +The syntax: + +```cpp +ACTION_TEMPLATE(ActionName, + HAS_m_TEMPLATE_PARAMS(kind1, name1, ..., kind_m, name_m), + AND_n_VALUE_PARAMS(p1, ..., p_n)) { statements; } +``` + +defines an action template that takes *m* explicit template parameters and *n* +value parameters, where *m* is in [1, 10] and *n* is in [0, 10]. `name_i` is the +name of the *i*-th template parameter, and `kind_i` specifies whether it's a +`typename`, an integral constant, or a template. `p_i` is the name of the *i*-th +value parameter. + +Example: + +```cpp +// DuplicateArg(output) converts the k-th argument of the mock +// function to type T and copies it to *output. +ACTION_TEMPLATE(DuplicateArg, + // Note the comma between int and k: + HAS_2_TEMPLATE_PARAMS(int, k, typename, T), + AND_1_VALUE_PARAMS(output)) { + *output = T(std::get(args)); +} +``` + +To create an instance of an action template, write: + +```cpp +ActionName(v1, ..., v_n) +``` + +where the `t`s are the template arguments and the `v`s are the value arguments. +The value argument types are inferred by the compiler. For example: + +```cpp +using ::testing::_; +... + int n; + EXPECT_CALL(mock, Foo).WillOnce(DuplicateArg<1, unsigned char>(&n)); +``` + +If you want to explicitly specify the value argument types, you can provide +additional template arguments: + +```cpp +ActionName(v1, ..., v_n) +``` + +where `u_i` is the desired type of `v_i`. + +`ACTION_TEMPLATE` and `ACTION`/`ACTION_P*` can be overloaded on the number of +value parameters, but not on the number of template parameters. Without the +restriction, the meaning of the following is unclear: + +```cpp + OverloadedAction(x); +``` + +Are we using a single-template-parameter action where `bool` refers to the type +of `x`, or a two-template-parameter action where the compiler is asked to infer +the type of `x`? + +### Using the ACTION Object's Type + +If you are writing a function that returns an `ACTION` object, you'll need to +know its type. The type depends on the macro used to define the action and the +parameter types. The rule is relatively simple: + + +| Given Definition | Expression | Has Type | +| ----------------------------- | ------------------- | --------------------- | +| `ACTION(Foo)` | `Foo()` | `FooAction` | +| `ACTION_TEMPLATE(Foo, HAS_m_TEMPLATE_PARAMS(...), AND_0_VALUE_PARAMS())` | `Foo()` | `FooAction` | +| `ACTION_P(Bar, param)` | `Bar(int_value)` | `BarActionP` | +| `ACTION_TEMPLATE(Bar, HAS_m_TEMPLATE_PARAMS(...), AND_1_VALUE_PARAMS(p1))` | `Bar(int_value)` | `BarActionP` | +| `ACTION_P2(Baz, p1, p2)` | `Baz(bool_value, int_value)` | `BazActionP2` | +| `ACTION_TEMPLATE(Baz, HAS_m_TEMPLATE_PARAMS(...), AND_2_VALUE_PARAMS(p1, p2))` | `Baz(bool_value, int_value)` | `BazActionP2` | +| ... | ... | ... | + + +Note that we have to pick different suffixes (`Action`, `ActionP`, `ActionP2`, +and etc) for actions with different numbers of value parameters, or the action +definitions cannot be overloaded on the number of them. + +### Writing New Monomorphic Actions {#NewMonoActions} + +While the `ACTION*` macros are very convenient, sometimes they are +inappropriate. For example, despite the tricks shown in the previous recipes, +they don't let you directly specify the types of the mock function arguments and +the action parameters, which in general leads to unoptimized compiler error +messages that can baffle unfamiliar users. They also don't allow overloading +actions based on parameter types without jumping through some hoops. + +An alternative to the `ACTION*` macros is to implement +`::testing::ActionInterface`, where `F` is the type of the mock function in +which the action will be used. For example: + +```cpp +template +class ActionInterface { + public: + virtual ~ActionInterface(); + + // Performs the action. Result is the return type of function type + // F, and ArgumentTuple is the tuple of arguments of F. + // + + // For example, if F is int(bool, const string&), then Result would + // be int, and ArgumentTuple would be std::tuple. + virtual Result Perform(const ArgumentTuple& args) = 0; +}; +``` + +```cpp +using ::testing::_; +using ::testing::Action; +using ::testing::ActionInterface; +using ::testing::MakeAction; + +typedef int IncrementMethod(int*); + +class IncrementArgumentAction : public ActionInterface { + public: + int Perform(const std::tuple& args) override { + int* p = std::get<0>(args); // Grabs the first argument. + return *p++; + } +}; + +Action IncrementArgument() { + return MakeAction(new IncrementArgumentAction); +} + +... + EXPECT_CALL(foo, Baz(_)) + .WillOnce(IncrementArgument()); + + int n = 5; + foo.Baz(&n); // Should return 5 and change n to 6. +``` + +### Writing New Polymorphic Actions {#NewPolyActions} + +The previous recipe showed you how to define your own action. This is all good, +except that you need to know the type of the function in which the action will +be used. Sometimes that can be a problem. For example, if you want to use the +action in functions with *different* types (e.g. like `Return()` and +`SetArgPointee()`). + +If an action can be used in several types of mock functions, we say it's +*polymorphic*. The `MakePolymorphicAction()` function template makes it easy to +define such an action: + +```cpp +namespace testing { +template +PolymorphicAction MakePolymorphicAction(const Impl& impl); +} // namespace testing +``` + +As an example, let's define an action that returns the second argument in the +mock function's argument list. The first step is to define an implementation +class: + +```cpp +class ReturnSecondArgumentAction { + public: + template + Result Perform(const ArgumentTuple& args) const { + // To get the i-th (0-based) argument, use std::get(args). + return std::get<1>(args); + } +}; +``` + +This implementation class does *not* need to inherit from any particular class. +What matters is that it must have a `Perform()` method template. This method +template takes the mock function's arguments as a tuple in a **single** +argument, and returns the result of the action. It can be either `const` or not, +but must be invocable with exactly one template argument, which is the result +type. In other words, you must be able to call `Perform(args)` where `R` is +the mock function's return type and `args` is its arguments in a tuple. + +Next, we use `MakePolymorphicAction()` to turn an instance of the implementation +class into the polymorphic action we need. It will be convenient to have a +wrapper for this: + +```cpp +using ::testing::MakePolymorphicAction; +using ::testing::PolymorphicAction; + +PolymorphicAction ReturnSecondArgument() { + return MakePolymorphicAction(ReturnSecondArgumentAction()); +} +``` + +Now, you can use this polymorphic action the same way you use the built-in ones: + +```cpp +using ::testing::_; + +class MockFoo : public Foo { + public: + MOCK_METHOD(int, DoThis, (bool flag, int n), (override)); + MOCK_METHOD(string, DoThat, (int x, const char* str1, const char* str2), + (override)); +}; + + ... + MockFoo foo; + EXPECT_CALL(foo, DoThis).WillOnce(ReturnSecondArgument()); + EXPECT_CALL(foo, DoThat).WillOnce(ReturnSecondArgument()); + ... + foo.DoThis(true, 5); // Will return 5. + foo.DoThat(1, "Hi", "Bye"); // Will return "Hi". +``` + +### Teaching gMock How to Print Your Values + +When an uninteresting or unexpected call occurs, gMock prints the argument +values and the stack trace to help you debug. Assertion macros like +`EXPECT_THAT` and `EXPECT_EQ` also print the values in question when the +assertion fails. gMock and googletest do this using googletest's user-extensible +value printer. + +This printer knows how to print built-in C++ types, native arrays, STL +containers, and any type that supports the `<<` operator. For other types, it +prints the raw bytes in the value and hopes that you the user can figure it out. +[The GoogleTest advanced guide](advanced.md#teaching-googletest-how-to-print-your-values) +explains how to extend the printer to do a better job at printing your +particular type than to dump the bytes. + +## Useful Mocks Created Using gMock + + + + +### Mock std::function {#MockFunction} + +`std::function` is a general function type introduced in C++11. It is a +preferred way of passing callbacks to new interfaces. Functions are copyable, +and are not usually passed around by pointer, which makes them tricky to mock. +But fear not - `MockFunction` can help you with that. + +`MockFunction` has a mock method `Call()` with the signature: + +```cpp + R Call(T1, ..., Tn); +``` + +It also has a `AsStdFunction()` method, which creates a `std::function` proxy +forwarding to Call: + +```cpp + std::function AsStdFunction(); +``` + +To use `MockFunction`, first create `MockFunction` object and set up +expectations on its `Call` method. Then pass proxy obtained from +`AsStdFunction()` to the code you are testing. For example: + +```cpp +TEST(FooTest, RunsCallbackWithBarArgument) { + // 1. Create a mock object. + MockFunction mock_function; + + // 2. Set expectations on Call() method. + EXPECT_CALL(mock_function, Call("bar")).WillOnce(Return(1)); + + // 3. Exercise code that uses std::function. + Foo(mock_function.AsStdFunction()); + // Foo's signature can be either of: + // void Foo(const std::function& fun); + // void Foo(std::function fun); + + // 4. All expectations will be verified when mock_function + // goes out of scope and is destroyed. +} +``` + +Remember that function objects created with `AsStdFunction()` are just +forwarders. If you create multiple of them, they will share the same set of +expectations. + +Although `std::function` supports unlimited number of arguments, `MockFunction` +implementation is limited to ten. If you ever hit that limit... well, your +callback has bigger problems than being mockable. :-) diff --git a/test/googletest/docs/gmock_faq.md b/test/googletest/docs/gmock_faq.md new file mode 100644 index 0000000..8f220bf --- /dev/null +++ b/test/googletest/docs/gmock_faq.md @@ -0,0 +1,390 @@ +# Legacy gMock FAQ + +### When I call a method on my mock object, the method for the real object is invoked instead. What's the problem? + +In order for a method to be mocked, it must be *virtual*, unless you use the +[high-perf dependency injection technique](gmock_cook_book.md#MockingNonVirtualMethods). + +### Can I mock a variadic function? + +You cannot mock a variadic function (i.e. a function taking ellipsis (`...`) +arguments) directly in gMock. + +The problem is that in general, there is *no way* for a mock object to know how +many arguments are passed to the variadic method, and what the arguments' types +are. Only the *author of the base class* knows the protocol, and we cannot look +into his or her head. + +Therefore, to mock such a function, the *user* must teach the mock object how to +figure out the number of arguments and their types. One way to do it is to +provide overloaded versions of the function. + +Ellipsis arguments are inherited from C and not really a C++ feature. They are +unsafe to use and don't work with arguments that have constructors or +destructors. Therefore we recommend to avoid them in C++ as much as possible. + +### MSVC gives me warning C4301 or C4373 when I define a mock method with a const parameter. Why? + +If you compile this using Microsoft Visual C++ 2005 SP1: + +```cpp +class Foo { + ... + virtual void Bar(const int i) = 0; +}; + +class MockFoo : public Foo { + ... + MOCK_METHOD(void, Bar, (const int i), (override)); +}; +``` + +You may get the following warning: + +```shell +warning C4301: 'MockFoo::Bar': overriding virtual function only differs from 'Foo::Bar' by const/volatile qualifier +``` + +This is a MSVC bug. The same code compiles fine with gcc, for example. If you +use Visual C++ 2008 SP1, you would get the warning: + +```shell +warning C4373: 'MockFoo::Bar': virtual function overrides 'Foo::Bar', previous versions of the compiler did not override when parameters only differed by const/volatile qualifiers +``` + +In C++, if you *declare* a function with a `const` parameter, the `const` +modifier is ignored. Therefore, the `Foo` base class above is equivalent to: + +```cpp +class Foo { + ... + virtual void Bar(int i) = 0; // int or const int? Makes no difference. +}; +``` + +In fact, you can *declare* `Bar()` with an `int` parameter, and define it with a +`const int` parameter. The compiler will still match them up. + +Since making a parameter `const` is meaningless in the method declaration, we +recommend to remove it in both `Foo` and `MockFoo`. That should workaround the +VC bug. + +Note that we are talking about the *top-level* `const` modifier here. If the +function parameter is passed by pointer or reference, declaring the pointee or +referee as `const` is still meaningful. For example, the following two +declarations are *not* equivalent: + +```cpp +void Bar(int* p); // Neither p nor *p is const. +void Bar(const int* p); // p is not const, but *p is. +``` + +### I can't figure out why gMock thinks my expectations are not satisfied. What should I do? + +You might want to run your test with `--gmock_verbose=info`. This flag lets +gMock print a trace of every mock function call it receives. By studying the +trace, you'll gain insights on why the expectations you set are not met. + +If you see the message "The mock function has no default action set, and its +return type has no default value set.", then try +[adding a default action](gmock_cheat_sheet.md#OnCall). Due to a known issue, +unexpected calls on mocks without default actions don't print out a detailed +comparison between the actual arguments and the expected arguments. + +### My program crashed and `ScopedMockLog` spit out tons of messages. Is it a gMock bug? + +gMock and `ScopedMockLog` are likely doing the right thing here. + +When a test crashes, the failure signal handler will try to log a lot of +information (the stack trace, and the address map, for example). The messages +are compounded if you have many threads with depth stacks. When `ScopedMockLog` +intercepts these messages and finds that they don't match any expectations, it +prints an error for each of them. + +You can learn to ignore the errors, or you can rewrite your expectations to make +your test more robust, for example, by adding something like: + +```cpp +using ::testing::AnyNumber; +using ::testing::Not; +... + // Ignores any log not done by us. + EXPECT_CALL(log, Log(_, Not(EndsWith("/my_file.cc")), _)) + .Times(AnyNumber()); +``` + +### How can I assert that a function is NEVER called? + +```cpp +using ::testing::_; +... + EXPECT_CALL(foo, Bar(_)) + .Times(0); +``` + +### I have a failed test where gMock tells me TWICE that a particular expectation is not satisfied. Isn't this redundant? + +When gMock detects a failure, it prints relevant information (the mock function +arguments, the state of relevant expectations, and etc) to help the user debug. +If another failure is detected, gMock will do the same, including printing the +state of relevant expectations. + +Sometimes an expectation's state didn't change between two failures, and you'll +see the same description of the state twice. They are however *not* redundant, +as they refer to *different points in time*. The fact they are the same *is* +interesting information. + +### I get a heapcheck failure when using a mock object, but using a real object is fine. What can be wrong? + +Does the class (hopefully a pure interface) you are mocking have a virtual +destructor? + +Whenever you derive from a base class, make sure its destructor is virtual. +Otherwise Bad Things will happen. Consider the following code: + +```cpp +class Base { + public: + // Not virtual, but should be. + ~Base() { ... } + ... +}; + +class Derived : public Base { + public: + ... + private: + std::string value_; +}; + +... + Base* p = new Derived; + ... + delete p; // Surprise! ~Base() will be called, but ~Derived() will not + // - value_ is leaked. +``` + +By changing `~Base()` to virtual, `~Derived()` will be correctly called when +`delete p` is executed, and the heap checker will be happy. + +### The "newer expectations override older ones" rule makes writing expectations awkward. Why does gMock do that? + +When people complain about this, often they are referring to code like: + +```cpp +using ::testing::Return; +... + // foo.Bar() should be called twice, return 1 the first time, and return + // 2 the second time. However, I have to write the expectations in the + // reverse order. This sucks big time!!! + EXPECT_CALL(foo, Bar()) + .WillOnce(Return(2)) + .RetiresOnSaturation(); + EXPECT_CALL(foo, Bar()) + .WillOnce(Return(1)) + .RetiresOnSaturation(); +``` + +The problem, is that they didn't pick the **best** way to express the test's +intent. + +By default, expectations don't have to be matched in *any* particular order. If +you want them to match in a certain order, you need to be explicit. This is +gMock's (and jMock's) fundamental philosophy: it's easy to accidentally +over-specify your tests, and we want to make it harder to do so. + +There are two better ways to write the test spec. You could either put the +expectations in sequence: + +```cpp +using ::testing::Return; +... + // foo.Bar() should be called twice, return 1 the first time, and return + // 2 the second time. Using a sequence, we can write the expectations + // in their natural order. + { + InSequence s; + EXPECT_CALL(foo, Bar()) + .WillOnce(Return(1)) + .RetiresOnSaturation(); + EXPECT_CALL(foo, Bar()) + .WillOnce(Return(2)) + .RetiresOnSaturation(); + } +``` + +or you can put the sequence of actions in the same expectation: + +```cpp +using ::testing::Return; +... + // foo.Bar() should be called twice, return 1 the first time, and return + // 2 the second time. + EXPECT_CALL(foo, Bar()) + .WillOnce(Return(1)) + .WillOnce(Return(2)) + .RetiresOnSaturation(); +``` + +Back to the original questions: why does gMock search the expectations (and +`ON_CALL`s) from back to front? Because this allows a user to set up a mock's +behavior for the common case early (e.g. in the mock's constructor or the test +fixture's set-up phase) and customize it with more specific rules later. If +gMock searches from front to back, this very useful pattern won't be possible. + +### gMock prints a warning when a function without EXPECT_CALL is called, even if I have set its behavior using ON_CALL. Would it be reasonable not to show the warning in this case? + +When choosing between being neat and being safe, we lean toward the latter. So +the answer is that we think it's better to show the warning. + +Often people write `ON_CALL`s in the mock object's constructor or `SetUp()`, as +the default behavior rarely changes from test to test. Then in the test body +they set the expectations, which are often different for each test. Having an +`ON_CALL` in the set-up part of a test doesn't mean that the calls are expected. +If there's no `EXPECT_CALL` and the method is called, it's possibly an error. If +we quietly let the call go through without notifying the user, bugs may creep in +unnoticed. + +If, however, you are sure that the calls are OK, you can write + +```cpp +using ::testing::_; +... + EXPECT_CALL(foo, Bar(_)) + .WillRepeatedly(...); +``` + +instead of + +```cpp +using ::testing::_; +... + ON_CALL(foo, Bar(_)) + .WillByDefault(...); +``` + +This tells gMock that you do expect the calls and no warning should be printed. + +Also, you can control the verbosity by specifying `--gmock_verbose=error`. Other +values are `info` and `warning`. If you find the output too noisy when +debugging, just choose a less verbose level. + +### How can I delete the mock function's argument in an action? + +If your mock function takes a pointer argument and you want to delete that +argument, you can use testing::DeleteArg() to delete the N'th (zero-indexed) +argument: + +```cpp +using ::testing::_; + ... + MOCK_METHOD(void, Bar, (X* x, const Y& y)); + ... + EXPECT_CALL(mock_foo_, Bar(_, _)) + .WillOnce(testing::DeleteArg<0>())); +``` + +### How can I perform an arbitrary action on a mock function's argument? + +If you find yourself needing to perform some action that's not supported by +gMock directly, remember that you can define your own actions using +[`MakeAction()`](#NewMonoActions) or +[`MakePolymorphicAction()`](#NewPolyActions), or you can write a stub function +and invoke it using [`Invoke()`](#FunctionsAsActions). + +```cpp +using ::testing::_; +using ::testing::Invoke; + ... + MOCK_METHOD(void, Bar, (X* p)); + ... + EXPECT_CALL(mock_foo_, Bar(_)) + .WillOnce(Invoke(MyAction(...))); +``` + +### My code calls a static/global function. Can I mock it? + +You can, but you need to make some changes. + +In general, if you find yourself needing to mock a static function, it's a sign +that your modules are too tightly coupled (and less flexible, less reusable, +less testable, etc). You are probably better off defining a small interface and +call the function through that interface, which then can be easily mocked. It's +a bit of work initially, but usually pays for itself quickly. + +This Google Testing Blog +[post](https://testing.googleblog.com/2008/06/defeat-static-cling.html) says it +excellently. Check it out. + +### My mock object needs to do complex stuff. It's a lot of pain to specify the actions. gMock sucks! + +I know it's not a question, but you get an answer for free any way. :-) + +With gMock, you can create mocks in C++ easily. And people might be tempted to +use them everywhere. Sometimes they work great, and sometimes you may find them, +well, a pain to use. So, what's wrong in the latter case? + +When you write a test without using mocks, you exercise the code and assert that +it returns the correct value or that the system is in an expected state. This is +sometimes called "state-based testing". + +Mocks are great for what some call "interaction-based" testing: instead of +checking the system state at the very end, mock objects verify that they are +invoked the right way and report an error as soon as it arises, giving you a +handle on the precise context in which the error was triggered. This is often +more effective and economical to do than state-based testing. + +If you are doing state-based testing and using a test double just to simulate +the real object, you are probably better off using a fake. Using a mock in this +case causes pain, as it's not a strong point for mocks to perform complex +actions. If you experience this and think that mocks suck, you are just not +using the right tool for your problem. Or, you might be trying to solve the +wrong problem. :-) + +### I got a warning "Uninteresting function call encountered - default action taken.." Should I panic? + +By all means, NO! It's just an FYI. :-) + +What it means is that you have a mock function, you haven't set any expectations +on it (by gMock's rule this means that you are not interested in calls to this +function and therefore it can be called any number of times), and it is called. +That's OK - you didn't say it's not OK to call the function! + +What if you actually meant to disallow this function to be called, but forgot to +write `EXPECT_CALL(foo, Bar()).Times(0)`? While one can argue that it's the +user's fault, gMock tries to be nice and prints you a note. + +So, when you see the message and believe that there shouldn't be any +uninteresting calls, you should investigate what's going on. To make your life +easier, gMock dumps the stack trace when an uninteresting call is encountered. +From that you can figure out which mock function it is, and how it is called. + +### I want to define a custom action. Should I use Invoke() or implement the ActionInterface interface? + +Either way is fine - you want to choose the one that's more convenient for your +circumstance. + +Usually, if your action is for a particular function type, defining it using +`Invoke()` should be easier; if your action can be used in functions of +different types (e.g. if you are defining `Return(*value*)`), +`MakePolymorphicAction()` is easiest. Sometimes you want precise control on what +types of functions the action can be used in, and implementing `ActionInterface` +is the way to go here. See the implementation of `Return()` in `gmock-actions.h` +for an example. + +### I use SetArgPointee() in WillOnce(), but gcc complains about "conflicting return type specified". What does it mean? + +You got this error as gMock has no idea what value it should return when the +mock method is called. `SetArgPointee()` says what the side effect is, but +doesn't say what the return value should be. You need `DoAll()` to chain a +`SetArgPointee()` with a `Return()` that provides a value appropriate to the API +being mocked. + +See this [recipe](gmock_cook_book.md#mocking-side-effects) for more details and +an example. + +### I have a huge mock class, and Microsoft Visual C++ runs out of memory when compiling it. What can I do? + +We've noticed that when the `/clr` compiler flag is used, Visual C++ uses 5~6 +times as much memory when compiling a mock class. We suggest to avoid `/clr` +when compiling native C++ mocks. diff --git a/test/googletest/docs/gmock_for_dummies.md b/test/googletest/docs/gmock_for_dummies.md new file mode 100644 index 0000000..ed2297c --- /dev/null +++ b/test/googletest/docs/gmock_for_dummies.md @@ -0,0 +1,702 @@ +# gMock for Dummies + +## What Is gMock? + +When you write a prototype or test, often it's not feasible or wise to rely on +real objects entirely. A **mock object** implements the same interface as a real +object (so it can be used as one), but lets you specify at run time how it will +be used and what it should do (which methods will be called? in which order? how +many times? with what arguments? what will they return? etc). + +It is easy to confuse the term *fake objects* with mock objects. Fakes and mocks +actually mean very different things in the Test-Driven Development (TDD) +community: + +* **Fake** objects have working implementations, but usually take some + shortcut (perhaps to make the operations less expensive), which makes them + not suitable for production. An in-memory file system would be an example of + a fake. +* **Mocks** are objects pre-programmed with *expectations*, which form a + specification of the calls they are expected to receive. + +If all this seems too abstract for you, don't worry - the most important thing +to remember is that a mock allows you to check the *interaction* between itself +and code that uses it. The difference between fakes and mocks shall become much +clearer once you start to use mocks. + +**gMock** is a library (sometimes we also call it a "framework" to make it sound +cool) for creating mock classes and using them. It does to C++ what +jMock/EasyMock does to Java (well, more or less). + +When using gMock, + +1. first, you use some simple macros to describe the interface you want to + mock, and they will expand to the implementation of your mock class; +2. next, you create some mock objects and specify its expectations and behavior + using an intuitive syntax; +3. then you exercise code that uses the mock objects. gMock will catch any + violation to the expectations as soon as it arises. + +## Why gMock? + +While mock objects help you remove unnecessary dependencies in tests and make +them fast and reliable, using mocks manually in C++ is *hard*: + +* Someone has to implement the mocks. The job is usually tedious and + error-prone. No wonder people go great distance to avoid it. +* The quality of those manually written mocks is a bit, uh, unpredictable. You + may see some really polished ones, but you may also see some that were + hacked up in a hurry and have all sorts of ad hoc restrictions. +* The knowledge you gained from using one mock doesn't transfer to the next + one. + +In contrast, Java and Python programmers have some fine mock frameworks (jMock, +EasyMock, etc), which automate the creation of mocks. As a result, mocking is a +proven effective technique and widely adopted practice in those communities. +Having the right tool absolutely makes the difference. + +gMock was built to help C++ programmers. It was inspired by jMock and EasyMock, +but designed with C++'s specifics in mind. It is your friend if any of the +following problems is bothering you: + +* You are stuck with a sub-optimal design and wish you had done more + prototyping before it was too late, but prototyping in C++ is by no means + "rapid". +* Your tests are slow as they depend on too many libraries or use expensive + resources (e.g. a database). +* Your tests are brittle as some resources they use are unreliable (e.g. the + network). +* You want to test how your code handles a failure (e.g. a file checksum + error), but it's not easy to cause one. +* You need to make sure that your module interacts with other modules in the + right way, but it's hard to observe the interaction; therefore you resort to + observing the side effects at the end of the action, but it's awkward at + best. +* You want to "mock out" your dependencies, except that they don't have mock + implementations yet; and, frankly, you aren't thrilled by some of those + hand-written mocks. + +We encourage you to use gMock as + +* a *design* tool, for it lets you experiment with your interface design early + and often. More iterations lead to better designs! +* a *testing* tool to cut your tests' outbound dependencies and probe the + interaction between your module and its collaborators. + +## Getting Started + +gMock is bundled with googletest. + +## A Case for Mock Turtles + +Let's look at an example. Suppose you are developing a graphics program that +relies on a [LOGO](https://en.wikipedia.org/wiki/Logo_programming_language)-like +API for drawing. How would you test that it does the right thing? Well, you can +run it and compare the screen with a golden screen snapshot, but let's admit it: +tests like this are expensive to run and fragile (What if you just upgraded to a +shiny new graphics card that has better anti-aliasing? Suddenly you have to +update all your golden images.). It would be too painful if all your tests are +like this. Fortunately, you learned about +[Dependency Injection](https://en.wikipedia.org/wiki/Dependency_injection) and know the right thing +to do: instead of having your application talk to the system API directly, wrap +the API in an interface (say, `Turtle`) and code to that interface: + +```cpp +class Turtle { + ... + virtual ~Turtle() {} + virtual void PenUp() = 0; + virtual void PenDown() = 0; + virtual void Forward(int distance) = 0; + virtual void Turn(int degrees) = 0; + virtual void GoTo(int x, int y) = 0; + virtual int GetX() const = 0; + virtual int GetY() const = 0; +}; +``` + +(Note that the destructor of `Turtle` **must** be virtual, as is the case for +**all** classes you intend to inherit from - otherwise the destructor of the +derived class will not be called when you delete an object through a base +pointer, and you'll get corrupted program states like memory leaks.) + +You can control whether the turtle's movement will leave a trace using `PenUp()` +and `PenDown()`, and control its movement using `Forward()`, `Turn()`, and +`GoTo()`. Finally, `GetX()` and `GetY()` tell you the current position of the +turtle. + +Your program will normally use a real implementation of this interface. In +tests, you can use a mock implementation instead. This allows you to easily +check what drawing primitives your program is calling, with what arguments, and +in which order. Tests written this way are much more robust (they won't break +because your new machine does anti-aliasing differently), easier to read and +maintain (the intent of a test is expressed in the code, not in some binary +images), and run *much, much faster*. + +## Writing the Mock Class + +If you are lucky, the mocks you need to use have already been implemented by +some nice people. If, however, you find yourself in the position to write a mock +class, relax - gMock turns this task into a fun game! (Well, almost.) + +### How to Define It + +Using the `Turtle` interface as example, here are the simple steps you need to +follow: + +* Derive a class `MockTurtle` from `Turtle`. +* Take a *virtual* function of `Turtle` (while it's possible to + [mock non-virtual methods using templates](gmock_cook_book.md#MockingNonVirtualMethods), + it's much more involved). +* In the `public:` section of the child class, write `MOCK_METHOD();` +* Now comes the fun part: you take the function signature, cut-and-paste it + into the macro, and add two commas - one between the return type and the + name, another between the name and the argument list. +* If you're mocking a const method, add a 4th parameter containing `(const)` + (the parentheses are required). +* Since you're overriding a virtual method, we suggest adding the `override` + keyword. For const methods the 4th parameter becomes `(const, override)`, + for non-const methods just `(override)`. This isn't mandatory. +* Repeat until all virtual functions you want to mock are done. (It goes + without saying that *all* pure virtual methods in your abstract class must + be either mocked or overridden.) + +After the process, you should have something like: + +```cpp +#include // Brings in gMock. + +class MockTurtle : public Turtle { + public: + ... + MOCK_METHOD(void, PenUp, (), (override)); + MOCK_METHOD(void, PenDown, (), (override)); + MOCK_METHOD(void, Forward, (int distance), (override)); + MOCK_METHOD(void, Turn, (int degrees), (override)); + MOCK_METHOD(void, GoTo, (int x, int y), (override)); + MOCK_METHOD(int, GetX, (), (const, override)); + MOCK_METHOD(int, GetY, (), (const, override)); +}; +``` + +You don't need to define these mock methods somewhere else - the `MOCK_METHOD` +macro will generate the definitions for you. It's that simple! + +### Where to Put It + +When you define a mock class, you need to decide where to put its definition. +Some people put it in a `_test.cc`. This is fine when the interface being mocked +(say, `Foo`) is owned by the same person or team. Otherwise, when the owner of +`Foo` changes it, your test could break. (You can't really expect `Foo`'s +maintainer to fix every test that uses `Foo`, can you?) + +Generally, you should not mock classes you don't own. If you must mock such a +class owned by others, define the mock class in `Foo`'s Bazel package (usually +the same directory or a `testing` sub-directory), and put it in a `.h` and a +`cc_library` with `testonly=True`. Then everyone can reference them from their +tests. If `Foo` ever changes, there is only one copy of `MockFoo` to change, and +only tests that depend on the changed methods need to be fixed. + +Another way to do it: you can introduce a thin layer `FooAdaptor` on top of +`Foo` and code to this new interface. Since you own `FooAdaptor`, you can absorb +changes in `Foo` much more easily. While this is more work initially, carefully +choosing the adaptor interface can make your code easier to write and more +readable (a net win in the long run), as you can choose `FooAdaptor` to fit your +specific domain much better than `Foo` does. + +## Using Mocks in Tests + +Once you have a mock class, using it is easy. The typical work flow is: + +1. Import the gMock names from the `testing` namespace such that you can use + them unqualified (You only have to do it once per file). Remember that + namespaces are a good idea. +2. Create some mock objects. +3. Specify your expectations on them (How many times will a method be called? + With what arguments? What should it do? etc.). +4. Exercise some code that uses the mocks; optionally, check the result using + googletest assertions. If a mock method is called more than expected or with + wrong arguments, you'll get an error immediately. +5. When a mock is destructed, gMock will automatically check whether all + expectations on it have been satisfied. + +Here's an example: + +```cpp +#include "path/to/mock-turtle.h" +#include +#include + +using ::testing::AtLeast; // #1 + +TEST(PainterTest, CanDrawSomething) { + MockTurtle turtle; // #2 + EXPECT_CALL(turtle, PenDown()) // #3 + .Times(AtLeast(1)); + + Painter painter(&turtle); // #4 + + EXPECT_TRUE(painter.DrawCircle(0, 0, 10)); // #5 +} +``` + +As you might have guessed, this test checks that `PenDown()` is called at least +once. If the `painter` object didn't call this method, your test will fail with +a message like this: + +```text +path/to/my_test.cc:119: Failure +Actual function call count doesn't match this expectation: +Actually: never called; +Expected: called at least once. +Stack trace: +... +``` + +**Tip 1:** If you run the test from an Emacs buffer, you can hit `` on +the line number to jump right to the failed expectation. + +**Tip 2:** If your mock objects are never deleted, the final verification won't +happen. Therefore it's a good idea to turn on the heap checker in your tests +when you allocate mocks on the heap. You get that automatically if you use the +`gtest_main` library already. + +###### Expectation Ordering + +**Important note:** gMock requires expectations to be set **before** the mock +functions are called, otherwise the behavior is **undefined**. Do not alternate +between calls to `EXPECT_CALL()` and calls to the mock functions, and do not set +any expectations on a mock after passing the mock to an API. + +This means `EXPECT_CALL()` should be read as expecting that a call will occur +*in the future*, not that a call has occurred. Why does gMock work like that? +Well, specifying the expectation beforehand allows gMock to report a violation +as soon as it rises, when the context (stack trace, etc) is still available. +This makes debugging much easier. + +Admittedly, this test is contrived and doesn't do much. You can easily achieve +the same effect without using gMock. However, as we shall reveal soon, gMock +allows you to do *so much more* with the mocks. + +## Setting Expectations + +The key to using a mock object successfully is to set the *right expectations* +on it. If you set the expectations too strict, your test will fail as the result +of unrelated changes. If you set them too loose, bugs can slip through. You want +to do it just right such that your test can catch exactly the kind of bugs you +intend it to catch. gMock provides the necessary means for you to do it "just +right." + +### General Syntax + +In gMock we use the `EXPECT_CALL()` macro to set an expectation on a mock +method. The general syntax is: + +```cpp +EXPECT_CALL(mock_object, method(matchers)) + .Times(cardinality) + .WillOnce(action) + .WillRepeatedly(action); +``` + +The macro has two arguments: first the mock object, and then the method and its +arguments. Note that the two are separated by a comma (`,`), not a period (`.`). +(Why using a comma? The answer is that it was necessary for technical reasons.) +If the method is not overloaded, the macro can also be called without matchers: + +```cpp +EXPECT_CALL(mock_object, non-overloaded-method) + .Times(cardinality) + .WillOnce(action) + .WillRepeatedly(action); +``` + +This syntax allows the test writer to specify "called with any arguments" +without explicitly specifying the number or types of arguments. To avoid +unintended ambiguity, this syntax may only be used for methods that are not +overloaded. + +Either form of the macro can be followed by some optional *clauses* that provide +more information about the expectation. We'll discuss how each clause works in +the coming sections. + +This syntax is designed to make an expectation read like English. For example, +you can probably guess that + +```cpp +using ::testing::Return; +... +EXPECT_CALL(turtle, GetX()) + .Times(5) + .WillOnce(Return(100)) + .WillOnce(Return(150)) + .WillRepeatedly(Return(200)); +``` + +says that the `turtle` object's `GetX()` method will be called five times, it +will return 100 the first time, 150 the second time, and then 200 every time. +Some people like to call this style of syntax a Domain-Specific Language (DSL). + +{: .callout .note} +**Note:** Why do we use a macro to do this? Well it serves two purposes: first +it makes expectations easily identifiable (either by `grep` or by a human +reader), and second it allows gMock to include the source file location of a +failed expectation in messages, making debugging easier. + +### Matchers: What Arguments Do We Expect? + +When a mock function takes arguments, we may specify what arguments we are +expecting, for example: + +```cpp +// Expects the turtle to move forward by 100 units. +EXPECT_CALL(turtle, Forward(100)); +``` + +Oftentimes you do not want to be too specific. Remember that talk about tests +being too rigid? Over specification leads to brittle tests and obscures the +intent of tests. Therefore we encourage you to specify only what's necessary—no +more, no less. If you aren't interested in the value of an argument, write `_` +as the argument, which means "anything goes": + +```cpp +using ::testing::_; +... +// Expects that the turtle jumps to somewhere on the x=50 line. +EXPECT_CALL(turtle, GoTo(50, _)); +``` + +`_` is an instance of what we call **matchers**. A matcher is like a predicate +and can test whether an argument is what we'd expect. You can use a matcher +inside `EXPECT_CALL()` wherever a function argument is expected. `_` is a +convenient way of saying "any value". + +In the above examples, `100` and `50` are also matchers; implicitly, they are +the same as `Eq(100)` and `Eq(50)`, which specify that the argument must be +equal (using `operator==`) to the matcher argument. There are many +[built-in matchers](reference/matchers.md) for common types (as well as +[custom matchers](gmock_cook_book.md#NewMatchers)); for example: + +```cpp +using ::testing::Ge; +... +// Expects the turtle moves forward by at least 100. +EXPECT_CALL(turtle, Forward(Ge(100))); +``` + +If you don't care about *any* arguments, rather than specify `_` for each of +them you may instead omit the parameter list: + +```cpp +// Expects the turtle to move forward. +EXPECT_CALL(turtle, Forward); +// Expects the turtle to jump somewhere. +EXPECT_CALL(turtle, GoTo); +``` + +This works for all non-overloaded methods; if a method is overloaded, you need +to help gMock resolve which overload is expected by specifying the number of +arguments and possibly also the +[types of the arguments](gmock_cook_book.md#SelectOverload). + +### Cardinalities: How Many Times Will It Be Called? + +The first clause we can specify following an `EXPECT_CALL()` is `Times()`. We +call its argument a **cardinality** as it tells *how many times* the call should +occur. It allows us to repeat an expectation many times without actually writing +it as many times. More importantly, a cardinality can be "fuzzy", just like a +matcher can be. This allows a user to express the intent of a test exactly. + +An interesting special case is when we say `Times(0)`. You may have guessed - it +means that the function shouldn't be called with the given arguments at all, and +gMock will report a googletest failure whenever the function is (wrongfully) +called. + +We've seen `AtLeast(n)` as an example of fuzzy cardinalities earlier. For the +list of built-in cardinalities you can use, see +[here](gmock_cheat_sheet.md#CardinalityList). + +The `Times()` clause can be omitted. **If you omit `Times()`, gMock will infer +the cardinality for you.** The rules are easy to remember: + +* If **neither** `WillOnce()` **nor** `WillRepeatedly()` is in the + `EXPECT_CALL()`, the inferred cardinality is `Times(1)`. +* If there are *n* `WillOnce()`'s but **no** `WillRepeatedly()`, where *n* >= + 1, the cardinality is `Times(n)`. +* If there are *n* `WillOnce()`'s and **one** `WillRepeatedly()`, where *n* >= + 0, the cardinality is `Times(AtLeast(n))`. + +**Quick quiz:** what do you think will happen if a function is expected to be +called twice but actually called four times? + +### Actions: What Should It Do? + +Remember that a mock object doesn't really have a working implementation? We as +users have to tell it what to do when a method is invoked. This is easy in +gMock. + +First, if the return type of a mock function is a built-in type or a pointer, +the function has a **default action** (a `void` function will just return, a +`bool` function will return `false`, and other functions will return 0). In +addition, in C++ 11 and above, a mock function whose return type is +default-constructible (i.e. has a default constructor) has a default action of +returning a default-constructed value. If you don't say anything, this behavior +will be used. + +Second, if a mock function doesn't have a default action, or the default action +doesn't suit you, you can specify the action to be taken each time the +expectation matches using a series of `WillOnce()` clauses followed by an +optional `WillRepeatedly()`. For example, + +```cpp +using ::testing::Return; +... +EXPECT_CALL(turtle, GetX()) + .WillOnce(Return(100)) + .WillOnce(Return(200)) + .WillOnce(Return(300)); +``` + +says that `turtle.GetX()` will be called *exactly three times* (gMock inferred +this from how many `WillOnce()` clauses we've written, since we didn't +explicitly write `Times()`), and will return 100, 200, and 300 respectively. + +```cpp +using ::testing::Return; +... +EXPECT_CALL(turtle, GetY()) + .WillOnce(Return(100)) + .WillOnce(Return(200)) + .WillRepeatedly(Return(300)); +``` + +says that `turtle.GetY()` will be called *at least twice* (gMock knows this as +we've written two `WillOnce()` clauses and a `WillRepeatedly()` while having no +explicit `Times()`), will return 100 and 200 respectively the first two times, +and 300 from the third time on. + +Of course, if you explicitly write a `Times()`, gMock will not try to infer the +cardinality itself. What if the number you specified is larger than there are +`WillOnce()` clauses? Well, after all `WillOnce()`s are used up, gMock will do +the *default* action for the function every time (unless, of course, you have a +`WillRepeatedly()`.). + +What can we do inside `WillOnce()` besides `Return()`? You can return a +reference using `ReturnRef(`*`variable`*`)`, or invoke a pre-defined function, +among [others](gmock_cook_book.md#using-actions). + +**Important note:** The `EXPECT_CALL()` statement evaluates the action clause +only once, even though the action may be performed many times. Therefore you +must be careful about side effects. The following may not do what you want: + +```cpp +using ::testing::Return; +... +int n = 100; +EXPECT_CALL(turtle, GetX()) + .Times(4) + .WillRepeatedly(Return(n++)); +``` + +Instead of returning 100, 101, 102, ..., consecutively, this mock function will +always return 100 as `n++` is only evaluated once. Similarly, `Return(new Foo)` +will create a new `Foo` object when the `EXPECT_CALL()` is executed, and will +return the same pointer every time. If you want the side effect to happen every +time, you need to define a custom action, which we'll teach in the +[cook book](gmock_cook_book.md). + +Time for another quiz! What do you think the following means? + +```cpp +using ::testing::Return; +... +EXPECT_CALL(turtle, GetY()) + .Times(4) + .WillOnce(Return(100)); +``` + +Obviously `turtle.GetY()` is expected to be called four times. But if you think +it will return 100 every time, think twice! Remember that one `WillOnce()` +clause will be consumed each time the function is invoked and the default action +will be taken afterwards. So the right answer is that `turtle.GetY()` will +return 100 the first time, but **return 0 from the second time on**, as +returning 0 is the default action for `int` functions. + +### Using Multiple Expectations {#MultiExpectations} + +So far we've only shown examples where you have a single expectation. More +realistically, you'll specify expectations on multiple mock methods which may be +from multiple mock objects. + +By default, when a mock method is invoked, gMock will search the expectations in +the **reverse order** they are defined, and stop when an active expectation that +matches the arguments is found (you can think of it as "newer rules override +older ones."). If the matching expectation cannot take any more calls, you will +get an upper-bound-violated failure. Here's an example: + +```cpp +using ::testing::_; +... +EXPECT_CALL(turtle, Forward(_)); // #1 +EXPECT_CALL(turtle, Forward(10)) // #2 + .Times(2); +``` + +If `Forward(10)` is called three times in a row, the third time it will be an +error, as the last matching expectation (#2) has been saturated. If, however, +the third `Forward(10)` call is replaced by `Forward(20)`, then it would be OK, +as now #1 will be the matching expectation. + +{: .callout .note} +**Note:** Why does gMock search for a match in the *reverse* order of the +expectations? The reason is that this allows a user to set up the default +expectations in a mock object's constructor or the test fixture's set-up phase +and then customize the mock by writing more specific expectations in the test +body. So, if you have two expectations on the same method, you want to put the +one with more specific matchers **after** the other, or the more specific rule +would be shadowed by the more general one that comes after it. + +{: .callout .tip} +**Tip:** It is very common to start with a catch-all expectation for a method +and `Times(AnyNumber())` (omitting arguments, or with `_` for all arguments, if +overloaded). This makes any calls to the method expected. This is not necessary +for methods that are not mentioned at all (these are "uninteresting"), but is +useful for methods that have some expectations, but for which other calls are +ok. See +[Understanding Uninteresting vs Unexpected Calls](gmock_cook_book.md#uninteresting-vs-unexpected). + +### Ordered vs Unordered Calls {#OrderedCalls} + +By default, an expectation can match a call even though an earlier expectation +hasn't been satisfied. In other words, the calls don't have to occur in the +order the expectations are specified. + +Sometimes, you may want all the expected calls to occur in a strict order. To +say this in gMock is easy: + +```cpp +using ::testing::InSequence; +... +TEST(FooTest, DrawsLineSegment) { + ... + { + InSequence seq; + + EXPECT_CALL(turtle, PenDown()); + EXPECT_CALL(turtle, Forward(100)); + EXPECT_CALL(turtle, PenUp()); + } + Foo(); +} +``` + +By creating an object of type `InSequence`, all expectations in its scope are +put into a *sequence* and have to occur *sequentially*. Since we are just +relying on the constructor and destructor of this object to do the actual work, +its name is really irrelevant. + +In this example, we test that `Foo()` calls the three expected functions in the +order as written. If a call is made out-of-order, it will be an error. + +(What if you care about the relative order of some of the calls, but not all of +them? Can you specify an arbitrary partial order? The answer is ... yes! The +details can be found [here](gmock_cook_book.md#OrderedCalls).) + +### All Expectations Are Sticky (Unless Said Otherwise) {#StickyExpectations} + +Now let's do a quick quiz to see how well you can use this mock stuff already. +How would you test that the turtle is asked to go to the origin *exactly twice* +(you want to ignore any other instructions it receives)? + +After you've come up with your answer, take a look at ours and compare notes +(solve it yourself first - don't cheat!): + +```cpp +using ::testing::_; +using ::testing::AnyNumber; +... +EXPECT_CALL(turtle, GoTo(_, _)) // #1 + .Times(AnyNumber()); +EXPECT_CALL(turtle, GoTo(0, 0)) // #2 + .Times(2); +``` + +Suppose `turtle.GoTo(0, 0)` is called three times. In the third time, gMock will +see that the arguments match expectation #2 (remember that we always pick the +last matching expectation). Now, since we said that there should be only two +such calls, gMock will report an error immediately. This is basically what we've +told you in the [Using Multiple Expectations](#MultiExpectations) section above. + +This example shows that **expectations in gMock are "sticky" by default**, in +the sense that they remain active even after we have reached their invocation +upper bounds. This is an important rule to remember, as it affects the meaning +of the spec, and is **different** to how it's done in many other mocking +frameworks (Why'd we do that? Because we think our rule makes the common cases +easier to express and understand.). + +Simple? Let's see if you've really understood it: what does the following code +say? + +```cpp +using ::testing::Return; +... +for (int i = n; i > 0; i--) { + EXPECT_CALL(turtle, GetX()) + .WillOnce(Return(10*i)); +} +``` + +If you think it says that `turtle.GetX()` will be called `n` times and will +return 10, 20, 30, ..., consecutively, think twice! The problem is that, as we +said, expectations are sticky. So, the second time `turtle.GetX()` is called, +the last (latest) `EXPECT_CALL()` statement will match, and will immediately +lead to an "upper bound violated" error - this piece of code is not very useful! + +One correct way of saying that `turtle.GetX()` will return 10, 20, 30, ..., is +to explicitly say that the expectations are *not* sticky. In other words, they +should *retire* as soon as they are saturated: + +```cpp +using ::testing::Return; +... +for (int i = n; i > 0; i--) { + EXPECT_CALL(turtle, GetX()) + .WillOnce(Return(10*i)) + .RetiresOnSaturation(); +} +``` + +And, there's a better way to do it: in this case, we expect the calls to occur +in a specific order, and we line up the actions to match the order. Since the +order is important here, we should make it explicit using a sequence: + +```cpp +using ::testing::InSequence; +using ::testing::Return; +... +{ + InSequence s; + + for (int i = 1; i <= n; i++) { + EXPECT_CALL(turtle, GetX()) + .WillOnce(Return(10*i)) + .RetiresOnSaturation(); + } +} +``` + +By the way, the other situation where an expectation may *not* be sticky is when +it's in a sequence - as soon as another expectation that comes after it in the +sequence has been used, it automatically retires (and will never be used to +match any call). + +### Uninteresting Calls + +A mock object may have many methods, and not all of them are that interesting. +For example, in some tests we may not care about how many times `GetX()` and +`GetY()` get called. + +In gMock, if you are not interested in a method, just don't say anything about +it. If a call to this method occurs, you'll see a warning in the test output, +but it won't be a failure. This is called "naggy" behavior; to change, see +[The Nice, the Strict, and the Naggy](gmock_cook_book.md#NiceStrictNaggy). diff --git a/test/googletest/docs/index.md b/test/googletest/docs/index.md new file mode 100644 index 0000000..b162c74 --- /dev/null +++ b/test/googletest/docs/index.md @@ -0,0 +1,22 @@ +# GoogleTest User's Guide + +## Welcome to GoogleTest! + +GoogleTest is Google's C++ testing and mocking framework. This user's guide has +the following contents: + +* [GoogleTest Primer](primer.md) - Teaches you how to write simple tests using + GoogleTest. Read this first if you are new to GoogleTest. +* [GoogleTest Advanced](advanced.md) - Read this when you've finished the + Primer and want to utilize GoogleTest to its full potential. +* [GoogleTest Samples](samples.md) - Describes some GoogleTest samples. +* [GoogleTest FAQ](faq.md) - Have a question? Want some tips? Check here + first. +* [Mocking for Dummies](gmock_for_dummies.md) - Teaches you how to create mock + objects and use them in tests. +* [Mocking Cookbook](gmock_cook_book.md) - Includes tips and approaches to + common mocking use cases. +* [Mocking Cheat Sheet](gmock_cheat_sheet.md) - A handy reference for + matchers, actions, invariants, and more. +* [Mocking FAQ](gmock_faq.md) - Contains answers to some mocking-specific + questions. diff --git a/test/googletest/docs/pkgconfig.md b/test/googletest/docs/pkgconfig.md new file mode 100644 index 0000000..bf05d59 --- /dev/null +++ b/test/googletest/docs/pkgconfig.md @@ -0,0 +1,144 @@ +## Using GoogleTest from various build systems + +GoogleTest comes with pkg-config files that can be used to determine all +necessary flags for compiling and linking to GoogleTest (and GoogleMock). +Pkg-config is a standardised plain-text format containing + +* the includedir (-I) path +* necessary macro (-D) definitions +* further required flags (-pthread) +* the library (-L) path +* the library (-l) to link to + +All current build systems support pkg-config in one way or another. For all +examples here we assume you want to compile the sample +`samples/sample3_unittest.cc`. + +### CMake + +Using `pkg-config` in CMake is fairly easy: + +```cmake +find_package(PkgConfig) +pkg_search_module(GTEST REQUIRED gtest_main) + +add_executable(testapp) +target_sources(testapp PRIVATE samples/sample3_unittest.cc) +target_link_libraries(testapp PRIVATE ${GTEST_LDFLAGS}) +target_compile_options(testapp PRIVATE ${GTEST_CFLAGS}) + +enable_testing() +add_test(first_and_only_test testapp) +``` + +It is generally recommended that you use `target_compile_options` + `_CFLAGS` +over `target_include_directories` + `_INCLUDE_DIRS` as the former includes not +just -I flags (GoogleTest might require a macro indicating to internal headers +that all libraries have been compiled with threading enabled. In addition, +GoogleTest might also require `-pthread` in the compiling step, and as such +splitting the pkg-config `Cflags` variable into include dirs and macros for +`target_compile_definitions()` might still miss this). The same recommendation +goes for using `_LDFLAGS` over the more commonplace `_LIBRARIES`, which happens +to discard `-L` flags and `-pthread`. + +### Help! pkg-config can't find GoogleTest! + +Let's say you have a `CMakeLists.txt` along the lines of the one in this +tutorial and you try to run `cmake`. It is very possible that you get a failure +along the lines of: + +``` +-- Checking for one of the modules 'gtest_main' +CMake Error at /usr/share/cmake/Modules/FindPkgConfig.cmake:640 (message): + None of the required 'gtest_main' found +``` + +These failures are common if you installed GoogleTest yourself and have not +sourced it from a distro or other package manager. If so, you need to tell +pkg-config where it can find the `.pc` files containing the information. Say you +installed GoogleTest to `/usr/local`, then it might be that the `.pc` files are +installed under `/usr/local/lib64/pkgconfig`. If you set + +``` +export PKG_CONFIG_PATH=/usr/local/lib64/pkgconfig +``` + +pkg-config will also try to look in `PKG_CONFIG_PATH` to find `gtest_main.pc`. + +### Using pkg-config in a cross-compilation setting + +Pkg-config can be used in a cross-compilation setting too. To do this, let's +assume the final prefix of the cross-compiled installation will be `/usr`, and +your sysroot is `/home/MYUSER/sysroot`. Configure and install GTest using + +``` +mkdir build && cmake -DCMAKE_INSTALL_PREFIX=/usr .. +``` + +Install into the sysroot using `DESTDIR`: + +``` +make -j install DESTDIR=/home/MYUSER/sysroot +``` + +Before we continue, it is recommended to **always** define the following two +variables for pkg-config in a cross-compilation setting: + +``` +export PKG_CONFIG_ALLOW_SYSTEM_CFLAGS=yes +export PKG_CONFIG_ALLOW_SYSTEM_LIBS=yes +``` + +otherwise `pkg-config` will filter `-I` and `-L` flags against standard prefixes +such as `/usr` (see https://bugs.freedesktop.org/show_bug.cgi?id=28264#c3 for +reasons why this stripping needs to occur usually). + +If you look at the generated pkg-config file, it will look something like + +``` +libdir=/usr/lib64 +includedir=/usr/include + +Name: gtest +Description: GoogleTest (without main() function) +Version: 1.11.0 +URL: https://github.com/google/googletest +Libs: -L${libdir} -lgtest -lpthread +Cflags: -I${includedir} -DGTEST_HAS_PTHREAD=1 -lpthread +``` + +Notice that the sysroot is not included in `libdir` and `includedir`! If you try +to run `pkg-config` with the correct +`PKG_CONFIG_LIBDIR=/home/MYUSER/sysroot/usr/lib64/pkgconfig` against this `.pc` +file, you will get + +``` +$ pkg-config --cflags gtest +-DGTEST_HAS_PTHREAD=1 -lpthread -I/usr/include +$ pkg-config --libs gtest +-L/usr/lib64 -lgtest -lpthread +``` + +which is obviously wrong and points to the `CBUILD` and not `CHOST` root. In +order to use this in a cross-compilation setting, we need to tell pkg-config to +inject the actual sysroot into `-I` and `-L` variables. Let us now tell +pkg-config about the actual sysroot + +``` +export PKG_CONFIG_DIR= +export PKG_CONFIG_SYSROOT_DIR=/home/MYUSER/sysroot +export PKG_CONFIG_LIBDIR=${PKG_CONFIG_SYSROOT_DIR}/usr/lib64/pkgconfig +``` + +and running `pkg-config` again we get + +``` +$ pkg-config --cflags gtest +-DGTEST_HAS_PTHREAD=1 -lpthread -I/home/MYUSER/sysroot/usr/include +$ pkg-config --libs gtest +-L/home/MYUSER/sysroot/usr/lib64 -lgtest -lpthread +``` + +which contains the correct sysroot now. For a more comprehensive guide to also +including `${CHOST}` in build system calls, see the excellent tutorial by Diego +Elio Pettenò: diff --git a/test/googletest/docs/platforms.md b/test/googletest/docs/platforms.md new file mode 100644 index 0000000..d35a7be --- /dev/null +++ b/test/googletest/docs/platforms.md @@ -0,0 +1,8 @@ +# Supported Platforms + +GoogleTest follows Google's +[Foundational C++ Support Policy](https://opensource.google/documentation/policies/cplusplus-support). +See +[this table](https://github.com/google/oss-policies-info/blob/main/foundational-cxx-support-matrix.md) +for a list of currently supported versions compilers, platforms, and build +tools. diff --git a/test/googletest/docs/primer.md b/test/googletest/docs/primer.md new file mode 100644 index 0000000..61806be --- /dev/null +++ b/test/googletest/docs/primer.md @@ -0,0 +1,482 @@ +# GoogleTest Primer + +## Introduction: Why GoogleTest? + +*GoogleTest* helps you write better C++ tests. + +GoogleTest is a testing framework developed by the Testing Technology team with +Google's specific requirements and constraints in mind. Whether you work on +Linux, Windows, or a Mac, if you write C++ code, GoogleTest can help you. And it +supports *any* kind of tests, not just unit tests. + +So what makes a good test, and how does GoogleTest fit in? We believe: + +1. Tests should be *independent* and *repeatable*. It's a pain to debug a test + that succeeds or fails as a result of other tests. GoogleTest isolates the + tests by running each of them on a different object. When a test fails, + GoogleTest allows you to run it in isolation for quick debugging. +2. Tests should be well *organized* and reflect the structure of the tested + code. GoogleTest groups related tests into test suites that can share data + and subroutines. This common pattern is easy to recognize and makes tests + easy to maintain. Such consistency is especially helpful when people switch + projects and start to work on a new code base. +3. Tests should be *portable* and *reusable*. Google has a lot of code that is + platform-neutral; its tests should also be platform-neutral. GoogleTest + works on different OSes, with different compilers, with or without + exceptions, so GoogleTest tests can work with a variety of configurations. +4. When tests fail, they should provide as much *information* about the problem + as possible. GoogleTest doesn't stop at the first test failure. Instead, it + only stops the current test and continues with the next. You can also set up + tests that report non-fatal failures after which the current test continues. + Thus, you can detect and fix multiple bugs in a single run-edit-compile + cycle. +5. The testing framework should liberate test writers from housekeeping chores + and let them focus on the test *content*. GoogleTest automatically keeps + track of all tests defined, and doesn't require the user to enumerate them + in order to run them. +6. Tests should be *fast*. With GoogleTest, you can reuse shared resources + across tests and pay for the set-up/tear-down only once, without making + tests depend on each other. + +Since GoogleTest is based on the popular xUnit architecture, you'll feel right +at home if you've used JUnit or PyUnit before. If not, it will take you about 10 +minutes to learn the basics and get started. So let's go! + +## Beware of the Nomenclature + +{: .callout .note} +*Note:* There might be some confusion arising from different definitions of the +terms *Test*, *Test Case* and *Test Suite*, so beware of misunderstanding these. + +Historically, GoogleTest started to use the term *Test Case* for grouping +related tests, whereas current publications, including International Software +Testing Qualifications Board ([ISTQB](https://www.istqb.org/)) materials and +various textbooks on software quality, use the term +*[Test Suite][istqb test suite]* for this. + +The related term *Test*, as it is used in GoogleTest, corresponds to the term +*[Test Case][istqb test case]* of ISTQB and others. + +The term *Test* is commonly of broad enough sense, including ISTQB's definition +of *Test Case*, so it's not much of a problem here. But the term *Test Case* as +was used in Google Test is of contradictory sense and thus confusing. + +GoogleTest recently started replacing the term *Test Case* with *Test Suite*. +The preferred API is *TestSuite*. The older TestCase API is being slowly +deprecated and refactored away. + +So please be aware of the different definitions of the terms: + + +Meaning | GoogleTest Term | [ISTQB](https://www.istqb.org/) Term +:----------------------------------------------------------------------------------- | :---------------------- | :---------------------------------- +Exercise a particular program path with specific input values and verify the results | [TEST()](#simple-tests) | [Test Case][istqb test case] + + +[istqb test case]: https://glossary.istqb.org/en_US/term/test-case-2 +[istqb test suite]: https://glossary.istqb.org/en_US/term/test-suite-1-3 + +## Basic Concepts + +When using GoogleTest, you start by writing *assertions*, which are statements +that check whether a condition is true. An assertion's result can be *success*, +*nonfatal failure*, or *fatal failure*. If a fatal failure occurs, it aborts the +current function; otherwise the program continues normally. + +*Tests* use assertions to verify the tested code's behavior. If a test crashes +or has a failed assertion, then it *fails*; otherwise it *succeeds*. + +A *test suite* contains one or many tests. You should group your tests into test +suites that reflect the structure of the tested code. When multiple tests in a +test suite need to share common objects and subroutines, you can put them into a +*test fixture* class. + +A *test program* can contain multiple test suites. + +We'll now explain how to write a test program, starting at the individual +assertion level and building up to tests and test suites. + +## Assertions + +GoogleTest assertions are macros that resemble function calls. You test a class +or function by making assertions about its behavior. When an assertion fails, +GoogleTest prints the assertion's source file and line number location, along +with a failure message. You may also supply a custom failure message which will +be appended to GoogleTest's message. + +The assertions come in pairs that test the same thing but have different effects +on the current function. `ASSERT_*` versions generate fatal failures when they +fail, and **abort the current function**. `EXPECT_*` versions generate nonfatal +failures, which don't abort the current function. Usually `EXPECT_*` are +preferred, as they allow more than one failure to be reported in a test. +However, you should use `ASSERT_*` if it doesn't make sense to continue when the +assertion in question fails. + +Since a failed `ASSERT_*` returns from the current function immediately, +possibly skipping clean-up code that comes after it, it may cause a space leak. +Depending on the nature of the leak, it may or may not be worth fixing - so keep +this in mind if you get a heap checker error in addition to assertion errors. + +To provide a custom failure message, simply stream it into the macro using the +`<<` operator or a sequence of such operators. See the following example, using +the [`ASSERT_EQ` and `EXPECT_EQ`](reference/assertions.md#EXPECT_EQ) macros to +verify value equality: + +```c++ +ASSERT_EQ(x.size(), y.size()) << "Vectors x and y are of unequal length"; + +for (int i = 0; i < x.size(); ++i) { + EXPECT_EQ(x[i], y[i]) << "Vectors x and y differ at index " << i; +} +``` + +Anything that can be streamed to an `ostream` can be streamed to an assertion +macro--in particular, C strings and `string` objects. If a wide string +(`wchar_t*`, `TCHAR*` in `UNICODE` mode on Windows, or `std::wstring`) is +streamed to an assertion, it will be translated to UTF-8 when printed. + +GoogleTest provides a collection of assertions for verifying the behavior of +your code in various ways. You can check Boolean conditions, compare values +based on relational operators, verify string values, floating-point values, and +much more. There are even assertions that enable you to verify more complex +states by providing custom predicates. For the complete list of assertions +provided by GoogleTest, see the [Assertions Reference](reference/assertions.md). + +## Simple Tests + +To create a test: + +1. Use the `TEST()` macro to define and name a test function. These are + ordinary C++ functions that don't return a value. +2. In this function, along with any valid C++ statements you want to include, + use the various GoogleTest assertions to check values. +3. The test's result is determined by the assertions; if any assertion in the + test fails (either fatally or non-fatally), or if the test crashes, the + entire test fails. Otherwise, it succeeds. + +```c++ +TEST(TestSuiteName, TestName) { + ... test body ... +} +``` + +`TEST()` arguments go from general to specific. The *first* argument is the name +of the test suite, and the *second* argument is the test's name within the test +suite. Both names must be valid C++ identifiers, and they should not contain any +underscores (`_`). A test's *full name* consists of its containing test suite +and its individual name. Tests from different test suites can have the same +individual name. + +For example, let's take a simple integer function: + +```c++ +int Factorial(int n); // Returns the factorial of n +``` + +A test suite for this function might look like: + +```c++ +// Tests factorial of 0. +TEST(FactorialTest, HandlesZeroInput) { + EXPECT_EQ(Factorial(0), 1); +} + +// Tests factorial of positive numbers. +TEST(FactorialTest, HandlesPositiveInput) { + EXPECT_EQ(Factorial(1), 1); + EXPECT_EQ(Factorial(2), 2); + EXPECT_EQ(Factorial(3), 6); + EXPECT_EQ(Factorial(8), 40320); +} +``` + +GoogleTest groups the test results by test suites, so logically related tests +should be in the same test suite; in other words, the first argument to their +`TEST()` should be the same. In the above example, we have two tests, +`HandlesZeroInput` and `HandlesPositiveInput`, that belong to the same test +suite `FactorialTest`. + +When naming your test suites and tests, you should follow the same convention as +for +[naming functions and classes](https://google.github.io/styleguide/cppguide.html#Function_Names). + +**Availability**: Linux, Windows, Mac. + +## Test Fixtures: Using the Same Data Configuration for Multiple Tests {#same-data-multiple-tests} + +If you find yourself writing two or more tests that operate on similar data, you +can use a *test fixture*. This allows you to reuse the same configuration of +objects for several different tests. + +To create a fixture: + +1. Derive a class from `testing::Test` . Start its body with `protected:`, as + we'll want to access fixture members from sub-classes. +2. Inside the class, declare any objects you plan to use. +3. If necessary, write a default constructor or `SetUp()` function to prepare + the objects for each test. A common mistake is to spell `SetUp()` as + **`Setup()`** with a small `u` - Use `override` in C++11 to make sure you + spelled it correctly. +4. If necessary, write a destructor or `TearDown()` function to release any + resources you allocated in `SetUp()` . To learn when you should use the + constructor/destructor and when you should use `SetUp()/TearDown()`, read + the [FAQ](faq.md#CtorVsSetUp). +5. If needed, define subroutines for your tests to share. + +When using a fixture, use `TEST_F()` instead of `TEST()` as it allows you to +access objects and subroutines in the test fixture: + +```c++ +TEST_F(TestFixtureClassName, TestName) { + ... test body ... +} +``` + +Unlike `TEST()`, in `TEST_F()` the first argument must be the name of the test +fixture class. (`_F` stands for "Fixture"). No test suite name is specified for +this macro. + +Unfortunately, the C++ macro system does not allow us to create a single macro +that can handle both types of tests. Using the wrong macro causes a compiler +error. + +Also, you must first define a test fixture class before using it in a +`TEST_F()`, or you'll get the compiler error "`virtual outside class +declaration`". + +For each test defined with `TEST_F()`, GoogleTest will create a *fresh* test +fixture at runtime, immediately initialize it via `SetUp()`, run the test, clean +up by calling `TearDown()`, and then delete the test fixture. Note that +different tests in the same test suite have different test fixture objects, and +GoogleTest always deletes a test fixture before it creates the next one. +GoogleTest does **not** reuse the same test fixture for multiple tests. Any +changes one test makes to the fixture do not affect other tests. + +As an example, let's write tests for a FIFO queue class named `Queue`, which has +the following interface: + +```c++ +template // E is the element type. +class Queue { + public: + Queue(); + void Enqueue(const E& element); + E* Dequeue(); // Returns NULL if the queue is empty. + size_t size() const; + ... +}; +``` + +First, define a fixture class. By convention, you should give it the name +`FooTest` where `Foo` is the class being tested. + +```c++ +class QueueTest : public testing::Test { + protected: + QueueTest() { + // q0_ remains empty + q1_.Enqueue(1); + q2_.Enqueue(2); + q2_.Enqueue(3); + } + + // ~QueueTest() override = default; + + Queue q0_; + Queue q1_; + Queue q2_; +}; +``` + +In this case, we don't need to define a destructor or a `TearDown()` method, +because the implicit destructor generated by the compiler will perform all of +the necessary cleanup. + +Now we'll write tests using `TEST_F()` and this fixture. + +```c++ +TEST_F(QueueTest, IsEmptyInitially) { + EXPECT_EQ(q0_.size(), 0); +} + +TEST_F(QueueTest, DequeueWorks) { + int* n = q0_.Dequeue(); + EXPECT_EQ(n, nullptr); + + n = q1_.Dequeue(); + ASSERT_NE(n, nullptr); + EXPECT_EQ(*n, 1); + EXPECT_EQ(q1_.size(), 0); + delete n; + + n = q2_.Dequeue(); + ASSERT_NE(n, nullptr); + EXPECT_EQ(*n, 2); + EXPECT_EQ(q2_.size(), 1); + delete n; +} +``` + +The above uses both `ASSERT_*` and `EXPECT_*` assertions. The rule of thumb is +to use `EXPECT_*` when you want the test to continue to reveal more errors after +the assertion failure, and use `ASSERT_*` when continuing after failure doesn't +make sense. For example, the second assertion in the `Dequeue` test is +`ASSERT_NE(n, nullptr)`, as we need to dereference the pointer `n` later, which +would lead to a segfault when `n` is `NULL`. + +When these tests run, the following happens: + +1. GoogleTest constructs a `QueueTest` object (let's call it `t1`). +2. The first test (`IsEmptyInitially`) runs on `t1`. +3. `t1` is destructed. +4. The above steps are repeated on another `QueueTest` object, this time + running the `DequeueWorks` test. + +**Availability**: Linux, Windows, Mac. + +## Invoking the Tests + +`TEST()` and `TEST_F()` implicitly register their tests with GoogleTest. So, +unlike with many other C++ testing frameworks, you don't have to re-list all +your defined tests in order to run them. + +After defining your tests, you can run them with `RUN_ALL_TESTS()`, which +returns `0` if all the tests are successful, or `1` otherwise. Note that +`RUN_ALL_TESTS()` runs *all tests* in your link unit--they can be from different +test suites, or even different source files. + +When invoked, the `RUN_ALL_TESTS()` macro: + +* Saves the state of all GoogleTest flags. + +* Creates a test fixture object for the first test. + +* Initializes it via `SetUp()`. + +* Runs the test on the fixture object. + +* Cleans up the fixture via `TearDown()`. + +* Deletes the fixture. + +* Restores the state of all GoogleTest flags. + +* Repeats the above steps for the next test, until all tests have run. + +If a fatal failure happens the subsequent steps will be skipped. + +{: .callout .important} +> IMPORTANT: You must **not** ignore the return value of `RUN_ALL_TESTS()`, or +> you will get a compiler error. The rationale for this design is that the +> automated testing service determines whether a test has passed based on its +> exit code, not on its stdout/stderr output; thus your `main()` function must +> return the value of `RUN_ALL_TESTS()`. +> +> Also, you should call `RUN_ALL_TESTS()` only **once**. Calling it more than +> once conflicts with some advanced GoogleTest features (e.g., thread-safe +> [death tests](advanced.md#death-tests)) and thus is not supported. + +**Availability**: Linux, Windows, Mac. + +## Writing the main() Function + +Most users should *not* need to write their own `main` function and instead link +with `gtest_main` (as opposed to with `gtest`), which defines a suitable entry +point. See the end of this section for details. The remainder of this section +should only apply when you need to do something custom before the tests run that +cannot be expressed within the framework of fixtures and test suites. + +If you write your own `main` function, it should return the value of +`RUN_ALL_TESTS()`. + +You can start from this boilerplate: + +```c++ +#include "this/package/foo.h" + +#include + +namespace my { +namespace project { +namespace { + +// The fixture for testing class Foo. +class FooTest : public testing::Test { + protected: + // You can remove any or all of the following functions if their bodies would + // be empty. + + FooTest() { + // You can do set-up work for each test here. + } + + ~FooTest() override { + // You can do clean-up work that doesn't throw exceptions here. + } + + // If the constructor and destructor are not enough for setting up + // and cleaning up each test, you can define the following methods: + + void SetUp() override { + // Code here will be called immediately after the constructor (right + // before each test). + } + + void TearDown() override { + // Code here will be called immediately after each test (right + // before the destructor). + } + + // Class members declared here can be used by all tests in the test suite + // for Foo. +}; + +// Tests that the Foo::Bar() method does Abc. +TEST_F(FooTest, MethodBarDoesAbc) { + const std::string input_filepath = "this/package/testdata/myinputfile.dat"; + const std::string output_filepath = "this/package/testdata/myoutputfile.dat"; + Foo f; + EXPECT_EQ(f.Bar(input_filepath, output_filepath), 0); +} + +// Tests that Foo does Xyz. +TEST_F(FooTest, DoesXyz) { + // Exercises the Xyz feature of Foo. +} + +} // namespace +} // namespace project +} // namespace my + +int main(int argc, char **argv) { + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} +``` + +The `testing::InitGoogleTest()` function parses the command line for GoogleTest +flags, and removes all recognized flags. This allows the user to control a test +program's behavior via various flags, which we'll cover in the +[AdvancedGuide](advanced.md). You **must** call this function before calling +`RUN_ALL_TESTS()`, or the flags won't be properly initialized. + +On Windows, `InitGoogleTest()` also works with wide strings, so it can be used +in programs compiled in `UNICODE` mode as well. + +But maybe you think that writing all those `main` functions is too much work? We +agree with you completely, and that's why Google Test provides a basic +implementation of main(). If it fits your needs, then just link your test with +the `gtest_main` library and you are good to go. + +{: .callout .note} +NOTE: `ParseGUnitFlags()` is deprecated in favor of `InitGoogleTest()`. + +## Known Limitations + +* Google Test is designed to be thread-safe. The implementation is thread-safe + on systems where the `pthreads` library is available. It is currently + *unsafe* to use Google Test assertions from two threads concurrently on + other systems (e.g. Windows). In most tests this is not an issue as usually + the assertions are done in the main thread. If you want to help, you can + volunteer to implement the necessary synchronization primitives in + `gtest-port.h` for your platform. diff --git a/test/googletest/docs/quickstart-bazel.md b/test/googletest/docs/quickstart-bazel.md new file mode 100644 index 0000000..5750f02 --- /dev/null +++ b/test/googletest/docs/quickstart-bazel.md @@ -0,0 +1,146 @@ +# Quickstart: Building with Bazel + +This tutorial aims to get you up and running with GoogleTest using the Bazel +build system. If you're using GoogleTest for the first time or need a refresher, +we recommend this tutorial as a starting point. + +## Prerequisites + +To complete this tutorial, you'll need: + +* A compatible operating system (e.g. Linux, macOS, Windows). +* A compatible C++ compiler that supports at least C++14. +* [Bazel](https://bazel.build/) 7.0 or higher, the preferred build system used + by the GoogleTest team. + +See [Supported Platforms](platforms.md) for more information about platforms +compatible with GoogleTest. + +If you don't already have Bazel installed, see the +[Bazel installation guide](https://bazel.build/install). + +{: .callout .note} Note: The terminal commands in this tutorial show a Unix +shell prompt, but the commands work on the Windows command line as well. + +## Set up a Bazel workspace + +A +[Bazel workspace](https://docs.bazel.build/versions/main/build-ref.html#workspace) +is a directory on your filesystem that you use to manage source files for the +software you want to build. Each workspace directory has a text file named +`MODULE.bazel` which may be empty, or may contain references to external +dependencies required to build the outputs. + +First, create a directory for your workspace: + +``` +$ mkdir my_workspace && cd my_workspace +``` + +Next, you’ll create the `MODULE.bazel` file to specify dependencies. As of Bazel +7.0, the recommended way to consume GoogleTest is through the +[Bazel Central Registry](https://registry.bazel.build/modules/googletest). To do +this, create a `MODULE.bazel` file in the root directory of your Bazel workspace +with the following content: + +``` +# MODULE.bazel + +# Choose the most recent version available at +# https://registry.bazel.build/modules/googletest +bazel_dep(name = "googletest", version = "1.15.2") +``` + +Now you're ready to build C++ code that uses GoogleTest. + +## Create and run a binary + +With your Bazel workspace set up, you can now use GoogleTest code within your +own project. + +As an example, create a file named `hello_test.cc` in your `my_workspace` +directory with the following contents: + +```cpp +#include + +// Demonstrate some basic assertions. +TEST(HelloTest, BasicAssertions) { + // Expect two strings not to be equal. + EXPECT_STRNE("hello", "world"); + // Expect equality. + EXPECT_EQ(7 * 6, 42); +} +``` + +GoogleTest provides [assertions](primer.md#assertions) that you use to test the +behavior of your code. The above sample includes the main GoogleTest header file +and demonstrates some basic assertions. + +To build the code, create a file named `BUILD` in the same directory with the +following contents: + +``` +cc_test( + name = "hello_test", + size = "small", + srcs = ["hello_test.cc"], + deps = [ + "@googletest//:gtest", + "@googletest//:gtest_main", + ], +) +``` + +This `cc_test` rule declares the C++ test binary you want to build, and links to +the GoogleTest library (`@googletest//:gtest"`) and the GoogleTest `main()` +function (`@googletest//:gtest_main`). For more information about Bazel `BUILD` +files, see the +[Bazel C++ Tutorial](https://docs.bazel.build/versions/main/tutorial/cpp.html). + +{: .callout .note} +NOTE: In the example below, we assume Clang or GCC and set `--cxxopt=-std=c++14` +to ensure that GoogleTest is compiled as C++14 instead of the compiler's default +setting (which could be C++11). For MSVC, the equivalent would be +`--cxxopt=/std:c++14`. See [Supported Platforms](platforms.md) for more details +on supported language versions. + +Now you can build and run your test: + +
+$ bazel test --cxxopt=-std=c++14 --test_output=all //:hello_test
+INFO: Analyzed target //:hello_test (26 packages loaded, 362 targets configured).
+INFO: Found 1 test target...
+INFO: From Testing //:hello_test:
+==================== Test output for //:hello_test:
+Running main() from gmock_main.cc
+[==========] Running 1 test from 1 test suite.
+[----------] Global test environment set-up.
+[----------] 1 test from HelloTest
+[ RUN      ] HelloTest.BasicAssertions
+[       OK ] HelloTest.BasicAssertions (0 ms)
+[----------] 1 test from HelloTest (0 ms total)
+
+[----------] Global test environment tear-down
+[==========] 1 test from 1 test suite ran. (0 ms total)
+[  PASSED  ] 1 test.
+================================================================================
+Target //:hello_test up-to-date:
+  bazel-bin/hello_test
+INFO: Elapsed time: 4.190s, Critical Path: 3.05s
+INFO: 27 processes: 8 internal, 19 linux-sandbox.
+INFO: Build completed successfully, 27 total actions
+//:hello_test                                                     PASSED in 0.1s
+
+INFO: Build completed successfully, 27 total actions
+
+ +Congratulations! You've successfully built and run a test binary using +GoogleTest. + +## Next steps + +* [Check out the Primer](primer.md) to start learning how to write simple + tests. +* [See the code samples](samples.md) for more examples showing how to use a + variety of GoogleTest features. diff --git a/test/googletest/docs/quickstart-cmake.md b/test/googletest/docs/quickstart-cmake.md new file mode 100644 index 0000000..4e422b7 --- /dev/null +++ b/test/googletest/docs/quickstart-cmake.md @@ -0,0 +1,157 @@ +# Quickstart: Building with CMake + +This tutorial aims to get you up and running with GoogleTest using CMake. If +you're using GoogleTest for the first time or need a refresher, we recommend +this tutorial as a starting point. If your project uses Bazel, see the +[Quickstart for Bazel](quickstart-bazel.md) instead. + +## Prerequisites + +To complete this tutorial, you'll need: + +* A compatible operating system (e.g. Linux, macOS, Windows). +* A compatible C++ compiler that supports at least C++14. +* [CMake](https://cmake.org/) and a compatible build tool for building the + project. + * Compatible build tools include + [Make](https://www.gnu.org/software/make/), + [Ninja](https://ninja-build.org/), and others - see + [CMake Generators](https://cmake.org/cmake/help/latest/manual/cmake-generators.7.html) + for more information. + +See [Supported Platforms](platforms.md) for more information about platforms +compatible with GoogleTest. + +If you don't already have CMake installed, see the +[CMake installation guide](https://cmake.org/install). + +{: .callout .note} +Note: The terminal commands in this tutorial show a Unix shell prompt, but the +commands work on the Windows command line as well. + +## Set up a project + +CMake uses a file named `CMakeLists.txt` to configure the build system for a +project. You'll use this file to set up your project and declare a dependency on +GoogleTest. + +First, create a directory for your project: + +``` +$ mkdir my_project && cd my_project +``` + +Next, you'll create the `CMakeLists.txt` file and declare a dependency on +GoogleTest. There are many ways to express dependencies in the CMake ecosystem; +in this quickstart, you'll use the +[`FetchContent` CMake module](https://cmake.org/cmake/help/latest/module/FetchContent.html). +To do this, in your project directory (`my_project`), create a file named +`CMakeLists.txt` with the following contents: + +```cmake +cmake_minimum_required(VERSION 3.14) +project(my_project) + +# GoogleTest requires at least C++14 +set(CMAKE_CXX_STANDARD 14) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +include(FetchContent) +FetchContent_Declare( + googletest + URL https://github.com/google/googletest/archive/03597a01ee50ed33e9dfd640b249b4be3799d395.zip +) +# For Windows: Prevent overriding the parent project's compiler/linker settings +set(gtest_force_shared_crt ON CACHE BOOL "" FORCE) +FetchContent_MakeAvailable(googletest) +``` + +The above configuration declares a dependency on GoogleTest which is downloaded +from GitHub. In the above example, `03597a01ee50ed33e9dfd640b249b4be3799d395` is +the Git commit hash of the GoogleTest version to use; we recommend updating the +hash often to point to the latest version. + +For more information about how to create `CMakeLists.txt` files, see the +[CMake Tutorial](https://cmake.org/cmake/help/latest/guide/tutorial/index.html). + +## Create and run a binary + +With GoogleTest declared as a dependency, you can use GoogleTest code within +your own project. + +As an example, create a file named `hello_test.cc` in your `my_project` +directory with the following contents: + +```cpp +#include + +// Demonstrate some basic assertions. +TEST(HelloTest, BasicAssertions) { + // Expect two strings not to be equal. + EXPECT_STRNE("hello", "world"); + // Expect equality. + EXPECT_EQ(7 * 6, 42); +} +``` + +GoogleTest provides [assertions](primer.md#assertions) that you use to test the +behavior of your code. The above sample includes the main GoogleTest header file +and demonstrates some basic assertions. + +To build the code, add the following to the end of your `CMakeLists.txt` file: + +```cmake +enable_testing() + +add_executable( + hello_test + hello_test.cc +) +target_link_libraries( + hello_test + GTest::gtest_main +) + +include(GoogleTest) +gtest_discover_tests(hello_test) +``` + +The above configuration enables testing in CMake, declares the C++ test binary +you want to build (`hello_test`), and links it to GoogleTest (`gtest_main`). The +last two lines enable CMake's test runner to discover the tests included in the +binary, using the +[`GoogleTest` CMake module](https://cmake.org/cmake/help/git-stage/module/GoogleTest.html). + +Now you can build and run your test: + +
+my_project$ cmake -S . -B build
+-- The C compiler identification is GNU 10.2.1
+-- The CXX compiler identification is GNU 10.2.1
+...
+-- Build files have been written to: .../my_project/build
+
+my_project$ cmake --build build
+Scanning dependencies of target gtest
+...
+[100%] Built target gmock_main
+
+my_project$ cd build && ctest
+Test project .../my_project/build
+    Start 1: HelloTest.BasicAssertions
+1/1 Test #1: HelloTest.BasicAssertions ........   Passed    0.00 sec
+
+100% tests passed, 0 tests failed out of 1
+
+Total Test time (real) =   0.01 sec
+
+ +Congratulations! You've successfully built and run a test binary using +GoogleTest. + +## Next steps + +* [Check out the Primer](primer.md) to start learning how to write simple + tests. +* [See the code samples](samples.md) for more examples showing how to use a + variety of GoogleTest features. diff --git a/test/googletest/docs/reference/actions.md b/test/googletest/docs/reference/actions.md new file mode 100644 index 0000000..ab81a12 --- /dev/null +++ b/test/googletest/docs/reference/actions.md @@ -0,0 +1,115 @@ +# Actions Reference + +[**Actions**](../gmock_for_dummies.md#actions-what-should-it-do) specify what a +mock function should do when invoked. This page lists the built-in actions +provided by GoogleTest. All actions are defined in the `::testing` namespace. + +## Returning a Value + +| Action | Description | +| :-------------------------------- | :-------------------------------------------- | +| `Return()` | Return from a `void` mock function. | +| `Return(value)` | Return `value`. If the type of `value` is different to the mock function's return type, `value` is converted to the latter type at the time the expectation is set, not when the action is executed. | +| `ReturnArg()` | Return the `N`-th (0-based) argument. | +| `ReturnNew(a1, ..., ak)` | Return `new T(a1, ..., ak)`; a different object is created each time. | +| `ReturnNull()` | Return a null pointer. | +| `ReturnPointee(ptr)` | Return the value pointed to by `ptr`. | +| `ReturnRef(variable)` | Return a reference to `variable`. | +| `ReturnRefOfCopy(value)` | Return a reference to a copy of `value`; the copy lives as long as the action. | +| `ReturnRoundRobin({a1, ..., ak})` | Each call will return the next `ai` in the list, starting at the beginning when the end of the list is reached. | + +## Side Effects + +| Action | Description | +| :--------------------------------- | :-------------------------------------- | +| `Assign(&variable, value)` | Assign `value` to variable. | +| `DeleteArg()` | Delete the `N`-th (0-based) argument, which must be a pointer. | +| `SaveArg(pointer)` | Save the `N`-th (0-based) argument to `*pointer`. | +| `SaveArgPointee(pointer)` | Save the value pointed to by the `N`-th (0-based) argument to `*pointer`. | +| `SetArgReferee(value)` | Assign `value` to the variable referenced by the `N`-th (0-based) argument. | +| `SetArgPointee(value)` | Assign `value` to the variable pointed by the `N`-th (0-based) argument. | +| `SetArgumentPointee(value)` | Same as `SetArgPointee(value)`. Deprecated. Will be removed in v1.7.0. | +| `SetArrayArgument(first, last)` | Copies the elements in source range [`first`, `last`) to the array pointed to by the `N`-th (0-based) argument, which can be either a pointer or an iterator. The action does not take ownership of the elements in the source range. | +| `SetErrnoAndReturn(error, value)` | Set `errno` to `error` and return `value`. | +| `Throw(exception)` | Throws the given exception, which can be any copyable value. Available since v1.1.0. | + +## Using a Function, Functor, or Lambda as an Action + +In the following, by "callable" we mean a free function, `std::function`, +functor, or lambda. + +| Action | Description | +| :---------------------------------- | :------------------------------------- | +| `f` | Invoke `f` with the arguments passed to the mock function, where `f` is a callable. | +| `Invoke(f)` | Invoke `f` with the arguments passed to the mock function, where `f` can be a global/static function or a functor. | +| `Invoke(object_pointer, &class::method)` | Invoke the method on the object with the arguments passed to the mock function. | +| `InvokeWithoutArgs(f)` | Invoke `f`, which can be a global/static function or a functor. `f` must take no arguments. | +| `InvokeWithoutArgs(object_pointer, &class::method)` | Invoke the method on the object, which takes no arguments. | +| `InvokeArgument(arg1, arg2, ..., argk)` | Invoke the mock function's `N`-th (0-based) argument, which must be a function or a functor, with the `k` arguments. | + +The return value of the invoked function is used as the return value of the +action. + +When defining a callable to be used with `Invoke*()`, you can declare any unused +parameters as `Unused`: + +```cpp +using ::testing::Invoke; +double Distance(Unused, double x, double y) { return sqrt(x*x + y*y); } +... +EXPECT_CALL(mock, Foo("Hi", _, _)).WillOnce(Invoke(Distance)); +``` + +`Invoke(callback)` and `InvokeWithoutArgs(callback)` take ownership of +`callback`, which must be permanent. The type of `callback` must be a base +callback type instead of a derived one, e.g. + +```cpp + BlockingClosure* done = new BlockingClosure; + ... Invoke(done) ...; // This won't compile! + + Closure* done2 = new BlockingClosure; + ... Invoke(done2) ...; // This works. +``` + +In `InvokeArgument(...)`, if an argument needs to be passed by reference, +wrap it inside `std::ref()`. For example, + +```cpp +using ::testing::InvokeArgument; +... +InvokeArgument<2>(5, string("Hi"), std::ref(foo)) +``` + +calls the mock function's #2 argument, passing to it `5` and `string("Hi")` by +value, and `foo` by reference. + +## Default Action + +| Action | Description | +| :------------ | :----------------------------------------------------- | +| `DoDefault()` | Do the default action (specified by `ON_CALL()` or the built-in one). | + +{: .callout .note} +**Note:** due to technical reasons, `DoDefault()` cannot be used inside a +composite action - trying to do so will result in a run-time error. + +## Composite Actions + +| Action | Description | +| :----------------------------- | :------------------------------------------ | +| `DoAll(a1, a2, ..., an)` | Do all actions `a1` to `an` and return the result of `an` in each invocation. The first `n - 1` sub-actions must return void and will receive a readonly view of the arguments. | +| `IgnoreResult(a)` | Perform action `a` and ignore its result. `a` must not return void. | +| `WithArg(a)` | Pass the `N`-th (0-based) argument of the mock function to action `a` and perform it. | +| `WithArgs(a)` | Pass the selected (0-based) arguments of the mock function to action `a` and perform it. | +| `WithoutArgs(a)` | Perform action `a` without any arguments. | + +## Defining Actions + +| Macro | Description | +| :--------------------------------- | :-------------------------------------- | +| `ACTION(Sum) { return arg0 + arg1; }` | Defines an action `Sum()` to return the sum of the mock function's argument #0 and #1. | +| `ACTION_P(Plus, n) { return arg0 + n; }` | Defines an action `Plus(n)` to return the sum of the mock function's argument #0 and `n`. | +| `ACTION_Pk(Foo, p1, ..., pk) { statements; }` | Defines a parameterized action `Foo(p1, ..., pk)` to execute the given `statements`. | + +The `ACTION*` macros cannot be used inside a function or class. diff --git a/test/googletest/docs/reference/assertions.md b/test/googletest/docs/reference/assertions.md new file mode 100644 index 0000000..eeec4a0 --- /dev/null +++ b/test/googletest/docs/reference/assertions.md @@ -0,0 +1,640 @@ +# Assertions Reference + +This page lists the assertion macros provided by GoogleTest for verifying code +behavior. To use them, add `#include `. + +The majority of the macros listed below come as a pair with an `EXPECT_` variant +and an `ASSERT_` variant. Upon failure, `EXPECT_` macros generate nonfatal +failures and allow the current function to continue running, while `ASSERT_` +macros generate fatal failures and abort the current function. + +All assertion macros support streaming a custom failure message into them with +the `<<` operator, for example: + +```cpp +EXPECT_TRUE(my_condition) << "My condition is not true"; +``` + +Anything that can be streamed to an `ostream` can be streamed to an assertion +macro—in particular, C strings and string objects. If a wide string (`wchar_t*`, +`TCHAR*` in `UNICODE` mode on Windows, or `std::wstring`) is streamed to an +assertion, it will be translated to UTF-8 when printed. + +## Explicit Success and Failure {#success-failure} + +The assertions in this section generate a success or failure directly instead of +testing a value or expression. These are useful when control flow, rather than a +Boolean expression, determines the test's success or failure, as shown by the +following example: + +```c++ +switch(expression) { + case 1: + ... some checks ... + case 2: + ... some other checks ... + default: + FAIL() << "We shouldn't get here."; +} +``` + +### SUCCEED {#SUCCEED} + +`SUCCEED()` + +Generates a success. This *does not* make the overall test succeed. A test is +considered successful only if none of its assertions fail during its execution. + +The `SUCCEED` assertion is purely documentary and currently doesn't generate any +user-visible output. However, we may add `SUCCEED` messages to GoogleTest output +in the future. + +### FAIL {#FAIL} + +`FAIL()` + +Generates a fatal failure, which returns from the current function. + +Can only be used in functions that return `void`. See +[Assertion Placement](../advanced.md#assertion-placement) for more information. + +### ADD_FAILURE {#ADD_FAILURE} + +`ADD_FAILURE()` + +Generates a nonfatal failure, which allows the current function to continue +running. + +### ADD_FAILURE_AT {#ADD_FAILURE_AT} + +`ADD_FAILURE_AT(`*`file_path`*`,`*`line_number`*`)` + +Generates a nonfatal failure at the file and line number specified. + +## Generalized Assertion {#generalized} + +The following assertion allows [matchers](matchers.md) to be used to verify +values. + +### EXPECT_THAT {#EXPECT_THAT} + +`EXPECT_THAT(`*`value`*`,`*`matcher`*`)` \ +`ASSERT_THAT(`*`value`*`,`*`matcher`*`)` + +Verifies that *`value`* matches the [matcher](matchers.md) *`matcher`*. + +For example, the following code verifies that the string `value1` starts with +`"Hello"`, `value2` matches a regular expression, and `value3` is between 5 and +10: + +```cpp +#include + +using ::testing::AllOf; +using ::testing::Gt; +using ::testing::Lt; +using ::testing::MatchesRegex; +using ::testing::StartsWith; + +... +EXPECT_THAT(value1, StartsWith("Hello")); +EXPECT_THAT(value2, MatchesRegex("Line \\d+")); +ASSERT_THAT(value3, AllOf(Gt(5), Lt(10))); +``` + +Matchers enable assertions of this form to read like English and generate +informative failure messages. For example, if the above assertion on `value1` +fails, the resulting message will be similar to the following: + +``` +Value of: value1 + Actual: "Hi, world!" +Expected: starts with "Hello" +``` + +GoogleTest provides a built-in library of matchers—see the +[Matchers Reference](matchers.md). It is also possible to write your own +matchers—see [Writing New Matchers Quickly](../gmock_cook_book.md#NewMatchers). +The use of matchers makes `EXPECT_THAT` a powerful, extensible assertion. + +*The idea for this assertion was borrowed from Joe Walnes' Hamcrest project, +which adds `assertThat()` to JUnit.* + +## Boolean Conditions {#boolean} + +The following assertions test Boolean conditions. + +### EXPECT_TRUE {#EXPECT_TRUE} + +`EXPECT_TRUE(`*`condition`*`)` \ +`ASSERT_TRUE(`*`condition`*`)` + +Verifies that *`condition`* is true. + +### EXPECT_FALSE {#EXPECT_FALSE} + +`EXPECT_FALSE(`*`condition`*`)` \ +`ASSERT_FALSE(`*`condition`*`)` + +Verifies that *`condition`* is false. + +## Binary Comparison {#binary-comparison} + +The following assertions compare two values. The value arguments must be +comparable by the assertion's comparison operator, otherwise a compiler error +will result. + +If an argument supports the `<<` operator, it will be called to print the +argument when the assertion fails. Otherwise, GoogleTest will attempt to print +them in the best way it can—see +[Teaching GoogleTest How to Print Your Values](../advanced.md#teaching-googletest-how-to-print-your-values). + +Arguments are always evaluated exactly once, so it's OK for the arguments to +have side effects. However, the argument evaluation order is undefined and +programs should not depend on any particular argument evaluation order. + +These assertions work with both narrow and wide string objects (`string` and +`wstring`). + +See also the [Floating-Point Comparison](#floating-point) assertions to compare +floating-point numbers and avoid problems caused by rounding. + +### EXPECT_EQ {#EXPECT_EQ} + +`EXPECT_EQ(`*`val1`*`,`*`val2`*`)` \ +`ASSERT_EQ(`*`val1`*`,`*`val2`*`)` + +Verifies that *`val1`*`==`*`val2`*. + +Does pointer equality on pointers. If used on two C strings, it tests if they +are in the same memory location, not if they have the same value. Use +[`EXPECT_STREQ`](#EXPECT_STREQ) to compare C strings (e.g. `const char*`) by +value. + +When comparing a pointer to `NULL`, use `EXPECT_EQ(`*`ptr`*`, nullptr)` instead +of `EXPECT_EQ(`*`ptr`*`, NULL)`. + +### EXPECT_NE {#EXPECT_NE} + +`EXPECT_NE(`*`val1`*`,`*`val2`*`)` \ +`ASSERT_NE(`*`val1`*`,`*`val2`*`)` + +Verifies that *`val1`*`!=`*`val2`*. + +Does pointer equality on pointers. If used on two C strings, it tests if they +are in different memory locations, not if they have different values. Use +[`EXPECT_STRNE`](#EXPECT_STRNE) to compare C strings (e.g. `const char*`) by +value. + +When comparing a pointer to `NULL`, use `EXPECT_NE(`*`ptr`*`, nullptr)` instead +of `EXPECT_NE(`*`ptr`*`, NULL)`. + +### EXPECT_LT {#EXPECT_LT} + +`EXPECT_LT(`*`val1`*`,`*`val2`*`)` \ +`ASSERT_LT(`*`val1`*`,`*`val2`*`)` + +Verifies that *`val1`*`<`*`val2`*. + +### EXPECT_LE {#EXPECT_LE} + +`EXPECT_LE(`*`val1`*`,`*`val2`*`)` \ +`ASSERT_LE(`*`val1`*`,`*`val2`*`)` + +Verifies that *`val1`*`<=`*`val2`*. + +### EXPECT_GT {#EXPECT_GT} + +`EXPECT_GT(`*`val1`*`,`*`val2`*`)` \ +`ASSERT_GT(`*`val1`*`,`*`val2`*`)` + +Verifies that *`val1`*`>`*`val2`*. + +### EXPECT_GE {#EXPECT_GE} + +`EXPECT_GE(`*`val1`*`,`*`val2`*`)` \ +`ASSERT_GE(`*`val1`*`,`*`val2`*`)` + +Verifies that *`val1`*`>=`*`val2`*. + +## String Comparison {#c-strings} + +The following assertions compare two **C strings**. To compare two `string` +objects, use [`EXPECT_EQ`](#EXPECT_EQ) or [`EXPECT_NE`](#EXPECT_NE) instead. + +These assertions also accept wide C strings (`wchar_t*`). If a comparison of two +wide strings fails, their values will be printed as UTF-8 narrow strings. + +To compare a C string with `NULL`, use `EXPECT_EQ(`*`c_string`*`, nullptr)` or +`EXPECT_NE(`*`c_string`*`, nullptr)`. + +### EXPECT_STREQ {#EXPECT_STREQ} + +`EXPECT_STREQ(`*`str1`*`,`*`str2`*`)` \ +`ASSERT_STREQ(`*`str1`*`,`*`str2`*`)` + +Verifies that the two C strings *`str1`* and *`str2`* have the same contents. + +### EXPECT_STRNE {#EXPECT_STRNE} + +`EXPECT_STRNE(`*`str1`*`,`*`str2`*`)` \ +`ASSERT_STRNE(`*`str1`*`,`*`str2`*`)` + +Verifies that the two C strings *`str1`* and *`str2`* have different contents. + +### EXPECT_STRCASEEQ {#EXPECT_STRCASEEQ} + +`EXPECT_STRCASEEQ(`*`str1`*`,`*`str2`*`)` \ +`ASSERT_STRCASEEQ(`*`str1`*`,`*`str2`*`)` + +Verifies that the two C strings *`str1`* and *`str2`* have the same contents, +ignoring case. + +### EXPECT_STRCASENE {#EXPECT_STRCASENE} + +`EXPECT_STRCASENE(`*`str1`*`,`*`str2`*`)` \ +`ASSERT_STRCASENE(`*`str1`*`,`*`str2`*`)` + +Verifies that the two C strings *`str1`* and *`str2`* have different contents, +ignoring case. + +## Floating-Point Comparison {#floating-point} + +The following assertions compare two floating-point values. + +Due to rounding errors, it is very unlikely that two floating-point values will +match exactly, so `EXPECT_EQ` is not suitable. In general, for floating-point +comparison to make sense, the user needs to carefully choose the error bound. + +GoogleTest also provides assertions that use a default error bound based on +Units in the Last Place (ULPs). To learn more about ULPs, see the article +[Comparing Floating Point Numbers](https://randomascii.wordpress.com/2012/02/25/comparing-floating-point-numbers-2012-edition/). + +### EXPECT_FLOAT_EQ {#EXPECT_FLOAT_EQ} + +`EXPECT_FLOAT_EQ(`*`val1`*`,`*`val2`*`)` \ +`ASSERT_FLOAT_EQ(`*`val1`*`,`*`val2`*`)` + +Verifies that the two `float` values *`val1`* and *`val2`* are approximately +equal, to within 4 ULPs from each other. Infinity and the largest finite float +value are considered to be one ULP apart. + +### EXPECT_DOUBLE_EQ {#EXPECT_DOUBLE_EQ} + +`EXPECT_DOUBLE_EQ(`*`val1`*`,`*`val2`*`)` \ +`ASSERT_DOUBLE_EQ(`*`val1`*`,`*`val2`*`)` + +Verifies that the two `double` values *`val1`* and *`val2`* are approximately +equal, to within 4 ULPs from each other. Infinity and the largest finite double +value are considered to be one ULP apart. + +### EXPECT_NEAR {#EXPECT_NEAR} + +`EXPECT_NEAR(`*`val1`*`,`*`val2`*`,`*`abs_error`*`)` \ +`ASSERT_NEAR(`*`val1`*`,`*`val2`*`,`*`abs_error`*`)` + +Verifies that the difference between *`val1`* and *`val2`* does not exceed the +absolute error bound *`abs_error`*. + +If *`val`* and *`val2`* are both infinity of the same sign, the difference is +considered to be 0. Otherwise, if either value is infinity, the difference is +considered to be infinity. All non-NaN values (including infinity) are +considered to not exceed an *`abs_error`* of infinity. + +## Exception Assertions {#exceptions} + +The following assertions verify that a piece of code throws, or does not throw, +an exception. Usage requires exceptions to be enabled in the build environment. + +Note that the piece of code under test can be a compound statement, for example: + +```cpp +EXPECT_NO_THROW({ + int n = 5; + DoSomething(&n); +}); +``` + +### EXPECT_THROW {#EXPECT_THROW} + +`EXPECT_THROW(`*`statement`*`,`*`exception_type`*`)` \ +`ASSERT_THROW(`*`statement`*`,`*`exception_type`*`)` + +Verifies that *`statement`* throws an exception of type *`exception_type`*. + +### EXPECT_ANY_THROW {#EXPECT_ANY_THROW} + +`EXPECT_ANY_THROW(`*`statement`*`)` \ +`ASSERT_ANY_THROW(`*`statement`*`)` + +Verifies that *`statement`* throws an exception of any type. + +### EXPECT_NO_THROW {#EXPECT_NO_THROW} + +`EXPECT_NO_THROW(`*`statement`*`)` \ +`ASSERT_NO_THROW(`*`statement`*`)` + +Verifies that *`statement`* does not throw any exception. + +## Predicate Assertions {#predicates} + +The following assertions enable more complex predicates to be verified while +printing a more clear failure message than if `EXPECT_TRUE` were used alone. + +### EXPECT_PRED* {#EXPECT_PRED} + +`EXPECT_PRED1(`*`pred`*`,`*`val1`*`)` \ +`EXPECT_PRED2(`*`pred`*`,`*`val1`*`,`*`val2`*`)` \ +`EXPECT_PRED3(`*`pred`*`,`*`val1`*`,`*`val2`*`,`*`val3`*`)` \ +`EXPECT_PRED4(`*`pred`*`,`*`val1`*`,`*`val2`*`,`*`val3`*`,`*`val4`*`)` \ +`EXPECT_PRED5(`*`pred`*`,`*`val1`*`,`*`val2`*`,`*`val3`*`,`*`val4`*`,`*`val5`*`)` + +`ASSERT_PRED1(`*`pred`*`,`*`val1`*`)` \ +`ASSERT_PRED2(`*`pred`*`,`*`val1`*`,`*`val2`*`)` \ +`ASSERT_PRED3(`*`pred`*`,`*`val1`*`,`*`val2`*`,`*`val3`*`)` \ +`ASSERT_PRED4(`*`pred`*`,`*`val1`*`,`*`val2`*`,`*`val3`*`,`*`val4`*`)` \ +`ASSERT_PRED5(`*`pred`*`,`*`val1`*`,`*`val2`*`,`*`val3`*`,`*`val4`*`,`*`val5`*`)` + +Verifies that the predicate *`pred`* returns `true` when passed the given values +as arguments. + +The parameter *`pred`* is a function or functor that accepts as many arguments +as the corresponding macro accepts values. If *`pred`* returns `true` for the +given arguments, the assertion succeeds, otherwise the assertion fails. + +When the assertion fails, it prints the value of each argument. Arguments are +always evaluated exactly once. + +As an example, see the following code: + +```cpp +// Returns true if m and n have no common divisors except 1. +bool MutuallyPrime(int m, int n) { ... } +... +const int a = 3; +const int b = 4; +const int c = 10; +... +EXPECT_PRED2(MutuallyPrime, a, b); // Succeeds +EXPECT_PRED2(MutuallyPrime, b, c); // Fails +``` + +In the above example, the first assertion succeeds, and the second fails with +the following message: + +``` +MutuallyPrime(b, c) is false, where +b is 4 +c is 10 +``` + +Note that if the given predicate is an overloaded function or a function +template, the assertion macro might not be able to determine which version to +use, and it might be necessary to explicitly specify the type of the function. +For example, for a Boolean function `IsPositive()` overloaded to take either a +single `int` or `double` argument, it would be necessary to write one of the +following: + +```cpp +EXPECT_PRED1(static_cast(IsPositive), 5); +EXPECT_PRED1(static_cast(IsPositive), 3.14); +``` + +Writing simply `EXPECT_PRED1(IsPositive, 5);` would result in a compiler error. +Similarly, to use a template function, specify the template arguments: + +```cpp +template +bool IsNegative(T x) { + return x < 0; +} +... +EXPECT_PRED1(IsNegative, -5); // Must specify type for IsNegative +``` + +If a template has multiple parameters, wrap the predicate in parentheses so the +macro arguments are parsed correctly: + +```cpp +ASSERT_PRED2((MyPredicate), 5, 0); +``` + +### EXPECT_PRED_FORMAT* {#EXPECT_PRED_FORMAT} + +`EXPECT_PRED_FORMAT1(`*`pred_formatter`*`,`*`val1`*`)` \ +`EXPECT_PRED_FORMAT2(`*`pred_formatter`*`,`*`val1`*`,`*`val2`*`)` \ +`EXPECT_PRED_FORMAT3(`*`pred_formatter`*`,`*`val1`*`,`*`val2`*`,`*`val3`*`)` \ +`EXPECT_PRED_FORMAT4(`*`pred_formatter`*`,`*`val1`*`,`*`val2`*`,`*`val3`*`,`*`val4`*`)` +\ +`EXPECT_PRED_FORMAT5(`*`pred_formatter`*`,`*`val1`*`,`*`val2`*`,`*`val3`*`,`*`val4`*`,`*`val5`*`)` + +`ASSERT_PRED_FORMAT1(`*`pred_formatter`*`,`*`val1`*`)` \ +`ASSERT_PRED_FORMAT2(`*`pred_formatter`*`,`*`val1`*`,`*`val2`*`)` \ +`ASSERT_PRED_FORMAT3(`*`pred_formatter`*`,`*`val1`*`,`*`val2`*`,`*`val3`*`)` \ +`ASSERT_PRED_FORMAT4(`*`pred_formatter`*`,`*`val1`*`,`*`val2`*`,`*`val3`*`,`*`val4`*`)` +\ +`ASSERT_PRED_FORMAT5(`*`pred_formatter`*`,`*`val1`*`,`*`val2`*`,`*`val3`*`,`*`val4`*`,`*`val5`*`)` + +Verifies that the predicate *`pred_formatter`* succeeds when passed the given +values as arguments. + +The parameter *`pred_formatter`* is a *predicate-formatter*, which is a function +or functor with the signature: + +```cpp +testing::AssertionResult PredicateFormatter(const char* expr1, + const char* expr2, + ... + const char* exprn, + T1 val1, + T2 val2, + ... + Tn valn); +``` + +where *`val1`*, *`val2`*, ..., *`valn`* are the values of the predicate +arguments, and *`expr1`*, *`expr2`*, ..., *`exprn`* are the corresponding +expressions as they appear in the source code. The types `T1`, `T2`, ..., `Tn` +can be either value types or reference types; if an argument has type `T`, it +can be declared as either `T` or `const T&`, whichever is appropriate. For more +about the return type `testing::AssertionResult`, see +[Using a Function That Returns an AssertionResult](../advanced.md#using-a-function-that-returns-an-assertionresult). + +As an example, see the following code: + +```cpp +// Returns the smallest prime common divisor of m and n, +// or 1 when m and n are mutually prime. +int SmallestPrimeCommonDivisor(int m, int n) { ... } + +// Returns true if m and n have no common divisors except 1. +bool MutuallyPrime(int m, int n) { ... } + +// A predicate-formatter for asserting that two integers are mutually prime. +testing::AssertionResult AssertMutuallyPrime(const char* m_expr, + const char* n_expr, + int m, + int n) { + if (MutuallyPrime(m, n)) return testing::AssertionSuccess(); + + return testing::AssertionFailure() << m_expr << " and " << n_expr + << " (" << m << " and " << n << ") are not mutually prime, " + << "as they have a common divisor " << SmallestPrimeCommonDivisor(m, n); +} + +... +const int a = 3; +const int b = 4; +const int c = 10; +... +EXPECT_PRED_FORMAT2(AssertMutuallyPrime, a, b); // Succeeds +EXPECT_PRED_FORMAT2(AssertMutuallyPrime, b, c); // Fails +``` + +In the above example, the final assertion fails and the predicate-formatter +produces the following failure message: + +``` +b and c (4 and 10) are not mutually prime, as they have a common divisor 2 +``` + +## Windows HRESULT Assertions {#HRESULT} + +The following assertions test for `HRESULT` success or failure. For example: + +```cpp +CComPtr shell; +ASSERT_HRESULT_SUCCEEDED(shell.CoCreateInstance(L"Shell.Application")); +CComVariant empty; +ASSERT_HRESULT_SUCCEEDED(shell->ShellExecute(CComBSTR(url), empty, empty, empty, empty)); +``` + +The generated output contains the human-readable error message associated with +the returned `HRESULT` code. + +### EXPECT_HRESULT_SUCCEEDED {#EXPECT_HRESULT_SUCCEEDED} + +`EXPECT_HRESULT_SUCCEEDED(`*`expression`*`)` \ +`ASSERT_HRESULT_SUCCEEDED(`*`expression`*`)` + +Verifies that *`expression`* is a success `HRESULT`. + +### EXPECT_HRESULT_FAILED {#EXPECT_HRESULT_FAILED} + +`EXPECT_HRESULT_FAILED(`*`expression`*`)` \ +`ASSERT_HRESULT_FAILED(`*`expression`*`)` + +Verifies that *`expression`* is a failure `HRESULT`. + +## Death Assertions {#death} + +The following assertions verify that a piece of code causes the process to +terminate. For context, see [Death Tests](../advanced.md#death-tests). + +These assertions spawn a new process and execute the code under test in that +process. How that happens depends on the platform and the variable +`::testing::GTEST_FLAG(death_test_style)`, which is initialized from the +command-line flag `--gtest_death_test_style`. + +* On POSIX systems, `fork()` (or `clone()` on Linux) is used to spawn the + child, after which: + * If the variable's value is `"fast"`, the death test statement is + immediately executed. + * If the variable's value is `"threadsafe"`, the child process re-executes + the unit test binary just as it was originally invoked, but with some + extra flags to cause just the single death test under consideration to + be run. +* On Windows, the child is spawned using the `CreateProcess()` API, and + re-executes the binary to cause just the single death test under + consideration to be run - much like the `"threadsafe"` mode on POSIX. + +Other values for the variable are illegal and will cause the death test to fail. +Currently, the flag's default value is +**`"fast"`**. + +If the death test statement runs to completion without dying, the child process +will nonetheless terminate, and the assertion fails. + +Note that the piece of code under test can be a compound statement, for example: + +```cpp +EXPECT_DEATH({ + int n = 5; + DoSomething(&n); +}, "Error on line .* of DoSomething()"); +``` + +### EXPECT_DEATH {#EXPECT_DEATH} + +`EXPECT_DEATH(`*`statement`*`,`*`matcher`*`)` \ +`ASSERT_DEATH(`*`statement`*`,`*`matcher`*`)` + +Verifies that *`statement`* causes the process to terminate with a nonzero exit +status and produces `stderr` output that matches *`matcher`*. + +The parameter *`matcher`* is either a [matcher](matchers.md) for a `const +std::string&`, or a regular expression (see +[Regular Expression Syntax](../advanced.md#regular-expression-syntax))—a bare +string *`s`* (with no matcher) is treated as +[`ContainsRegex(s)`](matchers.md#string-matchers), **not** +[`Eq(s)`](matchers.md#generic-comparison). + +For example, the following code verifies that calling `DoSomething(42)` causes +the process to die with an error message that contains the text `My error`: + +```cpp +EXPECT_DEATH(DoSomething(42), "My error"); +``` + +### EXPECT_DEATH_IF_SUPPORTED {#EXPECT_DEATH_IF_SUPPORTED} + +`EXPECT_DEATH_IF_SUPPORTED(`*`statement`*`,`*`matcher`*`)` \ +`ASSERT_DEATH_IF_SUPPORTED(`*`statement`*`,`*`matcher`*`)` + +If death tests are supported, behaves the same as +[`EXPECT_DEATH`](#EXPECT_DEATH). Otherwise, verifies nothing. + +### EXPECT_DEBUG_DEATH {#EXPECT_DEBUG_DEATH} + +`EXPECT_DEBUG_DEATH(`*`statement`*`,`*`matcher`*`)` \ +`ASSERT_DEBUG_DEATH(`*`statement`*`,`*`matcher`*`)` + +In debug mode, behaves the same as [`EXPECT_DEATH`](#EXPECT_DEATH). When not in +debug mode (i.e. `NDEBUG` is defined), just executes *`statement`*. + +### EXPECT_EXIT {#EXPECT_EXIT} + +`EXPECT_EXIT(`*`statement`*`,`*`predicate`*`,`*`matcher`*`)` \ +`ASSERT_EXIT(`*`statement`*`,`*`predicate`*`,`*`matcher`*`)` + +Verifies that *`statement`* causes the process to terminate with an exit status +that satisfies *`predicate`*, and produces `stderr` output that matches +*`matcher`*. + +The parameter *`predicate`* is a function or functor that accepts an `int` exit +status and returns a `bool`. GoogleTest provides two predicates to handle common +cases: + +```cpp +// Returns true if the program exited normally with the given exit status code. +::testing::ExitedWithCode(exit_code); + +// Returns true if the program was killed by the given signal. +// Not available on Windows. +::testing::KilledBySignal(signal_number); +``` + +The parameter *`matcher`* is either a [matcher](matchers.md) for a `const +std::string&`, or a regular expression (see +[Regular Expression Syntax](../advanced.md#regular-expression-syntax))—a bare +string *`s`* (with no matcher) is treated as +[`ContainsRegex(s)`](matchers.md#string-matchers), **not** +[`Eq(s)`](matchers.md#generic-comparison). + +For example, the following code verifies that calling `NormalExit()` causes the +process to print a message containing the text `Success` to `stderr` and exit +with exit status code 0: + +```cpp +EXPECT_EXIT(NormalExit(), testing::ExitedWithCode(0), "Success"); +``` diff --git a/test/googletest/docs/reference/matchers.md b/test/googletest/docs/reference/matchers.md new file mode 100644 index 0000000..243e3f9 --- /dev/null +++ b/test/googletest/docs/reference/matchers.md @@ -0,0 +1,302 @@ +# Matchers Reference + +A **matcher** matches a *single* argument. You can use it inside `ON_CALL()` or +`EXPECT_CALL()`, or use it to validate a value directly using two macros: + +| Macro | Description | +| :----------------------------------- | :------------------------------------ | +| `EXPECT_THAT(actual_value, matcher)` | Asserts that `actual_value` matches `matcher`. | +| `ASSERT_THAT(actual_value, matcher)` | The same as `EXPECT_THAT(actual_value, matcher)`, except that it generates a **fatal** failure. | + +{: .callout .warning} +**WARNING:** Equality matching via `EXPECT_THAT(actual_value, expected_value)` +is supported, however note that implicit conversions can cause surprising +results. For example, `EXPECT_THAT(some_bool, "some string")` will compile and +may pass unintentionally. + +**BEST PRACTICE:** Prefer to make the comparison explicit via +`EXPECT_THAT(actual_value, Eq(expected_value))` or `EXPECT_EQ(actual_value, +expected_value)`. + +Built-in matchers (where `argument` is the function argument, e.g. +`actual_value` in the example above, or when used in the context of +`EXPECT_CALL(mock_object, method(matchers))`, the arguments of `method`) are +divided into several categories. All matchers are defined in the `::testing` +namespace unless otherwise noted. + +## Wildcard + +Matcher | Description +:-------------------------- | :----------------------------------------------- +`_` | `argument` can be any value of the correct type. +`A()` or `An()` | `argument` can be any value of type `type`. + +## Generic Comparison + +| Matcher | Description | +| :--------------------- | :-------------------------------------------------- | +| `Eq(value)` or `value` | `argument == value` | +| `Ge(value)` | `argument >= value` | +| `Gt(value)` | `argument > value` | +| `Le(value)` | `argument <= value` | +| `Lt(value)` | `argument < value` | +| `Ne(value)` | `argument != value` | +| `IsFalse()` | `argument` evaluates to `false` in a Boolean context. | +| `IsTrue()` | `argument` evaluates to `true` in a Boolean context. | +| `IsNull()` | `argument` is a `NULL` pointer (raw or smart). | +| `NotNull()` | `argument` is a non-null pointer (raw or smart). | +| `Optional(m)` | `argument` is `optional<>` that contains a value matching `m`. (For testing whether an `optional<>` is set, check for equality with `nullopt`. You may need to use `Eq(nullopt)` if the inner type doesn't have `==`.)| +| `VariantWith(m)` | `argument` is `variant<>` that holds the alternative of type T with a value matching `m`. | +| `Ref(variable)` | `argument` is a reference to `variable`. | +| `TypedEq(value)` | `argument` has type `type` and is equal to `value`. You may need to use this instead of `Eq(value)` when the mock function is overloaded. | + +Except `Ref()`, these matchers make a *copy* of `value` in case it's modified or +destructed later. If the compiler complains that `value` doesn't have a public +copy constructor, try wrap it in `std::ref()`, e.g. +`Eq(std::ref(non_copyable_value))`. If you do that, make sure +`non_copyable_value` is not changed afterwards, or the meaning of your matcher +will be changed. + +`IsTrue` and `IsFalse` are useful when you need to use a matcher, or for types +that can be explicitly converted to Boolean, but are not implicitly converted to +Boolean. In other cases, you can use the basic +[`EXPECT_TRUE` and `EXPECT_FALSE`](assertions.md#boolean) assertions. + +## Floating-Point Matchers {#FpMatchers} + +| Matcher | Description | +| :------------------------------- | :--------------------------------- | +| `DoubleEq(a_double)` | `argument` is a `double` value approximately equal to `a_double`, treating two NaNs as unequal. | +| `FloatEq(a_float)` | `argument` is a `float` value approximately equal to `a_float`, treating two NaNs as unequal. | +| `NanSensitiveDoubleEq(a_double)` | `argument` is a `double` value approximately equal to `a_double`, treating two NaNs as equal. | +| `NanSensitiveFloatEq(a_float)` | `argument` is a `float` value approximately equal to `a_float`, treating two NaNs as equal. | +| `IsNan()` | `argument` is any floating-point type with a NaN value. | + +The above matchers use ULP-based comparison (the same as used in googletest). +They automatically pick a reasonable error bound based on the absolute value of +the expected value. `DoubleEq()` and `FloatEq()` conform to the IEEE standard, +which requires comparing two NaNs for equality to return false. The +`NanSensitive*` version instead treats two NaNs as equal, which is often what a +user wants. + +| Matcher | Description | +| :------------------------------------------------ | :----------------------- | +| `DoubleNear(a_double, max_abs_error)` | `argument` is a `double` value close to `a_double` (absolute error <= `max_abs_error`), treating two NaNs as unequal. | +| `FloatNear(a_float, max_abs_error)` | `argument` is a `float` value close to `a_float` (absolute error <= `max_abs_error`), treating two NaNs as unequal. | +| `NanSensitiveDoubleNear(a_double, max_abs_error)` | `argument` is a `double` value close to `a_double` (absolute error <= `max_abs_error`), treating two NaNs as equal. | +| `NanSensitiveFloatNear(a_float, max_abs_error)` | `argument` is a `float` value close to `a_float` (absolute error <= `max_abs_error`), treating two NaNs as equal. | + +## String Matchers + +The `argument` can be either a C string or a C++ string object: + +| Matcher | Description | +| :---------------------- | :------------------------------------------------- | +| `ContainsRegex(string)` | `argument` matches the given regular expression. | +| `EndsWith(suffix)` | `argument` ends with string `suffix`. | +| `HasSubstr(string)` | `argument` contains `string` as a sub-string. | +| `IsEmpty()` | `argument` is an empty string. | +| `MatchesRegex(string)` | `argument` matches the given regular expression with the match starting at the first character and ending at the last character. | +| `StartsWith(prefix)` | `argument` starts with string `prefix`. | +| `StrCaseEq(string)` | `argument` is equal to `string`, ignoring case. | +| `StrCaseNe(string)` | `argument` is not equal to `string`, ignoring case. | +| `StrEq(string)` | `argument` is equal to `string`. | +| `StrNe(string)` | `argument` is not equal to `string`. | +| `WhenBase64Unescaped(m)` | `argument` is a base-64 escaped string whose unescaped string matches `m`. The web-safe format from [RFC 4648](https://www.rfc-editor.org/rfc/rfc4648#section-5) is supported. | + +`ContainsRegex()` and `MatchesRegex()` take ownership of the `RE` object. They +use the regular expression syntax defined +[here](../advanced.md#regular-expression-syntax). All of these matchers, except +`ContainsRegex()` and `MatchesRegex()` work for wide strings as well. + +## Container Matchers + +Most STL-style containers support `==`, so you can use `Eq(expected_container)` +or simply `expected_container` to match a container exactly. If you want to +write the elements in-line, match them more flexibly, or get more informative +messages, you can use: + +| Matcher | Description | +| :---------------------------------------- | :------------------------------- | +| `BeginEndDistanceIs(m)` | `argument` is a container whose `begin()` and `end()` iterators are separated by a number of increments matching `m`. E.g. `BeginEndDistanceIs(2)` or `BeginEndDistanceIs(Lt(2))`. For containers that define a `size()` method, `SizeIs(m)` may be more efficient. | +| `ContainerEq(container)` | The same as `Eq(container)` except that the failure message also includes which elements are in one container but not the other. | +| `Contains(e)` | `argument` contains an element that matches `e`, which can be either a value or a matcher. | +| `Contains(e).Times(n)` | `argument` contains elements that match `e`, which can be either a value or a matcher, and the number of matches is `n`, which can be either a value or a matcher. Unlike the plain `Contains` and `Each` this allows to check for arbitrary occurrences including testing for absence with `Contains(e).Times(0)`. | +| `Each(e)` | `argument` is a container where *every* element matches `e`, which can be either a value or a matcher. | +| `ElementsAre(e0, e1, ..., en)` | `argument` has `n + 1` elements, where the *i*-th element matches `ei`, which can be a value or a matcher. | +| `ElementsAreArray({e0, e1, ..., en})`, `ElementsAreArray(a_container)`, `ElementsAreArray(begin, end)`, `ElementsAreArray(array)`, or `ElementsAreArray(array, count)` | The same as `ElementsAre()` except that the expected element values/matchers come from an initializer list, STL-style container, iterator range, or C-style array. | +| `IsEmpty()` | `argument` is an empty container (`container.empty()`). | +| `IsSubsetOf({e0, e1, ..., en})`, `IsSubsetOf(a_container)`, `IsSubsetOf(begin, end)`, `IsSubsetOf(array)`, or `IsSubsetOf(array, count)` | `argument` matches `UnorderedElementsAre(x0, x1, ..., xk)` for some subset `{x0, x1, ..., xk}` of the expected matchers. | +| `IsSupersetOf({e0, e1, ..., en})`, `IsSupersetOf(a_container)`, `IsSupersetOf(begin, end)`, `IsSupersetOf(array)`, or `IsSupersetOf(array, count)` | Some subset of `argument` matches `UnorderedElementsAre(`expected matchers`)`. | +| `Pointwise(m, container)`, `Pointwise(m, {e0, e1, ..., en})` | `argument` contains the same number of elements as in `container`, and for all i, (the i-th element in `argument`, the i-th element in `container`) match `m`, which is a matcher on 2-tuples. E.g. `Pointwise(Le(), upper_bounds)` verifies that each element in `argument` doesn't exceed the corresponding element in `upper_bounds`. See more detail below. | +| `SizeIs(m)` | `argument` is a container whose size matches `m`. E.g. `SizeIs(2)` or `SizeIs(Lt(2))`. | +| `UnorderedElementsAre(e0, e1, ..., en)` | `argument` has `n + 1` elements, and under *some* permutation of the elements, each element matches an `ei` (for a different `i`), which can be a value or a matcher. | +| `UnorderedElementsAreArray({e0, e1, ..., en})`, `UnorderedElementsAreArray(a_container)`, `UnorderedElementsAreArray(begin, end)`, `UnorderedElementsAreArray(array)`, or `UnorderedElementsAreArray(array, count)` | The same as `UnorderedElementsAre()` except that the expected element values/matchers come from an initializer list, STL-style container, iterator range, or C-style array. | +| `UnorderedPointwise(m, container)`, `UnorderedPointwise(m, {e0, e1, ..., en})` | Like `Pointwise(m, container)`, but ignores the order of elements. | +| `WhenSorted(m)` | When `argument` is sorted using the `<` operator, it matches container matcher `m`. E.g. `WhenSorted(ElementsAre(1, 2, 3))` verifies that `argument` contains elements 1, 2, and 3, ignoring order. | +| `WhenSortedBy(comparator, m)` | The same as `WhenSorted(m)`, except that the given comparator instead of `<` is used to sort `argument`. E.g. `WhenSortedBy(std::greater(), ElementsAre(3, 2, 1))`. | + +**Notes:** + +* These matchers can also match: + 1. a native array passed by reference (e.g. in `Foo(const int (&a)[5])`), + and + 2. an array passed as a pointer and a count (e.g. in `Bar(const T* buffer, + int len)` -- see [Multi-argument Matchers](#MultiArgMatchers)). +* The array being matched may be multi-dimensional (i.e. its elements can be + arrays). +* `m` in `Pointwise(m, ...)` and `UnorderedPointwise(m, ...)` should be a + matcher for `::std::tuple` where `T` and `U` are the element type of + the actual container and the expected container, respectively. For example, + to compare two `Foo` containers where `Foo` doesn't support `operator==`, + one might write: + + ```cpp + MATCHER(FooEq, "") { + return std::get<0>(arg).Equals(std::get<1>(arg)); + } + ... + EXPECT_THAT(actual_foos, Pointwise(FooEq(), expected_foos)); + ``` + +## Member Matchers + +| Matcher | Description | +| :------------------------------ | :----------------------------------------- | +| `Field(&class::field, m)` | `argument.field` (or `argument->field` when `argument` is a plain pointer) matches matcher `m`, where `argument` is an object of type _class_. | +| `Field(field_name, &class::field, m)` | The same as the two-parameter version, but provides a better error message. | +| `Key(e)` | `argument.first` matches `e`, which can be either a value or a matcher. E.g. `Contains(Key(Le(5)))` can verify that a `map` contains a key `<= 5`. | +| `Pair(m1, m2)` | `argument` is an `std::pair` whose `first` field matches `m1` and `second` field matches `m2`. | +| `FieldsAre(m...)` | `argument` is a compatible object where each field matches piecewise with the matchers `m...`. A compatible object is any that supports the `std::tuple_size`+`get(obj)` protocol. In C++17 and up this also supports types compatible with structured bindings, like aggregates. | +| `Property(&class::property, m)` | `argument.property()` (or `argument->property()` when `argument` is a plain pointer) matches matcher `m`, where `argument` is an object of type _class_. The method `property()` must take no argument and be declared as `const`. | +| `Property(property_name, &class::property, m)` | The same as the two-parameter version, but provides a better error message. + +**Notes:** + +* You can use `FieldsAre()` to match any type that supports structured + bindings, such as `std::tuple`, `std::pair`, `std::array`, and aggregate + types. For example: + + ```cpp + std::tuple my_tuple{7, "hello world"}; + EXPECT_THAT(my_tuple, FieldsAre(Ge(0), HasSubstr("hello"))); + + struct MyStruct { + int value = 42; + std::string greeting = "aloha"; + }; + MyStruct s; + EXPECT_THAT(s, FieldsAre(42, "aloha")); + ``` + +* Don't use `Property()` against member functions that you do not own, because + taking addresses of functions is fragile and generally not part of the + contract of the function. + +## Matching the Result of a Function, Functor, or Callback + +| Matcher | Description | +| :--------------- | :------------------------------------------------ | +| `ResultOf(f, m)` | `f(argument)` matches matcher `m`, where `f` is a function or functor. | +| `ResultOf(result_description, f, m)` | The same as the two-parameter version, but provides a better error message. + +## Pointer Matchers + +| Matcher | Description | +| :------------------------ | :---------------------------------------------- | +| `Address(m)` | the result of `std::addressof(argument)` matches `m`. | +| `Pointee(m)` | `argument` (either a smart pointer or a raw pointer) points to a value that matches matcher `m`. | +| `Pointer(m)` | `argument` (either a smart pointer or a raw pointer) contains a pointer that matches `m`. `m` will match against the raw pointer regardless of the type of `argument`. | +| `WhenDynamicCastTo(m)` | when `argument` is passed through `dynamic_cast()`, it matches matcher `m`. | + +## Multi-argument Matchers {#MultiArgMatchers} + +Technically, all matchers match a *single* value. A "multi-argument" matcher is +just one that matches a *tuple*. The following matchers can be used to match a +tuple `(x, y)`: + +Matcher | Description +:------ | :---------- +`Eq()` | `x == y` +`Ge()` | `x >= y` +`Gt()` | `x > y` +`Le()` | `x <= y` +`Lt()` | `x < y` +`Ne()` | `x != y` + +You can use the following selectors to pick a subset of the arguments (or +reorder them) to participate in the matching: + +| Matcher | Description | +| :------------------------- | :---------------------------------------------- | +| `AllArgs(m)` | Equivalent to `m`. Useful as syntactic sugar in `.With(AllArgs(m))`. | +| `Args(m)` | The tuple of the `k` selected (using 0-based indices) arguments matches `m`, e.g. `Args<1, 2>(Eq())`. | + +## Composite Matchers + +You can make a matcher from one or more other matchers: + +| Matcher | Description | +| :------------------------------- | :-------------------------------------- | +| `AllOf(m1, m2, ..., mn)` | `argument` matches all of the matchers `m1` to `mn`. | +| `AllOfArray({m0, m1, ..., mn})`, `AllOfArray(a_container)`, `AllOfArray(begin, end)`, `AllOfArray(array)`, or `AllOfArray(array, count)` | The same as `AllOf()` except that the matchers come from an initializer list, STL-style container, iterator range, or C-style array. | +| `AnyOf(m1, m2, ..., mn)` | `argument` matches at least one of the matchers `m1` to `mn`. | +| `AnyOfArray({m0, m1, ..., mn})`, `AnyOfArray(a_container)`, `AnyOfArray(begin, end)`, `AnyOfArray(array)`, or `AnyOfArray(array, count)` | The same as `AnyOf()` except that the matchers come from an initializer list, STL-style container, iterator range, or C-style array. | +| `Not(m)` | `argument` doesn't match matcher `m`. | +| `Conditional(cond, m1, m2)` | Matches matcher `m1` if `cond` evaluates to true, else matches `m2`.| + +## Adapters for Matchers + +| Matcher | Description | +| :---------------------- | :------------------------------------ | +| `MatcherCast(m)` | casts matcher `m` to type `Matcher`. | +| `SafeMatcherCast(m)` | [safely casts](../gmock_cook_book.md#SafeMatcherCast) matcher `m` to type `Matcher`. | +| `Truly(predicate)` | `predicate(argument)` returns something considered by C++ to be true, where `predicate` is a function or functor. | + +`AddressSatisfies(callback)` and `Truly(callback)` take ownership of `callback`, +which must be a permanent callback. + +## Using Matchers as Predicates {#MatchersAsPredicatesCheat} + +| Matcher | Description | +| :---------------------------- | :------------------------------------------ | +| `Matches(m)(value)` | evaluates to `true` if `value` matches `m`. You can use `Matches(m)` alone as a unary functor. | +| `ExplainMatchResult(m, value, result_listener)` | evaluates to `true` if `value` matches `m`, explaining the result to `result_listener`. | +| `Value(value, m)` | evaluates to `true` if `value` matches `m`. | + +## Defining Matchers + +| Macro | Description | +| :----------------------------------- | :------------------------------------ | +| `MATCHER(IsEven, "") { return (arg % 2) == 0; }` | Defines a matcher `IsEven()` to match an even number. | +| `MATCHER_P(IsDivisibleBy, n, "") { *result_listener << "where the remainder is " << (arg % n); return (arg % n) == 0; }` | Defines a matcher `IsDivisibleBy(n)` to match a number divisible by `n`. | +| `MATCHER_P2(IsBetween, a, b, absl::StrCat(negation ? "isn't" : "is", " between ", PrintToString(a), " and ", PrintToString(b))) { return a <= arg && arg <= b; }` | Defines a matcher `IsBetween(a, b)` to match a value in the range [`a`, `b`]. | + +**Notes:** + +1. The `MATCHER*` macros cannot be used inside a function or class. +2. The matcher body must be *purely functional* (i.e. it cannot have any side + effect, and the result must not depend on anything other than the value + being matched and the matcher parameters). +3. You can use `PrintToString(x)` to convert a value `x` of any type to a + string. +4. You can use `ExplainMatchResult()` in a custom matcher to wrap another + matcher, for example: + + ```cpp + MATCHER_P(NestedPropertyMatches, matcher, "") { + return ExplainMatchResult(matcher, arg.nested().property(), result_listener); + } + ``` + +5. You can use `DescribeMatcher<>` to describe another matcher. For example: + + ```cpp + MATCHER_P(XAndYThat, matcher, + "X that " + DescribeMatcher(matcher, negation) + + (negation ? " or" : " and") + " Y that " + + DescribeMatcher(matcher, negation)) { + return ExplainMatchResult(matcher, arg.x(), result_listener) && + ExplainMatchResult(matcher, arg.y(), result_listener); + } + ``` diff --git a/test/googletest/docs/reference/mocking.md b/test/googletest/docs/reference/mocking.md new file mode 100644 index 0000000..ab37ebf --- /dev/null +++ b/test/googletest/docs/reference/mocking.md @@ -0,0 +1,588 @@ +# Mocking Reference + +This page lists the facilities provided by GoogleTest for creating and working +with mock objects. To use them, add `#include `. + +## Macros {#macros} + +GoogleTest defines the following macros for working with mocks. + +### MOCK_METHOD {#MOCK_METHOD} + +`MOCK_METHOD(`*`return_type`*`,`*`method_name`*`, (`*`args...`*`));` \ +`MOCK_METHOD(`*`return_type`*`,`*`method_name`*`, (`*`args...`*`), +(`*`specs...`*`));` + +Defines a mock method *`method_name`* with arguments `(`*`args...`*`)` and +return type *`return_type`* within a mock class. + +The parameters of `MOCK_METHOD` mirror the method declaration. The optional +fourth parameter *`specs...`* is a comma-separated list of qualifiers. The +following qualifiers are accepted: + +| Qualifier | Meaning | +| -------------------------- | -------------------------------------------- | +| `const` | Makes the mocked method a `const` method. Required if overriding a `const` method. | +| `override` | Marks the method with `override`. Recommended if overriding a `virtual` method. | +| `noexcept` | Marks the method with `noexcept`. Required if overriding a `noexcept` method. | +| `Calltype(`*`calltype`*`)` | Sets the call type for the method, for example `Calltype(STDMETHODCALLTYPE)`. Useful on Windows. | +| `ref(`*`qualifier`*`)` | Marks the method with the given reference qualifier, for example `ref(&)` or `ref(&&)`. Required if overriding a method that has a reference qualifier. | + +Note that commas in arguments prevent `MOCK_METHOD` from parsing the arguments +correctly if they are not appropriately surrounded by parentheses. See the +following example: + +```cpp +class MyMock { + public: + // The following 2 lines will not compile due to commas in the arguments: + MOCK_METHOD(std::pair, GetPair, ()); // Error! + MOCK_METHOD(bool, CheckMap, (std::map, bool)); // Error! + + // One solution - wrap arguments that contain commas in parentheses: + MOCK_METHOD((std::pair), GetPair, ()); + MOCK_METHOD(bool, CheckMap, ((std::map), bool)); + + // Another solution - use type aliases: + using BoolAndInt = std::pair; + MOCK_METHOD(BoolAndInt, GetPair, ()); + using MapIntDouble = std::map; + MOCK_METHOD(bool, CheckMap, (MapIntDouble, bool)); +}; +``` + +`MOCK_METHOD` must be used in the `public:` section of a mock class definition, +regardless of whether the method being mocked is `public`, `protected`, or +`private` in the base class. + +### EXPECT_CALL {#EXPECT_CALL} + +`EXPECT_CALL(`*`mock_object`*`,`*`method_name`*`(`*`matchers...`*`))` + +Creates an [expectation](../gmock_for_dummies.md#setting-expectations) that the +method *`method_name`* of the object *`mock_object`* is called with arguments +that match the given matchers *`matchers...`*. `EXPECT_CALL` must precede any +code that exercises the mock object. + +The parameter *`matchers...`* is a comma-separated list of +[matchers](../gmock_for_dummies.md#matchers-what-arguments-do-we-expect) that +correspond to each argument of the method *`method_name`*. The expectation will +apply only to calls of *`method_name`* whose arguments match all of the +matchers. If `(`*`matchers...`*`)` is omitted, the expectation behaves as if +each argument's matcher were a [wildcard matcher (`_`)](matchers.md#wildcard). +See the [Matchers Reference](matchers.md) for a list of all built-in matchers. + +The following chainable clauses can be used to modify the expectation, and they +must be used in the following order: + +```cpp +EXPECT_CALL(mock_object, method_name(matchers...)) + .With(multi_argument_matcher) // Can be used at most once + .Times(cardinality) // Can be used at most once + .InSequence(sequences...) // Can be used any number of times + .After(expectations...) // Can be used any number of times + .WillOnce(action) // Can be used any number of times + .WillRepeatedly(action) // Can be used at most once + .RetiresOnSaturation(); // Can be used at most once +``` + +See details for each modifier clause below. + +#### With {#EXPECT_CALL.With} + +`.With(`*`multi_argument_matcher`*`)` + +Restricts the expectation to apply only to mock function calls whose arguments +as a whole match the multi-argument matcher *`multi_argument_matcher`*. + +GoogleTest passes all of the arguments as one tuple into the matcher. The +parameter *`multi_argument_matcher`* must thus be a matcher of type +`Matcher>`, where `A1, ..., An` are the types of the +function arguments. + +For example, the following code sets the expectation that +`my_mock.SetPosition()` is called with any two arguments, the first argument +being less than the second: + +```cpp +using ::testing::_; +using ::testing::Lt; +... +EXPECT_CALL(my_mock, SetPosition(_, _)) + .With(Lt()); +``` + +GoogleTest provides some built-in matchers for 2-tuples, including the `Lt()` +matcher above. See [Multi-argument Matchers](matchers.md#MultiArgMatchers). + +The `With` clause can be used at most once on an expectation and must be the +first clause. + +#### Times {#EXPECT_CALL.Times} + +`.Times(`*`cardinality`*`)` + +Specifies how many times the mock function call is expected. + +The parameter *`cardinality`* represents the number of expected calls and can be +one of the following, all defined in the `::testing` namespace: + +| Cardinality | Meaning | +| ------------------- | --------------------------------------------------- | +| `AnyNumber()` | The function can be called any number of times. | +| `AtLeast(n)` | The function call is expected at least *n* times. | +| `AtMost(n)` | The function call is expected at most *n* times. | +| `Between(m, n)` | The function call is expected between *m* and *n* times, inclusive. | +| `Exactly(n)` or `n` | The function call is expected exactly *n* times. If *n* is 0, the call should never happen. | + +If the `Times` clause is omitted, GoogleTest infers the cardinality as follows: + +* If neither [`WillOnce`](#EXPECT_CALL.WillOnce) nor + [`WillRepeatedly`](#EXPECT_CALL.WillRepeatedly) are specified, the inferred + cardinality is `Times(1)`. +* If there are *n* `WillOnce` clauses and no `WillRepeatedly` clause, where + *n* >= 1, the inferred cardinality is `Times(n)`. +* If there are *n* `WillOnce` clauses and one `WillRepeatedly` clause, where + *n* >= 0, the inferred cardinality is `Times(AtLeast(n))`. + +The `Times` clause can be used at most once on an expectation. + +#### InSequence {#EXPECT_CALL.InSequence} + +`.InSequence(`*`sequences...`*`)` + +Specifies that the mock function call is expected in a certain sequence. + +The parameter *`sequences...`* is any number of [`Sequence`](#Sequence) objects. +Expected calls assigned to the same sequence are expected to occur in the order +the expectations are declared. + +For example, the following code sets the expectation that the `Reset()` method +of `my_mock` is called before both `GetSize()` and `Describe()`, and `GetSize()` +and `Describe()` can occur in any order relative to each other: + +```cpp +using ::testing::Sequence; +Sequence s1, s2; +... +EXPECT_CALL(my_mock, Reset()) + .InSequence(s1, s2); +EXPECT_CALL(my_mock, GetSize()) + .InSequence(s1); +EXPECT_CALL(my_mock, Describe()) + .InSequence(s2); +``` + +The `InSequence` clause can be used any number of times on an expectation. + +See also the [`InSequence` class](#InSequence). + +#### After {#EXPECT_CALL.After} + +`.After(`*`expectations...`*`)` + +Specifies that the mock function call is expected to occur after one or more +other calls. + +The parameter *`expectations...`* can be up to five +[`Expectation`](#Expectation) or [`ExpectationSet`](#ExpectationSet) objects. +The mock function call is expected to occur after all of the given expectations. + +For example, the following code sets the expectation that the `Describe()` +method of `my_mock` is called only after both `InitX()` and `InitY()` have been +called. + +```cpp +using ::testing::Expectation; +... +Expectation init_x = EXPECT_CALL(my_mock, InitX()); +Expectation init_y = EXPECT_CALL(my_mock, InitY()); +EXPECT_CALL(my_mock, Describe()) + .After(init_x, init_y); +``` + +The `ExpectationSet` object is helpful when the number of prerequisites for an +expectation is large or variable, for example: + +```cpp +using ::testing::ExpectationSet; +... +ExpectationSet all_inits; +// Collect all expectations of InitElement() calls +for (int i = 0; i < element_count; i++) { + all_inits += EXPECT_CALL(my_mock, InitElement(i)); +} +EXPECT_CALL(my_mock, Describe()) + .After(all_inits); // Expect Describe() call after all InitElement() calls +``` + +The `After` clause can be used any number of times on an expectation. + +#### WillOnce {#EXPECT_CALL.WillOnce} + +`.WillOnce(`*`action`*`)` + +Specifies the mock function's actual behavior when invoked, for a single +matching function call. + +The parameter *`action`* represents the +[action](../gmock_for_dummies.md#actions-what-should-it-do) that the function +call will perform. See the [Actions Reference](actions.md) for a list of +built-in actions. + +The use of `WillOnce` implicitly sets a cardinality on the expectation when +`Times` is not specified. See [`Times`](#EXPECT_CALL.Times). + +Each matching function call will perform the next action in the order declared. +For example, the following code specifies that `my_mock.GetNumber()` is expected +to be called exactly 3 times and will return `1`, `2`, and `3` respectively on +the first, second, and third calls: + +```cpp +using ::testing::Return; +... +EXPECT_CALL(my_mock, GetNumber()) + .WillOnce(Return(1)) + .WillOnce(Return(2)) + .WillOnce(Return(3)); +``` + +The `WillOnce` clause can be used any number of times on an expectation. Unlike +`WillRepeatedly`, the action fed to each `WillOnce` call will be called at most +once, so may be a move-only type and/or have an `&&`-qualified call operator. + +#### WillRepeatedly {#EXPECT_CALL.WillRepeatedly} + +`.WillRepeatedly(`*`action`*`)` + +Specifies the mock function's actual behavior when invoked, for all subsequent +matching function calls. Takes effect after the actions specified in the +[`WillOnce`](#EXPECT_CALL.WillOnce) clauses, if any, have been performed. + +The parameter *`action`* represents the +[action](../gmock_for_dummies.md#actions-what-should-it-do) that the function +call will perform. See the [Actions Reference](actions.md) for a list of +built-in actions. + +The use of `WillRepeatedly` implicitly sets a cardinality on the expectation +when `Times` is not specified. See [`Times`](#EXPECT_CALL.Times). + +If any `WillOnce` clauses have been specified, matching function calls will +perform those actions before the action specified by `WillRepeatedly`. See the +following example: + +```cpp +using ::testing::Return; +... +EXPECT_CALL(my_mock, GetName()) + .WillRepeatedly(Return("John Doe")); // Return "John Doe" on all calls + +EXPECT_CALL(my_mock, GetNumber()) + .WillOnce(Return(42)) // Return 42 on the first call + .WillRepeatedly(Return(7)); // Return 7 on all subsequent calls +``` + +The `WillRepeatedly` clause can be used at most once on an expectation. + +#### RetiresOnSaturation {#EXPECT_CALL.RetiresOnSaturation} + +`.RetiresOnSaturation()` + +Indicates that the expectation will no longer be active after the expected +number of matching function calls has been reached. + +The `RetiresOnSaturation` clause is only meaningful for expectations with an +upper-bounded cardinality. The expectation will *retire* (no longer match any +function calls) after it has been *saturated* (the upper bound has been +reached). See the following example: + +```cpp +using ::testing::_; +using ::testing::AnyNumber; +... +EXPECT_CALL(my_mock, SetNumber(_)) // Expectation 1 + .Times(AnyNumber()); +EXPECT_CALL(my_mock, SetNumber(7)) // Expectation 2 + .Times(2) + .RetiresOnSaturation(); +``` + +In the above example, the first two calls to `my_mock.SetNumber(7)` match +expectation 2, which then becomes inactive and no longer matches any calls. A +third call to `my_mock.SetNumber(7)` would then match expectation 1. Without +`RetiresOnSaturation()` on expectation 2, a third call to `my_mock.SetNumber(7)` +would match expectation 2 again, producing a failure since the limit of 2 calls +was exceeded. + +The `RetiresOnSaturation` clause can be used at most once on an expectation and +must be the last clause. + +### ON_CALL {#ON_CALL} + +`ON_CALL(`*`mock_object`*`,`*`method_name`*`(`*`matchers...`*`))` + +Defines what happens when the method *`method_name`* of the object +*`mock_object`* is called with arguments that match the given matchers +*`matchers...`*. Requires a modifier clause to specify the method's behavior. +*Does not* set any expectations that the method will be called. + +The parameter *`matchers...`* is a comma-separated list of +[matchers](../gmock_for_dummies.md#matchers-what-arguments-do-we-expect) that +correspond to each argument of the method *`method_name`*. The `ON_CALL` +specification will apply only to calls of *`method_name`* whose arguments match +all of the matchers. If `(`*`matchers...`*`)` is omitted, the behavior is as if +each argument's matcher were a [wildcard matcher (`_`)](matchers.md#wildcard). +See the [Matchers Reference](matchers.md) for a list of all built-in matchers. + +The following chainable clauses can be used to set the method's behavior, and +they must be used in the following order: + +```cpp +ON_CALL(mock_object, method_name(matchers...)) + .With(multi_argument_matcher) // Can be used at most once + .WillByDefault(action); // Required +``` + +See details for each modifier clause below. + +#### With {#ON_CALL.With} + +`.With(`*`multi_argument_matcher`*`)` + +Restricts the specification to only mock function calls whose arguments as a +whole match the multi-argument matcher *`multi_argument_matcher`*. + +GoogleTest passes all of the arguments as one tuple into the matcher. The +parameter *`multi_argument_matcher`* must thus be a matcher of type +`Matcher>`, where `A1, ..., An` are the types of the +function arguments. + +For example, the following code sets the default behavior when +`my_mock.SetPosition()` is called with any two arguments, the first argument +being less than the second: + +```cpp +using ::testing::_; +using ::testing::Lt; +using ::testing::Return; +... +ON_CALL(my_mock, SetPosition(_, _)) + .With(Lt()) + .WillByDefault(Return(true)); +``` + +GoogleTest provides some built-in matchers for 2-tuples, including the `Lt()` +matcher above. See [Multi-argument Matchers](matchers.md#MultiArgMatchers). + +The `With` clause can be used at most once with each `ON_CALL` statement. + +#### WillByDefault {#ON_CALL.WillByDefault} + +`.WillByDefault(`*`action`*`)` + +Specifies the default behavior of a matching mock function call. + +The parameter *`action`* represents the +[action](../gmock_for_dummies.md#actions-what-should-it-do) that the function +call will perform. See the [Actions Reference](actions.md) for a list of +built-in actions. + +For example, the following code specifies that by default, a call to +`my_mock.Greet()` will return `"hello"`: + +```cpp +using ::testing::Return; +... +ON_CALL(my_mock, Greet()) + .WillByDefault(Return("hello")); +``` + +The action specified by `WillByDefault` is superseded by the actions specified +on a matching `EXPECT_CALL` statement, if any. See the +[`WillOnce`](#EXPECT_CALL.WillOnce) and +[`WillRepeatedly`](#EXPECT_CALL.WillRepeatedly) clauses of `EXPECT_CALL`. + +The `WillByDefault` clause must be used exactly once with each `ON_CALL` +statement. + +## Classes {#classes} + +GoogleTest defines the following classes for working with mocks. + +### DefaultValue {#DefaultValue} + +`::testing::DefaultValue` + +Allows a user to specify the default value for a type `T` that is both copyable +and publicly destructible (i.e. anything that can be used as a function return +type). For mock functions with a return type of `T`, this default value is +returned from function calls that do not specify an action. + +Provides the static methods `Set()`, `SetFactory()`, and `Clear()` to manage the +default value: + +```cpp +// Sets the default value to be returned. T must be copy constructible. +DefaultValue::Set(value); + +// Sets a factory. Will be invoked on demand. T must be move constructible. +T MakeT(); +DefaultValue::SetFactory(&MakeT); + +// Unsets the default value. +DefaultValue::Clear(); +``` + +### NiceMock {#NiceMock} + +`::testing::NiceMock` + +Represents a mock object that suppresses warnings on +[uninteresting calls](../gmock_cook_book.md#uninteresting-vs-unexpected). The +template parameter `T` is any mock class, except for another `NiceMock`, +`NaggyMock`, or `StrictMock`. + +Usage of `NiceMock` is analogous to usage of `T`. `NiceMock` is a subclass +of `T`, so it can be used wherever an object of type `T` is accepted. In +addition, `NiceMock` can be constructed with any arguments that a constructor +of `T` accepts. + +For example, the following code suppresses warnings on the mock `my_mock` of +type `MockClass` if a method other than `DoSomething()` is called: + +```cpp +using ::testing::NiceMock; +... +NiceMock my_mock("some", "args"); +EXPECT_CALL(my_mock, DoSomething()); +... code that uses my_mock ... +``` + +`NiceMock` only works for mock methods defined using the `MOCK_METHOD` macro +directly in the definition of class `T`. If a mock method is defined in a base +class of `T`, a warning might still be generated. + +`NiceMock` might not work correctly if the destructor of `T` is not virtual. + +### NaggyMock {#NaggyMock} + +`::testing::NaggyMock` + +Represents a mock object that generates warnings on +[uninteresting calls](../gmock_cook_book.md#uninteresting-vs-unexpected). The +template parameter `T` is any mock class, except for another `NiceMock`, +`NaggyMock`, or `StrictMock`. + +Usage of `NaggyMock` is analogous to usage of `T`. `NaggyMock` is a +subclass of `T`, so it can be used wherever an object of type `T` is accepted. +In addition, `NaggyMock` can be constructed with any arguments that a +constructor of `T` accepts. + +For example, the following code generates warnings on the mock `my_mock` of type +`MockClass` if a method other than `DoSomething()` is called: + +```cpp +using ::testing::NaggyMock; +... +NaggyMock my_mock("some", "args"); +EXPECT_CALL(my_mock, DoSomething()); +... code that uses my_mock ... +``` + +Mock objects of type `T` by default behave the same way as `NaggyMock`. + +### StrictMock {#StrictMock} + +`::testing::StrictMock` + +Represents a mock object that generates test failures on +[uninteresting calls](../gmock_cook_book.md#uninteresting-vs-unexpected). The +template parameter `T` is any mock class, except for another `NiceMock`, +`NaggyMock`, or `StrictMock`. + +Usage of `StrictMock` is analogous to usage of `T`. `StrictMock` is a +subclass of `T`, so it can be used wherever an object of type `T` is accepted. +In addition, `StrictMock` can be constructed with any arguments that a +constructor of `T` accepts. + +For example, the following code generates a test failure on the mock `my_mock` +of type `MockClass` if a method other than `DoSomething()` is called: + +```cpp +using ::testing::StrictMock; +... +StrictMock my_mock("some", "args"); +EXPECT_CALL(my_mock, DoSomething()); +... code that uses my_mock ... +``` + +`StrictMock` only works for mock methods defined using the `MOCK_METHOD` +macro directly in the definition of class `T`. If a mock method is defined in a +base class of `T`, a failure might not be generated. + +`StrictMock` might not work correctly if the destructor of `T` is not +virtual. + +### Sequence {#Sequence} + +`::testing::Sequence` + +Represents a chronological sequence of expectations. See the +[`InSequence`](#EXPECT_CALL.InSequence) clause of `EXPECT_CALL` for usage. + +### InSequence {#InSequence} + +`::testing::InSequence` + +An object of this type causes all expectations encountered in its scope to be +put in an anonymous sequence. + +This allows more convenient expression of multiple expectations in a single +sequence: + +```cpp +using ::testing::InSequence; +{ + InSequence seq; + + // The following are expected to occur in the order declared. + EXPECT_CALL(...); + EXPECT_CALL(...); + ... + EXPECT_CALL(...); +} +``` + +The name of the `InSequence` object does not matter. + +### Expectation {#Expectation} + +`::testing::Expectation` + +Represents a mock function call expectation as created by +[`EXPECT_CALL`](#EXPECT_CALL): + +```cpp +using ::testing::Expectation; +Expectation my_expectation = EXPECT_CALL(...); +``` + +Useful for specifying sequences of expectations; see the +[`After`](#EXPECT_CALL.After) clause of `EXPECT_CALL`. + +### ExpectationSet {#ExpectationSet} + +`::testing::ExpectationSet` + +Represents a set of mock function call expectations. + +Use the `+=` operator to add [`Expectation`](#Expectation) objects to the set: + +```cpp +using ::testing::ExpectationSet; +ExpectationSet my_expectations; +my_expectations += EXPECT_CALL(...); +``` + +Useful for specifying sequences of expectations; see the +[`After`](#EXPECT_CALL.After) clause of `EXPECT_CALL`. diff --git a/test/googletest/docs/reference/testing.md b/test/googletest/docs/reference/testing.md new file mode 100644 index 0000000..3ed5211 --- /dev/null +++ b/test/googletest/docs/reference/testing.md @@ -0,0 +1,1453 @@ +# Testing Reference + + + +This page lists the facilities provided by GoogleTest for writing test programs. +To use them, add `#include `. + +## Macros + +GoogleTest defines the following macros for writing tests. + +### TEST {#TEST} + +
+TEST(TestSuiteName, TestName) {
+  ... statements ...
+}
+
+ +Defines an individual test named *`TestName`* in the test suite +*`TestSuiteName`*, consisting of the given statements. + +Both arguments *`TestSuiteName`* and *`TestName`* must be valid C++ identifiers +and must not contain underscores (`_`). Tests in different test suites can have +the same individual name. + +The statements within the test body can be any code under test. +[Assertions](assertions.md) used within the test body determine the outcome of +the test. + +### TEST_F {#TEST_F} + +
+TEST_F(TestFixtureName, TestName) {
+  ... statements ...
+}
+
+ +Defines an individual test named *`TestName`* that uses the test fixture class +*`TestFixtureName`*. The test suite name is *`TestFixtureName`*. + +Both arguments *`TestFixtureName`* and *`TestName`* must be valid C++ +identifiers and must not contain underscores (`_`). *`TestFixtureName`* must be +the name of a test fixture class—see +[Test Fixtures](../primer.md#same-data-multiple-tests). + +The statements within the test body can be any code under test. +[Assertions](assertions.md) used within the test body determine the outcome of +the test. + +### TEST_P {#TEST_P} + +
+TEST_P(TestFixtureName, TestName) {
+  ... statements ...
+}
+
+ +Defines an individual value-parameterized test named *`TestName`* that uses the +test fixture class *`TestFixtureName`*. The test suite name is +*`TestFixtureName`*. + +Both arguments *`TestFixtureName`* and *`TestName`* must be valid C++ +identifiers and must not contain underscores (`_`). *`TestFixtureName`* must be +the name of a value-parameterized test fixture class—see +[Value-Parameterized Tests](../advanced.md#value-parameterized-tests). + +The statements within the test body can be any code under test. Within the test +body, the test parameter can be accessed with the `GetParam()` function (see +[`WithParamInterface`](#WithParamInterface)). For example: + +```cpp +TEST_P(MyTestSuite, DoesSomething) { + ... + EXPECT_TRUE(DoSomething(GetParam())); + ... +} +``` + +[Assertions](assertions.md) used within the test body determine the outcome of +the test. + +See also [`INSTANTIATE_TEST_SUITE_P`](#INSTANTIATE_TEST_SUITE_P). + +### INSTANTIATE_TEST_SUITE_P {#INSTANTIATE_TEST_SUITE_P} + +`INSTANTIATE_TEST_SUITE_P(`*`InstantiationName`*`,`*`TestSuiteName`*`,`*`param_generator`*`)` +\ +`INSTANTIATE_TEST_SUITE_P(`*`InstantiationName`*`,`*`TestSuiteName`*`,`*`param_generator`*`,`*`name_generator`*`)` + +Instantiates the value-parameterized test suite *`TestSuiteName`* (defined with +[`TEST_P`](#TEST_P)). + +The argument *`InstantiationName`* is a unique name for the instantiation of the +test suite, to distinguish between multiple instantiations. In test output, the +instantiation name is added as a prefix to the test suite name +*`TestSuiteName`*. If *`InstantiationName`* is empty +(`INSTANTIATE_TEST_SUITE_P(, ...)`), no prefix is added. + +The argument *`param_generator`* is one of the following GoogleTest-provided +functions that generate the test parameters, all defined in the `::testing` +namespace: + + + +| Parameter Generator | Behavior | +| ------------------- | ---------------------------------------------------- | +| `Range(begin, end [, step])` | Yields values `{begin, begin+step, begin+step+step, ...}`. The values do not include `end`. `step` defaults to 1. | +| `Values(v1, v2, ..., vN)` | Yields values `{v1, v2, ..., vN}`. | +| `ValuesIn(container)` or `ValuesIn(begin,end)` | Yields values from a C-style array, an STL-style container, or an iterator range `[begin, end)`. | +| `Bool()` | Yields sequence `{false, true}`. | +| `Combine(g1, g2, ..., gN)` | Yields as `std::tuple` *n*-tuples all combinations (Cartesian product) of the values generated by the given *n* generators `g1`, `g2`, ..., `gN`. | +| `ConvertGenerator(g)` | Yields values generated by generator `g`, `static_cast` to `T`. | + +The optional last argument *`name_generator`* is a function or functor that +generates custom test name suffixes based on the test parameters. The function +must accept an argument of type +[`TestParamInfo`](#TestParamInfo) and return a `std::string`. +The test name suffix can only contain alphanumeric characters and underscores. +GoogleTest provides [`PrintToStringParamName`](#PrintToStringParamName), or a +custom function can be used for more control: + +```cpp +INSTANTIATE_TEST_SUITE_P( + MyInstantiation, MyTestSuite, + testing::Values(...), + [](const testing::TestParamInfo& info) { + // Can use info.param here to generate the test suffix + std::string name = ... + return name; + }); +``` + +For more information, see +[Value-Parameterized Tests](../advanced.md#value-parameterized-tests). + +See also +[`GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST`](#GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST). + +### TYPED_TEST_SUITE {#TYPED_TEST_SUITE} + +`TYPED_TEST_SUITE(`*`TestFixtureName`*`,`*`Types`*`)` +`TYPED_TEST_SUITE(`*`TestFixtureName`*`,`*`Types`*`,`*`NameGenerator`*`)` + +Defines a typed test suite based on the test fixture *`TestFixtureName`*. The +test suite name is *`TestFixtureName`*. + +The argument *`TestFixtureName`* is a fixture class template, parameterized by a +type, for example: + +```cpp +template +class MyFixture : public testing::Test { + public: + ... + using List = std::list; + static T shared_; + T value_; +}; +``` + +The argument *`Types`* is a [`Types`](#Types) object representing the list of +types to run the tests on, for example: + +```cpp +using MyTypes = ::testing::Types; +TYPED_TEST_SUITE(MyFixture, MyTypes); +``` + +The type alias (`using` or `typedef`) is necessary for the `TYPED_TEST_SUITE` +macro to parse correctly. + +The optional third argument *`NameGenerator`* allows specifying a class that +exposes a templated static function `GetName(int)`. For example: + +```cpp +class NameGenerator { + public: + template + static std::string GetName(int) { + if constexpr (std::is_same_v) return "char"; + if constexpr (std::is_same_v) return "int"; + if constexpr (std::is_same_v) return "unsignedInt"; + } +}; +TYPED_TEST_SUITE(MyFixture, MyTypes, NameGenerator); +``` + +See also [`TYPED_TEST`](#TYPED_TEST) and +[Typed Tests](../advanced.md#typed-tests) for more information. + +### TYPED_TEST {#TYPED_TEST} + +
+TYPED_TEST(TestSuiteName, TestName) {
+  ... statements ...
+}
+
+ +Defines an individual typed test named *`TestName`* in the typed test suite +*`TestSuiteName`*. The test suite must be defined with +[`TYPED_TEST_SUITE`](#TYPED_TEST_SUITE). + +Within the test body, the special name `TypeParam` refers to the type parameter, +and `TestFixture` refers to the fixture class. See the following example: + +```cpp +TYPED_TEST(MyFixture, Example) { + // Inside a test, refer to the special name TypeParam to get the type + // parameter. Since we are inside a derived class template, C++ requires + // us to visit the members of MyFixture via 'this'. + TypeParam n = this->value_; + + // To visit static members of the fixture, add the 'TestFixture::' + // prefix. + n += TestFixture::shared_; + + // To refer to typedefs in the fixture, add the 'typename TestFixture::' + // prefix. The 'typename' is required to satisfy the compiler. + typename TestFixture::List values; + + values.push_back(n); + ... +} +``` + +For more information, see [Typed Tests](../advanced.md#typed-tests). + +### TYPED_TEST_SUITE_P {#TYPED_TEST_SUITE_P} + +`TYPED_TEST_SUITE_P(`*`TestFixtureName`*`)` + +Defines a type-parameterized test suite based on the test fixture +*`TestFixtureName`*. The test suite name is *`TestFixtureName`*. + +The argument *`TestFixtureName`* is a fixture class template, parameterized by a +type. See [`TYPED_TEST_SUITE`](#TYPED_TEST_SUITE) for an example. + +See also [`TYPED_TEST_P`](#TYPED_TEST_P) and +[Type-Parameterized Tests](../advanced.md#type-parameterized-tests) for more +information. + +### TYPED_TEST_P {#TYPED_TEST_P} + +
+TYPED_TEST_P(TestSuiteName, TestName) {
+  ... statements ...
+}
+
+ +Defines an individual type-parameterized test named *`TestName`* in the +type-parameterized test suite *`TestSuiteName`*. The test suite must be defined +with [`TYPED_TEST_SUITE_P`](#TYPED_TEST_SUITE_P). + +Within the test body, the special name `TypeParam` refers to the type parameter, +and `TestFixture` refers to the fixture class. See [`TYPED_TEST`](#TYPED_TEST) +for an example. + +See also [`REGISTER_TYPED_TEST_SUITE_P`](#REGISTER_TYPED_TEST_SUITE_P) and +[Type-Parameterized Tests](../advanced.md#type-parameterized-tests) for more +information. + +### REGISTER_TYPED_TEST_SUITE_P {#REGISTER_TYPED_TEST_SUITE_P} + +`REGISTER_TYPED_TEST_SUITE_P(`*`TestSuiteName`*`,`*`TestNames...`*`)` + +Registers the type-parameterized tests *`TestNames...`* of the test suite +*`TestSuiteName`*. The test suite and tests must be defined with +[`TYPED_TEST_SUITE_P`](#TYPED_TEST_SUITE_P) and [`TYPED_TEST_P`](#TYPED_TEST_P). + +For example: + +```cpp +// Define the test suite and tests. +TYPED_TEST_SUITE_P(MyFixture); +TYPED_TEST_P(MyFixture, HasPropertyA) { ... } +TYPED_TEST_P(MyFixture, HasPropertyB) { ... } + +// Register the tests in the test suite. +REGISTER_TYPED_TEST_SUITE_P(MyFixture, HasPropertyA, HasPropertyB); +``` + +See also [`INSTANTIATE_TYPED_TEST_SUITE_P`](#INSTANTIATE_TYPED_TEST_SUITE_P) and +[Type-Parameterized Tests](../advanced.md#type-parameterized-tests) for more +information. + +### INSTANTIATE_TYPED_TEST_SUITE_P {#INSTANTIATE_TYPED_TEST_SUITE_P} + +`INSTANTIATE_TYPED_TEST_SUITE_P(`*`InstantiationName`*`,`*`TestSuiteName`*`,`*`Types`*`)` + +Instantiates the type-parameterized test suite *`TestSuiteName`*. The test suite +must be registered with +[`REGISTER_TYPED_TEST_SUITE_P`](#REGISTER_TYPED_TEST_SUITE_P). + +The argument *`InstantiationName`* is a unique name for the instantiation of the +test suite, to distinguish between multiple instantiations. In test output, the +instantiation name is added as a prefix to the test suite name +*`TestSuiteName`*. If *`InstantiationName`* is empty +(`INSTANTIATE_TYPED_TEST_SUITE_P(, ...)`), no prefix is added. + +The argument *`Types`* is a [`Types`](#Types) object representing the list of +types to run the tests on, for example: + +```cpp +using MyTypes = ::testing::Types; +INSTANTIATE_TYPED_TEST_SUITE_P(MyInstantiation, MyFixture, MyTypes); +``` + +The type alias (`using` or `typedef`) is necessary for the +`INSTANTIATE_TYPED_TEST_SUITE_P` macro to parse correctly. + +For more information, see +[Type-Parameterized Tests](../advanced.md#type-parameterized-tests). + +### FRIEND_TEST {#FRIEND_TEST} + +`FRIEND_TEST(`*`TestSuiteName`*`,`*`TestName`*`)` + +Within a class body, declares an individual test as a friend of the class, +enabling the test to access private class members. + +If the class is defined in a namespace, then in order to be friends of the +class, test fixtures and tests must be defined in the exact same namespace, +without inline or anonymous namespaces. + +For example, if the class definition looks like the following: + +```cpp +namespace my_namespace { + +class MyClass { + friend class MyClassTest; + FRIEND_TEST(MyClassTest, HasPropertyA); + FRIEND_TEST(MyClassTest, HasPropertyB); + ... definition of class MyClass ... +}; + +} // namespace my_namespace +``` + +Then the test code should look like: + +```cpp +namespace my_namespace { + +class MyClassTest : public testing::Test { + ... +}; + +TEST_F(MyClassTest, HasPropertyA) { ... } +TEST_F(MyClassTest, HasPropertyB) { ... } + +} // namespace my_namespace +``` + +See [Testing Private Code](../advanced.md#testing-private-code) for more +information. + +### SCOPED_TRACE {#SCOPED_TRACE} + +`SCOPED_TRACE(`*`message`*`)` + +Causes the current file name, line number, and the given message *`message`* to +be added to the failure message for each assertion failure that occurs in the +scope. + +For more information, see +[Adding Traces to Assertions](../advanced.md#adding-traces-to-assertions). + +See also the [`ScopedTrace` class](#ScopedTrace). + +### GTEST_SKIP {#GTEST_SKIP} + +`GTEST_SKIP()` + +Prevents further test execution at runtime. + +Can be used in individual test cases or in the `SetUp()` methods of test +environments or test fixtures (classes derived from the +[`Environment`](#Environment) or [`Test`](#Test) classes). If used in a global +test environment `SetUp()` method, it skips all tests in the test program. If +used in a test fixture `SetUp()` method, it skips all tests in the corresponding +test suite. + +Similar to assertions, `GTEST_SKIP` allows streaming a custom message into it. + +See [Skipping Test Execution](../advanced.md#skipping-test-execution) for more +information. + +### GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST {#GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST} + +`GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(`*`TestSuiteName`*`)` + +Allows the value-parameterized test suite *`TestSuiteName`* to be +uninstantiated. + +By default, every [`TEST_P`](#TEST_P) call without a corresponding +[`INSTANTIATE_TEST_SUITE_P`](#INSTANTIATE_TEST_SUITE_P) call causes a failing +test in the test suite `GoogleTestVerification`. +`GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST` suppresses this failure for the +given test suite. + +## Classes and types + +GoogleTest defines the following classes and types to help with writing tests. + +### AssertionResult {#AssertionResult} + +`testing::AssertionResult` + +A class for indicating whether an assertion was successful. + +When the assertion wasn't successful, the `AssertionResult` object stores a +non-empty failure message that can be retrieved with the object's `message()` +method. + +To create an instance of this class, use one of the factory functions +[`AssertionSuccess()`](#AssertionSuccess) or +[`AssertionFailure()`](#AssertionFailure). + +### AssertionException {#AssertionException} + +`testing::AssertionException` + +Exception which can be thrown from +[`TestEventListener::OnTestPartResult`](#TestEventListener::OnTestPartResult). + +### EmptyTestEventListener {#EmptyTestEventListener} + +`testing::EmptyTestEventListener` + +Provides an empty implementation of all methods in the +[`TestEventListener`](#TestEventListener) interface, such that a subclass only +needs to override the methods it cares about. + +### Environment {#Environment} + +`testing::Environment` + +Represents a global test environment. See +[Global Set-Up and Tear-Down](../advanced.md#global-set-up-and-tear-down). + +#### Protected Methods {#Environment-protected} + +##### SetUp {#Environment::SetUp} + +`virtual void Environment::SetUp()` + +Override this to define how to set up the environment. + +##### TearDown {#Environment::TearDown} + +`virtual void Environment::TearDown()` + +Override this to define how to tear down the environment. + +### ScopedTrace {#ScopedTrace} + +`testing::ScopedTrace` + +An instance of this class causes a trace to be included in every test failure +message generated by code in the scope of the lifetime of the `ScopedTrace` +instance. The effect is undone with the destruction of the instance. + +The `ScopedTrace` constructor has the following form: + +```cpp +template +ScopedTrace(const char* file, int line, const T& message) +``` + +Example usage: + +```cpp +testing::ScopedTrace trace("file.cc", 123, "message"); +``` + +The resulting trace includes the given source file path and line number, and the +given message. The `message` argument can be anything streamable to +`std::ostream`. + +See also [`SCOPED_TRACE`](#SCOPED_TRACE). + +### Test {#Test} + +`testing::Test` + +The abstract class that all tests inherit from. `Test` is not copyable. + +#### Public Methods {#Test-public} + +##### SetUpTestSuite {#Test::SetUpTestSuite} + +`static void Test::SetUpTestSuite()` + +Performs shared setup for all tests in the test suite. GoogleTest calls +`SetUpTestSuite()` before running the first test in the test suite. + +##### TearDownTestSuite {#Test::TearDownTestSuite} + +`static void Test::TearDownTestSuite()` + +Performs shared teardown for all tests in the test suite. GoogleTest calls +`TearDownTestSuite()` after running the last test in the test suite. + +##### HasFatalFailure {#Test::HasFatalFailure} + +`static bool Test::HasFatalFailure()` + +Returns true if and only if the current test has a fatal failure. + +##### HasNonfatalFailure {#Test::HasNonfatalFailure} + +`static bool Test::HasNonfatalFailure()` + +Returns true if and only if the current test has a nonfatal failure. + +##### HasFailure {#Test::HasFailure} + +`static bool Test::HasFailure()` + +Returns true if and only if the current test has any failure, either fatal or +nonfatal. + +##### IsSkipped {#Test::IsSkipped} + +`static bool Test::IsSkipped()` + +Returns true if and only if the current test was skipped. + +##### RecordProperty {#Test::RecordProperty} + +`static void Test::RecordProperty(const std::string& key, const std::string& +value)` \ +`static void Test::RecordProperty(const std::string& key, int value)` + +Logs a property for the current test, test suite, or entire invocation of the +test program. Only the last value for a given key is logged. + +The key must be a valid XML attribute name, and cannot conflict with the ones +already used by GoogleTest (`name`, `file`, `line`, `status`, `time`, +`classname`, `type_param`, and `value_param`). + +`RecordProperty` is `public static` so it can be called from utility functions +that are not members of the test fixture. + +Calls to `RecordProperty` made during the lifespan of the test (from the moment +its constructor starts to the moment its destructor finishes) are output in XML +as attributes of the `` element. Properties recorded from a fixture's +`SetUpTestSuite` or `TearDownTestSuite` methods are logged as attributes of the +corresponding `` element. Calls to `RecordProperty` made in the +global context (before or after invocation of `RUN_ALL_TESTS` or from the +`SetUp`/`TearDown` methods of registered `Environment` objects) are output as +attributes of the `` element. + +#### Protected Methods {#Test-protected} + +##### SetUp {#Test::SetUp} + +`virtual void Test::SetUp()` + +Override this to perform test fixture setup. GoogleTest calls `SetUp()` before +running each individual test. + +##### TearDown {#Test::TearDown} + +`virtual void Test::TearDown()` + +Override this to perform test fixture teardown. GoogleTest calls `TearDown()` +after running each individual test. + +### TestWithParam {#TestWithParam} + +`testing::TestWithParam` + +A convenience class which inherits from both [`Test`](#Test) and +[`WithParamInterface`](#WithParamInterface). + +### TestSuite {#TestSuite} + +Represents a test suite. `TestSuite` is not copyable. + +#### Public Methods {#TestSuite-public} + +##### name {#TestSuite::name} + +`const char* TestSuite::name() const` + +Gets the name of the test suite. + +##### type_param {#TestSuite::type_param} + +`const char* TestSuite::type_param() const` + +Returns the name of the parameter type, or `NULL` if this is not a typed or +type-parameterized test suite. See [Typed Tests](../advanced.md#typed-tests) and +[Type-Parameterized Tests](../advanced.md#type-parameterized-tests). + +##### should_run {#TestSuite::should_run} + +`bool TestSuite::should_run() const` + +Returns true if any test in this test suite should run. + +##### successful_test_count {#TestSuite::successful_test_count} + +`int TestSuite::successful_test_count() const` + +Gets the number of successful tests in this test suite. + +##### skipped_test_count {#TestSuite::skipped_test_count} + +`int TestSuite::skipped_test_count() const` + +Gets the number of skipped tests in this test suite. + +##### failed_test_count {#TestSuite::failed_test_count} + +`int TestSuite::failed_test_count() const` + +Gets the number of failed tests in this test suite. + +##### reportable_disabled_test_count {#TestSuite::reportable_disabled_test_count} + +`int TestSuite::reportable_disabled_test_count() const` + +Gets the number of disabled tests that will be reported in the XML report. + +##### disabled_test_count {#TestSuite::disabled_test_count} + +`int TestSuite::disabled_test_count() const` + +Gets the number of disabled tests in this test suite. + +##### reportable_test_count {#TestSuite::reportable_test_count} + +`int TestSuite::reportable_test_count() const` + +Gets the number of tests to be printed in the XML report. + +##### test_to_run_count {#TestSuite::test_to_run_count} + +`int TestSuite::test_to_run_count() const` + +Get the number of tests in this test suite that should run. + +##### total_test_count {#TestSuite::total_test_count} + +`int TestSuite::total_test_count() const` + +Gets the number of all tests in this test suite. + +##### Passed {#TestSuite::Passed} + +`bool TestSuite::Passed() const` + +Returns true if and only if the test suite passed. + +##### Failed {#TestSuite::Failed} + +`bool TestSuite::Failed() const` + +Returns true if and only if the test suite failed. + +##### elapsed_time {#TestSuite::elapsed_time} + +`TimeInMillis TestSuite::elapsed_time() const` + +Returns the elapsed time, in milliseconds. + +##### start_timestamp {#TestSuite::start_timestamp} + +`TimeInMillis TestSuite::start_timestamp() const` + +Gets the time of the test suite start, in ms from the start of the UNIX epoch. + +##### GetTestInfo {#TestSuite::GetTestInfo} + +`const TestInfo* TestSuite::GetTestInfo(int i) const` + +Returns the [`TestInfo`](#TestInfo) for the `i`-th test among all the tests. `i` +can range from 0 to `total_test_count() - 1`. If `i` is not in that range, +returns `NULL`. + +##### ad_hoc_test_result {#TestSuite::ad_hoc_test_result} + +`const TestResult& TestSuite::ad_hoc_test_result() const` + +Returns the [`TestResult`](#TestResult) that holds test properties recorded +during execution of `SetUpTestSuite` and `TearDownTestSuite`. + +### TestInfo {#TestInfo} + +`testing::TestInfo` + +Stores information about a test. + +#### Public Methods {#TestInfo-public} + +##### test_suite_name {#TestInfo::test_suite_name} + +`const char* TestInfo::test_suite_name() const` + +Returns the test suite name. + +##### name {#TestInfo::name} + +`const char* TestInfo::name() const` + +Returns the test name. + +##### type_param {#TestInfo::type_param} + +`const char* TestInfo::type_param() const` + +Returns the name of the parameter type, or `NULL` if this is not a typed or +type-parameterized test. See [Typed Tests](../advanced.md#typed-tests) and +[Type-Parameterized Tests](../advanced.md#type-parameterized-tests). + +##### value_param {#TestInfo::value_param} + +`const char* TestInfo::value_param() const` + +Returns the text representation of the value parameter, or `NULL` if this is not +a value-parameterized test. See +[Value-Parameterized Tests](../advanced.md#value-parameterized-tests). + +##### file {#TestInfo::file} + +`const char* TestInfo::file() const` + +Returns the file name where this test is defined. + +##### line {#TestInfo::line} + +`int TestInfo::line() const` + +Returns the line where this test is defined. + +##### is_in_another_shard {#TestInfo::is_in_another_shard} + +`bool TestInfo::is_in_another_shard() const` + +Returns true if this test should not be run because it's in another shard. + +##### should_run {#TestInfo::should_run} + +`bool TestInfo::should_run() const` + +Returns true if this test should run, that is if the test is not disabled (or it +is disabled but the `also_run_disabled_tests` flag has been specified) and its +full name matches the user-specified filter. + +GoogleTest allows the user to filter the tests by their full names. Only the +tests that match the filter will run. See +[Running a Subset of the Tests](../advanced.md#running-a-subset-of-the-tests) +for more information. + +##### is_reportable {#TestInfo::is_reportable} + +`bool TestInfo::is_reportable() const` + +Returns true if and only if this test will appear in the XML report. + +##### result {#TestInfo::result} + +`const TestResult* TestInfo::result() const` + +Returns the result of the test. See [`TestResult`](#TestResult). + +### TestParamInfo {#TestParamInfo} + +`testing::TestParamInfo` + +Describes a parameter to a value-parameterized test. The type `T` is the type of +the parameter. + +Contains the fields `param` and `index` which hold the value of the parameter +and its integer index respectively. + +### UnitTest {#UnitTest} + +`testing::UnitTest` + +This class contains information about the test program. + +`UnitTest` is a singleton class. The only instance is created when +`UnitTest::GetInstance()` is first called. This instance is never deleted. + +`UnitTest` is not copyable. + +#### Public Methods {#UnitTest-public} + +##### GetInstance {#UnitTest::GetInstance} + +`static UnitTest* UnitTest::GetInstance()` + +Gets the singleton `UnitTest` object. The first time this method is called, a +`UnitTest` object is constructed and returned. Consecutive calls will return the +same object. + +##### original_working_dir {#UnitTest::original_working_dir} + +`const char* UnitTest::original_working_dir() const` + +Returns the working directory when the first [`TEST()`](#TEST) or +[`TEST_F()`](#TEST_F) was executed. The `UnitTest` object owns the string. + +##### current_test_suite {#UnitTest::current_test_suite} + +`const TestSuite* UnitTest::current_test_suite() const` + +Returns the [`TestSuite`](#TestSuite) object for the test that's currently +running, or `NULL` if no test is running. + +##### current_test_info {#UnitTest::current_test_info} + +`const TestInfo* UnitTest::current_test_info() const` + +Returns the [`TestInfo`](#TestInfo) object for the test that's currently +running, or `NULL` if no test is running. + +##### random_seed {#UnitTest::random_seed} + +`int UnitTest::random_seed() const` + +Returns the random seed used at the start of the current test run. + +##### successful_test_suite_count {#UnitTest::successful_test_suite_count} + +`int UnitTest::successful_test_suite_count() const` + +Gets the number of successful test suites. + +##### failed_test_suite_count {#UnitTest::failed_test_suite_count} + +`int UnitTest::failed_test_suite_count() const` + +Gets the number of failed test suites. + +##### total_test_suite_count {#UnitTest::total_test_suite_count} + +`int UnitTest::total_test_suite_count() const` + +Gets the number of all test suites. + +##### test_suite_to_run_count {#UnitTest::test_suite_to_run_count} + +`int UnitTest::test_suite_to_run_count() const` + +Gets the number of all test suites that contain at least one test that should +run. + +##### successful_test_count {#UnitTest::successful_test_count} + +`int UnitTest::successful_test_count() const` + +Gets the number of successful tests. + +##### skipped_test_count {#UnitTest::skipped_test_count} + +`int UnitTest::skipped_test_count() const` + +Gets the number of skipped tests. + +##### failed_test_count {#UnitTest::failed_test_count} + +`int UnitTest::failed_test_count() const` + +Gets the number of failed tests. + +##### reportable_disabled_test_count {#UnitTest::reportable_disabled_test_count} + +`int UnitTest::reportable_disabled_test_count() const` + +Gets the number of disabled tests that will be reported in the XML report. + +##### disabled_test_count {#UnitTest::disabled_test_count} + +`int UnitTest::disabled_test_count() const` + +Gets the number of disabled tests. + +##### reportable_test_count {#UnitTest::reportable_test_count} + +`int UnitTest::reportable_test_count() const` + +Gets the number of tests to be printed in the XML report. + +##### total_test_count {#UnitTest::total_test_count} + +`int UnitTest::total_test_count() const` + +Gets the number of all tests. + +##### test_to_run_count {#UnitTest::test_to_run_count} + +`int UnitTest::test_to_run_count() const` + +Gets the number of tests that should run. + +##### start_timestamp {#UnitTest::start_timestamp} + +`TimeInMillis UnitTest::start_timestamp() const` + +Gets the time of the test program start, in ms from the start of the UNIX epoch. + +##### elapsed_time {#UnitTest::elapsed_time} + +`TimeInMillis UnitTest::elapsed_time() const` + +Gets the elapsed time, in milliseconds. + +##### Passed {#UnitTest::Passed} + +`bool UnitTest::Passed() const` + +Returns true if and only if the unit test passed (i.e. all test suites passed). + +##### Failed {#UnitTest::Failed} + +`bool UnitTest::Failed() const` + +Returns true if and only if the unit test failed (i.e. some test suite failed or +something outside of all tests failed). + +##### GetTestSuite {#UnitTest::GetTestSuite} + +`const TestSuite* UnitTest::GetTestSuite(int i) const` + +Gets the [`TestSuite`](#TestSuite) object for the `i`-th test suite among all +the test suites. `i` can range from 0 to `total_test_suite_count() - 1`. If `i` +is not in that range, returns `NULL`. + +##### ad_hoc_test_result {#UnitTest::ad_hoc_test_result} + +`const TestResult& UnitTest::ad_hoc_test_result() const` + +Returns the [`TestResult`](#TestResult) containing information on test failures +and properties logged outside of individual test suites. + +##### listeners {#UnitTest::listeners} + +`TestEventListeners& UnitTest::listeners()` + +Returns the list of event listeners that can be used to track events inside +GoogleTest. See [`TestEventListeners`](#TestEventListeners). + +### TestEventListener {#TestEventListener} + +`testing::TestEventListener` + +The interface for tracing execution of tests. The methods below are listed in +the order the corresponding events are fired. + +#### Public Methods {#TestEventListener-public} + +##### OnTestProgramStart {#TestEventListener::OnTestProgramStart} + +`virtual void TestEventListener::OnTestProgramStart(const UnitTest& unit_test)` + +Fired before any test activity starts. + +##### OnTestIterationStart {#TestEventListener::OnTestIterationStart} + +`virtual void TestEventListener::OnTestIterationStart(const UnitTest& unit_test, +int iteration)` + +Fired before each iteration of tests starts. There may be more than one +iteration if `GTEST_FLAG(repeat)` is set. `iteration` is the iteration index, +starting from 0. + +##### OnEnvironmentsSetUpStart {#TestEventListener::OnEnvironmentsSetUpStart} + +`virtual void TestEventListener::OnEnvironmentsSetUpStart(const UnitTest& +unit_test)` + +Fired before environment set-up for each iteration of tests starts. + +##### OnEnvironmentsSetUpEnd {#TestEventListener::OnEnvironmentsSetUpEnd} + +`virtual void TestEventListener::OnEnvironmentsSetUpEnd(const UnitTest& +unit_test)` + +Fired after environment set-up for each iteration of tests ends. + +##### OnTestSuiteStart {#TestEventListener::OnTestSuiteStart} + +`virtual void TestEventListener::OnTestSuiteStart(const TestSuite& test_suite)` + +Fired before the test suite starts. + +##### OnTestStart {#TestEventListener::OnTestStart} + +`virtual void TestEventListener::OnTestStart(const TestInfo& test_info)` + +Fired before the test starts. + +##### OnTestPartResult {#TestEventListener::OnTestPartResult} + +`virtual void TestEventListener::OnTestPartResult(const TestPartResult& +test_part_result)` + +Fired after a failed assertion or a `SUCCEED()` invocation. If you want to throw +an exception from this function to skip to the next test, it must be an +[`AssertionException`](#AssertionException) or inherited from it. + +##### OnTestEnd {#TestEventListener::OnTestEnd} + +`virtual void TestEventListener::OnTestEnd(const TestInfo& test_info)` + +Fired after the test ends. + +##### OnTestSuiteEnd {#TestEventListener::OnTestSuiteEnd} + +`virtual void TestEventListener::OnTestSuiteEnd(const TestSuite& test_suite)` + +Fired after the test suite ends. + +##### OnEnvironmentsTearDownStart {#TestEventListener::OnEnvironmentsTearDownStart} + +`virtual void TestEventListener::OnEnvironmentsTearDownStart(const UnitTest& +unit_test)` + +Fired before environment tear-down for each iteration of tests starts. + +##### OnEnvironmentsTearDownEnd {#TestEventListener::OnEnvironmentsTearDownEnd} + +`virtual void TestEventListener::OnEnvironmentsTearDownEnd(const UnitTest& +unit_test)` + +Fired after environment tear-down for each iteration of tests ends. + +##### OnTestIterationEnd {#TestEventListener::OnTestIterationEnd} + +`virtual void TestEventListener::OnTestIterationEnd(const UnitTest& unit_test, +int iteration)` + +Fired after each iteration of tests finishes. + +##### OnTestProgramEnd {#TestEventListener::OnTestProgramEnd} + +`virtual void TestEventListener::OnTestProgramEnd(const UnitTest& unit_test)` + +Fired after all test activities have ended. + +### TestEventListeners {#TestEventListeners} + +`testing::TestEventListeners` + +Lets users add listeners to track events in GoogleTest. + +#### Public Methods {#TestEventListeners-public} + +##### Append {#TestEventListeners::Append} + +`void TestEventListeners::Append(TestEventListener* listener)` + +Appends an event listener to the end of the list. GoogleTest assumes ownership +of the listener (i.e. it will delete the listener when the test program +finishes). + +##### Release {#TestEventListeners::Release} + +`TestEventListener* TestEventListeners::Release(TestEventListener* listener)` + +Removes the given event listener from the list and returns it. It then becomes +the caller's responsibility to delete the listener. Returns `NULL` if the +listener is not found in the list. + +##### default_result_printer {#TestEventListeners::default_result_printer} + +`TestEventListener* TestEventListeners::default_result_printer() const` + +Returns the standard listener responsible for the default console output. Can be +removed from the listeners list to shut down default console output. Note that +removing this object from the listener list with +[`Release()`](#TestEventListeners::Release) transfers its ownership to the +caller and makes this function return `NULL` the next time. + +##### default_xml_generator {#TestEventListeners::default_xml_generator} + +`TestEventListener* TestEventListeners::default_xml_generator() const` + +Returns the standard listener responsible for the default XML output controlled +by the `--gtest_output=xml` flag. Can be removed from the listeners list by +users who want to shut down the default XML output controlled by this flag and +substitute it with custom one. Note that removing this object from the listener +list with [`Release()`](#TestEventListeners::Release) transfers its ownership to +the caller and makes this function return `NULL` the next time. + +### TestPartResult {#TestPartResult} + +`testing::TestPartResult` + +A copyable object representing the result of a test part (i.e. an assertion or +an explicit `FAIL()`, `ADD_FAILURE()`, or `SUCCESS()`). + +#### Public Methods {#TestPartResult-public} + +##### type {#TestPartResult::type} + +`Type TestPartResult::type() const` + +Gets the outcome of the test part. + +The return type `Type` is an enum defined as follows: + +```cpp +enum Type { + kSuccess, // Succeeded. + kNonFatalFailure, // Failed but the test can continue. + kFatalFailure, // Failed and the test should be terminated. + kSkip // Skipped. +}; +``` + +##### file_name {#TestPartResult::file_name} + +`const char* TestPartResult::file_name() const` + +Gets the name of the source file where the test part took place, or `NULL` if +it's unknown. + +##### line_number {#TestPartResult::line_number} + +`int TestPartResult::line_number() const` + +Gets the line in the source file where the test part took place, or `-1` if it's +unknown. + +##### summary {#TestPartResult::summary} + +`const char* TestPartResult::summary() const` + +Gets the summary of the failure message. + +##### message {#TestPartResult::message} + +`const char* TestPartResult::message() const` + +Gets the message associated with the test part. + +##### skipped {#TestPartResult::skipped} + +`bool TestPartResult::skipped() const` + +Returns true if and only if the test part was skipped. + +##### passed {#TestPartResult::passed} + +`bool TestPartResult::passed() const` + +Returns true if and only if the test part passed. + +##### nonfatally_failed {#TestPartResult::nonfatally_failed} + +`bool TestPartResult::nonfatally_failed() const` + +Returns true if and only if the test part non-fatally failed. + +##### fatally_failed {#TestPartResult::fatally_failed} + +`bool TestPartResult::fatally_failed() const` + +Returns true if and only if the test part fatally failed. + +##### failed {#TestPartResult::failed} + +`bool TestPartResult::failed() const` + +Returns true if and only if the test part failed. + +### TestProperty {#TestProperty} + +`testing::TestProperty` + +A copyable object representing a user-specified test property which can be +output as a key/value string pair. + +#### Public Methods {#TestProperty-public} + +##### key {#key} + +`const char* key() const` + +Gets the user-supplied key. + +##### value {#value} + +`const char* value() const` + +Gets the user-supplied value. + +##### SetValue {#SetValue} + +`void SetValue(const std::string& new_value)` + +Sets a new value, overriding the previous one. + +### TestResult {#TestResult} + +`testing::TestResult` + +Contains information about the result of a single test. + +`TestResult` is not copyable. + +#### Public Methods {#TestResult-public} + +##### total_part_count {#TestResult::total_part_count} + +`int TestResult::total_part_count() const` + +Gets the number of all test parts. This is the sum of the number of successful +test parts and the number of failed test parts. + +##### test_property_count {#TestResult::test_property_count} + +`int TestResult::test_property_count() const` + +Returns the number of test properties. + +##### Passed {#TestResult::Passed} + +`bool TestResult::Passed() const` + +Returns true if and only if the test passed (i.e. no test part failed). + +##### Skipped {#TestResult::Skipped} + +`bool TestResult::Skipped() const` + +Returns true if and only if the test was skipped. + +##### Failed {#TestResult::Failed} + +`bool TestResult::Failed() const` + +Returns true if and only if the test failed. + +##### HasFatalFailure {#TestResult::HasFatalFailure} + +`bool TestResult::HasFatalFailure() const` + +Returns true if and only if the test fatally failed. + +##### HasNonfatalFailure {#TestResult::HasNonfatalFailure} + +`bool TestResult::HasNonfatalFailure() const` + +Returns true if and only if the test has a non-fatal failure. + +##### elapsed_time {#TestResult::elapsed_time} + +`TimeInMillis TestResult::elapsed_time() const` + +Returns the elapsed time, in milliseconds. + +##### start_timestamp {#TestResult::start_timestamp} + +`TimeInMillis TestResult::start_timestamp() const` + +Gets the time of the test case start, in ms from the start of the UNIX epoch. + +##### GetTestPartResult {#TestResult::GetTestPartResult} + +`const TestPartResult& TestResult::GetTestPartResult(int i) const` + +Returns the [`TestPartResult`](#TestPartResult) for the `i`-th test part result +among all the results. `i` can range from 0 to `total_part_count() - 1`. If `i` +is not in that range, aborts the program. + +##### GetTestProperty {#TestResult::GetTestProperty} + +`const TestProperty& TestResult::GetTestProperty(int i) const` + +Returns the [`TestProperty`](#TestProperty) object for the `i`-th test property. +`i` can range from 0 to `test_property_count() - 1`. If `i` is not in that +range, aborts the program. + +### TimeInMillis {#TimeInMillis} + +`testing::TimeInMillis` + +An integer type representing time in milliseconds. + +### Types {#Types} + +`testing::Types` + +Represents a list of types for use in typed tests and type-parameterized tests. + +The template argument `T...` can be any number of types, for example: + +``` +testing::Types +``` + +See [Typed Tests](../advanced.md#typed-tests) and +[Type-Parameterized Tests](../advanced.md#type-parameterized-tests) for more +information. + +### WithParamInterface {#WithParamInterface} + +`testing::WithParamInterface` + +The pure interface class that all value-parameterized tests inherit from. + +A value-parameterized test fixture class must inherit from both [`Test`](#Test) +and `WithParamInterface`. In most cases that just means inheriting from +[`TestWithParam`](#TestWithParam), but more complicated test hierarchies may +need to inherit from `Test` and `WithParamInterface` at different levels. + +This interface defines the type alias `ParamType` for the parameter type `T` and +has support for accessing the test parameter value via the `GetParam()` method: + +``` +static const ParamType& GetParam() +``` + +For more information, see +[Value-Parameterized Tests](../advanced.md#value-parameterized-tests). + +## Functions + +GoogleTest defines the following functions to help with writing and running +tests. + +### InitGoogleTest {#InitGoogleTest} + +`void testing::InitGoogleTest(int* argc, char** argv)` \ +`void testing::InitGoogleTest(int* argc, wchar_t** argv)` \ +`void testing::InitGoogleTest()` + +Initializes GoogleTest. This must be called before calling +[`RUN_ALL_TESTS()`](#RUN_ALL_TESTS). In particular, it parses the command line +for the flags that GoogleTest recognizes. Whenever a GoogleTest flag is seen, it +is removed from `argv`, and `*argc` is decremented. Keep in mind that `argv` +must terminate with a `NULL` pointer (i.e. `argv[argc]` is `NULL`), which is +already the case with the default `argv` passed to `main`. + +No value is returned. Instead, the GoogleTest flag variables are updated. + +The `InitGoogleTest(int* argc, wchar_t** argv)` overload can be used in Windows +programs compiled in `UNICODE` mode. + +The argument-less `InitGoogleTest()` overload can be used on Arduino/embedded +platforms where there is no `argc`/`argv`. + +### AddGlobalTestEnvironment {#AddGlobalTestEnvironment} + +`Environment* testing::AddGlobalTestEnvironment(Environment* env)` + +Adds a test environment to the test program. Must be called before +[`RUN_ALL_TESTS()`](#RUN_ALL_TESTS) is called. See +[Global Set-Up and Tear-Down](../advanced.md#global-set-up-and-tear-down) for +more information. + +See also [`Environment`](#Environment). + +### RegisterTest {#RegisterTest} + +```cpp +template +TestInfo* testing::RegisterTest(const char* test_suite_name, const char* test_name, + const char* type_param, const char* value_param, + const char* file, int line, Factory factory) +``` + +Dynamically registers a test with the framework. + +The `factory` argument is a factory callable (move-constructible) object or +function pointer that creates a new instance of the `Test` object. It handles +ownership to the caller. The signature of the callable is `Fixture*()`, where +`Fixture` is the test fixture class for the test. All tests registered with the +same `test_suite_name` must return the same fixture type. This is checked at +runtime. + +The framework will infer the fixture class from the factory and will call the +`SetUpTestSuite` and `TearDownTestSuite` methods for it. + +Must be called before [`RUN_ALL_TESTS()`](#RUN_ALL_TESTS) is invoked, otherwise +behavior is undefined. + +See +[Registering tests programmatically](../advanced.md#registering-tests-programmatically) +for more information. + +### RUN_ALL_TESTS {#RUN_ALL_TESTS} + +`int RUN_ALL_TESTS()` + +Use this function in `main()` to run all tests. It returns `0` if all tests are +successful, or `1` otherwise. + +`RUN_ALL_TESTS()` should be invoked after the command line has been parsed by +[`InitGoogleTest()`](#InitGoogleTest). + +This function was formerly a macro; thus, it is in the global namespace and has +an all-caps name. + +### AssertionSuccess {#AssertionSuccess} + +`AssertionResult testing::AssertionSuccess()` + +Creates a successful assertion result. See +[`AssertionResult`](#AssertionResult). + +### AssertionFailure {#AssertionFailure} + +`AssertionResult testing::AssertionFailure()` + +Creates a failed assertion result. Use the `<<` operator to store a failure +message: + +```cpp +testing::AssertionFailure() << "My failure message"; +``` + +See [`AssertionResult`](#AssertionResult). + +### StaticAssertTypeEq {#StaticAssertTypeEq} + +`testing::StaticAssertTypeEq()` + +Compile-time assertion for type equality. Compiles if and only if `T1` and `T2` +are the same type. The value it returns is irrelevant. + +See [Type Assertions](../advanced.md#type-assertions) for more information. + +### PrintToString {#PrintToString} + +`std::string testing::PrintToString(x)` + +Prints any value `x` using GoogleTest's value printer. + +See +[Teaching GoogleTest How to Print Your Values](../advanced.md#teaching-googletest-how-to-print-your-values) +for more information. + +### PrintToStringParamName {#PrintToStringParamName} + +`std::string testing::PrintToStringParamName(TestParamInfo& info)` + +A built-in parameterized test name generator which returns the result of +[`PrintToString`](#PrintToString) called on `info.param`. Does not work when the +test parameter is a `std::string` or C string. See +[Specifying Names for Value-Parameterized Test Parameters](../advanced.md#specifying-names-for-value-parameterized-test-parameters) +for more information. + +See also [`TestParamInfo`](#TestParamInfo) and +[`INSTANTIATE_TEST_SUITE_P`](#INSTANTIATE_TEST_SUITE_P). diff --git a/test/googletest/docs/samples.md b/test/googletest/docs/samples.md new file mode 100644 index 0000000..dedc590 --- /dev/null +++ b/test/googletest/docs/samples.md @@ -0,0 +1,22 @@ +# Googletest Samples + +If you're like us, you'd like to look at +[googletest samples.](https://github.com/google/googletest/blob/main/googletest/samples) +The sample directory has a number of well-commented samples showing how to use a +variety of googletest features. + +* Sample #1 shows the basic steps of using googletest to test C++ functions. +* Sample #2 shows a more complex unit test for a class with multiple member + functions. +* Sample #3 uses a test fixture. +* Sample #4 teaches you how to use googletest and `googletest.h` together to + get the best of both libraries. +* Sample #5 puts shared testing logic in a base test fixture, and reuses it in + derived fixtures. +* Sample #6 demonstrates type-parameterized tests. +* Sample #7 teaches the basics of value-parameterized tests. +* Sample #8 shows using `Combine()` in value-parameterized tests. +* Sample #9 shows use of the listener API to modify Google Test's console + output and the use of its reflection API to inspect test results. +* Sample #10 shows use of the listener API to implement a primitive memory + leak checker. diff --git a/test/googletest/fake_fuchsia_sdk.bzl b/test/googletest/fake_fuchsia_sdk.bzl new file mode 100644 index 0000000..2024dc6 --- /dev/null +++ b/test/googletest/fake_fuchsia_sdk.bzl @@ -0,0 +1,33 @@ +"""Provides a fake @fuchsia_sdk implementation that's used when the real one isn't available. + +This is needed since bazel queries on targets that depend on //:gtest (eg: +`bazel query "deps(set(//googletest/test:gtest_all_test))"`) will fail if @fuchsia_sdk is not +defined when bazel is evaluating the transitive closure of the query target. + +See https://github.com/google/googletest/issues/4472. +""" + +def _fake_fuchsia_sdk_impl(repo_ctx): + for stub_target in repo_ctx.attr._stub_build_targets: + stub_package = stub_target + stub_target_name = stub_target.split("/")[-1] + repo_ctx.file("%s/BUILD.bazel" % stub_package, """ +filegroup( + name = "%s", +) +""" % stub_target_name) + +fake_fuchsia_sdk = repository_rule( + doc = "Used to create a fake @fuchsia_sdk repository with stub build targets.", + implementation = _fake_fuchsia_sdk_impl, + attrs = { + "_stub_build_targets": attr.string_list( + doc = "The stub build targets to initialize.", + default = [ + "pkg/fdio", + "pkg/syslog", + "pkg/zx", + ], + ), + }, +) diff --git a/test/googletest/googlemock/CMakeLists.txt b/test/googletest/googlemock/CMakeLists.txt new file mode 100644 index 0000000..99b2411 --- /dev/null +++ b/test/googletest/googlemock/CMakeLists.txt @@ -0,0 +1,210 @@ +######################################################################## +# Note: CMake support is community-based. The maintainers do not use CMake +# internally. +# +# CMake build script for Google Mock. +# +# To run the tests for Google Mock itself on Linux, use 'make test' or +# ctest. You can select which tests to run using 'ctest -R regex'. +# For more options, run 'ctest --help'. + +option(gmock_build_tests "Build all of Google Mock's own tests." OFF) + +# A directory to find Google Test sources. +if (EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/gtest/CMakeLists.txt") + set(gtest_dir gtest) +else() + set(gtest_dir ../googletest) +endif() + +# Defines pre_project_set_up_hermetic_build() and set_up_hermetic_build(). +include("${gtest_dir}/cmake/hermetic_build.cmake" OPTIONAL) + +if (COMMAND pre_project_set_up_hermetic_build) + # Google Test also calls hermetic setup functions from add_subdirectory, + # although its changes will not affect things at the current scope. + pre_project_set_up_hermetic_build() +endif() + +######################################################################## +# +# Project-wide settings + +# Name of the project. +# +# CMake files in this project can refer to the root source directory +# as ${gmock_SOURCE_DIR} and to the root binary directory as +# ${gmock_BINARY_DIR}. +# Language "C" is required for find_package(Threads). +cmake_minimum_required(VERSION 3.13) +project(gmock VERSION ${GOOGLETEST_VERSION} LANGUAGES CXX C) + +if (COMMAND set_up_hermetic_build) + set_up_hermetic_build() +endif() + +# Instructs CMake to process Google Test's CMakeLists.txt and add its +# targets to the current scope. We are placing Google Test's binary +# directory in a subdirectory of our own as VC compilation may break +# if they are the same (the default). +add_subdirectory("${gtest_dir}" "${gmock_BINARY_DIR}/${gtest_dir}") + + +# These commands only run if this is the main project +if(CMAKE_PROJECT_NAME STREQUAL "gmock" OR CMAKE_PROJECT_NAME STREQUAL "googletest-distribution") + # BUILD_SHARED_LIBS is a standard CMake variable, but we declare it here to + # make it prominent in the GUI. + option(BUILD_SHARED_LIBS "Build shared libraries (DLLs)." OFF) +else() + mark_as_advanced(gmock_build_tests) +endif() + +# Although Google Test's CMakeLists.txt calls this function, the +# changes there don't affect the current scope. Therefore we have to +# call it again here. +config_compiler_and_linker() # from ${gtest_dir}/cmake/internal_utils.cmake + +# Adds Google Mock's and Google Test's header directories to the search path. +# Get Google Test's include dirs from the target, gtest_SOURCE_DIR is broken +# when using fetch-content with the name "GTest". +get_target_property(gtest_include_dirs gtest INCLUDE_DIRECTORIES) +set(gmock_build_include_dirs + "${gmock_SOURCE_DIR}/include" + "${gmock_SOURCE_DIR}" + "${gtest_include_dirs}") +include_directories(${gmock_build_include_dirs}) + +######################################################################## +# +# Defines the gmock & gmock_main libraries. User tests should link +# with one of them. + +# Google Mock libraries. We build them using more strict warnings than what +# are used for other targets, to ensure that Google Mock can be compiled by +# a user aggressive about warnings. +if (MSVC) + cxx_library(gmock + "${cxx_strict}" + "${gtest_dir}/src/gtest-all.cc" + src/gmock-all.cc) + + cxx_library(gmock_main + "${cxx_strict}" + "${gtest_dir}/src/gtest-all.cc" + src/gmock-all.cc + src/gmock_main.cc) +else() + cxx_library(gmock "${cxx_strict}" src/gmock-all.cc) + target_link_libraries(gmock PUBLIC gtest) + set_target_properties(gmock PROPERTIES VERSION ${GOOGLETEST_VERSION}) + cxx_library(gmock_main "${cxx_strict}" src/gmock_main.cc) + target_link_libraries(gmock_main PUBLIC gmock) + set_target_properties(gmock_main PROPERTIES VERSION ${GOOGLETEST_VERSION}) +endif() + +string(REPLACE ";" "$" dirs "${gmock_build_include_dirs}") +target_include_directories(gmock SYSTEM INTERFACE + "$" + "$/${CMAKE_INSTALL_INCLUDEDIR}>") +target_include_directories(gmock_main SYSTEM INTERFACE + "$" + "$/${CMAKE_INSTALL_INCLUDEDIR}>") + +######################################################################## +# +# Install rules. +install_project(gmock gmock_main) + +######################################################################## +# +# Google Mock's own tests. +# +# You can skip this section if you aren't interested in testing +# Google Mock itself. +# +# The tests are not built by default. To build them, set the +# gmock_build_tests option to ON. You can do it by running ccmake +# or specifying the -Dgmock_build_tests=ON flag when running cmake. + +if (gmock_build_tests) + # This must be set in the root directory for the tests to be run by + # 'make test' or ctest. + enable_testing() + + if (MINGW OR CYGWIN) + add_compile_options("-Wa,-mbig-obj") + endif() + + ############################################################ + # C++ tests built with standard compiler flags. + + cxx_test(gmock-actions_test gmock_main) + cxx_test(gmock-cardinalities_test gmock_main) + cxx_test(gmock_ex_test gmock_main) + cxx_test(gmock-function-mocker_test gmock_main) + cxx_test(gmock-internal-utils_test gmock_main) + cxx_test(gmock-matchers-arithmetic_test gmock_main) + cxx_test(gmock-matchers-comparisons_test gmock_main) + cxx_test(gmock-matchers-containers_test gmock_main) + cxx_test(gmock-matchers-misc_test gmock_main) + cxx_test(gmock-more-actions_test gmock_main) + cxx_test(gmock-nice-strict_test gmock_main) + cxx_test(gmock-port_test gmock_main) + cxx_test(gmock-spec-builders_test gmock_main) + cxx_test(gmock_link_test gmock_main test/gmock_link2_test.cc) + cxx_test(gmock_test gmock_main) + + if (DEFINED GTEST_HAS_PTHREAD) + cxx_test(gmock_stress_test gmock) + endif() + + # gmock_all_test is commented to save time building and running tests. + # Uncomment if necessary. + # cxx_test(gmock_all_test gmock_main) + + ############################################################ + # C++ tests built with non-standard compiler flags. + + if (MSVC) + cxx_library(gmock_main_no_exception "${cxx_no_exception}" + "${gtest_dir}/src/gtest-all.cc" src/gmock-all.cc src/gmock_main.cc) + + cxx_library(gmock_main_no_rtti "${cxx_no_rtti}" + "${gtest_dir}/src/gtest-all.cc" src/gmock-all.cc src/gmock_main.cc) + + else() + cxx_library(gmock_main_no_exception "${cxx_no_exception}" src/gmock_main.cc) + target_link_libraries(gmock_main_no_exception PUBLIC gmock) + + cxx_library(gmock_main_no_rtti "${cxx_no_rtti}" src/gmock_main.cc) + target_link_libraries(gmock_main_no_rtti PUBLIC gmock) + endif() + cxx_test_with_flags(gmock-more-actions_no_exception_test "${cxx_no_exception}" + gmock_main_no_exception test/gmock-more-actions_test.cc) + + cxx_test_with_flags(gmock_no_rtti_test "${cxx_no_rtti}" + gmock_main_no_rtti test/gmock-spec-builders_test.cc) + + cxx_shared_library(shared_gmock_main "${cxx_default}" + "${gtest_dir}/src/gtest-all.cc" src/gmock-all.cc src/gmock_main.cc) + + # Tests that a binary can be built with Google Mock as a shared library. On + # some system configurations, it may not possible to run the binary without + # knowing more details about the system configurations. We do not try to run + # this binary. To get a more robust shared library coverage, configure with + # -DBUILD_SHARED_LIBS=ON. + cxx_executable_with_flags(shared_gmock_test_ "${cxx_default}" + shared_gmock_main test/gmock-spec-builders_test.cc) + set_target_properties(shared_gmock_test_ + PROPERTIES + COMPILE_DEFINITIONS "GTEST_LINKED_AS_SHARED_LIBRARY=1") + + ############################################################ + # Python tests. + + cxx_executable(gmock_leak_test_ test gmock_main) + py_test(gmock_leak_test) + + cxx_executable(gmock_output_test_ test gmock) + py_test(gmock_output_test) +endif() diff --git a/test/googletest/googlemock/README.md b/test/googletest/googlemock/README.md new file mode 100644 index 0000000..e1103b1 --- /dev/null +++ b/test/googletest/googlemock/README.md @@ -0,0 +1,40 @@ +# Googletest Mocking (gMock) Framework + +### Overview + +Google's framework for writing and using C++ mock classes. It can help you +derive better designs of your system and write better tests. + +It is inspired by: + +* [jMock](http://www.jmock.org/) +* [EasyMock](https://easymock.org/) +* [Hamcrest](https://code.google.com/p/hamcrest/) + +It is designed with C++'s specifics in mind. + +gMock: + +- Provides a declarative syntax for defining mocks. +- Can define partial (hybrid) mocks, which are a cross of real and mock + objects. +- Handles functions of arbitrary types and overloaded functions. +- Comes with a rich set of matchers for validating function arguments. +- Uses an intuitive syntax for controlling the behavior of a mock. +- Does automatic verification of expectations (no record-and-replay needed). +- Allows arbitrary (partial) ordering constraints on function calls to be + expressed. +- Lets a user extend it by defining new matchers and actions. +- Does not use exceptions. +- Is easy to learn and use. + +Details and examples can be found here: + +* [gMock for Dummies](https://google.github.io/googletest/gmock_for_dummies.html) +* [Legacy gMock FAQ](https://google.github.io/googletest/gmock_faq.html) +* [gMock Cookbook](https://google.github.io/googletest/gmock_cook_book.html) +* [gMock Cheat Sheet](https://google.github.io/googletest/gmock_cheat_sheet.html) + +GoogleMock is a part of +[GoogleTest C++ testing framework](https://github.com/google/googletest/) and a +subject to the same requirements. diff --git a/test/googletest/googlemock/cmake/gmock.pc.in b/test/googletest/googlemock/cmake/gmock.pc.in new file mode 100644 index 0000000..23c67b5 --- /dev/null +++ b/test/googletest/googlemock/cmake/gmock.pc.in @@ -0,0 +1,10 @@ +libdir=@CMAKE_INSTALL_FULL_LIBDIR@ +includedir=@CMAKE_INSTALL_FULL_INCLUDEDIR@ + +Name: gmock +Description: GoogleMock (without main() function) +Version: @PROJECT_VERSION@ +URL: https://github.com/google/googletest +Requires: gtest = @PROJECT_VERSION@ +Libs: -L${libdir} -lgmock @CMAKE_THREAD_LIBS_INIT@ +Cflags: -I${includedir} @GTEST_HAS_PTHREAD_MACRO@ diff --git a/test/googletest/googlemock/cmake/gmock_main.pc.in b/test/googletest/googlemock/cmake/gmock_main.pc.in new file mode 100644 index 0000000..66ffea7 --- /dev/null +++ b/test/googletest/googlemock/cmake/gmock_main.pc.in @@ -0,0 +1,10 @@ +libdir=@CMAKE_INSTALL_FULL_LIBDIR@ +includedir=@CMAKE_INSTALL_FULL_INCLUDEDIR@ + +Name: gmock_main +Description: GoogleMock (with main() function) +Version: @PROJECT_VERSION@ +URL: https://github.com/google/googletest +Requires: gmock = @PROJECT_VERSION@ +Libs: -L${libdir} -lgmock_main @CMAKE_THREAD_LIBS_INIT@ +Cflags: -I${includedir} @GTEST_HAS_PTHREAD_MACRO@ diff --git a/test/googletest/googlemock/docs/README.md b/test/googletest/googlemock/docs/README.md new file mode 100644 index 0000000..1bc57b7 --- /dev/null +++ b/test/googletest/googlemock/docs/README.md @@ -0,0 +1,4 @@ +# Content Moved + +We are working on updates to the GoogleTest documentation, which has moved to +the top-level [docs](../../docs) directory. diff --git a/test/googletest/googlemock/include/gmock/gmock-actions.h b/test/googletest/googlemock/include/gmock/gmock-actions.h new file mode 100644 index 0000000..aa47079 --- /dev/null +++ b/test/googletest/googlemock/include/gmock/gmock-actions.h @@ -0,0 +1,2360 @@ +// Copyright 2007, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Google Mock - a framework for writing C++ mock classes. +// +// The ACTION* family of macros can be used in a namespace scope to +// define custom actions easily. The syntax: +// +// ACTION(name) { statements; } +// +// will define an action with the given name that executes the +// statements. The value returned by the statements will be used as +// the return value of the action. Inside the statements, you can +// refer to the K-th (0-based) argument of the mock function by +// 'argK', and refer to its type by 'argK_type'. For example: +// +// ACTION(IncrementArg1) { +// arg1_type temp = arg1; +// return ++(*temp); +// } +// +// allows you to write +// +// ...WillOnce(IncrementArg1()); +// +// You can also refer to the entire argument tuple and its type by +// 'args' and 'args_type', and refer to the mock function type and its +// return type by 'function_type' and 'return_type'. +// +// Note that you don't need to specify the types of the mock function +// arguments. However rest assured that your code is still type-safe: +// you'll get a compiler error if *arg1 doesn't support the ++ +// operator, or if the type of ++(*arg1) isn't compatible with the +// mock function's return type, for example. +// +// Sometimes you'll want to parameterize the action. For that you can use +// another macro: +// +// ACTION_P(name, param_name) { statements; } +// +// For example: +// +// ACTION_P(Add, n) { return arg0 + n; } +// +// will allow you to write: +// +// ...WillOnce(Add(5)); +// +// Note that you don't need to provide the type of the parameter +// either. If you need to reference the type of a parameter named +// 'foo', you can write 'foo_type'. For example, in the body of +// ACTION_P(Add, n) above, you can write 'n_type' to refer to the type +// of 'n'. +// +// We also provide ACTION_P2, ACTION_P3, ..., up to ACTION_P10 to support +// multi-parameter actions. +// +// For the purpose of typing, you can view +// +// ACTION_Pk(Foo, p1, ..., pk) { ... } +// +// as shorthand for +// +// template +// FooActionPk Foo(p1_type p1, ..., pk_type pk) { ... } +// +// In particular, you can provide the template type arguments +// explicitly when invoking Foo(), as in Foo(5, false); +// although usually you can rely on the compiler to infer the types +// for you automatically. You can assign the result of expression +// Foo(p1, ..., pk) to a variable of type FooActionPk. This can be useful when composing actions. +// +// You can also overload actions with different numbers of parameters: +// +// ACTION_P(Plus, a) { ... } +// ACTION_P2(Plus, a, b) { ... } +// +// While it's tempting to always use the ACTION* macros when defining +// a new action, you should also consider implementing ActionInterface +// or using MakePolymorphicAction() instead, especially if you need to +// use the action a lot. While these approaches require more work, +// they give you more control on the types of the mock function +// arguments and the action parameters, which in general leads to +// better compiler error messages that pay off in the long run. They +// also allow overloading actions based on parameter types (as opposed +// to just based on the number of parameters). +// +// CAVEAT: +// +// ACTION*() can only be used in a namespace scope as templates cannot be +// declared inside of a local class. +// Users can, however, define any local functors (e.g. a lambda) that +// can be used as actions. +// +// MORE INFORMATION: +// +// To learn more about using these macros, please search for 'ACTION' on +// https://github.com/google/googletest/blob/main/docs/gmock_cook_book.md + +// IWYU pragma: private, include "gmock/gmock.h" +// IWYU pragma: friend gmock/.* + +#ifndef GOOGLEMOCK_INCLUDE_GMOCK_GMOCK_ACTIONS_H_ +#define GOOGLEMOCK_INCLUDE_GMOCK_GMOCK_ACTIONS_H_ + +#ifndef _WIN32_WCE +#include +#endif + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "gmock/internal/gmock-internal-utils.h" +#include "gmock/internal/gmock-port.h" +#include "gmock/internal/gmock-pp.h" + +GTEST_DISABLE_MSC_WARNINGS_PUSH_(4100) + +namespace testing { + +// To implement an action Foo, define: +// 1. a class FooAction that implements the ActionInterface interface, and +// 2. a factory function that creates an Action object from a +// const FooAction*. +// +// The two-level delegation design follows that of Matcher, providing +// consistency for extension developers. It also eases ownership +// management as Action objects can now be copied like plain values. + +namespace internal { + +// BuiltInDefaultValueGetter::Get() returns a +// default-constructed T value. BuiltInDefaultValueGetter::Get() crashes with an error. +// +// This primary template is used when kDefaultConstructible is true. +template +struct BuiltInDefaultValueGetter { + static T Get() { return T(); } +}; +template +struct BuiltInDefaultValueGetter { + static T Get() { + Assert(false, __FILE__, __LINE__, + "Default action undefined for the function return type."); +#if defined(__GNUC__) || defined(__clang__) + __builtin_unreachable(); +#elif defined(_MSC_VER) + __assume(0); +#else + return Invalid(); + // The above statement will never be reached, but is required in + // order for this function to compile. +#endif + } +}; + +// BuiltInDefaultValue::Get() returns the "built-in" default value +// for type T, which is NULL when T is a raw pointer type, 0 when T is +// a numeric type, false when T is bool, or "" when T is string or +// std::string. In addition, in C++11 and above, it turns a +// default-constructed T value if T is default constructible. For any +// other type T, the built-in default T value is undefined, and the +// function will abort the process. +template +class BuiltInDefaultValue { + public: + // This function returns true if and only if type T has a built-in default + // value. + static bool Exists() { return ::std::is_default_constructible::value; } + + static T Get() { + return BuiltInDefaultValueGetter< + T, ::std::is_default_constructible::value>::Get(); + } +}; + +// This partial specialization says that we use the same built-in +// default value for T and const T. +template +class BuiltInDefaultValue { + public: + static bool Exists() { return BuiltInDefaultValue::Exists(); } + static T Get() { return BuiltInDefaultValue::Get(); } +}; + +// This partial specialization defines the default values for pointer +// types. +template +class BuiltInDefaultValue { + public: + static bool Exists() { return true; } + static T* Get() { return nullptr; } +}; + +// The following specializations define the default values for +// specific types we care about. +#define GMOCK_DEFINE_DEFAULT_ACTION_FOR_RETURN_TYPE_(type, value) \ + template <> \ + class BuiltInDefaultValue { \ + public: \ + static bool Exists() { return true; } \ + static type Get() { return value; } \ + } + +GMOCK_DEFINE_DEFAULT_ACTION_FOR_RETURN_TYPE_(void, ); // NOLINT +GMOCK_DEFINE_DEFAULT_ACTION_FOR_RETURN_TYPE_(::std::string, ""); +GMOCK_DEFINE_DEFAULT_ACTION_FOR_RETURN_TYPE_(bool, false); +GMOCK_DEFINE_DEFAULT_ACTION_FOR_RETURN_TYPE_(unsigned char, '\0'); +GMOCK_DEFINE_DEFAULT_ACTION_FOR_RETURN_TYPE_(signed char, '\0'); +GMOCK_DEFINE_DEFAULT_ACTION_FOR_RETURN_TYPE_(char, '\0'); + +// There's no need for a default action for signed wchar_t, as that +// type is the same as wchar_t for gcc, and invalid for MSVC. +// +// There's also no need for a default action for unsigned wchar_t, as +// that type is the same as unsigned int for gcc, and invalid for +// MSVC. +#if GMOCK_WCHAR_T_IS_NATIVE_ +GMOCK_DEFINE_DEFAULT_ACTION_FOR_RETURN_TYPE_(wchar_t, 0U); // NOLINT +#endif + +GMOCK_DEFINE_DEFAULT_ACTION_FOR_RETURN_TYPE_(unsigned short, 0U); // NOLINT +GMOCK_DEFINE_DEFAULT_ACTION_FOR_RETURN_TYPE_(signed short, 0); // NOLINT +GMOCK_DEFINE_DEFAULT_ACTION_FOR_RETURN_TYPE_(unsigned int, 0U); +GMOCK_DEFINE_DEFAULT_ACTION_FOR_RETURN_TYPE_(signed int, 0); +GMOCK_DEFINE_DEFAULT_ACTION_FOR_RETURN_TYPE_(unsigned long, 0UL); // NOLINT +GMOCK_DEFINE_DEFAULT_ACTION_FOR_RETURN_TYPE_(signed long, 0L); // NOLINT +GMOCK_DEFINE_DEFAULT_ACTION_FOR_RETURN_TYPE_(unsigned long long, 0); // NOLINT +GMOCK_DEFINE_DEFAULT_ACTION_FOR_RETURN_TYPE_(signed long long, 0); // NOLINT +GMOCK_DEFINE_DEFAULT_ACTION_FOR_RETURN_TYPE_(float, 0); +GMOCK_DEFINE_DEFAULT_ACTION_FOR_RETURN_TYPE_(double, 0); + +#undef GMOCK_DEFINE_DEFAULT_ACTION_FOR_RETURN_TYPE_ + +// Partial implementations of metaprogramming types from the standard library +// not available in C++11. + +template +struct negation + // NOLINTNEXTLINE + : std::integral_constant {}; + +// Base case: with zero predicates the answer is always true. +template +struct conjunction : std::true_type {}; + +// With a single predicate, the answer is that predicate. +template +struct conjunction : P1 {}; + +// With multiple predicates the answer is the first predicate if that is false, +// and we recurse otherwise. +template +struct conjunction + : std::conditional, P1>::type {}; + +template +struct disjunction : std::false_type {}; + +template +struct disjunction : P1 {}; + +template +struct disjunction + // NOLINTNEXTLINE + : std::conditional, P1>::type {}; + +template +using void_t = void; + +// Detects whether an expression of type `From` can be implicitly converted to +// `To` according to [conv]. In C++17, [conv]/3 defines this as follows: +// +// An expression e can be implicitly converted to a type T if and only if +// the declaration T t=e; is well-formed, for some invented temporary +// variable t ([dcl.init]). +// +// [conv]/2 implies we can use function argument passing to detect whether this +// initialization is valid. +// +// Note that this is distinct from is_convertible, which requires this be valid: +// +// To test() { +// return declval(); +// } +// +// In particular, is_convertible doesn't give the correct answer when `To` and +// `From` are the same non-moveable type since `declval` will be an rvalue +// reference, defeating the guaranteed copy elision that would otherwise make +// this function work. +// +// REQUIRES: `From` is not cv void. +template +struct is_implicitly_convertible { + private: + // A function that accepts a parameter of type T. This can be called with type + // U successfully only if U is implicitly convertible to T. + template + static void Accept(T); + + // A function that creates a value of type T. + template + static T Make(); + + // An overload be selected when implicit conversion from T to To is possible. + template (Make()))> + static std::true_type TestImplicitConversion(int); + + // A fallback overload selected in all other cases. + template + static std::false_type TestImplicitConversion(...); + + public: + using type = decltype(TestImplicitConversion(0)); + static constexpr bool value = type::value; +}; + +// Like std::invoke_result_t from C++17, but works only for objects with call +// operators (not e.g. member function pointers, which we don't need specific +// support for in OnceAction because std::function deals with them). +template +using call_result_t = decltype(std::declval()(std::declval()...)); + +template +struct is_callable_r_impl : std::false_type {}; + +// Specialize the struct for those template arguments where call_result_t is +// well-formed. When it's not, the generic template above is chosen, resulting +// in std::false_type. +template +struct is_callable_r_impl>, R, F, Args...> + : std::conditional< + std::is_void::value, // + std::true_type, // + is_implicitly_convertible, R>>::type {}; + +// Like std::is_invocable_r from C++17, but works only for objects with call +// operators. See the note on call_result_t. +template +using is_callable_r = is_callable_r_impl; + +// Like std::as_const from C++17. +template +typename std::add_const::type& as_const(T& t) { + return t; +} + +} // namespace internal + +// Specialized for function types below. +template +class OnceAction; + +// An action that can only be used once. +// +// This is accepted by WillOnce, which doesn't require the underlying action to +// be copy-constructible (only move-constructible), and promises to invoke it as +// an rvalue reference. This allows the action to work with move-only types like +// std::move_only_function in a type-safe manner. +// +// For example: +// +// // Assume we have some API that needs to accept a unique pointer to some +// // non-copyable object Foo. +// void AcceptUniquePointer(std::unique_ptr foo); +// +// // We can define an action that provides a Foo to that API. Because It +// // has to give away its unique pointer, it must not be called more than +// // once, so its call operator is &&-qualified. +// struct ProvideFoo { +// std::unique_ptr foo; +// +// void operator()() && { +// AcceptUniquePointer(std::move(Foo)); +// } +// }; +// +// // This action can be used with WillOnce. +// EXPECT_CALL(mock, Call) +// .WillOnce(ProvideFoo{std::make_unique(...)}); +// +// // But a call to WillRepeatedly will fail to compile. This is correct, +// // since the action cannot correctly be used repeatedly. +// EXPECT_CALL(mock, Call) +// .WillRepeatedly(ProvideFoo{std::make_unique(...)}); +// +// A less-contrived example would be an action that returns an arbitrary type, +// whose &&-qualified call operator is capable of dealing with move-only types. +template +class OnceAction final { + private: + // True iff we can use the given callable type (or lvalue reference) directly + // via StdFunctionAdaptor. + template + using IsDirectlyCompatible = internal::conjunction< + // It must be possible to capture the callable in StdFunctionAdaptor. + std::is_constructible::type, Callable>, + // The callable must be compatible with our signature. + internal::is_callable_r::type, + Args...>>; + + // True iff we can use the given callable type via StdFunctionAdaptor once we + // ignore incoming arguments. + template + using IsCompatibleAfterIgnoringArguments = internal::conjunction< + // It must be possible to capture the callable in a lambda. + std::is_constructible::type, Callable>, + // The callable must be invocable with zero arguments, returning something + // convertible to Result. + internal::is_callable_r::type>>; + + public: + // Construct from a callable that is directly compatible with our mocked + // signature: it accepts our function type's arguments and returns something + // convertible to our result type. + template ::type>>, + IsDirectlyCompatible> // + ::value, + int>::type = 0> + OnceAction(Callable&& callable) // NOLINT + : function_(StdFunctionAdaptor::type>( + {}, std::forward(callable))) {} + + // As above, but for a callable that ignores the mocked function's arguments. + template ::type>>, + // Exclude callables for which the overload above works. + // We'd rather provide the arguments if possible. + internal::negation>, + IsCompatibleAfterIgnoringArguments>::value, + int>::type = 0> + OnceAction(Callable&& callable) // NOLINT + // Call the constructor above with a callable + // that ignores the input arguments. + : OnceAction(IgnoreIncomingArguments::type>{ + std::forward(callable)}) {} + + // We are naturally copyable because we store only an std::function, but + // semantically we should not be copyable. + OnceAction(const OnceAction&) = delete; + OnceAction& operator=(const OnceAction&) = delete; + OnceAction(OnceAction&&) = default; + + // Invoke the underlying action callable with which we were constructed, + // handing it the supplied arguments. + Result Call(Args... args) && { + return function_(std::forward(args)...); + } + + private: + // An adaptor that wraps a callable that is compatible with our signature and + // being invoked as an rvalue reference so that it can be used as an + // StdFunctionAdaptor. This throws away type safety, but that's fine because + // this is only used by WillOnce, which we know calls at most once. + // + // Once we have something like std::move_only_function from C++23, we can do + // away with this. + template + class StdFunctionAdaptor final { + public: + // A tag indicating that the (otherwise universal) constructor is accepting + // the callable itself, instead of e.g. stealing calls for the move + // constructor. + struct CallableTag final {}; + + template + explicit StdFunctionAdaptor(CallableTag, F&& callable) + : callable_(std::make_shared(std::forward(callable))) {} + + // Rather than explicitly returning Result, we return whatever the wrapped + // callable returns. This allows for compatibility with existing uses like + // the following, when the mocked function returns void: + // + // EXPECT_CALL(mock_fn_, Call) + // .WillOnce([&] { + // [...] + // return 0; + // }); + // + // Such a callable can be turned into std::function. If we use an + // explicit return type of Result here then it *doesn't* work with + // std::function, because we'll get a "void function should not return a + // value" error. + // + // We need not worry about incompatible result types because the SFINAE on + // OnceAction already checks this for us. std::is_invocable_r_v itself makes + // the same allowance for void result types. + template + internal::call_result_t operator()( + ArgRefs&&... args) const { + return std::move(*callable_)(std::forward(args)...); + } + + private: + // We must put the callable on the heap so that we are copyable, which + // std::function needs. + std::shared_ptr callable_; + }; + + // An adaptor that makes a callable that accepts zero arguments callable with + // our mocked arguments. + template + struct IgnoreIncomingArguments { + internal::call_result_t operator()(Args&&...) { + return std::move(callable)(); + } + + Callable callable; + }; + + std::function function_; +}; + +// When an unexpected function call is encountered, Google Mock will +// let it return a default value if the user has specified one for its +// return type, or if the return type has a built-in default value; +// otherwise Google Mock won't know what value to return and will have +// to abort the process. +// +// The DefaultValue class allows a user to specify the +// default value for a type T that is both copyable and publicly +// destructible (i.e. anything that can be used as a function return +// type). The usage is: +// +// // Sets the default value for type T to be foo. +// DefaultValue::Set(foo); +template +class DefaultValue { + public: + // Sets the default value for type T; requires T to be + // copy-constructable and have a public destructor. + static void Set(T x) { + delete producer_; + producer_ = new FixedValueProducer(x); + } + + // Provides a factory function to be called to generate the default value. + // This method can be used even if T is only move-constructible, but it is not + // limited to that case. + typedef T (*FactoryFunction)(); + static void SetFactory(FactoryFunction factory) { + delete producer_; + producer_ = new FactoryValueProducer(factory); + } + + // Unsets the default value for type T. + static void Clear() { + delete producer_; + producer_ = nullptr; + } + + // Returns true if and only if the user has set the default value for type T. + static bool IsSet() { return producer_ != nullptr; } + + // Returns true if T has a default return value set by the user or there + // exists a built-in default value. + static bool Exists() { + return IsSet() || internal::BuiltInDefaultValue::Exists(); + } + + // Returns the default value for type T if the user has set one; + // otherwise returns the built-in default value. Requires that Exists() + // is true, which ensures that the return value is well-defined. + static T Get() { + return producer_ == nullptr ? internal::BuiltInDefaultValue::Get() + : producer_->Produce(); + } + + private: + class ValueProducer { + public: + virtual ~ValueProducer() = default; + virtual T Produce() = 0; + }; + + class FixedValueProducer : public ValueProducer { + public: + explicit FixedValueProducer(T value) : value_(value) {} + T Produce() override { return value_; } + + private: + const T value_; + FixedValueProducer(const FixedValueProducer&) = delete; + FixedValueProducer& operator=(const FixedValueProducer&) = delete; + }; + + class FactoryValueProducer : public ValueProducer { + public: + explicit FactoryValueProducer(FactoryFunction factory) + : factory_(factory) {} + T Produce() override { return factory_(); } + + private: + const FactoryFunction factory_; + FactoryValueProducer(const FactoryValueProducer&) = delete; + FactoryValueProducer& operator=(const FactoryValueProducer&) = delete; + }; + + static ValueProducer* producer_; +}; + +// This partial specialization allows a user to set default values for +// reference types. +template +class DefaultValue { + public: + // Sets the default value for type T&. + static void Set(T& x) { // NOLINT + address_ = &x; + } + + // Unsets the default value for type T&. + static void Clear() { address_ = nullptr; } + + // Returns true if and only if the user has set the default value for type T&. + static bool IsSet() { return address_ != nullptr; } + + // Returns true if T has a default return value set by the user or there + // exists a built-in default value. + static bool Exists() { + return IsSet() || internal::BuiltInDefaultValue::Exists(); + } + + // Returns the default value for type T& if the user has set one; + // otherwise returns the built-in default value if there is one; + // otherwise aborts the process. + static T& Get() { + return address_ == nullptr ? internal::BuiltInDefaultValue::Get() + : *address_; + } + + private: + static T* address_; +}; + +// This specialization allows DefaultValue::Get() to +// compile. +template <> +class DefaultValue { + public: + static bool Exists() { return true; } + static void Get() {} +}; + +// Points to the user-set default value for type T. +template +typename DefaultValue::ValueProducer* DefaultValue::producer_ = nullptr; + +// Points to the user-set default value for type T&. +template +T* DefaultValue::address_ = nullptr; + +// Implement this interface to define an action for function type F. +template +class ActionInterface { + public: + typedef typename internal::Function::Result Result; + typedef typename internal::Function::ArgumentTuple ArgumentTuple; + + ActionInterface() = default; + virtual ~ActionInterface() = default; + + // Performs the action. This method is not const, as in general an + // action can have side effects and be stateful. For example, a + // get-the-next-element-from-the-collection action will need to + // remember the current element. + virtual Result Perform(const ArgumentTuple& args) = 0; + + private: + ActionInterface(const ActionInterface&) = delete; + ActionInterface& operator=(const ActionInterface&) = delete; +}; + +template +class Action; + +// An Action is a copyable and IMMUTABLE (except by assignment) +// object that represents an action to be taken when a mock function of type +// R(Args...) is called. The implementation of Action is just a +// std::shared_ptr to const ActionInterface. Don't inherit from Action! You +// can view an object implementing ActionInterface as a concrete action +// (including its current state), and an Action object as a handle to it. +template +class Action { + private: + using F = R(Args...); + + // Adapter class to allow constructing Action from a legacy ActionInterface. + // New code should create Actions from functors instead. + struct ActionAdapter { + // Adapter must be copyable to satisfy std::function requirements. + ::std::shared_ptr> impl_; + + template + typename internal::Function::Result operator()(InArgs&&... args) { + return impl_->Perform( + ::std::forward_as_tuple(::std::forward(args)...)); + } + }; + + template + using IsCompatibleFunctor = std::is_constructible, G>; + + public: + typedef typename internal::Function::Result Result; + typedef typename internal::Function::ArgumentTuple ArgumentTuple; + + // Constructs a null Action. Needed for storing Action objects in + // STL containers. + Action() = default; + + // Construct an Action from a specified callable. + // This cannot take std::function directly, because then Action would not be + // directly constructible from lambda (it would require two conversions). + template < + typename G, + typename = typename std::enable_if, std::is_constructible, + G>>::value>::type> + Action(G&& fun) { // NOLINT + Init(::std::forward(fun), IsCompatibleFunctor()); + } + + // Constructs an Action from its implementation. + explicit Action(ActionInterface* impl) + : fun_(ActionAdapter{::std::shared_ptr>(impl)}) {} + + // This constructor allows us to turn an Action object into an + // Action, as long as F's arguments can be implicitly converted + // to Func's and Func's return type can be implicitly converted to F's. + template + Action(const Action& action) // NOLINT + : fun_(action.fun_) {} + + // Returns true if and only if this is the DoDefault() action. + bool IsDoDefault() const { return fun_ == nullptr; } + + // Performs the action. Note that this method is const even though + // the corresponding method in ActionInterface is not. The reason + // is that a const Action means that it cannot be re-bound to + // another concrete action, not that the concrete action it binds to + // cannot change state. (Think of the difference between a const + // pointer and a pointer to const.) + Result Perform(ArgumentTuple args) const { + if (IsDoDefault()) { + internal::IllegalDoDefault(__FILE__, __LINE__); + } + return internal::Apply(fun_, ::std::move(args)); + } + + // An action can be used as a OnceAction, since it's obviously safe to call it + // once. + operator OnceAction() const { // NOLINT + // Return a OnceAction-compatible callable that calls Perform with the + // arguments it is provided. We could instead just return fun_, but then + // we'd need to handle the IsDoDefault() case separately. + struct OA { + Action action; + + R operator()(Args... args) && { + return action.Perform( + std::forward_as_tuple(std::forward(args)...)); + } + }; + + return OA{*this}; + } + + private: + template + friend class Action; + + template + void Init(G&& g, ::std::true_type) { + fun_ = ::std::forward(g); + } + + template + void Init(G&& g, ::std::false_type) { + fun_ = IgnoreArgs::type>{::std::forward(g)}; + } + + template + struct IgnoreArgs { + template + Result operator()(const InArgs&...) const { + return function_impl(); + } + + FunctionImpl function_impl; + }; + + // fun_ is an empty function if and only if this is the DoDefault() action. + ::std::function fun_; +}; + +// The PolymorphicAction class template makes it easy to implement a +// polymorphic action (i.e. an action that can be used in mock +// functions of than one type, e.g. Return()). +// +// To define a polymorphic action, a user first provides a COPYABLE +// implementation class that has a Perform() method template: +// +// class FooAction { +// public: +// template +// Result Perform(const ArgumentTuple& args) const { +// // Processes the arguments and returns a result, using +// // std::get(args) to get the N-th (0-based) argument in the tuple. +// } +// ... +// }; +// +// Then the user creates the polymorphic action using +// MakePolymorphicAction(object) where object has type FooAction. See +// the definition of Return(void) and SetArgumentPointee(value) for +// complete examples. +template +class PolymorphicAction { + public: + explicit PolymorphicAction(const Impl& impl) : impl_(impl) {} + + template + operator Action() const { + return Action(new MonomorphicImpl(impl_)); + } + + private: + template + class MonomorphicImpl : public ActionInterface { + public: + typedef typename internal::Function::Result Result; + typedef typename internal::Function::ArgumentTuple ArgumentTuple; + + explicit MonomorphicImpl(const Impl& impl) : impl_(impl) {} + + Result Perform(const ArgumentTuple& args) override { + return impl_.template Perform(args); + } + + private: + Impl impl_; + }; + + Impl impl_; +}; + +// Creates an Action from its implementation and returns it. The +// created Action object owns the implementation. +template +Action MakeAction(ActionInterface* impl) { + return Action(impl); +} + +// Creates a polymorphic action from its implementation. This is +// easier to use than the PolymorphicAction constructor as it +// doesn't require you to explicitly write the template argument, e.g. +// +// MakePolymorphicAction(foo); +// vs +// PolymorphicAction(foo); +template +inline PolymorphicAction MakePolymorphicAction(const Impl& impl) { + return PolymorphicAction(impl); +} + +namespace internal { + +// Helper struct to specialize ReturnAction to execute a move instead of a copy +// on return. Useful for move-only types, but could be used on any type. +template +struct ByMoveWrapper { + explicit ByMoveWrapper(T value) : payload(std::move(value)) {} + T payload; +}; + +// The general implementation of Return(R). Specializations follow below. +template +class ReturnAction final { + public: + explicit ReturnAction(R value) : value_(std::move(value)) {} + + template >, // + negation>, // + std::is_convertible, // + std::is_move_constructible>::value>::type> + operator OnceAction() && { // NOLINT + return Impl(std::move(value_)); + } + + template >, // + negation>, // + std::is_convertible, // + std::is_copy_constructible>::value>::type> + operator Action() const { // NOLINT + return Impl(value_); + } + + private: + // Implements the Return(x) action for a mock function that returns type U. + template + class Impl final { + public: + // The constructor used when the return value is allowed to move from the + // input value (i.e. we are converting to OnceAction). + explicit Impl(R&& input_value) + : state_(new State(std::move(input_value))) {} + + // The constructor used when the return value is not allowed to move from + // the input value (i.e. we are converting to Action). + explicit Impl(const R& input_value) : state_(new State(input_value)) {} + + U operator()() && { return std::move(state_->value); } + U operator()() const& { return state_->value; } + + private: + // We put our state on the heap so that the compiler-generated copy/move + // constructors work correctly even when U is a reference-like type. This is + // necessary only because we eagerly create State::value (see the note on + // that symbol for details). If we instead had only the input value as a + // member then the default constructors would work fine. + // + // For example, when R is std::string and U is std::string_view, value is a + // reference to the string backed by input_value. The copy constructor would + // copy both, so that we wind up with a new input_value object (with the + // same contents) and a reference to the *old* input_value object rather + // than the new one. + struct State { + explicit State(const R& input_value_in) + : input_value(input_value_in), + // Make an implicit conversion to Result before initializing the U + // object we store, avoiding calling any explicit constructor of U + // from R. + // + // This simulates the language rules: a function with return type U + // that does `return R()` requires R to be implicitly convertible to + // U, and uses that path for the conversion, even U Result has an + // explicit constructor from R. + value(ImplicitCast_(internal::as_const(input_value))) {} + + // As above, but for the case where we're moving from the ReturnAction + // object because it's being used as a OnceAction. + explicit State(R&& input_value_in) + : input_value(std::move(input_value_in)), + // For the same reason as above we make an implicit conversion to U + // before initializing the value. + // + // Unlike above we provide the input value as an rvalue to the + // implicit conversion because this is a OnceAction: it's fine if it + // wants to consume the input value. + value(ImplicitCast_(std::move(input_value))) {} + + // A copy of the value originally provided by the user. We retain this in + // addition to the value of the mock function's result type below in case + // the latter is a reference-like type. See the std::string_view example + // in the documentation on Return. + R input_value; + + // The value we actually return, as the type returned by the mock function + // itself. + // + // We eagerly initialize this here, rather than lazily doing the implicit + // conversion automatically each time Perform is called, for historical + // reasons: in 2009-11, commit a070cbd91c (Google changelist 13540126) + // made the Action conversion operator eagerly convert the R value to + // U, but without keeping the R alive. This broke the use case discussed + // in the documentation for Return, making reference-like types such as + // std::string_view not safe to use as U where the input type R is a + // value-like type such as std::string. + // + // The example the commit gave was not very clear, nor was the issue + // thread (https://github.com/google/googlemock/issues/86), but it seems + // the worry was about reference-like input types R that flatten to a + // value-like type U when being implicitly converted. An example of this + // is std::vector::reference, which is often a proxy type with an + // reference to the underlying vector: + // + // // Helper method: have the mock function return bools according + // // to the supplied script. + // void SetActions(MockFunction& mock, + // const std::vector& script) { + // for (size_t i = 0; i < script.size(); ++i) { + // EXPECT_CALL(mock, Call(i)).WillOnce(Return(script[i])); + // } + // } + // + // TEST(Foo, Bar) { + // // Set actions using a temporary vector, whose operator[] + // // returns proxy objects that references that will be + // // dangling once the call to SetActions finishes and the + // // vector is destroyed. + // MockFunction mock; + // SetActions(mock, {false, true}); + // + // EXPECT_FALSE(mock.AsStdFunction()(0)); + // EXPECT_TRUE(mock.AsStdFunction()(1)); + // } + // + // This eager conversion helps with a simple case like this, but doesn't + // fully make these types work in general. For example the following still + // uses a dangling reference: + // + // TEST(Foo, Baz) { + // MockFunction()> mock; + // + // // Return the same vector twice, and then the empty vector + // // thereafter. + // auto action = Return(std::initializer_list{ + // "taco", "burrito", + // }); + // + // EXPECT_CALL(mock, Call) + // .WillOnce(action) + // .WillOnce(action) + // .WillRepeatedly(Return(std::vector{})); + // + // EXPECT_THAT(mock.AsStdFunction()(), + // ElementsAre("taco", "burrito")); + // EXPECT_THAT(mock.AsStdFunction()(), + // ElementsAre("taco", "burrito")); + // EXPECT_THAT(mock.AsStdFunction()(), IsEmpty()); + // } + // + U value; + }; + + const std::shared_ptr state_; + }; + + R value_; +}; + +// A specialization of ReturnAction when R is ByMoveWrapper for some T. +// +// This version applies the type system-defeating hack of moving from T even in +// the const call operator, checking at runtime that it isn't called more than +// once, since the user has declared their intent to do so by using ByMove. +template +class ReturnAction> final { + public: + explicit ReturnAction(ByMoveWrapper wrapper) + : state_(new State(std::move(wrapper.payload))) {} + + T operator()() const { + GTEST_CHECK_(!state_->called) + << "A ByMove() action must be performed at most once."; + + state_->called = true; + return std::move(state_->value); + } + + private: + // We store our state on the heap so that we are copyable as required by + // Action, despite the fact that we are stateful and T may not be copyable. + struct State { + explicit State(T&& value_in) : value(std::move(value_in)) {} + + T value; + bool called = false; + }; + + const std::shared_ptr state_; +}; + +// Implements the ReturnNull() action. +class ReturnNullAction { + public: + // Allows ReturnNull() to be used in any pointer-returning function. In C++11 + // this is enforced by returning nullptr, and in non-C++11 by asserting a + // pointer type on compile time. + template + static Result Perform(const ArgumentTuple&) { + return nullptr; + } +}; + +// Implements the Return() action. +class ReturnVoidAction { + public: + // Allows Return() to be used in any void-returning function. + template + static void Perform(const ArgumentTuple&) { + static_assert(std::is_void::value, "Result should be void."); + } +}; + +// Implements the polymorphic ReturnRef(x) action, which can be used +// in any function that returns a reference to the type of x, +// regardless of the argument types. +template +class ReturnRefAction { + public: + // Constructs a ReturnRefAction object from the reference to be returned. + explicit ReturnRefAction(T& ref) : ref_(ref) {} // NOLINT + + // This template type conversion operator allows ReturnRef(x) to be + // used in ANY function that returns a reference to x's type. + template + operator Action() const { + typedef typename Function::Result Result; + // Asserts that the function return type is a reference. This + // catches the user error of using ReturnRef(x) when Return(x) + // should be used, and generates some helpful error message. + static_assert(std::is_reference::value, + "use Return instead of ReturnRef to return a value"); + return Action(new Impl(ref_)); + } + + private: + // Implements the ReturnRef(x) action for a particular function type F. + template + class Impl : public ActionInterface { + public: + typedef typename Function::Result Result; + typedef typename Function::ArgumentTuple ArgumentTuple; + + explicit Impl(T& ref) : ref_(ref) {} // NOLINT + + Result Perform(const ArgumentTuple&) override { return ref_; } + + private: + T& ref_; + }; + + T& ref_; +}; + +// Implements the polymorphic ReturnRefOfCopy(x) action, which can be +// used in any function that returns a reference to the type of x, +// regardless of the argument types. +template +class ReturnRefOfCopyAction { + public: + // Constructs a ReturnRefOfCopyAction object from the reference to + // be returned. + explicit ReturnRefOfCopyAction(const T& value) : value_(value) {} // NOLINT + + // This template type conversion operator allows ReturnRefOfCopy(x) to be + // used in ANY function that returns a reference to x's type. + template + operator Action() const { + typedef typename Function::Result Result; + // Asserts that the function return type is a reference. This + // catches the user error of using ReturnRefOfCopy(x) when Return(x) + // should be used, and generates some helpful error message. + static_assert(std::is_reference::value, + "use Return instead of ReturnRefOfCopy to return a value"); + return Action(new Impl(value_)); + } + + private: + // Implements the ReturnRefOfCopy(x) action for a particular function type F. + template + class Impl : public ActionInterface { + public: + typedef typename Function::Result Result; + typedef typename Function::ArgumentTuple ArgumentTuple; + + explicit Impl(const T& value) : value_(value) {} // NOLINT + + Result Perform(const ArgumentTuple&) override { return value_; } + + private: + T value_; + }; + + const T value_; +}; + +// Implements the polymorphic ReturnRoundRobin(v) action, which can be +// used in any function that returns the element_type of v. +template +class ReturnRoundRobinAction { + public: + explicit ReturnRoundRobinAction(std::vector values) { + GTEST_CHECK_(!values.empty()) + << "ReturnRoundRobin requires at least one element."; + state_->values = std::move(values); + } + + template + T operator()(Args&&...) const { + return state_->Next(); + } + + private: + struct State { + T Next() { + T ret_val = values[i++]; + if (i == values.size()) i = 0; + return ret_val; + } + + std::vector values; + size_t i = 0; + }; + std::shared_ptr state_ = std::make_shared(); +}; + +// Implements the polymorphic DoDefault() action. +class DoDefaultAction { + public: + // This template type conversion operator allows DoDefault() to be + // used in any function. + template + operator Action() const { + return Action(); + } // NOLINT +}; + +// Implements the Assign action to set a given pointer referent to a +// particular value. +template +class AssignAction { + public: + AssignAction(T1* ptr, T2 value) : ptr_(ptr), value_(value) {} + + template + void Perform(const ArgumentTuple& /* args */) const { + *ptr_ = value_; + } + + private: + T1* const ptr_; + const T2 value_; +}; + +#ifndef GTEST_OS_WINDOWS_MOBILE + +// Implements the SetErrnoAndReturn action to simulate return from +// various system calls and libc functions. +template +class SetErrnoAndReturnAction { + public: + SetErrnoAndReturnAction(int errno_value, T result) + : errno_(errno_value), result_(result) {} + template + Result Perform(const ArgumentTuple& /* args */) const { + errno = errno_; + return result_; + } + + private: + const int errno_; + const T result_; +}; + +#endif // !GTEST_OS_WINDOWS_MOBILE + +// Implements the SetArgumentPointee(x) action for any function +// whose N-th argument (0-based) is a pointer to x's type. +template +struct SetArgumentPointeeAction { + A value; + + template + void operator()(const Args&... args) const { + *::std::get(std::tie(args...)) = value; + } +}; + +// Implements the Invoke(object_ptr, &Class::Method) action. +template +struct InvokeMethodAction { + Class* const obj_ptr; + const MethodPtr method_ptr; + + template + auto operator()(Args&&... args) const + -> decltype((obj_ptr->*method_ptr)(std::forward(args)...)) { + return (obj_ptr->*method_ptr)(std::forward(args)...); + } +}; + +// Implements the InvokeWithoutArgs(f) action. The template argument +// FunctionImpl is the implementation type of f, which can be either a +// function pointer or a functor. InvokeWithoutArgs(f) can be used as an +// Action as long as f's type is compatible with F. +template +struct InvokeWithoutArgsAction { + FunctionImpl function_impl; + + // Allows InvokeWithoutArgs(f) to be used as any action whose type is + // compatible with f. + template + auto operator()(const Args&...) -> decltype(function_impl()) { + return function_impl(); + } +}; + +// Implements the InvokeWithoutArgs(object_ptr, &Class::Method) action. +template +struct InvokeMethodWithoutArgsAction { + Class* const obj_ptr; + const MethodPtr method_ptr; + + using ReturnType = + decltype((std::declval()->*std::declval())()); + + template + ReturnType operator()(const Args&...) const { + return (obj_ptr->*method_ptr)(); + } +}; + +// Implements the IgnoreResult(action) action. +template +class IgnoreResultAction { + public: + explicit IgnoreResultAction(const A& action) : action_(action) {} + + template + operator Action() const { + // Assert statement belongs here because this is the best place to verify + // conditions on F. It produces the clearest error messages + // in most compilers. + // Impl really belongs in this scope as a local class but can't + // because MSVC produces duplicate symbols in different translation units + // in this case. Until MS fixes that bug we put Impl into the class scope + // and put the typedef both here (for use in assert statement) and + // in the Impl class. But both definitions must be the same. + typedef typename internal::Function::Result Result; + + // Asserts at compile time that F returns void. + static_assert(std::is_void::value, "Result type should be void."); + + return Action(new Impl(action_)); + } + + private: + template + class Impl : public ActionInterface { + public: + typedef typename internal::Function::Result Result; + typedef typename internal::Function::ArgumentTuple ArgumentTuple; + + explicit Impl(const A& action) : action_(action) {} + + void Perform(const ArgumentTuple& args) override { + // Performs the action and ignores its result. + action_.Perform(args); + } + + private: + // Type OriginalFunction is the same as F except that its return + // type is IgnoredValue. + typedef + typename internal::Function::MakeResultIgnoredValue OriginalFunction; + + const Action action_; + }; + + const A action_; +}; + +template +struct WithArgsAction { + InnerAction inner_action; + + // The signature of the function as seen by the inner action, given an out + // action with the given result and argument types. + template + using InnerSignature = + R(typename std::tuple_element>::type...); + + // Rather than a call operator, we must define conversion operators to + // particular action types. This is necessary for embedded actions like + // DoDefault(), which rely on an action conversion operators rather than + // providing a call operator because even with a particular set of arguments + // they don't have a fixed return type. + + template < + typename R, typename... Args, + typename std::enable_if< + std::is_convertible>...)>>::value, + int>::type = 0> + operator OnceAction() && { // NOLINT + struct OA { + OnceAction> inner_action; + + R operator()(Args&&... args) && { + return std::move(inner_action) + .Call(std::get( + std::forward_as_tuple(std::forward(args)...))...); + } + }; + + return OA{std::move(inner_action)}; + } + + template < + typename R, typename... Args, + typename std::enable_if< + std::is_convertible>...)>>::value, + int>::type = 0> + operator Action() const { // NOLINT + Action> converted(inner_action); + + return [converted](Args&&... args) -> R { + return converted.Perform(std::forward_as_tuple( + std::get(std::forward_as_tuple(std::forward(args)...))...)); + }; + } +}; + +template +class DoAllAction; + +// Base case: only a single action. +template +class DoAllAction { + public: + struct UserConstructorTag {}; + + template + explicit DoAllAction(UserConstructorTag, T&& action) + : final_action_(std::forward(action)) {} + + // Rather than a call operator, we must define conversion operators to + // particular action types. This is necessary for embedded actions like + // DoDefault(), which rely on an action conversion operators rather than + // providing a call operator because even with a particular set of arguments + // they don't have a fixed return type. + + // We support conversion to OnceAction whenever the sub-action does. + template >::value, + int>::type = 0> + operator OnceAction() && { // NOLINT + return std::move(final_action_); + } + + // We also support conversion to OnceAction whenever the sub-action supports + // conversion to Action (since any Action can also be a OnceAction). + template < + typename R, typename... Args, + typename std::enable_if< + conjunction< + negation< + std::is_convertible>>, + std::is_convertible>>::value, + int>::type = 0> + operator OnceAction() && { // NOLINT + return Action(std::move(final_action_)); + } + + // We support conversion to Action whenever the sub-action does. + template < + typename R, typename... Args, + typename std::enable_if< + std::is_convertible>::value, + int>::type = 0> + operator Action() const { // NOLINT + return final_action_; + } + + private: + FinalAction final_action_; +}; + +// Recursive case: support N actions by calling the initial action and then +// calling through to the base class containing N-1 actions. +template +class DoAllAction + : private DoAllAction { + private: + using Base = DoAllAction; + + // The type of reference that should be provided to an initial action for a + // mocked function parameter of type T. + // + // There are two quirks here: + // + // * Unlike most forwarding functions, we pass scalars through by value. + // This isn't strictly necessary because an lvalue reference would work + // fine too and be consistent with other non-reference types, but it's + // perhaps less surprising. + // + // For example if the mocked function has signature void(int), then it + // might seem surprising for the user's initial action to need to be + // convertible to Action. This is perhaps less + // surprising for a non-scalar type where there may be a performance + // impact, or it might even be impossible, to pass by value. + // + // * More surprisingly, `const T&` is often not a const reference type. + // By the reference collapsing rules in C++17 [dcl.ref]/6, if T refers to + // U& or U&& for some non-scalar type U, then InitialActionArgType is + // U&. In other words, we may hand over a non-const reference. + // + // So for example, given some non-scalar type Obj we have the following + // mappings: + // + // T InitialActionArgType + // ------- ----------------------- + // Obj const Obj& + // Obj& Obj& + // Obj&& Obj& + // const Obj const Obj& + // const Obj& const Obj& + // const Obj&& const Obj& + // + // In other words, the initial actions get a mutable view of an non-scalar + // argument if and only if the mock function itself accepts a non-const + // reference type. They are never given an rvalue reference to an + // non-scalar type. + // + // This situation makes sense if you imagine use with a matcher that is + // designed to write through a reference. For example, if the caller wants + // to fill in a reference argument and then return a canned value: + // + // EXPECT_CALL(mock, Call) + // .WillOnce(DoAll(SetArgReferee<0>(17), Return(19))); + // + template + using InitialActionArgType = + typename std::conditional::value, T, const T&>::type; + + public: + struct UserConstructorTag {}; + + template + explicit DoAllAction(UserConstructorTag, T&& initial_action, + U&&... other_actions) + : Base({}, std::forward(other_actions)...), + initial_action_(std::forward(initial_action)) {} + + // We support conversion to OnceAction whenever both the initial action and + // the rest support conversion to OnceAction. + template < + typename R, typename... Args, + typename std::enable_if< + conjunction...)>>, + std::is_convertible>>::value, + int>::type = 0> + operator OnceAction() && { // NOLINT + // Return an action that first calls the initial action with arguments + // filtered through InitialActionArgType, then forwards arguments directly + // to the base class to deal with the remaining actions. + struct OA { + OnceAction...)> initial_action; + OnceAction remaining_actions; + + R operator()(Args... args) && { + std::move(initial_action) + .Call(static_cast>(args)...); + + return std::move(remaining_actions).Call(std::forward(args)...); + } + }; + + return OA{ + std::move(initial_action_), + std::move(static_cast(*this)), + }; + } + + // We also support conversion to OnceAction whenever the initial action + // supports conversion to Action (since any Action can also be a OnceAction). + // + // The remaining sub-actions must also be compatible, but we don't need to + // special case them because the base class deals with them. + template < + typename R, typename... Args, + typename std::enable_if< + conjunction< + negation...)>>>, + std::is_convertible...)>>, + std::is_convertible>>::value, + int>::type = 0> + operator OnceAction() && { // NOLINT + return DoAll( + Action...)>(std::move(initial_action_)), + std::move(static_cast(*this))); + } + + // We support conversion to Action whenever both the initial action and the + // rest support conversion to Action. + template < + typename R, typename... Args, + typename std::enable_if< + conjunction< + std::is_convertible...)>>, + std::is_convertible>>::value, + int>::type = 0> + operator Action() const { // NOLINT + // Return an action that first calls the initial action with arguments + // filtered through InitialActionArgType, then forwards arguments directly + // to the base class to deal with the remaining actions. + struct OA { + Action...)> initial_action; + Action remaining_actions; + + R operator()(Args... args) const { + initial_action.Perform(std::forward_as_tuple( + static_cast>(args)...)); + + return remaining_actions.Perform( + std::forward_as_tuple(std::forward(args)...)); + } + }; + + return OA{ + initial_action_, + static_cast(*this), + }; + } + + private: + InitialAction initial_action_; +}; + +template +struct ReturnNewAction { + T* operator()() const { + return internal::Apply( + [](const Params&... unpacked_params) { + return new T(unpacked_params...); + }, + params); + } + std::tuple params; +}; + +template +struct ReturnArgAction { + template ::type> + auto operator()(Args&&... args) const + -> decltype(std::get( + std::forward_as_tuple(std::forward(args)...))) { + return std::get(std::forward_as_tuple(std::forward(args)...)); + } +}; + +template +struct SaveArgAction { + Ptr pointer; + + template + void operator()(const Args&... args) const { + *pointer = std::get(std::tie(args...)); + } +}; + +template +struct SaveArgPointeeAction { + Ptr pointer; + + template + void operator()(const Args&... args) const { + *pointer = *std::get(std::tie(args...)); + } +}; + +template +struct SetArgRefereeAction { + T value; + + template + void operator()(Args&&... args) const { + using argk_type = + typename ::std::tuple_element>::type; + static_assert(std::is_lvalue_reference::value, + "Argument must be a reference type."); + std::get(std::tie(args...)) = value; + } +}; + +template +struct SetArrayArgumentAction { + I1 first; + I2 last; + + template + void operator()(const Args&... args) const { + auto value = std::get(std::tie(args...)); + for (auto it = first; it != last; ++it, (void)++value) { + *value = *it; + } + } +}; + +template +struct DeleteArgAction { + template + void operator()(const Args&... args) const { + delete std::get(std::tie(args...)); + } +}; + +template +struct ReturnPointeeAction { + Ptr pointer; + template + auto operator()(const Args&...) const -> decltype(*pointer) { + return *pointer; + } +}; + +#if GTEST_HAS_EXCEPTIONS +template +struct ThrowAction { + T exception; + // We use a conversion operator to adapt to any return type. + template + operator Action() const { // NOLINT + T copy = exception; + return [copy](Args...) -> R { throw copy; }; + } +}; +struct RethrowAction { + std::exception_ptr exception; + template + operator Action() const { // NOLINT + return [ex = exception](Args...) -> R { std::rethrow_exception(ex); }; + } +}; +#endif // GTEST_HAS_EXCEPTIONS + +} // namespace internal + +// An Unused object can be implicitly constructed from ANY value. +// This is handy when defining actions that ignore some or all of the +// mock function arguments. For example, given +// +// MOCK_METHOD3(Foo, double(const string& label, double x, double y)); +// MOCK_METHOD3(Bar, double(int index, double x, double y)); +// +// instead of +// +// double DistanceToOriginWithLabel(const string& label, double x, double y) { +// return sqrt(x*x + y*y); +// } +// double DistanceToOriginWithIndex(int index, double x, double y) { +// return sqrt(x*x + y*y); +// } +// ... +// EXPECT_CALL(mock, Foo("abc", _, _)) +// .WillOnce(Invoke(DistanceToOriginWithLabel)); +// EXPECT_CALL(mock, Bar(5, _, _)) +// .WillOnce(Invoke(DistanceToOriginWithIndex)); +// +// you could write +// +// // We can declare any uninteresting argument as Unused. +// double DistanceToOrigin(Unused, double x, double y) { +// return sqrt(x*x + y*y); +// } +// ... +// EXPECT_CALL(mock, Foo("abc", _, _)).WillOnce(Invoke(DistanceToOrigin)); +// EXPECT_CALL(mock, Bar(5, _, _)).WillOnce(Invoke(DistanceToOrigin)); +typedef internal::IgnoredValue Unused; + +// Creates an action that does actions a1, a2, ..., sequentially in +// each invocation. All but the last action will have a readonly view of the +// arguments. +template +internal::DoAllAction::type...> DoAll( + Action&&... action) { + return internal::DoAllAction::type...>( + {}, std::forward(action)...); +} + +// WithArg(an_action) creates an action that passes the k-th +// (0-based) argument of the mock function to an_action and performs +// it. It adapts an action accepting one argument to one that accepts +// multiple arguments. For convenience, we also provide +// WithArgs(an_action) (defined below) as a synonym. +template +internal::WithArgsAction::type, k> WithArg( + InnerAction&& action) { + return {std::forward(action)}; +} + +// WithArgs(an_action) creates an action that passes +// the selected arguments of the mock function to an_action and +// performs it. It serves as an adaptor between actions with +// different argument lists. +template +internal::WithArgsAction::type, k, ks...> +WithArgs(InnerAction&& action) { + return {std::forward(action)}; +} + +// WithoutArgs(inner_action) can be used in a mock function with a +// non-empty argument list to perform inner_action, which takes no +// argument. In other words, it adapts an action accepting no +// argument to one that accepts (and ignores) arguments. +template +internal::WithArgsAction::type> WithoutArgs( + InnerAction&& action) { + return {std::forward(action)}; +} + +// Creates an action that returns a value. +// +// The returned type can be used with a mock function returning a non-void, +// non-reference type U as follows: +// +// * If R is convertible to U and U is move-constructible, then the action can +// be used with WillOnce. +// +// * If const R& is convertible to U and U is copy-constructible, then the +// action can be used with both WillOnce and WillRepeatedly. +// +// The mock expectation contains the R value from which the U return value is +// constructed (a move/copy of the argument to Return). This means that the R +// value will survive at least until the mock object's expectations are cleared +// or the mock object is destroyed, meaning that U can safely be a +// reference-like type such as std::string_view: +// +// // The mock function returns a view of a copy of the string fed to +// // Return. The view is valid even after the action is performed. +// MockFunction mock; +// EXPECT_CALL(mock, Call).WillOnce(Return(std::string("taco"))); +// const std::string_view result = mock.AsStdFunction()(); +// EXPECT_EQ("taco", result); +// +template +internal::ReturnAction Return(R value) { + return internal::ReturnAction(std::move(value)); +} + +// Creates an action that returns NULL. +inline PolymorphicAction ReturnNull() { + return MakePolymorphicAction(internal::ReturnNullAction()); +} + +// Creates an action that returns from a void function. +inline PolymorphicAction Return() { + return MakePolymorphicAction(internal::ReturnVoidAction()); +} + +// Creates an action that returns the reference to a variable. +template +inline internal::ReturnRefAction ReturnRef(R& x) { // NOLINT + return internal::ReturnRefAction(x); +} + +// Prevent using ReturnRef on reference to temporary. +template +internal::ReturnRefAction ReturnRef(R&&) = delete; + +// Creates an action that returns the reference to a copy of the +// argument. The copy is created when the action is constructed and +// lives as long as the action. +template +inline internal::ReturnRefOfCopyAction ReturnRefOfCopy(const R& x) { + return internal::ReturnRefOfCopyAction(x); +} + +// DEPRECATED: use Return(x) directly with WillOnce. +// +// Modifies the parent action (a Return() action) to perform a move of the +// argument instead of a copy. +// Return(ByMove()) actions can only be executed once and will assert this +// invariant. +template +internal::ByMoveWrapper ByMove(R x) { + return internal::ByMoveWrapper(std::move(x)); +} + +// Creates an action that returns an element of `vals`. Calling this action will +// repeatedly return the next value from `vals` until it reaches the end and +// will restart from the beginning. +template +internal::ReturnRoundRobinAction ReturnRoundRobin(std::vector vals) { + return internal::ReturnRoundRobinAction(std::move(vals)); +} + +// Creates an action that returns an element of `vals`. Calling this action will +// repeatedly return the next value from `vals` until it reaches the end and +// will restart from the beginning. +template +internal::ReturnRoundRobinAction ReturnRoundRobin( + std::initializer_list vals) { + return internal::ReturnRoundRobinAction(std::vector(vals)); +} + +// Creates an action that does the default action for the give mock function. +inline internal::DoDefaultAction DoDefault() { + return internal::DoDefaultAction(); +} + +// Creates an action that sets the variable pointed by the N-th +// (0-based) function argument to 'value'. +template +internal::SetArgumentPointeeAction SetArgPointee(T value) { + return {std::move(value)}; +} + +// The following version is DEPRECATED. +template +internal::SetArgumentPointeeAction SetArgumentPointee(T value) { + return {std::move(value)}; +} + +// Creates an action that sets a pointer referent to a given value. +template +PolymorphicAction> Assign(T1* ptr, T2 val) { + return MakePolymorphicAction(internal::AssignAction(ptr, val)); +} + +#ifndef GTEST_OS_WINDOWS_MOBILE + +// Creates an action that sets errno and returns the appropriate error. +template +PolymorphicAction> SetErrnoAndReturn( + int errval, T result) { + return MakePolymorphicAction( + internal::SetErrnoAndReturnAction(errval, result)); +} + +#endif // !GTEST_OS_WINDOWS_MOBILE + +// Various overloads for Invoke(). + +// Legacy function. +// Actions can now be implicitly constructed from callables. No need to create +// wrapper objects. +// This function exists for backwards compatibility. +template +typename std::decay::type Invoke(FunctionImpl&& function_impl) { + return std::forward(function_impl); +} + +// Creates an action that invokes the given method on the given object +// with the mock function's arguments. +template +internal::InvokeMethodAction Invoke(Class* obj_ptr, + MethodPtr method_ptr) { + return {obj_ptr, method_ptr}; +} + +// Creates an action that invokes 'function_impl' with no argument. +template +internal::InvokeWithoutArgsAction::type> +InvokeWithoutArgs(FunctionImpl function_impl) { + return {std::move(function_impl)}; +} + +// Creates an action that invokes the given method on the given object +// with no argument. +template +internal::InvokeMethodWithoutArgsAction InvokeWithoutArgs( + Class* obj_ptr, MethodPtr method_ptr) { + return {obj_ptr, method_ptr}; +} + +// Creates an action that performs an_action and throws away its +// result. In other words, it changes the return type of an_action to +// void. an_action MUST NOT return void, or the code won't compile. +template +inline internal::IgnoreResultAction IgnoreResult(const A& an_action) { + return internal::IgnoreResultAction(an_action); +} + +// Creates a reference wrapper for the given L-value. If necessary, +// you can explicitly specify the type of the reference. For example, +// suppose 'derived' is an object of type Derived, ByRef(derived) +// would wrap a Derived&. If you want to wrap a const Base& instead, +// where Base is a base class of Derived, just write: +// +// ByRef(derived) +// +// N.B. ByRef is redundant with std::ref, std::cref and std::reference_wrapper. +// However, it may still be used for consistency with ByMove(). +template +inline ::std::reference_wrapper ByRef(T& l_value) { // NOLINT + return ::std::reference_wrapper(l_value); +} + +// The ReturnNew(a1, a2, ..., a_k) action returns a pointer to a new +// instance of type T, constructed on the heap with constructor arguments +// a1, a2, ..., and a_k. The caller assumes ownership of the returned value. +template +internal::ReturnNewAction::type...> ReturnNew( + Params&&... params) { + return {std::forward_as_tuple(std::forward(params)...)}; +} + +// Action ReturnArg() returns the k-th argument of the mock function. +template +internal::ReturnArgAction ReturnArg() { + return {}; +} + +// Action SaveArg(pointer) saves the k-th (0-based) argument of the +// mock function to *pointer. +template +internal::SaveArgAction SaveArg(Ptr pointer) { + return {pointer}; +} + +// Action SaveArgPointee(pointer) saves the value pointed to +// by the k-th (0-based) argument of the mock function to *pointer. +template +internal::SaveArgPointeeAction SaveArgPointee(Ptr pointer) { + return {pointer}; +} + +// Action SetArgReferee(value) assigns 'value' to the variable +// referenced by the k-th (0-based) argument of the mock function. +template +internal::SetArgRefereeAction::type> SetArgReferee( + T&& value) { + return {std::forward(value)}; +} + +// Action SetArrayArgument(first, last) copies the elements in +// source range [first, last) to the array pointed to by the k-th +// (0-based) argument, which can be either a pointer or an +// iterator. The action does not take ownership of the elements in the +// source range. +template +internal::SetArrayArgumentAction SetArrayArgument(I1 first, + I2 last) { + return {first, last}; +} + +// Action DeleteArg() deletes the k-th (0-based) argument of the mock +// function. +template +internal::DeleteArgAction DeleteArg() { + return {}; +} + +// This action returns the value pointed to by 'pointer'. +template +internal::ReturnPointeeAction ReturnPointee(Ptr pointer) { + return {pointer}; +} + +#if GTEST_HAS_EXCEPTIONS +// Action Throw(exception) can be used in a mock function of any type +// to throw the given exception. Any copyable value can be thrown, +// except for std::exception_ptr, which is likely a mistake if +// thrown directly. +template +typename std::enable_if< + !std::is_base_of::type>::value, + internal::ThrowAction::type>>::type +Throw(T&& exception) { + return {std::forward(exception)}; +} +// Action Rethrow(exception_ptr) can be used in a mock function of any type +// to rethrow any exception_ptr. Note that the same object is thrown each time. +inline internal::RethrowAction Rethrow(std::exception_ptr exception) { + return {std::move(exception)}; +} +#endif // GTEST_HAS_EXCEPTIONS + +namespace internal { + +// A macro from the ACTION* family (defined later in gmock-generated-actions.h) +// defines an action that can be used in a mock function. Typically, +// these actions only care about a subset of the arguments of the mock +// function. For example, if such an action only uses the second +// argument, it can be used in any mock function that takes >= 2 +// arguments where the type of the second argument is compatible. +// +// Therefore, the action implementation must be prepared to take more +// arguments than it needs. The ExcessiveArg type is used to +// represent those excessive arguments. In order to keep the compiler +// error messages tractable, we define it in the testing namespace +// instead of testing::internal. However, this is an INTERNAL TYPE +// and subject to change without notice, so a user MUST NOT USE THIS +// TYPE DIRECTLY. +struct ExcessiveArg {}; + +// Builds an implementation of an Action<> for some particular signature, using +// a class defined by an ACTION* macro. +template +struct ActionImpl; + +template +struct ImplBase { + struct Holder { + // Allows each copy of the Action<> to get to the Impl. + explicit operator const Impl&() const { return *ptr; } + std::shared_ptr ptr; + }; + using type = typename std::conditional::value, + Impl, Holder>::type; +}; + +template +struct ActionImpl : ImplBase::type { + using Base = typename ImplBase::type; + using function_type = R(Args...); + using args_type = std::tuple; + + ActionImpl() = default; // Only defined if appropriate for Base. + explicit ActionImpl(std::shared_ptr impl) : Base{std::move(impl)} {} + + R operator()(Args&&... arg) const { + static constexpr size_t kMaxArgs = + sizeof...(Args) <= 10 ? sizeof...(Args) : 10; + return Apply(std::make_index_sequence{}, + std::make_index_sequence<10 - kMaxArgs>{}, + args_type{std::forward(arg)...}); + } + + template + R Apply(std::index_sequence, std::index_sequence, + const args_type& args) const { + // Impl need not be specific to the signature of action being implemented; + // only the implementing function body needs to have all of the specific + // types instantiated. Up to 10 of the args that are provided by the + // args_type get passed, followed by a dummy of unspecified type for the + // remainder up to 10 explicit args. + static constexpr ExcessiveArg kExcessArg{}; + return static_cast(*this) + .template gmock_PerformImpl< + /*function_type=*/function_type, /*return_type=*/R, + /*args_type=*/args_type, + /*argN_type=*/ + typename std::tuple_element::type...>( + /*args=*/args, std::get(args)..., + ((void)excess_id, kExcessArg)...); + } +}; + +// Stores a default-constructed Impl as part of the Action<>'s +// std::function<>. The Impl should be trivial to copy. +template +::testing::Action MakeAction() { + return ::testing::Action(ActionImpl()); +} + +// Stores just the one given instance of Impl. +template +::testing::Action MakeAction(std::shared_ptr impl) { + return ::testing::Action(ActionImpl(std::move(impl))); +} + +#define GMOCK_INTERNAL_ARG_UNUSED(i, data, el) \ + , GTEST_INTERNAL_ATTRIBUTE_MAYBE_UNUSED const arg##i##_type& arg##i +#define GMOCK_ACTION_ARG_TYPES_AND_NAMES_UNUSED_ \ + GTEST_INTERNAL_ATTRIBUTE_MAYBE_UNUSED const args_type& args GMOCK_PP_REPEAT( \ + GMOCK_INTERNAL_ARG_UNUSED, , 10) + +#define GMOCK_INTERNAL_ARG(i, data, el) , const arg##i##_type& arg##i +#define GMOCK_ACTION_ARG_TYPES_AND_NAMES_ \ + const args_type& args GMOCK_PP_REPEAT(GMOCK_INTERNAL_ARG, , 10) + +#define GMOCK_INTERNAL_TEMPLATE_ARG(i, data, el) , typename arg##i##_type +#define GMOCK_ACTION_TEMPLATE_ARGS_NAMES_ \ + GMOCK_PP_TAIL(GMOCK_PP_REPEAT(GMOCK_INTERNAL_TEMPLATE_ARG, , 10)) + +#define GMOCK_INTERNAL_TYPENAME_PARAM(i, data, param) , typename param##_type +#define GMOCK_ACTION_TYPENAME_PARAMS_(params) \ + GMOCK_PP_TAIL(GMOCK_PP_FOR_EACH(GMOCK_INTERNAL_TYPENAME_PARAM, , params)) + +#define GMOCK_INTERNAL_TYPE_PARAM(i, data, param) , param##_type +#define GMOCK_ACTION_TYPE_PARAMS_(params) \ + GMOCK_PP_TAIL(GMOCK_PP_FOR_EACH(GMOCK_INTERNAL_TYPE_PARAM, , params)) + +#define GMOCK_INTERNAL_TYPE_GVALUE_PARAM(i, data, param) \ + , param##_type gmock_p##i +#define GMOCK_ACTION_TYPE_GVALUE_PARAMS_(params) \ + GMOCK_PP_TAIL(GMOCK_PP_FOR_EACH(GMOCK_INTERNAL_TYPE_GVALUE_PARAM, , params)) + +#define GMOCK_INTERNAL_GVALUE_PARAM(i, data, param) \ + , std::forward(gmock_p##i) +#define GMOCK_ACTION_GVALUE_PARAMS_(params) \ + GMOCK_PP_TAIL(GMOCK_PP_FOR_EACH(GMOCK_INTERNAL_GVALUE_PARAM, , params)) + +#define GMOCK_INTERNAL_INIT_PARAM(i, data, param) \ + , param(::std::forward(gmock_p##i)) +#define GMOCK_ACTION_INIT_PARAMS_(params) \ + GMOCK_PP_TAIL(GMOCK_PP_FOR_EACH(GMOCK_INTERNAL_INIT_PARAM, , params)) + +#define GMOCK_INTERNAL_FIELD_PARAM(i, data, param) param##_type param; +#define GMOCK_ACTION_FIELD_PARAMS_(params) \ + GMOCK_PP_FOR_EACH(GMOCK_INTERNAL_FIELD_PARAM, , params) + +#define GMOCK_INTERNAL_ACTION(name, full_name, params) \ + template \ + class full_name { \ + public: \ + explicit full_name(GMOCK_ACTION_TYPE_GVALUE_PARAMS_(params)) \ + : impl_(std::make_shared( \ + GMOCK_ACTION_GVALUE_PARAMS_(params))) {} \ + full_name(const full_name&) = default; \ + full_name(full_name&&) noexcept = default; \ + template \ + operator ::testing::Action() const { \ + return ::testing::internal::MakeAction(impl_); \ + } \ + \ + private: \ + class gmock_Impl { \ + public: \ + explicit gmock_Impl(GMOCK_ACTION_TYPE_GVALUE_PARAMS_(params)) \ + : GMOCK_ACTION_INIT_PARAMS_(params) {} \ + template \ + return_type gmock_PerformImpl(GMOCK_ACTION_ARG_TYPES_AND_NAMES_) const; \ + GMOCK_ACTION_FIELD_PARAMS_(params) \ + }; \ + std::shared_ptr impl_; \ + }; \ + template \ + inline full_name name( \ + GMOCK_ACTION_TYPE_GVALUE_PARAMS_(params)) GTEST_MUST_USE_RESULT_; \ + template \ + inline full_name name( \ + GMOCK_ACTION_TYPE_GVALUE_PARAMS_(params)) { \ + return full_name( \ + GMOCK_ACTION_GVALUE_PARAMS_(params)); \ + } \ + template \ + template \ + return_type \ + full_name::gmock_Impl::gmock_PerformImpl( \ + GMOCK_ACTION_ARG_TYPES_AND_NAMES_UNUSED_) const + +} // namespace internal + +// Similar to GMOCK_INTERNAL_ACTION, but no bound parameters are stored. +#define ACTION(name) \ + class name##Action { \ + public: \ + explicit name##Action() noexcept {} \ + name##Action(const name##Action&) noexcept {} \ + template \ + operator ::testing::Action() const { \ + return ::testing::internal::MakeAction(); \ + } \ + \ + private: \ + class gmock_Impl { \ + public: \ + template \ + return_type gmock_PerformImpl(GMOCK_ACTION_ARG_TYPES_AND_NAMES_) const; \ + }; \ + }; \ + inline name##Action name() GTEST_MUST_USE_RESULT_; \ + inline name##Action name() { return name##Action(); } \ + template \ + return_type name##Action::gmock_Impl::gmock_PerformImpl( \ + GMOCK_ACTION_ARG_TYPES_AND_NAMES_UNUSED_) const + +#define ACTION_P(name, ...) \ + GMOCK_INTERNAL_ACTION(name, name##ActionP, (__VA_ARGS__)) + +#define ACTION_P2(name, ...) \ + GMOCK_INTERNAL_ACTION(name, name##ActionP2, (__VA_ARGS__)) + +#define ACTION_P3(name, ...) \ + GMOCK_INTERNAL_ACTION(name, name##ActionP3, (__VA_ARGS__)) + +#define ACTION_P4(name, ...) \ + GMOCK_INTERNAL_ACTION(name, name##ActionP4, (__VA_ARGS__)) + +#define ACTION_P5(name, ...) \ + GMOCK_INTERNAL_ACTION(name, name##ActionP5, (__VA_ARGS__)) + +#define ACTION_P6(name, ...) \ + GMOCK_INTERNAL_ACTION(name, name##ActionP6, (__VA_ARGS__)) + +#define ACTION_P7(name, ...) \ + GMOCK_INTERNAL_ACTION(name, name##ActionP7, (__VA_ARGS__)) + +#define ACTION_P8(name, ...) \ + GMOCK_INTERNAL_ACTION(name, name##ActionP8, (__VA_ARGS__)) + +#define ACTION_P9(name, ...) \ + GMOCK_INTERNAL_ACTION(name, name##ActionP9, (__VA_ARGS__)) + +#define ACTION_P10(name, ...) \ + GMOCK_INTERNAL_ACTION(name, name##ActionP10, (__VA_ARGS__)) + +} // namespace testing + +GTEST_DISABLE_MSC_WARNINGS_POP_() // 4100 + +#endif // GOOGLEMOCK_INCLUDE_GMOCK_GMOCK_ACTIONS_H_ diff --git a/test/googletest/googlemock/include/gmock/gmock-cardinalities.h b/test/googletest/googlemock/include/gmock/gmock-cardinalities.h new file mode 100644 index 0000000..533e604 --- /dev/null +++ b/test/googletest/googlemock/include/gmock/gmock-cardinalities.h @@ -0,0 +1,159 @@ +// Copyright 2007, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Google Mock - a framework for writing C++ mock classes. +// +// This file implements some commonly used cardinalities. More +// cardinalities can be defined by the user implementing the +// CardinalityInterface interface if necessary. + +// IWYU pragma: private, include "gmock/gmock.h" +// IWYU pragma: friend gmock/.* + +#ifndef GOOGLEMOCK_INCLUDE_GMOCK_GMOCK_CARDINALITIES_H_ +#define GOOGLEMOCK_INCLUDE_GMOCK_GMOCK_CARDINALITIES_H_ + +#include + +#include +#include // NOLINT + +#include "gmock/internal/gmock-port.h" +#include "gtest/gtest.h" + +GTEST_DISABLE_MSC_WARNINGS_PUSH_(4251 \ +/* class A needs to have dll-interface to be used by clients of class B */) + +namespace testing { + +// To implement a cardinality Foo, define: +// 1. a class FooCardinality that implements the +// CardinalityInterface interface, and +// 2. a factory function that creates a Cardinality object from a +// const FooCardinality*. +// +// The two-level delegation design follows that of Matcher, providing +// consistency for extension developers. It also eases ownership +// management as Cardinality objects can now be copied like plain values. + +// The implementation of a cardinality. +class CardinalityInterface { + public: + virtual ~CardinalityInterface() = default; + + // Conservative estimate on the lower/upper bound of the number of + // calls allowed. + virtual int ConservativeLowerBound() const { return 0; } + virtual int ConservativeUpperBound() const { return INT_MAX; } + + // Returns true if and only if call_count calls will satisfy this + // cardinality. + virtual bool IsSatisfiedByCallCount(int call_count) const = 0; + + // Returns true if and only if call_count calls will saturate this + // cardinality. + virtual bool IsSaturatedByCallCount(int call_count) const = 0; + + // Describes self to an ostream. + virtual void DescribeTo(::std::ostream* os) const = 0; +}; + +// A Cardinality is a copyable and IMMUTABLE (except by assignment) +// object that specifies how many times a mock function is expected to +// be called. The implementation of Cardinality is just a std::shared_ptr +// to const CardinalityInterface. Don't inherit from Cardinality! +class GTEST_API_ Cardinality { + public: + // Constructs a null cardinality. Needed for storing Cardinality + // objects in STL containers. + Cardinality() = default; + + // Constructs a Cardinality from its implementation. + explicit Cardinality(const CardinalityInterface* impl) : impl_(impl) {} + + // Conservative estimate on the lower/upper bound of the number of + // calls allowed. + int ConservativeLowerBound() const { return impl_->ConservativeLowerBound(); } + int ConservativeUpperBound() const { return impl_->ConservativeUpperBound(); } + + // Returns true if and only if call_count calls will satisfy this + // cardinality. + bool IsSatisfiedByCallCount(int call_count) const { + return impl_->IsSatisfiedByCallCount(call_count); + } + + // Returns true if and only if call_count calls will saturate this + // cardinality. + bool IsSaturatedByCallCount(int call_count) const { + return impl_->IsSaturatedByCallCount(call_count); + } + + // Returns true if and only if call_count calls will over-saturate this + // cardinality, i.e. exceed the maximum number of allowed calls. + bool IsOverSaturatedByCallCount(int call_count) const { + return impl_->IsSaturatedByCallCount(call_count) && + !impl_->IsSatisfiedByCallCount(call_count); + } + + // Describes self to an ostream + void DescribeTo(::std::ostream* os) const { impl_->DescribeTo(os); } + + // Describes the given actual call count to an ostream. + static void DescribeActualCallCountTo(int actual_call_count, + ::std::ostream* os); + + private: + std::shared_ptr impl_; +}; + +// Creates a cardinality that allows at least n calls. +GTEST_API_ Cardinality AtLeast(int n); + +// Creates a cardinality that allows at most n calls. +GTEST_API_ Cardinality AtMost(int n); + +// Creates a cardinality that allows any number of calls. +GTEST_API_ Cardinality AnyNumber(); + +// Creates a cardinality that allows between min and max calls. +GTEST_API_ Cardinality Between(int min, int max); + +// Creates a cardinality that allows exactly n calls. +GTEST_API_ Cardinality Exactly(int n); + +// Creates a cardinality from its implementation. +inline Cardinality MakeCardinality(const CardinalityInterface* c) { + return Cardinality(c); +} + +} // namespace testing + +GTEST_DISABLE_MSC_WARNINGS_POP_() // 4251 + +#endif // GOOGLEMOCK_INCLUDE_GMOCK_GMOCK_CARDINALITIES_H_ diff --git a/test/googletest/googlemock/include/gmock/gmock-function-mocker.h b/test/googletest/googlemock/include/gmock/gmock-function-mocker.h new file mode 100644 index 0000000..d2cb13c --- /dev/null +++ b/test/googletest/googlemock/include/gmock/gmock-function-mocker.h @@ -0,0 +1,519 @@ +// Copyright 2007, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Google Mock - a framework for writing C++ mock classes. +// +// This file implements MOCK_METHOD. + +// IWYU pragma: private, include "gmock/gmock.h" +// IWYU pragma: friend gmock/.* + +#ifndef GOOGLEMOCK_INCLUDE_GMOCK_GMOCK_FUNCTION_MOCKER_H_ +#define GOOGLEMOCK_INCLUDE_GMOCK_GMOCK_FUNCTION_MOCKER_H_ + +#include +#include // IWYU pragma: keep +#include // IWYU pragma: keep + +#include "gmock/gmock-spec-builders.h" +#include "gmock/internal/gmock-internal-utils.h" +#include "gmock/internal/gmock-pp.h" + +namespace testing { +namespace internal { +template +using identity_t = T; + +template +struct ThisRefAdjuster { + template + using AdjustT = typename std::conditional< + std::is_const::type>::value, + typename std::conditional::value, + const T&, const T&&>::type, + typename std::conditional::value, T&, + T&&>::type>::type; + + template + static AdjustT Adjust(const MockType& mock) { + return static_cast>(const_cast(mock)); + } +}; + +constexpr bool PrefixOf(const char* a, const char* b) { + return *a == 0 || (*a == *b && internal::PrefixOf(a + 1, b + 1)); +} + +template +constexpr bool StartsWith(const char (&prefix)[N], const char (&str)[M]) { + return N <= M && internal::PrefixOf(prefix, str); +} + +template +constexpr bool EndsWith(const char (&suffix)[N], const char (&str)[M]) { + return N <= M && internal::PrefixOf(suffix, str + M - N); +} + +template +constexpr bool Equals(const char (&a)[N], const char (&b)[M]) { + return N == M && internal::PrefixOf(a, b); +} + +template +constexpr bool ValidateSpec(const char (&spec)[N]) { + return internal::Equals("const", spec) || + internal::Equals("override", spec) || + internal::Equals("final", spec) || + internal::Equals("noexcept", spec) || + (internal::StartsWith("noexcept(", spec) && + internal::EndsWith(")", spec)) || + internal::Equals("ref(&)", spec) || + internal::Equals("ref(&&)", spec) || + (internal::StartsWith("Calltype(", spec) && + internal::EndsWith(")", spec)); +} + +} // namespace internal + +// The style guide prohibits "using" statements in a namespace scope +// inside a header file. However, the FunctionMocker class template +// is meant to be defined in the ::testing namespace. The following +// line is just a trick for working around a bug in MSVC 8.0, which +// cannot handle it if we define FunctionMocker in ::testing. +using internal::FunctionMocker; +} // namespace testing + +#define MOCK_METHOD(...) \ + GMOCK_INTERNAL_WARNING_PUSH() \ + GMOCK_INTERNAL_WARNING_CLANG(ignored, "-Wunused-member-function") \ + GMOCK_PP_VARIADIC_CALL(GMOCK_INTERNAL_MOCK_METHOD_ARG_, __VA_ARGS__) \ + GMOCK_INTERNAL_WARNING_POP() + +#define GMOCK_INTERNAL_MOCK_METHOD_ARG_1(...) \ + GMOCK_INTERNAL_WRONG_ARITY(__VA_ARGS__) + +#define GMOCK_INTERNAL_MOCK_METHOD_ARG_2(...) \ + GMOCK_INTERNAL_WRONG_ARITY(__VA_ARGS__) + +#define GMOCK_INTERNAL_MOCK_METHOD_ARG_3(_Ret, _MethodName, _Args) \ + GMOCK_INTERNAL_MOCK_METHOD_ARG_4(_Ret, _MethodName, _Args, ()) + +#define GMOCK_INTERNAL_MOCK_METHOD_ARG_4(_Ret, _MethodName, _Args, _Spec) \ + GMOCK_INTERNAL_ASSERT_PARENTHESIS(_Args); \ + GMOCK_INTERNAL_ASSERT_PARENTHESIS(_Spec); \ + GMOCK_INTERNAL_ASSERT_VALID_SIGNATURE( \ + GMOCK_PP_NARG0 _Args, GMOCK_INTERNAL_SIGNATURE(_Ret, _Args)); \ + GMOCK_INTERNAL_ASSERT_VALID_SPEC(_Spec) \ + GMOCK_INTERNAL_MOCK_METHOD_IMPL( \ + GMOCK_PP_NARG0 _Args, _MethodName, GMOCK_INTERNAL_HAS_CONST(_Spec), \ + GMOCK_INTERNAL_HAS_OVERRIDE(_Spec), GMOCK_INTERNAL_HAS_FINAL(_Spec), \ + GMOCK_INTERNAL_GET_NOEXCEPT_SPEC(_Spec), \ + GMOCK_INTERNAL_GET_CALLTYPE_SPEC(_Spec), \ + GMOCK_INTERNAL_GET_REF_SPEC(_Spec), \ + (GMOCK_INTERNAL_SIGNATURE(_Ret, _Args))) + +#define GMOCK_INTERNAL_MOCK_METHOD_ARG_5(...) \ + GMOCK_INTERNAL_WRONG_ARITY(__VA_ARGS__) + +#define GMOCK_INTERNAL_MOCK_METHOD_ARG_6(...) \ + GMOCK_INTERNAL_WRONG_ARITY(__VA_ARGS__) + +#define GMOCK_INTERNAL_MOCK_METHOD_ARG_7(...) \ + GMOCK_INTERNAL_WRONG_ARITY(__VA_ARGS__) + +#define GMOCK_INTERNAL_WRONG_ARITY(...) \ + static_assert( \ + false, \ + "MOCK_METHOD must be called with 3 or 4 arguments. _Ret, " \ + "_MethodName, _Args and optionally _Spec. _Args and _Spec must be " \ + "enclosed in parentheses. If _Ret is a type with unprotected commas, " \ + "it must also be enclosed in parentheses.") + +#define GMOCK_INTERNAL_ASSERT_PARENTHESIS(_Tuple) \ + static_assert( \ + GMOCK_PP_IS_ENCLOSED_PARENS(_Tuple), \ + GMOCK_PP_STRINGIZE(_Tuple) " should be enclosed in parentheses.") + +#define GMOCK_INTERNAL_ASSERT_VALID_SIGNATURE(_N, ...) \ + static_assert( \ + std::is_function<__VA_ARGS__>::value, \ + "Signature must be a function type, maybe return type contains " \ + "unprotected comma."); \ + static_assert( \ + ::testing::tuple_size::ArgumentTuple>::value == _N, \ + "This method does not take " GMOCK_PP_STRINGIZE( \ + _N) " arguments. Parenthesize all types with unprotected commas.") + +#define GMOCK_INTERNAL_ASSERT_VALID_SPEC(_Spec) \ + GMOCK_PP_FOR_EACH(GMOCK_INTERNAL_ASSERT_VALID_SPEC_ELEMENT, ~, _Spec) + +#define GMOCK_INTERNAL_MOCK_METHOD_IMPL(_N, _MethodName, _Constness, \ + _Override, _Final, _NoexceptSpec, \ + _CallType, _RefSpec, _Signature) \ + typename ::testing::internal::Function::Result \ + GMOCK_INTERNAL_EXPAND(_CallType) \ + _MethodName(GMOCK_PP_REPEAT(GMOCK_INTERNAL_PARAMETER, _Signature, _N)) \ + GMOCK_PP_IF(_Constness, const, ) \ + _RefSpec _NoexceptSpec GMOCK_PP_IF(_Override, override, ) \ + GMOCK_PP_IF(_Final, final, ) { \ + GMOCK_MOCKER_(_N, _Constness, _MethodName) \ + .SetOwnerAndName(this, #_MethodName); \ + return GMOCK_MOCKER_(_N, _Constness, _MethodName) \ + .Invoke(GMOCK_PP_REPEAT(GMOCK_INTERNAL_FORWARD_ARG, _Signature, _N)); \ + } \ + ::testing::MockSpec gmock_##_MethodName( \ + GMOCK_PP_REPEAT(GMOCK_INTERNAL_MATCHER_PARAMETER, _Signature, _N)) \ + GMOCK_PP_IF(_Constness, const, ) _RefSpec { \ + GMOCK_MOCKER_(_N, _Constness, _MethodName).RegisterOwner(this); \ + return GMOCK_MOCKER_(_N, _Constness, _MethodName) \ + .With(GMOCK_PP_REPEAT(GMOCK_INTERNAL_MATCHER_ARGUMENT, , _N)); \ + } \ + ::testing::MockSpec gmock_##_MethodName( \ + const ::testing::internal::WithoutMatchers&, \ + GMOCK_PP_IF(_Constness, const, )::testing::internal::Function< \ + GMOCK_PP_REMOVE_PARENS(_Signature)>*) const _RefSpec _NoexceptSpec { \ + return ::testing::internal::ThisRefAdjuster::Adjust(*this) \ + .gmock_##_MethodName(GMOCK_PP_REPEAT( \ + GMOCK_INTERNAL_A_MATCHER_ARGUMENT, _Signature, _N)); \ + } \ + mutable ::testing::FunctionMocker \ + GMOCK_MOCKER_(_N, _Constness, _MethodName) + +#define GMOCK_INTERNAL_EXPAND(...) __VA_ARGS__ + +// Valid modifiers. +#define GMOCK_INTERNAL_HAS_CONST(_Tuple) \ + GMOCK_PP_HAS_COMMA(GMOCK_PP_FOR_EACH(GMOCK_INTERNAL_DETECT_CONST, ~, _Tuple)) + +#define GMOCK_INTERNAL_HAS_OVERRIDE(_Tuple) \ + GMOCK_PP_HAS_COMMA( \ + GMOCK_PP_FOR_EACH(GMOCK_INTERNAL_DETECT_OVERRIDE, ~, _Tuple)) + +#define GMOCK_INTERNAL_HAS_FINAL(_Tuple) \ + GMOCK_PP_HAS_COMMA(GMOCK_PP_FOR_EACH(GMOCK_INTERNAL_DETECT_FINAL, ~, _Tuple)) + +#define GMOCK_INTERNAL_GET_NOEXCEPT_SPEC(_Tuple) \ + GMOCK_PP_FOR_EACH(GMOCK_INTERNAL_NOEXCEPT_SPEC_IF_NOEXCEPT, ~, _Tuple) + +#define GMOCK_INTERNAL_NOEXCEPT_SPEC_IF_NOEXCEPT(_i, _, _elem) \ + GMOCK_PP_IF( \ + GMOCK_PP_HAS_COMMA(GMOCK_INTERNAL_DETECT_NOEXCEPT(_i, _, _elem)), \ + _elem, ) + +#define GMOCK_INTERNAL_GET_CALLTYPE_SPEC(_Tuple) \ + GMOCK_PP_FOR_EACH(GMOCK_INTERNAL_CALLTYPE_SPEC_IF_CALLTYPE, ~, _Tuple) + +#define GMOCK_INTERNAL_CALLTYPE_SPEC_IF_CALLTYPE(_i, _, _elem) \ + GMOCK_PP_IF( \ + GMOCK_PP_HAS_COMMA(GMOCK_INTERNAL_DETECT_CALLTYPE(_i, _, _elem)), \ + GMOCK_PP_CAT(GMOCK_INTERNAL_UNPACK_, _elem), ) + +#define GMOCK_INTERNAL_GET_REF_SPEC(_Tuple) \ + GMOCK_PP_FOR_EACH(GMOCK_INTERNAL_REF_SPEC_IF_REF, ~, _Tuple) + +#define GMOCK_INTERNAL_REF_SPEC_IF_REF(_i, _, _elem) \ + GMOCK_PP_IF(GMOCK_PP_HAS_COMMA(GMOCK_INTERNAL_DETECT_REF(_i, _, _elem)), \ + GMOCK_PP_CAT(GMOCK_INTERNAL_UNPACK_, _elem), ) + +#ifdef GMOCK_INTERNAL_STRICT_SPEC_ASSERT +#define GMOCK_INTERNAL_ASSERT_VALID_SPEC_ELEMENT(_i, _, _elem) \ + static_assert( \ + ::testing::internal::ValidateSpec(GMOCK_PP_STRINGIZE(_elem)), \ + "Token \'" GMOCK_PP_STRINGIZE( \ + _elem) "\' cannot be recognized as a valid specification " \ + "modifier. Is a ',' missing?"); +#else +#define GMOCK_INTERNAL_ASSERT_VALID_SPEC_ELEMENT(_i, _, _elem) \ + static_assert( \ + (GMOCK_PP_HAS_COMMA(GMOCK_INTERNAL_DETECT_CONST(_i, _, _elem)) + \ + GMOCK_PP_HAS_COMMA(GMOCK_INTERNAL_DETECT_OVERRIDE(_i, _, _elem)) + \ + GMOCK_PP_HAS_COMMA(GMOCK_INTERNAL_DETECT_FINAL(_i, _, _elem)) + \ + GMOCK_PP_HAS_COMMA(GMOCK_INTERNAL_DETECT_NOEXCEPT(_i, _, _elem)) + \ + GMOCK_PP_HAS_COMMA(GMOCK_INTERNAL_DETECT_REF(_i, _, _elem)) + \ + GMOCK_PP_HAS_COMMA(GMOCK_INTERNAL_DETECT_CALLTYPE(_i, _, _elem))) == 1, \ + GMOCK_PP_STRINGIZE( \ + _elem) " cannot be recognized as a valid specification modifier."); +#endif // GMOCK_INTERNAL_STRICT_SPEC_ASSERT + +// Modifiers implementation. +#define GMOCK_INTERNAL_DETECT_CONST(_i, _, _elem) \ + GMOCK_PP_CAT(GMOCK_INTERNAL_DETECT_CONST_I_, _elem) + +#define GMOCK_INTERNAL_DETECT_CONST_I_const , + +#define GMOCK_INTERNAL_DETECT_OVERRIDE(_i, _, _elem) \ + GMOCK_PP_CAT(GMOCK_INTERNAL_DETECT_OVERRIDE_I_, _elem) + +#define GMOCK_INTERNAL_DETECT_OVERRIDE_I_override , + +#define GMOCK_INTERNAL_DETECT_FINAL(_i, _, _elem) \ + GMOCK_PP_CAT(GMOCK_INTERNAL_DETECT_FINAL_I_, _elem) + +#define GMOCK_INTERNAL_DETECT_FINAL_I_final , + +#define GMOCK_INTERNAL_DETECT_NOEXCEPT(_i, _, _elem) \ + GMOCK_PP_CAT(GMOCK_INTERNAL_DETECT_NOEXCEPT_I_, _elem) + +#define GMOCK_INTERNAL_DETECT_NOEXCEPT_I_noexcept , + +#define GMOCK_INTERNAL_DETECT_REF(_i, _, _elem) \ + GMOCK_PP_CAT(GMOCK_INTERNAL_DETECT_REF_I_, _elem) + +#define GMOCK_INTERNAL_DETECT_REF_I_ref , + +#define GMOCK_INTERNAL_UNPACK_ref(x) x + +#define GMOCK_INTERNAL_DETECT_CALLTYPE(_i, _, _elem) \ + GMOCK_PP_CAT(GMOCK_INTERNAL_DETECT_CALLTYPE_I_, _elem) + +#define GMOCK_INTERNAL_DETECT_CALLTYPE_I_Calltype , + +#define GMOCK_INTERNAL_UNPACK_Calltype(...) __VA_ARGS__ + +// Note: The use of `identity_t` here allows _Ret to represent return types that +// would normally need to be specified in a different way. For example, a method +// returning a function pointer must be written as +// +// fn_ptr_return_t (*method(method_args_t...))(fn_ptr_args_t...) +// +// But we only support placing the return type at the beginning. To handle this, +// we wrap all calls in identity_t, so that a declaration will be expanded to +// +// identity_t method(method_args_t...) +// +// This allows us to work around the syntactic oddities of function/method +// types. +#define GMOCK_INTERNAL_SIGNATURE(_Ret, _Args) \ + ::testing::internal::identity_t( \ + GMOCK_PP_FOR_EACH(GMOCK_INTERNAL_GET_TYPE, _, _Args)) + +#define GMOCK_INTERNAL_GET_TYPE(_i, _, _elem) \ + GMOCK_PP_COMMA_IF(_i) \ + GMOCK_PP_IF(GMOCK_PP_IS_BEGIN_PARENS(_elem), GMOCK_PP_REMOVE_PARENS, \ + GMOCK_PP_IDENTITY) \ + (_elem) + +#define GMOCK_INTERNAL_PARAMETER(_i, _Signature, _) \ + GMOCK_PP_COMMA_IF(_i) \ + GMOCK_INTERNAL_ARG_O(_i, GMOCK_PP_REMOVE_PARENS(_Signature)) \ + gmock_a##_i + +#define GMOCK_INTERNAL_FORWARD_ARG(_i, _Signature, _) \ + GMOCK_PP_COMMA_IF(_i) \ + ::std::forward(gmock_a##_i) + +#define GMOCK_INTERNAL_MATCHER_PARAMETER(_i, _Signature, _) \ + GMOCK_PP_COMMA_IF(_i) \ + GMOCK_INTERNAL_MATCHER_O(_i, GMOCK_PP_REMOVE_PARENS(_Signature)) \ + gmock_a##_i + +#define GMOCK_INTERNAL_MATCHER_ARGUMENT(_i, _1, _2) \ + GMOCK_PP_COMMA_IF(_i) \ + gmock_a##_i + +#define GMOCK_INTERNAL_A_MATCHER_ARGUMENT(_i, _Signature, _) \ + GMOCK_PP_COMMA_IF(_i) \ + ::testing::A() + +#define GMOCK_INTERNAL_ARG_O(_i, ...) \ + typename ::testing::internal::Function<__VA_ARGS__>::template Arg<_i>::type + +#define GMOCK_INTERNAL_MATCHER_O(_i, ...) \ + const ::testing::Matcher::template Arg<_i>::type>& + +#define MOCK_METHOD0(m, ...) GMOCK_INTERNAL_MOCK_METHODN(, , m, 0, __VA_ARGS__) +#define MOCK_METHOD1(m, ...) GMOCK_INTERNAL_MOCK_METHODN(, , m, 1, __VA_ARGS__) +#define MOCK_METHOD2(m, ...) GMOCK_INTERNAL_MOCK_METHODN(, , m, 2, __VA_ARGS__) +#define MOCK_METHOD3(m, ...) GMOCK_INTERNAL_MOCK_METHODN(, , m, 3, __VA_ARGS__) +#define MOCK_METHOD4(m, ...) GMOCK_INTERNAL_MOCK_METHODN(, , m, 4, __VA_ARGS__) +#define MOCK_METHOD5(m, ...) GMOCK_INTERNAL_MOCK_METHODN(, , m, 5, __VA_ARGS__) +#define MOCK_METHOD6(m, ...) GMOCK_INTERNAL_MOCK_METHODN(, , m, 6, __VA_ARGS__) +#define MOCK_METHOD7(m, ...) GMOCK_INTERNAL_MOCK_METHODN(, , m, 7, __VA_ARGS__) +#define MOCK_METHOD8(m, ...) GMOCK_INTERNAL_MOCK_METHODN(, , m, 8, __VA_ARGS__) +#define MOCK_METHOD9(m, ...) GMOCK_INTERNAL_MOCK_METHODN(, , m, 9, __VA_ARGS__) +#define MOCK_METHOD10(m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(, , m, 10, __VA_ARGS__) + +#define MOCK_CONST_METHOD0(m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(const, , m, 0, __VA_ARGS__) +#define MOCK_CONST_METHOD1(m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(const, , m, 1, __VA_ARGS__) +#define MOCK_CONST_METHOD2(m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(const, , m, 2, __VA_ARGS__) +#define MOCK_CONST_METHOD3(m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(const, , m, 3, __VA_ARGS__) +#define MOCK_CONST_METHOD4(m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(const, , m, 4, __VA_ARGS__) +#define MOCK_CONST_METHOD5(m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(const, , m, 5, __VA_ARGS__) +#define MOCK_CONST_METHOD6(m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(const, , m, 6, __VA_ARGS__) +#define MOCK_CONST_METHOD7(m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(const, , m, 7, __VA_ARGS__) +#define MOCK_CONST_METHOD8(m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(const, , m, 8, __VA_ARGS__) +#define MOCK_CONST_METHOD9(m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(const, , m, 9, __VA_ARGS__) +#define MOCK_CONST_METHOD10(m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(const, , m, 10, __VA_ARGS__) + +#define MOCK_METHOD0_T(m, ...) MOCK_METHOD0(m, __VA_ARGS__) +#define MOCK_METHOD1_T(m, ...) MOCK_METHOD1(m, __VA_ARGS__) +#define MOCK_METHOD2_T(m, ...) MOCK_METHOD2(m, __VA_ARGS__) +#define MOCK_METHOD3_T(m, ...) MOCK_METHOD3(m, __VA_ARGS__) +#define MOCK_METHOD4_T(m, ...) MOCK_METHOD4(m, __VA_ARGS__) +#define MOCK_METHOD5_T(m, ...) MOCK_METHOD5(m, __VA_ARGS__) +#define MOCK_METHOD6_T(m, ...) MOCK_METHOD6(m, __VA_ARGS__) +#define MOCK_METHOD7_T(m, ...) MOCK_METHOD7(m, __VA_ARGS__) +#define MOCK_METHOD8_T(m, ...) MOCK_METHOD8(m, __VA_ARGS__) +#define MOCK_METHOD9_T(m, ...) MOCK_METHOD9(m, __VA_ARGS__) +#define MOCK_METHOD10_T(m, ...) MOCK_METHOD10(m, __VA_ARGS__) + +#define MOCK_CONST_METHOD0_T(m, ...) MOCK_CONST_METHOD0(m, __VA_ARGS__) +#define MOCK_CONST_METHOD1_T(m, ...) MOCK_CONST_METHOD1(m, __VA_ARGS__) +#define MOCK_CONST_METHOD2_T(m, ...) MOCK_CONST_METHOD2(m, __VA_ARGS__) +#define MOCK_CONST_METHOD3_T(m, ...) MOCK_CONST_METHOD3(m, __VA_ARGS__) +#define MOCK_CONST_METHOD4_T(m, ...) MOCK_CONST_METHOD4(m, __VA_ARGS__) +#define MOCK_CONST_METHOD5_T(m, ...) MOCK_CONST_METHOD5(m, __VA_ARGS__) +#define MOCK_CONST_METHOD6_T(m, ...) MOCK_CONST_METHOD6(m, __VA_ARGS__) +#define MOCK_CONST_METHOD7_T(m, ...) MOCK_CONST_METHOD7(m, __VA_ARGS__) +#define MOCK_CONST_METHOD8_T(m, ...) MOCK_CONST_METHOD8(m, __VA_ARGS__) +#define MOCK_CONST_METHOD9_T(m, ...) MOCK_CONST_METHOD9(m, __VA_ARGS__) +#define MOCK_CONST_METHOD10_T(m, ...) MOCK_CONST_METHOD10(m, __VA_ARGS__) + +#define MOCK_METHOD0_WITH_CALLTYPE(ct, m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(, ct, m, 0, __VA_ARGS__) +#define MOCK_METHOD1_WITH_CALLTYPE(ct, m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(, ct, m, 1, __VA_ARGS__) +#define MOCK_METHOD2_WITH_CALLTYPE(ct, m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(, ct, m, 2, __VA_ARGS__) +#define MOCK_METHOD3_WITH_CALLTYPE(ct, m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(, ct, m, 3, __VA_ARGS__) +#define MOCK_METHOD4_WITH_CALLTYPE(ct, m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(, ct, m, 4, __VA_ARGS__) +#define MOCK_METHOD5_WITH_CALLTYPE(ct, m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(, ct, m, 5, __VA_ARGS__) +#define MOCK_METHOD6_WITH_CALLTYPE(ct, m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(, ct, m, 6, __VA_ARGS__) +#define MOCK_METHOD7_WITH_CALLTYPE(ct, m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(, ct, m, 7, __VA_ARGS__) +#define MOCK_METHOD8_WITH_CALLTYPE(ct, m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(, ct, m, 8, __VA_ARGS__) +#define MOCK_METHOD9_WITH_CALLTYPE(ct, m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(, ct, m, 9, __VA_ARGS__) +#define MOCK_METHOD10_WITH_CALLTYPE(ct, m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(, ct, m, 10, __VA_ARGS__) + +#define MOCK_CONST_METHOD0_WITH_CALLTYPE(ct, m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(const, ct, m, 0, __VA_ARGS__) +#define MOCK_CONST_METHOD1_WITH_CALLTYPE(ct, m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(const, ct, m, 1, __VA_ARGS__) +#define MOCK_CONST_METHOD2_WITH_CALLTYPE(ct, m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(const, ct, m, 2, __VA_ARGS__) +#define MOCK_CONST_METHOD3_WITH_CALLTYPE(ct, m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(const, ct, m, 3, __VA_ARGS__) +#define MOCK_CONST_METHOD4_WITH_CALLTYPE(ct, m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(const, ct, m, 4, __VA_ARGS__) +#define MOCK_CONST_METHOD5_WITH_CALLTYPE(ct, m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(const, ct, m, 5, __VA_ARGS__) +#define MOCK_CONST_METHOD6_WITH_CALLTYPE(ct, m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(const, ct, m, 6, __VA_ARGS__) +#define MOCK_CONST_METHOD7_WITH_CALLTYPE(ct, m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(const, ct, m, 7, __VA_ARGS__) +#define MOCK_CONST_METHOD8_WITH_CALLTYPE(ct, m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(const, ct, m, 8, __VA_ARGS__) +#define MOCK_CONST_METHOD9_WITH_CALLTYPE(ct, m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(const, ct, m, 9, __VA_ARGS__) +#define MOCK_CONST_METHOD10_WITH_CALLTYPE(ct, m, ...) \ + GMOCK_INTERNAL_MOCK_METHODN(const, ct, m, 10, __VA_ARGS__) + +#define MOCK_METHOD0_T_WITH_CALLTYPE(ct, m, ...) \ + MOCK_METHOD0_WITH_CALLTYPE(ct, m, __VA_ARGS__) +#define MOCK_METHOD1_T_WITH_CALLTYPE(ct, m, ...) \ + MOCK_METHOD1_WITH_CALLTYPE(ct, m, __VA_ARGS__) +#define MOCK_METHOD2_T_WITH_CALLTYPE(ct, m, ...) \ + MOCK_METHOD2_WITH_CALLTYPE(ct, m, __VA_ARGS__) +#define MOCK_METHOD3_T_WITH_CALLTYPE(ct, m, ...) \ + MOCK_METHOD3_WITH_CALLTYPE(ct, m, __VA_ARGS__) +#define MOCK_METHOD4_T_WITH_CALLTYPE(ct, m, ...) \ + MOCK_METHOD4_WITH_CALLTYPE(ct, m, __VA_ARGS__) +#define MOCK_METHOD5_T_WITH_CALLTYPE(ct, m, ...) \ + MOCK_METHOD5_WITH_CALLTYPE(ct, m, __VA_ARGS__) +#define MOCK_METHOD6_T_WITH_CALLTYPE(ct, m, ...) \ + MOCK_METHOD6_WITH_CALLTYPE(ct, m, __VA_ARGS__) +#define MOCK_METHOD7_T_WITH_CALLTYPE(ct, m, ...) \ + MOCK_METHOD7_WITH_CALLTYPE(ct, m, __VA_ARGS__) +#define MOCK_METHOD8_T_WITH_CALLTYPE(ct, m, ...) \ + MOCK_METHOD8_WITH_CALLTYPE(ct, m, __VA_ARGS__) +#define MOCK_METHOD9_T_WITH_CALLTYPE(ct, m, ...) \ + MOCK_METHOD9_WITH_CALLTYPE(ct, m, __VA_ARGS__) +#define MOCK_METHOD10_T_WITH_CALLTYPE(ct, m, ...) \ + MOCK_METHOD10_WITH_CALLTYPE(ct, m, __VA_ARGS__) + +#define MOCK_CONST_METHOD0_T_WITH_CALLTYPE(ct, m, ...) \ + MOCK_CONST_METHOD0_WITH_CALLTYPE(ct, m, __VA_ARGS__) +#define MOCK_CONST_METHOD1_T_WITH_CALLTYPE(ct, m, ...) \ + MOCK_CONST_METHOD1_WITH_CALLTYPE(ct, m, __VA_ARGS__) +#define MOCK_CONST_METHOD2_T_WITH_CALLTYPE(ct, m, ...) \ + MOCK_CONST_METHOD2_WITH_CALLTYPE(ct, m, __VA_ARGS__) +#define MOCK_CONST_METHOD3_T_WITH_CALLTYPE(ct, m, ...) \ + MOCK_CONST_METHOD3_WITH_CALLTYPE(ct, m, __VA_ARGS__) +#define MOCK_CONST_METHOD4_T_WITH_CALLTYPE(ct, m, ...) \ + MOCK_CONST_METHOD4_WITH_CALLTYPE(ct, m, __VA_ARGS__) +#define MOCK_CONST_METHOD5_T_WITH_CALLTYPE(ct, m, ...) \ + MOCK_CONST_METHOD5_WITH_CALLTYPE(ct, m, __VA_ARGS__) +#define MOCK_CONST_METHOD6_T_WITH_CALLTYPE(ct, m, ...) \ + MOCK_CONST_METHOD6_WITH_CALLTYPE(ct, m, __VA_ARGS__) +#define MOCK_CONST_METHOD7_T_WITH_CALLTYPE(ct, m, ...) \ + MOCK_CONST_METHOD7_WITH_CALLTYPE(ct, m, __VA_ARGS__) +#define MOCK_CONST_METHOD8_T_WITH_CALLTYPE(ct, m, ...) \ + MOCK_CONST_METHOD8_WITH_CALLTYPE(ct, m, __VA_ARGS__) +#define MOCK_CONST_METHOD9_T_WITH_CALLTYPE(ct, m, ...) \ + MOCK_CONST_METHOD9_WITH_CALLTYPE(ct, m, __VA_ARGS__) +#define MOCK_CONST_METHOD10_T_WITH_CALLTYPE(ct, m, ...) \ + MOCK_CONST_METHOD10_WITH_CALLTYPE(ct, m, __VA_ARGS__) + +#define GMOCK_INTERNAL_MOCK_METHODN(constness, ct, Method, args_num, ...) \ + GMOCK_INTERNAL_ASSERT_VALID_SIGNATURE( \ + args_num, ::testing::internal::identity_t<__VA_ARGS__>); \ + GMOCK_INTERNAL_MOCK_METHOD_IMPL( \ + args_num, Method, GMOCK_PP_NARG0(constness), 0, 0, , ct, , \ + (::testing::internal::identity_t<__VA_ARGS__>)) + +#define GMOCK_MOCKER_(arity, constness, Method) \ + GTEST_CONCAT_TOKEN_(gmock##constness##arity##_##Method##_, __LINE__) + +#endif // GOOGLEMOCK_INCLUDE_GMOCK_GMOCK_FUNCTION_MOCKER_H_ diff --git a/test/googletest/googlemock/include/gmock/gmock-matchers.h b/test/googletest/googlemock/include/gmock/gmock-matchers.h new file mode 100644 index 0000000..16cf9c3 --- /dev/null +++ b/test/googletest/googlemock/include/gmock/gmock-matchers.h @@ -0,0 +1,5677 @@ +// Copyright 2007, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Google Mock - a framework for writing C++ mock classes. +// +// The MATCHER* family of macros can be used in a namespace scope to +// define custom matchers easily. +// +// Basic Usage +// =========== +// +// The syntax +// +// MATCHER(name, description_string) { statements; } +// +// defines a matcher with the given name that executes the statements, +// which must return a bool to indicate if the match succeeds. Inside +// the statements, you can refer to the value being matched by 'arg', +// and refer to its type by 'arg_type'. +// +// The description string documents what the matcher does, and is used +// to generate the failure message when the match fails. Since a +// MATCHER() is usually defined in a header file shared by multiple +// C++ source files, we require the description to be a C-string +// literal to avoid possible side effects. It can be empty, in which +// case we'll use the sequence of words in the matcher name as the +// description. +// +// For example: +// +// MATCHER(IsEven, "") { return (arg % 2) == 0; } +// +// allows you to write +// +// // Expects mock_foo.Bar(n) to be called where n is even. +// EXPECT_CALL(mock_foo, Bar(IsEven())); +// +// or, +// +// // Verifies that the value of some_expression is even. +// EXPECT_THAT(some_expression, IsEven()); +// +// If the above assertion fails, it will print something like: +// +// Value of: some_expression +// Expected: is even +// Actual: 7 +// +// where the description "is even" is automatically calculated from the +// matcher name IsEven. +// +// Argument Type +// ============= +// +// Note that the type of the value being matched (arg_type) is +// determined by the context in which you use the matcher and is +// supplied to you by the compiler, so you don't need to worry about +// declaring it (nor can you). This allows the matcher to be +// polymorphic. For example, IsEven() can be used to match any type +// where the value of "(arg % 2) == 0" can be implicitly converted to +// a bool. In the "Bar(IsEven())" example above, if method Bar() +// takes an int, 'arg_type' will be int; if it takes an unsigned long, +// 'arg_type' will be unsigned long; and so on. +// +// Parameterizing Matchers +// ======================= +// +// Sometimes you'll want to parameterize the matcher. For that you +// can use another macro: +// +// MATCHER_P(name, param_name, description_string) { statements; } +// +// For example: +// +// MATCHER_P(HasAbsoluteValue, value, "") { return abs(arg) == value; } +// +// will allow you to write: +// +// EXPECT_THAT(Blah("a"), HasAbsoluteValue(n)); +// +// which may lead to this message (assuming n is 10): +// +// Value of: Blah("a") +// Expected: has absolute value 10 +// Actual: -9 +// +// Note that both the matcher description and its parameter are +// printed, making the message human-friendly. +// +// In the matcher definition body, you can write 'foo_type' to +// reference the type of a parameter named 'foo'. For example, in the +// body of MATCHER_P(HasAbsoluteValue, value) above, you can write +// 'value_type' to refer to the type of 'value'. +// +// We also provide MATCHER_P2, MATCHER_P3, ..., up to MATCHER_P$n to +// support multi-parameter matchers. +// +// Describing Parameterized Matchers +// ================================= +// +// The last argument to MATCHER*() is a string-typed expression. The +// expression can reference all of the matcher's parameters and a +// special bool-typed variable named 'negation'. When 'negation' is +// false, the expression should evaluate to the matcher's description; +// otherwise it should evaluate to the description of the negation of +// the matcher. For example, +// +// using testing::PrintToString; +// +// MATCHER_P2(InClosedRange, low, hi, +// std::string(negation ? "is not" : "is") + " in range [" + +// PrintToString(low) + ", " + PrintToString(hi) + "]") { +// return low <= arg && arg <= hi; +// } +// ... +// EXPECT_THAT(3, InClosedRange(4, 6)); +// EXPECT_THAT(3, Not(InClosedRange(2, 4))); +// +// would generate two failures that contain the text: +// +// Expected: is in range [4, 6] +// ... +// Expected: is not in range [2, 4] +// +// If you specify "" as the description, the failure message will +// contain the sequence of words in the matcher name followed by the +// parameter values printed as a tuple. For example, +// +// MATCHER_P2(InClosedRange, low, hi, "") { ... } +// ... +// EXPECT_THAT(3, InClosedRange(4, 6)); +// EXPECT_THAT(3, Not(InClosedRange(2, 4))); +// +// would generate two failures that contain the text: +// +// Expected: in closed range (4, 6) +// ... +// Expected: not (in closed range (2, 4)) +// +// Types of Matcher Parameters +// =========================== +// +// For the purpose of typing, you can view +// +// MATCHER_Pk(Foo, p1, ..., pk, description_string) { ... } +// +// as shorthand for +// +// template +// FooMatcherPk +// Foo(p1_type p1, ..., pk_type pk) { ... } +// +// When you write Foo(v1, ..., vk), the compiler infers the types of +// the parameters v1, ..., and vk for you. If you are not happy with +// the result of the type inference, you can specify the types by +// explicitly instantiating the template, as in Foo(5, +// false). As said earlier, you don't get to (or need to) specify +// 'arg_type' as that's determined by the context in which the matcher +// is used. You can assign the result of expression Foo(p1, ..., pk) +// to a variable of type FooMatcherPk. This +// can be useful when composing matchers. +// +// While you can instantiate a matcher template with reference types, +// passing the parameters by pointer usually makes your code more +// readable. If, however, you still want to pass a parameter by +// reference, be aware that in the failure message generated by the +// matcher you will see the value of the referenced object but not its +// address. +// +// Explaining Match Results +// ======================== +// +// Sometimes the matcher description alone isn't enough to explain why +// the match has failed or succeeded. For example, when expecting a +// long string, it can be very helpful to also print the diff between +// the expected string and the actual one. To achieve that, you can +// optionally stream additional information to a special variable +// named result_listener, whose type is a pointer to class +// MatchResultListener: +// +// MATCHER_P(EqualsLongString, str, "") { +// if (arg == str) return true; +// +// *result_listener << "the difference: " +/// << DiffStrings(str, arg); +// return false; +// } +// +// Overloading Matchers +// ==================== +// +// You can overload matchers with different numbers of parameters: +// +// MATCHER_P(Blah, a, description_string1) { ... } +// MATCHER_P2(Blah, a, b, description_string2) { ... } +// +// Caveats +// ======= +// +// When defining a new matcher, you should also consider implementing +// MatcherInterface or using MakePolymorphicMatcher(). These +// approaches require more work than the MATCHER* macros, but also +// give you more control on the types of the value being matched and +// the matcher parameters, which may leads to better compiler error +// messages when the matcher is used wrong. They also allow +// overloading matchers based on parameter types (as opposed to just +// based on the number of parameters). +// +// MATCHER*() can only be used in a namespace scope as templates cannot be +// declared inside of a local class. +// +// More Information +// ================ +// +// To learn more about using these macros, please search for 'MATCHER' +// on +// https://github.com/google/googletest/blob/main/docs/gmock_cook_book.md +// +// This file also implements some commonly used argument matchers. More +// matchers can be defined by the user implementing the +// MatcherInterface interface if necessary. +// +// See googletest/include/gtest/gtest-matchers.h for the definition of class +// Matcher, class MatcherInterface, and others. + +// IWYU pragma: private, include "gmock/gmock.h" +// IWYU pragma: friend gmock/.* + +#ifndef GOOGLEMOCK_INCLUDE_GMOCK_GMOCK_MATCHERS_H_ +#define GOOGLEMOCK_INCLUDE_GMOCK_GMOCK_MATCHERS_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include // NOLINT +#include +#include +#include +#include +#include + +#include "gmock/internal/gmock-internal-utils.h" +#include "gmock/internal/gmock-port.h" +#include "gmock/internal/gmock-pp.h" +#include "gtest/gtest.h" + +// MSVC warning C5046 is new as of VS2017 version 15.8. +#if defined(_MSC_VER) && _MSC_VER >= 1915 +#define GMOCK_MAYBE_5046_ 5046 +#else +#define GMOCK_MAYBE_5046_ +#endif + +GTEST_DISABLE_MSC_WARNINGS_PUSH_( + 4251 GMOCK_MAYBE_5046_ /* class A needs to have dll-interface to be used by + clients of class B */ + /* Symbol involving type with internal linkage not defined */) + +namespace testing { + +// To implement a matcher Foo for type T, define: +// 1. a class FooMatcherImpl that implements the +// MatcherInterface interface, and +// 2. a factory function that creates a Matcher object from a +// FooMatcherImpl*. +// +// The two-level delegation design makes it possible to allow a user +// to write "v" instead of "Eq(v)" where a Matcher is expected, which +// is impossible if we pass matchers by pointers. It also eases +// ownership management as Matcher objects can now be copied like +// plain values. + +// A match result listener that stores the explanation in a string. +class StringMatchResultListener : public MatchResultListener { + public: + StringMatchResultListener() : MatchResultListener(&ss_) {} + + // Returns the explanation accumulated so far. + std::string str() const { return ss_.str(); } + + // Clears the explanation accumulated so far. + void Clear() { ss_.str(""); } + + private: + ::std::stringstream ss_; + + StringMatchResultListener(const StringMatchResultListener&) = delete; + StringMatchResultListener& operator=(const StringMatchResultListener&) = + delete; +}; + +// Anything inside the 'internal' namespace IS INTERNAL IMPLEMENTATION +// and MUST NOT BE USED IN USER CODE!!! +namespace internal { + +// The MatcherCastImpl class template is a helper for implementing +// MatcherCast(). We need this helper in order to partially +// specialize the implementation of MatcherCast() (C++ allows +// class/struct templates to be partially specialized, but not +// function templates.). + +// This general version is used when MatcherCast()'s argument is a +// polymorphic matcher (i.e. something that can be converted to a +// Matcher but is not one yet; for example, Eq(value)) or a value (for +// example, "hello"). +template +class MatcherCastImpl { + public: + static Matcher Cast(const M& polymorphic_matcher_or_value) { + // M can be a polymorphic matcher, in which case we want to use + // its conversion operator to create Matcher. Or it can be a value + // that should be passed to the Matcher's constructor. + // + // We can't call Matcher(polymorphic_matcher_or_value) when M is a + // polymorphic matcher because it'll be ambiguous if T has an implicit + // constructor from M (this usually happens when T has an implicit + // constructor from any type). + // + // It won't work to unconditionally implicit_cast + // polymorphic_matcher_or_value to Matcher because it won't trigger + // a user-defined conversion from M to T if one exists (assuming M is + // a value). + return CastImpl(polymorphic_matcher_or_value, + std::is_convertible>{}, + std::is_convertible{}); + } + + private: + template + static Matcher CastImpl(const M& polymorphic_matcher_or_value, + std::true_type /* convertible_to_matcher */, + std::integral_constant) { + // M is implicitly convertible to Matcher, which means that either + // M is a polymorphic matcher or Matcher has an implicit constructor + // from M. In both cases using the implicit conversion will produce a + // matcher. + // + // Even if T has an implicit constructor from M, it won't be called because + // creating Matcher would require a chain of two user-defined conversions + // (first to create T from M and then to create Matcher from T). + return polymorphic_matcher_or_value; + } + + // M can't be implicitly converted to Matcher, so M isn't a polymorphic + // matcher. It's a value of a type implicitly convertible to T. Use direct + // initialization to create a matcher. + static Matcher CastImpl(const M& value, + std::false_type /* convertible_to_matcher */, + std::true_type /* convertible_to_T */) { + return Matcher(ImplicitCast_(value)); + } + + // M can't be implicitly converted to either Matcher or T. Attempt to use + // polymorphic matcher Eq(value) in this case. + // + // Note that we first attempt to perform an implicit cast on the value and + // only fall back to the polymorphic Eq() matcher afterwards because the + // latter calls bool operator==(const Lhs& lhs, const Rhs& rhs) in the end + // which might be undefined even when Rhs is implicitly convertible to Lhs + // (e.g. std::pair vs. std::pair). + // + // We don't define this method inline as we need the declaration of Eq(). + static Matcher CastImpl(const M& value, + std::false_type /* convertible_to_matcher */, + std::false_type /* convertible_to_T */); +}; + +// This more specialized version is used when MatcherCast()'s argument +// is already a Matcher. This only compiles when type T can be +// statically converted to type U. +template +class MatcherCastImpl> { + public: + static Matcher Cast(const Matcher& source_matcher) { + return Matcher(new Impl(source_matcher)); + } + + private: + class Impl : public MatcherInterface { + public: + explicit Impl(const Matcher& source_matcher) + : source_matcher_(source_matcher) {} + + // We delegate the matching logic to the source matcher. + bool MatchAndExplain(T x, MatchResultListener* listener) const override { + using FromType = typename std::remove_cv::type>::type>::type; + using ToType = typename std::remove_cv::type>::type>::type; + // Do not allow implicitly converting base*/& to derived*/&. + static_assert( + // Do not trigger if only one of them is a pointer. That implies a + // regular conversion and not a down_cast. + (std::is_pointer::type>::value != + std::is_pointer::type>::value) || + std::is_same::value || + !std::is_base_of::value, + "Can't implicitly convert from to "); + + // Do the cast to `U` explicitly if necessary. + // Otherwise, let implicit conversions do the trick. + using CastType = + typename std::conditional::value, + T&, U>::type; + + return source_matcher_.MatchAndExplain(static_cast(x), + listener); + } + + void DescribeTo(::std::ostream* os) const override { + source_matcher_.DescribeTo(os); + } + + void DescribeNegationTo(::std::ostream* os) const override { + source_matcher_.DescribeNegationTo(os); + } + + private: + const Matcher source_matcher_; + }; +}; + +// This even more specialized version is used for efficiently casting +// a matcher to its own type. +template +class MatcherCastImpl> { + public: + static Matcher Cast(const Matcher& matcher) { return matcher; } +}; + +// Template specialization for parameterless Matcher. +template +class MatcherBaseImpl { + public: + MatcherBaseImpl() = default; + + template + operator ::testing::Matcher() const { // NOLINT(runtime/explicit) + return ::testing::Matcher(new + typename Derived::template gmock_Impl()); + } +}; + +// Template specialization for Matcher with parameters. +template