Skip to content

Commit dba6cd8

Browse files
authored
perf: get rid of MemoryStream in KeyRingBasedDataProtector (#59322)
1 parent 08f5e1a commit dba6cd8

File tree

5 files changed

+188
-38
lines changed

5 files changed

+188
-38
lines changed

src/DataProtection/DataProtection/src/KeyManagement/KeyRingBasedDataProtector.cs

Lines changed: 54 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33

44
using System;
5+
using System.Buffers.Binary;
6+
using System.Buffers;
57
using System.Collections.Generic;
68
using System.Diagnostics;
79
using System.Diagnostics.CodeAnalysis;
@@ -14,6 +16,8 @@
1416
using Microsoft.AspNetCore.DataProtection.KeyManagement.Internal;
1517
using Microsoft.AspNetCore.Shared;
1618
using Microsoft.Extensions.Logging;
19+
using System.Buffers.Text;
20+
using Microsoft.AspNetCore.DataProtection.Internal;
1721

1822
namespace Microsoft.AspNetCore.DataProtection.KeyManagement;
1923

@@ -313,39 +317,13 @@ private static void WriteBigEndianInteger(byte* ptr, uint value)
313317
ptr[3] = (byte)(value);
314318
}
315319

316-
private struct AdditionalAuthenticatedDataTemplate
320+
internal struct AdditionalAuthenticatedDataTemplate
317321
{
318322
private byte[] _aadTemplate;
319323

320-
public AdditionalAuthenticatedDataTemplate(IEnumerable<string> purposes)
324+
public AdditionalAuthenticatedDataTemplate(string[] purposes)
321325
{
322-
const int MEMORYSTREAM_DEFAULT_CAPACITY = 0x100; // matches MemoryStream.EnsureCapacity
323-
var ms = new MemoryStream(MEMORYSTREAM_DEFAULT_CAPACITY);
324-
325-
// additionalAuthenticatedData := { magicHeader (32-bit) || keyId || purposeCount (32-bit) || (purpose)* }
326-
// purpose := { utf8ByteCount (7-bit encoded) || utf8Text }
327-
328-
using (var writer = new PurposeBinaryWriter(ms))
329-
{
330-
writer.WriteBigEndian(MAGIC_HEADER_V0);
331-
Debug.Assert(ms.Position == sizeof(uint));
332-
var posPurposeCount = writer.Seek(sizeof(Guid), SeekOrigin.Current); // skip over where the key id will be stored; we'll fill it in later
333-
writer.Seek(sizeof(uint), SeekOrigin.Current); // skip over where the purposeCount will be stored; we'll fill it in later
334-
335-
uint purposeCount = 0;
336-
foreach (string purpose in purposes)
337-
{
338-
Debug.Assert(purpose != null);
339-
writer.Write(purpose); // prepends length as a 7-bit encoded integer
340-
purposeCount++;
341-
}
342-
343-
// Once we have written all the purposes, go back and fill in 'purposeCount'
344-
writer.Seek(checked((int)posPurposeCount), SeekOrigin.Begin);
345-
writer.WriteBigEndian(purposeCount);
346-
}
347-
348-
_aadTemplate = ms.ToArray();
326+
_aadTemplate = BuildAadTemplateBytes(purposes);
349327
}
350328

351329
public byte[] GetAadForKey(Guid keyId, bool isProtecting)
@@ -381,19 +359,57 @@ public byte[] GetAadForKey(Guid keyId, bool isProtecting)
381359
}
382360
}
383361

384-
private sealed class PurposeBinaryWriter : BinaryWriter
362+
internal static byte[] BuildAadTemplateBytes(string[] purposes)
385363
{
386-
public PurposeBinaryWriter(MemoryStream stream) : base(stream, EncodingUtil.SecureUtf8Encoding, leaveOpen: true) { }
364+
// additionalAuthenticatedData := { magicHeader (32-bit) || keyId || purposeCount (32-bit) || (purpose)* }
365+
// purpose := { utf8ByteCount (7-bit encoded) || utf8Text }
366+
367+
var keySize = sizeof(Guid);
368+
int totalPurposeLen = 4 + keySize + 4;
387369

388-
// Writes a big-endian 32-bit integer to the underlying stream.
389-
public void WriteBigEndian(uint value)
370+
int[]? lease = null;
371+
var targetLength = purposes.Length;
372+
Span<int> purposeLengthsPool = targetLength <= 32 ? stackalloc int[targetLength] : (lease = ArrayPool<int>.Shared.Rent(targetLength)).AsSpan(0, targetLength);
373+
for (int i = 0; i < targetLength; i++)
390374
{
391-
var outStream = BaseStream; // property accessor also performs a flush
392-
outStream.WriteByte((byte)(value >> 24));
393-
outStream.WriteByte((byte)(value >> 16));
394-
outStream.WriteByte((byte)(value >> 8));
395-
outStream.WriteByte((byte)(value));
375+
string purpose = purposes[i];
376+
377+
int purposeLength = EncodingUtil.SecureUtf8Encoding.GetByteCount(purpose);
378+
purposeLengthsPool[i] = purposeLength;
379+
380+
var encoded7BitUIntLength = purposeLength.Measure7BitEncodedUIntLength();
381+
totalPurposeLen += purposeLength /* length of actual string */ + encoded7BitUIntLength /* length of 'string length' 7-bit encoded int */;
396382
}
383+
384+
byte[] targetArr = new byte[totalPurposeLen];
385+
var targetSpan = targetArr.AsSpan();
386+
387+
// index 0: magic header
388+
BinaryPrimitives.WriteUInt32BigEndian(targetSpan.Slice(0), MAGIC_HEADER_V0);
389+
// index 4: key (skipped for now, will be populated in `GetAadForKey()`)
390+
// index 4 + keySize: purposeCount
391+
BinaryPrimitives.WriteInt32BigEndian(targetSpan.Slice(4 + keySize), targetLength);
392+
393+
int index = 4 /* MAGIC_HEADER_V0 */ + keySize + 4 /* purposeLength */; // starting from first purpose
394+
for (int i = 0; i < targetLength; i++)
395+
{
396+
string purpose = purposes[i];
397+
398+
// writing `utf8ByteCount (7-bit encoded integer)`
399+
// we have already calculated the lengths of the purpose strings, so just get it from the pool
400+
index += targetSpan.Slice(index).Write7BitEncodedInt(purposeLengthsPool[i]);
401+
402+
// write the utf8text for the purpose
403+
index += EncodingUtil.SecureUtf8Encoding.GetBytes(purpose, charIndex: 0, charCount: purpose.Length, bytes: targetArr, byteIndex: index);
404+
}
405+
406+
if (lease is not null)
407+
{
408+
ArrayPool<int>.Shared.Return(lease);
409+
}
410+
Debug.Assert(index == targetArr.Length);
411+
412+
return targetArr;
397413
}
398414
}
399415

src/DataProtection/DataProtection/src/Microsoft.AspNetCore.DataProtection.csproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
<Compile Include="$(SharedSourceRoot)TrimmingAttributes.cs" LinkBase="Shared" />
2222
<Compile Include="$(SharedSourceRoot)ThrowHelpers\ArgumentNullThrowHelper.cs" LinkBase="Shared" />
2323
<Compile Include="$(SharedSourceRoot)CallerArgument\CallerArgumentExpressionAttribute.cs" LinkBase="Shared" />
24+
<Compile Include="$(SharedSourceRoot)Encoding\Int7BitEncodingUtils.cs" LinkBase="Shared" />
2425
</ItemGroup>
2526

2627
<ItemGroup>
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
4+
using System;
5+
using System.Collections.Generic;
6+
using System.Text;
7+
using Microsoft.AspNetCore.DataProtection.Internal;
8+
using Microsoft.AspNetCore.Shared;
9+
10+
namespace Microsoft.AspNetCore.DataProtection.Tests.Internal;
11+
12+
public class Int7BitEncodingUtilsTests
13+
{
14+
[Theory]
15+
[InlineData(0, 1)]
16+
[InlineData(1, 1)]
17+
[InlineData(0b0_1111111, 1)]
18+
[InlineData(0b1_0000000, 2)]
19+
[InlineData(0b1111111_1111111, 2)]
20+
[InlineData(0b1_0000000_0000000, 3)]
21+
[InlineData(0b1111111_1111111_1111111, 3)]
22+
[InlineData(0b1_0000000_0000000_0000000, 4)]
23+
[InlineData(0b1111111_1111111_1111111_1111111, 4)]
24+
[InlineData(0b1_0000000_0000000_0000000_0000000, 5)]
25+
[InlineData(uint.MaxValue, 5)]
26+
public void Measure7BitEncodedUIntLength_ReturnsExceptedLength(uint value, int expectedSize)
27+
{
28+
var actualSize = value.Measure7BitEncodedUIntLength();
29+
Assert.Equal(expectedSize, actualSize);
30+
}
31+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
4+
using System;
5+
using System.Collections.Generic;
6+
using System.Reflection;
7+
using System.Runtime.CompilerServices;
8+
using System.Security.Claims;
9+
using System.Text;
10+
using System.Text.Unicode;
11+
using Microsoft.AspNetCore.DataProtection.KeyManagement;
12+
using Microsoft.AspNetCore.DataProtection.KeyManagement.Internal;
13+
using Microsoft.Extensions.Logging;
14+
using Microsoft.Extensions.Logging.Abstractions;
15+
using Moq;
16+
17+
namespace Microsoft.AspNetCore.DataProtection.Tests.KeyManagement;
18+
public class AdditionalAuthenticatedDataTemplateTests
19+
{
20+
[Fact]
21+
public void AdditionalAuthenticatedDataTemplateBuildAadTemplateBytes_ReturnsSameResultAsPreviousImplementation()
22+
{
23+
var actualBytes = KeyRingBasedDataProtector.AdditionalAuthenticatedDataTemplate.BuildAadTemplateBytes([
24+
"my sample string",
25+
"©®±µ¶", // exotic unicode characters (https://en.wikipedia.org/wiki/List_of_Unicode_characters)
26+
"my other sample string",
27+
// more than 128 utf-8 bytes string
28+
"CfDJ8H5oH_fp1QNBmvs-OWXxsVoV30hrXeI4-PI4p1VZytjsgd0DTstMdtTZbFtm2dKHvsBlDCv7TiEWKztZf8fb48pUgBgUE2SeYV3eOUXvSfNWU0D8SmHLy5KEnwKKkZKqudDhCnjQSIU7mhDliJJN1e4",
29+
"."
30+
]);
31+
32+
// expected bytes are formed by running former code with the same input
33+
// former code can be found in https://github.com/dotnet/aspnetcore/pull/59322
34+
var expectedBytesInBase64 = "CfDJ8AAAAAAAAAAAAAAAAAAAAAAAAAAFEG15IHNhbXBsZSBzdHJpbmcKwqnCrsKxwrXCthZteSBvdGhlciBzYW1wbGUgc3RyaW5nmwFDZkRKOEg1b0hfZnAxUU5CbXZzLU9XWHhzVm9WMzBoclhlSTQtUEk0cDFWWnl0anNnZDBEVHN0TWR0VFpiRnRtMmRLSHZzQmxEQ3Y3VGlFV0t6dFpmOGZiNDhwVWdCZ1VFMlNlWVYzZU9VWHZTZk5XVTBEOFNtSEx5NUtFbndLS2taS3F1ZERoQ25qUVNJVTdtaERsaUpKTjFlNAEu";
35+
36+
var actualBytesInBase64 = Convert.ToBase64String(actualBytes);
37+
Assert.Equal(expectedBytesInBase64, actualBytesInBase64);
38+
}
39+
40+
[Fact]
41+
public void AdditionalAuthenticatedDataTemplateBuildAadTemplateBytes_ThrowsOnIllegalUtf8Text()
42+
{
43+
Assert.Throws<EncoderFallbackException>(() =>
44+
{
45+
var actualBytes = KeyRingBasedDataProtector.AdditionalAuthenticatedDataTemplate.BuildAadTemplateBytes([
46+
"😀"[0] + "X",
47+
]);
48+
});
49+
}
50+
}
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
4+
using System;
5+
using System.Collections.Generic;
6+
using System.Linq;
7+
using System.Text;
8+
using System.Threading.Tasks;
9+
10+
namespace Microsoft.AspNetCore.Shared;
11+
12+
internal static class Int7BitEncodingUtils
13+
{
14+
public static int Measure7BitEncodedUIntLength(this int value)
15+
=> Measure7BitEncodedUIntLength((uint)value);
16+
17+
public static int Measure7BitEncodedUIntLength(this uint value)
18+
{
19+
#if NET10_0_OR_GREATER
20+
return ((31 - System.Numerics.BitOperations.LeadingZeroCount(value | 1)) / 7) + 1;
21+
#else
22+
int count = 1;
23+
while ((value >>= 7) != 0)
24+
{
25+
count++;
26+
}
27+
return count;
28+
#endif
29+
}
30+
31+
public static int Write7BitEncodedInt(this Span<byte> target, int value)
32+
=> Write7BitEncodedInt(target, (uint)value);
33+
34+
public static int Write7BitEncodedInt(this Span<byte> target, uint uValue)
35+
{
36+
// Write out an int 7 bits at a time. The high bit of the byte,
37+
// when on, tells reader to continue reading more bytes.
38+
//
39+
// Using the constants 0x7F and ~0x7F below offers smaller
40+
// codegen than using the constant 0x80.
41+
42+
int index = 0;
43+
while (uValue > 0x7Fu)
44+
{
45+
target[index++] = (byte)(uValue | ~0x7Fu);
46+
uValue >>= 7;
47+
}
48+
49+
target[index++] = (byte)uValue;
50+
return index;
51+
}
52+
}

0 commit comments

Comments
 (0)