Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

.Net: Moved IChatHistoryReducer from Agents to SK packages #10285

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright (c) Microsoft. All rights reserved.

using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.Agents;
using Microsoft.SemanticKernel.Agents.History;
using Microsoft.SemanticKernel.ChatCompletion;

namespace Agents;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,27 +19,29 @@ internal static class ChatHistoryExtensions
/// <remarks>
/// For simplicity only a single system message is supported in these examples.
/// </remarks>
internal static ChatMessageContent? GetSystemMessage(this ChatHistory chatHistory)
internal static ChatMessageContent? GetSystemMessage(this IReadOnlyList<ChatMessageContent> history)
{
return chatHistory.FirstOrDefault(m => m.Role == AuthorRole.System);
return history.FirstOrDefault(m => m.Role == AuthorRole.System);
dmytrostruk marked this conversation as resolved.
Show resolved Hide resolved
}

/// <summary>
/// Extract a range of messages from the provided <see cref="ChatHistory"/>.
/// </summary>
/// <param name="chatHistory">The source history</param>
/// <param name="history">The source history</param>
/// <param name="startIndex">The index of the first messageContent to extract</param>
/// <param name="endIndex">The index of the first messageContent to extract, if null extract up to the end of the chat history</param>
/// <param name="systemMessage">An optional system messageContent to include</param>
/// <param name="summaryMessage">An optional summary messageContent to include</param>
/// <param name="messageFilter">An optional message filter</param>
public static IEnumerable<ChatMessageContent> Extract(
this ChatHistory chatHistory, int startIndex, int? endIndex = null,
this IReadOnlyList<ChatMessageContent> history,
int startIndex,
int? endIndex = null,
ChatMessageContent? systemMessage = null,
ChatMessageContent? summaryMessage = null,
Func<ChatMessageContent, bool>? messageFilter = null)
{
endIndex ??= chatHistory.Count - 1;
endIndex ??= history.Count - 1;
if (startIndex > endIndex)
{
yield break;
Expand All @@ -57,7 +59,7 @@ public static IEnumerable<ChatMessageContent> Extract(

for (int index = startIndex; index <= endIndex; ++index)
{
var messageContent = chatHistory[index];
var messageContent = history[index];

if (messageFilter?.Invoke(messageContent) ?? false)
{
Expand All @@ -71,24 +73,24 @@ public static IEnumerable<ChatMessageContent> Extract(
/// <summary>
/// Compute the index truncation where truncation should begin using the current truncation threshold.
/// </summary>
/// <param name="chatHistory">ChatHistory instance to be truncated</param>
/// <param name="truncatedSize"></param>
/// <param name="truncationThreshold"></param>
/// <param name="history">The source history.</param>
/// <param name="truncatedSize">Truncated size.</param>
/// <param name="truncationThreshold">Truncation threshold.</param>
/// <param name="hasSystemMessage">Flag indicating whether or not the chat history contains a system messageContent</param>
public static int ComputeTruncationIndex(this ChatHistory chatHistory, int truncatedSize, int truncationThreshold, bool hasSystemMessage)
public static int ComputeTruncationIndex(this IReadOnlyList<ChatMessageContent> history, int truncatedSize, int truncationThreshold, bool hasSystemMessage)
{
if (chatHistory.Count <= truncationThreshold)
if (history.Count <= truncationThreshold)
{
return -1;
}

// Compute the index of truncation target
var truncationIndex = chatHistory.Count - (truncatedSize - (hasSystemMessage ? 1 : 0));
var truncationIndex = history.Count - (truncatedSize - (hasSystemMessage ? 1 : 0));

// Skip function related content
while (truncationIndex < chatHistory.Count)
while (truncationIndex < history.Count)
{
if (chatHistory[truncationIndex].Items.Any(i => i is FunctionCallContent || i is FunctionResultContent))
if (history[truncationIndex].Items.Any(i => i is FunctionCallContent || i is FunctionResultContent))
{
truncationIndex++;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,31 +11,6 @@ namespace ChatCompletion;
/// </summary>
public class ChatHistoryReducerTests(ITestOutputHelper output) : BaseTest(output)
{
[Theory]
[InlineData(3, null, null, 5, 0)]
[InlineData(2, null, null, 1, 1)]
[InlineData(2, "SystemMessage", null, 2, 2)]
[InlineData(10, null, null, 3, 3)]
[InlineData(10, "SystemMessage", null, 3, 3)]
[InlineData(9, null, null, 5, 5)]
[InlineData(11, null, null, 5, 5)]
[InlineData(8, "SystemMessage", null, 5, 5)]
[InlineData(10, "SystemMessage", null, 5, 5)]
[InlineData(3, null, new int[] { 0 }, 3, 2)]
[InlineData(3, "SystemMessage", new int[] { 0 }, 4, 3)]
public async Task VerifyTruncatingChatHistoryReducerAsync(int messageCount, string? systemMessage, int[]? functionCallIndexes, int truncatedSize, int expectedSize)
{
// Arrange
var chatHistory = CreateHistoryWithUserInput(messageCount, systemMessage, functionCallIndexes);
var reducer = new TruncatingChatHistoryReducer(truncatedSize);

// Act
var reducedHistory = await reducer.ReduceAsync(chatHistory);

// Assert
VerifyReducedHistory(reducedHistory, ComputeExpectedMessages(chatHistory, expectedSize));
}

[Theory]
[InlineData(3, null, null, 100, 0)]
[InlineData(3, "SystemMessage", null, 100, 0)]
Expand All @@ -56,29 +31,6 @@ public async Task VerifyMaxTokensChatHistoryReducerAsync(int messageCount, strin
VerifyReducedHistory(reducedHistory, ComputeExpectedMessages(chatHistory, expectedSize));
}

[Theory]
[InlineData(3, null, null, 5, 10, 0)]
[InlineData(10, null, null, 5, 10, 6)]
[InlineData(10, "SystemMessage", null, 5, 10, 6)]
[InlineData(10, null, new int[] { 1 }, 5, 10, 6)]
[InlineData(10, "SystemMessage", new int[] { 2 }, 5, 10, 6)]
public async Task VerifySummarizingChatHistoryReducerAsync(int messageCount, string? systemMessage, int[]? functionCallIndexes, int truncatedSize, int truncationThreshold, int expectedSize)
{
// Arrange
Assert.NotNull(TestConfiguration.OpenAI.ChatModelId);
Assert.NotNull(TestConfiguration.OpenAI.ApiKey);
IChatCompletionService chatClient = new FakeChatCompletionService("The dialog consists of repetitive interaction where both the user and assistant exchange identical phrases in Latin.");

var chatHistory = CreateHistoryWithUserInput(messageCount, systemMessage, functionCallIndexes, true);
var reducer = new SummarizingChatHistoryReducer(chatClient, truncatedSize, truncationThreshold);

// Act
var reducedHistory = await reducer.ReduceAsync(chatHistory);

// Assert
VerifyReducedHistory(reducedHistory, ComputeExpectedMessages(chatHistory, expectedSize));
}

private static void VerifyReducedHistory(IEnumerable<ChatMessageContent>? reducedHistory, ChatMessageContent[]? expectedMessages)
{
if (expectedMessages is null)
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -30,37 +30,38 @@ public MaxTokensChatHistoryReducer(int maxTokenCount)
}

/// <inheritdoc/>
public Task<IEnumerable<ChatMessageContent>?> ReduceAsync(ChatHistory chatHistory, CancellationToken cancellationToken = default)
public Task<IEnumerable<ChatMessageContent>?> ReduceAsync(IReadOnlyList<ChatMessageContent> history, CancellationToken cancellationToken = default)
{
var systemMessage = chatHistory.GetSystemMessage();
var systemMessage = history.GetSystemMessage();

var truncationIndex = ComputeTruncationIndex(chatHistory, systemMessage);
var truncationIndex = ComputeTruncationIndex(history, systemMessage);

IEnumerable<ChatMessageContent>? truncatedHistory = null;

if (truncationIndex > 0)
{
truncatedHistory = chatHistory.Extract(truncationIndex, systemMessage: systemMessage);
truncatedHistory = history.Extract(truncationIndex, systemMessage: systemMessage);
}

return Task.FromResult<IEnumerable<ChatMessageContent>?>(truncatedHistory);
}

#region private

/// <summary>
/// Compute the index truncation where truncation should begin using the current truncation threshold.
/// </summary>
/// <param name="chatHistory">ChatHistory instance to be truncated</param>
/// <param name="history">Chat history to be truncated.</param>
/// <param name="systemMessage">The system message</param>
private int ComputeTruncationIndex(ChatHistory chatHistory, ChatMessageContent? systemMessage)
private int ComputeTruncationIndex(IReadOnlyList<ChatMessageContent> history, ChatMessageContent? systemMessage)
{
var truncationIndex = -1;

var totalTokenCount = (int)(systemMessage?.Metadata?["TokenCount"] ?? 0);
for (int i = chatHistory.Count - 1; i >= 0; i--)
for (int i = history.Count - 1; i >= 0; i--)
{
truncationIndex = i;
var tokenCount = (int)(chatHistory[i].Metadata?["TokenCount"] ?? 0);
var tokenCount = (int)(history[i].Metadata?["TokenCount"] ?? 0);
if (tokenCount + totalTokenCount > this._maxTokenCount)
{
break;
Expand All @@ -69,9 +70,9 @@ private int ComputeTruncationIndex(ChatHistory chatHistory, ChatMessageContent?
}

// Skip function related content
while (truncationIndex < chatHistory.Count)
while (truncationIndex < history.Count)
{
if (chatHistory[truncationIndex].Items.Any(i => i is FunctionCallContent || i is FunctionResultContent))
if (history[truncationIndex].Items.Any(i => i is FunctionCallContent || i is FunctionResultContent))
{
truncationIndex++;
}
Expand All @@ -83,5 +84,6 @@ private int ComputeTruncationIndex(ChatHistory chatHistory, ChatMessageContent?

return truncationIndex;
}

#endregion
}

This file was deleted.

Loading
Loading