Skip to content

Commit 9875a83

Browse files
ZhiweiYan-96pytorchmergebot
authored andcommitted
[Intel GPU] oneDNN GPU GEMM support (pytorch#117202)
# Motivation This PR is a part of RFC pytorch#114848, and it is a successor PR of pytorch#116249 and pytorch#116019. This PR would depend on oneDNN compilation in pytorch#116249. Some runtime support is needed in pytorch#116019. Aten operators like `addmm`, `baddmm` is defined in `Blas.cpp` in `aten/src/ATen/native/mkldnn/xpu/`. Accompanied with these files provide core functionaliy, `BlasImpl.h`, `Utils.h` and other file provide basic utilities for them. For instance, `Utils.h` provide common memory descriptor query utils for `Matmul.h` and these utility function will also be used in other primitive, like `convolution`. `BlasImpl.h` is a header file that provide helper for handling shape info processing in matmul related operators. It would not only help basic GEMM operator like `addmm, baddmm` but also help fusion operators used in `torch.compile` like `linear_pointwise` in pytorch#117824. In next stage, we would continually complete the oneDNN support through enabling `matmul fusion` and `convolution` related code. Co-authored-by: xiaolil1 <[email protected]> Co-authored-by: lei,zhenyuan <[email protected]> Pull Request resolved: pytorch#117202 Approved by: https://github.com/EikanWang, https://github.com/jgong5, https://github.com/malfet ghstack dependencies: pytorch#117098, pytorch#117112
1 parent 6330aca commit 9875a83

File tree

6 files changed

+1623
-6
lines changed

6 files changed

+1623
-6
lines changed

aten/CMakeLists.txt

+6
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ cmake_policy(SET CMP0012 NEW)
1818
#############################################
1919

2020
set(ATen_CPU_SRCS)
21+
set(ATen_XPU_SRCS)
22+
set(ATen_XPU_INCLUDE)
2123
set(ATen_CPU_TEST_SRCS)
2224
set(ATen_CPU_INCLUDE)
2325
set(ATen_THIRD_PARTY_INCLUDE)
@@ -39,6 +41,7 @@ set(ATen_XPU_INCLUDE)
3941
set(ATen_XPU_TEST_SRCS)
4042
set(ATen_VULKAN_TEST_SRCS)
4143
set(ATen_CPU_DEPENDENCY_LIBS)
44+
set(ATen_XPU_DEPENDENCY_LIBS)
4245
set(ATen_CUDA_DEPENDENCY_LIBS)
4346
set(ATen_HIP_DEPENDENCY_LIBS)
4447
set(ATen_PUBLIC_CUDA_DEPENDENCY_LIBS)
@@ -105,6 +108,8 @@ add_subdirectory(src/ATen)
105108
# Pass source, includes, and libs to parent
106109
set(ATen_CPU_SRCS ${ATen_CPU_SRCS} PARENT_SCOPE)
107110
set(ATen_CORE_SRCS ${ATen_CORE_SRCS} PARENT_SCOPE)
111+
set(ATen_XPU_SRCS ${ATen_XPU_SRCS} PARENT_SCOPE)
112+
set(ATen_XPU_INCLUDE ${ATen_XPU_INCLUDE} PARENT_SCOPE)
108113
set(ATen_CUDA_CU_SRCS ${ATen_CUDA_CU_SRCS} PARENT_SCOPE)
109114
set(ATen_CUDA_CPP_SRCS ${ATen_CUDA_CPP_SRCS} PARENT_SCOPE)
110115
set(ATen_CUDA_LINALG_SRCS ${ATen_CUDA_LINALG_SRCS} PARENT_SCOPE)
@@ -130,6 +135,7 @@ set(ATen_HIP_INCLUDE ${ATen_HIP_INCLUDE} PARENT_SCOPE)
130135
set(ATen_XPU_INCLUDE ${ATen_XPU_INCLUDE} PARENT_SCOPE)
131136
set(ATen_THIRD_PARTY_INCLUDE ${ATen_THIRD_PARTY_INCLUDE} PARENT_SCOPE)
132137
set(ATen_CPU_DEPENDENCY_LIBS ${ATen_CPU_DEPENDENCY_LIBS} PARENT_SCOPE)
138+
set(ATen_XPU_DEPENDENCY_LIBS ${ATen_XPU_DEPENDENCY_LIBS} PARENT_SCOPE)
133139
set(ATen_CUDA_DEPENDENCY_LIBS ${ATen_CUDA_DEPENDENCY_LIBS} PARENT_SCOPE)
134140
set(ATen_HIP_DEPENDENCY_LIBS ${ATen_HIP_DEPENDENCY_LIBS} PARENT_SCOPE)
135141
set(ATen_CORE_TEST_SRCS ${ATen_CORE_TEST_SRCS} PARENT_SCOPE)

aten/src/ATen/CMakeLists.txt

+18
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,8 @@ file(GLOB miopen_cpp "miopen/*.cpp")
8585
file(GLOB mkl_cpp "mkl/*.cpp")
8686
file(GLOB mkldnn_cpp "mkldnn/*.cpp")
8787

88+
file(GLOB mkldnn_xpu_cpp "native/mkldnn/xpu/*.cpp" "native/mkldnn/xpu/detail/*.cpp")
89+
8890
file(GLOB native_cpp "native/*.cpp")
8991
file(GLOB native_mkl_cpp "native/mkl/*.cpp")
9092
file(GLOB native_mkldnn_cpp "native/mkldnn/*.cpp")
@@ -238,6 +240,20 @@ else()
238240
set(all_cpu_cpp ${all_cpu_cpp} ${vulkan_cpp})
239241
endif()
240242

243+
if(USE_XPU)
244+
list(APPEND ATen_XPU_SRCS ${mkldnn_xpu_cpp})
245+
list(APPEND ATen_XPU_DEPENDENCY_LIBS xpu_mkldnn)
246+
247+
list(APPEND ATen_XPU_DEPENDENCY_LIBS ${OCL_LIBRARY})
248+
list(APPEND ATen_XPU_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/native/mkldnn/xpu)
249+
list(APPEND ATen_XPU_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/native/mkldnn/xpu/detail)
250+
list(APPEND ATen_XPU_INCLUDE ${PROJECT_SOURCE_DIR}/third_party/ideep/mkl-dnn/include)
251+
list(APPEND ATen_XPU_INCLUDE ${XPU_MKLDNN_INCLUDE})
252+
253+
list(APPEND ATen_XPU_INCLUDE ${SYCL_INCLUDE_DIR})
254+
list(APPEND ATen_XPU_DEPENDENCY_LIBS ${SYCL_LIBRARY})
255+
endif()
256+
241257
# Metal
242258
if(USE_PYTORCH_METAL_EXPORT)
243259
# Add files needed from exporting metal models(optimized_for_mobile)
@@ -629,6 +645,7 @@ list(APPEND ATen_MOBILE_BENCHMARK_SRCS
629645
# Pass source, includes, and libs to parent
630646
set(ATen_CORE_SRCS ${ATen_CORE_SRCS} PARENT_SCOPE)
631647
set(ATen_CPU_SRCS ${ATen_CPU_SRCS} PARENT_SCOPE)
648+
set(ATen_XPU_SRCS ${ATen_XPU_SRCS} PARENT_SCOPE)
632649
set(ATen_CUDA_CU_SRCS ${ATen_CUDA_CU_SRCS} PARENT_SCOPE)
633650
set(ATen_CUDA_CPP_SRCS ${ATen_CUDA_CPP_SRCS} PARENT_SCOPE)
634651
set(ATen_CUDA_LINALG_SRCS ${ATen_CUDA_LINALG_SRCS} PARENT_SCOPE)
@@ -658,6 +675,7 @@ set(ATen_XPU_INCLUDE ${ATen_XPU_INCLUDE} PARENT_SCOPE)
658675
set(ATen_VULKAN_INCLUDE ${ATen_VULKAN_INCLUDE} PARENT_SCOPE)
659676
set(ATen_CPU_DEPENDENCY_LIBS ${ATen_CPU_DEPENDENCY_LIBS} PARENT_SCOPE)
660677
set(ATen_CUDA_DEPENDENCY_LIBS ${ATen_CUDA_DEPENDENCY_LIBS} PARENT_SCOPE)
678+
set(ATen_XPU_DEPENDENCY_LIBS ${ATen_XPU_DEPENDENCY_LIBS} PARENT_SCOPE)
661679
set(ATen_HIP_DEPENDENCY_LIBS ${ATen_HIP_DEPENDENCY_LIBS} PARENT_SCOPE)
662680
set(FLASH_ATTENTION_CUDA_SOURCES ${FLASH_ATTENTION_CUDA_SOURCES} PARENT_SCOPE)
663681
set(MEM_EFF_ATTENTION_CUDA_SOURCES ${MEM_EFF_ATTENTION_CUDA_SOURCES} PARENT_SCOPE)

0 commit comments

Comments
 (0)