Skip to content

Commit

Permalink
tests
Browse files Browse the repository at this point in the history
  • Loading branch information
singhk97 committed Sep 30, 2024
1 parent bf35a06 commit 9ab356b
Show file tree
Hide file tree
Showing 12 changed files with 229 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,18 +38,18 @@ public async Task Test_CreatePlanFromResponseAsync_ValidPlan_ShouldSucceed()
Message = new(ChatRole.Assistant)
{
Content = @"{
""type"": ""plan"",
""commands"": [
{
""type"": ""DO"",
""action"": ""test""
},
{
""type"": ""SAY"",
""response"": ""hello""
}
]
}",
""type"": ""plan"",
""commands"": [
{
""type"": ""DO"",
""action"": ""test""
},
{
""type"": ""SAY"",
""response"": ""hello""
}
]
}",
Context = new()
{
Intent = "test intent",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
using Microsoft.Bot.Builder;
using Microsoft.Teams.AI.AI.Augmentations;
using Microsoft.Teams.AI.AI.Models;
using Microsoft.Teams.AI.AI.Planners;
using Microsoft.Teams.AI.AI.Prompts;
using Microsoft.Teams.AI.State;
using Microsoft.Teams.AI.Tests.TestUtils;

namespace Microsoft.Teams.AI.Tests.AITests.Augmentations
{
public class ToolsAugmentationTests
{
[Fact]
public async void Test_CreatePlanFromResponse_NoActionCalls_CreateSayCommand()
{
// Arrange
ToolsAugmentation augmentation = new ToolsAugmentation();
TurnContext context = TurnStateConfig.CreateConfiguredTurnContext();
TurnState state = await TurnStateConfig.GetTurnStateWithConversationStateAsync(context);
PromptResponse response = new PromptResponse();
response.Message = new ChatMessage(ChatRole.Assistant) { Content = "testMessage" };

// Act
Plan? plan = await augmentation.CreatePlanFromResponseAsync(context, state, response);

// Assert
Assert.NotNull(plan);
Assert.Single(plan.Commands);

var sayCommand = plan.Commands[0] as PredictedSayCommand;
Assert.NotNull(sayCommand);
Assert.Equal("testMessage", sayCommand.Response.Content);
}

[Fact]
public async void Test_CreatePlanFromResponse_OneActionCall()
{
// Arrange
ToolsAugmentation augmentation = new ToolsAugmentation();
TurnContext context = TurnStateConfig.CreateConfiguredTurnContext();
TurnState state = await TurnStateConfig.GetTurnStateWithConversationStateAsync(context);
PromptResponse response = new PromptResponse();
response.Message = new ChatMessage(ChatRole.Assistant) {
Content = "testMessage",
ActionCalls = new List<ActionCall>() {
new ActionCall("id", new ActionFunction("testFunction", "{ \"key\": \"value\" }"))
}
};


// Act
Plan? plan = await augmentation.CreatePlanFromResponseAsync(context, state, response);

// Assert
Assert.NotNull(plan);
Assert.Single(plan.Commands);

var doCommand = plan.Commands[0] as PredictedDoCommand;
Assert.NotNull(doCommand);
Assert.Equal("testFunction", doCommand.Action);
Assert.Equal("value", doCommand.Parameters!["key"]);
}

[Fact]
public async void Test_CreatePlanFromTesponse_MultipleActionCalls()
{
// Arrange
ToolsAugmentation augmentation = new ToolsAugmentation();
TurnContext context = TurnStateConfig.CreateConfiguredTurnContext();
TurnState state = await TurnStateConfig.GetTurnStateWithConversationStateAsync(context);
PromptResponse response = new PromptResponse();
response.Message = new ChatMessage(ChatRole.Assistant)
{
Content = "testMessage",
ActionCalls = new List<ActionCall>() {
new ActionCall("id1", new ActionFunction("testFunction1", "{ \"key1\": \"value1\" }")),
new ActionCall("id2", new ActionFunction("testFunction2", "{ \"key2\": \"value2\" }")),
}
};


// Act
Plan? plan = await augmentation.CreatePlanFromResponseAsync(context, state, response);

// Assert
Assert.NotNull(plan);
Assert.Equal(2, plan.Commands.Count);

var doCommand1 = plan.Commands[0] as PredictedDoCommand;
Assert.NotNull(doCommand1);
Assert.Equal("testFunction1", doCommand1.Action);
Assert.Equal("value1", doCommand1.Parameters!["key1"]);

var doCommand2 = plan.Commands[1] as PredictedDoCommand;
Assert.NotNull(doCommand2);
Assert.Equal("testFunction2", doCommand2.Action);
Assert.Equal("value2", doCommand2.Parameters!["key2"]);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -160,16 +160,16 @@ public void Test_AssistantRole_ToOpenAISdkChatMessage_FunctionCall()
}

[Fact]
public void Test_AssistantRole_ToOpenAISdkChatMessage_ToolCall()
public void Test_AssistantRole_ToOpenAISdkChatMessage_ActionCall()
{
// Arrange
var chatMessage = new ChatMessage(ChatRole.Assistant)
{
Content = "test-content",
Name = "test-name",
ToolCalls = new List<ChatCompletionsToolCall>()
ActionCalls = new List<ActionCall>()
{
new ChatCompletionsFunctionToolCall("test-id", "test-tool-name", "test-tool-arg1")
new ActionCall("test-id", new ActionFunction("test-tool-name", "test-tool-arg1"))
}
};

Expand All @@ -183,7 +183,7 @@ public void Test_AssistantRole_ToOpenAISdkChatMessage_ToolCall()
// TODO: Uncomment when participant name issue is resolved.
//Assert.Equal("test-name", assistantMessage.ParticipantName);

Assert.Equal(1, assistantMessage.ToolCalls.Count);
Assert.Single(assistantMessage.ToolCalls);
ChatToolCall toolCall = assistantMessage.ToolCalls[0];
Assert.NotNull(toolCall);
Assert.Equal("test-id", toolCall.Id);
Expand Down Expand Up @@ -239,7 +239,7 @@ public void Test_ToolRole_ToOpenAISdkChatMessage()
{
Content = "test-content",
Name = "tool-name",
ToolCallId = "tool-call-id"
ActionCallId = "action-call-id"
};

// Act
Expand All @@ -249,7 +249,7 @@ public void Test_ToolRole_ToOpenAISdkChatMessage()
var toolMessage = result as ToolChatMessage;
Assert.NotNull(toolMessage);
Assert.Equal("test-content", toolMessage.Content[0].Text);
Assert.Equal("tool-call-id", toolMessage.ToolCallId);
Assert.Equal("action-call-id", toolMessage.ToolCallId);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ public async Task Test_CompletePromptAsync_PromptResponse_DisableHistory()
Assert.NotNull(response.Message);
Assert.Equal(ChatRole.Assistant, response.Message.Role);
Assert.Equal("welcome", response.Message.Content);
Assert.Equal(1, memory.Values.Count);
Assert.Empty(memory.Values);
}

[Fact]
Expand Down Expand Up @@ -444,7 +444,7 @@ public async Task Test_CompletePromptAsync_PromptResponse_DisableRepair()
Assert.NotNull(response.Message);
Assert.Equal(ChatRole.Assistant, response.Message.Role);
Assert.Equal("welcome", response.Message.Content);
Assert.Equal(1, memory.Values.Count);
Assert.Single(memory.Values);
Assert.Equal("hello", memory.Values[options.InputVariable]);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

namespace Microsoft.Teams.AI.Tests.AITests.Models
{
internal sealed class ChatCompletionToolCallTests
public sealed class ChatCompletionToolCallTests
{
[Fact]
public void Test_ChatCompletionsToolCall_ToFunctionToolCall()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,9 @@ public void Test_AssistantRole_ToOpenAISdkChatMessage_ToolCall()
{
Content = "test-content",
Name = "test-name",
ToolCalls = new List<ChatCompletionsToolCall>()
ActionCalls = new List<ActionCall>()
{
new ChatCompletionsFunctionToolCall("test-id", "test-tool-name", "test-tool-arg1")
new ActionCall("test-id", new ActionFunction("test-tool-name", "test-tool-arg1"))
}
};

Expand All @@ -114,10 +114,8 @@ public void Test_AssistantRole_ToOpenAISdkChatMessage_ToolCall()
var assistantMessage = result as AssistantChatMessage;
Assert.NotNull(assistantMessage);
Assert.Equal("test-content", assistantMessage.Content[0].Text);
// TODO: Uncomment when participant name issue is resolved.
//Assert.Equal("test-name", assistantMessage.ParticipantName);

Assert.Equal(1, assistantMessage.ToolCalls.Count);
Assert.Single(assistantMessage.ToolCalls);
ChatToolCall toolCall = assistantMessage.ToolCalls[0];
Assert.NotNull(toolCall);
Assert.Equal("test-id", toolCall.Id);
Expand Down Expand Up @@ -173,7 +171,7 @@ public void Test_ToolRole_ToOpenAISdkChatMessage()
{
Content = "test-content",
Name = "tool-name",
ToolCallId = "tool-call-id"
ActionCallId = "tool-call-id"
};

// Act
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
using ChatMessage = Microsoft.Teams.AI.AI.Models.ChatMessage;
using ChatRole = Microsoft.Teams.AI.AI.Models.ChatRole;
using Azure.Identity;
using Microsoft.Teams.AI.AI.Augmentations;

namespace Microsoft.Teams.AI.Tests.AITests.Models
{
Expand Down Expand Up @@ -207,5 +208,79 @@ public async void Test_CompletePromptAsync_AzureOpenAI_Chat()
Assert.Equal("test-choice", result.Message.Content);
}

[Fact]
public async void Test_CompletePromptAsync_AzureOpenAI_Chat_WithTools()
{
// Arrange
var turnContextMock = new Mock<ITurnContext>();
var turnStateMock = new Mock<TurnState>();
var renderedPrompt = new RenderedPromptSection<List<ChatMessage>>(new List<ChatMessage>(), length: 256, tooLong: false);
var promptMock = new Mock<Prompt>(new List<PromptSection>(), -1, true, "\n\n");
promptMock.Setup((prompt) => prompt.RenderAsMessagesAsync(
It.IsAny<ITurnContext>(), It.IsAny<IMemory>(), It.IsAny<IPromptFunctions<List<string>>>(),
It.IsAny<ITokenizer>(), It.IsAny<int>(), It.IsAny<CancellationToken>())).ReturnsAsync(renderedPrompt);
var promptTemplate = new PromptTemplate("test-prompt", promptMock.Object)
{
Actions = new List<ChatCompletionAction>() { new ChatCompletionAction() { Name = "testAction" } },
Augmentation = new ToolsAugmentation(),
Configuration = new PromptTemplateConfiguration()
{
Augmentation = new AugmentationConfiguration() {
Type = AugmentationType.Tools
}
}
};
var options = new AzureOpenAIModelOptions("test-key", "test-deployment", "https://test.openai.azure.com/")
{
CompletionType = CompletionConfiguration.CompletionType.Chat,
LogRequests = true,
};
var clientMock = new Mock<OpenAIClient>();
var chatCompletion = ModelReaderWriter.Read<ChatCompletion>(BinaryData.FromString(@$"{{
""choices"": [
{{
""finish_reason"": ""stop"",
""message"": {{
""role"": ""assistant"",
""content"": null,
""tool_calls"": [
{{
""id"": ""call_abc123"",
""type"": ""function"",
""function"": {{
""name"": ""testAction"",
""arguments"": ""{{}}""
}}
}}
]
}}
}}
]
}}"));
var response = new TestResponse(200, string.Empty);
clientMock.Setup((client) =>
client
.GetChatClient(It.IsAny<string>())
.CompleteChatAsync(It.IsAny<IEnumerable<OAIChatMessage>>(), It.IsAny<ChatCompletionOptions>(), It.IsAny<CancellationToken>())
).ReturnsAsync(ClientResult.FromValue(chatCompletion!, response));

var openAIModel = new OpenAIModel(options, loggerFactory: new TestLoggerFactory());
openAIModel.GetType().GetField("_openAIClient", BindingFlags.Instance | BindingFlags.NonPublic)!.SetValue(openAIModel, clientMock.Object);

// Act
var result = await openAIModel.CompletePromptAsync(turnContextMock.Object, turnStateMock.Object, new PromptManager(), new GPTTokenizer(), promptTemplate);

// Assert
Assert.Equal(PromptResponseStatus.Success, result.Status);
Assert.NotNull(result.Message);

Assert.NotNull(result.Message.ActionCalls);
Assert.Single(result.Message.ActionCalls);
Assert.Equal("testAction", result.Message.ActionCalls[0].Function.Name);

Assert.Null(result.Error);
Assert.Equal(ChatRole.Assistant, result.Message.Role);
Assert.Null(result.Message.Content);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ public void Test_ToJsonString_Complex()
// Arrange
Plan plan = new();
plan.Commands.Add(new PredictedSayCommand("Hello"));
plan.Commands.Add(new PredictedDoCommand("DoSomething", new() { { "prop", "value" } }));
plan.Commands.Add(new PredictedDoCommand("DoSomething", new Dictionary<string, object?>() { { "prop", "value" } }));

// Note: This is not a formatting error. It is formatted this way to match the expected string.
string expectedPlanJson = @"{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,9 @@ public async void Test_RenderAsMessagesAsync_ShoulderRender()

// Assert
RenderedPromptSection<List<ChatMessage>> rendered = await section.RenderAsMessagesAsync(context.Object, memory, manager, tokenizer, 50);
Assert.Equal("you are a unit test bot", rendered.Output[2].GetContent<string>());
Assert.Equal("you are a unit test bot", rendered.Output[0].GetContent<string>());
Assert.Equal("hi", rendered.Output[1].GetContent<string>());
Assert.Equal("hi, how may I assist you?", rendered.Output[0].GetContent<string>());
Assert.Equal("hi, how may I assist you?", rendered.Output[2].GetContent<string>());
Assert.Equal(15, rendered.Length);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,11 +189,13 @@ internal sealed class TestAsyncEnumerator<T> : IAsyncEnumerator<PageResult<T>> w
{
private readonly List<T> _items;
private readonly PipelineResponse _pipelineResponse;
private bool _movedOnToNext;

public TestAsyncEnumerator(List<T> items, PipelineResponse response)
{
_items = items;
_pipelineResponse = response;
_movedOnToNext = false;
}

public PageResult<T> Current => PageResult<T>.Create(_items, ContinuationToken.FromBytes(BinaryData.FromString("")), null, _pipelineResponse);
Expand All @@ -205,7 +207,16 @@ public ValueTask DisposeAsync()

public ValueTask<bool> MoveNextAsync()
{
return new ValueTask<bool>(false);
if (!_movedOnToNext)
{
return new ValueTask<bool>(true);
}
else
{
_movedOnToNext = true;
return new ValueTask<bool>(false);
}

}
}
}
Loading

0 comments on commit 9ab356b

Please sign in to comment.