Skip to content

Commit

Permalink
Allow Microsoft.WinGet.Client to run in any PowerShell session runnin…
Browse files Browse the repository at this point in the history
…g as system (#3816)

This PR adds support for running the Microsoft.WinGet.Client module in
system context without the need to start pwsh.exe -MTA. Running as MTA
is required using inproc Microsoft.Management.Deployment.

If the module is running inproc and the current thread is not an MTA, it
will create a new MTA thread and execute there. Otherwise, non inproc or
already an MTA will use the current thread.

This was done by sharing AsyncCommand (renamed to PowerShellCmdlet) from
Microsoft.WinGet.Configuration. Originally, I wanted to create a new
shared lib, but decided to just share the files between the projects.

All cmdlets must inherit from PowerShellCmdlet. All cmdlets that use
Microsoft.Management.Deployment must inherit ManagementDeploymentCommand
and use Execute at the command engine entry point. As a safe mechanism,
if any call to PackageManagerWrapper will verify the thread is not an
STA if running inproc and fail.

I verified the cmdlets work in system context locally. A future PR will
add running our pester tests using psexe.exe
  • Loading branch information
msftrubengu authored Nov 10, 2023
1 parent 28a3073 commit 09c8771
Show file tree
Hide file tree
Showing 59 changed files with 618 additions and 439 deletions.

Large diffs are not rendered by default.

49 changes: 49 additions & 0 deletions src/PowerShell/CommonFiles/StreamType.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// -----------------------------------------------------------------------------
// <copyright file="StreamType.cs" company="Microsoft Corporation">
// Copyright (c) Microsoft Corporation. Licensed under the MIT License.
// </copyright>
// -----------------------------------------------------------------------------

namespace Microsoft.WinGet.Common.Command
{
/// <summary>
/// The write stream type of the cmdlet.
/// </summary>
public enum StreamType
{
/// <summary>
/// Debug.
/// </summary>
Debug,

/// <summary>
/// Verbose.
/// </summary>
Verbose,

/// <summary>
/// Warning.
/// </summary>
Warning,

/// <summary>
/// Error.
/// </summary>
Error,

/// <summary>
/// Progress.
/// </summary>
Progress,

/// <summary>
/// Object.
/// </summary>
Object,

/// <summary>
/// Information.
/// </summary>
Information,
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ namespace Microsoft.WinGet.Client.Engine.Commands
using Microsoft.WinGet.Client.Engine.Commands.Common;
using Microsoft.WinGet.Client.Engine.Common;
using Microsoft.WinGet.Client.Engine.Helpers;
using Microsoft.WinGet.Common.Command;

/// <summary>
/// Commands that just calls winget.exe underneath.
Expand Down Expand Up @@ -55,11 +56,11 @@ public void GetSettings(bool asPlainText)

if (asPlainText)
{
this.PsCmdlet.WriteObject(result.StdOut);
this.Write(StreamType.Object, result.StdOut);
}
else
{
this.PsCmdlet.WriteObject(Utilities.ConvertToHashtable(result.StdOut));
this.Write(StreamType.Object, Utilities.ConvertToHashtable(result.StdOut));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,48 +6,23 @@

namespace Microsoft.WinGet.Client.Engine.Commands.Common
{
using System.Collections.Generic;
using System.Management.Automation;
using Microsoft.WinGet.Client.Engine.Common;
using Microsoft.WinGet.Client.Engine.Exceptions;
using Microsoft.WinGet.SharedLib.Exceptions;
using Microsoft.WinGet.Common.Command;
using Microsoft.WinGet.SharedLib.PolicySettings;

/// <summary>
/// Base class for all Cmdlets.
/// </summary>
public abstract class BaseCommand
public abstract class BaseCommand : PowerShellCmdlet
{
/// <summary>
/// Initializes a new instance of the <see cref="BaseCommand"/> class.
/// </summary>
/// <param name="psCmdlet">PSCmdlet.</param>
internal BaseCommand(PSCmdlet psCmdlet)
: base()
: base(psCmdlet, new HashSet<Policy> { Policy.WinGet, Policy.WinGetCommandLineInterfaces })
{
// The inproc COM API may deadlock on an STA thread.
if (Utilities.UsesInProcWinget && Utilities.ThreadIsSTA)
{
throw new SingleThreadedApartmentException();
}

GroupPolicy groupPolicy = GroupPolicy.GetInstance();

if (!groupPolicy.IsEnabled(Policy.WinGet))
{
throw new GroupPolicyException(Policy.WinGet, GroupPolicyFailureType.BlockedByPolicy);
}

if (!groupPolicy.IsEnabled(Policy.WinGetCommandLineInterfaces))
{
throw new GroupPolicyException(Policy.WinGetCommandLineInterfaces, GroupPolicyFailureType.BlockedByPolicy);
}

this.PsCmdlet = psCmdlet;
}

/// <summary>
/// Gets the caller PSCmdlet.
/// </summary>
protected PSCmdlet PsCmdlet { get; private set; }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,31 +35,33 @@ internal FinderCommand(PSCmdlet psCmdlet)
/// Gets or sets the field that is matched against the identifier of a package.
/// </summary>
[Filter(Field = PackageMatchField.Id)]
protected string Id { get; set; }
protected string? Id { get; set; }

/// <summary>
/// Gets or sets the field that is matched against the name of a package.
/// </summary>
[Filter(Field = PackageMatchField.Name)]
protected string Name { get; set; }
protected string? Name { get; set; }

/// <summary>
/// Gets or sets the field that is matched against the moniker of a package.
/// </summary>
[Filter(Field = PackageMatchField.Moniker)]
protected string Moniker { get; set; }
protected string? Moniker { get; set; }

/// <summary>
/// Gets or sets the name of the source to search for packages. If null, then all sources are searched.
/// </summary>
protected string Source { get; set; }
protected string? Source { get; set; }

/// <summary>
/// Gets or sets how to match against package fields.
/// </summary>
protected string[] Query { get; set; }
#pragma warning disable SA1011 // Closing square brackets should be spaced correctly
protected string[]? Query { get; set; }
#pragma warning restore SA1011 // Closing square brackets should be spaced correctly

private string QueryAsJoinedString
private string? QueryAsJoinedString
{
get
{
Expand Down Expand Up @@ -98,7 +100,7 @@ protected IReadOnlyList<MatchResult> FindPackages(
protected virtual void SetQueryInFindPackagesOptions(
ref FindPackagesOptions options,
string match,
string value)
string? value)
{
var selector = ManagementDeploymentFactory.Instance.CreatePackageMatchFilter();
selector.Field = PackageMatchField.CatalogDefault;
Expand All @@ -111,7 +113,7 @@ private void AddFilterToFindPackagesOptionsIfNotNull(
ref FindPackagesOptions options,
PackageMatchField field,
PackageFieldMatchOption match,
string value)
string? value)
{
if (value != null)
{
Expand Down Expand Up @@ -187,7 +189,7 @@ private void AddAttributedFiltersToFindPackagesOptions(
if (info.GetCustomAttribute(typeof(FilterAttribute), true) is FilterAttribute attribute)
{
PackageMatchField field = attribute.Field;
string value = info.GetValue(this, null) as string;
string? value = info.GetValue(this, null) as string;
this.AddFilterToFindPackagesOptionsIfNotNull(ref options, field, match, value);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@ internal FinderExtendedCommand(PSCmdlet psCmdlet)
/// Gets or sets the filter that is matched against the tags of the package.
/// </summary>
[Filter(Field = PackageMatchField.Tag)]
protected string Tag { get; set; }
protected string? Tag { get; set; }

/// <summary>
/// Gets or sets the filter that is matched against the commands of the package.
/// </summary>
[Filter(Field = PackageMatchField.Command)]
protected string Command { get; set; }
protected string? Command { get; set; }

/// <summary>
/// Gets or sets the maximum number of results returned.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,17 @@ internal InstallCommand(PSCmdlet psCmdlet)
/// <summary>
/// Gets or sets the override arguments to be passed on to the installer.
/// </summary>
protected string Override { get; set; }
protected string? Override { get; set; }

/// <summary>
/// Gets or sets the arguments to be passed on to the installer in addition to the defaults.
/// </summary>
protected string Custom { get; set; }
protected string? Custom { get; set; }

/// <summary>
/// Gets or sets the installation location.
/// </summary>
protected string Location { get; set; }
protected string? Location { get; set; }

/// <summary>
/// Gets or sets a value indicating whether to skip the installer hash validation check.
Expand All @@ -54,7 +54,7 @@ internal InstallCommand(PSCmdlet psCmdlet)
/// <summary>
/// Gets or sets the optional HTTP Header to pass on to the REST Source.
/// </summary>
protected string Header { get; set; }
protected string? Header { get; set; }

/// <summary>
/// Gets the install options from the configured parameters.
Expand All @@ -65,7 +65,7 @@ internal InstallCommand(PSCmdlet psCmdlet)
/// <param name="version">The <see cref="PackageVersionId" /> to install.</param>
/// <param name="mode">Package install mode as string.</param>
/// <returns>An <see cref="InstallOptions" /> instance.</returns>
protected virtual InstallOptions GetInstallOptions(PackageVersionId version, string mode)
protected virtual InstallOptions GetInstallOptions(PackageVersionId? version, string mode)
{
InstallOptions options = ManagementDeploymentFactory.Instance.CreateInstallOptions();
options.AllowHashMismatch = this.AllowHashMismatch;
Expand Down Expand Up @@ -115,10 +115,11 @@ protected InstallResult RegisterCallbacksAndWait(
IAsyncOperationWithProgress<InstallResult, InstallProgress> operation,
string activity)
{
WriteProgressAdapter adapter = new (this.PsCmdlet);
var activityId = this.GetNewProgressActivityId();
WriteProgressAdapter adapter = new (this);
operation.Progress = (context, progress) =>
{
ProgressRecord record = new (1, activity, progress.State.ToString())
ProgressRecord record = new (activityId, activity, progress.State.ToString())
{
RecordType = ProgressRecordType.Processing,
};
Expand All @@ -137,7 +138,7 @@ protected InstallResult RegisterCallbacksAndWait(
};
operation.Completed = (context, status) =>
{
adapter.WriteProgress(new ProgressRecord(1, activity, status.ToString())
adapter.WriteProgress(new ProgressRecord(activityId, activity, status.ToString())
{
RecordType = ProgressRecordType.Completed,
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@ namespace Microsoft.WinGet.Client.Engine.Commands.Common
using System.Management.Automation;
using System.Runtime.InteropServices;
using Microsoft.Management.Deployment;
using Microsoft.WinGet.Client.Engine.Common;
using Microsoft.WinGet.Client.Engine.Exceptions;
using Microsoft.WinGet.Client.Engine.Helpers;
using Microsoft.WinGet.Client.Engine.Properties;
using Microsoft.WinGet.Resources;

/// <summary>
/// This is the base class for all of the commands in this module that use the COM APIs.
Expand All @@ -39,13 +40,30 @@ internal ManagementDeploymentCommand(PSCmdlet psCmdlet)
#endif
}

/// <summary>
/// Executes the cmdlet. All cmdlets that uses the COM APIs MUST use this method.
/// The inproc COM API may deadlock on an STA thread.
/// </summary>
/// <typeparam name="TResult">The type of result of the cmdlet.</typeparam>
/// <param name="func">Cmdlet function.</param>
/// <returns>The result of the cmdlet.</returns>
protected TResult Execute<TResult>(Func<TResult> func)
{
if (Utilities.UsesInProcWinget)
{
return this.RunOnMTA(func);
}

return func();
}

/// <summary>
/// Retrieves the specified source or all sources if <paramref name="source" /> is null.
/// </summary>
/// <returns>A list of <see cref="PackageCatalogReference" /> instances.</returns>
/// <param name="source">The name of the source to retrieve. If null, then all sources are returned.</param>
/// <exception cref="ArgumentException">The source does not exist.</exception>
protected IReadOnlyList<PackageCatalogReference> GetPackageCatalogReferences(string source)
protected IReadOnlyList<PackageCatalogReference> GetPackageCatalogReferences(string? source)
{
if (string.IsNullOrEmpty(source))
{
Expand All @@ -55,8 +73,8 @@ protected IReadOnlyList<PackageCatalogReference> GetPackageCatalogReferences(str
{
return new List<PackageCatalogReference>()
{
PackageManagerWrapper.Instance.GetPackageCatalogByName(source)
?? throw new InvalidSourceException(source),
PackageManagerWrapper.Instance.GetPackageCatalogByName(source!)
?? throw new InvalidSourceException(source!),
};
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,35 +36,40 @@ internal PackageCommand(PSCmdlet psCmdlet)
/// <remarks>
/// Must match the name of the <see cref="CatalogPackage" /> field on the <see cref="MatchResult" /> class.
/// </remarks>
protected PSCatalogPackage CatalogPackage { get; set; } = null;
protected PSCatalogPackage? CatalogPackage { get; set; } = null;

/// <summary>
/// Gets or sets the version to install.
/// </summary>
protected string Version { get; set; }
protected string? Version { get; set; }

/// <summary>
/// Gets or sets the path to the logging file.
/// </summary>
protected string Log { get; set; }
protected string? Log { get; set; }

/// <summary>
/// Executes a command targeting a specific package version.
/// </summary>
/// <typeparam name="TResult">Type of callback's result.</typeparam>
/// <param name="behavior">The <see cref="CompositeSearchBehavior" /> value.</param>
/// <param name="match">The match option.</param>
/// <param name="callback">The method to call after retrieving the package and version to operate upon.</param>
protected void GetPackageAndExecute(
/// <returns>Result of the callback.</returns>
protected TResult? GetPackageAndExecute<TResult>(
CompositeSearchBehavior behavior,
PackageFieldMatchOption match,
Action<CatalogPackage, PackageVersionId> callback)
Func<CatalogPackage, PackageVersionId?, TResult> callback)
where TResult : class
{
CatalogPackage package = this.GetCatalogPackage(behavior, match);
PackageVersionId version = this.GetPackageVersionId(package);
if (this.PsCmdlet.ShouldProcess(package.ToString(version)))
PackageVersionId? version = this.GetPackageVersionId(package);
if (this.ShouldProcess(package.ToString(version)))
{
callback(package, version);
return callback(package, version);
}

return null;
}

/// <summary>
Expand All @@ -79,7 +84,7 @@ protected void GetPackageAndExecute(
protected override void SetQueryInFindPackagesOptions(
ref FindPackagesOptions options,
string match,
string value)
string? value)
{
var matchOption = PSEnumHelpers.ToPackageFieldMatchOption(match);
foreach (PackageMatchField field in new PackageMatchField[] { PackageMatchField.Id, PackageMatchField.Name, PackageMatchField.Moniker })
Expand Down Expand Up @@ -120,7 +125,7 @@ private CatalogPackage GetCatalogPackage(CompositeSearchBehavior behavior, Packa
}
}

private PackageVersionId GetPackageVersionId(CatalogPackage package)
private PackageVersionId? GetPackageVersionId(CatalogPackage package)
{
if (this.Version != null)
{
Expand Down
Loading

0 comments on commit 09c8771

Please sign in to comment.