Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 54 additions & 10 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -92,31 +92,71 @@ endif()
if(AUTO_DETECT_BACKENDS)
message(STATUS "Auto-detecting available backends...")

# The Python that scikit-build's build-isolated environment hands
# us does not have `torch` (only `[build-system].requires` is
# installed). Fall back to a list of common system interpreters so
# the auto-detection finds `torch` when it is in the install env
# but not the build env. The first interpreter that successfully
# imports `torch` wins and is reused by the `WITH_TORCH` block
# below for include / library lookups.
find_package(Python COMPONENTS Interpreter QUIET)

if(Python_FOUND)
set(_torch_python_candidates "${Python_EXECUTABLE}")
foreach(_candidate
python3
python
/usr/bin/python3
/usr/local/bin/python3
/opt/conda/bin/python
/opt/conda/bin/python3)
find_program(_resolved_${_candidate} ${_candidate})
if(_resolved_${_candidate} AND
NOT _resolved_${_candidate} STREQUAL "${Python_EXECUTABLE}")
list(APPEND _torch_python_candidates "${_resolved_${_candidate}}")
endif()
endforeach()

foreach(_py ${_torch_python_candidates})
if(NOT _py)
continue()
endif()

execute_process(
COMMAND ${Python_EXECUTABLE} -c "import torch"
COMMAND "${_py}" -c "import torch"
RESULT_VARIABLE _torch_import_result
OUTPUT_QUIET
ERROR_QUIET
)

if(_torch_import_result EQUAL 0)
set(WITH_TORCH ON)
message(STATUS "Auto-detected PyTorch.")
set(_TORCH_PYTHON "${_py}")
break()
endif()
endforeach()

if(_TORCH_PYTHON)
set(WITH_TORCH ON)
message(STATUS "Auto-detected PyTorch (via ${_TORCH_PYTHON}).")
endif()
endif()

if(WITH_TORCH)
find_package(Python COMPONENTS Interpreter REQUIRED)

# Prefer the interpreter that the auto-detect block already
# confirmed has `torch` (this is the system Python on hosts that
# use scikit-build's build-isolation, where the build interpreter
# does not have `torch`). Fall back to `Python_EXECUTABLE` for
# explicit `-DWITH_TORCH=ON` invocations.
if(NOT _TORCH_PYTHON)
set(_TORCH_PYTHON "${Python_EXECUTABLE}")
endif()

# Query `torch` paths directly instead of using `find_package(Torch)`,
# which pulls in Caffe2's CMake config and may fail on platforms with
# non-standard CUDA toolchains.
execute_process(
COMMAND ${Python_EXECUTABLE} -c "from torch.utils.cpp_extension import include_paths; print(';'.join(include_paths()))"
COMMAND ${_TORCH_PYTHON} -c "from torch.utils.cpp_extension import include_paths; print(';'.join(include_paths()))"
OUTPUT_VARIABLE TORCH_INCLUDE_DIRS
OUTPUT_STRIP_TRAILING_WHITESPACE
RESULT_VARIABLE _torch_result
Expand All @@ -127,7 +167,7 @@ if(WITH_TORCH)
endif()

execute_process(
COMMAND ${Python_EXECUTABLE} -c "from torch.utils.cpp_extension import library_paths; print(';'.join(library_paths()))"
COMMAND ${_TORCH_PYTHON} -c "from torch.utils.cpp_extension import library_paths; print(';'.join(library_paths()))"
OUTPUT_VARIABLE _torch_lib_dirs
OUTPUT_STRIP_TRAILING_WHITESPACE
)
Expand All @@ -144,7 +184,7 @@ if(WITH_TORCH)
# the bundled `NEEDED` entries (otherwise: `undefined reference to
# _gfortran_etime@GFORTRAN_8` etc.).
execute_process(
COMMAND ${Python_EXECUTABLE} -c "import os, torch; d = os.path.dirname(torch.__file__); p = os.path.join(os.path.dirname(d), 'torch.libs'); print(p if os.path.isdir(p) else '')"
COMMAND ${_TORCH_PYTHON} -c "import os, torch; d = os.path.dirname(torch.__file__); p = os.path.join(os.path.dirname(d), 'torch.libs'); print(p if os.path.isdir(p) else '')"
OUTPUT_VARIABLE TORCH_BUNDLED_LIBS_DIR
OUTPUT_STRIP_TRAILING_WHITESPACE
)
Expand All @@ -163,7 +203,7 @@ if(WITH_TORCH)
# A mismatch causes linker errors (e.g. undefined reference to
# `c10::Device::Device(std::string const&)`).
execute_process(
COMMAND ${Python_EXECUTABLE} -c "import torch; print(int(torch.compiled_with_cxx11_abi()))"
COMMAND ${_TORCH_PYTHON} -c "import torch; print(int(torch.compiled_with_cxx11_abi()))"
OUTPUT_VARIABLE TORCH_CXX11_ABI
OUTPUT_STRIP_TRAILING_WHITESPACE
RESULT_VARIABLE _torch_abi_result
Expand Down Expand Up @@ -218,6 +258,8 @@ if(WITH_ILUVATAR)
# tries to link. Disable automatic CUDA runtime linking and link
# manually via `find_package(CUDAToolkit)` instead.
set(CMAKE_CUDA_RUNTIME_LIBRARY NONE)
set(CMAKE_CUDA_COMPILER_ID Clang CACHE STRING "Iluvatar CUDA compiler id" FORCE)
set(CMAKE_CUDA_COMPILER_FORCED ON CACHE BOOL "Skip Iluvatar CUDA compiler detection" FORCE)
message(STATUS "Iluvatar: CUDA compiler ${CMAKE_CUDA_COMPILER}, arch ${ILUVATAR_ARCH}")
enable_language(CUDA)
find_package(CUDAToolkit REQUIRED)
Expand Down Expand Up @@ -314,10 +356,12 @@ if(NOT WITH_NVIDIA AND NOT WITH_ILUVATAR AND NOT WITH_METAX AND NOT WITH_MOORE A
add_compile_definitions(WITH_CPU=1)
endif()

if(WITH_METAX OR WITH_MOORE)
if(WITH_TORCH OR WITH_METAX OR WITH_MOORE)
set(PYBIND11_ENABLE_EXTRAS OFF)
endif()

add_subdirectory(src)

add_subdirectory(examples)
if(NOT GENERATE_PYTHON_BINDINGS)
add_subdirectory(examples)
endif()
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[build-system]
requires = ["scikit-build-core", "pybind11", "libclang"]
requires = ["scikit-build-core", "pybind11", "libclang", "pyyaml"]
build-backend = "scikit_build_core.build"

[project]
Expand Down
Loading
Loading