Skip to content

Commit d24bd31

Browse files
authored
Add built-in tools to AIShell (#394)
1 parent f932833 commit d24bd31

File tree

10 files changed

+787
-40
lines changed

10 files changed

+787
-40
lines changed

shell/AIShell.Abstraction/NamedPipe.cs

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,32 @@ public enum MessageType : int
3535
PostCode = 4,
3636
}
3737

38+
/// <summary>
39+
/// Context types that can be requested by AIShell from the connected PowerShell session.
40+
/// </summary>
41+
public enum ContextType : int
42+
{
43+
/// <summary>
44+
/// Ask for the current working directory of the shell.
45+
/// </summary>
46+
CurrentLocation = 0,
47+
48+
/// <summary>
49+
/// Ask for the command history of the shell session.
50+
/// </summary>
51+
CommandHistory = 1,
52+
53+
/// <summary>
54+
/// Ask for the content of the terminal window.
55+
/// </summary>
56+
TerminalContent = 2,
57+
58+
/// <summary>
59+
/// Ask for the environment variables of the shell session.
60+
/// </summary>
61+
EnvironmentVariables = 3,
62+
}
63+
3864
/// <summary>
3965
/// Base class for all pipe messages.
4066
/// </summary>
@@ -108,12 +134,24 @@ public AskConnectionMessage(string pipeName)
108134
/// </summary>
109135
public sealed class AskContextMessage : PipeMessage
110136
{
137+
/// <summary>
138+
/// Gets the type of context information requested.
139+
/// </summary>
140+
public ContextType ContextType { get; }
141+
142+
/// <summary>
143+
/// Gets the argument value associated with the current context query operation.
144+
/// </summary>
145+
public string[] Arguments { get; }
146+
111147
/// <summary>
112148
/// Creates an instance of <see cref="AskContextMessage"/>.
113149
/// </summary>
114-
public AskContextMessage()
150+
public AskContextMessage(ContextType contextType, string[] arguments = null)
115151
: base(MessageType.AskContext)
116152
{
153+
ContextType = contextType;
154+
Arguments = arguments ?? null;
117155
}
118156
}
119157

@@ -125,21 +163,20 @@ public sealed class PostContextMessage : PipeMessage
125163
/// <summary>
126164
/// Represents a none instance to be used when the shell has no context information to return.
127165
/// </summary>
128-
public static readonly PostContextMessage None = new([]);
166+
public static readonly PostContextMessage None = new(contextInfo: null);
129167

130168
/// <summary>
131-
/// Gets the command history.
169+
/// Gets the information of the requested context.
132170
/// </summary>
133-
public List<string> CommandHistory { get; }
171+
public string ContextInfo { get; }
134172

135173
/// <summary>
136174
/// Creates an instance of <see cref="PostContextMessage"/>.
137175
/// </summary>
138-
public PostContextMessage(List<string> commandHistory)
176+
public PostContextMessage(string contextInfo)
139177
: base(MessageType.PostContext)
140178
{
141-
ArgumentNullException.ThrowIfNull(commandHistory);
142-
CommandHistory = commandHistory;
179+
ContextInfo = contextInfo;
143180
}
144181
}
145182

shell/AIShell.Integration/AIShell.psm1

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,4 @@ if ($null -eq $runspace) {
1313
}
1414

1515
## Create the channel singleton when loading the module.
16-
$null = [AIShell.Integration.Channel]::CreateSingleton($runspace, [Microsoft.PowerShell.PSConsoleReadLine])
16+
$null = [AIShell.Integration.Channel]::CreateSingleton($runspace, $ExecutionContext, [Microsoft.PowerShell.PSConsoleReadLine])

shell/AIShell.Integration/Channel.cs

Lines changed: 213 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
1-
using System.Diagnostics;
1+
using System.Collections.ObjectModel;
2+
using System.Diagnostics;
23
using System.Reflection;
34
using System.Text;
45
using System.Management.Automation;
6+
using System.Management.Automation.Host;
57
using System.Management.Automation.Runspaces;
68
using AIShell.Abstraction;
9+
using Microsoft.PowerShell.Commands;
10+
using System.Text.Json;
711

812
namespace AIShell.Integration;
913

@@ -15,28 +19,35 @@ public class Channel : IDisposable
1519
private readonly string _shellPipeName;
1620
private readonly Type _psrlType;
1721
private readonly Runspace _runspace;
22+
private readonly EngineIntrinsics _intrinsics;
1823
private readonly MethodInfo _psrlInsert, _psrlRevertLine, _psrlAcceptLine;
1924
private readonly FieldInfo _psrlHandleResizing, _psrlReadLineReady;
2025
private readonly object _psrlSingleton;
2126
private readonly ManualResetEvent _connSetupWaitHandler;
2227
private readonly Predictor _predictor;
2328
private readonly ScriptBlock _onIdleAction;
29+
private readonly List<HistoryInfo> _commandHistory;
2430

31+
private PathInfo _currentLocation;
2532
private ShellClientPipe _clientPipe;
2633
private ShellServerPipe _serverPipe;
2734
private bool? _setupSuccess;
2835
private Exception _exception;
2936
private Thread _serverThread;
3037
private CodePostData _pendingPostCodeData;
3138

32-
private Channel(Runspace runspace, Type psConsoleReadLineType)
39+
private Channel(Runspace runspace, EngineIntrinsics intrinsics, Type psConsoleReadLineType)
3340
{
3441
ArgumentNullException.ThrowIfNull(runspace);
3542
ArgumentNullException.ThrowIfNull(psConsoleReadLineType);
3643

3744
_runspace = runspace;
45+
_intrinsics = intrinsics;
3846
_psrlType = psConsoleReadLineType;
3947
_connSetupWaitHandler = new ManualResetEvent(false);
48+
_currentLocation = _intrinsics.SessionState.Path.CurrentLocation;
49+
_runspace.AvailabilityChanged += RunspaceAvailableAction;
50+
_intrinsics.InvokeCommand.LocationChangedAction += LocationChangedAction;
4051

4152
_shellPipeName = new StringBuilder(MaxNamedPipeNameSize)
4253
.Append("pwsh_aish.")
@@ -57,13 +68,14 @@ private Channel(Runspace runspace, Type psConsoleReadLineType)
5768
_psrlReadLineReady = _psrlType.GetField("_readLineReady", fieldFlags);
5869
_psrlHandleResizing = _psrlType.GetField("_handlePotentialResizing", fieldFlags);
5970

71+
_commandHistory = [];
6072
_predictor = new Predictor();
6173
_onIdleAction = ScriptBlock.Create("[AIShell.Integration.Channel]::Singleton.OnIdleHandler()");
6274
}
6375

64-
public static Channel CreateSingleton(Runspace runspace, Type psConsoleReadLineType)
76+
public static Channel CreateSingleton(Runspace runspace, EngineIntrinsics intrinsics, Type psConsoleReadLineType)
6577
{
66-
return Singleton ??= new Channel(runspace, psConsoleReadLineType);
78+
return Singleton ??= new Channel(runspace, intrinsics, psConsoleReadLineType);
6779
}
6880

6981
public static Channel Singleton { get; private set; }
@@ -127,6 +139,95 @@ private async void ThreadProc()
127139
await _serverPipe.StartProcessingAsync(ConnectionTimeout, CancellationToken.None);
128140
}
129141

142+
private void LocationChangedAction(object sender, LocationChangedEventArgs e)
143+
{
144+
_currentLocation = e.NewPath;
145+
}
146+
147+
private void RunspaceAvailableAction(object sender, RunspaceAvailabilityEventArgs e)
148+
{
149+
if (sender is null || e.RunspaceAvailability is not RunspaceAvailability.Available)
150+
{
151+
return;
152+
}
153+
154+
// It's safe to get states of the PowerShell Runspace now because it's available and this event
155+
// is handled synchronously.
156+
// We may want to invoke command or script here, and we have to unregister ourself before doing
157+
// that, because the invocation would change the availability of the Runspace, which will cause
158+
// the 'AvailabilityChanged' to be fired again and re-enter our handler.
159+
// We register ourself back after we are done with the processing.
160+
var pwshRunspace = (Runspace)sender;
161+
pwshRunspace.AvailabilityChanged -= RunspaceAvailableAction;
162+
163+
try
164+
{
165+
using var ps = PowerShell.Create();
166+
ps.Runspace = pwshRunspace;
167+
168+
var results = ps
169+
.AddCommand("Get-History")
170+
.AddParameter("Count", 5)
171+
.InvokeAndCleanup<HistoryInfo>();
172+
173+
if (results.Count is 0 ||
174+
(_commandHistory.Count > 0 && _commandHistory[^1].Id == results[^1].Id))
175+
{
176+
// No command history yet, or no change since the last update.
177+
return;
178+
}
179+
180+
lock (_commandHistory)
181+
{
182+
_commandHistory.Clear();
183+
_commandHistory.AddRange(results);
184+
}
185+
}
186+
finally
187+
{
188+
pwshRunspace.AvailabilityChanged += RunspaceAvailableAction;
189+
}
190+
}
191+
192+
private string CaptureScreen()
193+
{
194+
if (!OperatingSystem.IsWindows())
195+
{
196+
return null;
197+
}
198+
199+
try
200+
{
201+
PSHostRawUserInterface rawUI = _intrinsics.Host.UI.RawUI;
202+
Coordinates start = new(0, 0), end = rawUI.CursorPosition;
203+
end.X = rawUI.BufferSize.Width - 1;
204+
205+
BufferCell[,] content = rawUI.GetBufferContents(new Rectangle(start, end));
206+
StringBuilder line = new(), buffer = new();
207+
208+
int rows = content.GetLength(0);
209+
int columns = content.GetLength(1);
210+
211+
for (int row = 0; row < rows; row++)
212+
{
213+
line.Clear();
214+
for (int column = 0; column < columns; column++)
215+
{
216+
line.Append(content[row, column].Character);
217+
}
218+
219+
line.TrimEnd();
220+
buffer.Append(line).Append('\n');
221+
}
222+
223+
return buffer.Length is 0 ? string.Empty : buffer.ToString();
224+
}
225+
catch
226+
{
227+
return null;
228+
}
229+
}
230+
130231
internal void PostQuery(PostQueryMessage message)
131232
{
132233
ThrowIfNotConnected();
@@ -138,6 +239,8 @@ public void Dispose()
138239
Reset();
139240
_connSetupWaitHandler.Dispose();
140241
_predictor.Unregister();
242+
_runspace.AvailabilityChanged -= RunspaceAvailableAction;
243+
_intrinsics.InvokeCommand.LocationChangedAction -= LocationChangedAction;
141244
GC.SuppressFinalize(this);
142245
}
143246

@@ -257,8 +360,76 @@ private void OnPostCode(PostCodeMessage postCodeMessage)
257360

258361
private PostContextMessage OnAskContext(AskContextMessage askContextMessage)
259362
{
260-
// Not implemented yet.
261-
return null;
363+
const string RedactedValue = "***<sensitive data redacted>***";
364+
365+
ContextType type = askContextMessage.ContextType;
366+
string[] arguments = askContextMessage.Arguments;
367+
368+
string contextInfo;
369+
switch (type)
370+
{
371+
case ContextType.CurrentLocation:
372+
contextInfo = JsonSerializer.Serialize(
373+
new { Provider = _currentLocation.Provider.Name, _currentLocation.Path });
374+
break;
375+
376+
case ContextType.CommandHistory:
377+
lock (_commandHistory)
378+
{
379+
contextInfo = JsonSerializer.Serialize(
380+
_commandHistory.Select(o => new { o.Id, o.CommandLine }));
381+
}
382+
break;
383+
384+
case ContextType.TerminalContent:
385+
contextInfo = CaptureScreen();
386+
break;
387+
388+
case ContextType.EnvironmentVariables:
389+
if (arguments is { Length: > 0 })
390+
{
391+
var varsCopy = new Dictionary<string, string>();
392+
foreach (string name in arguments)
393+
{
394+
if (!string.IsNullOrEmpty(name))
395+
{
396+
varsCopy.Add(name, Environment.GetEnvironmentVariable(name) is string value
397+
? EnvVarMayBeSensitive(name) ? RedactedValue : value
398+
: $"[env variable '{arguments}' is undefined]");
399+
}
400+
}
401+
402+
contextInfo = varsCopy.Count > 0
403+
? JsonSerializer.Serialize(varsCopy)
404+
: "The specified environment variable names are invalid";
405+
}
406+
else
407+
{
408+
var vars = Environment.GetEnvironmentVariables();
409+
var varsCopy = new Dictionary<string, string>();
410+
411+
foreach (string key in vars.Keys)
412+
{
413+
varsCopy.Add(key, EnvVarMayBeSensitive(key) ? RedactedValue : (string)vars[key]);
414+
}
415+
416+
contextInfo = JsonSerializer.Serialize(varsCopy);
417+
}
418+
break;
419+
420+
default:
421+
throw new InvalidDataException($"Unknown context type '{type}'");
422+
}
423+
424+
return new PostContextMessage(contextInfo);
425+
426+
static bool EnvVarMayBeSensitive(string key)
427+
{
428+
return key.Contains("key", StringComparison.OrdinalIgnoreCase) ||
429+
key.Contains("token", StringComparison.OrdinalIgnoreCase) ||
430+
key.Contains("pass", StringComparison.OrdinalIgnoreCase) ||
431+
key.Contains("secret", StringComparison.OrdinalIgnoreCase);
432+
}
262433
}
263434

264435
private void OnAskConnection(ShellClientPipe clientPipe, Exception exception)
@@ -334,3 +505,39 @@ public void Dispose()
334505
}
335506

336507
internal record CodePostData(string CodeToInsert, List<PredictionCandidate> PredictionCandidates);
508+
509+
internal static class ExtensionMethods
510+
{
511+
internal static Collection<T> InvokeAndCleanup<T>(this PowerShell ps)
512+
{
513+
var results = ps.Invoke<T>();
514+
ps.Commands.Clear();
515+
516+
return results;
517+
}
518+
519+
internal static void InvokeAndCleanup(this PowerShell ps)
520+
{
521+
ps.Invoke();
522+
ps.Commands.Clear();
523+
}
524+
525+
internal static void TrimEnd(this StringBuilder sb)
526+
{
527+
// end will point to the first non-trimmed character on the right.
528+
int end = sb.Length - 1;
529+
for (; end >= 0; end--)
530+
{
531+
if (!char.IsWhiteSpace(sb[end]))
532+
{
533+
break;
534+
}
535+
}
536+
537+
int index = end + 1;
538+
if (index < sb.Length)
539+
{
540+
sb.Remove(index, sb.Length - index);
541+
}
542+
}
543+
}

0 commit comments

Comments
 (0)