@@ -115,42 +115,6 @@ inline __SYCL_ALWAYS_INLINE
115
115
116
116
} // namespace native
117
117
118
- namespace detail {
119
-
120
- template <typename T> struct is_bf16_storage_type {
121
- static constexpr int value = false ;
122
- };
123
-
124
- template <> struct is_bf16_storage_type <uint16_t > {
125
- static constexpr int value = true ;
126
- };
127
-
128
- template <> struct is_bf16_storage_type <uint32_t > {
129
- static constexpr int value = true ;
130
- };
131
-
132
- template <int N> struct is_bf16_storage_type <vec<uint16_t , N>> {
133
- static constexpr int value = true ;
134
- };
135
-
136
- template <int N> struct is_bf16_storage_type <vec<uint32_t , N>> {
137
- static constexpr int value = true ;
138
- };
139
-
140
- } // namespace detail
141
-
142
- template <typename T>
143
- std::enable_if_t <experimental::detail::is_bf16_storage_type<T>::value, T>
144
- fabs (T x) {
145
- #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
146
- return __clc_fabs (x);
147
- #else
148
- (void )x;
149
- throw runtime_error (" bfloat16 is not currently supported on the host device." ,
150
- PI_INVALID_DEVICE);
151
- #endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
152
- }
153
-
154
118
template <typename T>
155
119
std::enable_if_t <std::is_same<T, bfloat16>::value, T> fabs (T x) {
156
120
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
@@ -162,9 +126,8 @@ std::enable_if_t<std::is_same<T, bfloat16>::value, T> fabs(T x) {
162
126
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
163
127
}
164
128
165
- template <typename T, size_t N>
166
- std::enable_if_t <std::is_same<T, bfloat16>::value, sycl::marray<T, N>>
167
- fabs (sycl::marray<T, N> x) {
129
+ template <size_t N>
130
+ sycl::marray<bfloat16, N> fabs (sycl::marray<bfloat16, N> x) {
168
131
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
169
132
sycl::marray<bfloat16, N> res;
170
133
auto x_storage = reinterpret_cast <uint32_t const *>(&x);
@@ -184,19 +147,6 @@ fabs(sycl::marray<T, N> x) {
184
147
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
185
148
}
186
149
187
- template <typename T>
188
- std::enable_if_t <experimental::detail::is_bf16_storage_type<T>::value, T>
189
- fmin (T x, T y) {
190
- #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
191
- return __clc_fmin (x, y);
192
- #else
193
- (void )x;
194
- (void )y;
195
- throw runtime_error (" bfloat16 is not currently supported on the host device." ,
196
- PI_INVALID_DEVICE);
197
- #endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
198
- }
199
-
200
150
template <typename T>
201
151
std::enable_if_t <std::is_same<T, bfloat16>::value, T> fmin (T x, T y) {
202
152
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
@@ -209,9 +159,9 @@ std::enable_if_t<std::is_same<T, bfloat16>::value, T> fmin(T x, T y) {
209
159
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
210
160
}
211
161
212
- template <typename T, size_t N>
213
- std:: enable_if_t <std::is_same<T, bfloat16>::value, sycl::marray<T , N>>
214
- fmin (sycl::marray<T, N> x, sycl::marray<T , N> y) {
162
+ template <size_t N>
163
+ sycl::marray<bfloat16, N> fmin ( sycl::marray<bfloat16 , N> x,
164
+ sycl::marray<bfloat16 , N> y) {
215
165
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
216
166
sycl::marray<bfloat16, N> res;
217
167
auto x_storage = reinterpret_cast <uint32_t const *>(&x);
@@ -235,19 +185,6 @@ fmin(sycl::marray<T, N> x, sycl::marray<T, N> y) {
235
185
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
236
186
}
237
187
238
- template <typename T>
239
- std::enable_if_t <experimental::detail::is_bf16_storage_type<T>::value, T>
240
- fmax (T x, T y) {
241
- #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
242
- return __clc_fmax (x, y);
243
- #else
244
- (void )x;
245
- (void )y;
246
- throw runtime_error (" bfloat16 is not currently supported on the host device." ,
247
- PI_INVALID_DEVICE);
248
- #endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
249
- }
250
-
251
188
template <typename T>
252
189
std::enable_if_t <std::is_same<T, bfloat16>::value, T> fmax (T x, T y) {
253
190
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
@@ -260,9 +197,9 @@ std::enable_if_t<std::is_same<T, bfloat16>::value, T> fmax(T x, T y) {
260
197
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
261
198
}
262
199
263
- template <typename T, size_t N>
264
- std:: enable_if_t <std::is_same<T, bfloat16>::value, sycl::marray<T , N>>
265
- fmax (sycl::marray<T, N> x, sycl::marray<T , N> y) {
200
+ template <size_t N>
201
+ sycl::marray<bfloat16, N> fmax ( sycl::marray<bfloat16 , N> x,
202
+ sycl::marray<bfloat16 , N> y) {
266
203
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
267
204
sycl::marray<bfloat16, N> res;
268
205
auto x_storage = reinterpret_cast <uint32_t const *>(&x);
@@ -285,20 +222,6 @@ fmax(sycl::marray<T, N> x, sycl::marray<T, N> y) {
285
222
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
286
223
}
287
224
288
- template <typename T>
289
- std::enable_if_t <experimental::detail::is_bf16_storage_type<T>::value, T>
290
- fma (T x, T y, T z) {
291
- #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
292
- return __clc_fma (x, y, z);
293
- #else
294
- (void )x;
295
- (void )y;
296
- (void )z;
297
- throw runtime_error (" bfloat16 is not currently supported on the host device." ,
298
- PI_INVALID_DEVICE);
299
- #endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
300
- }
301
-
302
225
template <typename T>
303
226
std::enable_if_t <std::is_same<T, bfloat16>::value, T> fma (T x, T y, T z) {
304
227
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
@@ -312,9 +235,10 @@ std::enable_if_t<std::is_same<T, bfloat16>::value, T> fma(T x, T y, T z) {
312
235
#endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
313
236
}
314
237
315
- template <typename T, size_t N>
316
- std::enable_if_t <std::is_same<T, bfloat16>::value, sycl::marray<T, N>>
317
- fma (sycl::marray<T, N> x, sycl::marray<T, N> y, sycl::marray<T, N> z) {
238
+ template <size_t N>
239
+ sycl::marray<bfloat16, N> fma (sycl::marray<bfloat16, N> x,
240
+ sycl::marray<bfloat16, N> y,
241
+ sycl::marray<bfloat16, N> z) {
318
242
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
319
243
sycl::marray<bfloat16, N> res;
320
244
auto x_storage = reinterpret_cast <uint32_t const *>(&x);
0 commit comments