Skip to content

Commit 9927cdd

Browse files
authored
Merge pull request #218 from pq-code-package/rej_uniform_reform
CBMC: Refactor rej_uniform and poly_uniform functions
2 parents 3932986 + 1d08d81 commit 9927cdd

File tree

2 files changed

+42
-30
lines changed

2 files changed

+42
-30
lines changed

mldsa/poly.c

Lines changed: 36 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -259,38 +259,49 @@ int poly_chknorm(const poly *a, int32_t B)
259259
* performing rejection sampling on array of random bytes.
260260
*
261261
* Arguments: - int32_t *a: pointer to output array (allocated)
262-
* - unsigned int len: number of coefficients to be sampled
263-
* - const uint8_t *buf: array of random bytes
264-
* - unsigned int buflen: length of array of random bytes
262+
* - unsigned int target: requested number of coefficients to
263+
*sample
264+
* - unsigned int offset: number of coefficients already sampled
265+
* - const uint8_t *buf: array of random bytes to sample from
266+
* - unsigned int buflen: length of array of random bytes (must be
267+
* multiple of 3)
265268
*
266269
* Returns number of sampled coefficients. Can be smaller than len if not enough
267270
* random bytes were given.
268271
**************************************************/
272+
273+
/* Reference: `rej_uniform()` in the reference implementation [@REF].
274+
* - Our signature differs from the reference implementation
275+
* in that it adds the offset and always expects the base of the
276+
* target buffer. This avoids shifting the buffer base in the
277+
* caller, which appears tricky to reason about. */
269278
#define POLY_UNIFORM_NBLOCKS \
270279
((768 + STREAM128_BLOCKBYTES - 1) / STREAM128_BLOCKBYTES)
271-
static unsigned int rej_uniform(int32_t *a, unsigned int len,
272-
const uint8_t *buf, unsigned int buflen)
280+
static unsigned int rej_uniform(int32_t *a, unsigned int target,
281+
unsigned int offset, const uint8_t *buf,
282+
unsigned int buflen)
273283
__contract__(
274-
requires(len <= buflen && len <= MLDSA_N)
284+
requires(offset <= target && target <= MLDSA_N)
275285
requires(buflen <= (POLY_UNIFORM_NBLOCKS * STREAM128_BLOCKBYTES) && buflen % 3 == 0)
276-
requires(memory_no_alias(a, sizeof(int32_t) * len))
286+
requires(memory_no_alias(a, sizeof(int32_t) * target))
277287
requires(memory_no_alias(buf, buflen))
278-
assigns(memory_slice(a, sizeof(int32_t) * len))
279-
ensures(return_value <= len)
288+
requires(array_bound(a, 0, offset, 0, MLDSA_Q))
289+
assigns(memory_slice(a, sizeof(int32_t) * target))
290+
ensures(offset <= return_value && return_value <= target)
280291
ensures(array_bound(a, 0, return_value, 0, MLDSA_Q))
281292
)
282293
{
283294
unsigned int ctr, pos;
284295
uint32_t t;
285296

286-
ctr = pos = 0;
297+
ctr = offset;
298+
pos = 0;
287299
/* pos + 3 cannot overflow due to the assumption
288300
buflen <= (POLY_UNIFORM_NBLOCKS * STREAM128_BLOCKBYTES) */
289-
while (ctr < len && pos + 3 <= buflen)
301+
while (ctr < target && pos + 3 <= buflen)
290302
__loop__(
291-
invariant(ctr <= len && pos <= buflen)
292-
invariant(array_bound(a, 0, ctr, 0, MLDSA_Q))
293-
)
303+
invariant(offset <= ctr && ctr <= target && pos <= buflen)
304+
invariant(array_bound(a, 0, ctr, 0, MLDSA_Q)))
294305
{
295306
t = buf[pos++];
296307
t |= (uint32_t)buf[pos++] << 8;
@@ -306,29 +317,28 @@ __contract__(
306317
return ctr;
307318
}
308319

320+
/* Reference: poly_uniform() in the reference implementation [@REF].
321+
* - Simplified from reference by removing buffer tail handling
322+
* since buflen % 3 = 0 always holds true (STREAM128_BLOCKBYTES =
323+
* 168).
324+
* - Modified rej_uniform interface to track offset directly. */
309325
void poly_uniform(poly *a, const uint8_t seed[MLDSA_SEEDBYTES], uint16_t nonce)
310326
{
311-
unsigned int i, ctr, off;
327+
unsigned int ctr;
312328
unsigned int buflen = POLY_UNIFORM_NBLOCKS * STREAM128_BLOCKBYTES;
313-
uint8_t buf[POLY_UNIFORM_NBLOCKS * STREAM128_BLOCKBYTES + 2];
329+
uint8_t buf[POLY_UNIFORM_NBLOCKS * STREAM128_BLOCKBYTES];
314330
stream128_state state;
315331

316332
stream128_init(&state, seed, nonce);
317333
stream128_squeezeblocks(buf, POLY_UNIFORM_NBLOCKS, &state);
318334

319-
ctr = rej_uniform(a->coeffs, MLDSA_N, buf, buflen);
335+
ctr = rej_uniform(a->coeffs, MLDSA_N, 0, buf, buflen);
320336

321337
while (ctr < MLDSA_N)
322338
{
323-
off = buflen % 3;
324-
for (i = 0; i < off; ++i)
325-
{
326-
buf[i] = buf[buflen - off + i];
327-
}
328-
329-
stream128_squeezeblocks(buf + off, 1, &state);
330-
buflen = STREAM128_BLOCKBYTES + off;
331-
ctr += rej_uniform(a->coeffs + ctr, MLDSA_N - ctr, buf, buflen);
339+
stream128_squeezeblocks(buf, 1, &state);
340+
buflen = STREAM128_BLOCKBYTES;
341+
ctr = rej_uniform(a->coeffs, MLDSA_N, ctr, buf, buflen);
332342
}
333343
}
334344

proofs/cbmc/rej_uniform/rej_uniform_harness.c

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,17 @@
33

44
#include "poly.h"
55

6-
unsigned rej_uniform(int32_t *a, unsigned int len, const uint8_t *buf,
7-
unsigned int buflen);
6+
static unsigned int rej_uniform(int32_t *a, unsigned int target,
7+
unsigned int offset, const uint8_t *buf,
8+
unsigned int buflen);
89

910
void harness(void)
1011
{
1112
int32_t *a;
12-
unsigned int len;
13+
unsigned int target;
14+
unsigned int offset;
1315
const uint8_t *buf;
1416
unsigned int buflen;
1517

16-
rej_uniform(a, len, buf, buflen);
18+
rej_uniform(a, target, offset, buf, buflen);
1719
}

0 commit comments

Comments
 (0)