@@ -602,6 +602,316 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_apply(
602
602
#endif
603
603
}
604
604
605
+ using namespace sycl ::ext::oneapi::experimental::matrix;
606
+
607
+ // Begin out-of-bounds API
608
+
609
+ template <typename Group, typename T, size_t NumRows, size_t NumCols, use Use,
610
+ layout Layout, typename T2>
611
+ inline __SYCL_ALWAYS_INLINE void joint_matrix_fill_checked (
612
+ Group, joint_matrix<Group, T, Use, NumRows, NumCols, Layout> &Res,
613
+ const T2 &Value, size_t Stride, size_t Height, size_t Width, size_t CoordX,
614
+ size_t CoordY) {
615
+ #if defined(__SYCL_DEVICE_ONLY__)
616
+ using storage_element_type =
617
+ typename oneapi::detail::jm_type_interpretation_helper_trait<
618
+ T>::storage_element_type;
619
+ Res.spvm = __spirv_CompositeConstructCheckedINTEL<
620
+ storage_element_type, T, NumRows, NumCols,
621
+ spv_matrix_use_traits<Use>::value,
622
+ spv_matrix_layout_traits<Layout>::value>(
623
+ static_cast <storage_element_type>(Value), Stride, Height, Width, CoordX,
624
+ CoordY);
625
+ #else
626
+ std::ignore = Res;
627
+ std::ignore = Value;
628
+ std::ignore = Stride;
629
+ std::ignore = Height;
630
+ std::ignore = Width;
631
+ std::ignore = CoordX;
632
+ std::ignore = CoordY;
633
+ throw runtime_error (" joint matrix is not supported on host device." ,
634
+ PI_ERROR_INVALID_DEVICE);
635
+ #endif // defined(__SYCL_DEVICE_ONLY__)
636
+ }
637
+
638
+ template <
639
+ typename Group, typename S, typename T, size_t NumRows, size_t NumCols,
640
+ access::address_space Space, access::decorated IsDecorated,
641
+ std::enable_if_t <std::is_same<S, std::remove_const_t <T>>::value, bool > =
642
+ true >
643
+ inline __SYCL_ALWAYS_INLINE void joint_matrix_load_checked (
644
+ Group sg,
645
+ joint_matrix<Group, S, use::accumulator, NumRows, NumCols, layout::dynamic>
646
+ &Res,
647
+ multi_ptr<T, Space, IsDecorated> Src, size_t Stride, layout Layout,
648
+ size_t Height, size_t Width, size_t CoordX, size_t CoordY) {
649
+ #if defined(__SYCL_DEVICE_ONLY__)
650
+ static_assert (Space != access ::address_space::private_space,
651
+ " Joint Matrix doesn't support load from private memory!" );
652
+ std::ignore = sg;
653
+ using DecorT = typename sycl::detail::DecoratedType<T, Space>::type;
654
+ DecorT *Ptr = sycl::detail::getDecorated<DecorT>(Src);
655
+ Res.spvm = __spirv_JointMatrixLoadCheckedINTEL<
656
+ DecorT, S, NumRows, NumCols,
657
+ spv_matrix_use_traits<use::accumulator>::value,
658
+ spv_matrix_layout_traits<layout::dynamic>::value>(
659
+ Ptr , Stride, Height, Width, CoordX, CoordY,
660
+ sycl::detail::joint_matrix_layout_to_spv (Layout),
661
+ spv_scope_traits<Group>::value);
662
+ #else
663
+ std::ignore = sg;
664
+ std::ignore = Res;
665
+ std::ignore = Src;
666
+ std::ignore = Stride;
667
+ std::ignore = Height;
668
+ std::ignore = Width;
669
+ std::ignore = Layout;
670
+ std::ignore = CoordX;
671
+ std::ignore = CoordY;
672
+ throw runtime_error (" joint matrix is not supported on host device." ,
673
+ PI_ERROR_INVALID_DEVICE);
674
+ #endif // defined(__SYCL_DEVICE_ONLY__)
675
+ }
676
+
677
+ template <
678
+ typename Group, typename S, typename T, use Use, size_t NumRows,
679
+ size_t NumCols, layout Layout, access::address_space Space,
680
+ access::decorated IsDecorated,
681
+ std::enable_if_t <std::is_same<S, std::remove_const_t <T>>::value ||
682
+ (std::is_same<S, precision::tf32>::value &&
683
+ std::is_same<std::remove_const_t <T>, float >::value),
684
+ bool > = true >
685
+ inline __SYCL_ALWAYS_INLINE void joint_matrix_load_checked (
686
+ Group sg, joint_matrix<Group, S, Use, NumRows, NumCols, Layout> &Res,
687
+ multi_ptr<T, Space, IsDecorated> Src, size_t Stride, size_t Height,
688
+ size_t Width, size_t CoordX, size_t CoordY) {
689
+ #if defined(__SYCL_DEVICE_ONLY__)
690
+ static_assert (Space != access ::address_space::private_space,
691
+ " Joint Matrix doesn't support load from private memory!" );
692
+ std::ignore = sg;
693
+ using DecorT = typename sycl::detail::DecoratedType<T, Space>::type;
694
+ DecorT *Ptr = sycl::detail::getDecorated<DecorT>(Src);
695
+ Res.spvm = __spirv_JointMatrixLoadCheckedINTEL<
696
+ DecorT, S, NumRows, NumCols, spv_matrix_use_traits<Use>::value,
697
+ spv_matrix_layout_traits<Layout>::value>(
698
+ Ptr , Stride, Height, Width, CoordX, CoordY,
699
+ spv_matrix_layout_traits<Layout>::value, spv_scope_traits<Group>::value);
700
+ #else
701
+ std::ignore = sg;
702
+ std::ignore = Res;
703
+ std::ignore = Src;
704
+ std::ignore = Stride;
705
+ std::ignore = Height;
706
+ std::ignore = Width;
707
+ std::ignore = CoordX;
708
+ std::ignore = CoordY;
709
+ throw runtime_error (" joint matrix is not supported on host device." ,
710
+ PI_ERROR_INVALID_DEVICE);
711
+ #endif // defined(__SYCL_DEVICE_ONLY__)
712
+ }
713
+
714
+ template <typename Group, typename T, size_t NumRows, size_t NumCols,
715
+ access::address_space Space, access::decorated IsDecorated>
716
+ inline __SYCL_ALWAYS_INLINE void joint_matrix_store_checked (
717
+ Group sg,
718
+ joint_matrix<Group, T, use::accumulator, NumRows, NumCols, layout::dynamic>
719
+ &Src,
720
+ multi_ptr<T, Space, IsDecorated> Dst, size_t Stride, layout Layout,
721
+ size_t Height, size_t Width, size_t CoordX, size_t CoordY) {
722
+ #if defined(__SYCL_DEVICE_ONLY__)
723
+ static_assert (Space != access ::address_space::private_space,
724
+ " Joint Matrix doesn't support store to private memory!" );
725
+ std::ignore = sg;
726
+ using DecorT = typename sycl::detail::DecoratedType<T, Space>::type;
727
+ DecorT *Ptr = sycl::detail::getDecorated<DecorT>(Dst);
728
+ __spirv_JointMatrixStoreCheckedINTEL<
729
+ DecorT, T, NumRows, NumCols,
730
+ spv_matrix_use_traits<use::accumulator>::value,
731
+ spv_matrix_layout_traits<layout::dynamic>::value>(
732
+ Ptr , Src.spvm , Stride, Height, Width, CoordX, CoordY,
733
+ sycl::detail::joint_matrix_layout_to_spv (Layout),
734
+ spv_scope_traits<Group>::value);
735
+ #else
736
+ std::ignore = sg;
737
+ std::ignore = Src;
738
+ std::ignore = Dst;
739
+ std::ignore = Stride;
740
+ std::ignore = Height;
741
+ std::ignore = Width;
742
+ std::ignore = Layout;
743
+ std::ignore = CoordX;
744
+ std::ignore = CoordY;
745
+ throw runtime_error (" joint matrix is not supported on host device." ,
746
+ PI_ERROR_INVALID_DEVICE);
747
+ #endif // defined(__SYCL_DEVICE_ONLY__)
748
+ }
749
+
750
+ template <typename Group, typename T, typename Tp, use Use, size_t NumRows,
751
+ size_t NumCols, layout Layout, access::address_space Space,
752
+ access::decorated IsDecorated,
753
+ std::enable_if_t <Use == use::a || Use == use::b, bool > = true >
754
+ inline __SYCL_ALWAYS_INLINE void joint_matrix_store_checked (
755
+ Group sg, const joint_matrix<Group, Tp, Use, NumRows, NumCols, Layout> &Src,
756
+ multi_ptr<T, Space, IsDecorated> Dst, size_t Stride, size_t Height,
757
+ size_t Width, size_t CoordX, size_t CoordY) {
758
+ #if defined(__SYCL_DEVICE_ONLY__)
759
+ static_assert (Space != access ::address_space::private_space,
760
+ " Joint Matrix doesn't support store to private memory!" );
761
+ std::ignore = sg;
762
+ using DecorT = typename sycl::detail::DecoratedType<T, Space>::type;
763
+ DecorT *Ptr = sycl::detail::getDecorated<DecorT>(Dst);
764
+ __spirv_JointMatrixStoreCheckedINTEL<DecorT, Tp, NumRows, NumCols,
765
+ spv_matrix_use_traits<Use>::value,
766
+ spv_matrix_layout_traits<Layout>::value>(
767
+ Ptr , Src.spvm , Stride, Height, Width, CoordX, CoordY,
768
+ spv_matrix_layout_traits<Layout>::value, spv_scope_traits<Group>::value);
769
+ #else
770
+ std::ignore = sg;
771
+ std::ignore = Src;
772
+ std::ignore = Dst;
773
+ std::ignore = Stride;
774
+ std::ignore = Height;
775
+ std::ignore = Width;
776
+ std::ignore = CoordX;
777
+ std::ignore = CoordY;
778
+ throw runtime_error (" joint matrix is not supported on host device." ,
779
+ PI_ERROR_INVALID_DEVICE);
780
+ #endif // defined(__SYCL_DEVICE_ONLY__)
781
+ }
782
+
783
+ // Annotated pointer overloads:
784
+ template <typename Group, typename S, typename T, size_t NumRows,
785
+ size_t NumCols, typename PropertyListT,
786
+ std::enable_if_t <std::is_same<S, std::remove_const_t <T>>::value,
787
+ bool > = true >
788
+ inline __SYCL_ALWAYS_INLINE void joint_matrix_load_checked (
789
+ Group sg,
790
+ joint_matrix<Group, S, use::accumulator, NumRows, NumCols, layout::dynamic>
791
+ &Res,
792
+ ext::oneapi::experimental::annotated_ptr<T, PropertyListT> Src,
793
+ size_t Stride, layout Layout, size_t Height, size_t Width, size_t CoordX,
794
+ size_t CoordY) {
795
+ #if defined(__SYCL_DEVICE_ONLY__)
796
+ std::ignore = sg;
797
+ T *Ptr = Src.get ();
798
+ Res.spvm = __spirv_JointMatrixLoadCheckedINTEL<
799
+ T, S, NumRows, NumCols, spv_matrix_use_traits<use::accumulator>::value,
800
+ spv_matrix_layout_traits<layout::dynamic>::value>(
801
+ Ptr , Stride, Height, Width, CoordX, CoordY,
802
+ sycl::detail::joint_matrix_layout_to_spv (Layout),
803
+ spv_scope_traits<Group>::value);
804
+ #else
805
+ std::ignore = sg;
806
+ std::ignore = Res;
807
+ std::ignore = Src;
808
+ std::ignore = Stride;
809
+ std::ignore = Height;
810
+ std::ignore = Width;
811
+ std::ignore = Layout;
812
+ std::ignore = CoordX;
813
+ std::ignore = CoordY;
814
+ throw runtime_error (" joint matrix is not supported on host device." ,
815
+ PI_ERROR_INVALID_DEVICE);
816
+ #endif // defined(__SYCL_DEVICE_ONLY__)
817
+ }
818
+
819
+ template <
820
+ typename Group, typename S, typename T, use Use, size_t NumRows,
821
+ size_t NumCols, layout Layout, typename PropertyListT,
822
+ std::enable_if_t <std::is_same<S, std::remove_const_t <T>>::value ||
823
+ (std::is_same<S, precision::tf32>::value &&
824
+ std::is_same<std::remove_const_t <T>, float >::value),
825
+ bool > = true >
826
+ inline __SYCL_ALWAYS_INLINE void joint_matrix_load_checked (
827
+ Group sg, joint_matrix<Group, S, Use, NumRows, NumCols, Layout> &Res,
828
+ ext::oneapi::experimental::annotated_ptr<T, PropertyListT> Src,
829
+ size_t Stride, size_t Height, size_t Width, size_t CoordX, size_t CoordY) {
830
+ #if defined(__SYCL_DEVICE_ONLY__)
831
+ std::ignore = sg;
832
+ T *Ptr = Src.get ();
833
+ Res.spvm = __spirv_JointMatrixLoadCheckedINTEL<
834
+ T, S, NumRows, NumCols, spv_matrix_use_traits<Use>::value,
835
+ spv_matrix_layout_traits<Layout>::value>(
836
+ Ptr , Stride, Height, Width, CoordX, CoordY,
837
+ spv_matrix_layout_traits<Layout>::value, spv_scope_traits<Group>::value);
838
+ #else
839
+ std::ignore = sg;
840
+ std::ignore = Res;
841
+ std::ignore = Src;
842
+ std::ignore = Stride;
843
+ std::ignore = Height;
844
+ std::ignore = Width;
845
+ std::ignore = CoordX;
846
+ std::ignore = CoordY;
847
+ throw runtime_error (" joint matrix is not supported on host device." ,
848
+ PI_ERROR_INVALID_DEVICE);
849
+ #endif // defined(__SYCL_DEVICE_ONLY__)
850
+ }
851
+
852
+ template <typename Group, typename T, size_t NumRows, size_t NumCols,
853
+ typename PropertyListT>
854
+ inline __SYCL_ALWAYS_INLINE void joint_matrix_store_checked (
855
+ Group sg,
856
+ joint_matrix<Group, T, use::accumulator, NumRows, NumCols, layout::dynamic>
857
+ &Src,
858
+ ext::oneapi::experimental::annotated_ptr<T, PropertyListT> Dst,
859
+ size_t Stride, layout Layout, size_t Height, size_t Width, size_t CoordX,
860
+ size_t CoordY) {
861
+ #if defined(__SYCL_DEVICE_ONLY__)
862
+ std::ignore = sg;
863
+ T *Ptr = Dst.get ();
864
+ __spirv_JointMatrixStoreCheckedINTEL<
865
+ T, T, NumRows, NumCols, spv_matrix_use_traits<use::accumulator>::value,
866
+ spv_matrix_layout_traits<layout::dynamic>::value>(
867
+ Ptr , Src.spvm , Stride, Height, Width, CoordX, CoordY,
868
+ sycl::detail::joint_matrix_layout_to_spv (Layout),
869
+ spv_scope_traits<Group>::value);
870
+ #else
871
+ std::ignore = sg;
872
+ std::ignore = Src;
873
+ std::ignore = Dst;
874
+ std::ignore = Stride;
875
+ std::ignore = Height;
876
+ std::ignore = Width;
877
+ std::ignore = Layout;
878
+ std::ignore = CoordX;
879
+ std::ignore = CoordY;
880
+ throw runtime_error (" joint matrix is not supported on host device." ,
881
+ PI_ERROR_INVALID_DEVICE);
882
+ #endif // defined(__SYCL_DEVICE_ONLY__)
883
+ }
884
+
885
+ template <typename Group, typename T, typename Tp, use Use, size_t NumRows,
886
+ size_t NumCols, layout Layout, typename PropertyListT,
887
+ std::enable_if_t <Use == use::a || Use == use::b, bool > = true >
888
+ inline __SYCL_ALWAYS_INLINE void joint_matrix_store_checked (
889
+ Group sg, const joint_matrix<Group, Tp, Use, NumRows, NumCols, Layout> &Src,
890
+ ext::oneapi::experimental::annotated_ptr<T, PropertyListT> Dst,
891
+ size_t Stride, size_t Height, size_t Width, size_t CoordX, size_t CoordY) {
892
+ #if defined(__SYCL_DEVICE_ONLY__)
893
+ std::ignore = sg;
894
+ T *Ptr = Dst.get ();
895
+ __spirv_JointMatrixStoreCheckedINTEL<T, Tp, NumRows, NumCols,
896
+ spv_matrix_use_traits<Use>::value,
897
+ spv_matrix_layout_traits<Layout>::value>(
898
+ Ptr , Src.spvm , Stride, Height, Width, CoordX, CoordY,
899
+ spv_matrix_layout_traits<Layout>::value, spv_scope_traits<Group>::value);
900
+ #else
901
+ std::ignore = sg;
902
+ std::ignore = Src;
903
+ std::ignore = Dst;
904
+ std::ignore = Stride;
905
+ std::ignore = Height;
906
+ std::ignore = Width;
907
+ std::ignore = CoordX;
908
+ std::ignore = CoordY;
909
+ throw runtime_error (" joint matrix is not supported on host device." ,
910
+ PI_ERROR_INVALID_DEVICE);
911
+ #endif // defined(__SYCL_DEVICE_ONLY__)
912
+ }
913
+ // End out-of-bounds API
914
+
605
915
} // namespace intel::experimental::matrix
606
916
607
917
} // namespace ext
0 commit comments