-
Notifications
You must be signed in to change notification settings - Fork 602
Enable SYCL NVIDIA and AMD backends #2192
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Tested with ``` CXX=icpx CC=icx AR=llvm-ar ./build.sh -Dsycl=nvidia ``` on Ubuntu 24.04 with CUDA 12.9 and oneAPI 2025.1. The CUDA Compute Capability can be optionally specified with `-Dcc_cuda`. If not specified, the default CUDA target of the DPC++ compiler is used, which means SYCL device code is precompiled for the lowest supported CC. When executed on a GPU with different CC, it is recompiled at runtime for the specific architecture. In addition to meson.build changes, remove a redundand free(nullptr) causing crashes in the SYCL NVIDIA backend.
Tested with ``` CXX=icpx CC=icx AR=llvm-ar ./build.sh -Dsycl=amd -Damd_gfx=90a ``` on Ubuntu 22.04 with ROCm 6.3.3 and oneAPI 2025.1. The new amd_gfx option is required as DPC++ does not support Just-In-Time compilation for AMD GPU code. It has to be precompiled for the right architecture when building the application. Fix the SYCL AMD fp16 backend which missed calling the fp16 hipBLAS functions where needed. Also fix the hardcoded sub-group / warp / wavefront size of 32. Some AMD GPUs have wavefront size of 64 and this has to be used instead.
|
Looks good, thank you. I'll take a second look tomorrow before merging. BTW, do you have any performance numbers? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
Enables SYCL support for both NVIDIA and AMD backends, fixes fp16 GEMM calls, and makes subgroup sizes dynamic.
- Added
amd_gfxMeson option and configured HIPBLAS/CUBLAS flags inmeson.buildfor AMD/NVIDIA targets. - Implemented
USE_HIPBLASpaths in SYCL layers and GEMM routines, replacing single-precision calls with half-precision (hipblasHgemm). - Introduced
SYCL_SUB_GROUP_SIZEmacro in common kernels and replaced hardcoded32in subgroup attributes; removed a redundantsycl::free(nullptr).
Reviewed Changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
| src/neural/backends/sycl/layers.cc.dp.cpp | Added USE_HIPBLAS branches for fp16 GEMM, fixed redundant free, corrected SYCL submit closure |
| src/neural/backends/sycl/fp16_kernels.dp.cpp | Defined SYCL_SUB_GROUP_SIZE based on AMD vs NVIDIA, updated subgroup annotations |
| src/neural/backends/sycl/common_kernels.dp.cpp | Defined SYCL_SUB_GROUP_SIZE and updated subgroup annotations for multiple kernels |
| meson_options.txt | Introduced amd_gfx build option for specifying AMD GPU arch |
| meson.build | Configured library deps and compiler/link flags for SYCL AMD/NVIDIA backends |
Comments suppressed due to low confidence (1)
src/neural/backends/sycl/layers.cc.dp.cpp:361
- [nitpick] Consider adding unit or integration tests for the new
USE_HIPBLASfp16 GEMM paths to validate correctness on AMD hardware and prevent regressions.
#elif defined(USE_HIPBLAS)
| #include "winograd_helper.h" | ||
| #include <cmath> | ||
|
|
||
| #if defined(__HIP_PLATFORM_AMD__) && (defined(__GFX9__) || defined(__GFX8__)) |
Copilot
AI
Jun 19, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] The SYCL_SUB_GROUP_SIZE macro is defined locally here; since it’s duplicated in multiple files, consider extracting it to a shared header to avoid divergence and improve maintainability.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good point, moved to sycl_common.h
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible that we may also need to set SYCL_SUB_GROUP_SIZE for future architectures?
Then it may make sense to use this pattern:
#if __has_include("params_override.h")
#include "params_override.h"
#endif
#ifndef SYCL_SUB_GROUP_SIZE
#if defined(__HIP_PLATFORM_AMD__) && (defined(__GFX9__) || defined(__GFX8__))
#define SYCL_SUB_GROUP_SIZE 64
#else
#define SYCL_SUB_GROUP_SIZE 32
#endif
#endifThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you'd like an interesting read, the full context for this very specific definition is ROCm/ROCm#4121 where an AMD engineer recommended this claiming it "will work without needing revisiting in the foreseeable future". I also can't imagine the other major vendors not supporting a sub-group size of 32 any time soon. Supporting other SYCL devices than AMD/Intel/NVIDIA GPUs would require big changes to the code as it is currently, so I'm quite confident this won't be needed for now. I would suggest implementing this pattern if/when a need for this comes up, but if you think it's useful I don't mind adding it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems fine then.
| #endif | ||
| #include "winograd_helper.h" | ||
|
|
||
| #if defined(__HIP_PLATFORM_AMD__) && (defined(__GFX9__) || defined(__GFX8__)) |
Copilot
AI
Jun 19, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] This defines SYCL_SUB_GROUP_SIZE again; extracting it to a common header would reduce duplication and ensure consistency across kernels.
| #elif defined(USE_HIPBLAS) | ||
| hipblasHandle_t handle = hipBlasContextManager::gethipBlasHandle_t(); | ||
| hipblasHandle_t handle = hipBlasContextManager::gethipBlasHandle_t(); | ||
| if (fp16) { |
Copilot
AI
Jun 19, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] The if (fp16) blocks across multiple GEMM routines duplicate conversion and submission logic; extracting common fp16-path code into a helper could reduce repetition and simplify future updates.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Refactoring the code structure is beyond the scope of this PR which brings minimal changes required to enable the backends while keeping everything else untouched. The new blocks in the HIP path follow the same style as the existing blocks for the CUDA path.
We’ve seen around 4000-5000 nodes/second with SYCL (fp32) on Intel Data Center GPU Max 1100, AMD MI210 and NVIDIA H100. We know that for the biggest and fastest GPUs the performance of the NVIDIA and AMD backends is limited by the CPU threading performance of the cuBLAS/hipBLAS task submissions. This can be mitigated with a “native command” SYCL extension and we’re currently working on integrating that into the Velocity-Bench version in oneapi-src/Velocity-Bench#98. We’d like to upstream this work once it’s merged and well tested in Velocity-Bench and this should bring the performance much closer to native CUDA/HIP code. |
* disable opencl, dx12 and onednn appveyor builds (LeelaChessZero#2019) * remove support for f8e5m2 as onnx datatype (LeelaChessZero#2032) * add saturation to fp8 conversions, default on (LeelaChessZero#2033) Co-authored-by: borg323 <[email protected]> * fix build outside x86_64/aarch64 (LeelaChessZero#2010) the current logic includes the x86 simd header any time the platform is !arm, which is backwards * Fix build for GCC 14.1 (LeelaChessZero#2039) * circleci: build macos binaries on arm (LeelaChessZero#2040) * has_wdl support typo to fix WDL_mu centipawn (LeelaChessZero#2051) * Fix error message. (LeelaChessZero#2053) * A few improvements to the onxx 2 hlo converter. (LeelaChessZero#2047) * A few improvements to the onxx 2 hlo converter. * Cleaning up. * adressing issue LeelaChessZero#825 (LeelaChessZero#2054) * Update bitboard.h adressing issue LeelaChessZero#825 * Update bitboard.h adressing issue LeelaChessZero#825 * fix opset18 handling for reduce ops in xla hlo builing (LeelaChessZero#2066) * Support negative axes in OpUnsqueeze (LeelaChessZero#2070) * Add missing #pragma once (LeelaChessZero#2073) * Update README.md (LeelaChessZero#2067) * onnx2hlo handle negative axes and assorted changes (LeelaChessZero#2076) * onnx opset 22 fully supports bfloat16 so remove relax_op_types (LeelaChessZero#2077) * fix onnx2hlo mish (LeelaChessZero#2078) * Change centipawn fallback to account for sharper WDL with high WDLCalibrationElo (LeelaChessZero#2075) * half eval fallback formula With WDL sharpening at 3600 Elo (most commonly used value e.g. in TCEC both for playing and for kibitzing), the old centipawn calibration is off by about a factor 2 compared to Stockfish and generally takes over too quickly around +2.00 while it should only take over around +4.00 since up to there, `WDL_mu` behaves well enough. With lower calibration Elo (e.g. for analysis of human games / openings), the takeover point is significantly later due to lower Q from broader WDL, so this change doesn't affect anything. Doesn't yet fix the jumpy eval behavior in draws with very low W or L but substantial L resp. W remaining. * changed factor to +128 convention initial oversight: in a +1 position we want to display +128, that shouldn't change * clean up onnx casts (LeelaChessZero#2084) * onnx make alt_mish work with other data types (LeelaChessZero#2082) * fix xla with multiple devices (LeelaChessZero#2081) * chore: update configfile.cc (LeelaChessZero#2085) * Move to C++20 (LeelaChessZero#2088) * update circleci * update appveyor android builds * update appveyor to vs2019 * build with c++20 * fix for vs2019 breakage * update circleci meson for macos * code reorganization (LeelaChessZero#2041) The src directory is re-organized as follows: 1. The benchmark and lc0ctl directories are merged into tools. 2. All backends are moved into neural/backends. 3. The neural/onnx and neural/xla directories remain but only contain the conversion functions and not any backend code. 4. The neural/tables directory contains some generic information that used to be in neural/shared (activation functions, policy maps etc.) that are used by both backends and the onnx converter. 5. The rest of neural/shared is moved to neural/backends/shared/. 6. The rescorer is moved into the trainingdata directory. * Spinlock warning fix (LeelaChessZero#2095) * speed up position and history (LeelaChessZero#1761) * Make EncodePositionForNN accept span<Position> (LeelaChessZero#2097) * Decoupling EngineLoop from EngineController (LeelaChessZero#2102) * Minor changes to proto generation (LeelaChessZero#2025) * Move src/mcts/ to src/search/classic/, and to classic namespace (LeelaChessZero#2094) * New backend API and wrapper for old networks. (LeelaChessZero#2098) * Move backend-specific options out of search params.cc (LeelaChessZero#2104) * fix rescorer build with newer meson (LeelaChessZero#2089) * Introduce new search API, add a sample "policy head" algorithm (LeelaChessZero#2100) * Adding BackendManager functions to create a backend. (LeelaChessZero#2106) * Implement node cache as a backend layer. (LeelaChessZero#2108) * Plug new Search and Backend APIs to the engine (LeelaChessZero#2107) * fix appveyor build (LeelaChessZero#2110) * Add a makefile for OpenBench (LeelaChessZero#2113) * Fix openbench-specific issues (LeelaChessZero#2115) * Chmod build.sh before running. (LeelaChessZero#2117) * Update benchmark defaults for OpenBench (LeelaChessZero#2118) * Use `bench` for a short benchmark and `benchmark` for a full one. (LeelaChessZero#2120) * Add AUTHORs file (LeelaChessZero#2116) * Generate AUTHORS file. * Address review comment. * Make MemCache / Wrapper backends thread-safe (LeelaChessZero#2112) * Atomic vector * Lockless MemCache * Lockless threadsafe wrapper. * Bugfix * Build fix. * Fixing elo-gaining bug. * Update logging.h (LeelaChessZero#2124) * Introduce a "valuehead" search algorithm. (LeelaChessZero#2121) * Change Search API. (LeelaChessZero#2127) * Switch classic search to Backend interface. (LeelaChessZero#2109) * Did the easy part. * Remove display cacheusage * Helper functions. * Prefill fix.. * Search compiles * Benchmark compiles * Caching backends have interface now. * Value head * Made training pipeline compile with some force. * More careful memcache logic. For the cases legal moves are not passed * Rollback of changes in search.* * Roll forward some changes. * Attempt 2 * amend * Flip sign for the fetched V. * std::unique_ptr<EvalResult> * Mutex * Remove the mutex! * Comment * Add iterator traits for edge iterator * Clear cache after backend is created. * Temporarily make MemCache owning the wrapped backend (Fixing issues after another rebase) * Addressing some review comments. * Undo the last commit to be on a safe side. * Rescorer build fix (LeelaChessZero#2125) * replace cl2.hpp with latest opencl.hpp (LeelaChessZero#2130) * do not build legacy backends by default (LeelaChessZero#2129) * fix training data generation (LeelaChessZero#2128) * fix xla device selection (LeelaChessZero#2131) * fix GetPositionHistoryAtNode() (LeelaChessZero#2133) also fix appveyor syzygy downloads * Refactor `Move` representation (LeelaChessZero#2126) * WIP * Compiles (except tests), doesn't link * Everything compiles (doesn't link) * Now everything really compiles (doesn't link) * Piece is also a struct * ParseMove is now a Board method. * ChessBoard::ParseMove() * MoveToNNIndex and deps * Move::ToString() * Square::Parse() * Compiles and links! \o/ * No code left in bitboard.cc * Remove some dead code (I was sure I commited that). * Crash fix. * Everything was uppercase. * Delete dead code. * Don't flip move when applying. * More dead code removal. * More constexpr * Fix after the rebase, * Support promotions without specified piece (defaults to Knight) * Make functions returning std::string not constexpr. * Bugfix. * Fix VS build. * Hopefully fix build for clang-12 * Rewrote FEN parser. * s/on_board/IsValid/ * row/col → rank/file in the error message * Include <charconv> * Fix a test. * Post-rebase fix. * fix rescorer * fix assert * Fix after mis-rebase. * fix selfplay and pgn reader --------- Co-authored-by: borg323 <[email protected]> * assorted warning fixes (LeelaChessZero#2135) * onnx-trt backend (LeelaChessZero#2134) * use previous release of opencl.hpp (LeelaChessZero#2138) * Support updating backend parameters (LeelaChessZero#2136) * Add UpdateConfiguration() * Recreate backend if configuration changed. * onnx-trt improvements (LeelaChessZero#2139) * There's no need for rescorer to be inherited from UciLoop (LeelaChessZero#2142) * Add caching layer in the "new" Engine controller (LeelaChessZero#2137) * Make cache watch itself for the parameter changes. * Changed my mind. * Separate responder off UciLoop (LeelaChessZero#2143) * Move Show_WDL and Show_MLH params to the UciResponder (LeelaChessZero#2144) * Register UciResponder through observer pattern, rather than passing to constructor. (LeelaChessZero#2141) * Remove EngineLoop class (LeelaChessZero#2146) * Move Preload off engine_loop into engines * Make Engine params also populateable. * Moved UCI loop outside of the UCI class. * Compiles. * LogFile handling. * bugfix * Add syzygy tablebase loading to the new Engine driver. (LeelaChessZero#2147) * Add a test to the Engine class. (LeelaChessZero#2140) * skeleton * Actually have a test. * std::promise<void> instead of BestMoveInfo as we don't use it anyway yet. * Update gtest wrap * Fix test. * (likely) build fix * Wrapper for old search through new API (LeelaChessZero#2148) * Wrapper for old search through new API * Port ClearTree flag. * Populate Config file params from RunEngineInternal * don't compute empty batches (LeelaChessZero#2151) * Support for ponder in Engine (LeelaChessZero#2150) * Make it possible to specify the default search algorithm to run (and … (LeelaChessZero#2153) * Wait for search to finish before destructing Engine. (LeelaChessZero#2157) * cache cleanups (LeelaChessZero#2156) * Remove explicit namespace `classic` from the classic wrapper as it's already in that namespace. (LeelaChessZero#2158) * add dag-preview search algorithm (LeelaChessZero#2155) Co-authored-by: Etcaqab <[email protected]> * Fix broken bench. (LeelaChessZero#2159) * Do not flip moves getting positions as they are already flipped in ParseMove() (LeelaChessZero#2160) * Flip the moves before responding, if they are from black perspective. (LeelaChessZero#2161) * dag GetVerboseStats() fix and assorted cleanup. (LeelaChessZero#2162) * add some resiliency against cache hash collisions (LeelaChessZero#2163) * correct tt clearing order (LeelaChessZero#2166) * fix GetVerboseStats() cleanup bug (LeelaChessZero#2167) * sycl backend (LeelaChessZero#2152) Authored-by: Mcgrievy, Kathleen <[email protected]> * Check for unread options after backend is created. (LeelaChessZero#2169) * Run through new search API by default. (LeelaChessZero#2170) * Use FEN (as URL) for chess board DebugString() (LeelaChessZero#2165) * Allow certain backends to run without network file. (LeelaChessZero#2173) * Support simple/normal/pro mode with new search algorithms. (LeelaChessZero#2172) * Introduce visibility mode * Move visibility to OptionId def * Remove kSelfplayMode * Address review comments * s/visibility_mask/visibility in OptionParams (all other places still have it as visibility_mast) * Update params.cc * Update params.cc --------- Co-authored-by: borg323 <[email protected]> * Remove EngineClassic. (LeelaChessZero#2178) Co-authored-by: borg323 <[email protected]> * complete simple mode (LeelaChessZero#2181) * merge search params (LeelaChessZero#2183) Co-authored-by: borg323 <[email protected]> * Add wait UCI command (LeelaChessZero#2177) (LeelaChessZero#2184) Co-authored-by: julian <[email protected]> Co-authored-by: Gergely Fülöp <[email protected]> Co-authored-by: borg323 <[email protected]> * Keep one set of stoppers. (LeelaChessZero#2186) * excise node info from classic stoppers * replace dag stoppers with classic ones * Enable SYCL NVIDIA and AMD backends (LeelaChessZero#2192) * Enable SYCL NVIDIA backend Tested with ``` CXX=icpx CC=icx AR=llvm-ar ./build.sh -Dsycl=nvidia ``` on Ubuntu 24.04 with CUDA 12.9 and oneAPI 2025.1. The CUDA Compute Capability can be optionally specified with `-Dcc_cuda`. If not specified, the default CUDA target of the DPC++ compiler is used, which means SYCL device code is precompiled for the lowest supported CC. When executed on a GPU with different CC, it is recompiled at runtime for the specific architecture. In addition to meson.build changes, remove a redundand free(nullptr) causing crashes in the SYCL NVIDIA backend. * Enable SYCL AMD backend and fix its fp16 support Tested with ``` CXX=icpx CC=icx AR=llvm-ar ./build.sh -Dsycl=amd -Damd_gfx=90a ``` on Ubuntu 22.04 with ROCm 6.3.3 and oneAPI 2025.1. The new amd_gfx option is required as DPC++ does not support Just-In-Time compilation for AMD GPU code. It has to be precompiled for the right architecture when building the application. Fix the SYCL AMD fp16 backend which missed calling the fp16 hipBLAS functions where needed. Also fix the hardcoded sub-group / warp / wavefront size of 32. Some AMD GPUs have wavefront size of 64 and this has to be used instead. * Move SYCL_SUB_GROUP_SIZE definition to a common header * Attempt to Autocreate macOS Builds (Gemini) (LeelaChessZero#2191) * update onnxruntime for onnx-dml and disable telemetry (LeelaChessZero#2187) * Show mate score in valuehead mode. (LeelaChessZero#2189) * Show mate score in valuehead mode. * Fix the info bug. * Remove incorrect namespace comments. * Address review comments. * fix appveyor breakage --------- Co-authored-by: borg323 <[email protected]> * show info for policyhead (LeelaChessZero#2194) * Update of CONTRIBUTING.md (LeelaChessZero#2195) * WIP * Cleanup * Rewritten what to do after sqashing, but now it's too long. * Address copilot comment. * Remove the old copy of the paragraph. * More minor tweaks * Addressing review comments. * clear cache for new game (LeelaChessZero#2197) * Improve windows build script (LeelaChessZero#2196) * feat: Improve build.cmd for automatic CUDA detection * update selfplay and backendbench backend interface use (LeelaChessZero#2193) Also: * introduces Backend::IsSameConfiguration() * clears the cache if backend config changed * cleans up some includes * add native_arch build option (LeelaChessZero#2199) * alternative default search configuration (LeelaChessZero#2200) * make it possible to specify the default backend (LeelaChessZero#2188) * set strict_uci_timing_ initially to true (LeelaChessZero#2203) * build support for onnx-trt (LeelaChessZero#2204) * add cuda 12 build (LeelaChessZero#2205) * modifying 'check if node_limit was initialized' to not fail if node_limit was set to 0, moving 'set node_limit to 4000000000 if it wasnt initialized' from stopper.h to common.cc (LeelaChessZero#2056) fix for go nodes 0 * fix memory limit overflow (LeelaChessZero#2058) * checking if 'name' parameter was provided with 'setoption' before proceeding with execution (LeelaChessZero#2062) * cuda blas backward compatibility (LeelaChessZero#1747) * prefetch cleanup (LeelaChessZero#1778) * appveyor onnx build cleanup (LeelaChessZero#2206) * custom setoption parsing (LeelaChessZero#2207) * better default and options to set onnx ir (LeelaChessZero#2209) * update authors, changelog and version before branch (LeelaChessZero#2210) * Use system opencl.hpp if it is available (LeelaChessZero#2213) OpenCL extension can make incompatible changes. These changes would require careful version checks in C++ binding. Too bad C++ bindings fail to have backward compatibility. C++ bindings fail the compatibility. Many systems offer matching C++ and C bindings which are known to work together. We can use the system header if it exists. CL directory is standard location. Mac uses custom OpenCL directory but it might only include C headers. * try to fix mac binary uploads (LeelaChessZero#2212) * add a build using the latest xcode version (LeelaChessZero#2217) * error out for rpe nets (LeelaChessZero#2218) Co-authored-by: borg323 <[email protected]> * Fix warnings in clang / macos compile process (LeelaChessZero#2216) * Optimize metal: input expansion (LeelaChessZero#2220) * Move bitboard expansion to gpu. * Fix optimized gpu code path for bitboard expansion. * Fix broadcast to work correctly. * Remove unneeded code. * Fix warnings. * Minor fixes * Remove forwardEvalLegacy() and other alternate codepaths except the bitwiseANDWithPrimaryTensor * Remove unused variables and memory. * add onnx model hash to trt cache prefix (LeelaChessZero#2214) Also adds onnxruntime version to cache prefix and removes random name from onnx model. * alternative fp16 conversions using _Float where supported (LeelaChessZero#2219) * Optimize metal: policy mapping (LeelaChessZero#2221) * Move bitboard expansion to gpu. * Fix optimized gpu code path for bitboard expansion. * Fix broadcast to work correctly. * Remove unneeded code. * Fix warnings. * Minor nit. * Minor fixes * Debug stuff * Move attention policy promo offset calculation to gpu. * Move attention and convolution policy mapping to gpu. * Remove policymapping from legacy codepath. * Remove forwardEvalLegacy() and other alternate codepaths except the bitwiseANDWithPrimaryTensor * Remove unused variables and memory. * Update change * Sycl for AMD build improvements (LeelaChessZero#2215) * Support new hipBLAS 3 API (LeelaChessZero#2222) ROCm version 7 is under development. It will include hipBLAS 3. The new API uses different enums for datatype and computetype. We can update our code to use the new version while macros map the new compute type to old values for version 2 support. * remove c++17 workarounds (LeelaChessZero#2223) * Update `sycl` backend (LeelaChessZero#2228) -add auto threads and batching -add show platforms to see the supported platforms by device. -add enhancements to show device info. * Extract V6TrainingData struct to standalone header for external project usage (LeelaChessZero#2235) Co-authored-by: mooskagh <[email protected]> Co-authored-by: copilot-swe-agent[bot] <[email protected]> * fix some icx warnings (LeelaChessZero#2229) * try to simplify installation of onnx-trt required dlls (LeelaChessZero#2225) * onnx-trt install script * update onnx-trt readme * add onnx build option and default paths (LeelaChessZero#2226) now that some systems offer onnxruntime packages. * Make WDL_ShowWDL parameter false by default (LeelaChessZero#2234) Co-authored-by: mooskagh <[email protected]> Co-authored-by: copilot-swe-agent[bot] <[email protected]> Co-authored-by: Alexander Lyashuk <[email protected]> * Fixes to problems found because of Search::GetVerboseStats crash (LeelaChessZero#2232) * Fix Metal backend bug with out of bounds vector access (LeelaChessZero#2233) * Avoid sycl waits in forwardEval code paths (LeelaChessZero#2230) * Avoid GPU synchronizing in sycl backend GPU synchronization waits for all operations to complete before adding new commands to execute. The round trip latency costs makes GPU idle for a short period of time when we want to keep it fully loaded. Removing waits improves GPU utilization and improves NN evaluation performance. * Generate offset pointers on GPU CPU generated offset pointers require synchronization for memory management. GPU can generated pointers directly to the target memory. This avoids a sycl::event::wait() call. * fix sycl builds on windows (LeelaChessZero#2236) * appveyor only build cudnn on tags (LeelaChessZero#2238) * appveyor run tests with onnx (LeelaChessZero#2239) * update cuda/cudnn version warning (LeelaChessZero#2240) * Refactor ProcessFile function in rescorer.cc to improve maintainability (LeelaChessZero#2237) Co-authored-by: copilot-swe-agent[bot] <[email protected]> Co-authored-by: Alexander Lyashuk <[email protected]> Co-authored-by: borg323 <[email protected]> * fix test dependencies (LeelaChessZero#2241) * Add improved SYCL device FP16 capability detection and device selection logic (LeelaChessZero#2243) * Refactor SYCL - Replaces dpct::has_capability_or_fail with direct device aspect checks for fp16 support. - Improve error handling when selecting devices for fp16 execution. * Remove redundant device selection in NewComputation Eliminated the call to dpct::select_device(gpu_id_) in SyclNetwork::NewComputation, as it is no longer necessary to explicitly set the device for each computation. * address fp16 handling, as pointed out by menkib. * rescorer interface (LeelaChessZero#2246) * update changelog and authors for rc2 (LeelaChessZero#2244) * yet another try to fix mac binary uploads (LeelaChessZero#2250) * Use rocm_agent_enumerator to detect AMD GPU core version (LeelaChessZero#2254) * Start Count at One for Windows Install Instructions (LeelaChessZero#2252) * fix for eigen dependency issue (LeelaChessZero#2251) * updated cuda networks for Cuda TK 13 (LeelaChessZero#2256) * readme update (LeelaChessZero#2253) * update changelog and authors for release (LeelaChessZero#2257) * Support of default value in our protobufs (LeelaChessZero#2247) * Output input embedding format in describenet (LeelaChessZero#2259) * build onnx on linux ci (LeelaChessZero#2262) Also replace call to std::format() * make d_ double for dag (LeelaChessZero#2265) * Avoid incrementing shared pointer when accessing low node (LeelaChessZero#2266) A copy of shared ptr requires atomic increment and decrement operations. These use hardware level locking mechanism which have a small extra cost. The extra cost should be avoided in an inner loops when reading values from the node. The reference is only valid when holding nodes mutex. * move fused_multi_head_attention to thitrd_party * make cutlass target architecture configurable * meson.build cleanup * build fixes --------- Co-authored-by: borg323 <[email protected]> Co-authored-by: alice <[email protected]> Co-authored-by: Alexander Lyashuk <[email protected]> Co-authored-by: Naphthalin <[email protected]> Co-authored-by: Karl Kfoury <[email protected]> Co-authored-by: Kovax <[email protected]> Co-authored-by: Bonan <[email protected]> Co-authored-by: Ikko Eltociear Ashimine <[email protected]> Co-authored-by: Viet-Anh Tran <[email protected]> Co-authored-by: Etcaqab <[email protected]> Co-authored-by: Julian Helmsen <[email protected]> Co-authored-by: julian <[email protected]> Co-authored-by: Gergely Fülöp <[email protected]> Co-authored-by: Rafal Bielski <[email protected]> Co-authored-by: john-sp <[email protected]> Co-authored-by: Jamie Huta <[email protected]> Co-authored-by: Menkib <[email protected]> Co-authored-by: almaudoh <[email protected]> Co-authored-by: Copilot <[email protected]> Co-authored-by: mooskagh <[email protected]> Co-authored-by: Shukant Pal <[email protected]> Co-authored-by: Gabe <[email protected]> Co-authored-by: Jack L <[email protected]>
Building on top of #2152 from @KateBlueSky and @borg323, fix issues in the SYCL NVIDIA and AMD backends, and add the build configuration to enable them.
Tested with oneAPI 2025.1 on Ubuntu 22.04/24.04 with two NVIDIA and two AMD GPU models (one workstation/gaming and one data centre model from each vendor). Tested with
./lc0 benchusing bothsyclandsycl-fp16backends and using the t3-512x15x16h-distill-swa-2767500 network.The build works with the commands:
The (already existing)
cc_cudasetting is optional and if not specified, the default CUDA target of the DPC++ compiler is used, which means SYCL device code is precompiled for the lowest supported CC. When executed on a GPU with different CC, it is recompiled at runtime for the specific architecture.The new
amd_gfxoption is required as DPC++ does not support Just-In-Time compilation for AMD GPU code. It has to be precompiled for the right architecture when building the application.Code fixes include: