Skip to content

Commit 7301e50

Browse files
authored
Fboemer/ntt fix (#58)
* Fix NTT AVX512 implementation
1 parent bb4b6a6 commit 7301e50

21 files changed

+229
-122
lines changed

CHANGES.md

+3-1
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22

33
## Version 1.2.1
44
- Fixes a bug in AVX512 floating-point implementation of element-wise vector-vector modular multiplication (https://github.com/microsoft/SEAL/issues/385)
5-
- Fixes a bug in the NTT default allocator (https://gitlab.com/palisade/palisade-development/-/issues/323#note_662270512)
5+
- Fixes a bug in the NTT default constructor (https://gitlab.com/palisade/palisade-development/-/issues/329)
6+
- Fixes a bug in the AVX512 NTT (https://github.com/intel/hexl/pull/58)
67
- Improves performance of EltwiseFMAModAVX512 on ICX (https://github.com/intel/hexl/pull/42)
78
- Improves performance of the native NTT
89
- Adds reference implementations for the radix-4 NTT
10+
- Enables support for pre-built easylogging (https://github.com/intel/hexl/pull/57)
911

1012
## Version 1.2.0
1113
- Large performance improvement in large (N >= 16384) AVX512 NTTs via recursive implementations

benchmark/bench-eltwise-fma-mod.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ static void BM_EltwiseFMAModAddNative(benchmark::State& state) { // NOLINT
3535

3636
BENCHMARK(BM_EltwiseFMAModAddNative)
3737
->Unit(benchmark::kMicrosecond)
38-
->ArgsProduct({{1024, 8192, 16384}, {false, true}});
38+
->ArgsProduct({{1024, 4096, 16384}, {false, true}});
3939

4040
//=================================================================
4141

@@ -59,7 +59,7 @@ static void BM_EltwiseFMAModAVX512DQ(benchmark::State& state) { // NOLINT
5959

6060
BENCHMARK(BM_EltwiseFMAModAVX512DQ)
6161
->Unit(benchmark::kMicrosecond)
62-
->ArgsProduct({{1024, 8192, 16384}, {false, true}});
62+
->ArgsProduct({{1024, 4096, 16384}, {false, true}});
6363
#endif
6464

6565
//=================================================================
@@ -84,7 +84,7 @@ static void BM_EltwiseFMAModAVX512IFMA(benchmark::State& state) { // NOLINT
8484

8585
BENCHMARK(BM_EltwiseFMAModAVX512IFMA)
8686
->Unit(benchmark::kMicrosecond)
87-
->ArgsProduct({{1024, 8192, 16384}, {false, true}});
87+
->ArgsProduct({{1024, 4096, 16384}, {false, true}});
8888

8989
#endif
9090

benchmark/bench-eltwise-mult-mod.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ static void BM_EltwiseMultMod(benchmark::State& state) { // NOLINT
3636

3737
BENCHMARK(BM_EltwiseMultMod)
3838
->Unit(benchmark::kMicrosecond)
39-
->ArgsProduct({{1024, 8192, 16384}, {48, 60}, {1, 2, 4}});
39+
->ArgsProduct({{1024, 4096, 16384}, {48, 60}, {1, 2, 4}});
4040

4141
//=================================================================
4242

benchmark/bench-ntt.cpp

+15-15
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ namespace hexl {
2222

2323
static void BM_FwdNTTNativeRadix2(benchmark::State& state) { // NOLINT
2424
size_t ntt_size = state.range(0);
25-
size_t modulus = GeneratePrimes(1, 45, ntt_size)[0];
25+
size_t modulus = GeneratePrimes(1, 45, true, ntt_size)[0];
2626

2727
AlignedVector64<uint64_t> input(ntt_size, 1);
2828
NTT ntt(ntt_size, modulus);
@@ -43,7 +43,7 @@ BENCHMARK(BM_FwdNTTNativeRadix2)
4343

4444
static void BM_FwdNTTNativeRadix4(benchmark::State& state) { // NOLINT
4545
size_t ntt_size = state.range(0);
46-
size_t modulus = GeneratePrimes(1, 45, ntt_size)[0];
46+
size_t modulus = GeneratePrimes(1, 45, true, ntt_size)[0];
4747

4848
AlignedVector64<uint64_t> input(ntt_size, 1);
4949
NTT ntt(ntt_size, modulus);
@@ -67,7 +67,7 @@ BENCHMARK(BM_FwdNTTNativeRadix4)
6767
static void BM_FwdNTT_AVX512IFMA(benchmark::State& state) { // NOLINT
6868
size_t ntt_size = state.range(0);
6969
size_t modulus_bits = 49;
70-
size_t modulus = GeneratePrimes(1, modulus_bits, ntt_size)[0];
70+
size_t modulus = GeneratePrimes(1, modulus_bits, true, ntt_size)[0];
7171

7272
AlignedVector64<uint64_t> input(ntt_size, 1);
7373
NTT ntt(ntt_size, modulus);
@@ -96,7 +96,7 @@ BENCHMARK(BM_FwdNTT_AVX512IFMA)
9696
static void BM_FwdNTT_AVX512IFMALazy(benchmark::State& state) { // NOLINT
9797
size_t ntt_size = state.range(0);
9898
size_t modulus_bits = 49;
99-
size_t modulus = GeneratePrimes(1, modulus_bits, ntt_size)[0];
99+
size_t modulus = GeneratePrimes(1, modulus_bits, true, ntt_size)[0];
100100

101101
AlignedVector64<uint64_t> input(ntt_size, 1);
102102
NTT ntt(ntt_size, modulus);
@@ -132,7 +132,7 @@ static void BM_FwdNTT_AVX512DQ_32(benchmark::State& state) { // NOLINT
132132
size_t ntt_size = state.range(0);
133133
uint64_t output_mod_factor = state.range(1);
134134
size_t modulus_bits = 29;
135-
size_t modulus = GeneratePrimes(1, modulus_bits, ntt_size)[0];
135+
size_t modulus = GeneratePrimes(1, modulus_bits, true, ntt_size)[0];
136136

137137
AlignedVector64<uint64_t> input(ntt_size, 1);
138138
NTT ntt(ntt_size, modulus);
@@ -163,7 +163,7 @@ static void BM_FwdNTT_AVX512DQ_64(benchmark::State& state) { // NOLINT
163163
size_t ntt_size = state.range(0);
164164
uint64_t output_mod_factor = state.range(1);
165165
size_t modulus_bits = 55;
166-
size_t modulus = GeneratePrimes(1, modulus_bits, ntt_size)[0];
166+
size_t modulus = GeneratePrimes(1, modulus_bits, true, ntt_size)[0];
167167

168168
AlignedVector64<uint64_t> input(ntt_size, 1);
169169
NTT ntt(ntt_size, modulus);
@@ -195,7 +195,7 @@ BENCHMARK(BM_FwdNTT_AVX512DQ_64)
195195
// state[0] is the degree
196196
static void BM_FwdNTTInPlace(benchmark::State& state) { // NOLINT
197197
size_t ntt_size = state.range(0);
198-
size_t modulus = GeneratePrimes(1, 61, ntt_size)[0];
198+
size_t modulus = GeneratePrimes(1, 61, true, ntt_size)[0];
199199

200200
AlignedVector64<uint64_t> input(ntt_size, 1);
201201
NTT ntt(ntt_size, modulus);
@@ -216,7 +216,7 @@ BENCHMARK(BM_FwdNTTInPlace)
216216
// state[0] is the degree
217217
static void BM_FwdNTTCopy(benchmark::State& state) { // NOLINT
218218
size_t ntt_size = state.range(0);
219-
size_t modulus = GeneratePrimes(1, 45, ntt_size)[0];
219+
size_t modulus = GeneratePrimes(1, 45, true, ntt_size)[0];
220220

221221
AlignedVector64<uint64_t> input(ntt_size, 1);
222222
AlignedVector64<uint64_t> output(ntt_size, 1);
@@ -236,7 +236,7 @@ BENCHMARK(BM_FwdNTTCopy)
236236
// state[0] is the degree
237237
static void BM_InvNTTCopy(benchmark::State& state) { // NOLINT
238238
size_t ntt_size = state.range(0);
239-
size_t modulus = GeneratePrimes(1, 45, ntt_size)[0];
239+
size_t modulus = GeneratePrimes(1, 45, true, ntt_size)[0];
240240

241241
AlignedVector64<uint64_t> input(ntt_size, 1);
242242
AlignedVector64<uint64_t> output(ntt_size, 1);
@@ -259,7 +259,7 @@ BENCHMARK(BM_InvNTTCopy)
259259

260260
static void BM_InvNTTNativeRadix2(benchmark::State& state) { // NOLINT
261261
size_t ntt_size = state.range(0);
262-
size_t modulus = GeneratePrimes(1, 45, ntt_size)[0];
262+
size_t modulus = GeneratePrimes(1, 45, true, ntt_size)[0];
263263

264264
AlignedVector64<uint64_t> input(ntt_size, 1);
265265
NTT ntt(ntt_size, modulus);
@@ -284,7 +284,7 @@ BENCHMARK(BM_InvNTTNativeRadix2)
284284

285285
static void BM_InvNTTNativeRadix4(benchmark::State& state) { // NOLINT
286286
size_t ntt_size = state.range(0);
287-
size_t modulus = GeneratePrimes(1, 45, ntt_size)[0];
287+
size_t modulus = GeneratePrimes(1, 45, true, ntt_size)[0];
288288

289289
AlignedVector64<uint64_t> input(ntt_size, 1);
290290
NTT ntt(ntt_size, modulus);
@@ -311,7 +311,7 @@ BENCHMARK(BM_InvNTTNativeRadix4)
311311
// state[0] is the degree
312312
static void BM_InvNTT_AVX512IFMA(benchmark::State& state) { // NOLINT
313313
size_t ntt_size = state.range(0);
314-
size_t modulus = GeneratePrimes(1, 49, ntt_size)[0];
314+
size_t modulus = GeneratePrimes(1, 49, true, ntt_size)[0];
315315

316316
AlignedVector64<uint64_t> input(ntt_size, 1);
317317
NTT ntt(ntt_size, modulus);
@@ -337,7 +337,7 @@ BENCHMARK(BM_InvNTT_AVX512IFMA)
337337
// state[0] is the degree
338338
static void BM_InvNTT_AVX512IFMALazy(benchmark::State& state) { // NOLINT
339339
size_t ntt_size = state.range(0);
340-
size_t modulus = GeneratePrimes(1, 49, ntt_size)[0];
340+
size_t modulus = GeneratePrimes(1, 49, true, ntt_size)[0];
341341

342342
AlignedVector64<uint64_t> input(ntt_size, 1);
343343
NTT ntt(ntt_size, modulus);
@@ -367,7 +367,7 @@ BENCHMARK(BM_InvNTT_AVX512IFMALazy)
367367
static void BM_InvNTT_AVX512DQ_32(benchmark::State& state) { // NOLINT
368368
size_t ntt_size = state.range(0);
369369
uint64_t output_mod_factor = state.range(1);
370-
size_t modulus = GeneratePrimes(1, 30, ntt_size)[0];
370+
size_t modulus = GeneratePrimes(1, 30, true, ntt_size)[0];
371371

372372
AlignedVector64<uint64_t> input(ntt_size, 1);
373373
NTT ntt(ntt_size, modulus);
@@ -395,7 +395,7 @@ BENCHMARK(BM_InvNTT_AVX512DQ_32)
395395
static void BM_InvNTT_AVX512DQ_64(benchmark::State& state) { // NOLINT
396396
size_t ntt_size = state.range(0);
397397
uint64_t output_mod_factor = state.range(1);
398-
size_t modulus = GeneratePrimes(1, 61, ntt_size)[0];
398+
size_t modulus = GeneratePrimes(1, 61, true, ntt_size)[0];
399399

400400
AlignedVector64<uint64_t> input(ntt_size, 1);
401401
NTT ntt(ntt_size, modulus);

hexl/CMakeLists.txt

+2
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@ endif()
7676
if (CMAKE_CXX_COMPILER_ID MATCHES "GNU|Clang")
7777
target_compile_options(hexl PRIVATE -Wall -Wconversion -Wshadow -pedantic -Wextra
7878
-Wno-unknown-pragmas -march=native -O3 -fomit-frame-pointer
79+
-Wno-sign-conversion
80+
-Wno-implicit-int-conversion
7981
)
8082
# Avoid 3rd-party dependency warnings when including HEXL as a dependency
8183
target_compile_options(hexl PUBLIC

hexl/include/hexl/number-theory/number-theory.hpp

+5-2
Original file line numberDiff line numberDiff line change
@@ -176,14 +176,17 @@ inline unsigned char AddUInt64(uint64_t operand1, uint64_t operand2,
176176
/// @brief Returns whether or not the input is prime
177177
bool IsPrime(uint64_t n);
178178

179-
/// @brief Generates a list of num_primes primes in the range [2^(bit_size,
179+
/// @brief Generates a list of num_primes primes in the range [2^(bit_size),
180180
// 2^(bit_size+1)]. Ensures each prime q satisfies
181181
// q % (2*ntt_size+1)) == 1
182182
/// @param[in] num_primes Number of primes to generate
183183
/// @param[in] bit_size Bit size of each prime
184+
/// @param[in] prefer_small_primes When true, returns primes starting from
185+
/// 2^(bit_size); when false, returns primes starting from 2^(bit_size+1)
184186
/// @param[in] ntt_size N such that each prime q satisfies q % (2N) == 1. N must
185-
/// be a power of two
187+
/// be a power of two less than 2^bit_size.
186188
std::vector<uint64_t> GeneratePrimes(size_t num_primes, size_t bit_size,
189+
bool prefer_small_primes,
187190
size_t ntt_size = 1);
188191

189192
/// @brief Returns input mod modulus, computed via 64-bit Barrett reduction

hexl/ntt/fwd-ntt-avx512.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ void ForwardTransformToBitReverseAVX512(
266266
const uint64_t* W = &root_of_unity_powers[W_idx];
267267
const uint64_t* W_precon = &precon_root_of_unity_powers[W_idx];
268268

269-
if (input_mod_factor <= 2) {
269+
if ((input_mod_factor <= 2) && (recursion_depth == 0)) {
270270
FwdT8<BitShift, true>(operand, v_neg_modulus, v_twice_mod, t, m, W,
271271
W_precon);
272272
} else {

hexl/ntt/inv-ntt-avx512.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ void InverseTransformFromBitReverseAVX512(
260260
// t = 1
261261
const uint64_t* W = &inv_root_of_unity_powers[W_idx];
262262
const uint64_t* W_precon = &precon_inv_root_of_unity_powers[W_idx];
263-
if (input_mod_factor == 1) {
263+
if ((input_mod_factor == 1) && (recursion_depth == 0)) {
264264
InvT1<BitShift, true>(operand, v_neg_modulus, v_twice_mod, m, W,
265265
W_precon);
266266
} else {

hexl/number-theory/number-theory.cpp

+27-5
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,7 @@ bool IsPrime(uint64_t n) {
223223
}
224224

225225
std::vector<uint64_t> GeneratePrimes(size_t num_primes, size_t bit_size,
226+
bool prefer_small_primes,
226227
size_t ntt_size) {
227228
HEXL_CHECK(num_primes > 0, "num_primes == 0");
228229
HEXL_CHECK(IsPowerOfTwo(ntt_size),
@@ -231,18 +232,39 @@ std::vector<uint64_t> GeneratePrimes(size_t num_primes, size_t bit_size,
231232
"log2(ntt_size) " << Log2(ntt_size)
232233
<< " should be less than bit_size " << bit_size);
233234

234-
uint64_t value = (1ULL << bit_size) + 1;
235+
int64_t prime_lower_bound = (1LL << bit_size) + 1LL;
236+
int64_t prime_upper_bound = (1LL << (bit_size + 1LL)) - 1LL;
237+
238+
// Keep signed to enable negative step
239+
int64_t prime_candidate =
240+
prefer_small_primes
241+
? prime_lower_bound
242+
: prime_upper_bound - (prime_upper_bound % (2 * ntt_size)) + 1;
243+
HEXL_CHECK(prime_candidate % (2 * ntt_size) == 1, "bad prime candidate");
244+
245+
// Ensure prime % 2 * ntt_size == 1
246+
int64_t prime_candidate_step =
247+
(prefer_small_primes ? 1 : -1) * 2 * static_cast<int64_t>(ntt_size);
248+
249+
auto continue_condition = [&](int64_t local_candidate_prime) {
250+
if (prefer_small_primes) {
251+
return local_candidate_prime < prime_upper_bound;
252+
} else {
253+
return local_candidate_prime > prime_lower_bound;
254+
}
255+
};
235256

236257
std::vector<uint64_t> ret;
237258

238-
while (value < (1ULL << (bit_size + 1))) {
239-
if (IsPrime(value)) {
240-
ret.emplace_back(value);
259+
while (continue_condition(prime_candidate)) {
260+
if (IsPrime(prime_candidate)) {
261+
HEXL_CHECK(prime_candidate % (2 * ntt_size) == 1, "bad prime candidate");
262+
ret.emplace_back(static_cast<uint64_t>(prime_candidate));
241263
if (ret.size() == num_primes) {
242264
return ret;
243265
}
244266
}
245-
value += 2 * ntt_size;
267+
prime_candidate += prime_candidate_step;
246268
}
247269

248270
HEXL_CHECK(false, "Failed to find enough primes");

test/test-eltwise-add-mod-avx512.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ TEST(EltwiseAddMod, vector_vector_avx512_big) {
5252
GTEST_SKIP();
5353
}
5454

55-
uint64_t modulus = GeneratePrimes(1, 60, 1024)[0];
55+
uint64_t modulus = GeneratePrimes(1, 60, true, 1024)[0];
5656

5757
std::vector<uint64_t> op1{modulus - 1, modulus - 1, modulus - 2, modulus - 2,
5858
modulus - 3, modulus - 3, modulus - 4, modulus - 4};
@@ -72,7 +72,7 @@ TEST(EltwiseAddMod, vector_scalar_avx512_big) {
7272
GTEST_SKIP();
7373
}
7474

75-
uint64_t modulus = GeneratePrimes(1, 60, 1024)[0];
75+
uint64_t modulus = GeneratePrimes(1, 60, true, 1024)[0];
7676

7777
std::vector<uint64_t> op1{modulus - 1, modulus - 1, modulus - 2, modulus - 2,
7878
modulus - 3, modulus - 3, modulus - 4, modulus - 4};

test/test-eltwise-add-mod.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ TEST(EltwiseAddMod, vector_scalar_native_small) {
8181
}
8282

8383
TEST(EltwiseAddMod, vector_vector_native_big) {
84-
uint64_t modulus = GeneratePrimes(1, 60, 1024)[0];
84+
uint64_t modulus = GeneratePrimes(1, 60, true, 1024)[0];
8585

8686
std::vector<uint64_t> op1{modulus - 1, modulus - 1, modulus - 2, modulus - 2,
8787
modulus - 3, modulus - 3, modulus - 4, modulus - 4};
@@ -97,7 +97,7 @@ TEST(EltwiseAddMod, vector_vector_native_big) {
9797
}
9898

9999
TEST(EltwiseAddMod, vector_scalar_native_big) {
100-
uint64_t modulus = GeneratePrimes(1, 60, 1024)[0];
100+
uint64_t modulus = GeneratePrimes(1, 60, true, 1024)[0];
101101

102102
std::vector<uint64_t> op1{modulus - 1, modulus - 1, modulus - 2, modulus - 2,
103103
modulus - 3, modulus - 3, modulus - 4, modulus - 4};

test/test-eltwise-cmp-sub-mod-avx512.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ TEST(EltwiseCmpSubMod, AVX512) {
3131

3232
for (size_t cmp = 0; cmp < 8; ++cmp) {
3333
for (size_t bits = 48; bits <= 51; ++bits) {
34-
uint64_t modulus = GeneratePrimes(1, bits, 1024)[0];
34+
uint64_t modulus = GeneratePrimes(1, bits, true, 1024)[0];
3535
std::uniform_int_distribution<uint64_t> distrib(0, modulus - 1);
3636

3737
for (size_t trial = 0; trial < 200; ++trial) {

test/test-eltwise-fma-mod-avx512.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ TEST(EltwiseFMAMod, AVX512IFMA) {
229229
constexpr uint64_t input_mod_factor = 8;
230230

231231
for (size_t bits = 48; bits <= 51; ++bits) {
232-
uint64_t modulus = GeneratePrimes(1, bits, length)[0];
232+
uint64_t modulus = GeneratePrimes(1, bits, true, length)[0];
233233
std::uniform_int_distribution<uint64_t> distrib(
234234
0, input_mod_factor * modulus - 1);
235235

test/test-eltwise-mult-mod-avx512.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ TEST(EltwiseMultMod, avx512_int2) {
3939
if (!has_avx512dq) {
4040
GTEST_SKIP();
4141
}
42-
uint64_t modulus = GeneratePrimes(1, 60, 1024)[0];
42+
uint64_t modulus = GeneratePrimes(1, 60, true, 1024)[0];
4343

4444
std::vector<uint64_t> op1{modulus - 3, 1, 1, 1, 1, 1, 1, 1};
4545
std::vector<uint64_t> op2{modulus - 4, 1, 1, 1, 1, 1, 1, 1};

test/test-eltwise-mult-mod.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ TEST(EltwiseMultModInPlace, 8_bounds) {
7777
#endif
7878

7979
TEST(EltwiseMultModInPlace, 9) {
80-
uint64_t modulus = GeneratePrimes(1, 51, 1024)[0];
80+
uint64_t modulus = GeneratePrimes(1, 51, true, 1024)[0];
8181

8282
std::vector<uint64_t> op1{modulus - 3, 1, 2, 3, 4, 5, 6, 7, 8};
8383
std::vector<uint64_t> op2{modulus - 4, 8, 7, 6, 5, 4, 3, 2, 1};
@@ -105,7 +105,7 @@ TEST(EltwiseMultMod, native_mult2) {
105105
}
106106

107107
TEST(EltwiseMultMod, native2_big) {
108-
uint64_t modulus = GeneratePrimes(1, 60, 1024)[0];
108+
uint64_t modulus = GeneratePrimes(1, 60, true, 1024)[0];
109109

110110
std::vector<uint64_t> op1{modulus - 3, 1, 1, 1, 1, 1, 1, 1};
111111
std::vector<uint64_t> op2{modulus - 4, 1, 1, 1, 1, 1, 1, 1};
@@ -119,7 +119,7 @@ TEST(EltwiseMultMod, native2_big) {
119119
}
120120

121121
TEST(EltwiseMultMod, 8big) {
122-
uint64_t modulus = GeneratePrimes(1, 48, 1024)[0];
122+
uint64_t modulus = GeneratePrimes(1, 48, true, 1024)[0];
123123

124124
std::vector<uint64_t> op1{modulus - 1, 1, 1, 1, 1, 1, 1, 1};
125125
std::vector<uint64_t> op2{modulus - 1, 1, 1, 1, 1, 1, 1, 1};
@@ -198,7 +198,7 @@ TEST(EltwiseMultMod, 8_bounds) {
198198
#endif
199199

200200
TEST(EltwiseMultMod, 9) {
201-
uint64_t modulus = GeneratePrimes(1, 51, 1024)[0];
201+
uint64_t modulus = GeneratePrimes(1, 51, true, 1024)[0];
202202

203203
std::vector<uint64_t> op1{modulus - 3, 1, 2, 3, 4, 5, 6, 7, 8};
204204
std::vector<uint64_t> op2{modulus - 4, 8, 7, 6, 5, 4, 3, 2, 1};

0 commit comments

Comments
 (0)