@@ -22,221 +22,6 @@ namespace internal {
2222TORCHAO_ALWAYS_INLINE inline void pack_4_uint6_values (
2323 uint8_t * packed,
2424 const uint8_t * unpacked) {
25- // Given 4 unpacked uint6 values: 01abcd, 23efgh, 45ijkl, 67mnop
26- // this function packs them as:
27- // b54: 67|45|23|01 (to hold upper 2 bits on all values)
28- // b3210_0: efgh|abcd (lower 4 bits for first 2 values)
29- // b3210_1: mnop|ijkl (lower 4 bits for last 2 values)
30-
31- // These are stored in packed as: b54, b3210_0, b3210_1
32- //
33- // Input is 4 bytes
34- // Output is 6 * 4 bits/8 = 3 bytes
35-
36- // b54
37- packed[0 ] = ((unpacked[0 ] & 48 ) >> 4 ) | ((unpacked[1 ] & 48 ) >> 2 ) |
38- ((unpacked[2 ] & 48 )) | ((unpacked[3 ] & 48 ) << 2 );
39-
40- // b3210_0
41- packed[1 ] = (unpacked[0 ] & 15 ) | ((unpacked[1 ] & 15 ) << 4 );
42-
43- // b3210_1
44- packed[2 ] = (unpacked[2 ] & 15 ) | ((unpacked[3 ] & 15 ) << 4 );
45- }
46-
47- TORCHAO_ALWAYS_INLINE inline void unpack_4_uint6_values (
48- uint8_t * unpacked,
49- const uint8_t * packed) {
50- // Unpacks data packed by pack_4_uint6_values
51- //
52- // Input is 24 bits = 3 bytes
53- // Output is 4 bytes
54-
55- uint8_t b54 = packed[0 ];
56- uint8_t b3210_0 = packed[1 ];
57- uint8_t b3210_1 = packed[2 ];
58-
59- unpacked[0 ] = ((b54 & 3 ) << 4 ) | (b3210_0 & 15 );
60- unpacked[1 ] = ((b54 & 12 ) << 2 ) | (b3210_0 >> 4 );
61-
62- unpacked[2 ] = (b54 & 48 ) | (b3210_1 & 15 );
63- unpacked[3 ] = ((b54 & 192 ) >> 2 ) | (b3210_1 >> 4 );
64- }
65-
66- TORCHAO_ALWAYS_INLINE inline void vec_pack_32_uint6_values (
67- uint8_t * packed,
68- const uint8x16_t & unpacked0,
69- const uint8x16_t & unpacked1) {
70- // This function is a vectorized version of pack_8_uint6_values
71- // To understand it, please see pack_8_uint6_values first.
72- // Before each code section, there is a comment indicating the
73- // code in pack_8_uint6_values that is being vectorized
74- //
75- // Input is 32 bytes
76- // Output is 6*32= 192 bits = 24 bytes
77-
78- uint8x8_t b54;
79- uint8x8_t mask;
80-
81- // // b54
82- // packed[0] = ((unpacked[0] & 48) >> 4) | ((unpacked[1] & 48) >> 2) |
83- // ((unpacked[2] & 48)) | ((unpacked[3] & 48) << 2);
84- mask = vdup_n_u8 (48 );
85- b54 = vshr_n_u8 (vand_u8 (vget_low_u8 (unpacked0), mask), 4 );
86- b54 = vorr_u8 (b54, vshr_n_u8 (vand_u8 (vget_high_u8 (unpacked0), mask), 2 ));
87-
88- b54 = vorr_u8 (b54, vand_u8 (vget_low_u8 (unpacked1), mask));
89- b54 = vorr_u8 (b54, vshl_n_u8 (vand_u8 (vget_high_u8 (unpacked1), mask), 2 ));
90-
91- vst1_u8 (packed, b54);
92-
93- mask = vdup_n_u8 (15 );
94- uint8x8_t b3210;
95-
96- // b3210_0
97- // packed[1] = (unpacked[0] & 15) | ((unpacked[1] & 15) << 4);
98- b3210 = vand_u8 (vget_low_u8 (unpacked0), mask);
99- b3210 = vorr_u8 (b3210, vshl_n_u8 (vand_u8 (vget_high_u8 (unpacked0), mask), 4 ));
100- vst1_u8 (packed + 8 , b3210);
101-
102- // b3210_1
103- // packed[2] = (unpacked[2] & 15) | ((unpacked[3] & 15) << 4);
104- b3210 = vand_u8 (vget_low_u8 (unpacked1), mask);
105- b3210 = vorr_u8 (b3210, vshl_n_u8 (vand_u8 (vget_high_u8 (unpacked1), mask), 4 ));
106- vst1_u8 (packed + 16 , b3210);
107- }
108-
109- TORCHAO_ALWAYS_INLINE inline void vec_unpack_32_uint6_values (
110- uint8x16_t & unpacked0,
111- uint8x16_t & unpacked1,
112- const uint8_t * packed) {
113- // Unpacks data packed by pack_32_uint6_values
114- //
115- // This function vectorizes vec_unpack_4_uint6_values
116- // To understand it, please see vec_unpack_4_uint6_values first.
117- // Before each code section, there is a comment indicating the
118- // code in vec_unpack_4_uint6_values that is being vectorized
119-
120- // Input is 24 bytes
121- // Output is 32 bytes
122-
123- uint8x8_t b54 = vld1_u8 (packed);
124- uint8x8_t b3210;
125- uint8x8_t unpacked_tmp0;
126- uint8x8_t unpacked_tmp1;
127-
128- // unpacked[0] = ((b54 & 3) << 4) | (b3210_0 & 15);
129- // unpacked[1] = ((b54 & 12) << 2) | (b3210_0 >> 4);
130- b3210 = vld1_u8 (packed + 8 );
131-
132- unpacked_tmp0 = vshl_n_u8 (vand_u8 (b54, vdup_n_u8 (3 )), 4 );
133- unpacked_tmp0 = vorr_u8 (unpacked_tmp0, vand_u8 (b3210, vdup_n_u8 (15 )));
134-
135- unpacked_tmp1 = vshl_n_u8 (vand_u8 (b54, vdup_n_u8 (12 )), 2 );
136- unpacked_tmp1 = vorr_u8 (unpacked_tmp1, vshr_n_u8 (b3210, 4 ));
137-
138- unpacked0 = vcombine_u8 (unpacked_tmp0, unpacked_tmp1);
139-
140- // unpacked[2] = (b54 & 48) | (b3210_1 & 15);
141- // unpacked[3] = ((b54 & 192) >> 2) | (b3210_1 >> 4);
142- b3210 = vld1_u8 (packed + 16 );
143-
144- unpacked_tmp0 = vand_u8 (b54, vdup_n_u8 (48 ));
145- unpacked_tmp0 = vorr_u8 (unpacked_tmp0, vand_u8 (b3210, vdup_n_u8 (15 )));
146-
147- unpacked_tmp1 = vshr_n_u8 (vand_u8 (b54, vdup_n_u8 (192 )), 2 );
148- unpacked_tmp1 = vorr_u8 (unpacked_tmp1, vshr_n_u8 (b3210, 4 ));
149-
150- unpacked1 = vcombine_u8 (unpacked_tmp0, unpacked_tmp1);
151- }
152-
153- TORCHAO_ALWAYS_INLINE inline void vec_pack_64_uint6_values (
154- uint8_t * packed,
155- const uint8x16_t & unpacked0,
156- const uint8x16_t & unpacked1,
157- const uint8x16_t & unpacked2,
158- const uint8x16_t & unpacked3) {
159- // This function is a vectorized version of pack_4_uint6_values
160- // To understand it, please see pack_4_uint6_values first.
161- // Before each code section, there is a comment indicating the
162- // code in pack_4_uint6_values that is being vectorized
163- //
164- // Input is 64 bytes
165- // Output is 6*64= 384 bits = 48 bytes
166-
167- uint8x16_t b54;
168- uint8x16_t mask;
169-
170- // b54
171- // packed[0] = ((unpacked[0] & 48) >> 4) | ((unpacked[1] & 48) >> 2) |
172- // ((unpacked[2] & 48)) | ((unpacked[3] & 48) << 2);
173- mask = vdupq_n_u8 (48 );
174- b54 = vshrq_n_u8 (vandq_u8 (unpacked0, mask), 4 );
175- b54 = vorrq_u8 (b54, vshrq_n_u8 (vandq_u8 (unpacked1, mask), 2 ));
176- b54 = vorrq_u8 (b54, vandq_u8 (unpacked2, mask));
177- b54 = vorrq_u8 (b54, vshlq_n_u8 (vandq_u8 (unpacked3, mask), 2 ));
178-
179- vst1q_u8 (packed, b54);
180-
181- mask = vdupq_n_u8 (15 );
182- uint8x16_t b3210;
183-
184- // b3210_0
185- // packed[1] = (unpacked[0] & 15) | ((unpacked[1] & 15) << 4);
186- b3210 = vandq_u8 (unpacked0, mask);
187- b3210 = vorrq_u8 (b3210, vshlq_n_u8 (vandq_u8 (unpacked1, mask), 4 ));
188- vst1q_u8 (packed + 16 , b3210);
189-
190- // b3210_1
191- // packed[2] = (unpacked[2] & 15) | ((unpacked[3] & 15) << 4);
192- b3210 = vandq_u8 (unpacked2, mask);
193- b3210 = vorrq_u8 (b3210, vshlq_n_u8 (vandq_u8 (unpacked3, mask), 4 ));
194- vst1q_u8 (packed + 32 , b3210);
195- }
196-
197- TORCHAO_ALWAYS_INLINE inline void vec_unpack_64_uint6_values (
198- uint8x16_t & unpacked0,
199- uint8x16_t & unpacked1,
200- uint8x16_t & unpacked2,
201- uint8x16_t & unpacked3,
202- const uint8_t * packed) {
203- // Unpacks data packed by pack_64_uint6_values
204- //
205- // This function vectorizes vec_unpack_4_uint6_values
206- // To understand it, please see vec_unpack_4_uint6_values first.
207- // Before each code section, there is a comment indicating the
208- // code in vec_unpack_4_uint6_values that is being vectorized
209-
210- // Input is 48 bytes
211- // Output is 64 bytes
212-
213- uint8x16_t b54 = vld1q_u8 (packed);
214- uint8x16_t b3210;
215-
216- // unpacked[0] = ((b54 & 3) << 4) | (b3210_0 & 15);
217- // unpacked[1] = ((b54 & 12) << 2) | (b3210_0 >> 4);
218- b3210 = vld1q_u8 (packed + 16 );
219-
220- unpacked0 = vshlq_n_u8 (vandq_u8 (b54, vdupq_n_u8 (3 )), 4 );
221- unpacked0 = vorrq_u8 (unpacked0, vandq_u8 (b3210, vdupq_n_u8 (15 )));
222-
223- unpacked1 = vshlq_n_u8 (vandq_u8 (b54, vdupq_n_u8 (12 )), 2 );
224- unpacked1 = vorrq_u8 (unpacked1, vshrq_n_u8 (b3210, 4 ));
225-
226- // unpacked[2] = (b54 & 48) | (b3210_1 & 15);
227- // unpacked[3] = ((b54 & 192) >> 2) | (b3210_1 >> 4);
228- b3210 = vld1q_u8 (packed + 32 );
229-
230- unpacked2 = vandq_u8 (b54, vdupq_n_u8 (48 ));
231- unpacked2 = vorrq_u8 (unpacked2, vandq_u8 (b3210, vdupq_n_u8 (15 )));
232-
233- unpacked3 = vshrq_n_u8 (vandq_u8 (b54, vdupq_n_u8 (192 )), 2 );
234- unpacked3 = vorrq_u8 (unpacked3, vshrq_n_u8 (b3210, 4 ));
235- }
236-
237- TORCHAO_ALWAYS_INLINE inline void pack_4_uint6_values_v2 (
238- uint8_t * packed,
239- const uint8_t * unpacked) {
24025 // Given 4 unpacked uint6 values: abcdef, ghijkl, mnopqr, 123456
24126 // this function packs them as:
24227 // packed[0]: 56 | abcdef
@@ -254,9 +39,9 @@ TORCHAO_ALWAYS_INLINE inline void pack_4_uint6_values_v2(
25439 packed[2 ] |= ((unpacked[3 ] & 0b11'0000u ) << 2 );
25540}
25641
257- TORCHAO_ALWAYS_INLINE inline void unpack_4_uint6_values_v2 (
258- uint8_t * unpacked,
259- const uint8_t * packed) {
42+ TORCHAO_ALWAYS_INLINE inline void unpack_4_uint6_values (
43+ uint8_t * unpacked,
44+ const uint8_t * packed) {
26045 // Unpacks data packed by pack_4_uint6_values_v2
26146 //
26247 // Input is 24 bits = 3 bytes
@@ -266,17 +51,17 @@ TORCHAO_ALWAYS_INLINE inline void unpack_4_uint6_values_v2(
26651 unpacked[2 ] = packed[2 ] & 0b111111u ;
26752 // Last value is packed in the upper 2 bits of the three bytes
26853 unpacked[3 ] = ((packed[0 ] & 0b1100'0000u ) >> 6 ) |
269- ((packed[1 ] & 0b1100'0000u ) >> 4 ) |
270- ((packed[2 ] & 0b1100'0000u ) >> 2 );
54+ ((packed[1 ] & 0b1100'0000u ) >> 4 ) | ((packed[2 ] & 0b1100'0000u ) >> 2 );
27155}
27256
273- TORCHAO_ALWAYS_INLINE inline void vec_pack_32_uint6_values_v2 (
274- uint8_t * packed,
275- const uint8x16_t & unpacked0,
276- const uint8x16_t & unpacked1) {
57+ TORCHAO_ALWAYS_INLINE inline void vec_pack_32_uint6_values (
58+ uint8_t * packed,
59+ const uint8x16_t & unpacked0,
60+ const uint8x16_t & unpacked1) {
27761 // This function is a vectorized version of pack_4_uint6_values_v2.
278- // To understand the following code, please see pack_4_uint6_values_v2 first and
279- // consider the following mapping for the unpacked parameter of that function:
62+ // To understand the following code, please see pack_4_uint6_values_v2 first
63+ // and consider the following mapping for the unpacked parameter of that
64+ // function:
28065 //
28166 // unpacked[0] -> vget_low_u8(unpacked0)
28267 // unpacked[1] -> vget_high_u8(unpacked0)
@@ -293,23 +78,26 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_32_uint6_values_v2(
29378 // packed[0] = unpacked[0]
29479 // packed[0] |= ((unpacked[3] & 0b00'0011u) << 6)
29580 r = vget_low_u8 (unpacked0);
296- r = vorr_u8 (r, vshl_n_u8 (vand_u8 (vget_high_u8 (unpacked1), vdup_n_u8 (0b00'0011u )), 6 ));
81+ r = vorr_u8 (
82+ r, vshl_n_u8 (vand_u8 (vget_high_u8 (unpacked1), vdup_n_u8 (0b00'0011u )), 6 ));
29783 vst1_u8 (packed, r);
29884
29985 // packed[1] = unpacked[1]
30086 // packed[1] |= ((unpacked[3] & 0b00'1100u) << 4)
30187 r = vget_high_u8 (unpacked0);
302- r = vorr_u8 (r, vshl_n_u8 (vand_u8 (vget_high_u8 (unpacked1), vdup_n_u8 (0b00'1100u )), 4 ));
88+ r = vorr_u8 (
89+ r, vshl_n_u8 (vand_u8 (vget_high_u8 (unpacked1), vdup_n_u8 (0b00'1100u )), 4 ));
30390 vst1_u8 (packed + 8 , r);
30491
30592 // packed[2] = unpacked[2]
30693 // packed[2] |= ((unpacked[3] & 0b11'0000u) << 2)
30794 r = vget_low_u8 (unpacked1);
308- r = vorr_u8 (r, vshl_n_u8 (vand_u8 (vget_high_u8 (unpacked1), vdup_n_u8 (0b11'0000u )), 2 ));
95+ r = vorr_u8 (
96+ r, vshl_n_u8 (vand_u8 (vget_high_u8 (unpacked1), vdup_n_u8 (0b11'0000u )), 2 ));
30997 vst1_u8 (packed + 16 , r);
31098}
31199
312- TORCHAO_ALWAYS_INLINE inline void vec_unpack_32_uint6_values_v2 (
100+ TORCHAO_ALWAYS_INLINE inline void vec_unpack_32_uint6_values (
313101 uint8x16_t & unpacked0,
314102 uint8x16_t & unpacked1,
315103 const uint8_t * packed) {
@@ -331,18 +119,18 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_32_uint6_values_v2(
331119 // ((packed[2] & 0b1100'0000u) >> 2);
332120 const uint8x8_t high = vdup_n_u8 (0b1100'0000u );
333121 uint8x8_t unpacked3;
334- unpacked3 = vorr_u8 (vshr_n_u8 ( vand_u8 (packed0, high), 6 ),
335- vshr_n_u8 (vand_u8 (packed1 , high), 4 ));
336- unpacked3 = vorr_u8 (unpacked3,
337- vshr_n_u8 (vand_u8 (packed2, high), 2 ));
122+ unpacked3 = vorr_u8 (
123+ vshr_n_u8 (vand_u8 (packed0 , high), 6 ),
124+ vshr_n_u8 ( vand_u8 (packed1, high), 4 ));
125+ unpacked3 = vorr_u8 (unpacked3, vshr_n_u8 (vand_u8 (packed2, high), 2 ));
338126
339127 // unpacked[i] = packed[i] & 0b11'1111u;
340128 const uint8x8_t mask = vdup_n_u8 (0b11'1111u );
341129 unpacked0 = vcombine_u8 (vand_u8 (packed0, mask), vand_u8 (packed1, mask));
342130 unpacked1 = vcombine_u8 (vand_u8 (packed2, mask), unpacked3);
343131}
344132
345- TORCHAO_ALWAYS_INLINE inline void vec_pack_64_uint6_values_v2 (
133+ TORCHAO_ALWAYS_INLINE inline void vec_pack_64_uint6_values (
346134 uint8_t * packed,
347135 const uint8x16_t & unpacked0,
348136 const uint8x16_t & unpacked1,
@@ -376,7 +164,7 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_64_uint6_values_v2(
376164 vst1q_u8 (packed + 32 , r);
377165}
378166
379- TORCHAO_ALWAYS_INLINE inline void vec_unpack_64_uint6_values_v2 (
167+ TORCHAO_ALWAYS_INLINE inline void vec_unpack_64_uint6_values (
380168 uint8x16_t & unpacked0,
381169 uint8x16_t & unpacked1,
382170 uint8x16_t & unpacked2,
@@ -399,10 +187,10 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_64_uint6_values_v2(
399187 // ((packed[1] & 0b1100'0000u) >> 4) |
400188 // ((packed[2] & 0b1100'0000u) >> 2);
401189 const uint8x16_t high = vdupq_n_u8 (0b1100'0000u );
402- unpacked3 = vorrq_u8 (vshrq_n_u8 ( vandq_u8 (unpacked0, high), 6 ),
403- vshrq_n_u8 (vandq_u8 (unpacked1 , high), 4 ));
404- unpacked3 = vorrq_u8 (unpacked3,
405- vshrq_n_u8 (vandq_u8 (unpacked2, high), 2 ));
190+ unpacked3 = vorrq_u8 (
191+ vshrq_n_u8 (vandq_u8 (unpacked0 , high), 6 ),
192+ vshrq_n_u8 ( vandq_u8 (unpacked1, high), 4 ));
193+ unpacked3 = vorrq_u8 (unpacked3, vshrq_n_u8 (vandq_u8 (unpacked2, high), 2 ));
406194
407195 // unpacked[i] = packed[i] & 0b11'1111u;
408196 const uint8x16_t mask = vdupq_n_u8 (0b11'1111u );
0 commit comments