Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix constrained call corner cases #111178

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,11 @@ public static unsafe IntPtr ResolveStaticDispatchOnType(RuntimeTypeHandle instan
return result;
}

public static unsafe IntPtr ResolveDispatchOnType(RuntimeTypeHandle instanceType, RuntimeTypeHandle interfaceType, int slot)
{
return RuntimeImports.RhResolveDispatchOnType(instanceType.ToMethodTable(), interfaceType.ToMethodTable(), checked((ushort)slot));
}

public static bool IsUnmanagedPointerType(RuntimeTypeHandle typeHandle)
{
return typeHandle.ToMethodTable()->IsPointer;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,41 @@ internal override IntPtr Create(TypeBuilder builder)
}
}

/// <summary>
/// Used for non-generic instance constrained Methods
/// </summary>
private class NonGenericInstanceConstrainedMethodCell : GenericDictionaryCell
{
internal TypeDesc ConstraintType;
internal TypeDesc ConstrainedMethodType;
internal int ConstrainedMethodSlot;

internal override void Prepare(TypeBuilder builder)
{
if (ConstraintType.IsCanonicalSubtype(CanonicalFormKind.Any) || ConstrainedMethodType.IsCanonicalSubtype(CanonicalFormKind.Any))
Environment.FailFast("Unable to compute call information for a canonical type/method.");

builder.RegisterForPreparation(ConstraintType);
builder.RegisterForPreparation(ConstrainedMethodType);
}

internal override IntPtr Create(TypeBuilder builder)
{
IntPtr result = RuntimeAugments.ResolveDispatchOnType(
builder.GetRuntimeTypeHandle(ConstraintType),
builder.GetRuntimeTypeHandle(ConstrainedMethodType),
ConstrainedMethodSlot);

Debug.Assert(result != IntPtr.Zero);

return result;
}
}

/// <summary>
/// Used for generic static constrained Methods
/// </summary>
private class GenericStaticConstrainedMethodCell : GenericDictionaryCell
private class GenericConstrainedMethodCell : GenericDictionaryCell
{
internal DefType ConstraintType;
internal InstantiatedMethod ConstrainedMethod;
Expand Down Expand Up @@ -512,29 +543,45 @@ internal static GenericDictionaryCell ParseAndCreateCell(NativeLayoutInfoLoadCon
break;

case FixupSignatureKind.NonGenericStaticConstrainedMethod:
{
case FixupSignatureKind.NonGenericInstanceConstrainedMethod:
{
var constraintType = nativeLayoutInfoLoadContext.GetType(ref parser);
var constrainedMethodType = nativeLayoutInfoLoadContext.GetType(ref parser);
var constrainedMethodSlot = parser.GetUnsigned();
TypeLoaderLogger.WriteLine("NonGenericStaticConstrainedMethod: " + constraintType.ToString() + " Method " + constrainedMethodType.ToString() + ", slot #" + constrainedMethodSlot.LowLevelToString());

cell = new NonGenericStaticConstrainedMethodCell()
string kindString = kind == FixupSignatureKind.NonGenericStaticConstrainedMethod ? "NonGenericStaticConstrainedMethod: " : "NonGenericInstanceConstrainedMethod: ";

TypeLoaderLogger.WriteLine(kindString + constraintType.ToString() + " Method " + constrainedMethodType.ToString() + ", slot #" + constrainedMethodSlot.LowLevelToString());

if (kind == FixupSignatureKind.NonGenericStaticConstrainedMethod)
{
ConstraintType = constraintType,
ConstrainedMethodType = constrainedMethodType,
ConstrainedMethodSlot = (int)constrainedMethodSlot
};
cell = new NonGenericStaticConstrainedMethodCell()
{
ConstraintType = constraintType,
ConstrainedMethodType = constrainedMethodType,
ConstrainedMethodSlot = (int)constrainedMethodSlot
};
}
else
{
cell = new NonGenericInstanceConstrainedMethodCell()
{
ConstraintType = constraintType,
ConstrainedMethodType = constrainedMethodType,
ConstrainedMethodSlot = (int)constrainedMethodSlot
};
}
}
break;

case FixupSignatureKind.GenericStaticConstrainedMethod:
{
case FixupSignatureKind.GenericConstrainedMethod:
{
TypeDesc constraintType = nativeLayoutInfoLoadContext.GetType(ref parser);
MethodDesc constrainedMethod = nativeLayoutInfoLoadContext.GetMethod(ref parser);

TypeLoaderLogger.WriteLine("GenericStaticConstrainedMethod: " + constraintType.ToString() + " Method " + constrainedMethod.ToString());
TypeLoaderLogger.WriteLine("GenericConstrainedMethod: " + constraintType.ToString() + " Method " + constrainedMethod.ToString());

cell = new GenericStaticConstrainedMethodCell()
cell = new GenericConstrainedMethodCell()
{
ConstraintType = (DefType)constraintType,
ConstrainedMethod = (InstantiatedMethod)constrainedMethod,
Expand Down
4 changes: 3 additions & 1 deletion src/coreclr/tools/Common/Compiler/TypeExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,9 @@ public static MethodDesc TryResolveConstraintMethodApprox(this TypeDesc constrai
potentialInterfaceMethod.GetTypicalMethodDefinition(), (InstantiatedType)potentialInterfaceType);
}

method = canonType.ResolveInterfaceMethodToVirtualMethodOnType(potentialInterfaceMethod);
method = canonType.ResolveInterfaceMethodToVirtualMethodOnType(potentialInterfaceMethod)
// Do not lose track of `method` if we were able to resolve it previously
?? method;

// See code:#TryResolveConstraintMethodApprox_DoNotReturnParentMethod
if (method != null && !method.OwningType.IsValueType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,9 @@ enum FixupSignatureKind : uint
// unused = 0x17,
// unused = 0x18,
// unused = 0x19,
// unused = 0x20,
NonGenericInstanceConstrainedMethod = 0x20,
NonGenericStaticConstrainedMethod = 0x21,
GenericStaticConstrainedMethod = 0x22,
GenericConstrainedMethod = 0x22,

NotYetSupported = 0xee,
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -936,8 +936,6 @@ public override DefaultInterfaceMethodResolution ResolveVariantInterfaceMethodTo

public static DefaultInterfaceMethodResolution ResolveVariantInterfaceMethodToDefaultImplementationOnType(MethodDesc interfaceMethod, MetadataType currentType, out MethodDesc impl)
{
Debug.Assert(interfaceMethod.Signature.IsStatic);

MetadataType interfaceType = (MetadataType)interfaceMethod.OwningType;
bool foundInterface = IsInterfaceImplementedOnType(currentType, interfaceType);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -848,30 +848,38 @@ public override ISymbolNode GetTarget(NodeFactory factory, GenericLookupResultCo
TypeDesc instantiatedConstraintType = _constraintType.GetNonRuntimeDeterminedTypeFromRuntimeDeterminedSubtypeViaSubstitution(dictionary.TypeInstantiation, dictionary.MethodInstantiation);
MethodDesc implMethod;

MethodDesc instantiatedConstrainedMethodDefinition = instantiatedConstrainedMethod.GetMethodDefinition();

if (instantiatedConstrainedMethod.OwningType.IsInterface)
{
if (instantiatedConstrainedMethod.Signature.IsStatic)
{
implMethod = instantiatedConstraintType.GetClosestDefType().ResolveVariantInterfaceMethodToStaticVirtualMethodOnType(instantiatedConstrainedMethod);
if (implMethod == null)
{
DefaultInterfaceMethodResolution resolution =
instantiatedConstraintType.GetClosestDefType().ResolveVariantInterfaceMethodToDefaultImplementationOnType(instantiatedConstrainedMethod, out implMethod);
if (resolution != DefaultInterfaceMethodResolution.DefaultImplementation)
{
// TODO: diamond/reabstraction
ThrowHelper.ThrowInvalidProgramException();
}
}
implMethod = instantiatedConstraintType.GetClosestDefType().ResolveVariantInterfaceMethodToStaticVirtualMethodOnType(instantiatedConstrainedMethodDefinition);
}
else
{
throw new NotImplementedException();
implMethod = instantiatedConstraintType.GetClosestDefType().ResolveVariantInterfaceMethodToVirtualMethodOnType(instantiatedConstrainedMethodDefinition);
}

if (implMethod == null)
{
DefaultInterfaceMethodResolution resolution =
instantiatedConstraintType.GetClosestDefType().ResolveVariantInterfaceMethodToDefaultImplementationOnType(instantiatedConstrainedMethodDefinition, out implMethod);
if (resolution != DefaultInterfaceMethodResolution.DefaultImplementation)
{
// TODO: diamond/reabstraction
ThrowHelper.ThrowInvalidProgramException();
}
}
}
else
{
implMethod = instantiatedConstraintType.GetClosestDefType().FindVirtualFunctionTargetMethodOnObjectType(instantiatedConstrainedMethod);
implMethod = instantiatedConstraintType.GetClosestDefType().FindVirtualFunctionTargetMethodOnObjectType(instantiatedConstrainedMethodDefinition);
}

if (instantiatedConstrainedMethod != instantiatedConstrainedMethodDefinition)
{
implMethod = implMethod.MakeInstantiatedMethod(instantiatedConstrainedMethod.Instantiation);
}

// AOT use of this generic lookup is restricted to finding methods on valuetypes (runtime usage of this slot in universal generics is more flexible)
Expand All @@ -880,21 +888,10 @@ public override ISymbolNode GetTarget(NodeFactory factory, GenericLookupResultCo
factory.MetadataManager.NoteOverridingMethod(_constrainedMethod, implMethod);

// TODO-SIZE: this is address taken only in the delegate target case
if (implMethod.Signature.IsStatic)
{
if (implMethod.GetCanonMethodTarget(CanonicalFormKind.Specific).IsSharedByGenericInstantiations)
return factory.ExactCallableAddressTakenAddress(implMethod);
else
return factory.AddressTakenMethodEntrypoint(implMethod);
}
else if (implMethod.HasInstantiation)
{
if (implMethod.GetCanonMethodTarget(CanonicalFormKind.Specific).IsSharedByGenericInstantiations)
return factory.ExactCallableAddressTakenAddress(implMethod);
}
else
{
return factory.AddressTakenMethodEntrypoint(implMethod.GetCanonMethodTarget(CanonicalFormKind.Specific));
}
return factory.AddressTakenMethodEntrypoint(implMethod);
}

public override void AppendMangledName(NameMangler nameMangler, Utf8StringBuilder sb)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1414,7 +1414,6 @@ public NativeLayoutConstrainedMethodDictionarySlotNode(MethodDesc constrainedMet
_directCall = directCall;
Debug.Assert(_constrainedMethod.OwningType.IsInterface);
Debug.Assert(!_constrainedMethod.HasInstantiation || !directCall);
Debug.Assert(_constrainedMethod.Signature.IsStatic);
}

protected sealed override string GetName(NodeFactory factory) =>
Expand All @@ -1428,10 +1427,12 @@ protected sealed override FixupSignatureKind SignatureKind
{
get
{
if (_constrainedMethod.HasInstantiation)
return FixupSignatureKind.GenericStaticConstrainedMethod;
else
return FixupSignatureKind.NonGenericStaticConstrainedMethod;
return (_constrainedMethod.HasInstantiation, _constrainedMethod.Signature.IsStatic) switch
{
(true, _) => FixupSignatureKind.GenericConstrainedMethod,
(false, true) => FixupSignatureKind.NonGenericStaticConstrainedMethod,
(false, false) => FixupSignatureKind.NonGenericInstanceConstrainedMethod,
};
}
}

Expand Down Expand Up @@ -1477,13 +1478,13 @@ protected sealed override Vertex WriteSignatureVertex(NativeWriter writer, NodeF
Vertex constraintType = factory.NativeLayout.TypeSignatureVertex(_constraintType).WriteVertex(factory);
if (_constrainedMethod.HasInstantiation)
{
Debug.Assert(SignatureKind is FixupSignatureKind.GenericStaticConstrainedMethod);
Debug.Assert(SignatureKind is FixupSignatureKind.GenericConstrainedMethod);
Vertex constrainedMethodVertex = factory.NativeLayout.MethodEntry(_constrainedMethod).WriteVertex(factory);
return writer.GetTuple(constraintType, constrainedMethodVertex);
}
else
{
Debug.Assert(SignatureKind is FixupSignatureKind.NonGenericStaticConstrainedMethod);
Debug.Assert(SignatureKind is FixupSignatureKind.NonGenericStaticConstrainedMethod or FixupSignatureKind.NonGenericInstanceConstrainedMethod);
Vertex methodType = factory.NativeLayout.TypeSignatureVertex(_constrainedMethod.OwningType).WriteVertex(factory);
var canonConstrainedMethod = _constrainedMethod.GetCanonMethodTarget(CanonicalFormKind.Specific);
int interfaceSlot = VirtualMethodSlotHelper.GetVirtualMethodSlot(factory, canonConstrainedMethod, canonConstrainedMethod.OwningType);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -550,42 +550,39 @@ private void ImportCall(ILOpcode opcode, int token)

bool allowInstParam = opcode != ILOpcode.ldvirtftn && opcode != ILOpcode.ldftn;

if (directCall && resolvedConstraint && exactContextNeedsRuntimeLookup)
if (directCall && resolvedConstraint && (exactContextNeedsRuntimeLookup || forceUseRuntimeLookup))
{
// We want to do a direct call to a shared method on a valuetype. We need to provide
// a generic context, but the JitInterface doesn't provide a way for us to do it from here.
// So we do the next best thing and ask RyuJIT to look up a fat pointer.
//
// We have the canonical version of the method - find the runtime determined version.
// This is simplified because we know the method is on a valuetype.
Debug.Assert(targetMethod.OwningType.IsValueType);

if (forceUseRuntimeLookup)
{
// The below logic would incorrectly resolve the lookup into the first match we found,
// but there was a compile-time ambiguity due to shared code. The correct fix should
// use the ConstrainedMethodUseLookupResult dictionary entry so that the exact
// dispatch can be computed with the help of the generic dictionary.
// We fail the compilation here to avoid bad codegen. This is not actually an invalid program.
// https://github.com/dotnet/runtimelab/issues/1431
ThrowHelper.ThrowInvalidProgramException();
var constrainedCallInfo = new ConstrainedCallInfo(_constrained, runtimeDeterminedMethod);
_dependencies.Add(GetGenericLookupHelper(ReadyToRunHelperId.ConstrainedDirectCall, constrainedCallInfo), reason);
}

MethodDesc targetOfLookup;
if (_constrained.IsRuntimeDeterminedType)
targetOfLookup = _compilation.TypeSystemContext.GetMethodForRuntimeDeterminedType(targetMethod.GetTypicalMethodDefinition(), (RuntimeDeterminedType)_constrained);
else if (_constrained.HasInstantiation)
targetOfLookup = _compilation.TypeSystemContext.GetMethodForInstantiatedType(targetMethod.GetTypicalMethodDefinition(), (InstantiatedType)_constrained);
else
targetOfLookup = targetMethod.GetMethodDefinition();
if (targetOfLookup.HasInstantiation)
{
targetOfLookup = targetOfLookup.MakeInstantiatedMethod(runtimeDeterminedMethod.Instantiation);
}
Debug.Assert(targetOfLookup.GetCanonMethodTarget(CanonicalFormKind.Specific) == targetMethod.GetCanonMethodTarget(CanonicalFormKind.Specific));
_dependencies.Add(GetGenericLookupHelper(ReadyToRunHelperId.MethodEntry, targetOfLookup), reason);
// We have the canonical version of the method - find the runtime determined version.
// This is simplified because we know the method is on a valuetype.
Debug.Assert(targetMethod.OwningType.IsValueType);

MethodDesc targetOfLookup;
if (_constrained.IsRuntimeDeterminedType)
targetOfLookup = _compilation.TypeSystemContext.GetMethodForRuntimeDeterminedType(targetMethod.GetTypicalMethodDefinition(), (RuntimeDeterminedType)_constrained);
else if (_constrained.HasInstantiation)
targetOfLookup = _compilation.TypeSystemContext.GetMethodForInstantiatedType(targetMethod.GetTypicalMethodDefinition(), (InstantiatedType)_constrained);
else
targetOfLookup = targetMethod.GetMethodDefinition();
if (targetOfLookup.HasInstantiation)
{
targetOfLookup = targetOfLookup.MakeInstantiatedMethod(runtimeDeterminedMethod.Instantiation);
}
Debug.Assert(targetOfLookup.GetCanonMethodTarget(CanonicalFormKind.Specific) == targetMethod.GetCanonMethodTarget(CanonicalFormKind.Specific));
_dependencies.Add(GetGenericLookupHelper(ReadyToRunHelperId.MethodEntry, targetOfLookup), reason);

targetForDelegate = targetOfLookup;
targetForDelegate = targetOfLookup;
}
}
else if (directCall && !allowInstParam && targetMethod.GetCanonMethodTarget(CanonicalFormKind.Specific).RequiresInstArg())
{
Expand Down
Loading
Loading