Skip to content

Commit 5881938

Browse files
authored
[SYCL][ESIMD] Add support for dpas API (#5637)
Signed-off-by: Sergey Dmitriev <[email protected]>
1 parent 5e6995e commit 5881938

File tree

6 files changed

+862
-0
lines changed

6 files changed

+862
-0
lines changed

llvm/lib/SYCLLowerIR/ESIMD/LowerESIMD.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -446,6 +446,11 @@ class ESIMDIntrinDescTable {
446446
{"raw_send2_noresult",
447447
{"raw.send2.noresult",
448448
{a(0), a(1), ai1(2), a(3), a(4), a(5), a(6), a(7)}}},
449+
{"dpas",
450+
{"dpas2", {a(0), a(1), a(2), a(3), a(4), a(5), a(6), a(7), a(8)}}},
451+
{"dpas2", {"dpas.nosrc0", {a(0), a(1), a(2)}}},
452+
{"dpasw", {"dpasw", {a(0), a(1), a(2), a(3)}}},
453+
{"dpasw2", {"dpasw.nosrc0", {a(0), a(1), a(2)}}},
449454
{"nbarrier", {"nbarrier", {a(0), a(1), a(2)}}},
450455
{"raw_send_nbarrier_signal",
451456
{"raw.send.noresult", {a(0), ai1(4), a(1), a(2), a(3)}}},

sycl/include/sycl/ext/intel/experimental/esimd/common.hpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,20 @@ static inline constexpr saturation_off_tag saturation_off{};
9595
/// Type tag object representing "saturation on" behavior.
9696
static inline constexpr saturation_on_tag saturation_on{};
9797

98+
enum class argument_type {
99+
U1 = 0, // unsigned 1 bit
100+
S1 = 1, // signed 1 bit
101+
U2 = 2, // unsigned 2 bits
102+
S2 = 3, // signed 2 bits
103+
U4 = 4, // unsigned 4 bits
104+
S4 = 5, // signed 4 bits
105+
U8 = 6, // unsigned 8 bits
106+
S8 = 7, // signed 8 bits
107+
BF16 = 8, // bfloat 16
108+
FP16 = 9, // half float
109+
TF32 = 11 // tensorfloat 32
110+
};
111+
98112
/// Represents a pixel's channel.
99113
enum class rgba_channel : uint8_t { R, G, B, A };
100114

sycl/include/sycl/ext/intel/experimental/esimd/detail/math_intrin.hpp

Lines changed: 293 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,29 @@ __ESIMD_INTRIN __ESIMD_raw_vec_t(T, N)
279279
}
280280
#endif // __SYCL_DEVICE_ONLY__
281281

282+
template <typename T, typename T0, typename T1, typename T2, int N, int N1,
283+
int N2>
284+
SYCL_EXTERNAL SYCL_ESIMD_FUNCTION __SEIEED::vector_type_t<T, N> __esimd_dpas(
285+
__SEIEED::vector_type_t<T0, N> src0, __SEIEED::vector_type_t<T1, N1> src1,
286+
__SEIEED::vector_type_t<T2, N2> src2, int src1_precision,
287+
int src2_precision, int depth, int repeat, int sign_res, int sign_acc);
288+
289+
template <typename T, typename T1, typename T2, int N, int N1, int N2>
290+
SYCL_EXTERNAL SYCL_ESIMD_FUNCTION __SEIEED::vector_type_t<T, N>
291+
__esimd_dpas2(__SEIEED::vector_type_t<T1, N1> src1,
292+
__SEIEED::vector_type_t<T2, N2> src2, int dpas_info);
293+
294+
template <typename T, typename T1, typename T2, int N, int N1, int N2>
295+
SYCL_EXTERNAL SYCL_ESIMD_FUNCTION __SEIEED::vector_type_t<T, N>
296+
__esimd_dpasw(__SEIEED::vector_type_t<T, N> src0,
297+
__SEIEED::vector_type_t<T1, N1> src1,
298+
__SEIEED::vector_type_t<T2, N2> src2, int dpas_info);
299+
300+
template <typename T, typename T1, typename T2, int N, int N1, int N2>
301+
SYCL_EXTERNAL SYCL_ESIMD_FUNCTION __SEIEED::vector_type_t<T, N>
302+
__esimd_dpasw2(__SEIEED::vector_type_t<T1, N1> src1,
303+
__SEIEED::vector_type_t<T2, N2> src2, int dpas_info);
304+
282305
#ifdef __SYCL_DEVICE_ONLY__
283306

284307
// lane-id for reusing scalar math functions.
@@ -1198,6 +1221,276 @@ __ESIMD_INTRIN __ESIMD_raw_vec_t(T, N)
11981221
return retv;
11991222
}
12001223

1224+
inline constexpr __SEIEE::uint
1225+
__esimd_dpas_bits_precision(__SEIEE::argument_type precisionType) {
1226+
return precisionType == __SEIEE::argument_type::TF32 ? 32
1227+
: precisionType == __SEIEE::argument_type::BF16 ||
1228+
precisionType == __SEIEE::argument_type::FP16
1229+
? 16
1230+
: precisionType == __SEIEE::argument_type::S8 ||
1231+
precisionType == __SEIEE::argument_type::U8
1232+
? 8
1233+
: precisionType == __SEIEE::argument_type::S4 ||
1234+
precisionType == __SEIEE::argument_type::U4
1235+
? 4
1236+
: precisionType == __SEIEE::argument_type::S2 ||
1237+
precisionType == __SEIEE::argument_type::U2
1238+
? 2
1239+
: 1;
1240+
}
1241+
1242+
template <__SEIEE::argument_type src1_precision,
1243+
__SEIEE::argument_type src2_precision, int systolic_depth,
1244+
int repeat_count, typename RT, typename T1, typename T2,
1245+
__SEIEE::uint SZ, __SEIEE::uint N1, __SEIEE::uint N2>
1246+
inline __SEIEED::vector_type_t<RT, SZ>
1247+
__esimd_dpas_inner(const __SEIEED::vector_type_t<RT, SZ> *src0,
1248+
const __SEIEED::vector_type_t<T1, N1> &src1,
1249+
const __SEIEED::vector_type_t<T2, N2> &src2) {
1250+
__SEIEED::vector_type_t<RT, SZ> retv;
1251+
1252+
__SEIEE::uint sat1 =
1253+
__SEIEEED::SetSatur<T1, __SEIEEED::is_inttype<RT>::value>::set() ||
1254+
__SEIEEED::SetSatur<T2, __SEIEEED::is_inttype<RT>::value>::set();
1255+
1256+
constexpr __SEIEE::uint ops_per_chan =
1257+
src1_precision == __SEIEE::argument_type::BF16 ||
1258+
src1_precision == __SEIEE::argument_type::FP16 ||
1259+
src2_precision == __SEIEE::argument_type::BF16 ||
1260+
src2_precision == __SEIEE::argument_type::FP16
1261+
? 2
1262+
: src1_precision == __SEIEE::argument_type::S8 ||
1263+
src1_precision == __SEIEE::argument_type::U8 ||
1264+
src2_precision == __SEIEE::argument_type::S8 ||
1265+
src2_precision == __SEIEE::argument_type::U8
1266+
? 4
1267+
: 8;
1268+
1269+
__SEIEE::uint V = 0, U = 0, k = 0, temp = 0, src1_ops_per_dword = 0, p = 0;
1270+
1271+
constexpr auto src1_el_bits = __esimd_dpas_bits_precision(src1_precision);
1272+
constexpr auto src2_el_bits = __esimd_dpas_bits_precision(src2_precision);
1273+
1274+
uint32_t src1_signed = src1_precision == __SEIEE::argument_type::S2 ||
1275+
src1_precision == __SEIEE::argument_type::S4 ||
1276+
src1_precision == __SEIEE::argument_type::S8
1277+
? 1
1278+
: 0;
1279+
1280+
uint32_t src2_signed = src2_precision == __SEIEE::argument_type::S2 ||
1281+
src2_precision == __SEIEE::argument_type::S4 ||
1282+
src2_precision == __SEIEE::argument_type::S8
1283+
? 1
1284+
: 0;
1285+
1286+
#if defined(ESIMD_XE_HPC) || defined(ESIMD_XE_HPG)
1287+
constexpr bool isPvc = true;
1288+
constexpr size_t SIMDSize = 16;
1289+
#else
1290+
constexpr bool isPvc = false;
1291+
constexpr size_t SIMDSize = 8;
1292+
#endif
1293+
1294+
constexpr bool
1295+
pvcHfDest = isPvc && std::is_same<RT, __SEIEEED::half>::value,
1296+
pvcBfDest = isPvc && std::is_same<RT, short>::value,
1297+
pvcBfOrHfDest = pvcBfDest || pvcHfDest,
1298+
1299+
pvcBfDestChecks = pvcBfDest &&
1300+
src1_precision == __SEIEE::argument_type::BF16 &&
1301+
src2_precision == __SEIEE::argument_type::BF16,
1302+
1303+
pvcHfDestChecks =
1304+
pvcHfDest && ((src1_precision == __SEIEE::argument_type::FP16 &&
1305+
src2_precision == __SEIEE::argument_type::FP16) ||
1306+
(src1_precision == __SEIEE::argument_type::BF16 &&
1307+
src2_precision == __SEIEE::argument_type::BF16)),
1308+
1309+
destTypeChk =
1310+
(!pvcBfOrHfDest && __SEIEEED::is_fp_or_dword_type<RT>::value) ||
1311+
(pvcBfOrHfDest && (pvcBfDestChecks || pvcHfDestChecks)),
1312+
1313+
srcTypeChk = __SEIEEED::is_dword_type<T1>::value &&
1314+
__SEIEEED::is_dword_type<T2>::value,
1315+
1316+
destSizeChk = SZ >= /*TODO: ==*/SIMDSize * repeat_count,
1317+
1318+
systolicDepthAndRepeatCountChk =
1319+
systolic_depth == 8 && repeat_count >= 1 && repeat_count <= 8,
1320+
1321+
src1CountChk =
1322+
N1 == ((src1_el_bits * systolic_depth * ops_per_chan * SZ) /
1323+
(repeat_count * sizeof(T1) * 8)),
1324+
src2CountChk =
1325+
N2 >= ((src2_el_bits * systolic_depth * ops_per_chan * repeat_count) /
1326+
(sizeof(T2) * 8))
1327+
/*TODO: ==; fix PVCIGEMM24*/
1328+
;
1329+
1330+
if constexpr (!isPvc)
1331+
static_assert(!pvcBfOrHfDest, "dpas: hfloat and bfloat16 destination "
1332+
"element type is only supported on PVC.");
1333+
static_assert(destTypeChk, "dpas: unsupported dest and accumulator type.");
1334+
static_assert(srcTypeChk, "dpas: unsupported src element type.");
1335+
static_assert(destSizeChk,
1336+
"dpas: destination size must be SIMDSize x repeat_count.");
1337+
static_assert(systolicDepthAndRepeatCountChk,
1338+
"dpas: only systolic_depth = 8 and repeat_count of 1 to 8 are "
1339+
"supported.");
1340+
static_assert(src1CountChk, "dpas: invalid size for src1.");
1341+
static_assert(src2CountChk, "dpas: invalid size for src2.");
1342+
1343+
using TmpAccEl = typename std::conditional<
1344+
pvcBfOrHfDest, float,
1345+
typename __SEIEEED::restype_ex<
1346+
RT, typename __SEIEEED::restype_ex<T1, T2>::type>::type>::type;
1347+
1348+
__SEIEED::vector_type_t<TmpAccEl, SIMDSize> simdAcc;
1349+
1350+
for (uint r = 0; r < repeat_count; r++) {
1351+
V = r;
1352+
k = 0;
1353+
1354+
for (uint n = 0; n < SIMDSize; n++) {
1355+
if (src0 != nullptr) {
1356+
auto src0El = src0[0][r * SIMDSize + n];
1357+
1358+
if (pvcBfDest) {
1359+
const auto tmp = (uint32_t)(src0El) << 16;
1360+
simdAcc[n] = reinterpret_cast<const TmpAccEl &>(tmp);
1361+
} else
1362+
simdAcc[n] = src0El;
1363+
} else
1364+
simdAcc[n] = 0;
1365+
}
1366+
1367+
for (uint s = 0; s < systolic_depth; s++) {
1368+
src1_ops_per_dword = 32 / (ops_per_chan * src1_el_bits);
1369+
// U = s / src1_ops_per_dword;
1370+
U = s >> uint(log2(src1_ops_per_dword));
1371+
1372+
for (uint n = 0; n < SIMDSize; n++) {
1373+
for (uint d = 0; d < ops_per_chan; d++) {
1374+
p = d + (s % src1_ops_per_dword) * ops_per_chan;
1375+
uint32_t extension_temp = false;
1376+
1377+
if (src2_precision == __SEIEE::argument_type::BF16) {
1378+
const auto s1 =
1379+
extract<uint32_t>(src1_el_bits, p * src1_el_bits,
1380+
src1[U * SIMDSize + n], extension_temp)
1381+
<< 16;
1382+
const auto s2 =
1383+
extract<uint32_t>(src2_el_bits, d * src2_el_bits,
1384+
src2[V * 8 + k / ops_per_chan], src2_signed)
1385+
<< 16;
1386+
simdAcc[n] += reinterpret_cast<const float &>(s2) *
1387+
reinterpret_cast<const float &>(s1);
1388+
} else if (src2_precision == __SEIEE::argument_type::FP16) {
1389+
const auto s1 =
1390+
extract<short>(src1_el_bits, p * src1_el_bits,
1391+
src1[U * SIMDSize + n], extension_temp);
1392+
const auto s2 =
1393+
extract<short>(src2_el_bits, d * src2_el_bits,
1394+
src2[V * 8 + k / ops_per_chan], src2_signed);
1395+
simdAcc[n] += reinterpret_cast<const __SEIEEED::half &>(s1) *
1396+
reinterpret_cast<const __SEIEEED::half &>(s2);
1397+
} else {
1398+
int src = (sizeof(T2) * 8) / (ops_per_chan * src2_el_bits);
1399+
int off = s % src * (ops_per_chan * src2_el_bits);
1400+
int src1_tmp = extract<T1>(src1_el_bits, p * src1_el_bits,
1401+
src1[U * SIMDSize + n], src1_signed);
1402+
int src2_tmp = extract<T2>(src2_el_bits, d * src2_el_bits + off,
1403+
src2[(V * 8 + k / ops_per_chan) / src],
1404+
src2_signed);
1405+
simdAcc[n] += src1_tmp * src2_tmp;
1406+
}
1407+
}
1408+
}
1409+
1410+
k += ops_per_chan;
1411+
1412+
} // Systolic phase.
1413+
1414+
for (uint n = 0; n < SIMDSize; n++) {
1415+
if constexpr (pvcBfDest) {
1416+
// TODO: make abstraction, support saturation, review rounding algo for
1417+
// corner cases.
1418+
auto tmpFloat = simdAcc[n];
1419+
auto tmpUint = reinterpret_cast<uint32_t &>(tmpFloat);
1420+
if (std::isnormal(tmpFloat) && tmpUint & 1ull << 15 &&
1421+
(tmpUint & 0x7fff || tmpUint & 1ull << 16)) {
1422+
tmpUint += 1ull << 16;
1423+
}
1424+
retv[r * SIMDSize + n] =
1425+
static_cast<short>(reinterpret_cast<uint32_t &>(tmpUint) >> 16);
1426+
} else
1427+
retv[r * SIMDSize + n] =
1428+
__SEIEEED::satur<RT>::saturate(simdAcc[n], sat1);
1429+
}
1430+
1431+
} // Repeat.
1432+
1433+
return retv;
1434+
}
1435+
1436+
template <__SEIEE::argument_type src1_precision,
1437+
__SEIEE::argument_type src2_precision, int systolic_depth,
1438+
int repeat_count, typename T, typename T0, typename T1, typename T2,
1439+
int N, int N1, int N2>
1440+
inline __SEIEED::vector_type_t<T, N>
1441+
__esimd_dpas(__SEIEED::vector_type_t<T0, N> src0,
1442+
__SEIEED::vector_type_t<T1, N1> src1,
1443+
__SEIEED::vector_type_t<T2, N2> src2) {
1444+
#ifdef __SYCL_EXPLICIT_SIMD_PLUGIN__
1445+
return __esimd_dpas_inner<src1_precision, src2_precision, systolic_depth,
1446+
repeat_count, T, T1, T2, N, N1, N2>(
1447+
std::addressof(src0), src1, src2);
1448+
#else // __SYCL_EXPLICIT_SIMD_PLUGIN__
1449+
throw cl::sycl::feature_not_supported();
1450+
return __SEIEED::vector_type_t<T, N>();
1451+
#endif // __SYCL_EXPLICIT_SIMD_PLUGIN__
1452+
}
1453+
1454+
template <__SEIEE::argument_type src1_precision,
1455+
__SEIEE::argument_type src2_precision, int systolic_depth,
1456+
int repeat_count, typename T, typename T1, typename T2, int N, int N1,
1457+
int N2>
1458+
inline __SEIEED::vector_type_t<T, N>
1459+
__esimd_dpas2(__SEIEED::vector_type_t<T1, N1> src1,
1460+
__SEIEED::vector_type_t<T2, N2> src2) {
1461+
#ifdef __SYCL_EXPLICIT_SIMD_PLUGIN__
1462+
return __esimd_dpas_inner<src1_precision, src2_precision, systolic_depth,
1463+
repeat_count, T, T1, T2, N, N1, N2>(nullptr, src1,
1464+
src2);
1465+
#else // __SYCL_EXPLICIT_SIMD_PLUGIN__
1466+
throw cl::sycl::feature_not_supported();
1467+
return __SEIEED::vector_type_t<T, N>();
1468+
#endif // __SYCL_EXPLICIT_SIMD_PLUGIN__
1469+
}
1470+
1471+
template <__SEIEE::argument_type src1_precision,
1472+
__SEIEE::argument_type src2_precision, int systolic_depth,
1473+
int repeat_count, typename T, typename T1, typename T2, int N, int N1,
1474+
int N2>
1475+
inline __SEIEED::vector_type_t<T, N>
1476+
__esimd_dpasw(__SEIEED::vector_type_t<T, N> src0,
1477+
__SEIEED::vector_type_t<T1, N1> src1,
1478+
__SEIEED::vector_type_t<T2, N2> src2) {
1479+
throw cl::sycl::feature_not_supported();
1480+
return __SEIEED::vector_type_t<T, N>();
1481+
}
1482+
1483+
template <__SEIEE::argument_type src1_precision,
1484+
__SEIEE::argument_type src2_precision, int systolic_depth,
1485+
int repeat_count, typename T, typename T1, typename T2, int N, int N1,
1486+
int N2>
1487+
inline __SEIEED::vector_type_t<T, N>
1488+
__esimd_dpasw2(__SEIEED::vector_type_t<T1, N1> src1,
1489+
__SEIEED::vector_type_t<T2, N2> src2) {
1490+
throw cl::sycl::feature_not_supported();
1491+
return __SEIEED::vector_type_t<T, N>();
1492+
}
1493+
12011494
#endif // #ifdef __SYCL_DEVICE_ONLY__
12021495

12031496
#undef __ESIMD_raw_vec_t

sycl/include/sycl/ext/intel/experimental/esimd/detail/util.hpp

100755100644
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,36 @@ using is_fp_or_dword_type =
111111
typename std::bool_constant<is_fp_type<T>::value ||
112112
is_dword_type<T>::value>;
113113

114+
/// Compile-time checks if first template parameter is equal for any other
115+
template <typename...> struct is_one_of {
116+
static constexpr bool value = false;
117+
};
118+
119+
template <typename Checked, typename First, typename... Other>
120+
struct is_one_of<Checked, First, Other...> {
121+
static constexpr bool value =
122+
std::is_same<typename std::remove_const<Checked>::type,
123+
typename std::remove_const<First>::type>::value ||
124+
is_one_of<Checked, Other...>::value;
125+
};
126+
template <typename Checked, typename... T>
127+
inline constexpr bool is_one_of_v = is_one_of<Checked, T...>::value;
128+
129+
/// Compile-time checks if compile-time known element of enum class is equal
130+
/// for any other compile-time known elements of enum
131+
template <typename enumClass, enumClass... E> struct is_one_of_enum {
132+
static constexpr bool value = false;
133+
};
134+
135+
template <typename enumClass, enumClass Checked, enumClass First,
136+
enumClass... Else>
137+
struct is_one_of_enum<enumClass, Checked, First, Else...> {
138+
static constexpr bool value =
139+
(Checked == First) || is_one_of_enum<enumClass, Checked, Else...>::value;
140+
};
141+
template <typename enumClass, enumClass... T>
142+
inline constexpr bool is_one_of_enum_v = is_one_of_enum<enumClass, T...>::value;
143+
114144
/// Convert types into vector types
115145
template <typename T> struct simd_type { using type = simd<T, 1>; };
116146
template <typename T, int N> struct simd_type<raw_vector_type<T, N>> {

0 commit comments

Comments
 (0)