Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -339,11 +339,11 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_load(
#endif // defined(__SYCL_DEVICE_ONLY__)
}

template <typename Group, typename T, size_t NumRows, size_t NumCols,
template <typename Group, typename S, typename T, size_t NumRows, size_t NumCols,
access::address_space Space, access::decorated IsDecorated>
inline __SYCL_ALWAYS_INLINE void joint_matrix_store(
Group sg,
const joint_matrix<Group, T, use::accumulator, NumRows, NumCols,
const joint_matrix<Group, S, use::accumulator, NumRows, NumCols,
sycl::ext::oneapi::experimental::matrix::layout::dynamic>
&src,
multi_ptr<T, Space, IsDecorated> dst, size_t stride,
Expand All @@ -365,7 +365,7 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_store(
using DecorT = typename sycl::detail::DecoratedType<T, Space>::type;
DecorT *Ptr = sycl::detail::getDecorated<DecorT>(dst);
__spirv_CooperativeMatrixStoreKHR<
DecorT, T, NumRows, NumCols,
DecorT, S, NumRows, NumCols,
spv_matrix_use_traits<use::accumulator>::value,
spv_matrix_layout_traits<layout::dynamic>::value>(
Ptr, src.spvm, sycl::detail::joint_matrix_layout_to_spv(Layout), stride);
Expand All @@ -381,11 +381,11 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_store(
#endif // defined(__SYCL_DEVICE_ONLY__)
}

template <typename Group, typename T, size_t NumRows, size_t NumCols,
template <typename Group, typename S, typename T, size_t NumRows, size_t NumCols,
typename PropertyListT>
inline __SYCL_ALWAYS_INLINE void joint_matrix_store(
Group sg,
const joint_matrix<Group, T, use::accumulator, NumRows, NumCols,
const joint_matrix<Group, S, use::accumulator, NumRows, NumCols,
sycl::ext::oneapi::experimental::matrix::layout::dynamic>
&src,
ext::oneapi::experimental::annotated_ptr<T, PropertyListT> dst,
Expand All @@ -402,7 +402,7 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_store(
(void)sg;
T *Ptr = dst.get();
__spirv_CooperativeMatrixStoreKHR<
T, T, NumRows, NumCols, spv_matrix_use_traits<use::accumulator>::value,
T, S, NumRows, NumCols, spv_matrix_use_traits<use::accumulator>::value,
spv_matrix_layout_traits<layout::dynamic>::value>(
Ptr, src.spvm, sycl::detail::joint_matrix_layout_to_spv(Layout), stride);
#endif // defined(__NVPTX__)
Expand Down
Loading