Skip to content

Commit 804a2b8

Browse files
mkannwischerhanno-becker
authored andcommitted
native chknorm: 0/0xFFFFFFFF -> 0/1
Previously the native chknorm functions would follow the C implementation, i.e., return 0 if all coefficients are within bound and 0xFFFFFFFF otherwise. This leads to problems with the run-time dispatch in #607 as we want to use -1 to signal that the current platforms lacks the required capabilities to run the native code, and we should fall back to the C implementation. This commit changes the backend API to return 1 in the failure mode. This will allow to implement the run-time dispatch. Signed-off-by: Matthias J. Kannwischer <[email protected]>
1 parent c4d64c6 commit 804a2b8

File tree

14 files changed

+22
-17
lines changed

14 files changed

+22
-17
lines changed

dev/aarch64_clean/meta.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ static MLD_INLINE void mld_poly_use_hint_88_native(int32_t *b, const int32_t *a,
136136
mld_poly_use_hint_88_asm(b, a, h);
137137
}
138138

139-
static MLD_INLINE uint32_t mld_poly_chknorm_native(const int32_t *a, int32_t B)
139+
static MLD_INLINE int mld_poly_chknorm_native(const int32_t *a, int32_t B)
140140
{
141141
return mld_poly_chknorm_asm(a, B);
142142
}

dev/aarch64_clean/src/arith_native_aarch64.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ void mld_poly_use_hint_32_asm(int32_t *b, const int32_t *a, const int32_t *h);
8383
void mld_poly_use_hint_88_asm(int32_t *b, const int32_t *a, const int32_t *h);
8484

8585
#define mld_poly_chknorm_asm MLD_NAMESPACE(poly_chknorm_asm)
86-
uint32_t mld_poly_chknorm_asm(const int32_t *a, int32_t B);
86+
int mld_poly_chknorm_asm(const int32_t *a, int32_t B);
8787

8888
#define mld_polyz_unpack_17_asm MLD_NAMESPACE(polyz_unpack_17_asm)
8989
void mld_polyz_unpack_17_asm(int32_t *r, const uint8_t *buf,

dev/aarch64_clean/src/poly_chknorm_asm.S

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,10 @@ poly_chknorm_loop:
4848
subs count, count, #1
4949
bne poly_chknorm_loop
5050

51-
// Return 0xffffffff if any of the 4 lanes is 0xffffffff
51+
// Return 1 if any of the 4 lanes is 0xffffffff
5252
umaxv s21, flags.4s
5353
fmov w0, s21
54+
and w0, w0, #1
5455

5556
ret
5657

dev/x86_64/meta.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ static MLD_INLINE void mld_poly_use_hint_88_native(int32_t *b, const int32_t *a,
143143
(const __m256i *)h);
144144
}
145145

146-
static MLD_INLINE uint32_t mld_poly_chknorm_native(const int32_t *a, int32_t B)
146+
static MLD_INLINE int mld_poly_chknorm_native(const int32_t *a, int32_t B)
147147
{
148148
return mld_poly_chknorm_avx2((const __m256i *)a, B);
149149
}

dev/x86_64/src/arith_native_x86_64.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ void mld_poly_use_hint_32_avx2(__m256i *b, const __m256i *a, const __m256i *h);
7070
void mld_poly_use_hint_88_avx2(__m256i *b, const __m256i *a, const __m256i *h);
7171

7272
#define mld_poly_chknorm_avx2 MLD_NAMESPACE(mld_poly_chknorm_avx2)
73-
uint32_t mld_poly_chknorm_avx2(const __m256i *a, int32_t B);
73+
int mld_poly_chknorm_avx2(const __m256i *a, int32_t B);
7474

7575
#define mld_polyz_unpack_17_avx2 MLD_NAMESPACE(mld_polyz_unpack_17_avx2)
7676
void mld_polyz_unpack_17_avx2(__m256i *r, const uint8_t *a);

dev/x86_64/src/poly_chknorm_avx2.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
#include <stdint.h>
2727
#include "arith_native_x86_64.h"
2828

29-
uint32_t mld_poly_chknorm_avx2(const __m256i *a, int32_t B)
29+
int mld_poly_chknorm_avx2(const __m256i *a, int32_t B)
3030
{
3131
unsigned int i;
3232
__m256i f, t;
@@ -41,7 +41,7 @@ uint32_t mld_poly_chknorm_avx2(const __m256i *a, int32_t B)
4141
t = _mm256_or_si256(t, f);
4242
}
4343

44-
return (uint32_t)(_mm256_testz_si256(t, t) - 1);
44+
return 1 - _mm256_testz_si256(t, t);
4545
}
4646

4747
#else /* MLD_ARITH_BACKEND_X86_64_DEFAULT && !MLD_CONFIG_MULTILEVEL_NO_SHARED \

mldsa/src/native/aarch64/meta.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ static MLD_INLINE void mld_poly_use_hint_88_native(int32_t *b, const int32_t *a,
136136
mld_poly_use_hint_88_asm(b, a, h);
137137
}
138138

139-
static MLD_INLINE uint32_t mld_poly_chknorm_native(const int32_t *a, int32_t B)
139+
static MLD_INLINE int mld_poly_chknorm_native(const int32_t *a, int32_t B)
140140
{
141141
return mld_poly_chknorm_asm(a, B);
142142
}

mldsa/src/native/aarch64/src/arith_native_aarch64.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ void mld_poly_use_hint_32_asm(int32_t *b, const int32_t *a, const int32_t *h);
8383
void mld_poly_use_hint_88_asm(int32_t *b, const int32_t *a, const int32_t *h);
8484

8585
#define mld_poly_chknorm_asm MLD_NAMESPACE(poly_chknorm_asm)
86-
uint32_t mld_poly_chknorm_asm(const int32_t *a, int32_t B);
86+
int mld_poly_chknorm_asm(const int32_t *a, int32_t B);
8787

8888
#define mld_polyz_unpack_17_asm MLD_NAMESPACE(polyz_unpack_17_asm)
8989
void mld_polyz_unpack_17_asm(int32_t *r, const uint8_t *buf,

mldsa/src/native/aarch64/src/poly_chknorm_asm.S

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ Lpoly_chknorm_loop:
4343
b.ne Lpoly_chknorm_loop
4444
umaxv s21, v21.4s
4545
fmov w0, s21
46+
and w0, w0, #0x1
4647
ret
4748
.cfi_endproc
4849

mldsa/src/native/api.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -264,10 +264,10 @@ static MLD_INLINE void mld_poly_use_hint_88_native(int32_t *b, const int32_t *a,
264264
* Arguments: - const int32_t *a: pointer to polynomial
265265
* - int32_t B: norm bound
266266
*
267-
* Returns 0 if the infinity norm is strictly smaller than B, and 0xFFFFFFFF
267+
* Returns 0 if the infinity norm is strictly smaller than B, and 1
268268
* otherwise. B must not be larger than MLDSA_Q - REDUCE32_RANGE_MAX.
269269
**************************************************/
270-
static MLD_INLINE uint32_t mld_poly_chknorm_native(const int32_t *a, int32_t B);
270+
static MLD_INLINE int mld_poly_chknorm_native(const int32_t *a, int32_t B);
271271
#endif /* MLD_USE_NATIVE_POLY_CHKNORM */
272272

273273
#if defined(MLD_USE_NATIVE_POLYZ_UNPACK_17)

0 commit comments

Comments
 (0)