Skip to content

Make CertificateRequest et al work with ML-DSA #114471

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 37 additions & 6 deletions src/libraries/Common/src/System/Security/Cryptography/Helpers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -113,17 +113,48 @@ internal static int HashOidToByteLength(string hashOid)
};
}

internal static bool HashAlgorithmRequired(string? keyAlgorithm)
{
// This list could either be written as "ML-DSA and friends return false",
// or "RSA and friends return true".
//
// The consequences of returning true is that the hashAlgorithm parameter
// gets pre-validated to not be null or empty, which means false positives
// impact new ML-DSA-like algorithms.
//
// The consequences of returning false is that the hashAlgorithm parameter
// is not pre-validated. That just means that in a false negative the user
// gets probably the same exception, but from a different callstack.
//
// False positives or negatives are not possible with the simple Build that takes
// only an X509Certificate2, as we control the destiny there entirely, it's only
// for the power user scenario of the X509SignatureGenerator that this is a concern.
//
// Since the false-positive is worse than the false-negative, the list is written
// as explicit-true, implicit-false.
return keyAlgorithm switch
{
Oids.Rsa or
Oids.RsaPss or
Oids.EcPublicKey or
Oids.Dsa => true,
_ => false,
};
}

internal static CryptographicException CreateAlgorithmUnknownException(AsnWriter encodedId)
{
#if NET10_0_OR_GREATER
return encodedId.Encode(static encoded =>
new CryptographicException(
SR.Format(SR.Cryptography_UnknownAlgorithmIdentifier, Convert.ToHexString(encoded))));
return encodedId.Encode(static encoded => CreateAlgorithmUnknownException(Convert.ToHexString(encoded)));
#else
return new CryptographicException(
SR.Format(SR.Cryptography_UnknownAlgorithmIdentifier,
HexConverter.ToString(encodedId.Encode(), HexConverter.Casing.Upper)));
return CreateAlgorithmUnknownException(HexConverter.ToString(encodedId.Encode(), HexConverter.Casing.Upper));
#endif
}

internal static CryptographicException CreateAlgorithmUnknownException(string algorithmId)
{
throw new CryptographicException(
SR.Format(SR.Cryptography_UnknownAlgorithmIdentifier, algorithmId));
}
}
}
14 changes: 12 additions & 2 deletions src/libraries/Common/src/System/Security/Cryptography/MLDsa.cs
Original file line number Diff line number Diff line change
Expand Up @@ -753,7 +753,12 @@ public static MLDsa ImportSubjectPublicKeyInfo(ReadOnlySpan<byte> source)
AsnValueReader reader = new AsnValueReader(source, AsnEncodingRules.DER);
SubjectPublicKeyInfoAsn.Decode(ref reader, manager.Memory, out SubjectPublicKeyInfoAsn spki);

MLDsaAlgorithm algorithm = MLDsaAlgorithm.GetMLDsaAlgorithmFromOid(spki.Algorithm.Algorithm);
MLDsaAlgorithm? algorithm = MLDsaAlgorithm.GetMLDsaAlgorithmFromOid(spki.Algorithm.Algorithm);

if (algorithm is null)
{
throw Helpers.CreateAlgorithmUnknownException(spki.Algorithm.Algorithm);
}

if (spki.Algorithm.Parameters.HasValue)
{
Expand Down Expand Up @@ -803,7 +808,12 @@ public static MLDsa ImportPkcs8PrivateKey(ReadOnlySpan<byte> source)
AsnValueReader reader = new AsnValueReader(source, AsnEncodingRules.DER);
PrivateKeyInfoAsn.Decode(ref reader, manager.Memory, out PrivateKeyInfoAsn pki);

MLDsaAlgorithm algorithm = MLDsaAlgorithm.GetMLDsaAlgorithmFromOid(pki.PrivateKeyAlgorithm.Algorithm);
MLDsaAlgorithm? algorithm = MLDsaAlgorithm.GetMLDsaAlgorithmFromOid(pki.PrivateKeyAlgorithm.Algorithm);

if (algorithm is null)
{
throw Helpers.CreateAlgorithmUnknownException(pki.PrivateKeyAlgorithm.Algorithm);
}

if (pki.PrivateKeyAlgorithm.Parameters.HasValue)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,22 +107,15 @@ private MLDsaAlgorithm(string name, int secretKeySizeInBytes, int publicKeySizeI
/// </value>
public static MLDsaAlgorithm MLDsa87 { get; } = new MLDsaAlgorithm("ML-DSA-87", 4896, 2592, 4627, Oids.MLDsa87);

internal static MLDsaAlgorithm GetMLDsaAlgorithmFromOid(string oid)
internal static MLDsaAlgorithm? GetMLDsaAlgorithmFromOid(string? oid)
{
return oid switch
{
Oids.MLDsa44 => MLDsa44,
Oids.MLDsa65 => MLDsa65,
Oids.MLDsa87 => MLDsa87,
_ => ThrowAlgorithmUnknown(oid),
_ => null,
};
}

[DoesNotReturn]
private static MLDsaAlgorithm ThrowAlgorithmUnknown(string algorithmId)
{
throw new CryptographicException(
SR.Format(SR.Cryptography_UnknownAlgorithmIdentifier, algorithmId));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,19 @@ protected override void ExportMLDsaSecretKeyCore(Span<byte> destination) =>
protected override void ExportMLDsaPrivateSeedCore(Span<byte> destination) =>
throw new PlatformNotSupportedException();

internal static partial MLDsa GenerateKeyImpl(MLDsaAlgorithm algorithm) =>
internal static partial MLDsaImplementation GenerateKeyImpl(MLDsaAlgorithm algorithm) =>
throw new PlatformNotSupportedException();

internal static partial MLDsa ImportPublicKey(MLDsaAlgorithm algorithm, ReadOnlySpan<byte> source) =>
internal static partial MLDsaImplementation ImportPublicKey(MLDsaAlgorithm algorithm, ReadOnlySpan<byte> source) =>
throw new PlatformNotSupportedException();

internal static partial MLDsa ImportPkcs8PrivateKeyValue(MLDsaAlgorithm algorithm, ReadOnlySpan<byte> source) =>
internal static partial MLDsaImplementation ImportPkcs8PrivateKeyValue(MLDsaAlgorithm algorithm, ReadOnlySpan<byte> source) =>
throw new PlatformNotSupportedException();

internal static partial MLDsa ImportSecretKey(MLDsaAlgorithm algorithm, ReadOnlySpan<byte> source) =>
internal static partial MLDsaImplementation ImportSecretKey(MLDsaAlgorithm algorithm, ReadOnlySpan<byte> source) =>
throw new PlatformNotSupportedException();

internal static partial MLDsa ImportSeed(MLDsaAlgorithm algorithm, ReadOnlySpan<byte> source) =>
internal static partial MLDsaImplementation ImportSeed(MLDsaAlgorithm algorithm, ReadOnlySpan<byte> source) =>
throw new PlatformNotSupportedException();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,19 @@ protected override void ExportMLDsaSecretKeyCore(Span<byte> destination) =>
protected override void ExportMLDsaPrivateSeedCore(Span<byte> destination) =>
throw new PlatformNotSupportedException();

internal static partial MLDsa GenerateKeyImpl(MLDsaAlgorithm algorithm) =>
internal static partial MLDsaImplementation GenerateKeyImpl(MLDsaAlgorithm algorithm) =>
throw new PlatformNotSupportedException();

internal static partial MLDsa ImportPublicKey(MLDsaAlgorithm algorithm, ReadOnlySpan<byte> source) =>
internal static partial MLDsaImplementation ImportPublicKey(MLDsaAlgorithm algorithm, ReadOnlySpan<byte> source) =>
throw new PlatformNotSupportedException();

internal static partial MLDsa ImportPkcs8PrivateKeyValue(MLDsaAlgorithm algorithm, ReadOnlySpan<byte> source) =>
internal static partial MLDsaImplementation ImportPkcs8PrivateKeyValue(MLDsaAlgorithm algorithm, ReadOnlySpan<byte> source) =>
throw new PlatformNotSupportedException();

internal static partial MLDsa ImportSecretKey(MLDsaAlgorithm algorithm, ReadOnlySpan<byte> source) =>
internal static partial MLDsaImplementation ImportSecretKey(MLDsaAlgorithm algorithm, ReadOnlySpan<byte> source) =>
throw new PlatformNotSupportedException();

internal static partial MLDsa ImportSeed(MLDsaAlgorithm algorithm, ReadOnlySpan<byte> source) =>
internal static partial MLDsaImplementation ImportSeed(MLDsaAlgorithm algorithm, ReadOnlySpan<byte> source) =>
throw new PlatformNotSupportedException();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,41 @@ private MLDsaImplementation(MLDsaAlgorithm algorithm)

internal static partial bool SupportsAny();

internal static partial MLDsa GenerateKeyImpl(MLDsaAlgorithm algorithm);
internal static partial MLDsa ImportPublicKey(MLDsaAlgorithm algorithm, ReadOnlySpan<byte> source);
internal static partial MLDsa ImportPkcs8PrivateKeyValue(MLDsaAlgorithm algorithm, ReadOnlySpan<byte> source);
internal static partial MLDsa ImportSecretKey(MLDsaAlgorithm algorithm, ReadOnlySpan<byte> source);
internal static partial MLDsa ImportSeed(MLDsaAlgorithm algorithm, ReadOnlySpan<byte> source);
internal static partial MLDsaImplementation GenerateKeyImpl(MLDsaAlgorithm algorithm);
internal static partial MLDsaImplementation ImportPublicKey(MLDsaAlgorithm algorithm, ReadOnlySpan<byte> source);
internal static partial MLDsaImplementation ImportPkcs8PrivateKeyValue(MLDsaAlgorithm algorithm, ReadOnlySpan<byte> source);
internal static partial MLDsaImplementation ImportSecretKey(MLDsaAlgorithm algorithm, ReadOnlySpan<byte> source);
internal static partial MLDsaImplementation ImportSeed(MLDsaAlgorithm algorithm, ReadOnlySpan<byte> source);

/// <summary>
/// Duplicates an ML-DSA private key by export/import.
/// Only intended to be used when the key type is unknown.
/// </summary>
internal static MLDsaImplementation DuplicatePrivateKey(MLDsa key)
{
// The implementation type and any platform types (e.g. MLDsaOpenSsl)
// should inherently know how to clone themselves without the crudeness
// of export/import.
Debug.Assert(key is not MLDsaImplementation);

MLDsaAlgorithm alg = key.Algorithm;
byte[] rented = CryptoPool.Rent(alg.SecretKeySizeInBytes);
int written = 0;

try
{
written = key.ExportMLDsaPrivateSeed(rented);
return ImportSeed(alg, new ReadOnlySpan<byte>(rented, 0, written));
}
catch (CryptographicException)
{
written = key.ExportMLDsaSecretKey(rented);
return ImportSecretKey(alg, new ReadOnlySpan<byte>(rented, 0, written));
}
finally
{
CryptoPool.Return(rented, written);
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using Xunit;

namespace System.Security.Cryptography.Tests
{
internal sealed class MLDsaTestImplementation : MLDsa
{
internal delegate void ExportAction(Span<byte> destination);
internal delegate void SignAction(ReadOnlySpan<byte> data, ReadOnlySpan<byte> context, Span<byte> destination);
internal delegate bool VerifyFunc(ReadOnlySpan<byte> data, ReadOnlySpan<byte> context, ReadOnlySpan<byte> signature);

internal ExportAction ExportMLDsaPrivateSeedHook { get; set; }
internal ExportAction ExportMLDsaPublicKeyHook { get; set; }
internal ExportAction ExportMLDsaSecretKeyHook { get; set; }
internal SignAction SignDataHook { get; set; }
internal VerifyFunc VerifyDataHook { get; set; }
internal Action<bool> DisposeHook { get; set; } = _ => { };

private MLDsaTestImplementation(MLDsaAlgorithm algorithm) : base(algorithm)
{
}

protected override void Dispose(bool disposing) => DisposeHook(disposing);

protected override void ExportMLDsaPrivateSeedCore(Span<byte> destination) => ExportMLDsaPrivateSeedHook(destination);
protected override void ExportMLDsaPublicKeyCore(Span<byte> destination) => ExportMLDsaPublicKeyHook(destination);
protected override void ExportMLDsaSecretKeyCore(Span<byte> destination) => ExportMLDsaSecretKeyHook(destination);

protected override void SignDataCore(ReadOnlySpan<byte> data, ReadOnlySpan<byte> context, Span<byte> destination) =>
SignDataHook(data, context, destination);

protected override bool VerifyDataCore(ReadOnlySpan<byte> data, ReadOnlySpan<byte> context, ReadOnlySpan<byte> signature) =>
VerifyDataHook(data, context, signature);

internal static MLDsaTestImplementation CreateOverriddenCoreMethodsFail(MLDsaAlgorithm algorithm)
{
return new MLDsaTestImplementation(algorithm)
{
ExportMLDsaPrivateSeedHook = _ => Assert.Fail(),
ExportMLDsaPublicKeyHook = _ => Assert.Fail(),
ExportMLDsaSecretKeyHook = _ => Assert.Fail(),
SignDataHook = (_, _, _) => Assert.Fail(),
VerifyDataHook = (_, _, _) => { Assert.Fail(); return false; },
};
}

internal static MLDsaTestImplementation CreateNoOp(MLDsaAlgorithm algorithm)
{
return new MLDsaTestImplementation(algorithm)
{
ExportMLDsaPrivateSeedHook = d => d.Clear(),
ExportMLDsaPublicKeyHook = d => d.Clear(),
ExportMLDsaSecretKeyHook = d => d.Clear(),
SignDataHook = (data, context, destination) => destination.Clear(),
VerifyDataHook = (data, context, signature) => signature.IndexOfAnyExcept((byte)0) == -1,
};
}

internal static MLDsaTestImplementation Wrap(MLDsa other)
{
return new MLDsaTestImplementation(other.Algorithm)
{
ExportMLDsaPrivateSeedHook = d => other.ExportMLDsaPrivateSeed(d),
ExportMLDsaPublicKeyHook = d => other.ExportMLDsaPublicKey(d),
ExportMLDsaSecretKeyHook = d => other.ExportMLDsaSecretKey(d),
SignDataHook = (data, context, destination) => other.SignData(data, destination, context),
VerifyDataHook = (data, context, signature) => other.VerifyData(data, signature, context),
};
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
using System.Runtime.InteropServices;
using Xunit;

// PQC types are used throughout, but only when the caller requests them.
#pragma warning disable SYSLIB5006

namespace System.Security.Cryptography.X509Certificates.Tests.Common
{
// This class represents only a portion of what is required to be a proper Certificate Authority.
Expand Down Expand Up @@ -994,6 +997,7 @@ internal static X509Certificate2 CloneWithPrivateKey(X509Certificate2 cert, obje
{
RSA rsa => cert.CopyWithPrivateKey(rsa),
ECDsa ecdsa => cert.CopyWithPrivateKey(ecdsa),
MLDsa mldsa => cert.CopyWithPrivateKey(mldsa),
DSA dsa => cert.CopyWithPrivateKey(dsa),
_ => throw new InvalidOperationException(
$"Had no handler for key of type {key?.GetType().FullName ?? "null"}")
Expand All @@ -1008,6 +1012,9 @@ internal sealed class KeyFactory
internal static KeyFactory ECDsa { get; } =
new(() => Cryptography.ECDsa.Create(ECCurve.NamedCurves.nistP384));

internal static KeyFactory MLDsa { get; } =
new(() => Cryptography.MLDsa.GenerateKey(MLDsaAlgorithm.MLDsa65));

private Func<IDisposable> _factory;

private KeyFactory(Func<IDisposable> factory)
Expand Down Expand Up @@ -1047,6 +1054,7 @@ internal KeyHolder(X509Certificate2 cert)
_key =
cert.GetRSAPrivateKey() ??
cert.GetECDsaPrivateKey() ??
cert.GetMLDsaPrivateKey() ??
(IDisposable)cert.GetDSAPrivateKey() ??
throw new NotSupportedException();
}
Expand All @@ -1067,6 +1075,7 @@ internal CertificateRequest CreateRequest(string subject)
{
RSA rsa => new CertificateRequest(subject, rsa, HashAlgorithmName.SHA256, RSASignaturePadding.Pkcs1),
ECDsa ecdsa => new CertificateRequest(subject, ecdsa, HashAlgorithmName.SHA256),
MLDsa mldsa => new CertificateRequest(subject, mldsa),
_ => throw new NotSupportedException(),
};
}
Expand All @@ -1077,6 +1086,7 @@ internal X509SignatureGenerator GetGenerator()
{
RSA rsa => X509SignatureGenerator.CreateForRSA(rsa, RSASignaturePadding.Pkcs1),
ECDsa ecdsa => X509SignatureGenerator.CreateForECDsa(ecdsa),
MLDsa mldsa => X509SignatureGenerator.CreateForMLDsa(mldsa),
_ => throw new NotSupportedException(),
};
}
Expand Down
Loading
Loading