Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit 6327396

Browse files
committed
Enable device-specific compilation feature MACRO using Option DEVICE with target_compile_definitions (GPU_ARCH, MMA_ENGINE)
Signed-off-by: Qun Gao <[email protected]>
1 parent 81f132d commit 6327396

File tree

4 files changed

+16
-8
lines changed

4 files changed

+16
-8
lines changed

tests/integration/gemm/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
include_directories(${CMAKE_SOURCE_DIR}/tests/integration/gemm)
22
if (DEVICE STREQUAL "mtl")
3+
add_subdirectory(int4_dequantization)
4+
add_subdirectory(int4_dequantization_bias)
35
else()
46
add_subdirectory(bf16)
57
add_subdirectory(stream_k)

tests/integration/gemm/int4_dequantization/CMakeLists.txt

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,9 @@ string(REPLACE " " "_" ProjectId ${ProjectId})
33
string(PREPEND ProjectId "gemm_")
44

55
FILE(GLOB src main.cpp)
6-
add_integration_test(${ProjectId} ${src})
6+
add_integration_test(${ProjectId} ${src})
7+
if (DEVICE STREQUAL "mtl")
8+
target_compile_definitions(${ProjectId} PRIVATE MMA_ENGINE=fpu GPU_ARCH=XeLpg)
9+
else()
10+
target_compile_definitions(${ProjectId} PRIVATE MMA_ENGINE=xmx GPU_ARCH=XeHpg)
11+
endif()

tests/integration/gemm/int4_dequantization/main.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -234,16 +234,16 @@ void dequantize_gemm_run(uint32_t iter) {
234234
data_type_zero_pt,
235235
gpu::xetla::group::quant_mode::S4_ASYM,
236236
dequant_s,
237-
mma_engine::xmx,
238-
gpu_arch::XeHpg>;
237+
mma_engine::MMA_ENGINE,
238+
gpu_arch::GPU_ARCH>;
239239
using gemm_t = xetla::group::
240240
gemm_t<compute_policy, tile_shape, mem_desc_a_t, mem_desc_b_t>;
241241

242242
using epilogue_t = xetla::group::epilogue_t<
243-
xetla::group::epilogue_policy_default<gpu_arch::XeHpg>,
243+
xetla::group::epilogue_policy_default<gpu_arch::GPU_ARCH>,
244244
tile_shape,
245245
mem_desc_c_t>;
246-
using group_swizzle = xetla::kernel::group_swizzle_default<gpu_arch::XeHpg>;
246+
using group_swizzle = xetla::kernel::group_swizzle_default<gpu_arch::GPU_ARCH>;
247247
using gemm_op_t = xetla::kernel::gemm_universal_t<
248248
gpu::xetla::kernel::dispatch_policy_int4_dequantize_kslicing<
249249
group_swizzle,

tests/integration/gemm/int4_dequantization_bias/CMakeLists.txt

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@ set(ProjectIdXe ${ProjectId})
55
string(PREPEND ProjectIdClient "gemm_client_")
66
string(PREPEND ProjectIdXe "gemm_xe_")
77

8-
FILE(GLOB src_client main_client.cpp)
98
if (DEVICE STREQUAL "mtl")
9+
FILE(GLOB src_client main_client.cpp)
1010
add_integration_test(${ProjectIdClient} ${src_client})
11+
else()
12+
FILE(GLOB src_xe main_xe.cpp)
13+
add_integration_test(${ProjectIdXe} ${src_xe})
1114
endif()
12-
FILE(GLOB src_xe main_xe.cpp)
13-
add_integration_test(${ProjectIdXe} ${src_xe})

0 commit comments

Comments
 (0)