Skip to content

Commit c5c7ac2

Browse files
[SYCL] Add marray support to common + some math functions (#8631)
This patch adds marray support to all functions from Table 179 of SYCL 2020 spec + to functions fabs, ilogb, fmax, fmin, ldexp, pown, rootn from Table 175 + to function exp10 from Table 177. E2E tests: intel/llvm-test-suite#1656 --------- Co-authored-by: KornevNikita <[email protected]>
1 parent 7c7efee commit c5c7ac2

File tree

1 file changed

+182
-0
lines changed

1 file changed

+182
-0
lines changed

sycl/include/sycl/builtins.hpp

+182
Original file line numberDiff line numberDiff line change
@@ -122,10 +122,27 @@ __SYCL_MATH_FUNCTION_OVERLOAD_FM(log2)
122122
__SYCL_MATH_FUNCTION_OVERLOAD_FM(log10)
123123
__SYCL_MATH_FUNCTION_OVERLOAD_FM(sqrt)
124124
__SYCL_MATH_FUNCTION_OVERLOAD_FM(rsqrt)
125+
__SYCL_MATH_FUNCTION_OVERLOAD_FM(fabs)
125126

126127
#undef __SYCL_MATH_FUNCTION_OVERLOAD_FM
127128
#undef __SYCL_MATH_FUNCTION_OVERLOAD_IMPL
128129

130+
template <typename T, size_t N>
131+
inline __SYCL_ALWAYS_INLINE
132+
std::enable_if_t<detail::is_sgenfloat<T>::value, marray<int, N>>
133+
ilogb(marray<T, N> x) __NOEXC {
134+
marray<int, N> res;
135+
for (size_t i = 0; i < N / 2; i++) {
136+
vec<int, 2> partial_res =
137+
__sycl_std::__invoke_ilogb<vec<int, 2>>(detail::to_vec2(x, i * 2));
138+
std::memcpy(&res[i * 2], &partial_res, sizeof(vec<int, 2>));
139+
}
140+
if (N % 2) {
141+
res[N - 1] = __sycl_std::__invoke_ilogb<int>(x[N - 1]);
142+
}
143+
return res;
144+
}
145+
129146
#define __SYCL_MATH_FUNCTION_2_OVERLOAD_IMPL(NAME) \
130147
marray<T, N> res; \
131148
for (size_t i = 0; i < N / 2; i++) { \
@@ -170,6 +187,98 @@ inline __SYCL_ALWAYS_INLINE
170187

171188
#undef __SYCL_MATH_FUNCTION_2_OVERLOAD_IMPL
172189

190+
#define __SYCL_MATH_FUNCTION_2_SGENFLOAT_Y_OVERLOAD(NAME) \
191+
template <typename T, size_t N> \
192+
inline __SYCL_ALWAYS_INLINE \
193+
std::enable_if_t<detail::is_sgenfloat<T>::value, marray<T, N>> \
194+
NAME(marray<T, N> x, T y) __NOEXC { \
195+
marray<T, N> res; \
196+
sycl::vec<T, 2> y_vec{y, y}; \
197+
for (size_t i = 0; i < N / 2; i++) { \
198+
auto partial_res = __sycl_std::__invoke_##NAME<vec<T, 2>>( \
199+
detail::to_vec2(x, i * 2), y_vec); \
200+
std::memcpy(&res[i * 2], &partial_res, sizeof(vec<T, 2>)); \
201+
} \
202+
if (N % 2) { \
203+
res[N - 1] = __sycl_std::__invoke_##NAME<T>(x[N - 1], y_vec[0]); \
204+
} \
205+
return res; \
206+
}
207+
208+
__SYCL_MATH_FUNCTION_2_SGENFLOAT_Y_OVERLOAD(fmax)
209+
// clang-format off
210+
__SYCL_MATH_FUNCTION_2_SGENFLOAT_Y_OVERLOAD(fmin)
211+
212+
#undef __SYCL_MATH_FUNCTION_2_SGENFLOAT_Y_OVERLOAD
213+
214+
template <typename T, size_t N>
215+
inline __SYCL_ALWAYS_INLINE
216+
std::enable_if_t<detail::is_sgenfloat<T>::value, marray<T, N>>
217+
ldexp(marray<T, N> x, marray<int, N> k) __NOEXC {
218+
// clang-format on
219+
marray<T, N> res;
220+
for (size_t i = 0; i < N; i++) {
221+
res[i] = __sycl_std::__invoke_ldexp<T>(x[i], k[i]);
222+
}
223+
return res;
224+
}
225+
226+
template <typename T, size_t N>
227+
inline __SYCL_ALWAYS_INLINE
228+
std::enable_if_t<detail::is_sgenfloat<T>::value, marray<T, N>>
229+
ldexp(marray<T, N> x, int k) __NOEXC {
230+
marray<T, N> res;
231+
for (size_t i = 0; i < N; i++) {
232+
res[i] = __sycl_std::__invoke_ldexp<T>(x[i], k);
233+
}
234+
return res;
235+
}
236+
237+
#define __SYCL_MATH_FUNCTION_2_GENINT_Y_OVERLOAD_IMPL(NAME) \
238+
marray<T, N> res; \
239+
for (size_t i = 0; i < N; i++) { \
240+
res[i] = __sycl_std::__invoke_##NAME<T>(x[i], y[i]); \
241+
} \
242+
return res;
243+
244+
template <typename T, size_t N>
245+
inline __SYCL_ALWAYS_INLINE
246+
std::enable_if_t<detail::is_sgenfloat<T>::value, marray<T, N>>
247+
pown(marray<T, N> x, marray<int, N> y) __NOEXC {
248+
__SYCL_MATH_FUNCTION_2_GENINT_Y_OVERLOAD_IMPL(pown)
249+
}
250+
251+
template <typename T, size_t N>
252+
inline __SYCL_ALWAYS_INLINE
253+
std::enable_if_t<detail::is_sgenfloat<T>::value, marray<T, N>>
254+
rootn(marray<T, N> x, marray<int, N> y) __NOEXC {
255+
__SYCL_MATH_FUNCTION_2_GENINT_Y_OVERLOAD_IMPL(rootn)
256+
}
257+
258+
#undef __SYCL_MATH_FUNCTION_2_GENINT_Y_OVERLOAD_IMPL
259+
260+
#define __SYCL_MATH_FUNCTION_2_INT_Y_OVERLOAD_IMPL(NAME) \
261+
marray<T, N> res; \
262+
for (size_t i = 0; i < N; i++) { \
263+
res[i] = __sycl_std::__invoke_##NAME<T>(x[i], y); \
264+
} \
265+
return res;
266+
267+
template <typename T, size_t N>
268+
inline __SYCL_ALWAYS_INLINE
269+
std::enable_if_t<detail::is_sgenfloat<T>::value, marray<T, N>>
270+
pown(marray<T, N> x, int y) __NOEXC {
271+
__SYCL_MATH_FUNCTION_2_INT_Y_OVERLOAD_IMPL(pown)
272+
}
273+
274+
template <typename T, size_t N>
275+
inline __SYCL_ALWAYS_INLINE
276+
std::enable_if_t<detail::is_sgenfloat<T>::value, marray<T, N>>
277+
rootn(marray<T, N> x,
278+
int y) __NOEXC{__SYCL_MATH_FUNCTION_2_INT_Y_OVERLOAD_IMPL(rootn)}
279+
280+
#undef __SYCL_MATH_FUNCTION_2_INT_Y_OVERLOAD_IMPL
281+
173282
#define __SYCL_MATH_FUNCTION_3_OVERLOAD(NAME) \
174283
template <typename T, size_t N> \
175284
inline __SYCL_ALWAYS_INLINE \
@@ -789,6 +898,78 @@ detail::enable_if_t<detail::is_svgenfloat<T>::value, T> sign(T x) __NOEXC {
789898
return __sycl_std::__invoke_sign<T>(x);
790899
}
791900

901+
// marray common functions
902+
903+
// TODO: can be optimized in the way math functions are optimized (usage of
904+
// vec<T, 2>)
905+
#define __SYCL_MARRAY_COMMON_FUNCTION_OVERLOAD_IMPL(NAME, ...) \
906+
T res; \
907+
for (int i = 0; i < T::size(); i++) { \
908+
res[i] = NAME(__VA_ARGS__); \
909+
} \
910+
return res;
911+
912+
#define __SYCL_MARRAY_COMMON_FUNCTION_UNOP_OVERLOAD(NAME, ARG, ...) \
913+
template <typename T, \
914+
typename = std::enable_if_t<detail::is_mgenfloat<T>::value>> \
915+
T NAME(ARG) __NOEXC { \
916+
__SYCL_MARRAY_COMMON_FUNCTION_OVERLOAD_IMPL(NAME, __VA_ARGS__) \
917+
}
918+
919+
__SYCL_MARRAY_COMMON_FUNCTION_UNOP_OVERLOAD(degrees, T radians, radians[i])
920+
__SYCL_MARRAY_COMMON_FUNCTION_UNOP_OVERLOAD(radians, T degrees, degrees[i])
921+
__SYCL_MARRAY_COMMON_FUNCTION_UNOP_OVERLOAD(sign, T x, x[i])
922+
923+
#undef __SYCL_MARRAY_COMMON_FUNCTION_UNOP_OVERLOAD
924+
925+
#define __SYCL_MARRAY_COMMON_FUNCTION_BINOP_OVERLOAD(NAME, ARG1, ARG2, ...) \
926+
template <typename T, \
927+
typename = std::enable_if_t<detail::is_mgenfloat<T>::value>> \
928+
T NAME(ARG1, ARG2) __NOEXC { \
929+
__SYCL_MARRAY_COMMON_FUNCTION_OVERLOAD_IMPL(NAME, __VA_ARGS__) \
930+
}
931+
932+
__SYCL_MARRAY_COMMON_FUNCTION_BINOP_OVERLOAD(min, T x, T y, x[i], y[i])
933+
__SYCL_MARRAY_COMMON_FUNCTION_BINOP_OVERLOAD(min, T x,
934+
detail::marray_element_type<T> y,
935+
x[i], y)
936+
__SYCL_MARRAY_COMMON_FUNCTION_BINOP_OVERLOAD(max, T x, T y, x[i], y[i])
937+
__SYCL_MARRAY_COMMON_FUNCTION_BINOP_OVERLOAD(max, T x,
938+
detail::marray_element_type<T> y,
939+
x[i], y)
940+
__SYCL_MARRAY_COMMON_FUNCTION_BINOP_OVERLOAD(step, T edge, T x, edge[i], x[i])
941+
__SYCL_MARRAY_COMMON_FUNCTION_BINOP_OVERLOAD(
942+
step, detail::marray_element_type<T> edge, T x, edge, x[i])
943+
944+
#undef __SYCL_MARRAY_COMMON_FUNCTION_BINOP_OVERLOAD
945+
946+
#define __SYCL_MARRAY_COMMON_FUNCTION_TEROP_OVERLOAD(NAME, ARG1, ARG2, ARG3, \
947+
...) \
948+
template <typename T, \
949+
typename = std::enable_if_t<detail::is_mgenfloat<T>::value>> \
950+
T NAME(ARG1, ARG2, ARG3) __NOEXC { \
951+
__SYCL_MARRAY_COMMON_FUNCTION_OVERLOAD_IMPL(NAME, __VA_ARGS__) \
952+
}
953+
954+
__SYCL_MARRAY_COMMON_FUNCTION_TEROP_OVERLOAD(clamp, T x, T minval, T maxval,
955+
x[i], minval[i], maxval[i])
956+
__SYCL_MARRAY_COMMON_FUNCTION_TEROP_OVERLOAD(
957+
clamp, T x, detail::marray_element_type<T> minval,
958+
detail::marray_element_type<T> maxval, x[i], minval, maxval)
959+
__SYCL_MARRAY_COMMON_FUNCTION_TEROP_OVERLOAD(mix, T x, T y, T a, x[i], y[i],
960+
a[i])
961+
__SYCL_MARRAY_COMMON_FUNCTION_TEROP_OVERLOAD(mix, T x, T y,
962+
detail::marray_element_type<T> a,
963+
x[i], y[i], a)
964+
__SYCL_MARRAY_COMMON_FUNCTION_TEROP_OVERLOAD(smoothstep, T edge0, T edge1, T x,
965+
edge0[i], edge1[i], x[i])
966+
__SYCL_MARRAY_COMMON_FUNCTION_TEROP_OVERLOAD(
967+
smoothstep, detail::marray_element_type<T> edge0,
968+
detail::marray_element_type<T> edge1, T x, edge0, edge1, x[i])
969+
970+
#undef __SYCL_MARRAY_COMMON_FUNCTION_TEROP_OVERLOAD
971+
#undef __SYCL_MARRAY_COMMON_FUNCTION_OVERLOAD_IMPL
972+
792973
/* --------------- 4.13.4 Integer functions. --------------------------------*/
793974
// ugeninteger abs (geninteger x)
794975
template <typename T>
@@ -1724,6 +1905,7 @@ __SYCL_HALF_PRECISION_MATH_FUNCTION_OVERLOAD(cos)
17241905
__SYCL_HALF_PRECISION_MATH_FUNCTION_OVERLOAD(tan)
17251906
__SYCL_HALF_PRECISION_MATH_FUNCTION_OVERLOAD(exp)
17261907
__SYCL_HALF_PRECISION_MATH_FUNCTION_OVERLOAD(exp2)
1908+
__SYCL_HALF_PRECISION_MATH_FUNCTION_OVERLOAD(exp10)
17271909
__SYCL_HALF_PRECISION_MATH_FUNCTION_OVERLOAD(log)
17281910
__SYCL_HALF_PRECISION_MATH_FUNCTION_OVERLOAD(log2)
17291911
__SYCL_HALF_PRECISION_MATH_FUNCTION_OVERLOAD(log10)

0 commit comments

Comments
 (0)