Skip to content

Commit dddf223

Browse files
EgorBojakobbotsch
andauthored
Unroll SequenceEqual(ref byte, ref byte, nuint) in JIT (#83945)
Co-authored-by: Jakob Botsch Nielsen <[email protected]>
1 parent 8ca896c commit dddf223

File tree

8 files changed

+266
-22
lines changed

8 files changed

+266
-22
lines changed

src/coreclr/jit/importercalls.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3782,6 +3782,7 @@ GenTree* Compiler::impIntrinsic(GenTree* newobjThis,
37823782
break;
37833783
}
37843784

3785+
case NI_System_SpanHelpers_SequenceEqual:
37853786
case NI_System_Buffer_Memmove:
37863787
{
37873788
// We'll try to unroll this in lower for constant input.
@@ -8139,6 +8140,13 @@ NamedIntrinsic Compiler::lookupNamedIntrinsic(CORINFO_METHOD_HANDLE method)
81398140
result = NI_System_Span_get_Length;
81408141
}
81418142
}
8143+
else if (strcmp(className, "SpanHelpers") == 0)
8144+
{
8145+
if (strcmp(methodName, "SequenceEqual") == 0)
8146+
{
8147+
result = NI_System_SpanHelpers_SequenceEqual;
8148+
}
8149+
}
81428150
else if (strcmp(className, "String") == 0)
81438151
{
81448152
if (strcmp(methodName, "Equals") == 0)

src/coreclr/jit/lower.cpp

Lines changed: 194 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1865,6 +1865,185 @@ GenTree* Lowering::LowerCallMemmove(GenTreeCall* call)
18651865
return nullptr;
18661866
}
18671867

1868+
//------------------------------------------------------------------------
1869+
// LowerCallMemcmp: Replace SpanHelpers.SequenceEqual)(left, right, CNS_SIZE)
1870+
// with a series of merged comparisons (via GT_IND nodes)
1871+
//
1872+
// Arguments:
1873+
// tree - GenTreeCall node to unroll as memcmp
1874+
//
1875+
// Return Value:
1876+
// nullptr if no changes were made
1877+
//
1878+
GenTree* Lowering::LowerCallMemcmp(GenTreeCall* call)
1879+
{
1880+
JITDUMP("Considering Memcmp [%06d] for unrolling.. ", comp->dspTreeID(call))
1881+
assert(comp->lookupNamedIntrinsic(call->gtCallMethHnd) == NI_System_SpanHelpers_SequenceEqual);
1882+
assert(call->gtArgs.CountUserArgs() == 3);
1883+
assert(TARGET_POINTER_SIZE == 8);
1884+
1885+
if (!comp->opts.OptimizationEnabled())
1886+
{
1887+
JITDUMP("Optimizations aren't allowed - bail out.\n")
1888+
return nullptr;
1889+
}
1890+
1891+
if (comp->info.compHasNextCallRetAddr)
1892+
{
1893+
JITDUMP("compHasNextCallRetAddr=true so we won't be able to remove the call - bail out.\n")
1894+
return nullptr;
1895+
}
1896+
1897+
GenTree* lengthArg = call->gtArgs.GetUserArgByIndex(2)->GetNode();
1898+
if (lengthArg->IsIntegralConst())
1899+
{
1900+
ssize_t cnsSize = lengthArg->AsIntCon()->IconValue();
1901+
JITDUMP("Size=%ld.. ", (LONG)cnsSize);
1902+
// TODO-CQ: drop the whole thing in case of 0
1903+
if (cnsSize > 0)
1904+
{
1905+
GenTree* lArg = call->gtArgs.GetUserArgByIndex(0)->GetNode();
1906+
GenTree* rArg = call->gtArgs.GetUserArgByIndex(1)->GetNode();
1907+
// TODO: Add SIMD path for [16..128] via GT_HWINTRINSIC nodes
1908+
if (cnsSize <= 16)
1909+
{
1910+
unsigned loadWidth = 1 << BitOperations::Log2((unsigned)cnsSize);
1911+
var_types loadType;
1912+
if (loadWidth == 1)
1913+
{
1914+
loadType = TYP_UBYTE;
1915+
}
1916+
else if (loadWidth == 2)
1917+
{
1918+
loadType = TYP_USHORT;
1919+
}
1920+
else if (loadWidth == 4)
1921+
{
1922+
loadType = TYP_INT;
1923+
}
1924+
else if ((loadWidth == 8) || (loadWidth == 16))
1925+
{
1926+
loadWidth = 8;
1927+
loadType = TYP_LONG;
1928+
}
1929+
else
1930+
{
1931+
unreached();
1932+
}
1933+
var_types actualLoadType = genActualType(loadType);
1934+
1935+
GenTree* result = nullptr;
1936+
1937+
// loadWidth == cnsSize means a single load is enough for both args
1938+
if ((loadWidth == (unsigned)cnsSize) && (loadWidth <= 8))
1939+
{
1940+
// We're going to emit something like the following:
1941+
//
1942+
// bool result = *(int*)leftArg == *(int*)rightArg
1943+
//
1944+
// ^ in the given example we unroll for length=4
1945+
//
1946+
GenTree* lIndir = comp->gtNewIndir(loadType, lArg);
1947+
GenTree* rIndir = comp->gtNewIndir(loadType, rArg);
1948+
result = comp->gtNewOperNode(GT_EQ, TYP_INT, lIndir, rIndir);
1949+
1950+
BlockRange().InsertAfter(lArg, lIndir);
1951+
BlockRange().InsertAfter(rArg, rIndir);
1952+
BlockRange().InsertBefore(call, result);
1953+
}
1954+
else
1955+
{
1956+
// First, make both args multi-use:
1957+
LIR::Use lArgUse;
1958+
LIR::Use rArgUse;
1959+
bool lFoundUse = BlockRange().TryGetUse(lArg, &lArgUse);
1960+
bool rFoundUse = BlockRange().TryGetUse(rArg, &rArgUse);
1961+
assert(lFoundUse && rFoundUse);
1962+
GenTree* lArgClone = comp->gtNewLclvNode(lArgUse.ReplaceWithLclVar(comp), genActualType(lArg));
1963+
GenTree* rArgClone = comp->gtNewLclvNode(rArgUse.ReplaceWithLclVar(comp), genActualType(rArg));
1964+
BlockRange().InsertBefore(call, lArgClone, rArgClone);
1965+
1966+
// We're going to emit something like the following:
1967+
//
1968+
// bool result = ((*(int*)leftArg ^ *(int*)rightArg) |
1969+
// (*(int*)(leftArg + 1) ^ *((int*)(rightArg + 1)))) == 0;
1970+
//
1971+
// ^ in the given example we unroll for length=5
1972+
//
1973+
// In IR:
1974+
//
1975+
// * EQ int
1976+
// +--* OR int
1977+
// | +--* XOR int
1978+
// | | +--* IND int
1979+
// | | | \--* LCL_VAR byref V1
1980+
// | | \--* IND int
1981+
// | | \--* LCL_VAR byref V2
1982+
// | \--* XOR int
1983+
// | +--* IND int
1984+
// | | \--* ADD byref
1985+
// | | +--* LCL_VAR byref V1
1986+
// | | \--* CNS_INT int 1
1987+
// | \--* IND int
1988+
// | \--* ADD byref
1989+
// | +--* LCL_VAR byref V2
1990+
// | \--* CNS_INT int 1
1991+
// \--* CNS_INT int 0
1992+
//
1993+
GenTree* l1Indir = comp->gtNewIndir(loadType, lArgUse.Def());
1994+
GenTree* r1Indir = comp->gtNewIndir(loadType, rArgUse.Def());
1995+
GenTree* lXor = comp->gtNewOperNode(GT_XOR, actualLoadType, l1Indir, r1Indir);
1996+
GenTree* l2Offs = comp->gtNewIconNode(cnsSize - loadWidth, TYP_I_IMPL);
1997+
GenTree* l2AddOffs = comp->gtNewOperNode(GT_ADD, lArg->TypeGet(), lArgClone, l2Offs);
1998+
GenTree* l2Indir = comp->gtNewIndir(loadType, l2AddOffs);
1999+
GenTree* r2Offs = comp->gtCloneExpr(l2Offs); // offset is the same
2000+
GenTree* r2AddOffs = comp->gtNewOperNode(GT_ADD, rArg->TypeGet(), rArgClone, r2Offs);
2001+
GenTree* r2Indir = comp->gtNewIndir(loadType, r2AddOffs);
2002+
GenTree* rXor = comp->gtNewOperNode(GT_XOR, actualLoadType, l2Indir, r2Indir);
2003+
GenTree* resultOr = comp->gtNewOperNode(GT_OR, actualLoadType, lXor, rXor);
2004+
GenTree* zeroCns = comp->gtNewIconNode(0, actualLoadType);
2005+
result = comp->gtNewOperNode(GT_EQ, TYP_INT, resultOr, zeroCns);
2006+
2007+
BlockRange().InsertAfter(rArgClone, l1Indir, r1Indir, l2Offs, l2AddOffs);
2008+
BlockRange().InsertAfter(l2AddOffs, l2Indir, r2Offs, r2AddOffs, r2Indir);
2009+
BlockRange().InsertAfter(r2Indir, lXor, rXor, resultOr, zeroCns);
2010+
BlockRange().InsertAfter(zeroCns, result);
2011+
}
2012+
2013+
JITDUMP("\nUnrolled to:\n");
2014+
DISPTREE(result);
2015+
2016+
LIR::Use use;
2017+
if (BlockRange().TryGetUse(call, &use))
2018+
{
2019+
use.ReplaceWith(result);
2020+
}
2021+
BlockRange().Remove(lengthArg);
2022+
BlockRange().Remove(call);
2023+
2024+
// Remove all non-user args (e.g. r2r cell)
2025+
for (CallArg& arg : call->gtArgs.Args())
2026+
{
2027+
if (!arg.IsUserArg())
2028+
{
2029+
arg.GetNode()->SetUnusedValue();
2030+
}
2031+
}
2032+
return lArg;
2033+
}
2034+
}
2035+
else
2036+
{
2037+
JITDUMP("Size is either 0 or too big to unroll.\n")
2038+
}
2039+
}
2040+
else
2041+
{
2042+
JITDUMP("size is not a constant.\n")
2043+
}
2044+
return nullptr;
2045+
}
2046+
18682047
// do lowering steps for a call
18692048
// this includes:
18702049
// - adding the placement nodes (either stack or register variety) for arguments
@@ -1883,19 +2062,26 @@ GenTree* Lowering::LowerCall(GenTree* node)
18832062
// All runtime lookups are expected to be expanded in fgExpandRuntimeLookups
18842063
assert(!call->IsExpRuntimeLookup());
18852064

2065+
#if defined(TARGET_AMD64) || defined(TARGET_ARM64)
18862066
if (call->gtCallMoreFlags & GTF_CALL_M_SPECIAL_INTRINSIC)
18872067
{
1888-
#if defined(TARGET_AMD64) || defined(TARGET_ARM64)
1889-
if (comp->lookupNamedIntrinsic(call->gtCallMethHnd) == NI_System_Buffer_Memmove)
2068+
GenTree* newNode = nullptr;
2069+
NamedIntrinsic ni = comp->lookupNamedIntrinsic(call->gtCallMethHnd);
2070+
if (ni == NI_System_Buffer_Memmove)
18902071
{
1891-
GenTree* newNode = LowerCallMemmove(call);
1892-
if (newNode != nullptr)
1893-
{
1894-
return newNode->gtNext;
1895-
}
2072+
newNode = LowerCallMemmove(call);
2073+
}
2074+
else if (ni == NI_System_SpanHelpers_SequenceEqual)
2075+
{
2076+
newNode = LowerCallMemcmp(call);
2077+
}
2078+
2079+
if (newNode != nullptr)
2080+
{
2081+
return newNode->gtNext;
18962082
}
1897-
#endif
18982083
}
2084+
#endif
18992085

19002086
call->ClearOtherRegs();
19012087
LowerArgsForCall(call);

src/coreclr/jit/lower.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ class Lowering final : public Phase
128128
// ------------------------------
129129
GenTree* LowerCall(GenTree* call);
130130
GenTree* LowerCallMemmove(GenTreeCall* call);
131+
GenTree* LowerCallMemcmp(GenTreeCall* call);
131132
void LowerCFGCall(GenTreeCall* call);
132133
void MoveCFGCallArg(GenTreeCall* call, GenTree* node);
133134
#ifndef TARGET_64BIT

src/coreclr/jit/namedintrinsiclist.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ enum NamedIntrinsic : unsigned short
104104
NI_System_String_StartsWith,
105105
NI_System_Span_get_Item,
106106
NI_System_Span_get_Length,
107+
NI_System_SpanHelpers_SequenceEqual,
107108
NI_System_ReadOnlySpan_get_Item,
108109
NI_System_ReadOnlySpan_get_Length,
109110

src/libraries/System.Private.CoreLib/src/System/MemoryExtensions.cs

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1429,12 +1429,11 @@ public static unsafe bool SequenceEqual<T>(this Span<T> span, ReadOnlySpan<T> ot
14291429

14301430
if (RuntimeHelpers.IsBitwiseEquatable<T>())
14311431
{
1432-
nuint size = (nuint)sizeof(T);
14331432
return length == other.Length &&
14341433
SpanHelpers.SequenceEqual(
14351434
ref Unsafe.As<T, byte>(ref MemoryMarshal.GetReference(span)),
14361435
ref Unsafe.As<T, byte>(ref MemoryMarshal.GetReference(other)),
1437-
((uint)length) * size); // If this multiplication overflows, the Span we got overflows the entire address range. There's no happy outcome for this api in such a case so we choose not to take the overhead of checking.
1436+
((uint)length) * (nuint)sizeof(T)); // If this multiplication overflows, the Span we got overflows the entire address range. There's no happy outcome for this api in such a case so we choose not to take the overhead of checking.
14381437
}
14391438

14401439
return length == other.Length && SpanHelpers.SequenceEqual(ref MemoryMarshal.GetReference(span), ref MemoryMarshal.GetReference(other), length);
@@ -2164,12 +2163,11 @@ public static unsafe bool SequenceEqual<T>(this ReadOnlySpan<T> span, ReadOnlySp
21642163
int length = span.Length;
21652164
if (RuntimeHelpers.IsBitwiseEquatable<T>())
21662165
{
2167-
nuint size = (nuint)sizeof(T);
21682166
return length == other.Length &&
21692167
SpanHelpers.SequenceEqual(
21702168
ref Unsafe.As<T, byte>(ref MemoryMarshal.GetReference(span)),
21712169
ref Unsafe.As<T, byte>(ref MemoryMarshal.GetReference(other)),
2172-
((uint)length) * size); // If this multiplication overflows, the Span we got overflows the entire address range. There's no happy outcome for this API in such a case so we choose not to take the overhead of checking.
2170+
((uint)length) * (nuint)sizeof(T)); // If this multiplication overflows, the Span we got overflows the entire address range. There's no happy outcome for this API in such a case so we choose not to take the overhead of checking.
21732171
}
21742172

21752173
return length == other.Length && SpanHelpers.SequenceEqual(ref MemoryMarshal.GetReference(span), ref MemoryMarshal.GetReference(other), length);
@@ -2207,11 +2205,10 @@ public static unsafe bool SequenceEqual<T>(this ReadOnlySpan<T> span, ReadOnlySp
22072205
// If no comparer was supplied and the type is bitwise equatable, take the fast path doing a bitwise comparison.
22082206
if (RuntimeHelpers.IsBitwiseEquatable<T>())
22092207
{
2210-
nuint size = (nuint)sizeof(T);
22112208
return SpanHelpers.SequenceEqual(
22122209
ref Unsafe.As<T, byte>(ref MemoryMarshal.GetReference(span)),
22132210
ref Unsafe.As<T, byte>(ref MemoryMarshal.GetReference(other)),
2214-
((uint)span.Length) * size); // If this multiplication overflows, the Span we got overflows the entire address range. There's no happy outcome for this API in such a case so we choose not to take the overhead of checking.
2211+
((uint)span.Length) * (nuint)sizeof(T)); // If this multiplication overflows, the Span we got overflows the entire address range. There's no happy outcome for this API in such a case so we choose not to take the overhead of checking.
22152212
}
22162213

22172214
// Otherwise, compare each element using EqualityComparer<T>.Default.Equals in a way that will enable it to devirtualize.
@@ -2277,12 +2274,11 @@ public static unsafe bool StartsWith<T>(this Span<T> span, ReadOnlySpan<T> value
22772274
int valueLength = value.Length;
22782275
if (RuntimeHelpers.IsBitwiseEquatable<T>())
22792276
{
2280-
nuint size = (nuint)sizeof(T);
22812277
return valueLength <= span.Length &&
22822278
SpanHelpers.SequenceEqual(
22832279
ref Unsafe.As<T, byte>(ref MemoryMarshal.GetReference(span)),
22842280
ref Unsafe.As<T, byte>(ref MemoryMarshal.GetReference(value)),
2285-
((uint)valueLength) * size); // If this multiplication overflows, the Span we got overflows the entire address range. There's no happy outcome for this api in such a case so we choose not to take the overhead of checking.
2281+
((uint)valueLength) * (nuint)sizeof(T)); // If this multiplication overflows, the Span we got overflows the entire address range. There's no happy outcome for this api in such a case so we choose not to take the overhead of checking.
22862282
}
22872283

22882284
return valueLength <= span.Length && SpanHelpers.SequenceEqual(ref MemoryMarshal.GetReference(span), ref MemoryMarshal.GetReference(value), valueLength);
@@ -2298,12 +2294,11 @@ public static unsafe bool StartsWith<T>(this ReadOnlySpan<T> span, ReadOnlySpan<
22982294
int valueLength = value.Length;
22992295
if (RuntimeHelpers.IsBitwiseEquatable<T>())
23002296
{
2301-
nuint size = (nuint)sizeof(T);
23022297
return valueLength <= span.Length &&
23032298
SpanHelpers.SequenceEqual(
23042299
ref Unsafe.As<T, byte>(ref MemoryMarshal.GetReference(span)),
23052300
ref Unsafe.As<T, byte>(ref MemoryMarshal.GetReference(value)),
2306-
((uint)valueLength) * size); // If this multiplication overflows, the Span we got overflows the entire address range. There's no happy outcome for this api in such a case so we choose not to take the overhead of checking.
2301+
((uint)valueLength) * (nuint)sizeof(T)); // If this multiplication overflows, the Span we got overflows the entire address range. There's no happy outcome for this api in such a case so we choose not to take the overhead of checking.
23072302
}
23082303

23092304
return valueLength <= span.Length && SpanHelpers.SequenceEqual(ref MemoryMarshal.GetReference(span), ref MemoryMarshal.GetReference(value), valueLength);
@@ -2319,12 +2314,11 @@ public static unsafe bool EndsWith<T>(this Span<T> span, ReadOnlySpan<T> value)
23192314
int valueLength = value.Length;
23202315
if (RuntimeHelpers.IsBitwiseEquatable<T>())
23212316
{
2322-
nuint size = (nuint)sizeof(T);
23232317
return valueLength <= spanLength &&
23242318
SpanHelpers.SequenceEqual(
23252319
ref Unsafe.As<T, byte>(ref Unsafe.Add(ref MemoryMarshal.GetReference(span), (nint)(uint)(spanLength - valueLength) /* force zero-extension */)),
23262320
ref Unsafe.As<T, byte>(ref MemoryMarshal.GetReference(value)),
2327-
((uint)valueLength) * size); // If this multiplication overflows, the Span we got overflows the entire address range. There's no happy outcome for this api in such a case so we choose not to take the overhead of checking.
2321+
((uint)valueLength) * (nuint)sizeof(T)); // If this multiplication overflows, the Span we got overflows the entire address range. There's no happy outcome for this api in such a case so we choose not to take the overhead of checking.
23282322
}
23292323

23302324
return valueLength <= spanLength &&
@@ -2344,12 +2338,11 @@ public static unsafe bool EndsWith<T>(this ReadOnlySpan<T> span, ReadOnlySpan<T>
23442338
int valueLength = value.Length;
23452339
if (RuntimeHelpers.IsBitwiseEquatable<T>())
23462340
{
2347-
nuint size = (nuint)sizeof(T);
23482341
return valueLength <= spanLength &&
23492342
SpanHelpers.SequenceEqual(
23502343
ref Unsafe.As<T, byte>(ref Unsafe.Add(ref MemoryMarshal.GetReference(span), (nint)(uint)(spanLength - valueLength) /* force zero-extension */)),
23512344
ref Unsafe.As<T, byte>(ref MemoryMarshal.GetReference(value)),
2352-
((uint)valueLength) * size); // If this multiplication overflows, the Span we got overflows the entire address range. There's no happy outcome for this api in such a case so we choose not to take the overhead of checking.
2345+
((uint)valueLength) * (nuint)sizeof(T)); // If this multiplication overflows, the Span we got overflows the entire address range. There's no happy outcome for this api in such a case so we choose not to take the overhead of checking.
23532346
}
23542347

23552348
return valueLength <= spanLength &&

src/libraries/System.Private.CoreLib/src/System/SpanHelpers.Byte.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -566,6 +566,7 @@ internal static unsafe int IndexOfNullByte(byte* searchSpace)
566566

567567
// Optimized byte-based SequenceEquals. The "length" parameter for this one is declared a nuint rather than int as we also use it for types other than byte
568568
// where the length can exceed 2Gb once scaled by sizeof(T).
569+
[Intrinsic] // Unrolled for constant length
569570
public static unsafe bool SequenceEqual(ref byte first, ref byte second, nuint length)
570571
{
571572
bool result;

0 commit comments

Comments
 (0)