Skip to content

[SYCL][Matrix spec] keep deletion of assign op and copy ctor but change signature of joint_matrix_mad #11007

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -304,15 +304,20 @@ q.submit([&](sycl::handler& cgh) {
joint_matrix<sub_group, int32_t, use::accumulator, tM, tN> tC;
joint_matrix_fill(sg, tC, 0);
for (int k = 0; k < K; k += tK) {
joint_matrix_load(sg, tA, accA + sg_startx * tM * K + k, K);
joint_matrix_load(sg, tB, accB + k * N*4 + sg_starty/SG_SIZE*tN*4, N*4);
tC = joint_matrix_mad(sg, tA, tB, tC);
joint_matrix_load(sg, tA,
accA.template get_multi_ptr<sycl::access::decorated::no>() +
sg_startx * tM * K + k, K);
joint_matrix_load(sg, tB,
accB.template get_multi_ptr<sycl::access::decorated::no>() +
k * N*4 + sg_starty/SG_SIZE*tN*4, N*4);
joint_matrix_mad(sg, tC, tA, tB, tC);
}
auto wi_data_c = ext::intel::experimental::matrix::get_wi_data(sg, tC);
for (int i = 0; i < wi_data_c.length(); i++)
wi_data_c[i] *= alpha;
joint_matrix_apply(sg, tC, [=](int8_t x) {
x *= alpha;
});
joint_matrix_store(sg, tC,
accC + sg_startx * tM * N + sg_starty/SG_SIZE*tN, N, layout::row_major);
accC.template get_multi_ptr<sycl::access::decorated::no>()
+ sg_startx * tM * N + sg_starty/SG_SIZE*tN, N, layout::row_major);
});
});
q.wait();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -274,11 +274,11 @@ rows for the row major layout, or between columns for the column major layout.
```c++
namespace sycl::ext::oneapi::experimental::matrix {

template <typename Group, typename Ta, typename Tb, typename Tc,
std::size_t M, std::size_t K, std::size_t N, layout LayoutA, layout
LayoutB, typename Td = Tc>
joint_matrix<Group, Td, use::accumulator, M, N, layout::dynamic>
joint_matrix_mad(Group g,
template <typename Group, typename Ta, typename Tb, typename Tc, typename Td,
std::size_t M, std::size_t K, std::size_t N,
layout LayoutA, layout LayoutB>
void joint_matrix_mad(Group g,
joint_matrix<Group, Td, use::accumulator, M, N, layout::dynamic> &D,
const joint_matrix<Group, Ta, use::a, M, K, LayoutA> &A,
const joint_matrix<Group, Tb, use::b, K, N, LayoutB> &B,
const joint_matrix<Group, Tc, use::accumulator, M, N, layout::dynamic> &C);
Expand All @@ -287,7 +287,7 @@ joint_matrix_mad(Group g,
```
The matrix multiply and add function performs the multiply operation
on the matrices `A` and `B`, accumulates the result with `C` and returns
the result.
the result into the matrix `D`.

Each device supports only certain combinations of types for the `A`,
`B`, and `C` matrices. The application must use the query operations
Expand Down Expand Up @@ -505,6 +505,12 @@ range<2> L = {1, SG_SIZE};
int8_t *memA = malloc_shared<int8_t>(M*K, q);
int8_t *memB = malloc_shared<int8_t>(K*N, q);
int32_t *memC = malloc_shared<int32_t>(M*N, q);
auto pA = address_space_cast<sycl::access::address_space::global_space,
sycl::access::decorated::no>(memA);
auto pB = address_space_cast<sycl::access::address_space::global_space,
sycl::access::decorated::no>(memB);
auto pC = address_space_cast<sycl::access::address_space::global_space,
sycl::access::decorated::no>(memC);
q.parallel_for(nd_range<2>(G, L), [=](nd_item<2> item)
[[sycl::reqd_sub_group_size(SG_SIZE)]] {
const auto global_idx = item.get_global_id(0);
Expand All @@ -517,20 +523,15 @@ q.parallel_for(nd_range<2>(G, L), [=](nd_item<2> item)
joint_matrix<sub_group, int32_t, use::accumulator, tM, tN> tC;
joint_matrix_fill(sg, tC, 0);
for (int k = 0; k < K; k += tK) {
joint_matrix_load(sg, tA,
multi_ptr<int8_t, sycl::access::address_space::global_space>(memA) +
sg_startx * tM * K + k, K);
joint_matrix_load(sg, tB,
multi_ptr<int8_t, sycl::access::address_space::global_space>(memB) +
k * N + sg_starty/SG_SIZE*tN, N);
tC = joint_matrix_mad(sg, tA, tB, tC);
joint_matrix_load(sg, tA, pA + sg_startx * tM * K + k, K);
joint_matrix_load(sg, tB, pB + k * N + sg_starty/SG_SIZE*tN, N);
joint_matrix_mad(sg, tC, tA, tB, tC);
}
joint_matrix_apply(sg, tC, [=](int8_t x) {
x *= alpha;
});
joint_matrix_store(sg, tC,
multi_ptr<int32_t, sycl::access::address_space::global_space>(memC) +
sg_startx * tM * N + sg_starty/SG_SIZE*tN, N, layout::row_major);
joint_matrix_store(sg, tC, pC + sg_startx * tM * N + sg_starty/SG_SIZE*tN,
N, layout::row_major);
}).wait();
```

Expand Down