Skip to content

Commit 2859fd6

Browse files
[Matirx] add support for joint_matrix_copy
1 parent 47d8489 commit 2859fd6

File tree

3 files changed

+195
-0
lines changed

3 files changed

+195
-0
lines changed

sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,34 @@ joint_matrix_load(Group sg,
323323
#endif // defined(__SYCL_DEVICE_ONLY__)
324324
}
325325

326+
template <typename Group, typename T1, typename T2, size_t Rows, size_t Cols,
327+
use Use1, use Use2, layout Layout>
328+
void joint_matrix_copy(
329+
Group sg, joint_matrix<Group, T1, Use1, Rows, Cols, Layout> &src,
330+
joint_matrix<Group, T2, Use2, Rows, Cols, Layout> &dest) {
331+
#if defined(__SYCL_DEVICE_ONLY__)
332+
#if defined(__NVPTX__)
333+
// cuda code
334+
#else
335+
using storage_element_type =
336+
typename oneapi::detail::jm_type_interpretation_helper_trait<
337+
T2>::storage_element_type;
338+
auto wi_data_c = sycl::ext::intel::experimental::matrix::get_wi_data(sg, src);
339+
auto wi_data_d =
340+
sycl::ext::intel::experimental::matrix::get_wi_data(sg, dest);
341+
for (int i = 0; i < wi_data_c.length(); i++) {
342+
wi_data_d[i] = static_cast<storage_element_type>(wi_data_c[i]);
343+
}
344+
#endif // defined(__NVPTX__)
345+
#else
346+
std::ignore = sg;
347+
std::ignore = dest;
348+
std::ignore = src;
349+
throw runtime_error("joint matrix is not supported on host device.",
350+
PI_ERROR_INVALID_DEVICE);
351+
#endif // defined(__SYCL_DEVICE_ONLY__)
352+
}
353+
326354
template <typename Group, typename T, size_t NumRows, size_t NumCols,
327355
access::address_space Space, access::decorated IsDecorated>
328356
inline __SYCL_ALWAYS_INLINE void joint_matrix_store(
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
//==-------- joint_matrix_ss_int8.cpp - DPC++ joint_matrix------------ ----==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
// REQUIRES: matrix
9+
10+
// RUN: %{build} -o %t.out -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4
11+
// RUN: %{run} %t.out
12+
13+
#include <iostream>
14+
#include <sycl/sycl.hpp>
15+
16+
using namespace sycl;
17+
using namespace sycl::ext::oneapi::experimental::matrix;
18+
19+
#define SG_SZ 16
20+
21+
#include "joint_matrix_copy_impl.hpp"
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
#define TM 8
2+
#define TN SG_SZ
3+
#define TK 32
4+
5+
template <typename T, size_t NUM_ROWS, size_t NUM_COLS> struct big_matrix {
6+
public:
7+
T *mat;
8+
9+
public:
10+
T *get_data() { return mat; }
11+
void set_data(T *data) { mat = data; }
12+
big_matrix(T *data) : mat(data) {}
13+
};
14+
15+
template <typename T1, typename T2, size_t NUM_ROWS_A, size_t NUM_COLS_A,
16+
size_t NUM_ROWS_B, size_t NUM_COLS_B, size_t NUM_ROWS_C,
17+
size_t NUM_COLS_C>
18+
void matrix_multiply(big_matrix<T1, NUM_ROWS_C, NUM_COLS_C> &C,
19+
big_matrix<T2, NUM_ROWS_A, NUM_COLS_A> &A,
20+
big_matrix<T2, NUM_ROWS_B, NUM_COLS_B> &B) {
21+
size_t M = NUM_ROWS_C;
22+
size_t N = NUM_COLS_C;
23+
size_t K = NUM_COLS_A;
24+
// B => K/4 x N*4, A => M x K, C => M, N
25+
// stride should be X's cols, e.g., B's stirde = N*4
26+
assert(NUM_ROWS_C == NUM_ROWS_A && NUM_COLS_A == NUM_ROWS_B * 4);
27+
size_t NDRangeM = M / TM;
28+
size_t NDRangeN = N / TN;
29+
buffer<int8_t, 2> bufA(A.get_data(), range<2>(M, K));
30+
buffer<int8_t, 2> bufB(B.get_data(), range<2>(K, N));
31+
buffer<int32_t, 2> bufC(C.get_data(), range<2>(M, N));
32+
33+
queue q;
34+
q.submit([&](handler &cgh) {
35+
auto accC = bufC.get_access<access::mode::read_write>(cgh);
36+
auto accA = bufA.get_access<access::mode::read_write>(cgh);
37+
auto accB = bufB.get_access<access::mode::read_write>(cgh);
38+
39+
cgh.parallel_for<class imatrix>(
40+
nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}),
41+
[accA, accB, accC, M, N,
42+
K](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] {
43+
// The submatrix API has to be accessed by all the workitems in a
44+
// subgroup these functions will be called once by the subgroup no
45+
// code divergence between the workitems
46+
const auto global_idx = spmd_item.get_global_id(0);
47+
const auto global_idy = spmd_item.get_global_id(1);
48+
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
49+
const auto sg_starty = global_idy - spmd_item.get_local_id(1);
50+
51+
sub_group sg = spmd_item.get_sub_group();
52+
joint_matrix<sub_group, int8_t, use::a, TM, TK, layout::row_major>
53+
sub_a;
54+
// For B, we assume B has been already VNNIed.
55+
joint_matrix<sub_group, int8_t, use::b, TK, TN,
56+
ext::intel::experimental::matrix::layout::packed>
57+
sub_b;
58+
joint_matrix<sub_group, int32_t, use::accumulator, TM, TN> sub_c;
59+
joint_matrix<sub_group, float, use::a, TM, TK, layout::row_major> sub_d;
60+
61+
joint_matrix_fill(sg, sub_c, 0);
62+
for (int k = 0; k < K / TK; k += 1) {
63+
joint_matrix_load(
64+
sg, sub_a,
65+
accA.template get_multi_ptr<access::decorated::no>() +
66+
(sg_startx * TM) * K + k * TK,
67+
K);
68+
joint_matrix_copy(sg, sub_a, sub_d);
69+
joint_matrix_copy(sg, sub_d, sub_a);
70+
joint_matrix_load(
71+
sg, sub_b,
72+
accB.template get_multi_ptr<access::decorated::no>() +
73+
(k * TK / 4) * (N * 4) + sg_starty / SG_SZ * TN * 4,
74+
N * 4);
75+
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c);
76+
}
77+
joint_matrix_store(
78+
sg, sub_c,
79+
accC.template get_multi_ptr<access::decorated::no>() +
80+
(sg_startx * TM) * N + sg_starty / SG_SZ * TN,
81+
N, layout::row_major);
82+
}); // parallel for
83+
}).wait();
84+
}
85+
86+
static constexpr size_t MATRIX_M = TM * 2;
87+
static constexpr size_t MATRIX_N = TN * 2;
88+
static constexpr size_t MATRIX_K = TK * 2;
89+
int8_t A[MATRIX_M][MATRIX_K];
90+
int8_t B[MATRIX_K / 4][MATRIX_N * 4];
91+
int32_t C[MATRIX_M][MATRIX_N];
92+
int32_t D[MATRIX_M][MATRIX_N];
93+
94+
void matrix_multiply_ref(int32_t *A_mem, int32_t *B_mem, int32_t *C_mem, int M,
95+
int N, int K) {
96+
// tiling
97+
for (int m = 0; m < M; m++)
98+
for (int n = 0; n < N; n++) {
99+
for (int k = 0; k < K; k++) {
100+
char *va = (char *)(A_mem + m * K + k);
101+
char *vb = (char *)(B_mem + k * N + n);
102+
int acc = *(C_mem + m * N + n);
103+
for (int i = 0; i < 4; i++) {
104+
acc += (va[i] * vb[i]);
105+
}
106+
*(C_mem + m * N + n) = acc;
107+
}
108+
}
109+
}
110+
111+
int main() {
112+
for (int i = 0; i < MATRIX_M; i++) {
113+
for (int j = 0; j < MATRIX_K; j++) {
114+
A[i][j] = i + 2 * j;
115+
}
116+
}
117+
for (int i = 0; i < MATRIX_K / 4; i++) {
118+
for (int j = 0; j < MATRIX_N * 4; j++) {
119+
B[i][j] = i + j;
120+
}
121+
}
122+
for (int i = 0; i < MATRIX_M; i++) {
123+
for (int j = 0; j < MATRIX_N; j++) {
124+
C[i][j] = 0;
125+
D[i][j] = 0;
126+
}
127+
}
128+
129+
big_matrix<int32_t, MATRIX_M, MATRIX_N> MC((int32_t *)&C);
130+
big_matrix<int32_t, MATRIX_M, MATRIX_N> MD((int32_t *)&D);
131+
big_matrix<int8_t, MATRIX_M, MATRIX_K> MA((int8_t *)&A);
132+
big_matrix<int8_t, MATRIX_K / 4, MATRIX_N * 4> MB((int8_t *)&B);
133+
matrix_multiply(MC, MA, MB);
134+
matrix_multiply_ref((int32_t *)A, (int32_t *)B, (int32_t *)D, MATRIX_M,
135+
MATRIX_N, MATRIX_K / 4);
136+
137+
bool res = true;
138+
for (int i = 0; i < MATRIX_M; i++) {
139+
for (int j = 0; j < MATRIX_N; j++) {
140+
if (C[i][j] != D[i][j])
141+
res = false;
142+
}
143+
}
144+
std::cout << (res ? "passed" : "failed") << std::endl;
145+
return !res;
146+
}

0 commit comments

Comments
 (0)