Skip to content

Commit 7c86911

Browse files
authored
chacha20: Remove mutable borrows from AVX2 backend (#268)
The use of `&mut StateWord` everywhere caused a `vmovdqa` to be inserted after almost every operation, and also caused the diagonalization to use `vpermilps` instead of seeing the optimisation to `vpshufd`. The new `State` struct helps to manage the passing-around of owned `StateWord`s.
1 parent 818c4ac commit 7c86911

File tree

1 file changed

+146
-118
lines changed

1 file changed

+146
-118
lines changed

chacha20/src/backend/avx2.rs

Lines changed: 146 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -35,90 +35,115 @@ union StateWord {
3535

3636
impl StateWord {
3737
#[inline]
38+
#[must_use]
3839
#[target_feature(enable = "avx2")]
39-
unsafe fn add_assign_epi32(&mut self, rhs: &Self) {
40-
self.avx = [
41-
_mm256_add_epi32(self.avx[0], rhs.avx[0]),
42-
_mm256_add_epi32(self.avx[1], rhs.avx[1]),
43-
];
40+
unsafe fn add_epi32(self, rhs: Self) -> Self {
41+
StateWord {
42+
avx: [
43+
_mm256_add_epi32(self.avx[0], rhs.avx[0]),
44+
_mm256_add_epi32(self.avx[1], rhs.avx[1]),
45+
],
46+
}
4447
}
4548

4649
#[inline]
50+
#[must_use]
4751
#[target_feature(enable = "avx2")]
48-
unsafe fn xor_assign(&mut self, rhs: &Self) {
49-
self.avx = [
50-
_mm256_xor_si256(self.avx[0], rhs.avx[0]),
51-
_mm256_xor_si256(self.avx[1], rhs.avx[1]),
52-
];
52+
unsafe fn xor(self, rhs: Self) -> Self {
53+
StateWord {
54+
avx: [
55+
_mm256_xor_si256(self.avx[0], rhs.avx[0]),
56+
_mm256_xor_si256(self.avx[1], rhs.avx[1]),
57+
],
58+
}
5359
}
5460

5561
#[inline]
62+
#[must_use]
5663
#[target_feature(enable = "avx2")]
57-
unsafe fn shuffle_epi32<const MASK: i32>(&mut self) {
58-
self.avx = [
59-
_mm256_shuffle_epi32(self.avx[0], MASK),
60-
_mm256_shuffle_epi32(self.avx[1], MASK),
61-
];
64+
unsafe fn shuffle_epi32<const MASK: i32>(self) -> Self {
65+
StateWord {
66+
avx: [
67+
_mm256_shuffle_epi32(self.avx[0], MASK),
68+
_mm256_shuffle_epi32(self.avx[1], MASK),
69+
],
70+
}
6271
}
6372

6473
#[inline]
74+
#[must_use]
6575
#[target_feature(enable = "avx2")]
66-
unsafe fn rol<const BY: i32, const REST: i32>(&mut self) {
67-
self.avx = [
68-
_mm256_xor_si256(
69-
_mm256_slli_epi32(self.avx[0], BY),
70-
_mm256_srli_epi32(self.avx[0], REST),
71-
),
72-
_mm256_xor_si256(
73-
_mm256_slli_epi32(self.avx[1], BY),
74-
_mm256_srli_epi32(self.avx[1], REST),
75-
),
76-
];
76+
unsafe fn rol<const BY: i32, const REST: i32>(self) -> Self {
77+
StateWord {
78+
avx: [
79+
_mm256_xor_si256(
80+
_mm256_slli_epi32(self.avx[0], BY),
81+
_mm256_srli_epi32(self.avx[0], REST),
82+
),
83+
_mm256_xor_si256(
84+
_mm256_slli_epi32(self.avx[1], BY),
85+
_mm256_srli_epi32(self.avx[1], REST),
86+
),
87+
],
88+
}
7789
}
7890

7991
#[inline]
92+
#[must_use]
8093
#[target_feature(enable = "avx2")]
81-
unsafe fn rol_8(&mut self) {
82-
self.avx = [
83-
_mm256_shuffle_epi8(
84-
self.avx[0],
85-
_mm256_set_epi8(
86-
14, 13, 12, 15, 10, 9, 8, 11, 6, 5, 4, 7, 2, 1, 0, 3, 14, 13, 12, 15, 10, 9, 8,
87-
11, 6, 5, 4, 7, 2, 1, 0, 3,
94+
unsafe fn rol_8(self) -> Self {
95+
StateWord {
96+
avx: [
97+
_mm256_shuffle_epi8(
98+
self.avx[0],
99+
_mm256_set_epi8(
100+
14, 13, 12, 15, 10, 9, 8, 11, 6, 5, 4, 7, 2, 1, 0, 3, 14, 13, 12, 15, 10,
101+
9, 8, 11, 6, 5, 4, 7, 2, 1, 0, 3,
102+
),
88103
),
89-
),
90-
_mm256_shuffle_epi8(
91-
self.avx[1],
92-
_mm256_set_epi8(
93-
14, 13, 12, 15, 10, 9, 8, 11, 6, 5, 4, 7, 2, 1, 0, 3, 14, 13, 12, 15, 10, 9, 8,
94-
11, 6, 5, 4, 7, 2, 1, 0, 3,
104+
_mm256_shuffle_epi8(
105+
self.avx[1],
106+
_mm256_set_epi8(
107+
14, 13, 12, 15, 10, 9, 8, 11, 6, 5, 4, 7, 2, 1, 0, 3, 14, 13, 12, 15, 10,
108+
9, 8, 11, 6, 5, 4, 7, 2, 1, 0, 3,
109+
),
95110
),
96-
),
97-
];
111+
],
112+
}
98113
}
99114

100115
#[inline]
116+
#[must_use]
101117
#[target_feature(enable = "avx2")]
102-
unsafe fn rol_16(&mut self) {
103-
self.avx = [
104-
_mm256_shuffle_epi8(
105-
self.avx[0],
106-
_mm256_set_epi8(
107-
13, 12, 15, 14, 9, 8, 11, 10, 5, 4, 7, 6, 1, 0, 3, 2, 13, 12, 15, 14, 9, 8, 11,
108-
10, 5, 4, 7, 6, 1, 0, 3, 2,
118+
unsafe fn rol_16(self) -> Self {
119+
StateWord {
120+
avx: [
121+
_mm256_shuffle_epi8(
122+
self.avx[0],
123+
_mm256_set_epi8(
124+
13, 12, 15, 14, 9, 8, 11, 10, 5, 4, 7, 6, 1, 0, 3, 2, 13, 12, 15, 14, 9, 8,
125+
11, 10, 5, 4, 7, 6, 1, 0, 3, 2,
126+
),
109127
),
110-
),
111-
_mm256_shuffle_epi8(
112-
self.avx[1],
113-
_mm256_set_epi8(
114-
13, 12, 15, 14, 9, 8, 11, 10, 5, 4, 7, 6, 1, 0, 3, 2, 13, 12, 15, 14, 9, 8, 11,
115-
10, 5, 4, 7, 6, 1, 0, 3, 2,
128+
_mm256_shuffle_epi8(
129+
self.avx[1],
130+
_mm256_set_epi8(
131+
13, 12, 15, 14, 9, 8, 11, 10, 5, 4, 7, 6, 1, 0, 3, 2, 13, 12, 15, 14, 9, 8,
132+
11, 10, 5, 4, 7, 6, 1, 0, 3, 2,
133+
),
116134
),
117-
),
118-
];
135+
],
136+
}
119137
}
120138
}
121139

140+
struct State {
141+
a: StateWord,
142+
b: StateWord,
143+
c: StateWord,
144+
d: StateWord,
145+
}
146+
122147
/// The ChaCha20 core function (AVX2 accelerated implementation for x86/x86_64)
123148
// TODO(tarcieri): zeroize?
124149
#[derive(Clone)]
@@ -152,10 +177,14 @@ impl<R: Rounds> Core<R> {
152177
#[inline]
153178
pub fn generate(&self, counter: u64, output: &mut [u8]) {
154179
unsafe {
155-
let (mut v0, mut v1, mut v2) = (self.v0, self.v1, self.v2);
156-
let mut v3 = iv_setup(self.iv, counter);
157-
self.rounds(&mut v0, &mut v1, &mut v2, &mut v3);
158-
store(v0, v1, v2, v3, output);
180+
let state = State {
181+
a: self.v0,
182+
b: self.v1,
183+
c: self.v2,
184+
d: iv_setup(self.iv, counter),
185+
};
186+
let state = self.rounds(state);
187+
store(state.a, state.b, state.c, state.d, output);
159188
}
160189
}
161190

@@ -166,14 +195,22 @@ impl<R: Rounds> Core<R> {
166195
debug_assert_eq!(output.len(), BUFFER_SIZE);
167196

168197
unsafe {
169-
let (mut v0, mut v1, mut v2) = (self.v0, self.v1, self.v2);
170-
let mut v3 = iv_setup(self.iv, counter);
171-
self.rounds(&mut v0, &mut v1, &mut v2, &mut v3);
198+
let state = State {
199+
a: self.v0,
200+
b: self.v1,
201+
c: self.v2,
202+
d: iv_setup(self.iv, counter),
203+
};
204+
let state = self.rounds(state);
172205

173206
for i in 0..BLOCKS {
174207
for (chunk, a) in output[i * BLOCK_SIZE..(i + 1) * BLOCK_SIZE]
175208
.chunks_mut(0x10)
176-
.zip([v0, v1, v2, v3].iter().map(|s| s.blocks[i]))
209+
.zip(
210+
[state.a, state.b, state.c, state.d]
211+
.iter()
212+
.map(|s| s.blocks[i]),
213+
)
177214
{
178215
let b = _mm_loadu_si128(chunk.as_ptr() as *const __m128i);
179216
let out = _mm_xor_si128(a, b);
@@ -185,23 +222,19 @@ impl<R: Rounds> Core<R> {
185222

186223
#[inline]
187224
#[target_feature(enable = "avx2")]
188-
unsafe fn rounds(
189-
&self,
190-
v0: &mut StateWord,
191-
v1: &mut StateWord,
192-
v2: &mut StateWord,
193-
v3: &mut StateWord,
194-
) {
195-
let v3_orig = *v3;
225+
unsafe fn rounds(&self, mut state: State) -> State {
226+
let d_orig = state.d;
196227

197228
for _ in 0..(R::COUNT / 2) {
198-
double_quarter_round(v0, v1, v2, v3);
229+
state = double_quarter_round(state);
199230
}
200231

201-
v0.add_assign_epi32(&self.v0);
202-
v1.add_assign_epi32(&self.v1);
203-
v2.add_assign_epi32(&self.v2);
204-
v3.add_assign_epi32(&v3_orig);
232+
State {
233+
a: state.a.add_epi32(self.v0),
234+
b: state.b.add_epi32(self.v1),
235+
c: state.c.add_epi32(self.v2),
236+
d: state.d.add_epi32(d_orig),
237+
}
205238
}
206239
}
207240

@@ -264,16 +297,9 @@ unsafe fn store(v0: StateWord, v1: StateWord, v2: StateWord, v3: StateWord, outp
264297

265298
#[inline]
266299
#[target_feature(enable = "avx2")]
267-
unsafe fn double_quarter_round(
268-
a: &mut StateWord,
269-
b: &mut StateWord,
270-
c: &mut StateWord,
271-
d: &mut StateWord,
272-
) {
273-
add_xor_rot(a, b, c, d);
274-
rows_to_cols(a, b, c, d);
275-
add_xor_rot(a, b, c, d);
276-
cols_to_rows(a, b, c, d);
300+
unsafe fn double_quarter_round(state: State) -> State {
301+
let state = add_xor_rot(state);
302+
cols_to_rows(add_xor_rot(rows_to_cols(state)))
277303
}
278304

279305
/// The goal of this function is to transform the state words from:
@@ -313,16 +339,18 @@ unsafe fn double_quarter_round(
313339
/// - https://github.com/floodyberry/chacha-opt/blob/0ab65cb99f5016633b652edebaf3691ceb4ff753/chacha_blocks_ssse3-64.S#L639-L643
314340
#[inline]
315341
#[target_feature(enable = "avx2")]
316-
unsafe fn rows_to_cols(
317-
a: &mut StateWord,
318-
_b: &mut StateWord,
319-
c: &mut StateWord,
320-
d: &mut StateWord,
321-
) {
342+
unsafe fn rows_to_cols(state: State) -> State {
322343
// c = ROR256_B(c); d = ROR256_C(d); a = ROR256_D(a);
323-
c.shuffle_epi32::<0b_00_11_10_01>(); // _MM_SHUFFLE(0, 3, 2, 1)
324-
d.shuffle_epi32::<0b_01_00_11_10>(); // _MM_SHUFFLE(1, 0, 3, 2)
325-
a.shuffle_epi32::<0b_10_01_00_11>(); // _MM_SHUFFLE(2, 1, 0, 3)
344+
let c = state.c.shuffle_epi32::<0b_00_11_10_01>(); // _MM_SHUFFLE(0, 3, 2, 1)
345+
let d = state.d.shuffle_epi32::<0b_01_00_11_10>(); // _MM_SHUFFLE(1, 0, 3, 2)
346+
let a = state.a.shuffle_epi32::<0b_10_01_00_11>(); // _MM_SHUFFLE(2, 1, 0, 3)
347+
348+
State {
349+
a,
350+
b: state.b,
351+
c,
352+
d,
353+
}
326354
}
327355

328356
/// The goal of this function is to transform the state words from:
@@ -344,38 +372,38 @@ unsafe fn rows_to_cols(
344372
/// reversing the transformation of [`rows_to_cols`].
345373
#[inline]
346374
#[target_feature(enable = "avx2")]
347-
unsafe fn cols_to_rows(
348-
a: &mut StateWord,
349-
_b: &mut StateWord,
350-
c: &mut StateWord,
351-
d: &mut StateWord,
352-
) {
375+
unsafe fn cols_to_rows(state: State) -> State {
353376
// c = ROR256_D(c); d = ROR256_C(d); a = ROR256_B(a);
354-
c.shuffle_epi32::<0b_10_01_00_11>(); // _MM_SHUFFLE(2, 1, 0, 3)
355-
d.shuffle_epi32::<0b_01_00_11_10>(); // _MM_SHUFFLE(1, 0, 3, 2)
356-
a.shuffle_epi32::<0b_00_11_10_01>(); // _MM_SHUFFLE(0, 3, 2, 1)
377+
let c = state.c.shuffle_epi32::<0b_10_01_00_11>(); // _MM_SHUFFLE(2, 1, 0, 3)
378+
let d = state.d.shuffle_epi32::<0b_01_00_11_10>(); // _MM_SHUFFLE(1, 0, 3, 2)
379+
let a = state.a.shuffle_epi32::<0b_00_11_10_01>(); // _MM_SHUFFLE(0, 3, 2, 1)
380+
381+
State {
382+
a,
383+
b: state.b,
384+
c,
385+
d,
386+
}
357387
}
358388

359389
#[inline]
360390
#[target_feature(enable = "avx2")]
361-
unsafe fn add_xor_rot(a: &mut StateWord, b: &mut StateWord, c: &mut StateWord, d: &mut StateWord) {
391+
unsafe fn add_xor_rot(state: State) -> State {
362392
// a = ADD256_32(a,b); d = XOR256(d,a); d = ROL256_16(d);
363-
a.add_assign_epi32(b);
364-
d.xor_assign(a);
365-
d.rol_16();
393+
let a = state.a.add_epi32(state.b);
394+
let d = state.d.xor(a).rol_16();
366395

367396
// c = ADD256_32(c,d); b = XOR256(b,c); b = ROL256_12(b);
368-
c.add_assign_epi32(d);
369-
b.xor_assign(c);
370-
b.rol::<12, 20>();
397+
let c = state.c.add_epi32(d);
398+
let b = state.b.xor(c).rol::<12, 20>();
371399

372400
// a = ADD256_32(a,b); d = XOR256(d,a); d = ROL256_8(d);
373-
a.add_assign_epi32(b);
374-
d.xor_assign(a);
375-
d.rol_8();
401+
let a = a.add_epi32(b);
402+
let d = d.xor(a).rol_8();
376403

377404
// c = ADD256_32(c,d); b = XOR256(b,c); b = ROL256_7(b);
378-
c.add_assign_epi32(d);
379-
b.xor_assign(c);
380-
b.rol::<7, 25>();
405+
let c = c.add_epi32(d);
406+
let b = b.xor(c).rol::<7, 25>();
407+
408+
State { a, b, c, d }
381409
}

0 commit comments

Comments
 (0)