@@ -2056,60 +2056,99 @@ class joint_matrix {
20562056 const size_t num_elements;
20572057};
20582058
2059+ // / Loads 1 8x8 b16 matrix from shared memory to local memory (32-bits per wi)
2060+ // / Requires the sub-group size of kernel calling this function to be 32
2061+ // / \tparam [in] T The type of result variable
2062+ // / \param [in] addr The address of the matrix in shared memory
2063+ // / \param [in] m The local memory to store the matrix
2064+ // / \param [in] item_ct1 The sycl::nd_item object
2065+ // / \param [in] trans Indicates whether the matrix to be loaded transposed
2066+ // / \param [in] mat The matrix index to be loaded
20592067template <typename T>
20602068void ldmatrix (uintptr_t addr, T *m, const sycl::nd_item<3 > &item_ct1,
20612069 bool trans = false , unsigned mat = 0 ) {
2062- int lane = item_ct1.get_local_id (2 );
2070+ int lane = item_ct1.get_local_id (2 ) % 32 ;
20632071
2064- int group = lane / 8 ;
2065- int sub = lane % 8 ;
2066- int src_base = group * 2 ;
2072+ int lane_group8_row = lane / 8 ;
2073+ int lane_group8_col = lane % 8 ;
20672074
20682075 if (!trans) {
20692076 // calculate the source lane
2070- int src_lane = (sub / 4 ) ? (src_base + 1 ) : src_base;
2077+ int src_lane = 2 * lane_group8_row;
2078+ if (lane_group8_col >= 4 )
2079+ src_lane += 1 ;
20712080
20722081 // Broadcast the address from the source lane
20732082 auto recv_addr_uintp = dpct::select_from_sub_group (
20742083 item_ct1.get_sub_group (), addr, mat * 8 + src_lane);
2084+
2085+ // Cast the received address from uintptr_t to the type of 'm'
20752086 auto recv_addr = reinterpret_cast <T *>(recv_addr_uintp);
20762087
20772088 // Non-transposed load
2078- *m = recv_addr[sub % 4 ];
2089+ *m = recv_addr[lane_group8_col % 4 ];
20792090 } else {
20802091 // calculate the source lane
20812092 int src_lane = (lane % 4 ) * 2 ;
20822093
2083- // Broadcast the address from the source lane:
2094+ // Broadcast the address from the source lane
20842095 auto recv_addr_uintp_1 = dpct::select_from_sub_group (
20852096 item_ct1.get_sub_group (), addr, mat * 8 + src_lane);
20862097 auto recv_addr_uintp_2 = dpct::select_from_sub_group (
20872098 item_ct1.get_sub_group (), addr, mat * 8 + src_lane + 1 );
2099+
2100+ // Cast the received address from uintptr_t to 'half *'
20882101 auto recv_addr_1 = reinterpret_cast <sycl::half *>(recv_addr_uintp_1);
20892102 auto recv_addr_2 = reinterpret_cast <sycl::half *>(recv_addr_uintp_2);
20902103
20912104 // Transposed load
2092- int index = ( lane / 4 ) ;
2105+ int index = lane / 4 ;
20932106 sycl::half val0 = recv_addr_1[index];
20942107 sycl::half val1 = recv_addr_2[index];
2108+
2109+ // Combine the two 16-bits into one 32-bit value
20952110 sycl::half2 val = sycl::half2 (val0, val1);
20962111 *m = *reinterpret_cast <T *>(&val);
20972112 }
20982113}
20992114
2115+ // / Loads 2 8x8 b16 matrix from shared memory to local memory (32-bits per wi)
2116+ // / Requires the sub-group size of kernel calling this function to be 32
2117+ // / \tparam [in] T The type of result variable
2118+ // / \param [in] addr The address of the matrix in shared memory
2119+ // / \param [in] m1 The local memory to store data of 1st matrix
2120+ // / \param [in] m2 The local memory to store data of 2nd matrix
2121+ // / \param [in] item_ct1 The sycl::nd_item object
2122+ // / \param [in] trans Indicates whether the matrix to be loaded transposed
21002123template <typename T>
21012124void ldmatrix (uintptr_t addr, T *m1, T *m2, const sycl::nd_item<3 > &item_ct1,
21022125 bool trans = false ) {
2126+ // Load 1st matrix
21032127 ldmatrix (addr, m1, item_ct1, trans, 0 );
2128+ // Load 2nd matrix
21042129 ldmatrix (addr, m2, item_ct1, trans, 1 );
21052130}
21062131
2132+ // / Loads 4 8x8 b16 matrix from shared memory to local memory (32-bits per wi)
2133+ // / Requires the sub-group size of kernel calling this function to be 32
2134+ // / \tparam [in] T The type of result variable
2135+ // / \param [in] addr The address of the matrix in shared memory
2136+ // / \param [in] m1 The local memory to store data of 1st matrix
2137+ // / \param [in] m2 The local memory to store data of 2nd matrix
2138+ // / \param [in] m3 The local memory to store data of 3rd matrix
2139+ // / \param [in] m4 The local memory to store data of 4th matrix
2140+ // / \param [in] item_ct1 The sycl::nd_item object
2141+ // / \param [in] trans Indicates whether the matrix to be loaded transposed
21072142template <typename T>
21082143void ldmatrix (uintptr_t addr, T *m1, T *m2, T *m3, T *m4,
21092144 const sycl::nd_item<3 > &item_ct1, bool trans = false ) {
2145+ // Load 1st matrix
21102146 ldmatrix (addr, m1, item_ct1, trans, 0 );
2147+ // Load 2nd matrix
21112148 ldmatrix (addr, m2, item_ct1, trans, 1 );
2149+ // Load 3rd matrix
21122150 ldmatrix (addr, m3, item_ct1, trans, 2 );
2151+ // Load 4th matrix
21132152 ldmatrix (addr, m4, item_ct1, trans, 3 );
21142153}
21152154
0 commit comments