@@ -35,90 +35,115 @@ union StateWord {
35
35
36
36
impl StateWord {
37
37
#[ inline]
38
+ #[ must_use]
38
39
#[ target_feature( enable = "avx2" ) ]
39
- unsafe fn add_assign_epi32 ( & mut self , rhs : & Self ) {
40
- self . avx = [
41
- _mm256_add_epi32 ( self . avx [ 0 ] , rhs. avx [ 0 ] ) ,
42
- _mm256_add_epi32 ( self . avx [ 1 ] , rhs. avx [ 1 ] ) ,
43
- ] ;
40
+ unsafe fn add_epi32 ( self , rhs : Self ) -> Self {
41
+ StateWord {
42
+ avx : [
43
+ _mm256_add_epi32 ( self . avx [ 0 ] , rhs. avx [ 0 ] ) ,
44
+ _mm256_add_epi32 ( self . avx [ 1 ] , rhs. avx [ 1 ] ) ,
45
+ ] ,
46
+ }
44
47
}
45
48
46
49
#[ inline]
50
+ #[ must_use]
47
51
#[ target_feature( enable = "avx2" ) ]
48
- unsafe fn xor_assign ( & mut self , rhs : & Self ) {
49
- self . avx = [
50
- _mm256_xor_si256 ( self . avx [ 0 ] , rhs. avx [ 0 ] ) ,
51
- _mm256_xor_si256 ( self . avx [ 1 ] , rhs. avx [ 1 ] ) ,
52
- ] ;
52
+ unsafe fn xor ( self , rhs : Self ) -> Self {
53
+ StateWord {
54
+ avx : [
55
+ _mm256_xor_si256 ( self . avx [ 0 ] , rhs. avx [ 0 ] ) ,
56
+ _mm256_xor_si256 ( self . avx [ 1 ] , rhs. avx [ 1 ] ) ,
57
+ ] ,
58
+ }
53
59
}
54
60
55
61
#[ inline]
62
+ #[ must_use]
56
63
#[ target_feature( enable = "avx2" ) ]
57
- unsafe fn shuffle_epi32 < const MASK : i32 > ( & mut self ) {
58
- self . avx = [
59
- _mm256_shuffle_epi32 ( self . avx [ 0 ] , MASK ) ,
60
- _mm256_shuffle_epi32 ( self . avx [ 1 ] , MASK ) ,
61
- ] ;
64
+ unsafe fn shuffle_epi32 < const MASK : i32 > ( self ) -> Self {
65
+ StateWord {
66
+ avx : [
67
+ _mm256_shuffle_epi32 ( self . avx [ 0 ] , MASK ) ,
68
+ _mm256_shuffle_epi32 ( self . avx [ 1 ] , MASK ) ,
69
+ ] ,
70
+ }
62
71
}
63
72
64
73
#[ inline]
74
+ #[ must_use]
65
75
#[ target_feature( enable = "avx2" ) ]
66
- unsafe fn rol < const BY : i32 , const REST : i32 > ( & mut self ) {
67
- self . avx = [
68
- _mm256_xor_si256 (
69
- _mm256_slli_epi32 ( self . avx [ 0 ] , BY ) ,
70
- _mm256_srli_epi32 ( self . avx [ 0 ] , REST ) ,
71
- ) ,
72
- _mm256_xor_si256 (
73
- _mm256_slli_epi32 ( self . avx [ 1 ] , BY ) ,
74
- _mm256_srli_epi32 ( self . avx [ 1 ] , REST ) ,
75
- ) ,
76
- ] ;
76
+ unsafe fn rol < const BY : i32 , const REST : i32 > ( self ) -> Self {
77
+ StateWord {
78
+ avx : [
79
+ _mm256_xor_si256 (
80
+ _mm256_slli_epi32 ( self . avx [ 0 ] , BY ) ,
81
+ _mm256_srli_epi32 ( self . avx [ 0 ] , REST ) ,
82
+ ) ,
83
+ _mm256_xor_si256 (
84
+ _mm256_slli_epi32 ( self . avx [ 1 ] , BY ) ,
85
+ _mm256_srli_epi32 ( self . avx [ 1 ] , REST ) ,
86
+ ) ,
87
+ ] ,
88
+ }
77
89
}
78
90
79
91
#[ inline]
92
+ #[ must_use]
80
93
#[ target_feature( enable = "avx2" ) ]
81
- unsafe fn rol_8 ( & mut self ) {
82
- self . avx = [
83
- _mm256_shuffle_epi8 (
84
- self . avx [ 0 ] ,
85
- _mm256_set_epi8 (
86
- 14 , 13 , 12 , 15 , 10 , 9 , 8 , 11 , 6 , 5 , 4 , 7 , 2 , 1 , 0 , 3 , 14 , 13 , 12 , 15 , 10 , 9 , 8 ,
87
- 11 , 6 , 5 , 4 , 7 , 2 , 1 , 0 , 3 ,
94
+ unsafe fn rol_8 ( self ) -> Self {
95
+ StateWord {
96
+ avx : [
97
+ _mm256_shuffle_epi8 (
98
+ self . avx [ 0 ] ,
99
+ _mm256_set_epi8 (
100
+ 14 , 13 , 12 , 15 , 10 , 9 , 8 , 11 , 6 , 5 , 4 , 7 , 2 , 1 , 0 , 3 , 14 , 13 , 12 , 15 , 10 ,
101
+ 9 , 8 , 11 , 6 , 5 , 4 , 7 , 2 , 1 , 0 , 3 ,
102
+ ) ,
88
103
) ,
89
- ) ,
90
- _mm256_shuffle_epi8 (
91
- self . avx [ 1 ] ,
92
- _mm256_set_epi8 (
93
- 14 , 13 , 12 , 15 , 10 , 9 , 8 , 11 , 6 , 5 , 4 , 7 , 2 , 1 , 0 , 3 , 14 , 13 , 12 , 15 , 10 , 9 , 8 ,
94
- 11 , 6 , 5 , 4 , 7 , 2 , 1 , 0 , 3 ,
104
+ _mm256_shuffle_epi8 (
105
+ self . avx [ 1 ] ,
106
+ _mm256_set_epi8 (
107
+ 14 , 13 , 12 , 15 , 10 , 9 , 8 , 11 , 6 , 5 , 4 , 7 , 2 , 1 , 0 , 3 , 14 , 13 , 12 , 15 , 10 ,
108
+ 9 , 8 , 11 , 6 , 5 , 4 , 7 , 2 , 1 , 0 , 3 ,
109
+ ) ,
95
110
) ,
96
- ) ,
97
- ] ;
111
+ ] ,
112
+ }
98
113
}
99
114
100
115
#[ inline]
116
+ #[ must_use]
101
117
#[ target_feature( enable = "avx2" ) ]
102
- unsafe fn rol_16 ( & mut self ) {
103
- self . avx = [
104
- _mm256_shuffle_epi8 (
105
- self . avx [ 0 ] ,
106
- _mm256_set_epi8 (
107
- 13 , 12 , 15 , 14 , 9 , 8 , 11 , 10 , 5 , 4 , 7 , 6 , 1 , 0 , 3 , 2 , 13 , 12 , 15 , 14 , 9 , 8 , 11 ,
108
- 10 , 5 , 4 , 7 , 6 , 1 , 0 , 3 , 2 ,
118
+ unsafe fn rol_16 ( self ) -> Self {
119
+ StateWord {
120
+ avx : [
121
+ _mm256_shuffle_epi8 (
122
+ self . avx [ 0 ] ,
123
+ _mm256_set_epi8 (
124
+ 13 , 12 , 15 , 14 , 9 , 8 , 11 , 10 , 5 , 4 , 7 , 6 , 1 , 0 , 3 , 2 , 13 , 12 , 15 , 14 , 9 , 8 ,
125
+ 11 , 10 , 5 , 4 , 7 , 6 , 1 , 0 , 3 , 2 ,
126
+ ) ,
109
127
) ,
110
- ) ,
111
- _mm256_shuffle_epi8 (
112
- self . avx [ 1 ] ,
113
- _mm256_set_epi8 (
114
- 13 , 12 , 15 , 14 , 9 , 8 , 11 , 10 , 5 , 4 , 7 , 6 , 1 , 0 , 3 , 2 , 13 , 12 , 15 , 14 , 9 , 8 , 11 ,
115
- 10 , 5 , 4 , 7 , 6 , 1 , 0 , 3 , 2 ,
128
+ _mm256_shuffle_epi8 (
129
+ self . avx [ 1 ] ,
130
+ _mm256_set_epi8 (
131
+ 13 , 12 , 15 , 14 , 9 , 8 , 11 , 10 , 5 , 4 , 7 , 6 , 1 , 0 , 3 , 2 , 13 , 12 , 15 , 14 , 9 , 8 ,
132
+ 11 , 10 , 5 , 4 , 7 , 6 , 1 , 0 , 3 , 2 ,
133
+ ) ,
116
134
) ,
117
- ) ,
118
- ] ;
135
+ ] ,
136
+ }
119
137
}
120
138
}
121
139
140
+ struct State {
141
+ a : StateWord ,
142
+ b : StateWord ,
143
+ c : StateWord ,
144
+ d : StateWord ,
145
+ }
146
+
122
147
/// The ChaCha20 core function (AVX2 accelerated implementation for x86/x86_64)
123
148
// TODO(tarcieri): zeroize?
124
149
#[ derive( Clone ) ]
@@ -152,10 +177,14 @@ impl<R: Rounds> Core<R> {
152
177
#[ inline]
153
178
pub fn generate ( & self , counter : u64 , output : & mut [ u8 ] ) {
154
179
unsafe {
155
- let ( mut v0, mut v1, mut v2) = ( self . v0 , self . v1 , self . v2 ) ;
156
- let mut v3 = iv_setup ( self . iv , counter) ;
157
- self . rounds ( & mut v0, & mut v1, & mut v2, & mut v3) ;
158
- store ( v0, v1, v2, v3, output) ;
180
+ let state = State {
181
+ a : self . v0 ,
182
+ b : self . v1 ,
183
+ c : self . v2 ,
184
+ d : iv_setup ( self . iv , counter) ,
185
+ } ;
186
+ let state = self . rounds ( state) ;
187
+ store ( state. a , state. b , state. c , state. d , output) ;
159
188
}
160
189
}
161
190
@@ -166,14 +195,22 @@ impl<R: Rounds> Core<R> {
166
195
debug_assert_eq ! ( output. len( ) , BUFFER_SIZE ) ;
167
196
168
197
unsafe {
169
- let ( mut v0, mut v1, mut v2) = ( self . v0 , self . v1 , self . v2 ) ;
170
- let mut v3 = iv_setup ( self . iv , counter) ;
171
- self . rounds ( & mut v0, & mut v1, & mut v2, & mut v3) ;
198
+ let state = State {
199
+ a : self . v0 ,
200
+ b : self . v1 ,
201
+ c : self . v2 ,
202
+ d : iv_setup ( self . iv , counter) ,
203
+ } ;
204
+ let state = self . rounds ( state) ;
172
205
173
206
for i in 0 ..BLOCKS {
174
207
for ( chunk, a) in output[ i * BLOCK_SIZE ..( i + 1 ) * BLOCK_SIZE ]
175
208
. chunks_mut ( 0x10 )
176
- . zip ( [ v0, v1, v2, v3] . iter ( ) . map ( |s| s. blocks [ i] ) )
209
+ . zip (
210
+ [ state. a , state. b , state. c , state. d ]
211
+ . iter ( )
212
+ . map ( |s| s. blocks [ i] ) ,
213
+ )
177
214
{
178
215
let b = _mm_loadu_si128 ( chunk. as_ptr ( ) as * const __m128i ) ;
179
216
let out = _mm_xor_si128 ( a, b) ;
@@ -185,23 +222,19 @@ impl<R: Rounds> Core<R> {
185
222
186
223
#[ inline]
187
224
#[ target_feature( enable = "avx2" ) ]
188
- unsafe fn rounds (
189
- & self ,
190
- v0 : & mut StateWord ,
191
- v1 : & mut StateWord ,
192
- v2 : & mut StateWord ,
193
- v3 : & mut StateWord ,
194
- ) {
195
- let v3_orig = * v3;
225
+ unsafe fn rounds ( & self , mut state : State ) -> State {
226
+ let d_orig = state. d ;
196
227
197
228
for _ in 0 ..( R :: COUNT / 2 ) {
198
- double_quarter_round ( v0 , v1 , v2 , v3 ) ;
229
+ state = double_quarter_round ( state ) ;
199
230
}
200
231
201
- v0. add_assign_epi32 ( & self . v0 ) ;
202
- v1. add_assign_epi32 ( & self . v1 ) ;
203
- v2. add_assign_epi32 ( & self . v2 ) ;
204
- v3. add_assign_epi32 ( & v3_orig) ;
232
+ State {
233
+ a : state. a . add_epi32 ( self . v0 ) ,
234
+ b : state. b . add_epi32 ( self . v1 ) ,
235
+ c : state. c . add_epi32 ( self . v2 ) ,
236
+ d : state. d . add_epi32 ( d_orig) ,
237
+ }
205
238
}
206
239
}
207
240
@@ -264,16 +297,9 @@ unsafe fn store(v0: StateWord, v1: StateWord, v2: StateWord, v3: StateWord, outp
264
297
265
298
#[ inline]
266
299
#[ target_feature( enable = "avx2" ) ]
267
- unsafe fn double_quarter_round (
268
- a : & mut StateWord ,
269
- b : & mut StateWord ,
270
- c : & mut StateWord ,
271
- d : & mut StateWord ,
272
- ) {
273
- add_xor_rot ( a, b, c, d) ;
274
- rows_to_cols ( a, b, c, d) ;
275
- add_xor_rot ( a, b, c, d) ;
276
- cols_to_rows ( a, b, c, d) ;
300
+ unsafe fn double_quarter_round ( state : State ) -> State {
301
+ let state = add_xor_rot ( state) ;
302
+ cols_to_rows ( add_xor_rot ( rows_to_cols ( state) ) )
277
303
}
278
304
279
305
/// The goal of this function is to transform the state words from:
@@ -313,16 +339,18 @@ unsafe fn double_quarter_round(
313
339
/// - https://github.com/floodyberry/chacha-opt/blob/0ab65cb99f5016633b652edebaf3691ceb4ff753/chacha_blocks_ssse3-64.S#L639-L643
314
340
#[ inline]
315
341
#[ target_feature( enable = "avx2" ) ]
316
- unsafe fn rows_to_cols (
317
- a : & mut StateWord ,
318
- _b : & mut StateWord ,
319
- c : & mut StateWord ,
320
- d : & mut StateWord ,
321
- ) {
342
+ unsafe fn rows_to_cols ( state : State ) -> State {
322
343
// c = ROR256_B(c); d = ROR256_C(d); a = ROR256_D(a);
323
- c. shuffle_epi32 :: < 0b_00_11_10_01 > ( ) ; // _MM_SHUFFLE(0, 3, 2, 1)
324
- d. shuffle_epi32 :: < 0b_01_00_11_10 > ( ) ; // _MM_SHUFFLE(1, 0, 3, 2)
325
- a. shuffle_epi32 :: < 0b_10_01_00_11 > ( ) ; // _MM_SHUFFLE(2, 1, 0, 3)
344
+ let c = state. c . shuffle_epi32 :: < 0b_00_11_10_01 > ( ) ; // _MM_SHUFFLE(0, 3, 2, 1)
345
+ let d = state. d . shuffle_epi32 :: < 0b_01_00_11_10 > ( ) ; // _MM_SHUFFLE(1, 0, 3, 2)
346
+ let a = state. a . shuffle_epi32 :: < 0b_10_01_00_11 > ( ) ; // _MM_SHUFFLE(2, 1, 0, 3)
347
+
348
+ State {
349
+ a,
350
+ b : state. b ,
351
+ c,
352
+ d,
353
+ }
326
354
}
327
355
328
356
/// The goal of this function is to transform the state words from:
@@ -344,38 +372,38 @@ unsafe fn rows_to_cols(
344
372
/// reversing the transformation of [`rows_to_cols`].
345
373
#[ inline]
346
374
#[ target_feature( enable = "avx2" ) ]
347
- unsafe fn cols_to_rows (
348
- a : & mut StateWord ,
349
- _b : & mut StateWord ,
350
- c : & mut StateWord ,
351
- d : & mut StateWord ,
352
- ) {
375
+ unsafe fn cols_to_rows ( state : State ) -> State {
353
376
// c = ROR256_D(c); d = ROR256_C(d); a = ROR256_B(a);
354
- c. shuffle_epi32 :: < 0b_10_01_00_11 > ( ) ; // _MM_SHUFFLE(2, 1, 0, 3)
355
- d. shuffle_epi32 :: < 0b_01_00_11_10 > ( ) ; // _MM_SHUFFLE(1, 0, 3, 2)
356
- a. shuffle_epi32 :: < 0b_00_11_10_01 > ( ) ; // _MM_SHUFFLE(0, 3, 2, 1)
377
+ let c = state. c . shuffle_epi32 :: < 0b_10_01_00_11 > ( ) ; // _MM_SHUFFLE(2, 1, 0, 3)
378
+ let d = state. d . shuffle_epi32 :: < 0b_01_00_11_10 > ( ) ; // _MM_SHUFFLE(1, 0, 3, 2)
379
+ let a = state. a . shuffle_epi32 :: < 0b_00_11_10_01 > ( ) ; // _MM_SHUFFLE(0, 3, 2, 1)
380
+
381
+ State {
382
+ a,
383
+ b : state. b ,
384
+ c,
385
+ d,
386
+ }
357
387
}
358
388
359
389
#[ inline]
360
390
#[ target_feature( enable = "avx2" ) ]
361
- unsafe fn add_xor_rot ( a : & mut StateWord , b : & mut StateWord , c : & mut StateWord , d : & mut StateWord ) {
391
+ unsafe fn add_xor_rot ( state : State ) -> State {
362
392
// a = ADD256_32(a,b); d = XOR256(d,a); d = ROL256_16(d);
363
- a. add_assign_epi32 ( b) ;
364
- d. xor_assign ( a) ;
365
- d. rol_16 ( ) ;
393
+ let a = state. a . add_epi32 ( state. b ) ;
394
+ let d = state. d . xor ( a) . rol_16 ( ) ;
366
395
367
396
// c = ADD256_32(c,d); b = XOR256(b,c); b = ROL256_12(b);
368
- c. add_assign_epi32 ( d) ;
369
- b. xor_assign ( c) ;
370
- b. rol :: < 12 , 20 > ( ) ;
397
+ let c = state. c . add_epi32 ( d) ;
398
+ let b = state. b . xor ( c) . rol :: < 12 , 20 > ( ) ;
371
399
372
400
// a = ADD256_32(a,b); d = XOR256(d,a); d = ROL256_8(d);
373
- a. add_assign_epi32 ( b) ;
374
- d. xor_assign ( a) ;
375
- d. rol_8 ( ) ;
401
+ let a = a. add_epi32 ( b) ;
402
+ let d = d. xor ( a) . rol_8 ( ) ;
376
403
377
404
// c = ADD256_32(c,d); b = XOR256(b,c); b = ROL256_7(b);
378
- c. add_assign_epi32 ( d) ;
379
- b. xor_assign ( c) ;
380
- b. rol :: < 7 , 25 > ( ) ;
405
+ let c = c. add_epi32 ( d) ;
406
+ let b = b. xor ( c) . rol :: < 7 , 25 > ( ) ;
407
+
408
+ State { a, b, c, d }
381
409
}
0 commit comments