Skip to content

Commit 68762b8

Browse files
committed
Add first attempt at generalized 2- and 3-layer merged forward
NTT functions. CBMC proofs of these new functions are all TBD. For now, we call these in a "3,2,1,1" pattern to make sure there are no unreferenced functions. Signed-off-by: Rod Chapman <[email protected]> Correct one call from fqmul() to mlk_fqmul() Signed-off-by: Rod Chapman <[email protected]>
1 parent cc6c398 commit 68762b8

File tree

1 file changed

+113
-82
lines changed

1 file changed

+113
-82
lines changed

mlkem/poly.c

Lines changed: 113 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -280,99 +280,134 @@ void mlk_poly_mulcache_compute(mlk_poly_mulcache *x, const mlk_poly *a)
280280
#endif /* MLK_USE_NATIVE_POLY_MULCACHE_COMPUTE */
281281

282282
#if !defined(MLK_USE_NATIVE_NTT)
283-
/*
284-
* Computes a block CT butterflies with a fixed twiddle factor,
285-
* using Montgomery multiplication.
286-
* Parameters:
287-
* - r: Pointer to base of polynomial (_not_ the base of butterfly block)
288-
* - root: Twiddle factor to use for the butterfly. This must be in
289-
* Montgomery form and signed canonical.
290-
* - start: Offset to the beginning of the butterfly block
291-
* - len: Index difference between coefficients subject to a butterfly
292-
* - bound: Ghost variable describing coefficient bound: Prior to `start`,
293-
* coefficients must be bound by `bound + MLKEM_Q`. Post `start`,
294-
* they must be bound by `bound`.
295-
* When this function returns, output coefficients in the index range
296-
* [start, start+2*len) have bound bumped to `bound + MLKEM_Q`.
297-
* Example:
298-
* - start=8, len=4
299-
* This would compute the following four butterflies
300-
* 8 -- 12
301-
* 9 -- 13
302-
* 10 -- 14
303-
* 11 -- 15
304-
* - start=4, len=2
305-
* This would compute the following two butterflies
306-
* 4 -- 6
307-
* 5 -- 7
308-
*/
309283

310284
/* Reference: Embedded in `ntt()` in the reference implementation. */
311-
static void mlk_ntt_butterfly_block(int16_t r[MLKEM_N], int16_t zeta,
312-
unsigned start, unsigned len, int bound)
313-
__contract__(
314-
requires(start < MLKEM_N)
315-
requires(1 <= len && len <= MLKEM_N / 2 && start + 2 * len <= MLKEM_N)
316-
requires(0 <= bound && bound < INT16_MAX - MLKEM_Q)
317-
requires(-MLKEM_Q_HALF < zeta && zeta < MLKEM_Q_HALF)
318-
requires(memory_no_alias(r, sizeof(int16_t) * MLKEM_N))
319-
requires(array_abs_bound(r, 0, start, bound + MLKEM_Q))
320-
requires(array_abs_bound(r, start, MLKEM_N, bound))
321-
assigns(memory_slice(r, sizeof(int16_t) * MLKEM_N))
322-
ensures(array_abs_bound(r, 0, start + 2*len, bound + MLKEM_Q))
323-
ensures(array_abs_bound(r, start + 2 * len, MLKEM_N, bound)))
285+
static MLK_INLINE void mlk_ct_butterfly(int16_t r[MLKEM_N],
286+
const unsigned coeff1_index,
287+
const unsigned coeff2_index,
288+
const int16_t zeta)
324289
{
325-
/* `bound` is a ghost variable only needed in the CBMC specification */
326-
unsigned j;
327-
((void)bound);
328-
for (j = start; j < start + len; j++)
329-
__loop__(
330-
invariant(start <= j && j <= start + len)
331-
/*
332-
* Coefficients are updated in strided pairs, so the bounds for the
333-
* intermediate states alternate twice between the old and new bound
334-
*/
335-
invariant(array_abs_bound(r, 0, j, bound + MLKEM_Q))
336-
invariant(array_abs_bound(r, j, start + len, bound))
337-
invariant(array_abs_bound(r, start + len, j + len, bound + MLKEM_Q))
338-
invariant(array_abs_bound(r, j + len, MLKEM_N, bound)))
339-
{
340-
int16_t t;
341-
t = mlk_fqmul(r[j + len], zeta);
342-
r[j + len] = r[j] - t;
343-
r[j] = r[j] + t;
344-
}
290+
int16_t t1 = r[coeff1_index];
291+
int16_t t2 = mlk_fqmul(r[coeff2_index], zeta);
292+
r[coeff1_index] = t1 + t2;
293+
r[coeff2_index] = t1 - t2;
345294
}
346295

347-
/*
348-
* Compute one layer of forward NTT
349-
* Parameters:
350-
* - r: Pointer to base of polynomial
351-
* - layer: Variable indicating which layer is being applied.
352-
*/
353-
354-
/* Reference: Embedded in `ntt()` in the reference implementation. */
355-
static void mlk_ntt_layer(int16_t r[MLKEM_N], unsigned layer)
296+
static void mlk_ntt_1_layer(int16_t r[MLKEM_N], unsigned layer)
356297
__contract__(
357298
requires(memory_no_alias(r, sizeof(int16_t) * MLKEM_N))
358299
requires(1 <= layer && layer <= 7)
359300
requires(array_abs_bound(r, 0, MLKEM_N, layer * MLKEM_Q))
360301
assigns(memory_slice(r, sizeof(int16_t) * MLKEM_N))
361302
ensures(array_abs_bound(r, 0, MLKEM_N, (layer + 1) * MLKEM_Q)))
362303
{
363-
unsigned start, k, len;
364-
/* Twiddle factors for layer n are at indices 2^(n-1)..2^n-1. */
365-
k = 1u << (layer - 1);
366-
len = MLKEM_N >> layer;
304+
const unsigned len = MLKEM_N >> layer;
305+
unsigned start, k;
306+
/* Twiddle factors for layer n start at index 2 ** (layer-1) */
307+
k = 1 << (layer - 1);
367308
for (start = 0; start < MLKEM_N; start += 2 * len)
368309
__loop__(
369310
invariant(start < MLKEM_N + 2 * len)
370311
invariant(k <= MLKEM_N / 2 && 2 * len * k == start + MLKEM_N)
371312
invariant(array_abs_bound(r, 0, start, layer * MLKEM_Q + MLKEM_Q))
372313
invariant(array_abs_bound(r, start, MLKEM_N, layer * MLKEM_Q)))
373314
{
374-
int16_t zeta = zetas[k++];
375-
mlk_ntt_butterfly_block(r, zeta, start, len, layer * MLKEM_Q);
315+
const int16_t zeta = zetas[k++];
316+
unsigned j;
317+
for (j = 0; j < len; j++)
318+
{
319+
mlk_ct_butterfly(r, j + start, j + start + len, zeta);
320+
}
321+
}
322+
}
323+
324+
static void mlk_ntt_2_layers(int16_t r[MLKEM_N], unsigned layer)
325+
__contract__(
326+
requires(memory_no_alias(r, sizeof(int16_t) * MLKEM_N))
327+
requires(1 <= layer && layer <= 6)
328+
requires(array_abs_bound(r, 0, MLKEM_N, layer * MLKEM_Q))
329+
assigns(memory_slice(r, sizeof(int16_t) * MLKEM_N))
330+
ensures(array_abs_bound(r, 0, MLKEM_N, (layer + 2) * MLKEM_Q)))
331+
{
332+
const unsigned len = MLKEM_N >> layer;
333+
unsigned start, k;
334+
/* Twiddle factors for layer n start at index 2 ** (layer-1) */
335+
k = 1 << (layer - 1);
336+
for (start = 0; start < MLKEM_N; start += 2 * len)
337+
{
338+
unsigned j;
339+
const int16_t this_layer_zeta = zetas[k];
340+
const int16_t next_layer_zeta1 = zetas[k * 2];
341+
const int16_t next_layer_zeta2 = zetas[k * 2 + 1];
342+
k++;
343+
344+
for (j = 0; j < len / 2; j++)
345+
{
346+
const unsigned ci0 = j + start;
347+
const unsigned ci1 = ci0 + len / 2;
348+
const unsigned ci2 = ci1 + len / 2;
349+
const unsigned ci3 = ci2 + len / 2;
350+
351+
mlk_ct_butterfly(r, ci0, ci2, this_layer_zeta);
352+
mlk_ct_butterfly(r, ci1, ci3, this_layer_zeta);
353+
mlk_ct_butterfly(r, ci0, ci1, next_layer_zeta1);
354+
mlk_ct_butterfly(r, ci2, ci3, next_layer_zeta2);
355+
}
356+
}
357+
}
358+
359+
static void mlk_ntt_3_layers(int16_t r[MLKEM_N], unsigned layer)
360+
__contract__(
361+
requires(memory_no_alias(r, sizeof(int16_t) * MLKEM_N))
362+
requires(1 <= layer && layer <= 5)
363+
requires(array_abs_bound(r, 0, MLKEM_N, layer * MLKEM_Q))
364+
assigns(memory_slice(r, sizeof(int16_t) * MLKEM_N))
365+
ensures(array_abs_bound(r, 0, MLKEM_N, (layer + 3) * MLKEM_Q)))
366+
{
367+
const unsigned len = MLKEM_N >> layer;
368+
unsigned start, k;
369+
/* Twiddle factors for layer n start at index 2 ** (layer-1) */
370+
k = 1 << (layer - 1);
371+
for (start = 0; start < MLKEM_N; start += 2 * len)
372+
{
373+
unsigned j;
374+
const int16_t first_layer_zeta = zetas[k];
375+
const unsigned second_layer_zi1 = k * 2;
376+
const unsigned second_layer_zi2 = k * 2 + 1;
377+
const int16_t second_layer_zeta1 = zetas[second_layer_zi1];
378+
const int16_t second_layer_zeta2 = zetas[second_layer_zi2];
379+
const int16_t third_layer_zeta1 = zetas[second_layer_zi1 * 2];
380+
const int16_t third_layer_zeta2 = zetas[second_layer_zi1 * 2 + 1];
381+
const int16_t third_layer_zeta3 = zetas[second_layer_zi2 * 2];
382+
const int16_t third_layer_zeta4 = zetas[second_layer_zi2 * 2 + 1];
383+
k++;
384+
385+
for (j = 0; j < len / 4; j++)
386+
{
387+
const unsigned ci0 = j + start;
388+
const unsigned ci1 = ci0 + len / 4;
389+
const unsigned ci2 = ci1 + len / 4;
390+
const unsigned ci3 = ci2 + len / 4;
391+
const unsigned ci4 = ci3 + len / 4;
392+
const unsigned ci5 = ci4 + len / 4;
393+
const unsigned ci6 = ci5 + len / 4;
394+
const unsigned ci7 = ci6 + len / 4;
395+
396+
mlk_ct_butterfly(r, ci0, ci4, first_layer_zeta);
397+
mlk_ct_butterfly(r, ci1, ci5, first_layer_zeta);
398+
mlk_ct_butterfly(r, ci2, ci6, first_layer_zeta);
399+
mlk_ct_butterfly(r, ci3, ci7, first_layer_zeta);
400+
401+
mlk_ct_butterfly(r, ci0, ci2, second_layer_zeta1);
402+
mlk_ct_butterfly(r, ci1, ci3, second_layer_zeta1);
403+
mlk_ct_butterfly(r, ci4, ci6, second_layer_zeta2);
404+
mlk_ct_butterfly(r, ci5, ci7, second_layer_zeta2);
405+
406+
mlk_ct_butterfly(r, ci0, ci1, third_layer_zeta1);
407+
mlk_ct_butterfly(r, ci2, ci3, third_layer_zeta2);
408+
mlk_ct_butterfly(r, ci4, ci5, third_layer_zeta3);
409+
mlk_ct_butterfly(r, ci6, ci7, third_layer_zeta4);
410+
}
376411
}
377412
}
378413

@@ -391,18 +426,14 @@ __contract__(
391426
MLK_INTERNAL_API
392427
void mlk_poly_ntt(mlk_poly *p)
393428
{
394-
unsigned layer;
395429
int16_t *r;
396430
mlk_assert_abs_bound(p, MLKEM_N, MLKEM_Q);
397431
r = p->coeffs;
398432

399-
for (layer = 1; layer <= 7; layer++)
400-
__loop__(
401-
invariant(1 <= layer && layer <= 8)
402-
invariant(array_abs_bound(r, 0, MLKEM_N, layer * MLKEM_Q)))
403-
{
404-
mlk_ntt_layer(r, layer);
405-
}
433+
mlk_ntt_3_layers(r, 1);
434+
mlk_ntt_2_layers(r, 4);
435+
mlk_ntt_1_layer(r, 6);
436+
mlk_ntt_1_layer(r, 7);
406437

407438
/* Check the stronger bound */
408439
mlk_assert_abs_bound(p, MLKEM_N, MLK_NTT_BOUND);

0 commit comments

Comments
 (0)