Skip to content

Commit aa0efb9

Browse files
committed
Add Neon mld_polyvecl_pointwise_acc_montgomery_l{4,5,7}_native
These are basically written from scratch inspired by the same functions in mlkem-native. Resolves #257
1 parent a1ad592 commit aa0efb9

7 files changed

+523
-2
lines changed

mldsa/native/aarch64/meta.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
/* Set of primitives that this backend replaces */
1111
#define MLD_USE_NATIVE_NTT
1212
#define MLD_USE_NATIVE_INTT
13+
#define MLD_USE_NATIVE_POLYVECL_POINTWISE_ACC_MONTGOMERY
1314

1415
/* Identifier for this backend so that source and assembly files
1516
* in the build can be appropriately guarded. */
@@ -31,6 +32,27 @@ static MLD_INLINE void mld_intt_native(int32_t data[MLDSA_N])
3132
mld_aarch64_intt_zetas_layer123456);
3233
}
3334

35+
static MLD_INLINE void mld_polyvecl_pointwise_acc_montgomery_l4_native(
36+
int32_t w[MLDSA_N], const int32_t u[4 * MLDSA_N],
37+
const int32_t v[4 * MLDSA_N])
38+
{
39+
mld_polyvecl_pointwise_acc_montgomery_l4_asm(w, u, v);
40+
}
41+
42+
static MLD_INLINE void mld_polyvecl_pointwise_acc_montgomery_l5_native(
43+
int32_t w[MLDSA_N], const int32_t u[5 * MLDSA_N],
44+
const int32_t v[5 * MLDSA_N])
45+
{
46+
mld_polyvecl_pointwise_acc_montgomery_l5_asm(w, u, v);
47+
}
48+
49+
static MLD_INLINE void mld_polyvecl_pointwise_acc_montgomery_l7_native(
50+
int32_t w[MLDSA_N], const int32_t u[7 * MLDSA_N],
51+
const int32_t v[7 * MLDSA_N])
52+
{
53+
mld_polyvecl_pointwise_acc_montgomery_l7_asm(w, u, v);
54+
}
55+
3456
#endif /* !__ASSEMBLER__ */
3557

3658
#endif /* !MLD_NATIVE_AARCH64_META_H */

mldsa/native/aarch64/src/arith_native_aarch64.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,4 +32,19 @@ void mld_ntt_asm(int32_t *, const int32_t *, const int32_t *);
3232
#define mld_intt_asm MLD_NAMESPACE(intt_asm)
3333
void mld_intt_asm(int32_t *, const int32_t *, const int32_t *);
3434

35+
#define mld_polyvecl_pointwise_acc_montgomery_l4_asm \
36+
MLD_NAMESPACE(polyvecl_pointwise_acc_montgomery_l4_asm)
37+
void mld_polyvecl_pointwise_acc_montgomery_l4_asm(int32_t *, const int32_t *,
38+
const int32_t *);
39+
40+
#define mld_polyvecl_pointwise_acc_montgomery_l5_asm \
41+
MLD_NAMESPACE(polyvecl_pointwise_acc_montgomery_l5_asm)
42+
void mld_polyvecl_pointwise_acc_montgomery_l5_asm(int32_t *, const int32_t *,
43+
const int32_t *);
44+
45+
#define mld_polyvecl_pointwise_acc_montgomery_l7_asm \
46+
MLD_NAMESPACE(polyvecl_pointwise_acc_montgomery_l7_asm)
47+
void mld_polyvecl_pointwise_acc_montgomery_l7_asm(int32_t *, const int32_t *,
48+
const int32_t *);
49+
3550
#endif /* !MLD_NATIVE_AARCH64_SRC_ARITH_NATIVE_AARCH64_H */
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
/* Copyright (c) The mldsa-native project authors
2+
* SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT
3+
*/
4+
5+
#include "../../../common.h"
6+
#if defined(MLD_ARITH_BACKEND_AARCH64)
7+
8+
.macro montgomery_reduce_long res, inl, inh
9+
uzp1 t0.4s, \inl\().4s, \inh\().4s
10+
mul t0.4s, t0.4s, modulus_twisted.4s
11+
smlal \inl\().2d, t0.2s, modulus.2s
12+
smlal2 \inh\().2d, t0.4s, modulus.4s
13+
uzp2 \res\().4s, \inl\().4s, \inh\().4s
14+
.endm
15+
16+
.macro load_polys a, b, a_ptr, b_ptr
17+
ldr q_\()\a, [\a_ptr], #16
18+
ldr q_\()\b, [\b_ptr], #16
19+
.endm
20+
21+
.macro pmull dl, dh, a, b
22+
smull \dl\().2d, \a\().2s, \b\().2s
23+
smull2 \dh\().2d, \a\().4s, \b\().4s
24+
.endm
25+
26+
.macro pmlal dl, dh, a, b
27+
smlal \dl\().2d, \a\().2s, \b\().2s
28+
smlal2 \dh\().2d, \a\().4s, \b\().4s
29+
.endm
30+
31+
.macro save_vregs
32+
sub sp, sp, #(16*4)
33+
stp d8, d9, [sp, #16*0]
34+
stp d10, d11, [sp, #16*1]
35+
stp d12, d13, [sp, #16*2]
36+
stp d14, d15, [sp, #16*3]
37+
.endm
38+
39+
.macro restore_vregs
40+
ldp d8, d9, [sp, #16*0]
41+
ldp d10, d11, [sp, #16*1]
42+
ldp d12, d13, [sp, #16*2]
43+
ldp d14, d15, [sp, #16*3]
44+
add sp, sp, #(16*4)
45+
.endm
46+
47+
.macro push_stack
48+
save_vregs
49+
.endm
50+
51+
.macro pop_stack
52+
restore_vregs
53+
.endm
54+
55+
out_ptr .req x0
56+
a0_ptr .req x1
57+
b0_ptr .req x2
58+
a1_ptr .req x3
59+
b1_ptr .req x4
60+
a2_ptr .req x5
61+
b2_ptr .req x6
62+
a3_ptr .req x7
63+
b3_ptr .req x8
64+
count .req x9
65+
wtmp .req w9
66+
67+
modulus .req v0
68+
modulus_twisted .req v1
69+
70+
aa .req v2
71+
bb .req v3
72+
res .req v4
73+
resl .req v5
74+
resh .req v6
75+
t0 .req v7
76+
77+
q_aa .req q2
78+
q_bb .req q3
79+
q_res .req q4
80+
81+
.text
82+
.global MLD_ASM_NAMESPACE(polyvecl_pointwise_acc_montgomery_l4_asm)
83+
.balign 4
84+
MLD_ASM_FN_SYMBOL(polyvecl_pointwise_acc_montgomery_l4_asm)
85+
push_stack
86+
87+
// load q = 8380417
88+
movz wtmp, #57345
89+
movk wtmp, #127, lsl #16
90+
dup modulus.4s, wtmp
91+
92+
// load -q^-1 = 4236238847
93+
movz wtmp, #57343
94+
movk wtmp, #64639, lsl #16
95+
dup modulus_twisted.4s, wtmp
96+
97+
// Computed bases of vector entries
98+
add a1_ptr, a0_ptr, #(1 * 1024)
99+
add a2_ptr, a0_ptr, #(2 * 1024)
100+
add a3_ptr, a0_ptr, #(3 * 1024)
101+
102+
add b1_ptr, b0_ptr, #(1 * 1024)
103+
add b2_ptr, b0_ptr, #(2 * 1024)
104+
add b3_ptr, b0_ptr, #(3 * 1024)
105+
106+
mov count, #(MLDSA_N / 4)
107+
l4_loop_start:
108+
load_polys aa, bb, a0_ptr, b0_ptr
109+
pmull resl, resh, aa, bb
110+
load_polys aa, bb, a1_ptr, b1_ptr
111+
pmlal resl, resh, aa, bb
112+
load_polys aa, bb, a2_ptr, b2_ptr
113+
pmlal resl, resh, aa, bb
114+
load_polys aa, bb, a3_ptr, b3_ptr
115+
pmlal resl, resh, aa, bb
116+
117+
montgomery_reduce_long res, resl, resh
118+
119+
str q_res, [out_ptr], #16
120+
121+
subs count, count, #1
122+
cbnz count, l4_loop_start
123+
124+
pop_stack
125+
ret
126+
#endif /* MLD_ARITH_BACKEND_AARCH64 */
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
/* Copyright (c) The mldsa-native project authors
2+
* SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT
3+
*/
4+
5+
#include "../../../common.h"
6+
#if defined(MLD_ARITH_BACKEND_AARCH64)
7+
8+
.macro montgomery_reduce_long res, inl, inh
9+
uzp1 t0.4s, \inl\().4s, \inh\().4s
10+
mul t0.4s, t0.4s, modulus_twisted.4s
11+
smlal \inl\().2d, t0.2s, modulus.2s
12+
smlal2 \inh\().2d, t0.4s, modulus.4s
13+
uzp2 \res\().4s, \inl\().4s, \inh\().4s
14+
.endm
15+
16+
.macro load_polys a, b, a_ptr, b_ptr
17+
ldr q_\()\a, [\a_ptr], #16
18+
ldr q_\()\b, [\b_ptr], #16
19+
.endm
20+
21+
.macro pmull dl, dh, a, b
22+
smull \dl\().2d, \a\().2s, \b\().2s
23+
smull2 \dh\().2d, \a\().4s, \b\().4s
24+
.endm
25+
26+
.macro pmlal dl, dh, a, b
27+
smlal \dl\().2d, \a\().2s, \b\().2s
28+
smlal2 \dh\().2d, \a\().4s, \b\().4s
29+
.endm
30+
31+
.macro save_vregs
32+
sub sp, sp, #(16*4)
33+
stp d8, d9, [sp, #16*0]
34+
stp d10, d11, [sp, #16*1]
35+
stp d12, d13, [sp, #16*2]
36+
stp d14, d15, [sp, #16*3]
37+
.endm
38+
39+
.macro restore_vregs
40+
ldp d8, d9, [sp, #16*0]
41+
ldp d10, d11, [sp, #16*1]
42+
ldp d12, d13, [sp, #16*2]
43+
ldp d14, d15, [sp, #16*3]
44+
add sp, sp, #(16*4)
45+
.endm
46+
47+
.macro push_stack
48+
save_vregs
49+
.endm
50+
51+
.macro pop_stack
52+
restore_vregs
53+
.endm
54+
55+
out_ptr .req x0
56+
a0_ptr .req x1
57+
b0_ptr .req x2
58+
a1_ptr .req x3
59+
b1_ptr .req x4
60+
a2_ptr .req x5
61+
b2_ptr .req x6
62+
a3_ptr .req x7
63+
b3_ptr .req x8
64+
a4_ptr .req x9
65+
b4_ptr .req x10
66+
count .req x11
67+
wtmp .req w11
68+
69+
modulus .req v0
70+
modulus_twisted .req v1
71+
72+
aa .req v2
73+
bb .req v3
74+
res .req v4
75+
resl .req v5
76+
resh .req v6
77+
t0 .req v7
78+
79+
q_aa .req q2
80+
q_bb .req q3
81+
q_res .req q4
82+
83+
.text
84+
.global MLD_ASM_NAMESPACE(polyvecl_pointwise_acc_montgomery_l5_asm)
85+
.balign 4
86+
MLD_ASM_FN_SYMBOL(polyvecl_pointwise_acc_montgomery_l5_asm)
87+
push_stack
88+
89+
// load q = 8380417
90+
movz wtmp, #57345
91+
movk wtmp, #127, lsl #16
92+
dup modulus.4s, wtmp
93+
94+
// load -q^-1 = 4236238847
95+
movz wtmp, #57343
96+
movk wtmp, #64639, lsl #16
97+
dup modulus_twisted.4s, wtmp
98+
99+
// Computed bases of vector entries
100+
add a1_ptr, a0_ptr, #(1 * 1024)
101+
add a2_ptr, a0_ptr, #(2 * 1024)
102+
add a3_ptr, a0_ptr, #(3 * 1024)
103+
add a4_ptr, a0_ptr, #(4 * 1024)
104+
105+
add b1_ptr, b0_ptr, #(1 * 1024)
106+
add b2_ptr, b0_ptr, #(2 * 1024)
107+
add b3_ptr, b0_ptr, #(3 * 1024)
108+
add b4_ptr, b0_ptr, #(4 * 1024)
109+
110+
mov count, #(MLDSA_N / 4)
111+
l5_loop_start:
112+
load_polys aa, bb, a0_ptr, b0_ptr
113+
pmull resl, resh, aa, bb
114+
load_polys aa, bb, a1_ptr, b1_ptr
115+
pmlal resl, resh, aa, bb
116+
load_polys aa, bb, a2_ptr, b2_ptr
117+
pmlal resl, resh, aa, bb
118+
load_polys aa, bb, a3_ptr, b3_ptr
119+
pmlal resl, resh, aa, bb
120+
load_polys aa, bb, a4_ptr, b4_ptr
121+
pmlal resl, resh, aa, bb
122+
123+
montgomery_reduce_long res, resl, resh
124+
125+
str q_res, [out_ptr], #16
126+
127+
subs count, count, #1
128+
cbnz count, l5_loop_start
129+
130+
pop_stack
131+
ret
132+
#endif /* MLD_ARITH_BACKEND_AARCH64 */

0 commit comments

Comments
 (0)