Skip to content

Enable loading COM component in default ALC via runtime config setting #79026

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Dec 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,15 @@
</type>
</assembly>

<assembly fullname="System.Private.CoreLib" feature="System.Runtime.InteropServices.BuiltInComInterop.IsSupported" featurevalue="true">
<!-- Enables the .NET COM host (.NET 8.0+) to load a COM component. -->
<type fullname="Internal.Runtime.InteropServices.ComActivator" >
<method name="GetClassFactoryForTypeInContext" />
<method name="RegisterClassForTypeInContext" />
<method name="UnregisterClassForTypeInContext" />
</type>
</assembly>

<assembly fullname="System.Private.CoreLib" feature="System.Runtime.InteropServices.EnableCppCLIHostActivation" featurevalue="true">
<!-- Enables the .NET IJW host (.NET 7.0+) to load an in-memory module as a .NET assembly. -->
<type fullname="Internal.Runtime.InteropServices.InMemoryAssemblyLoader">
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,34 +17,4 @@ internal unsafe struct ComActivationContextInternal
public char* TypeNameBuffer;
public IntPtr ClassFactoryDest;
}

//
// Types below are 'public' only to aid in testing of functionality.
// They should not be considered publicly consumable.
//

[StructLayout(LayoutKind.Sequential)]
internal partial struct ComActivationContext
{
public Guid ClassId;
public Guid InterfaceId;
public string AssemblyPath;
public string AssemblyName;
public string TypeName;
}

[ComImport]
[ComVisible(false)]
[Guid("00000001-0000-0000-C000-000000000046")]
[InterfaceType(ComInterfaceType.InterfaceIsIUnknown)]
internal interface IClassFactory
{
[RequiresUnreferencedCode("Built-in COM support is not trim compatible", Url = "https://aka.ms/dotnet-illink/com")]
void CreateInstance(
[MarshalAs(UnmanagedType.Interface)] object? pUnkOuter,
ref Guid riid,
out IntPtr ppvObject);

void LockServer([MarshalAs(UnmanagedType.Bool)] bool fLock);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,21 @@ internal struct LICINFO
public bool fLicVerified;
}

[ComImport]
[ComVisible(false)]
[Guid("00000001-0000-0000-C000-000000000046")]
[InterfaceType(ComInterfaceType.InterfaceIsIUnknown)]
internal interface IClassFactory
{
[RequiresUnreferencedCode("Built-in COM support is not trim compatible", Url = "https://aka.ms/dotnet-illink/com")]
void CreateInstance(
[MarshalAs(UnmanagedType.Interface)] object? pUnkOuter,
ref Guid riid,
out IntPtr ppvObject);

void LockServer([MarshalAs(UnmanagedType.Bool)] bool fLock);
}

[ComImport]
[ComVisible(false)]
[Guid("B196B28F-BAB4-101A-B69C-00AA00341D07")]
Expand Down Expand Up @@ -57,9 +72,17 @@ void CreateInstanceLic(
out IntPtr ppvObject);
}

internal partial struct ComActivationContext
[StructLayout(LayoutKind.Sequential)]
internal struct ComActivationContext
{
public static unsafe ComActivationContext Create(ref ComActivationContextInternal cxtInt)
public Guid ClassId;
public Guid InterfaceId;
public string AssemblyPath;
public string AssemblyName;
public string TypeName;
public bool IsolatedContext;

public static unsafe ComActivationContext Create(ref ComActivationContextInternal cxtInt, bool isolatedContext)
{
if (!Marshal.IsBuiltInComSupported)
{
Expand All @@ -72,7 +95,8 @@ public static unsafe ComActivationContext Create(ref ComActivationContextInterna
InterfaceId = cxtInt.InterfaceId,
AssemblyPath = Marshal.PtrToStringUni(new IntPtr(cxtInt.AssemblyPathBuffer))!,
AssemblyName = Marshal.PtrToStringUni(new IntPtr(cxtInt.AssemblyNameBuffer))!,
TypeName = Marshal.PtrToStringUni(new IntPtr(cxtInt.TypeNameBuffer))!
TypeName = Marshal.PtrToStringUni(new IntPtr(cxtInt.TypeNameBuffer))!,
IsolatedContext = isolatedContext
};
}
}
Expand All @@ -84,6 +108,9 @@ internal static class ComActivator
// unloadable COM server ALCs, this will need to be changed.
private static readonly Dictionary<string, AssemblyLoadContext> s_assemblyLoadContexts = new Dictionary<string, AssemblyLoadContext>(StringComparer.InvariantCultureIgnoreCase);

// COM component assembly paths loaded in the default ALC
private static readonly HashSet<string> s_loadedInDefaultContext = new HashSet<string>(StringComparer.InvariantCultureIgnoreCase);

/// <summary>
/// Entry point for unmanaged COM activation API from managed code
/// </summary>
Expand All @@ -107,7 +134,7 @@ private static object GetClassFactoryForType(ComActivationContext cxt)
throw new ArgumentException(null, nameof(cxt));
}

Type classType = FindClassType(cxt.ClassId, cxt.AssemblyPath, cxt.AssemblyName, cxt.TypeName);
Type classType = FindClassType(cxt);

if (LicenseInteropProxy.HasLicense(classType))
{
Expand Down Expand Up @@ -145,7 +172,7 @@ private static void ClassRegistrationScenarioForType(ComActivationContext cxt, b
throw new ArgumentException(null, nameof(cxt));
}

Type classType = FindClassType(cxt.ClassId, cxt.AssemblyPath, cxt.AssemblyName, cxt.TypeName);
Type classType = FindClassType(cxt);

Type? currentType = classType;
bool calledFunction = false;
Expand Down Expand Up @@ -213,17 +240,45 @@ private static void ClassRegistrationScenarioForType(ComActivationContext cxt, b
}

/// <summary>
/// Internal entry point for unmanaged COM activation API from native code
/// Gets a class factory for COM activation in an isolated load context
/// </summary>
/// <param name="pCxtInt">Pointer to a <see cref="ComActivationContextInternal"/> instance</param>
[UnmanagedCallersOnly]
private static unsafe int GetClassFactoryForTypeInternal(ComActivationContextInternal* pCxtInt)
{
if (!Marshal.IsBuiltInComSupported)
{
throw new NotSupportedException(SR.NotSupported_COM);
}

#pragma warning disable IL2026 // suppressed in ILLink.Suppressions.LibraryBuild.xml
return GetClassFactoryForTypeImpl(pCxtInt, isolatedContext: true);
#pragma warning restore IL2026
}

/// <summary>
/// Gets a class factory for COM activation in the specified load context
/// </summary>
/// <param name="pCxtInt">Pointer to a <see cref="ComActivationContextInternal"/> instance</param>
/// <param name="loadContext">Load context - currently must be IntPtr.Zero (default context) or -1 (isolated context)</param>
[UnmanagedCallersOnly]
private static unsafe int GetClassFactoryForTypeInContext(ComActivationContextInternal* pCxtInt, IntPtr loadContext)
{
if (!Marshal.IsBuiltInComSupported)
throw new NotSupportedException(SR.NotSupported_COM);

if (loadContext != IntPtr.Zero && loadContext != (IntPtr)(-1))
throw new ArgumentOutOfRangeException(nameof(loadContext));

return GetClassFactoryForTypeLocal(pCxtInt, isolatedContext: loadContext != IntPtr.Zero);

// Use a local function for a targeted suppression of the requires unreferenced code warning
[UnconditionalSuppressMessage("ReflectionAnalysis", "IL2026:RequiresUnreferencedCode",
Justification = "The same feature switch applies to GetClassFactoryForTypeInternal and this function. We rely on the warning from GetClassFactoryForTypeInternal.")]
static int GetClassFactoryForTypeLocal(ComActivationContextInternal* pCxtInt, bool isolatedContext) => GetClassFactoryForTypeImpl(pCxtInt, isolatedContext);
}

[RequiresUnreferencedCode("Built-in COM support is not trim compatible", Url = "https://aka.ms/dotnet-illink/com")]
private static unsafe int GetClassFactoryForTypeImpl(ComActivationContextInternal* pCxtInt, bool isolatedContext)
{
ref ComActivationContextInternal cxtInt = ref *pCxtInt;

if (IsLoggingEnabled())
Expand All @@ -240,10 +295,8 @@ private static unsafe int GetClassFactoryForTypeInternal(ComActivationContextInt

try
{
var cxt = ComActivationContext.Create(ref cxtInt);
#pragma warning disable IL2026 // suppressed in ILLink.Suppressions.LibraryBuild.xml
var cxt = ComActivationContext.Create(ref cxtInt, isolatedContext);
object cf = GetClassFactoryForType(cxt);
#pragma warning restore IL2026
IntPtr nativeIUnknown = Marshal.GetIUnknownForObject(cf);
Marshal.WriteIntPtr(cxtInt.ClassFactoryDest, nativeIUnknown);
}
Expand All @@ -256,17 +309,37 @@ private static unsafe int GetClassFactoryForTypeInternal(ComActivationContextInt
}

/// <summary>
/// Internal entry point for registering a managed COM server API from native code
/// Registers a managed COM server in an isolated load context
/// </summary>
/// <param name="pCxtInt">Pointer to a <see cref="ComActivationContextInternal"/> instance</param>
[UnmanagedCallersOnly]
private static unsafe int RegisterClassForTypeInternal(ComActivationContextInternal* pCxtInt)
{
if (!Marshal.IsBuiltInComSupported)
{
throw new NotSupportedException(SR.NotSupported_COM);
}

return RegisterClassForTypeImpl(pCxtInt, isolatedContext: true);
}

/// <summary>
/// Registers a managed COM server in the specified load context
/// </summary>
/// <param name="pCxtInt">Pointer to a <see cref="ComActivationContextInternal"/> instance</param>
/// <param name="loadContext">Load context - currently must be IntPtr.Zero (default context) or -1 (isolated context)</param>
[UnmanagedCallersOnly]
private static unsafe int RegisterClassForTypeInContext(ComActivationContextInternal* pCxtInt, IntPtr loadContext)
{
if (!Marshal.IsBuiltInComSupported)
throw new NotSupportedException(SR.NotSupported_COM);

if (loadContext != IntPtr.Zero && loadContext != (IntPtr)(-1))
throw new ArgumentOutOfRangeException(nameof(loadContext));

return RegisterClassForTypeImpl(pCxtInt, isolatedContext: loadContext != IntPtr.Zero);
}

private static unsafe int RegisterClassForTypeImpl(ComActivationContextInternal* pCxtInt, bool isolatedContext)
{
ref ComActivationContextInternal cxtInt = ref *pCxtInt;

if (IsLoggingEnabled())
Expand All @@ -289,7 +362,7 @@ private static unsafe int RegisterClassForTypeInternal(ComActivationContextInter

try
{
var cxt = ComActivationContext.Create(ref cxtInt);
var cxt = ComActivationContext.Create(ref cxtInt, isolatedContext);
ClassRegistrationScenarioForTypeLocal(cxt, register: true);
}
catch (Exception e)
Expand All @@ -306,16 +379,36 @@ private static unsafe int RegisterClassForTypeInternal(ComActivationContextInter
}

/// <summary>
/// Internal entry point for unregistering a managed COM server API from native code
/// Unregisters a managed COM server in an isolated load context
/// </summary>
[UnmanagedCallersOnly]
private static unsafe int UnregisterClassForTypeInternal(ComActivationContextInternal* pCxtInt)
{
if (!Marshal.IsBuiltInComSupported)
{
throw new NotSupportedException(SR.NotSupported_COM);
}

return UnregisterClassForTypeImpl(pCxtInt, isolatedContext: true);
}

/// <summary>
/// Unregisters a managed COM server in the specified load context
/// </summary>
/// <param name="pCxtInt">Pointer to a <see cref="ComActivationContextInternal"/> instance</param>
/// <param name="loadContext">Load context - currently must be IntPtr.Zero (default context) or -1 (isolated context)</param>
[UnmanagedCallersOnly]
private static unsafe int UnregisterClassForTypeInContext(ComActivationContextInternal* pCxtInt, IntPtr loadContext)
{
if (!Marshal.IsBuiltInComSupported)
throw new NotSupportedException(SR.NotSupported_COM);

if (loadContext != IntPtr.Zero && loadContext != (IntPtr)(-1))
throw new ArgumentOutOfRangeException(nameof(loadContext));

return UnregisterClassForTypeImpl(pCxtInt, isolatedContext: loadContext != IntPtr.Zero);
}

private static unsafe int UnregisterClassForTypeImpl(ComActivationContextInternal* pCxtInt, bool isolatedContext)
{
ref ComActivationContextInternal cxtInt = ref *pCxtInt;

if (IsLoggingEnabled())
Expand All @@ -338,7 +431,7 @@ private static unsafe int UnregisterClassForTypeInternal(ComActivationContextInt

try
{
var cxt = ComActivationContext.Create(ref cxtInt);
var cxt = ComActivationContext.Create(ref cxtInt, isolatedContext);
ClassRegistrationScenarioForTypeLocal(cxt, register: false);
}
catch (Exception e)
Expand Down Expand Up @@ -370,14 +463,14 @@ private static void Log(string fmt, params object[] args)
}

[RequiresUnreferencedCode("Built-in COM support is not trim compatible", Url = "https://aka.ms/dotnet-illink/com")]
private static Type FindClassType(Guid clsid, string assemblyPath, string assemblyName, string typeName)
private static Type FindClassType(ComActivationContext cxt)
{
try
{
AssemblyLoadContext alc = GetALC(assemblyPath);
var assemblyNameLocal = new AssemblyName(assemblyName);
AssemblyLoadContext alc = GetALC(cxt.AssemblyPath, cxt.IsolatedContext);
var assemblyNameLocal = new AssemblyName(cxt.AssemblyName);
Assembly assem = alc.LoadFromAssemblyName(assemblyNameLocal);
Type? t = assem.GetType(typeName);
Type? t = assem.GetType(cxt.TypeName);
if (t != null)
{
return t;
Expand All @@ -387,7 +480,7 @@ private static Type FindClassType(Guid clsid, string assemblyPath, string assemb
{
if (IsLoggingEnabled())
{
Log($"COM Activation of {clsid} failed. {e}");
Log($"COM Activation of {cxt.ClassId} failed. {e}");
}
}

Expand All @@ -396,16 +489,39 @@ private static Type FindClassType(Guid clsid, string assemblyPath, string assemb
}

[RequiresUnreferencedCode("The trimmer might remove types which are needed by the assemblies loaded in this method.")]
private static AssemblyLoadContext GetALC(string assemblyPath)
private static AssemblyLoadContext GetALC(string assemblyPath, bool isolatedContext)
{
AssemblyLoadContext? alc;

lock (s_assemblyLoadContexts)
if (isolatedContext)
{
lock (s_assemblyLoadContexts)
{
if (!s_assemblyLoadContexts.TryGetValue(assemblyPath, out alc))
{
alc = new IsolatedComponentLoadContext(assemblyPath);
s_assemblyLoadContexts.Add(assemblyPath, alc);
}
}
}
else
{
if (!s_assemblyLoadContexts.TryGetValue(assemblyPath, out alc))
alc = AssemblyLoadContext.Default;
lock (s_loadedInDefaultContext)
{
alc = new IsolatedComponentLoadContext(assemblyPath);
s_assemblyLoadContexts.Add(assemblyPath, alc);
if (!s_loadedInDefaultContext.Contains(assemblyPath))
{
var resolver = new AssemblyDependencyResolver(assemblyPath);
AssemblyLoadContext.Default.Resolving +=
(context, assemblyName) =>
{
string? assemblyPath = resolver.ResolveAssemblyToPath(assemblyName);
return assemblyPath != null
? context.LoadFromAssemblyPath(assemblyPath)
: null;
};

s_loadedInDefaultContext.Add(assemblyPath);
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Reflection;
using System.Runtime.InteropServices;
using System.Runtime.Loader;

namespace ComLibrary
{
Expand All @@ -25,6 +27,8 @@ public class Server : IServer
{
public Server()
{
Assembly asm = Assembly.GetExecutingAssembly();
Console.WriteLine($"{asm.GetName().Name}: AssemblyLoadContext = {AssemblyLoadContext.GetLoadContext(asm)}");
Console.WriteLine($"New instance of {nameof(Server)} created");
}
}
Expand Down
Loading