Skip to content

MLDsaOpenSsl + tests #114485

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,40 @@ internal static partial class Interop
{
internal static partial class Crypto
{
// Must be kept in sync with PalMLDsaId in native shim.
internal enum PalMLDsaAlgorithmId
{
Unknown = 0,
MLDsa44 = 1,
MLDsa65 = 2,
MLDsa87 = 3,
}

[LibraryImport(Libraries.CryptoNative)]
private static partial int CryptoNative_MLDsaGetPalId(
SafeEvpPKeyHandle mldsa,
out PalMLDsaAlgorithmId mldsaId);

internal static PalMLDsaAlgorithmId MLDsaGetPalId(SafeEvpPKeyHandle key)
{
const int Success = 1;
const int Fail = 0;
int result = CryptoNative_MLDsaGetPalId(key, out PalMLDsaAlgorithmId mldsaId);

return result switch
{
Success => mldsaId,
Fail => throw CreateOpenSslCryptographicException(),
int other => throw FailThrow(other),
};

static Exception FailThrow(int result)
{
Debug.Fail($"Unexpected return value {result} from {nameof(CryptoNative_MLDsaGetPalId)}.");
return new CryptographicException();
}
}

[LibraryImport(Libraries.CryptoNative, StringMarshalling = StringMarshalling.Utf8)]
private static partial SafeEvpPKeyHandle CryptoNative_MLDsaGenerateKey(string keyType, ReadOnlySpan<byte> seed, int seedLength);

Expand Down Expand Up @@ -80,7 +114,7 @@ internal static void MLDsaSignPure(
Span<byte> destination)
{
int ret = CryptoNative_MLDsaSignPure(
pkey, pkey.ExtraHandle,
pkey, GetExtraHandle(pkey),
msg, msg.Length,
context, context.Length,
destination, destination.Length);
Expand All @@ -105,7 +139,7 @@ internal static bool MLDsaVerifyPure(
ReadOnlySpan<byte> signature)
{
int ret = CryptoNative_MLDsaVerifyPure(
pkey, pkey.ExtraHandle,
pkey, GetExtraHandle(pkey),
msg, msg.Length,
context, context.Length,
signature, signature.Length);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections.Generic;
using System.Security.Cryptography.Dsa.Tests;
using Microsoft.DotNet.RemoteExecutor;
using Microsoft.DotNet.XUnitExtensions;
using Test.Cryptography;
using Xunit;

namespace System.Security.Cryptography.Tests
{
[ConditionalClass(typeof(MLDsa), nameof(MLDsa.IsSupported))]
public class MLDsaImplementationTests : MLDsaTestsBase
{
protected override MLDsa GenerateKey(MLDsaAlgorithm algorithm) => MLDsa.GenerateKey(algorithm);
protected override MLDsa ImportPrivateSeed(MLDsaAlgorithm algorithm, ReadOnlySpan<byte> seed) => MLDsa.ImportMLDsaPrivateSeed(algorithm, seed);
protected override MLDsa ImportSecretKey(MLDsaAlgorithm algorithm, ReadOnlySpan<byte> source) => MLDsa.ImportMLDsaSecretKey(algorithm, source);
protected override MLDsa ImportPublicKey(MLDsaAlgorithm algorithm, ReadOnlySpan<byte> source) => MLDsa.ImportMLDsaPublicKey(algorithm, source);

[Fact]
public static void GenerateImport_NullAlgorithm()
{
AssertExtensions.Throws<ArgumentNullException>("algorithm", static () => MLDsa.GenerateKey(null));
AssertExtensions.Throws<ArgumentNullException>("algorithm", static () => MLDsa.ImportMLDsaPrivateSeed(null, default));
AssertExtensions.Throws<ArgumentNullException>("algorithm", static () => MLDsa.ImportMLDsaPublicKey(null, default));
AssertExtensions.Throws<ArgumentNullException>("algorithm", static () => MLDsa.ImportMLDsaSecretKey(null, default));
}

[Theory]
[MemberData(nameof(MLDsaTestsData.AllMLDsaAlgorithms), MemberType = typeof(MLDsaTestsData))]
public static void ImportMLDsaSecretKey_WrongSize(MLDsaAlgorithm algorithm)
{
AssertExtensions.Throws<ArgumentException>("source", () => MLDsa.ImportMLDsaSecretKey(algorithm, new byte[algorithm.SecretKeySizeInBytes - 1]));
AssertExtensions.Throws<ArgumentException>("source", () => MLDsa.ImportMLDsaSecretKey(algorithm, new byte[algorithm.SecretKeySizeInBytes + 1]));
AssertExtensions.Throws<ArgumentException>("source", () => MLDsa.ImportMLDsaSecretKey(algorithm, default));
}

[Theory]
[MemberData(nameof(MLDsaTestsData.AllMLDsaAlgorithms), MemberType = typeof(MLDsaTestsData))]
public static void ImportMLDsaPrivateSeed_WrongSize(MLDsaAlgorithm algorithm)
{
AssertExtensions.Throws<ArgumentException>("source", () => MLDsa.ImportMLDsaPrivateSeed(algorithm, new byte[algorithm.PrivateSeedSizeInBytes - 1]));
AssertExtensions.Throws<ArgumentException>("source", () => MLDsa.ImportMLDsaPrivateSeed(algorithm, new byte[algorithm.PrivateSeedSizeInBytes + 1]));
AssertExtensions.Throws<ArgumentException>("source", () => MLDsa.ImportMLDsaPrivateSeed(algorithm, default));
}

[Theory]
[MemberData(nameof(MLDsaTestsData.AllMLDsaAlgorithms), MemberType = typeof(MLDsaTestsData))]
public static void ImportMLDsaPublicKey_WrongSize(MLDsaAlgorithm algorithm)
{
AssertExtensions.Throws<ArgumentException>("source", () => MLDsa.ImportMLDsaPublicKey(algorithm, new byte[algorithm.PublicKeySizeInBytes - 1]));
AssertExtensions.Throws<ArgumentException>("source", () => MLDsa.ImportMLDsaPublicKey(algorithm, new byte[algorithm.PublicKeySizeInBytes + 1]));
AssertExtensions.Throws<ArgumentException>("source", () => MLDsa.ImportMLDsaPublicKey(algorithm, default));
}

[Fact]
public static void UseAfterDispose()
{
MLDsa mldsa = MLDsa.GenerateKey(MLDsaAlgorithm.MLDsa44);
mldsa.Dispose();
mldsa.Dispose(); // no throw

VerifyDisposed(mldsa);
}
}
}

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections.Generic;
using Microsoft.DotNet.XUnitExtensions;
using Test.Cryptography;
using Xunit;

namespace System.Security.Cryptography.Tests
{
[ConditionalClass(typeof(MLDsa), nameof(MLDsa.IsSupported))]
public abstract class MLDsaTestsBase
{
protected abstract MLDsa GenerateKey(MLDsaAlgorithm algorithm);
protected abstract MLDsa ImportPrivateSeed(MLDsaAlgorithm algorithm, ReadOnlySpan<byte> seed);
protected abstract MLDsa ImportSecretKey(MLDsaAlgorithm algorithm, ReadOnlySpan<byte> source);
protected abstract MLDsa ImportPublicKey(MLDsaAlgorithm algorithm, ReadOnlySpan<byte> source);

[Theory]
[MemberData(nameof(MLDsaTestsData.AllMLDsaAlgorithms), MemberType = typeof(MLDsaTestsData))]
public void AlgorithmIsAssigned(MLDsaAlgorithm algorithm)
{
using MLDsa mldsa = GenerateKey(algorithm);
Assert.Same(algorithm, mldsa.Algorithm);
}

[Theory]
[MemberData(nameof(MLDsaTestsData.AllMLDsaAlgorithms), MemberType = typeof(MLDsaTestsData))]
public void GenerateSignVerifyNoContext(MLDsaAlgorithm algorithm)
{
using MLDsa mldsa = GenerateKey(algorithm);
byte[] data = [ 1, 2, 3, 4, 5 ];
byte[] signature = new byte[mldsa.Algorithm.SignatureSizeInBytes];
Assert.Equal(signature.Length, mldsa.SignData(data, signature));

ExerciseSuccessfulVerify(mldsa, data, signature, []);
}

[Theory]
[MemberData(nameof(MLDsaTestsData.AllMLDsaAlgorithms), MemberType = typeof(MLDsaTestsData))]
public void GenerateSignVerifyWithContext(MLDsaAlgorithm algorithm)
{
using MLDsa mldsa = GenerateKey(algorithm);
byte[] context = [ 1, 1, 3, 5, 6 ];
byte[] data = [ 1, 2, 3, 4, 5 ];
byte[] signature = new byte[mldsa.Algorithm.SignatureSizeInBytes];
Assert.Equal(signature.Length, mldsa.SignData(data, signature, context));

ExerciseSuccessfulVerify(mldsa, data, signature, context);
}

[Theory]
[MemberData(nameof(MLDsaTestsData.AllMLDsaAlgorithms), MemberType = typeof(MLDsaTestsData))]
public void GenerateSignExportPublicVerifyWithPublicOnly(MLDsaAlgorithm algorithm)
{
byte[] publicKey;
byte[] data = [ 1, 2, 3, 4, 5 ];
byte[] signature;

using (MLDsa mldsa = GenerateKey(algorithm))
{
signature = new byte[algorithm.SignatureSizeInBytes];
Assert.Equal(signature.Length, mldsa.SignData(data, signature));
AssertExtensions.TrueExpression(mldsa.VerifyData(data, signature));

publicKey = new byte[algorithm.PublicKeySizeInBytes];
Assert.Equal(publicKey.Length, mldsa.ExportMLDsaPublicKey(publicKey));
}

using (MLDsa mldsaPub = ImportPublicKey(algorithm, publicKey))
{
ExerciseSuccessfulVerify(mldsaPub, data, signature, []);
}
}

[Theory]
[MemberData(nameof(MLDsaTestsData.AllMLDsaAlgorithms), MemberType = typeof(MLDsaTestsData))]
public void GenerateExportSecretKeySignAndVerify(MLDsaAlgorithm algorithm)
{
byte[] secretKey;
byte[] data = [ 1, 2, 3, 4, 5 ];
byte[] signature;

using (MLDsa mldsaTmp = GenerateKey(algorithm))
{
signature = new byte[algorithm.SignatureSizeInBytes];
Assert.Equal(signature.Length, mldsaTmp.SignData(data, signature));

secretKey = new byte[algorithm.SecretKeySizeInBytes];
Assert.Equal(secretKey.Length, mldsaTmp.ExportMLDsaSecretKey(secretKey));
}

using (MLDsa mldsa = ImportSecretKey(algorithm, secretKey))
{
AssertExtensions.TrueExpression(mldsa.VerifyData(data, signature));

signature.AsSpan().Fill(0);
Assert.Equal(signature.Length, mldsa.SignData(data, signature));

AssertExtensions.TrueExpression(mldsa.VerifyData(data, signature));
data[0] ^= 1;
AssertExtensions.FalseExpression(mldsa.VerifyData(data, signature));
}
}

[Theory]
[MemberData(nameof(MLDsaTestsData.AllMLDsaAlgorithms), MemberType = typeof(MLDsaTestsData))]
public void GenerateExportPrivateSeedSignAndVerify(MLDsaAlgorithm algorithm)
{
byte[] privateSeed;
byte[] data = [ 1, 2, 3, 4, 5 ];
byte[] signature;

using (MLDsa mldsaTmp = GenerateKey(algorithm))
{
signature = new byte[algorithm.SignatureSizeInBytes];
Assert.Equal(signature.Length, mldsaTmp.SignData(data, signature));

privateSeed = new byte[algorithm.PrivateSeedSizeInBytes];
Assert.Equal(privateSeed.Length, mldsaTmp.ExportMLDsaPrivateSeed(privateSeed));
}

using (MLDsa mldsa = ImportPrivateSeed(algorithm, privateSeed))
{
AssertExtensions.TrueExpression(mldsa.VerifyData(data, signature));

signature.AsSpan().Fill(0);
Assert.Equal(signature.Length, mldsa.SignData(data, signature));

ExerciseSuccessfulVerify(mldsa, data, signature, []);
}
}

[Fact]
public void ImportSecretKey_CannotReconstructSeed()
{
byte[] secretKey = new byte[MLDsaAlgorithm.MLDsa44.SecretKeySizeInBytes];
using (MLDsa mldsaOriginal = GenerateKey(MLDsaAlgorithm.MLDsa44))
{
Assert.Equal(secretKey.Length, mldsaOriginal.ExportMLDsaSecretKey(secretKey));
}

using (MLDsa mldsa = ImportSecretKey(MLDsaAlgorithm.MLDsa44, secretKey))
{
Assert.Throws<CryptographicException>(() => mldsa.ExportMLDsaPrivateSeed(new byte[MLDsaAlgorithm.MLDsa44.PrivateSeedSizeInBytes]));
}
}

[Fact]
public void ImportSeed_CanReconstructSecretKey()
{
byte[] secretKey = new byte[MLDsaAlgorithm.MLDsa44.SecretKeySizeInBytes];
byte[] seed = new byte[MLDsaAlgorithm.MLDsa44.PrivateSeedSizeInBytes];
using (MLDsa mldsaOriginal = GenerateKey(MLDsaAlgorithm.MLDsa44))
{
Assert.Equal(secretKey.Length, mldsaOriginal.ExportMLDsaSecretKey(secretKey));
Assert.Equal(seed.Length, mldsaOriginal.ExportMLDsaPrivateSeed(seed));
}

using (MLDsa mldsa = ImportPrivateSeed(MLDsaAlgorithm.MLDsa44, seed))
{
byte[] secretKey2 = new byte[MLDsaAlgorithm.MLDsa44.SecretKeySizeInBytes];
byte[] seed2 = new byte[MLDsaAlgorithm.MLDsa44.PrivateSeedSizeInBytes];

Assert.Equal(secretKey2.Length, mldsa.ExportMLDsaSecretKey(secretKey2));
Assert.Equal(seed2.Length, mldsa.ExportMLDsaPrivateSeed(seed2));

AssertExtensions.SequenceEqual(secretKey, secretKey2);
AssertExtensions.SequenceEqual(seed, seed2);
}
}

[Theory]
[MemberData(nameof(MLDsaTestsData.AllNistTestCases), MemberType = typeof(MLDsaTestsData))]
public void NistImportPublicKeyVerify(MLDsaNistTestCase testCase)
{
using MLDsa mldsa = ImportPublicKey(testCase.Algorithm, testCase.PublicKey);
Assert.Equal(testCase.ShouldPass, mldsa.VerifyData(testCase.Message, testCase.Signature, testCase.Context));
}

[Theory]
[MemberData(nameof(MLDsaTestsData.AllNistTestCases), MemberType = typeof(MLDsaTestsData))]
public void NistImportSecretKeyVerifyExportsAndSignature(MLDsaNistTestCase testCase)
{
using MLDsa mldsa = ImportSecretKey(testCase.Algorithm, testCase.SecretKey);

byte[] pubKey = new byte[testCase.Algorithm.PublicKeySizeInBytes];
Assert.Equal(pubKey.Length, mldsa.ExportMLDsaPublicKey(pubKey));
AssertExtensions.SequenceEqual(testCase.PublicKey, pubKey);

byte[] secretKey = new byte[testCase.Algorithm.SecretKeySizeInBytes];
Assert.Equal(secretKey.Length, mldsa.ExportMLDsaSecretKey(secretKey));

byte[] seed = new byte[testCase.Algorithm.PrivateSeedSizeInBytes];
Assert.Throws<CryptographicException>(() => mldsa.ExportMLDsaPrivateSeed(seed));

Assert.Equal(testCase.ShouldPass, mldsa.VerifyData(testCase.Message, testCase.Signature, testCase.Context));
}

protected static void ExerciseSuccessfulVerify(MLDsa mldsa, byte[] data, byte[] signature, byte[] context)
{
AssertExtensions.TrueExpression(mldsa.VerifyData(data, signature, context));
data[0] ^= 1;
AssertExtensions.FalseExpression(mldsa.VerifyData(data, signature, context));
data[0] ^= 1;

signature[0] ^= 1;
AssertExtensions.FalseExpression(mldsa.VerifyData(data, signature, context));
signature[0] ^= 1;

if (context.Length > 0)
{
AssertExtensions.FalseExpression(mldsa.VerifyData(data, signature, []));

context[0] ^= 1;
AssertExtensions.FalseExpression(mldsa.VerifyData(data, signature, context));
context[0] ^= 1;
}
else
{
AssertExtensions.FalseExpression(mldsa.VerifyData(data, signature, [0]));
AssertExtensions.FalseExpression(mldsa.VerifyData(data, signature, [1, 2, 3]));
}

AssertExtensions.TrueExpression(mldsa.VerifyData(data, signature, context));
}

protected static void VerifyDisposed(MLDsa mldsa)
{
PbeParameters pbeParams = new PbeParameters(PbeEncryptionAlgorithm.Aes128Cbc, HashAlgorithmName.SHA256, 10);

Assert.Throws<ObjectDisposedException>(() => mldsa.SignData([], new byte[mldsa.Algorithm.SignatureSizeInBytes]));
Assert.Throws<ObjectDisposedException>(() => mldsa.VerifyData([], new byte[mldsa.Algorithm.SignatureSizeInBytes]));

Assert.Throws<ObjectDisposedException>(() => mldsa.ExportMLDsaPrivateSeed(new byte[mldsa.Algorithm.PrivateSeedSizeInBytes]));
Assert.Throws<ObjectDisposedException>(() => mldsa.ExportMLDsaPublicKey(new byte[mldsa.Algorithm.PublicKeySizeInBytes]));
Assert.Throws<ObjectDisposedException>(() => mldsa.ExportMLDsaSecretKey(new byte[mldsa.Algorithm.SecretKeySizeInBytes]));

Assert.Throws<ObjectDisposedException>(() => mldsa.ExportPkcs8PrivateKey());
Assert.Throws<ObjectDisposedException>(() => mldsa.TryExportPkcs8PrivateKey(new byte[10000], out _));
Assert.Throws<ObjectDisposedException>(() => mldsa.ExportPkcs8PrivateKeyPem());

Assert.Throws<ObjectDisposedException>(() => mldsa.ExportEncryptedPkcs8PrivateKey([1, 2, 3], pbeParams));
Assert.Throws<ObjectDisposedException>(() => mldsa.ExportEncryptedPkcs8PrivateKey("123", pbeParams));
Assert.Throws<ObjectDisposedException>(() => mldsa.TryExportEncryptedPkcs8PrivateKey([1, 2, 3], pbeParams, new byte[10000], out _));
Assert.Throws<ObjectDisposedException>(() => mldsa.TryExportEncryptedPkcs8PrivateKey("123", pbeParams, new byte[10000], out _));

Assert.Throws<ObjectDisposedException>(() => mldsa.ExportEncryptedPkcs8PrivateKeyPem([1, 2, 3], pbeParams));
Assert.Throws<ObjectDisposedException>(() => mldsa.ExportEncryptedPkcs8PrivateKeyPem("123", pbeParams));

Assert.Throws<ObjectDisposedException>(() => mldsa.ExportSubjectPublicKeyInfo());
Assert.Throws<ObjectDisposedException>(() => mldsa.TryExportSubjectPublicKeyInfo(new byte[10000], out _));
Assert.Throws<ObjectDisposedException>(() => mldsa.ExportSubjectPublicKeyInfoPem());
}
}
}
Loading
Loading