Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tests to check for failures in MLKEM/Kyber #71

Merged
merged 1 commit into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 0 additions & 9 deletions src/kyber_py/drbg/aes256_ctr_drbg.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,6 @@ def ctr_drbg_update(self, provided_data):
self.key = tmp[:32]
self.V = tmp[32:]

def reseed(self, additional_information=b""):
"""
Reseed the DRBG for when reseed_ctr hits the
limit.
"""
seed_material = self.__instantiate(additional_information)
self.ctr_drbg_update(seed_material)
self.reseed_ctr = 1

def random_bytes(self, num_bytes, additional=None):
if self.reseed_ctr >= self.reseed_interval:
raise Warning("The DRBG has been exhausted! Reseed!")
Expand Down
18 changes: 1 addition & 17 deletions src/kyber_py/kyber/kyber.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,25 +46,9 @@ def set_drbg_seed(self, seed):
self.random_bytes = self._drbg.random_bytes
except ImportError as e:
print(f"Error importing AES from pycryptodome: {e = }")
print(
"Have you tried installing requirements: pip -r install requirements"
)

def reseed_drbg(self, seed):
"""
Reseeds the DRBG, errors if a DRBG is not set.

Note:
currently requires pycryptodome for AES impl.

:param bytes seed: random bytes to use as a new seed of the DRBG
"""
if self._drbg is None:
raise Warning(
"Cannot reseed DRBG without first initialising. Try using `set_drbg_seed`"
"Cannot set DRBG seed due to missing dependencies, try installing requirements: pip -r install requirements"
)
else:
self._drbg.reseed(seed)

@staticmethod
def _xof(bytes32, i, j):
Expand Down
45 changes: 20 additions & 25 deletions src/kyber_py/ml_kem/ml_kem.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,25 +51,9 @@ def set_drbg_seed(self, seed):
self.random_bytes = self._drbg.random_bytes
except ImportError as e:
print(f"Error importing AES from pycryptodome: {e = }")
print(
"Have you tried installing requirements: pip -r install requirements"
)

def reseed_drbg(self, seed):
"""
Reseeds the DRBG, errors if a DRBG is not set.

Note:
currently requires pycryptodome for AES impl.

:param bytes seed: random bytes to use as a new seed of the DRBG
"""
if self._drbg is None:
raise Warning(
"Cannot reseed DRBG without first initialising. Try using `set_drbg_seed`"
"Cannot set DRBG seed due to missing dependencies, try installing requirements: pip -r install requirements"
)
else:
self._drbg.reseed(seed)

@staticmethod
def _xof(bytes32, i, j):
Expand Down Expand Up @@ -201,12 +185,13 @@ def _pke_encrypt(self, ek_pke, m, r):

# NOTE:
# Perform the input validation checks for ML-KEM
assert (
len(ek_pke) == 384 * self.k + 32
), "Type check failed, ek_pke has the wrong length"
assert (
t_hat.encode(12) == t_hat_bytes
), "Modulus check failed, t_hat does not encode correctly"
if len(ek_pke) != 384 * self.k + 32:
raise ValueError("Type check failed, ek_pke has the wrong length")

if t_hat.encode(12) != t_hat_bytes:
raise ValueError(
"Modulus check failed, t_hat does not encode correctly"
)

# Generate A_hat^T from seed rho
A_hat_T = self._generate_matrix_from_seed(rho, transpose=True)
Expand Down Expand Up @@ -286,8 +271,14 @@ def encaps(self, ek):
m = self.random_bytes(32)
K, r = self._G(m + self._H(ek))

# Perform the underlying pke encryption
c = self._pke_encrypt(ek, m, r)
# Perform the underlying pke encryption, raises a ValueError if
# ek fails either the TypeCheck or ModulusCheck
try:
c = self._pke_encrypt(ek, m, r)
except ValueError as e:
raise ValueError(
f"Valildation of encapsulation key failed: {e = }"
)

return (K, c)

Expand Down Expand Up @@ -325,6 +316,10 @@ def decaps(self, c, dk):
# Re-encrypt the recovered message
K_prime, r_prime = self._G(m_prime + h)
K_bar = self._J(z + c)

# Here the public encapsulation key is read from the private
# key and so we never expect this to fail the TypeCheck or
# ModulusCheck
c_prime = self._pke_encrypt(ek_pke, m_prime, r_prime)

# If c != c_prime, return K_bar as garbage
Expand Down
13 changes: 13 additions & 0 deletions tests/test_kyber.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,16 @@ def generic_test_kyber(self, Kyber, count):
pk, sk = Kyber.keygen()
for _ in range(count):
key, c = Kyber.encaps(pk)

# Correct decaps works
_key = Kyber.decaps(c, sk)
self.assertEqual(key, _key)

# Incorrect ct does not work
_bad_ct = bytes([0] * len(c))
_bad = Kyber.decaps(_bad_ct, sk)
self.assertNotEqual(key, _bad)

def test_kyber512(self):
self.generic_test_kyber(Kyber512, 5)

Expand All @@ -46,6 +53,12 @@ def test_kyber768(self):
def test_kyber1024(self):
self.generic_test_kyber(Kyber1024, 5)

def test_xof_failure(self):
self.assertRaises(ValueError, lambda: Kyber512._xof(b"1", b"2", b"3"))

def test_prf_failure(self):
self.assertRaises(ValueError, lambda: Kyber512._prf(b"1", b"2", 32))


class TestKyberDeterministic(unittest.TestCase):
"""
Expand Down
31 changes: 31 additions & 0 deletions tests/test_ml_kem.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,37 @@ def test_ML_KEM_768(self):
def test_ML_KEM_1024(self):
self.generic_test_ML_KEM(ML_KEM_1024, 5)

def test_encaps_type_check_failure(self):
"""
Send an ecaps key of the wrong length
"""
self.assertRaises(ValueError, lambda: ML_KEM_512.encaps(b"1"))

def test_encaps_modulus_check_failure(self):
"""
We create a vector of polynomials with non-canonical values for
coefficents to fail the modulus check
"""
(ek, _) = ML_KEM_512.keygen()
rho = ek[-32:]

bad_f_hat = ML_KEM_512.R([3329] * 256)
bad_t_hat = ML_KEM_512.M.vector([bad_f_hat, bad_f_hat])
bad_t_hat_bytes = bad_t_hat.encode(12)

bad_ek = bad_t_hat_bytes + rho

self.assertEqual(len(bad_ek), len(ek))
self.assertRaises(ValueError, lambda: ML_KEM_512.encaps(bad_ek))

def test_xof_failure(self):
self.assertRaises(
ValueError, lambda: ML_KEM_512._xof(b"1", b"2", b"3")
)

def test_prf_failure(self):
self.assertRaises(ValueError, lambda: ML_KEM_512._prf(2, b"1", b"2"))


# As there are 1000 KATs in the file, execution of all of them takes
# a lot of time, run just 100
Expand Down