Skip to content

Commit 4c17a7f

Browse files
authored
[SYCL][Matrix Headers] Add out of bounds load/store (intel#11210)
Spec is in intel#11172
1 parent caa4ed5 commit 4c17a7f

File tree

3 files changed

+353
-9
lines changed

3 files changed

+353
-9
lines changed

sycl/include/CL/__spirv/spirv_ops.hpp

+33
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,39 @@ extern __DPCPP_SYCL_EXTERNAL void __spirv_JointMatrixStoreINTEL(
4545
std::size_t Stride, __spv::MatrixLayout Layout = L,
4646
__spv::Scope::Flag Sc = S, int MemOperand = 0);
4747

48+
template <typename T, typename Tp, std::size_t R, std::size_t C,
49+
__spv::MatrixUse U,
50+
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
51+
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
52+
extern __DPCPP_SYCL_EXTERNAL
53+
__spv::__spirv_JointMatrixINTEL<Tp, R, C, L, S, U> *
54+
__spirv_CompositeConstructCheckedINTEL(const T Value, size_t Height,
55+
size_t Stride, size_t Width,
56+
size_t CoordX, size_t CoordY);
57+
58+
template <typename T, typename Tp, std::size_t R, std::size_t C,
59+
__spv::MatrixUse U,
60+
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
61+
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
62+
extern __DPCPP_SYCL_EXTERNAL
63+
__spv::__spirv_JointMatrixINTEL<Tp, R, C, L, S, U> *
64+
__spirv_JointMatrixLoadCheckedINTEL(T *Ptr, std::size_t Stride,
65+
size_t Height, size_t Width,
66+
size_t CoordX, size_t CoordY,
67+
__spv::MatrixLayout Layout = L,
68+
__spv::Scope::Flag Sc = S,
69+
int MemOperand = 0);
70+
71+
template <typename T, typename Tp, std::size_t R, std::size_t C,
72+
__spv::MatrixUse U,
73+
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
74+
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
75+
extern __DPCPP_SYCL_EXTERNAL void __spirv_JointMatrixStoreCheckedINTEL(
76+
T *Ptr, __spv::__spirv_JointMatrixINTEL<Tp, R, C, L, S, U> *Object,
77+
std::size_t Stride, size_t Height, size_t Width, size_t CoordX,
78+
size_t CoordY, __spv::MatrixLayout Layout = L, __spv::Scope::Flag Sc = S,
79+
int MemOperand = 0);
80+
4881
template <typename T1, typename T2, std::size_t M, std::size_t K, std::size_t N,
4982
__spv::MatrixUse UA, __spv::MatrixUse UB, __spv::MatrixUse UC,
5083
__spv::MatrixLayout LA = __spv::MatrixLayout::RowMajor,

sycl/include/sycl/ext/oneapi/matrix/matrix-intel.hpp

+310
Original file line numberDiff line numberDiff line change
@@ -602,6 +602,316 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_apply(
602602
#endif
603603
}
604604

605+
using namespace sycl::ext::oneapi::experimental::matrix;
606+
607+
// Begin out-of-bounds API
608+
609+
template <typename Group, typename T, size_t NumRows, size_t NumCols, use Use,
610+
layout Layout, typename T2>
611+
inline __SYCL_ALWAYS_INLINE void joint_matrix_fill_checked(
612+
Group, joint_matrix<Group, T, Use, NumRows, NumCols, Layout> &Res,
613+
const T2 &Value, size_t Stride, size_t Height, size_t Width, size_t CoordX,
614+
size_t CoordY) {
615+
#if defined(__SYCL_DEVICE_ONLY__)
616+
using storage_element_type =
617+
typename oneapi::detail::jm_type_interpretation_helper_trait<
618+
T>::storage_element_type;
619+
Res.spvm = __spirv_CompositeConstructCheckedINTEL<
620+
storage_element_type, T, NumRows, NumCols,
621+
spv_matrix_use_traits<Use>::value,
622+
spv_matrix_layout_traits<Layout>::value>(
623+
static_cast<storage_element_type>(Value), Stride, Height, Width, CoordX,
624+
CoordY);
625+
#else
626+
std::ignore = Res;
627+
std::ignore = Value;
628+
std::ignore = Stride;
629+
std::ignore = Height;
630+
std::ignore = Width;
631+
std::ignore = CoordX;
632+
std::ignore = CoordY;
633+
throw runtime_error("joint matrix is not supported on host device.",
634+
PI_ERROR_INVALID_DEVICE);
635+
#endif // defined(__SYCL_DEVICE_ONLY__)
636+
}
637+
638+
template <
639+
typename Group, typename S, typename T, size_t NumRows, size_t NumCols,
640+
access::address_space Space, access::decorated IsDecorated,
641+
std::enable_if_t<std::is_same<S, std::remove_const_t<T>>::value, bool> =
642+
true>
643+
inline __SYCL_ALWAYS_INLINE void joint_matrix_load_checked(
644+
Group sg,
645+
joint_matrix<Group, S, use::accumulator, NumRows, NumCols, layout::dynamic>
646+
&Res,
647+
multi_ptr<T, Space, IsDecorated> Src, size_t Stride, layout Layout,
648+
size_t Height, size_t Width, size_t CoordX, size_t CoordY) {
649+
#if defined(__SYCL_DEVICE_ONLY__)
650+
static_assert(Space != access::address_space::private_space,
651+
"Joint Matrix doesn't support load from private memory!");
652+
std::ignore = sg;
653+
using DecorT = typename sycl::detail::DecoratedType<T, Space>::type;
654+
DecorT *Ptr = sycl::detail::getDecorated<DecorT>(Src);
655+
Res.spvm = __spirv_JointMatrixLoadCheckedINTEL<
656+
DecorT, S, NumRows, NumCols,
657+
spv_matrix_use_traits<use::accumulator>::value,
658+
spv_matrix_layout_traits<layout::dynamic>::value>(
659+
Ptr, Stride, Height, Width, CoordX, CoordY,
660+
sycl::detail::joint_matrix_layout_to_spv(Layout),
661+
spv_scope_traits<Group>::value);
662+
#else
663+
std::ignore = sg;
664+
std::ignore = Res;
665+
std::ignore = Src;
666+
std::ignore = Stride;
667+
std::ignore = Height;
668+
std::ignore = Width;
669+
std::ignore = Layout;
670+
std::ignore = CoordX;
671+
std::ignore = CoordY;
672+
throw runtime_error("joint matrix is not supported on host device.",
673+
PI_ERROR_INVALID_DEVICE);
674+
#endif // defined(__SYCL_DEVICE_ONLY__)
675+
}
676+
677+
template <
678+
typename Group, typename S, typename T, use Use, size_t NumRows,
679+
size_t NumCols, layout Layout, access::address_space Space,
680+
access::decorated IsDecorated,
681+
std::enable_if_t<std::is_same<S, std::remove_const_t<T>>::value ||
682+
(std::is_same<S, precision::tf32>::value &&
683+
std::is_same<std::remove_const_t<T>, float>::value),
684+
bool> = true>
685+
inline __SYCL_ALWAYS_INLINE void joint_matrix_load_checked(
686+
Group sg, joint_matrix<Group, S, Use, NumRows, NumCols, Layout> &Res,
687+
multi_ptr<T, Space, IsDecorated> Src, size_t Stride, size_t Height,
688+
size_t Width, size_t CoordX, size_t CoordY) {
689+
#if defined(__SYCL_DEVICE_ONLY__)
690+
static_assert(Space != access::address_space::private_space,
691+
"Joint Matrix doesn't support load from private memory!");
692+
std::ignore = sg;
693+
using DecorT = typename sycl::detail::DecoratedType<T, Space>::type;
694+
DecorT *Ptr = sycl::detail::getDecorated<DecorT>(Src);
695+
Res.spvm = __spirv_JointMatrixLoadCheckedINTEL<
696+
DecorT, S, NumRows, NumCols, spv_matrix_use_traits<Use>::value,
697+
spv_matrix_layout_traits<Layout>::value>(
698+
Ptr, Stride, Height, Width, CoordX, CoordY,
699+
spv_matrix_layout_traits<Layout>::value, spv_scope_traits<Group>::value);
700+
#else
701+
std::ignore = sg;
702+
std::ignore = Res;
703+
std::ignore = Src;
704+
std::ignore = Stride;
705+
std::ignore = Height;
706+
std::ignore = Width;
707+
std::ignore = CoordX;
708+
std::ignore = CoordY;
709+
throw runtime_error("joint matrix is not supported on host device.",
710+
PI_ERROR_INVALID_DEVICE);
711+
#endif // defined(__SYCL_DEVICE_ONLY__)
712+
}
713+
714+
template <typename Group, typename T, size_t NumRows, size_t NumCols,
715+
access::address_space Space, access::decorated IsDecorated>
716+
inline __SYCL_ALWAYS_INLINE void joint_matrix_store_checked(
717+
Group sg,
718+
joint_matrix<Group, T, use::accumulator, NumRows, NumCols, layout::dynamic>
719+
&Src,
720+
multi_ptr<T, Space, IsDecorated> Dst, size_t Stride, layout Layout,
721+
size_t Height, size_t Width, size_t CoordX, size_t CoordY) {
722+
#if defined(__SYCL_DEVICE_ONLY__)
723+
static_assert(Space != access::address_space::private_space,
724+
"Joint Matrix doesn't support store to private memory!");
725+
std::ignore = sg;
726+
using DecorT = typename sycl::detail::DecoratedType<T, Space>::type;
727+
DecorT *Ptr = sycl::detail::getDecorated<DecorT>(Dst);
728+
__spirv_JointMatrixStoreCheckedINTEL<
729+
DecorT, T, NumRows, NumCols,
730+
spv_matrix_use_traits<use::accumulator>::value,
731+
spv_matrix_layout_traits<layout::dynamic>::value>(
732+
Ptr, Src.spvm, Stride, Height, Width, CoordX, CoordY,
733+
sycl::detail::joint_matrix_layout_to_spv(Layout),
734+
spv_scope_traits<Group>::value);
735+
#else
736+
std::ignore = sg;
737+
std::ignore = Src;
738+
std::ignore = Dst;
739+
std::ignore = Stride;
740+
std::ignore = Height;
741+
std::ignore = Width;
742+
std::ignore = Layout;
743+
std::ignore = CoordX;
744+
std::ignore = CoordY;
745+
throw runtime_error("joint matrix is not supported on host device.",
746+
PI_ERROR_INVALID_DEVICE);
747+
#endif // defined(__SYCL_DEVICE_ONLY__)
748+
}
749+
750+
template <typename Group, typename T, typename Tp, use Use, size_t NumRows,
751+
size_t NumCols, layout Layout, access::address_space Space,
752+
access::decorated IsDecorated,
753+
std::enable_if_t<Use == use::a || Use == use::b, bool> = true>
754+
inline __SYCL_ALWAYS_INLINE void joint_matrix_store_checked(
755+
Group sg, const joint_matrix<Group, Tp, Use, NumRows, NumCols, Layout> &Src,
756+
multi_ptr<T, Space, IsDecorated> Dst, size_t Stride, size_t Height,
757+
size_t Width, size_t CoordX, size_t CoordY) {
758+
#if defined(__SYCL_DEVICE_ONLY__)
759+
static_assert(Space != access::address_space::private_space,
760+
"Joint Matrix doesn't support store to private memory!");
761+
std::ignore = sg;
762+
using DecorT = typename sycl::detail::DecoratedType<T, Space>::type;
763+
DecorT *Ptr = sycl::detail::getDecorated<DecorT>(Dst);
764+
__spirv_JointMatrixStoreCheckedINTEL<DecorT, Tp, NumRows, NumCols,
765+
spv_matrix_use_traits<Use>::value,
766+
spv_matrix_layout_traits<Layout>::value>(
767+
Ptr, Src.spvm, Stride, Height, Width, CoordX, CoordY,
768+
spv_matrix_layout_traits<Layout>::value, spv_scope_traits<Group>::value);
769+
#else
770+
std::ignore = sg;
771+
std::ignore = Src;
772+
std::ignore = Dst;
773+
std::ignore = Stride;
774+
std::ignore = Height;
775+
std::ignore = Width;
776+
std::ignore = CoordX;
777+
std::ignore = CoordY;
778+
throw runtime_error("joint matrix is not supported on host device.",
779+
PI_ERROR_INVALID_DEVICE);
780+
#endif // defined(__SYCL_DEVICE_ONLY__)
781+
}
782+
783+
// Annotated pointer overloads:
784+
template <typename Group, typename S, typename T, size_t NumRows,
785+
size_t NumCols, typename PropertyListT,
786+
std::enable_if_t<std::is_same<S, std::remove_const_t<T>>::value,
787+
bool> = true>
788+
inline __SYCL_ALWAYS_INLINE void joint_matrix_load_checked(
789+
Group sg,
790+
joint_matrix<Group, S, use::accumulator, NumRows, NumCols, layout::dynamic>
791+
&Res,
792+
ext::oneapi::experimental::annotated_ptr<T, PropertyListT> Src,
793+
size_t Stride, layout Layout, size_t Height, size_t Width, size_t CoordX,
794+
size_t CoordY) {
795+
#if defined(__SYCL_DEVICE_ONLY__)
796+
std::ignore = sg;
797+
T *Ptr = Src.get();
798+
Res.spvm = __spirv_JointMatrixLoadCheckedINTEL<
799+
T, S, NumRows, NumCols, spv_matrix_use_traits<use::accumulator>::value,
800+
spv_matrix_layout_traits<layout::dynamic>::value>(
801+
Ptr, Stride, Height, Width, CoordX, CoordY,
802+
sycl::detail::joint_matrix_layout_to_spv(Layout),
803+
spv_scope_traits<Group>::value);
804+
#else
805+
std::ignore = sg;
806+
std::ignore = Res;
807+
std::ignore = Src;
808+
std::ignore = Stride;
809+
std::ignore = Height;
810+
std::ignore = Width;
811+
std::ignore = Layout;
812+
std::ignore = CoordX;
813+
std::ignore = CoordY;
814+
throw runtime_error("joint matrix is not supported on host device.",
815+
PI_ERROR_INVALID_DEVICE);
816+
#endif // defined(__SYCL_DEVICE_ONLY__)
817+
}
818+
819+
template <
820+
typename Group, typename S, typename T, use Use, size_t NumRows,
821+
size_t NumCols, layout Layout, typename PropertyListT,
822+
std::enable_if_t<std::is_same<S, std::remove_const_t<T>>::value ||
823+
(std::is_same<S, precision::tf32>::value &&
824+
std::is_same<std::remove_const_t<T>, float>::value),
825+
bool> = true>
826+
inline __SYCL_ALWAYS_INLINE void joint_matrix_load_checked(
827+
Group sg, joint_matrix<Group, S, Use, NumRows, NumCols, Layout> &Res,
828+
ext::oneapi::experimental::annotated_ptr<T, PropertyListT> Src,
829+
size_t Stride, size_t Height, size_t Width, size_t CoordX, size_t CoordY) {
830+
#if defined(__SYCL_DEVICE_ONLY__)
831+
std::ignore = sg;
832+
T *Ptr = Src.get();
833+
Res.spvm = __spirv_JointMatrixLoadCheckedINTEL<
834+
T, S, NumRows, NumCols, spv_matrix_use_traits<Use>::value,
835+
spv_matrix_layout_traits<Layout>::value>(
836+
Ptr, Stride, Height, Width, CoordX, CoordY,
837+
spv_matrix_layout_traits<Layout>::value, spv_scope_traits<Group>::value);
838+
#else
839+
std::ignore = sg;
840+
std::ignore = Res;
841+
std::ignore = Src;
842+
std::ignore = Stride;
843+
std::ignore = Height;
844+
std::ignore = Width;
845+
std::ignore = CoordX;
846+
std::ignore = CoordY;
847+
throw runtime_error("joint matrix is not supported on host device.",
848+
PI_ERROR_INVALID_DEVICE);
849+
#endif // defined(__SYCL_DEVICE_ONLY__)
850+
}
851+
852+
template <typename Group, typename T, size_t NumRows, size_t NumCols,
853+
typename PropertyListT>
854+
inline __SYCL_ALWAYS_INLINE void joint_matrix_store_checked(
855+
Group sg,
856+
joint_matrix<Group, T, use::accumulator, NumRows, NumCols, layout::dynamic>
857+
&Src,
858+
ext::oneapi::experimental::annotated_ptr<T, PropertyListT> Dst,
859+
size_t Stride, layout Layout, size_t Height, size_t Width, size_t CoordX,
860+
size_t CoordY) {
861+
#if defined(__SYCL_DEVICE_ONLY__)
862+
std::ignore = sg;
863+
T *Ptr = Dst.get();
864+
__spirv_JointMatrixStoreCheckedINTEL<
865+
T, T, NumRows, NumCols, spv_matrix_use_traits<use::accumulator>::value,
866+
spv_matrix_layout_traits<layout::dynamic>::value>(
867+
Ptr, Src.spvm, Stride, Height, Width, CoordX, CoordY,
868+
sycl::detail::joint_matrix_layout_to_spv(Layout),
869+
spv_scope_traits<Group>::value);
870+
#else
871+
std::ignore = sg;
872+
std::ignore = Src;
873+
std::ignore = Dst;
874+
std::ignore = Stride;
875+
std::ignore = Height;
876+
std::ignore = Width;
877+
std::ignore = Layout;
878+
std::ignore = CoordX;
879+
std::ignore = CoordY;
880+
throw runtime_error("joint matrix is not supported on host device.",
881+
PI_ERROR_INVALID_DEVICE);
882+
#endif // defined(__SYCL_DEVICE_ONLY__)
883+
}
884+
885+
template <typename Group, typename T, typename Tp, use Use, size_t NumRows,
886+
size_t NumCols, layout Layout, typename PropertyListT,
887+
std::enable_if_t<Use == use::a || Use == use::b, bool> = true>
888+
inline __SYCL_ALWAYS_INLINE void joint_matrix_store_checked(
889+
Group sg, const joint_matrix<Group, Tp, Use, NumRows, NumCols, Layout> &Src,
890+
ext::oneapi::experimental::annotated_ptr<T, PropertyListT> Dst,
891+
size_t Stride, size_t Height, size_t Width, size_t CoordX, size_t CoordY) {
892+
#if defined(__SYCL_DEVICE_ONLY__)
893+
std::ignore = sg;
894+
T *Ptr = Dst.get();
895+
__spirv_JointMatrixStoreCheckedINTEL<T, Tp, NumRows, NumCols,
896+
spv_matrix_use_traits<Use>::value,
897+
spv_matrix_layout_traits<Layout>::value>(
898+
Ptr, Src.spvm, Stride, Height, Width, CoordX, CoordY,
899+
spv_matrix_layout_traits<Layout>::value, spv_scope_traits<Group>::value);
900+
#else
901+
std::ignore = sg;
902+
std::ignore = Src;
903+
std::ignore = Dst;
904+
std::ignore = Stride;
905+
std::ignore = Height;
906+
std::ignore = Width;
907+
std::ignore = CoordX;
908+
std::ignore = CoordY;
909+
throw runtime_error("joint matrix is not supported on host device.",
910+
PI_ERROR_INVALID_DEVICE);
911+
#endif // defined(__SYCL_DEVICE_ONLY__)
912+
}
913+
// End out-of-bounds API
914+
605915
} // namespace intel::experimental::matrix
606916

607917
} // namespace ext

sycl/test-e2e/Matrix/joint_matrix_out_bounds_impl.hpp

+10-9
Original file line numberDiff line numberDiff line change
@@ -51,22 +51,23 @@ void matrix_multiply(T1 *C, T2 *A, T2 *B, queue q, unsigned int vnniFactor) {
5151
sub_b;
5252
joint_matrix<sub_group, float, use::accumulator, TM, TN> sub_c;
5353
// bounds-checked load where width and height are added
54-
joint_matrix_fill_checked(sg, sub_c, 1, M, N);
54+
ext::intel::experimental::matrix::joint_matrix_fill_checked(
55+
sg, sub_c, 1, N, M, N, sg_startx * TM, sg_starty / SG_SZ * TN);
5556
for (int k = 0; k < K; k += TK) {
5657
// bounds-checked load where width and height are added
57-
joint_matrix_load_checked(sg, sub_a, pA + (sg_startx * TM) * K + k,
58-
K, M, K);
58+
ext::intel::experimental::matrix::joint_matrix_load_checked(
59+
sg, sub_a, pA, K, M, K, sg_startx * TM, k);
5960
// Assume we alreay in vnni format.
6061
// bounds-checked load where width and height are added
61-
joint_matrix_load_checked(
62-
sg, sub_b, pB + k * N + sg_starty / SG_SZ * TN * vnniFactor,
63-
N * vnniFactor, K / vnniFactor, N * vnniFactor);
62+
ext::intel::experimental::matrix::joint_matrix_load_checked(
63+
sg, sub_b, pB, N * vnniFactor, K / vnniFactor, N * vnniFactor,
64+
k, sg_starty / SG_SZ * TN * vnniFactor);
6465
joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c);
6566
}
6667
// bounds-checked store where width and height are added
67-
joint_matrix_store_checked(
68-
sg, sub_c, pC + (sg_startx * TM) * N + sg_starty / SG_SZ * TN, N,
69-
layout::row_major, M, N);
68+
ext::intel::experimental::matrix::joint_matrix_store_checked(
69+
sg, sub_c, pC, N, layout::row_major, M, N, sg_startx * TM,
70+
sg_starty / SG_SZ * TN);
7071
}); // parallel for
7172
}).wait();
7273
}

0 commit comments

Comments
 (0)