From 6821038048e1251c30f819fb0aeeed8cd3099388 Mon Sep 17 00:00:00 2001 From: Giacomo Pope Date: Wed, 24 Jul 2024 10:57:56 +0100 Subject: [PATCH] tweak tests for higher coverage --- src/kyber_py/kyber/kyber.py | 11 +- src/kyber_py/ml_kem/ml_kem.py | 11 +- src/kyber_py/modules/modules.py | 6 +- src/kyber_py/polynomials/polynomials.py | 37 +++-- tests/test_kyber.py | 9 +- tests/test_ml_kem.py | 4 - tests/test_module.py | 20 +++ tests/test_module_generic.py | 4 - tests/test_polynomial.py | 186 ++++++++++++++++++++++++ tests/test_polynomial_generic.py | 7 +- 10 files changed, 235 insertions(+), 60 deletions(-) create mode 100644 tests/test_module.py create mode 100644 tests/test_polynomial.py diff --git a/src/kyber_py/kyber/kyber.py b/src/kyber_py/kyber/kyber.py index 1a7ebec..1762945 100644 --- a/src/kyber_py/kyber/kyber.py +++ b/src/kyber_py/kyber/kyber.py @@ -5,13 +5,11 @@ class Kyber: - def __init__(self, parameter_set, seed=None): + def __init__(self, parameter_set): """ Initialise Kyber with specified lattice parameters. :param dict params: the lattice parameters - :param bytes seed: the optional seed for a DRBG, must be unique and - unpredictable """ self.k = parameter_set["k"] self.eta_1 = parameter_set["eta_1"] @@ -22,13 +20,10 @@ def __init__(self, parameter_set, seed=None): self.M = ModuleKyber() self.R = self.M.ring - # Use system randomness by default + # Use system randomness by default, for deterministic randomness + # use the method `set_drbg_seed()` self.random_bytes = os.urandom - # If a seed is supplied, use deterministic randomness - if seed is not None: - self.set_drbg_seed(seed) - def set_drbg_seed(self, seed): """ Change entropy source to a DRBG and seed it with provided value. diff --git a/src/kyber_py/ml_kem/ml_kem.py b/src/kyber_py/ml_kem/ml_kem.py index 28c334f..f56f277 100644 --- a/src/kyber_py/ml_kem/ml_kem.py +++ b/src/kyber_py/ml_kem/ml_kem.py @@ -9,13 +9,11 @@ class ML_KEM: - def __init__(self, params, seed=None): + def __init__(self, params): """ Initialise the ML-KEM with specified lattice parameters. :param dict params: the lattice parameters - :param bytes seed: the optional seed for a DRBG, must be unique and - unpredictable """ # ml-kem params self.k = params["k"] @@ -27,13 +25,10 @@ def __init__(self, params, seed=None): self.M = ModuleKyber() self.R = self.M.ring - # Use system randomness by default + # Use system randomness by default, for deterministic randomness + # use the method `set_drbg_seed()` self.random_bytes = os.urandom - # If a seed is supplied, use deterministic randomness - if seed is not None: - self.set_drbg_seed(seed) - def set_drbg_seed(self, seed): """ Change entropy source to a DRBG and seed it with provided value. diff --git a/src/kyber_py/modules/modules.py b/src/kyber_py/modules/modules.py index 3a69e76..e33f5d2 100644 --- a/src/kyber_py/modules/modules.py +++ b/src/kyber_py/modules/modules.py @@ -29,11 +29,7 @@ def decode_vector(self, input_bytes, k, d, is_ntt=False): class MatrixKyber(Matrix): def __init__(self, parent, matrix_data, transpose=False): - self.parent = parent - self._data = matrix_data - self._transpose = transpose - if not self._check_dimensions(): - raise ValueError("Inconsistent row lengths in matrix") + super().__init__(parent, matrix_data, transpose=transpose) def encode(self, d): output = b"" diff --git a/src/kyber_py/polynomials/polynomials.py b/src/kyber_py/polynomials/polynomials.py index 61eec28..bb03465 100644 --- a/src/kyber_py/polynomials/polynomials.py +++ b/src/kyber_py/polynomials/polynomials.py @@ -17,12 +17,12 @@ def __init__(self): root_of_unity = 17 self.ntt_zetas = [ - pow(root_of_unity, self.br(i, 7), 3329) for i in range(128) + pow(root_of_unity, self._br(i, 7), 3329) for i in range(128) ] self.ntt_f = pow(128, -1, 3329) @staticmethod - def br(i, k): + def _br(i, k): """ bit reversal of an unsigned k-bit integer """ @@ -105,7 +105,7 @@ def __call__(self, coefficients, is_ntt=False): return element(self, [coefficients]) if not isinstance(coefficients, list): raise TypeError( - f"Polynomials should be constructed from a list of integers, of length at most d = {256}" + f"Polynomials should be constructed from a list of integers, of length at most n = {256}" ) return element(self, coefficients) @@ -122,7 +122,7 @@ def encode(self, d): bit_string = "".join(format(c, f"0{d}b")[::-1] for c in self.coeffs) return bitstring_to_bytes(bit_string) - def compress_ele(self, x, d): + def _compress_ele(self, x, d): """ Compute round((2^d / q) * x) % 2^d """ @@ -130,7 +130,7 @@ def compress_ele(self, x, d): y = (t * x + 1664) // 3329 # 1664 = 3329 // 2 return y % t - def decompress_ele(self, x, d): + def _decompress_ele(self, x, d): """ Compute round((q / 2^d) * x) """ @@ -143,7 +143,7 @@ def compress(self, d): Compress the polynomial by compressing each coefficient NOTE: This is lossy compression """ - self.coeffs = [self.compress_ele(c, d) for c in self.coeffs] + self.coeffs = [self._compress_ele(c, d) for c in self.coeffs] return self def decompress(self, d): @@ -153,7 +153,7 @@ def decompress(self, d): x' = decompress(compress(x)), which x' != x, but is close in magnitude. """ - self.coeffs = [self.decompress_ele(c, d) for c in self.coeffs] + self.coeffs = [self._decompress_ele(c, d) for c in self.coeffs] return self def to_ntt(self): @@ -182,7 +182,7 @@ def to_ntt(self): return self.parent(coeffs, is_ntt=True) def from_ntt(self): - raise TypeError(f"Polynomial is of type: {type(self)}") + raise TypeError(f"Polynomial not in the NTT domain: {type(self) = }") class PolynomialKyberNTT(PolynomialKyber): @@ -191,7 +191,9 @@ def __init__(self, parent, coefficients): self.coeffs = self._parse_coefficients(coefficients) def to_ntt(self): - raise TypeError(f"Polynomial is of type: {type(self)}") + raise TypeError( + f"Polynomial is already in the NTT domain: {type(self) = }" + ) def from_ntt(self): """ @@ -222,7 +224,7 @@ def from_ntt(self): return self.parent(coeffs, is_ntt=False) @staticmethod - def ntt_base_multiplication(a0, a1, b0, b1, zeta): + def _ntt_base_multiplication(a0, a1, b0, b1, zeta): """ Base case for ntt multiplication """ @@ -230,18 +232,18 @@ def ntt_base_multiplication(a0, a1, b0, b1, zeta): r1 = (a1 * b0 + a0 * b1) % 3329 return r0, r1 - def ntt_coefficient_multiplication(self, f_coeffs, g_coeffs): + def _ntt_coefficient_multiplication(self, f_coeffs, g_coeffs): new_coeffs = [] zetas = self.parent.ntt_zetas for i in range(64): - r0, r1 = self.ntt_base_multiplication( + r0, r1 = self._ntt_base_multiplication( f_coeffs[4 * i + 0], f_coeffs[4 * i + 1], g_coeffs[4 * i + 0], g_coeffs[4 * i + 1], zetas[64 + i], ) - r2, r3 = self.ntt_base_multiplication( + r2, r3 = self._ntt_base_multiplication( f_coeffs[4 * i + 2], f_coeffs[4 * i + 3], g_coeffs[4 * i + 2], @@ -251,15 +253,12 @@ def ntt_coefficient_multiplication(self, f_coeffs, g_coeffs): new_coeffs += [r0, r1, r2, r3] return new_coeffs - def ntt_multiplication(self, other): + def _ntt_multiplication(self, other): """ Number Theoretic Transform multiplication. Only implemented (currently) for n = 256 """ - if not isinstance(other, type(self)): - raise ValueError - - new_coeffs = self.ntt_coefficient_multiplication( + new_coeffs = self._ntt_coefficient_multiplication( self.coeffs, other.coeffs ) return new_coeffs @@ -274,7 +273,7 @@ def __sub__(self, other): def __mul__(self, other): if isinstance(other, type(self)): - new_coeffs = self.ntt_multiplication(other) + new_coeffs = self._ntt_multiplication(other) elif isinstance(other, int): new_coeffs = [(c * other) % 3329 for c in self.coeffs] else: diff --git a/tests/test_kyber.py b/tests/test_kyber.py index e115711..eefad76 100644 --- a/tests/test_kyber.py +++ b/tests/test_kyber.py @@ -1,6 +1,5 @@ import unittest import os -from itertools import islice import pytest from kyber_py.kyber import Kyber512, Kyber768, Kyber1024 from kyber_py.drbg.aes256_ctr_drbg import AES256_CTR_DRBG @@ -143,13 +142,9 @@ def test_generic_kyber_known_answer(Kyber, seed, data): # Assert encapsulation matches ss, ct = Kyber.encaps(pk) - assert ct == data["ct"] assert ss == data["ss"] + assert ct == data["ct"] # Assert decapsulation matches _ss = Kyber.decaps(ct, sk) - assert ss == data["ss"] - - -if __name__ == "__main__": - unittest.main() + assert _ss == data["ss"] diff --git a/tests/test_ml_kem.py b/tests/test_ml_kem.py index 6398616..f061807 100644 --- a/tests/test_ml_kem.py +++ b/tests/test_ml_kem.py @@ -136,7 +136,3 @@ def test_mlkem_known_answer(ML_KEM, seed, kat_vals): # Assert decapsulation with faulty ciphertext ss_n = ML_KEM.decaps(data["ct_n"], dk) assert ss_n == data["ss_n"] - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_module.py b/tests/test_module.py new file mode 100644 index 0000000..ab1f627 --- /dev/null +++ b/tests/test_module.py @@ -0,0 +1,20 @@ +import unittest +from random import randint +from kyber_py.modules.modules import ModuleKyber + + +class TestModuleKyber(unittest.TestCase): + M = ModuleKyber() + R = M.ring + + def test_decode_vector(self): + for _ in range(100): + k = randint(1, 5) + v = self.M.random_element(k, 1) + v_bytes = v.encode(12) + self.assertEqual(v, self.M.decode_vector(v_bytes, k, 12)) + + def test_recode_vector_wrong_length(self): + self.assertRaises( + ValueError, lambda: self.M.decode_vector(b"1", 2, 12) + ) diff --git a/tests/test_module_generic.py b/tests/test_module_generic.py index 3fac4df..8b23333 100644 --- a/tests/test_module_generic.py +++ b/tests/test_module_generic.py @@ -182,7 +182,3 @@ def test_print(self): su = "[1 + 2*x, 3 + 4*x + 5*x^2 + 6*x^3]" self.assertEqual(str(A), sA) self.assertEqual(str(u), su) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_polynomial.py b/tests/test_polynomial.py new file mode 100644 index 0000000..6dc8796 --- /dev/null +++ b/tests/test_polynomial.py @@ -0,0 +1,186 @@ +import unittest +from random import randint +from kyber_py.polynomials.polynomials import PolynomialRingKyber + + +class TestModuleKyber(unittest.TestCase): + R = PolynomialRingKyber() + + def test_decode(self): + for _ in range(10): + f = self.R.random_element() + f_bytes = f.encode(12) + self.assertEqual(f, self.R.decode(f_bytes, 12)) + + def test_decode_wrong_length(self): + self.assertRaises(ValueError, lambda: self.R.decode(b"1", 12)) + + def test_call(self): + self.assertEqual(1, self.R(1)) + self.assertRaises(TypeError, lambda: self.R("a")) + + +class TestPolynomial(unittest.TestCase): + R = PolynomialRingKyber() + + def test_ntt_transform(self): + f = self.R.random_element() + g_hat = self.R.random_element().to_ntt() + + self.assertEqual(f, f.to_ntt().from_ntt()) + self.assertEqual(g_hat, g_hat.from_ntt().to_ntt()) + + self.assertRaises(TypeError, lambda: f.from_ntt()) + self.assertRaises(TypeError, lambda: g_hat.to_ntt()) + + def test_add_failure(self): + f1 = self.R.random_element() + self.assertRaises(NotImplementedError, lambda: f1 + "a") + + def test_sub_failure(self): + f1 = self.R.random_element() + self.assertRaises(NotImplementedError, lambda: f1 - "a") + + def test_mul_failure(self): + f1 = self.R.random_element() + self.assertRaises(NotImplementedError, lambda: f1 * "a") + + def test_pow_failure(self): + f1 = self.R.random_element() + self.assertRaises(TypeError, lambda: f1 ** "a") + + def test_add_polynomials(self): + zero = self.R(0) + for _ in range(10): + f1 = self.R.random_element() + f2 = self.R.random_element() + f3 = self.R.random_element() + + self.assertEqual(f1 + zero, f1) + self.assertEqual(f1 + f2, f2 + f1) + self.assertEqual(f1 + (f2 + f3), (f1 + f2) + f3) + + f2 = f1 + f2 += f1 + self.assertEqual(f1 + f1, f2) + + def test_sub_polynomials(self): + zero = self.R(0) + for _ in range(10): + f1 = self.R.random_element() + f2 = self.R.random_element() + f3 = self.R.random_element() + + self.assertEqual(f1 - zero, f1) + self.assertEqual(f3 - f3, zero) + self.assertEqual(f3 - 0, f3) + self.assertEqual(0 - f3, -f3) + self.assertEqual(f1 - f2, -(f2 - f1)) + self.assertEqual(f1 - (f2 - f3), (f1 - f2) + f3) + + f2 = f1 + f2 -= f1 + self.assertEqual(f2, zero) + + def test_mul_polynomials(self): + zero = self.R(0) + one = self.R(1) + for _ in range(10): + f1 = self.R.random_element() + f2 = self.R.random_element() + f3 = self.R.random_element() + + self.assertEqual(f1 * zero, zero) + self.assertEqual(f1 * one, f1) + self.assertEqual(f1 * f2, f2 * f1) + self.assertEqual(f1 * (f2 * f3), (f1 * f2) * f3) + self.assertEqual(2 * f1, f1 + f1) + self.assertEqual(2 * f1, f1 * 2) + + f2 = f1 + f2 *= f2 + self.assertEqual(f1 * f1, f2) + + def test_pow_polynomials(self): + one = self.R(1) + for _ in range(10): + f1 = self.R.random_element() + + self.assertEqual(one, f1**0) + self.assertEqual(f1, f1**1) + self.assertEqual(f1 * f1, f1**2) + self.assertEqual(f1 * f1 * f1, f1**3) + self.assertRaises(ValueError, lambda: f1 ** (-1)) + + def test_add_failure_ntt(self): + f1 = self.R.random_element().to_ntt() + self.assertRaises(NotImplementedError, lambda: f1 + "a") + + def test_sub_failure_ntt(self): + f1 = self.R.random_element().to_ntt() + self.assertRaises(NotImplementedError, lambda: f1 - "a") + + def test_mul_failure_ntt(self): + f1 = self.R.random_element().to_ntt() + self.assertRaises(NotImplementedError, lambda: f1 * "a") + + def test_pow_failure_ntt(self): + f1 = self.R.random_element().to_ntt() + self.assertRaises(TypeError, lambda: f1 ** "a") + + def test_add_polynomials_ntt(self): + zero_hat = self.R(0).to_ntt() + for _ in range(10): + f1_hat = self.R.random_element().to_ntt() + f2_hat = self.R.random_element().to_ntt() + f3_hat = self.R.random_element().to_ntt() + + self.assertEqual(f1_hat + zero_hat, f1_hat) + self.assertEqual(f1_hat + f2_hat, f2_hat + f1_hat) + self.assertEqual( + f1_hat + (f2_hat + f3_hat), (f1_hat + f2_hat) + f3_hat + ) + + f2_hat = f1_hat + f2_hat += f1_hat + self.assertEqual(f1_hat + f1_hat, f2_hat) + + def test_sub_polynomials_ntt(self): + zero_hat = self.R(0).to_ntt() + for _ in range(10): + f1_hat = self.R.random_element().to_ntt() + f2_hat = self.R.random_element().to_ntt() + f3_hat = self.R.random_element().to_ntt() + + self.assertEqual(f1_hat - zero_hat, f1_hat) + self.assertEqual(f3_hat - f3_hat, zero_hat) + self.assertEqual(f3_hat - 0, f3_hat) + self.assertEqual(0 - f3_hat, -f3_hat) + self.assertEqual( + f1_hat - (f2_hat - f3_hat), (f1_hat - f2_hat) + f3_hat + ) + + f2_hat = f1_hat + f2_hat -= f1_hat + self.assertEqual(f2_hat, zero_hat) + + def test_mul_polynomials_ntt(self): + zero_hat = self.R(0).to_ntt() + one_hat = self.R(1).to_ntt() + for _ in range(10): + f1_hat = self.R.random_element().to_ntt() + f2_hat = self.R.random_element().to_ntt() + f3_hat = self.R.random_element().to_ntt() + + self.assertEqual(f1_hat * zero_hat, zero_hat) + self.assertEqual(f1_hat * one_hat, f1_hat) + self.assertEqual(f1_hat * f2_hat, f2_hat * f1_hat) + self.assertEqual( + f1_hat * (f2_hat * f3_hat), (f1_hat * f2_hat) * f3_hat + ) + self.assertEqual(2 * f1_hat, f1_hat + f1_hat) + self.assertEqual(2 * f1_hat, f1_hat * 2) + + f2_hat = f1_hat + f2_hat *= f2_hat + self.assertEqual(f1_hat * f1_hat, f2_hat) diff --git a/tests/test_polynomial_generic.py b/tests/test_polynomial_generic.py index cc71c70..500f4a1 100644 --- a/tests/test_polynomial_generic.py +++ b/tests/test_polynomial_generic.py @@ -68,6 +68,7 @@ def test_equality(self): self.assertTrue(self.R(0) == 0) self.assertTrue(self.R(1) == self.R.q + 1) self.assertTrue(self.R(self.R.q - 1) == -1) + self.assertFalse(self.R(self.R.q - 1) == "a") def test_add_failure(self): f1 = self.R.random_element() @@ -153,9 +154,5 @@ def test_print(self): self.assertEqual(str(self.R(1)), "1") self.assertEqual(str(self.R.gen()), "x") self.assertEqual( - str(self.R([1, 2, 3, 4, 5])), "1 + 2*x + 3*x^2 + 4*x^3 + 5*x^4" + str(self.R([1, 2, 3, 4, 1])), "1 + 2*x + 3*x^2 + 4*x^3 + x^4" ) - - -if __name__ == "__main__": - unittest.main()