Skip to content

Commit 973fe29

Browse files
polyval: implement Karatsuba multiplication for arm64 (#181)
Improves performance by ~200 MB/s on a 2020 M1. Signed-off-by: Eric Lagergren <[email protected]>
1 parent f897eac commit 973fe29

File tree

1 file changed

+95
-32
lines changed

1 file changed

+95
-32
lines changed

polyval/src/backend/pmull.rs

+95-32
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,6 @@ impl Reset for Polyval {
6767
}
6868

6969
impl Polyval {
70-
/// Mask value used when performing reduction.
71-
/// This corresponds to POLYVAL's polynomial with the highest bit unset.
72-
const MASK: u128 = 1 << 127 | 1 << 126 | 1 << 121 | 1;
73-
7470
/// Get POLYVAL output.
7571
pub(crate) fn finalize(self) -> Tag {
7672
unsafe { mem::transmute(self.y) }
@@ -81,42 +77,109 @@ impl Polyval {
8177
#[inline]
8278
#[target_feature(enable = "neon")]
8379
unsafe fn mul(&mut self, x: &Block) {
84-
let h = self.h;
8580
let y = veorq_u8(self.y, vld1q_u8(x.as_ptr()));
86-
87-
// polynomial multiply
88-
let z = vdupq_n_u8(0);
89-
let r0 = pmull::<0, 0>(h, y);
90-
let r1 = pmull::<1, 1>(h, y);
91-
let t0 = pmull::<0, 1>(h, y);
92-
let t1 = pmull::<1, 0>(h, y);
93-
let t0 = veorq_u8(t0, t1);
94-
let t1 = vextq_u8(z, t0, 8);
95-
let r0 = veorq_u8(r0, t1);
96-
let t1 = vextq_u8(t0, z, 8);
97-
let r1 = veorq_u8(r1, t1);
98-
99-
// polynomial reduction
100-
let p = mem::transmute(Self::MASK);
101-
let t0 = pmull::<0, 1>(r0, p);
102-
let t1 = vextq_u8(t0, t0, 8);
103-
let r0 = veorq_u8(r0, t1);
104-
let t1 = pmull::<1, 1>(r0, p);
105-
let r0 = veorq_u8(r0, t1);
106-
107-
self.y = veorq_u8(r0, r1);
81+
let (h, m, l) = karatsuba1(self.h, y);
82+
let (h, l) = karatsuba2(h, m, l);
83+
self.y = mont_reduce(h, l);
10884
}
10985
}
11086

111-
/// Wrapper for the ARM64 `PMULL` instruction.
112-
#[inline(always)]
113-
unsafe fn pmull<const A_LANE: i32, const B_LANE: i32>(a: uint8x16_t, b: uint8x16_t) -> uint8x16_t {
87+
/// Karatsuba decomposition for `x*y`.
88+
#[inline]
89+
#[target_feature(enable = "neon")]
90+
unsafe fn karatsuba1(x: uint8x16_t, y: uint8x16_t) -> (uint8x16_t, uint8x16_t, uint8x16_t) {
91+
// First Karatsuba step: decompose x and y.
92+
//
93+
// (x1*y0 + x0*y1) = (x1+x0) * (y1+x0) + (x1*y1) + (x0*y0)
94+
// M H L
95+
//
96+
// m = x.hi^x.lo * y.hi^y.lo
97+
let m = pmull(
98+
veorq_u8(x, vextq_u8(x, x, 8)), // x.hi^x.lo
99+
veorq_u8(y, vextq_u8(y, y, 8)), // y.hi^y.lo
100+
);
101+
let h = pmull2(x, y); // h = x.hi * y.hi
102+
let l = pmull(x, y); // l = x.lo * y.lo
103+
(h, m, l)
104+
}
105+
106+
/// Karatsuba combine.
107+
#[inline]
108+
#[target_feature(enable = "neon")]
109+
unsafe fn karatsuba2(h: uint8x16_t, m: uint8x16_t, l: uint8x16_t) -> (uint8x16_t, uint8x16_t) {
110+
// Second Karatsuba step: combine into a 2n-bit product.
111+
//
112+
// m0 ^= l0 ^ h0 // = m0^(l0^h0)
113+
// m1 ^= l1 ^ h1 // = m1^(l1^h1)
114+
// l1 ^= m0 // = l1^(m0^l0^h0)
115+
// h0 ^= l0 ^ m1 // = h0^(l0^m1^l1^h1)
116+
// h1 ^= l1 // = h1^(l1^m0^l0^h0)
117+
let t = {
118+
// {m0, m1} ^ {l1, h0}
119+
// = {m0^l1, m1^h0}
120+
let t0 = veorq_u8(m, vextq_u8(l, h, 8));
121+
122+
// {h0, h1} ^ {l0, l1}
123+
// = {h0^l0, h1^l1}
124+
let t1 = veorq_u8(h, l);
125+
126+
// {m0^l1, m1^h0} ^ {h0^l0, h1^l1}
127+
// = {m0^l1^h0^l0, m1^h0^h1^l1}
128+
veorq_u8(t0, t1)
129+
};
130+
131+
// {m0^l1^h0^l0, l0}
132+
let x01 = vextq_u8(
133+
vextq_u8(l, l, 8), // {l1, l0}
134+
t,
135+
8,
136+
);
137+
138+
// {h1, m1^h0^h1^l1}
139+
let x23 = vextq_u8(
140+
t,
141+
vextq_u8(h, h, 8), // {h1, h0}
142+
8,
143+
);
144+
145+
(x23, x01)
146+
}
147+
148+
#[inline]
149+
#[target_feature(enable = "neon")]
150+
unsafe fn mont_reduce(x23: uint8x16_t, x01: uint8x16_t) -> uint8x16_t {
151+
// Perform the Montgomery reduction over the 256-bit X.
152+
// [A1:A0] = X0 • poly
153+
// [B1:B0] = [X0 ⊕ A1 : X1 ⊕ A0]
154+
// [C1:C0] = B0 • poly
155+
// [D1:D0] = [B0 ⊕ C1 : B1 ⊕ C0]
156+
// Output: [D1 ⊕ X3 : D0 ⊕ X2]
157+
let poly = vreinterpretq_u8_p128(1 << 127 | 1 << 126 | 1 << 121 | 1 << 63 | 1 << 62 | 1 << 57);
158+
let a = pmull(x01, poly);
159+
let b = veorq_u8(x01, vextq_u8(a, a, 8));
160+
let c = pmull2(b, poly);
161+
veorq_u8(x23, veorq_u8(c, b))
162+
}
163+
164+
/// Multiplies the low bits in `a` and `b`.
165+
#[inline]
166+
#[target_feature(enable = "neon")]
167+
unsafe fn pmull(a: uint8x16_t, b: uint8x16_t) -> uint8x16_t {
114168
mem::transmute(vmull_p64(
115-
vgetq_lane_u64(vreinterpretq_u64_u8(a), A_LANE),
116-
vgetq_lane_u64(vreinterpretq_u64_u8(b), B_LANE),
169+
vgetq_lane_u64(vreinterpretq_u64_u8(a), 0),
170+
vgetq_lane_u64(vreinterpretq_u64_u8(b), 0),
117171
))
118172
}
119173

174+
/// Multiplies the high bits in `a` and `b`.
175+
#[inline]
176+
#[target_feature(enable = "neon")]
177+
unsafe fn pmull2(a: uint8x16_t, b: uint8x16_t) -> uint8x16_t {
178+
mem::transmute(vmull_p64(
179+
vgetq_lane_u64(vreinterpretq_u64_u8(a), 1),
180+
vgetq_lane_u64(vreinterpretq_u64_u8(b), 1),
181+
))
182+
}
120183
// TODO(tarcieri): zeroize support
121184
// #[cfg(feature = "zeroize")]
122185
// impl Drop for Polyval {

0 commit comments

Comments
 (0)