forked from NVIDIA/cutlass
-
Notifications
You must be signed in to change notification settings - Fork 28
/
Copy pathfp8_to_fp16.h
117 lines (97 loc) · 3.87 KB
/
fp8_to_fp16.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
/***************************************************************************************************
* Copyright (c) 2025 - 2025 Codeplay Software Ltd. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cutlass/half.h>
#include <cute/util/sycl_vec.hpp>
using half_t = cutlass::half_t;
using uchar16 = cute::intel::uchar16;
using ushort16 = cute::intel::ushort16;
static inline ushort16 convert_ushort16(uchar16 x) {
ushort16 result;
#pragma unroll
for (int i = 0; i < 16; ++i) {
result[i] = static_cast<uint16_t>(x[i]);
}
return result;
}
static inline ushort16 E4M3_to_FP16_vec16(uchar16 xin) {
uchar16 xa = xin & 0x7F;
uchar16 sgn_x = xin ^ xa;
uchar16 zero_mask;
#pragma unroll
for (int i = 0; i < 16; ++i) {
zero_mask[i] = (xa[i] == 0) ? 1 : 0;
}
uchar16 nan_mask = (0x7E - xa) & 0x80;
uchar16 den_mask = ((xa - 8) >> 7) & 0x01;
xa += (nan_mask >> 1);
xa |= (den_mask & 8);
den_mask &= 0x48;
xa += 0x40 & ~(zero_mask * 0x40);
ushort16 x16 = convert_ushort16(xa) << 7;
ushort16 den_corr = convert_ushort16(den_mask & ~zero_mask) << 7;
ushort16 result = x16 - den_corr;
result &= ~(convert_ushort16(zero_mask) << 7);
ushort16 sign_ext = convert_ushort16(sgn_x) << 8;
result ^= sign_ext;
return result;
}
static inline unsigned short E4M3_to_FP16(unsigned char xin) {
unsigned char xa, sgn_x, nan_mask, den_mask;
union {
signed short i;
_Float16 f;
} x16, den_corr;
xa = xin & 0x7f;
sgn_x = xin ^ xa;
// mask for NaN input
nan_mask = (0x7e - xa) & 0x80;
// mask for denormal / zero input
den_mask = (((signed char)(xa - 8)) >> 7);
// apply Nan correction
xa += (nan_mask >> 1);
// first denormal correction
xa |= (den_mask & 8);
den_mask &= 0x48;
// exponent bias correction
xa += 0x40;
// zero-extend to 16 bits
x16.i = xa;
den_corr.i = den_mask;
// FP16 format
x16.i <<= 7;
den_corr.i <<= 7;
// apply correction for denormals/zero
x16.f -= den_corr.f;
// finally, apply the sign
x16.i ^= (((signed short)sgn_x) << 8);
return (unsigned short)x16.i;
}