Skip to content

Add Neon mld_polyvecl_pointwise_acc_montgomery_l{4,5,7}_native #281

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions mldsa/native/aarch64/meta.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
/* Set of primitives that this backend replaces */
#define MLD_USE_NATIVE_NTT
#define MLD_USE_NATIVE_INTT
#define MLD_USE_NATIVE_POLYVECL_POINTWISE_ACC_MONTGOMERY

/* Identifier for this backend so that source and assembly files
* in the build can be appropriately guarded. */
Expand All @@ -31,6 +32,27 @@ static MLD_INLINE void mld_intt_native(int32_t data[MLDSA_N])
mld_aarch64_intt_zetas_layer123456);
}

static MLD_INLINE void mld_polyvecl_pointwise_acc_montgomery_l4_native(
int32_t w[MLDSA_N], const int32_t u[4 * MLDSA_N],
const int32_t v[4 * MLDSA_N])
{
mld_polyvecl_pointwise_acc_montgomery_l4_asm(w, u, v);
}

static MLD_INLINE void mld_polyvecl_pointwise_acc_montgomery_l5_native(
int32_t w[MLDSA_N], const int32_t u[5 * MLDSA_N],
const int32_t v[5 * MLDSA_N])
{
mld_polyvecl_pointwise_acc_montgomery_l5_asm(w, u, v);
}

static MLD_INLINE void mld_polyvecl_pointwise_acc_montgomery_l7_native(
int32_t w[MLDSA_N], const int32_t u[7 * MLDSA_N],
const int32_t v[7 * MLDSA_N])
{
mld_polyvecl_pointwise_acc_montgomery_l7_asm(w, u, v);
}

#endif /* !__ASSEMBLER__ */

#endif /* !MLD_NATIVE_AARCH64_META_H */
15 changes: 15 additions & 0 deletions mldsa/native/aarch64/src/arith_native_aarch64.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,19 @@ void mld_ntt_asm(int32_t *, const int32_t *, const int32_t *);
#define mld_intt_asm MLD_NAMESPACE(intt_asm)
void mld_intt_asm(int32_t *, const int32_t *, const int32_t *);

#define mld_polyvecl_pointwise_acc_montgomery_l4_asm \
MLD_NAMESPACE(polyvecl_pointwise_acc_montgomery_l4_asm)
void mld_polyvecl_pointwise_acc_montgomery_l4_asm(int32_t *, const int32_t *,
const int32_t *);

#define mld_polyvecl_pointwise_acc_montgomery_l5_asm \
MLD_NAMESPACE(polyvecl_pointwise_acc_montgomery_l5_asm)
void mld_polyvecl_pointwise_acc_montgomery_l5_asm(int32_t *, const int32_t *,
const int32_t *);

#define mld_polyvecl_pointwise_acc_montgomery_l7_asm \
MLD_NAMESPACE(polyvecl_pointwise_acc_montgomery_l7_asm)
void mld_polyvecl_pointwise_acc_montgomery_l7_asm(int32_t *, const int32_t *,
const int32_t *);

#endif /* !MLD_NATIVE_AARCH64_SRC_ARITH_NATIVE_AARCH64_H */
126 changes: 126 additions & 0 deletions mldsa/native/aarch64/src/mld_polyvecl_pointwise_acc_montgomery_l4.S
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
/* Copyright (c) The mldsa-native project authors
* SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT
*/

#include "../../../common.h"
#if defined(MLD_ARITH_BACKEND_AARCH64)

.macro montgomery_reduce_long res, inl, inh
uzp1 t0.4s, \inl\().4s, \inh\().4s
mul t0.4s, t0.4s, modulus_twisted.4s
smlal \inl\().2d, t0.2s, modulus.2s
smlal2 \inh\().2d, t0.4s, modulus.4s
uzp2 \res\().4s, \inl\().4s, \inh\().4s
.endm

.macro load_polys a, b, a_ptr, b_ptr
ldr q_\()\a, [\a_ptr], #16
ldr q_\()\b, [\b_ptr], #16
.endm

.macro pmull dl, dh, a, b
smull \dl\().2d, \a\().2s, \b\().2s
smull2 \dh\().2d, \a\().4s, \b\().4s
.endm

.macro pmlal dl, dh, a, b
smlal \dl\().2d, \a\().2s, \b\().2s
smlal2 \dh\().2d, \a\().4s, \b\().4s
.endm

.macro save_vregs
sub sp, sp, #(16*4)
stp d8, d9, [sp, #16*0]
stp d10, d11, [sp, #16*1]
stp d12, d13, [sp, #16*2]
stp d14, d15, [sp, #16*3]
.endm

.macro restore_vregs
ldp d8, d9, [sp, #16*0]
ldp d10, d11, [sp, #16*1]
ldp d12, d13, [sp, #16*2]
ldp d14, d15, [sp, #16*3]
add sp, sp, #(16*4)
.endm

.macro push_stack
save_vregs
.endm

.macro pop_stack
restore_vregs
.endm

out_ptr .req x0
a0_ptr .req x1
b0_ptr .req x2
a1_ptr .req x3
b1_ptr .req x4
a2_ptr .req x5
b2_ptr .req x6
a3_ptr .req x7
b3_ptr .req x8
count .req x9
wtmp .req w9

modulus .req v0
modulus_twisted .req v1

aa .req v2
bb .req v3
res .req v4
resl .req v5
resh .req v6
t0 .req v7

q_aa .req q2
q_bb .req q3
q_res .req q4

.text
.global MLD_ASM_NAMESPACE(polyvecl_pointwise_acc_montgomery_l4_asm)
.balign 4
MLD_ASM_FN_SYMBOL(polyvecl_pointwise_acc_montgomery_l4_asm)
push_stack

// load q = 8380417
movz wtmp, #57345
movk wtmp, #127, lsl #16
dup modulus.4s, wtmp

// load -q^-1 = 4236238847
movz wtmp, #57343
movk wtmp, #64639, lsl #16
dup modulus_twisted.4s, wtmp

// Computed bases of vector entries
add a1_ptr, a0_ptr, #(1 * 1024)
add a2_ptr, a0_ptr, #(2 * 1024)
add a3_ptr, a0_ptr, #(3 * 1024)

add b1_ptr, b0_ptr, #(1 * 1024)
add b2_ptr, b0_ptr, #(2 * 1024)
add b3_ptr, b0_ptr, #(3 * 1024)

mov count, #(MLDSA_N / 4)
l4_loop_start:
load_polys aa, bb, a0_ptr, b0_ptr
pmull resl, resh, aa, bb
load_polys aa, bb, a1_ptr, b1_ptr
pmlal resl, resh, aa, bb
load_polys aa, bb, a2_ptr, b2_ptr
pmlal resl, resh, aa, bb
load_polys aa, bb, a3_ptr, b3_ptr
pmlal resl, resh, aa, bb

montgomery_reduce_long res, resl, resh

str q_res, [out_ptr], #16

subs count, count, #1
cbnz count, l4_loop_start

pop_stack
ret
#endif /* MLD_ARITH_BACKEND_AARCH64 */
132 changes: 132 additions & 0 deletions mldsa/native/aarch64/src/mld_polyvecl_pointwise_acc_montgomery_l5.S
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
/* Copyright (c) The mldsa-native project authors
* SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT
*/

#include "../../../common.h"
#if defined(MLD_ARITH_BACKEND_AARCH64)

.macro montgomery_reduce_long res, inl, inh
uzp1 t0.4s, \inl\().4s, \inh\().4s
mul t0.4s, t0.4s, modulus_twisted.4s
smlal \inl\().2d, t0.2s, modulus.2s
smlal2 \inh\().2d, t0.4s, modulus.4s
uzp2 \res\().4s, \inl\().4s, \inh\().4s
.endm

.macro load_polys a, b, a_ptr, b_ptr
ldr q_\()\a, [\a_ptr], #16
ldr q_\()\b, [\b_ptr], #16
.endm

.macro pmull dl, dh, a, b
smull \dl\().2d, \a\().2s, \b\().2s
smull2 \dh\().2d, \a\().4s, \b\().4s
.endm

.macro pmlal dl, dh, a, b
smlal \dl\().2d, \a\().2s, \b\().2s
smlal2 \dh\().2d, \a\().4s, \b\().4s
.endm

.macro save_vregs
sub sp, sp, #(16*4)
stp d8, d9, [sp, #16*0]
stp d10, d11, [sp, #16*1]
stp d12, d13, [sp, #16*2]
stp d14, d15, [sp, #16*3]
.endm

.macro restore_vregs
ldp d8, d9, [sp, #16*0]
ldp d10, d11, [sp, #16*1]
ldp d12, d13, [sp, #16*2]
ldp d14, d15, [sp, #16*3]
add sp, sp, #(16*4)
.endm

.macro push_stack
save_vregs
.endm

.macro pop_stack
restore_vregs
.endm

out_ptr .req x0
a0_ptr .req x1
b0_ptr .req x2
a1_ptr .req x3
b1_ptr .req x4
a2_ptr .req x5
b2_ptr .req x6
a3_ptr .req x7
b3_ptr .req x8
a4_ptr .req x9
b4_ptr .req x10
count .req x11
wtmp .req w11

modulus .req v0
modulus_twisted .req v1

aa .req v2
bb .req v3
res .req v4
resl .req v5
resh .req v6
t0 .req v7

q_aa .req q2
q_bb .req q3
q_res .req q4

.text
.global MLD_ASM_NAMESPACE(polyvecl_pointwise_acc_montgomery_l5_asm)
.balign 4
MLD_ASM_FN_SYMBOL(polyvecl_pointwise_acc_montgomery_l5_asm)
push_stack

// load q = 8380417
movz wtmp, #57345
movk wtmp, #127, lsl #16
dup modulus.4s, wtmp

// load -q^-1 = 4236238847
movz wtmp, #57343
movk wtmp, #64639, lsl #16
dup modulus_twisted.4s, wtmp

// Computed bases of vector entries
add a1_ptr, a0_ptr, #(1 * 1024)
add a2_ptr, a0_ptr, #(2 * 1024)
add a3_ptr, a0_ptr, #(3 * 1024)
add a4_ptr, a0_ptr, #(4 * 1024)

add b1_ptr, b0_ptr, #(1 * 1024)
add b2_ptr, b0_ptr, #(2 * 1024)
add b3_ptr, b0_ptr, #(3 * 1024)
add b4_ptr, b0_ptr, #(4 * 1024)

mov count, #(MLDSA_N / 4)
l5_loop_start:
load_polys aa, bb, a0_ptr, b0_ptr
pmull resl, resh, aa, bb
load_polys aa, bb, a1_ptr, b1_ptr
pmlal resl, resh, aa, bb
load_polys aa, bb, a2_ptr, b2_ptr
pmlal resl, resh, aa, bb
load_polys aa, bb, a3_ptr, b3_ptr
pmlal resl, resh, aa, bb
load_polys aa, bb, a4_ptr, b4_ptr
pmlal resl, resh, aa, bb

montgomery_reduce_long res, resl, resh

str q_res, [out_ptr], #16

subs count, count, #1
cbnz count, l5_loop_start

pop_stack
ret
#endif /* MLD_ARITH_BACKEND_AARCH64 */
Loading
Loading