@@ -279,6 +279,29 @@ __ESIMD_INTRIN __ESIMD_raw_vec_t(T, N)
279
279
}
280
280
#endif // __SYCL_DEVICE_ONLY__
281
281
282
+ template <typename T, typename T0, typename T1, typename T2, int N, int N1,
283
+ int N2>
284
+ SYCL_EXTERNAL SYCL_ESIMD_FUNCTION __SEIEED::vector_type_t <T, N> __esimd_dpas (
285
+ __SEIEED::vector_type_t <T0, N> src0, __SEIEED::vector_type_t <T1, N1> src1,
286
+ __SEIEED::vector_type_t <T2, N2> src2, int src1_precision,
287
+ int src2_precision, int depth, int repeat, int sign_res, int sign_acc);
288
+
289
+ template <typename T, typename T1, typename T2, int N, int N1, int N2>
290
+ SYCL_EXTERNAL SYCL_ESIMD_FUNCTION __SEIEED::vector_type_t <T, N>
291
+ __esimd_dpas2 (__SEIEED::vector_type_t <T1, N1> src1,
292
+ __SEIEED::vector_type_t <T2, N2> src2, int dpas_info);
293
+
294
+ template <typename T, typename T1, typename T2, int N, int N1, int N2>
295
+ SYCL_EXTERNAL SYCL_ESIMD_FUNCTION __SEIEED::vector_type_t <T, N>
296
+ __esimd_dpasw (__SEIEED::vector_type_t <T, N> src0,
297
+ __SEIEED::vector_type_t <T1, N1> src1,
298
+ __SEIEED::vector_type_t <T2, N2> src2, int dpas_info);
299
+
300
+ template <typename T, typename T1, typename T2, int N, int N1, int N2>
301
+ SYCL_EXTERNAL SYCL_ESIMD_FUNCTION __SEIEED::vector_type_t <T, N>
302
+ __esimd_dpasw2 (__SEIEED::vector_type_t <T1, N1> src1,
303
+ __SEIEED::vector_type_t <T2, N2> src2, int dpas_info);
304
+
282
305
#ifdef __SYCL_DEVICE_ONLY__
283
306
284
307
// lane-id for reusing scalar math functions.
@@ -1198,6 +1221,276 @@ __ESIMD_INTRIN __ESIMD_raw_vec_t(T, N)
1198
1221
return retv;
1199
1222
}
1200
1223
1224
+ inline constexpr __SEIEE::uint
1225
+ __esimd_dpas_bits_precision (__SEIEE::argument_type precisionType) {
1226
+ return precisionType == __SEIEE::argument_type::TF32 ? 32
1227
+ : precisionType == __SEIEE::argument_type::BF16 ||
1228
+ precisionType == __SEIEE::argument_type::FP16
1229
+ ? 16
1230
+ : precisionType == __SEIEE::argument_type::S8 ||
1231
+ precisionType == __SEIEE::argument_type::U8
1232
+ ? 8
1233
+ : precisionType == __SEIEE::argument_type::S4 ||
1234
+ precisionType == __SEIEE::argument_type::U4
1235
+ ? 4
1236
+ : precisionType == __SEIEE::argument_type::S2 ||
1237
+ precisionType == __SEIEE::argument_type::U2
1238
+ ? 2
1239
+ : 1 ;
1240
+ }
1241
+
1242
+ template <__SEIEE::argument_type src1_precision,
1243
+ __SEIEE::argument_type src2_precision, int systolic_depth,
1244
+ int repeat_count, typename RT, typename T1, typename T2,
1245
+ __SEIEE::uint SZ, __SEIEE::uint N1, __SEIEE::uint N2>
1246
+ inline __SEIEED::vector_type_t <RT, SZ>
1247
+ __esimd_dpas_inner (const __SEIEED::vector_type_t <RT, SZ> *src0,
1248
+ const __SEIEED::vector_type_t <T1, N1> &src1,
1249
+ const __SEIEED::vector_type_t <T2, N2> &src2) {
1250
+ __SEIEED::vector_type_t <RT, SZ> retv;
1251
+
1252
+ __SEIEE::uint sat1 =
1253
+ __SEIEEED::SetSatur<T1, __SEIEEED::is_inttype<RT>::value>::set () ||
1254
+ __SEIEEED::SetSatur<T2, __SEIEEED::is_inttype<RT>::value>::set ();
1255
+
1256
+ constexpr __SEIEE::uint ops_per_chan =
1257
+ src1_precision == __SEIEE::argument_type::BF16 ||
1258
+ src1_precision == __SEIEE::argument_type::FP16 ||
1259
+ src2_precision == __SEIEE::argument_type::BF16 ||
1260
+ src2_precision == __SEIEE::argument_type::FP16
1261
+ ? 2
1262
+ : src1_precision == __SEIEE::argument_type::S8 ||
1263
+ src1_precision == __SEIEE::argument_type::U8 ||
1264
+ src2_precision == __SEIEE::argument_type::S8 ||
1265
+ src2_precision == __SEIEE::argument_type::U8
1266
+ ? 4
1267
+ : 8 ;
1268
+
1269
+ __SEIEE::uint V = 0 , U = 0 , k = 0 , temp = 0 , src1_ops_per_dword = 0 , p = 0 ;
1270
+
1271
+ constexpr auto src1_el_bits = __esimd_dpas_bits_precision (src1_precision);
1272
+ constexpr auto src2_el_bits = __esimd_dpas_bits_precision (src2_precision);
1273
+
1274
+ uint32_t src1_signed = src1_precision == __SEIEE::argument_type::S2 ||
1275
+ src1_precision == __SEIEE::argument_type::S4 ||
1276
+ src1_precision == __SEIEE::argument_type::S8
1277
+ ? 1
1278
+ : 0 ;
1279
+
1280
+ uint32_t src2_signed = src2_precision == __SEIEE::argument_type::S2 ||
1281
+ src2_precision == __SEIEE::argument_type::S4 ||
1282
+ src2_precision == __SEIEE::argument_type::S8
1283
+ ? 1
1284
+ : 0 ;
1285
+
1286
+ #if defined(ESIMD_XE_HPC) || defined(ESIMD_XE_HPG)
1287
+ constexpr bool isPvc = true ;
1288
+ constexpr size_t SIMDSize = 16 ;
1289
+ #else
1290
+ constexpr bool isPvc = false ;
1291
+ constexpr size_t SIMDSize = 8 ;
1292
+ #endif
1293
+
1294
+ constexpr bool
1295
+ pvcHfDest = isPvc && std::is_same<RT, __SEIEEED::half>::value,
1296
+ pvcBfDest = isPvc && std::is_same<RT, short >::value,
1297
+ pvcBfOrHfDest = pvcBfDest || pvcHfDest,
1298
+
1299
+ pvcBfDestChecks = pvcBfDest &&
1300
+ src1_precision == __SEIEE::argument_type::BF16 &&
1301
+ src2_precision == __SEIEE::argument_type::BF16,
1302
+
1303
+ pvcHfDestChecks =
1304
+ pvcHfDest && ((src1_precision == __SEIEE::argument_type::FP16 &&
1305
+ src2_precision == __SEIEE::argument_type::FP16) ||
1306
+ (src1_precision == __SEIEE::argument_type::BF16 &&
1307
+ src2_precision == __SEIEE::argument_type::BF16)),
1308
+
1309
+ destTypeChk =
1310
+ (!pvcBfOrHfDest && __SEIEEED::is_fp_or_dword_type<RT>::value) ||
1311
+ (pvcBfOrHfDest && (pvcBfDestChecks || pvcHfDestChecks)),
1312
+
1313
+ srcTypeChk = __SEIEEED::is_dword_type<T1>::value &&
1314
+ __SEIEEED::is_dword_type<T2>::value,
1315
+
1316
+ destSizeChk = SZ >= /* TODO: ==*/ SIMDSize * repeat_count,
1317
+
1318
+ systolicDepthAndRepeatCountChk =
1319
+ systolic_depth == 8 && repeat_count >= 1 && repeat_count <= 8 ,
1320
+
1321
+ src1CountChk =
1322
+ N1 == ((src1_el_bits * systolic_depth * ops_per_chan * SZ) /
1323
+ (repeat_count * sizeof (T1) * 8 )),
1324
+ src2CountChk =
1325
+ N2 >= ((src2_el_bits * systolic_depth * ops_per_chan * repeat_count) /
1326
+ (sizeof (T2) * 8 ))
1327
+ /* TODO: ==; fix PVCIGEMM24*/
1328
+ ;
1329
+
1330
+ if constexpr (!isPvc)
1331
+ static_assert (!pvcBfOrHfDest, " dpas: hfloat and bfloat16 destination "
1332
+ " element type is only supported on PVC." );
1333
+ static_assert (destTypeChk, " dpas: unsupported dest and accumulator type." );
1334
+ static_assert (srcTypeChk, " dpas: unsupported src element type." );
1335
+ static_assert (destSizeChk,
1336
+ " dpas: destination size must be SIMDSize x repeat_count." );
1337
+ static_assert (systolicDepthAndRepeatCountChk,
1338
+ " dpas: only systolic_depth = 8 and repeat_count of 1 to 8 are "
1339
+ " supported." );
1340
+ static_assert (src1CountChk, " dpas: invalid size for src1." );
1341
+ static_assert (src2CountChk, " dpas: invalid size for src2." );
1342
+
1343
+ using TmpAccEl = typename std::conditional<
1344
+ pvcBfOrHfDest, float ,
1345
+ typename __SEIEEED::restype_ex<
1346
+ RT, typename __SEIEEED::restype_ex<T1, T2>::type>::type>::type;
1347
+
1348
+ __SEIEED::vector_type_t <TmpAccEl, SIMDSize> simdAcc;
1349
+
1350
+ for (uint r = 0 ; r < repeat_count; r++) {
1351
+ V = r;
1352
+ k = 0 ;
1353
+
1354
+ for (uint n = 0 ; n < SIMDSize; n++) {
1355
+ if (src0 != nullptr ) {
1356
+ auto src0El = src0[0 ][r * SIMDSize + n];
1357
+
1358
+ if (pvcBfDest) {
1359
+ const auto tmp = (uint32_t )(src0El) << 16 ;
1360
+ simdAcc[n] = reinterpret_cast <const TmpAccEl &>(tmp);
1361
+ } else
1362
+ simdAcc[n] = src0El;
1363
+ } else
1364
+ simdAcc[n] = 0 ;
1365
+ }
1366
+
1367
+ for (uint s = 0 ; s < systolic_depth; s++) {
1368
+ src1_ops_per_dword = 32 / (ops_per_chan * src1_el_bits);
1369
+ // U = s / src1_ops_per_dword;
1370
+ U = s >> uint (log2 (src1_ops_per_dword));
1371
+
1372
+ for (uint n = 0 ; n < SIMDSize; n++) {
1373
+ for (uint d = 0 ; d < ops_per_chan; d++) {
1374
+ p = d + (s % src1_ops_per_dword) * ops_per_chan;
1375
+ uint32_t extension_temp = false ;
1376
+
1377
+ if (src2_precision == __SEIEE::argument_type::BF16) {
1378
+ const auto s1 =
1379
+ extract<uint32_t >(src1_el_bits, p * src1_el_bits,
1380
+ src1[U * SIMDSize + n], extension_temp)
1381
+ << 16 ;
1382
+ const auto s2 =
1383
+ extract<uint32_t >(src2_el_bits, d * src2_el_bits,
1384
+ src2[V * 8 + k / ops_per_chan], src2_signed)
1385
+ << 16 ;
1386
+ simdAcc[n] += reinterpret_cast <const float &>(s2) *
1387
+ reinterpret_cast <const float &>(s1);
1388
+ } else if (src2_precision == __SEIEE::argument_type::FP16) {
1389
+ const auto s1 =
1390
+ extract<short >(src1_el_bits, p * src1_el_bits,
1391
+ src1[U * SIMDSize + n], extension_temp);
1392
+ const auto s2 =
1393
+ extract<short >(src2_el_bits, d * src2_el_bits,
1394
+ src2[V * 8 + k / ops_per_chan], src2_signed);
1395
+ simdAcc[n] += reinterpret_cast <const __SEIEEED::half &>(s1) *
1396
+ reinterpret_cast <const __SEIEEED::half &>(s2);
1397
+ } else {
1398
+ int src = (sizeof (T2) * 8 ) / (ops_per_chan * src2_el_bits);
1399
+ int off = s % src * (ops_per_chan * src2_el_bits);
1400
+ int src1_tmp = extract<T1>(src1_el_bits, p * src1_el_bits,
1401
+ src1[U * SIMDSize + n], src1_signed);
1402
+ int src2_tmp = extract<T2>(src2_el_bits, d * src2_el_bits + off,
1403
+ src2[(V * 8 + k / ops_per_chan) / src],
1404
+ src2_signed);
1405
+ simdAcc[n] += src1_tmp * src2_tmp;
1406
+ }
1407
+ }
1408
+ }
1409
+
1410
+ k += ops_per_chan;
1411
+
1412
+ } // Systolic phase.
1413
+
1414
+ for (uint n = 0 ; n < SIMDSize; n++) {
1415
+ if constexpr (pvcBfDest) {
1416
+ // TODO: make abstraction, support saturation, review rounding algo for
1417
+ // corner cases.
1418
+ auto tmpFloat = simdAcc[n];
1419
+ auto tmpUint = reinterpret_cast <uint32_t &>(tmpFloat);
1420
+ if (std::isnormal (tmpFloat) && tmpUint & 1ull << 15 &&
1421
+ (tmpUint & 0x7fff || tmpUint & 1ull << 16 )) {
1422
+ tmpUint += 1ull << 16 ;
1423
+ }
1424
+ retv[r * SIMDSize + n] =
1425
+ static_cast <short >(reinterpret_cast <uint32_t &>(tmpUint) >> 16 );
1426
+ } else
1427
+ retv[r * SIMDSize + n] =
1428
+ __SEIEEED::satur<RT>::saturate (simdAcc[n], sat1);
1429
+ }
1430
+
1431
+ } // Repeat.
1432
+
1433
+ return retv;
1434
+ }
1435
+
1436
+ template <__SEIEE::argument_type src1_precision,
1437
+ __SEIEE::argument_type src2_precision, int systolic_depth,
1438
+ int repeat_count, typename T, typename T0, typename T1, typename T2,
1439
+ int N, int N1, int N2>
1440
+ inline __SEIEED::vector_type_t <T, N>
1441
+ __esimd_dpas (__SEIEED::vector_type_t <T0, N> src0,
1442
+ __SEIEED::vector_type_t <T1, N1> src1,
1443
+ __SEIEED::vector_type_t <T2, N2> src2) {
1444
+ #ifdef __SYCL_EXPLICIT_SIMD_PLUGIN__
1445
+ return __esimd_dpas_inner<src1_precision, src2_precision, systolic_depth,
1446
+ repeat_count, T, T1, T2, N, N1, N2>(
1447
+ std::addressof (src0), src1, src2);
1448
+ #else // __SYCL_EXPLICIT_SIMD_PLUGIN__
1449
+ throw cl::sycl::feature_not_supported ();
1450
+ return __SEIEED::vector_type_t <T, N>();
1451
+ #endif // __SYCL_EXPLICIT_SIMD_PLUGIN__
1452
+ }
1453
+
1454
+ template <__SEIEE::argument_type src1_precision,
1455
+ __SEIEE::argument_type src2_precision, int systolic_depth,
1456
+ int repeat_count, typename T, typename T1, typename T2, int N, int N1,
1457
+ int N2>
1458
+ inline __SEIEED::vector_type_t <T, N>
1459
+ __esimd_dpas2 (__SEIEED::vector_type_t <T1, N1> src1,
1460
+ __SEIEED::vector_type_t <T2, N2> src2) {
1461
+ #ifdef __SYCL_EXPLICIT_SIMD_PLUGIN__
1462
+ return __esimd_dpas_inner<src1_precision, src2_precision, systolic_depth,
1463
+ repeat_count, T, T1, T2, N, N1, N2>(nullptr , src1,
1464
+ src2);
1465
+ #else // __SYCL_EXPLICIT_SIMD_PLUGIN__
1466
+ throw cl::sycl::feature_not_supported ();
1467
+ return __SEIEED::vector_type_t <T, N>();
1468
+ #endif // __SYCL_EXPLICIT_SIMD_PLUGIN__
1469
+ }
1470
+
1471
+ template <__SEIEE::argument_type src1_precision,
1472
+ __SEIEE::argument_type src2_precision, int systolic_depth,
1473
+ int repeat_count, typename T, typename T1, typename T2, int N, int N1,
1474
+ int N2>
1475
+ inline __SEIEED::vector_type_t <T, N>
1476
+ __esimd_dpasw (__SEIEED::vector_type_t <T, N> src0,
1477
+ __SEIEED::vector_type_t <T1, N1> src1,
1478
+ __SEIEED::vector_type_t <T2, N2> src2) {
1479
+ throw cl::sycl::feature_not_supported ();
1480
+ return __SEIEED::vector_type_t <T, N>();
1481
+ }
1482
+
1483
+ template <__SEIEE::argument_type src1_precision,
1484
+ __SEIEE::argument_type src2_precision, int systolic_depth,
1485
+ int repeat_count, typename T, typename T1, typename T2, int N, int N1,
1486
+ int N2>
1487
+ inline __SEIEED::vector_type_t <T, N>
1488
+ __esimd_dpasw2 (__SEIEED::vector_type_t <T1, N1> src1,
1489
+ __SEIEED::vector_type_t <T2, N2> src2) {
1490
+ throw cl::sycl::feature_not_supported ();
1491
+ return __SEIEED::vector_type_t <T, N>();
1492
+ }
1493
+
1201
1494
#endif // #ifdef __SYCL_DEVICE_ONLY__
1202
1495
1203
1496
#undef __ESIMD_raw_vec_t
0 commit comments