Skip to content

Commit

Permalink
[C#] port: TeamsAttachmentDownloader, vision support...etc (#1723)
Browse files Browse the repository at this point in the history
## Linked issues

closes: #1011 #1182 (issue number)

## Details

Implements the TeamsAttachmentDownloader and OpenAI/AOAI Vision model
support

#### Change details

* Bumped `Azure.AI.OpenAI` to version `1.0.0-beta.17`
* Implemented TeamsAttachmentDownloader
* Auxilliary changes in the `Application/` and `AI/` code.
* Implemented and tested CardGazer sample.

##### Breaking Change

* `ChatMessage.Content` is changed from `string?` to `object?` since it
can hold a `string` or `IEnumerable<MessageContentParts>`

## Attestation Checklist

- [x] My code follows the style guidelines of this project

- I have checked for/fixed spelling, linting, and other errors
- I have commented my code for clarity
- I have made corresponding changes to the documentation (updating the
doc strings in the code is sufficient)
- My changes generate no new warnings
- I have added tests that validates my changes, and provides sufficient
test coverage. I have tested with:
  - Local testing
  - E2E testing in Teams
- New and existing unit tests pass locally with my changes
  • Loading branch information
singhk97 authored Jun 11, 2024
1 parent b69535c commit a82f18a
Show file tree
Hide file tree
Showing 68 changed files with 1,810 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ public string DoCommand([ActionName] string action)
[Action(AIConstants.SayCommandActionName)]
public string SayCommand([ActionParameters] PredictedSayCommand command)
{
SayActionRecord.Add(command.Response.Content);
SayActionRecord.Add(command.Response.GetContent<string>());
return string.Empty;
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
using Microsoft.Teams.AI.AI.Models;

namespace Microsoft.Teams.AI.Tests.AITests
{
public class ChatMessageTests
{
[Fact]
public void Test_Get_Content()
{
// Arrange
ChatMessage msg = new(ChatRole.Assistant);
msg.Content = "test";

// Act
var content = msg.GetContent<string>();

// Assert
Assert.Equal("test", content);
}

[Fact]
public void Test_Get_Content_TypeMismatch_ThrowsException()
{
// Arrange
ChatMessage msg = new(ChatRole.Assistant);
msg.Content = "test";

// Act & Assert
Assert.Throws<InvalidCastException>(() => msg.GetContent<bool>());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,10 @@ public async Task Test_CompletePromptAsync_PromptResponse_Success()
}
});

memory.SetValue("temp.input", "hello");

// Act
var response = await client.CompletePromptAsync(new Mock<ITurnContext>().Object, memory, new PromptManager(), "hello");
var response = await client.CompletePromptAsync(new Mock<ITurnContext>().Object, memory, new PromptManager());

// Assert
Assert.NotNull(response);
Expand Down Expand Up @@ -161,7 +163,7 @@ public async Task Test_CompletePromptAsync_PromptResponse_Exception()
TestMemory memory = new();

// Act
var response = await client.CompletePromptAsync(new Mock<ITurnContext>().Object, memory, new PromptManager(), "hello");
var response = await client.CompletePromptAsync(new Mock<ITurnContext>().Object, memory, new PromptManager());

// Assert
Assert.NotNull(response);
Expand Down Expand Up @@ -211,8 +213,10 @@ public async Task Test_CompletePromptAsync_PromptResponse_Repair()
Valid = true
});

memory.SetValue("temp.input", "hello");

// Act
var response = await client.CompletePromptAsync(new Mock<ITurnContext>().Object, memory, new PromptManager(), "hello");
var response = await client.CompletePromptAsync(new Mock<ITurnContext>().Object, memory, new PromptManager());

// Assert
Assert.NotNull(response);
Expand Down Expand Up @@ -277,8 +281,10 @@ public async Task Test_CompletePromptAsync_PromptResponse_RepairNotSuccess()
Valid = true
});

memory.SetValue("temp.input", "hello");

// Act
var response = await client.CompletePromptAsync(new Mock<ITurnContext>().Object, memory, new PromptManager(), "hello");
var response = await client.CompletePromptAsync(new Mock<ITurnContext>().Object, memory, new PromptManager());

// Assert
Assert.NotNull(response);
Expand Down Expand Up @@ -344,8 +350,10 @@ public async Task Test_CompletePromptAsync_PromptResponse_Repair_ExceedMaxRepair
Valid = true
});

memory.SetValue("temp.input", "hello");

// Act
var response = await client.CompletePromptAsync(new Mock<ITurnContext>().Object, memory, new PromptManager(), "hello");
var response = await client.CompletePromptAsync(new Mock<ITurnContext>().Object, memory, new PromptManager());

// Assert
Assert.NotNull(response);
Expand All @@ -368,7 +376,6 @@ public async Task Test_CompletePromptAsync_PromptResponse_DisableHistory()
LLMClientOptions<object> options = new(promptCompletionModel, promptTemplate)
{
HistoryVariable = string.Empty,
InputVariable = string.Empty
};
LLMClient<object> client = new(options, null);
TestMemory memory = new();
Expand All @@ -391,7 +398,7 @@ public async Task Test_CompletePromptAsync_PromptResponse_DisableHistory()
Assert.NotNull(response.Message);
Assert.Equal(ChatRole.Assistant, response.Message.Role);
Assert.Equal("welcome", response.Message.Content);
Assert.Equal(0, memory.Values.Count);
Assert.Equal(1, memory.Values.Count);
}

[Fact]
Expand Down Expand Up @@ -425,8 +432,10 @@ public async Task Test_CompletePromptAsync_PromptResponse_DisableRepair()
Valid = false
});

memory.SetValue("temp.input", "hello");

// Act
var response = await client.CompletePromptAsync(new Mock<ITurnContext>().Object, memory, new PromptManager(), "hello");
var response = await client.CompletePromptAsync(new Mock<ITurnContext>().Object, memory, new PromptManager());

// Assert
Assert.NotNull(response);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@ public class ChatMessageExtensionsTests
public void Test_InvalidRole_ToAzureSdkChatMessage()
{
// Arrange
var chatMessage = new ChatMessage(new AI.Models.ChatRole("InvalidRole"));
var chatMessage = new ChatMessage(new AI.Models.ChatRole("InvalidRole"))
{
Content = "test"
};

// Act
var ex = Assert.Throws<TeamsAIException>(() => chatMessage.ToChatRequestMessage());
Expand All @@ -20,7 +23,7 @@ public void Test_InvalidRole_ToAzureSdkChatMessage()
}

[Fact]
public void Test_UserRole_ToAzureSdkChatMessage()
public void Test_UserRole_StringContent_ToAzureSdkChatMessage()
{
// Arrange
var chatMessage = new ChatMessage(AI.Models.ChatRole.User)
Expand All @@ -39,6 +42,32 @@ public void Test_UserRole_ToAzureSdkChatMessage()
Assert.Equal("author", ((ChatRequestUserMessage)result).Name);
}

[Fact]
public void Test_UserRole_MultiModalContent_ToAzureSdkChatMessage()
{
// Arrange
var messageContentParts = new List<MessageContentParts>() { new TextContentPart() { Text = "test" }, new ImageContentPart { ImageUrl = "https://www.testurl.com" } };
var chatMessage = new ChatMessage(AI.Models.ChatRole.User)
{
Content = messageContentParts,
Name = "author"
};

// Act
var result = chatMessage.ToChatRequestMessage();

// Assert
Assert.Equal(Azure.AI.OpenAI.ChatRole.User, result.Role);
Assert.Equal(typeof(ChatRequestUserMessage), result.GetType());

var userMessage = (ChatRequestUserMessage)result;

Assert.Equal(null, userMessage.Content);
Assert.Equal("test", ((ChatMessageTextContentItem)userMessage.MultimodalContentItems[0]).Text);
Assert.Equal(typeof(ChatMessageImageContentItem), userMessage.MultimodalContentItems[1].GetType());
Assert.Equal("author", userMessage.Name);
}

[Fact]
public void Test_AssistantRole_ToAzureSdkChatMessage()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ public async void Test_CompletePromptAsync_AzureOpenAI_Text_RequestFailed()
// Assert
Assert.Equal(PromptResponseStatus.Error, result.Status);
Assert.NotNull(result.Error);
Assert.Equal("The text completion API returned an error status of InternalServerError: Service request failed.\r\nStatus: 500 (exception)\r\n\r\nHeaders:\r\n", result.Error.Message);
Assert.True(result.Error.Message.StartsWith("The text completion API returned an error status of InternalServerError: Service request failed.\r\nStatus: 500 (exception)"));
}

[Fact]
Expand Down Expand Up @@ -273,7 +273,7 @@ public async void Test_CompletePromptAsync_AzureOpenAI_Chat_RequestFailed()
// Assert
Assert.Equal(PromptResponseStatus.Error, result.Status);
Assert.NotNull(result.Error);
Assert.Equal("The chat completion API returned an error status of InternalServerError: Service request failed.\r\nStatus: 500 (exception)\r\n\r\nHeaders:\r\n", result.Error.Message);
Assert.True(result.Error.Message.StartsWith("The chat completion API returned an error status of InternalServerError: Service request failed.\r\nStatus: 500 (exception)"));
}

[Fact]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ public class ConversationHistorySectionTests
[Fact]
public async void Test_RenderAsTextAsync_ShouldRender()
{
// Arrange
ConversationHistorySection section = new("history");
Mock<ITurnContext> context = new();
MemoryFork memory = new();
Expand All @@ -26,23 +27,57 @@ public async void Test_RenderAsTextAsync_ShouldRender()
new(ChatRole.Assistant) { Content = "hi, how may I assist you?" }
});

// Act
RenderedPromptSection<string> rendered = await section.RenderAsTextAsync(context.Object, memory, manager, tokenizer, 50);

// Assert
Assert.Equal("assistant: hi, how may I assist you?\nuser: hi\nyou are a unit test bot", rendered.Output);
Assert.Equal(21, rendered.Length);
}

[Fact]
public async void Test_RenderAsTextAsync_ShouldRenderEmpty()
{
// Arrange
ConversationHistorySection section = new("history");
Mock<ITurnContext> context = new();
MemoryFork memory = new();
GPTTokenizer tokenizer = new();
PromptManager manager = new();

// Act
RenderedPromptSection<string> rendered = await section.RenderAsTextAsync(context.Object, memory, manager, tokenizer, 50);

// Assert
Assert.Equal("", rendered.Output);
Assert.Equal(0, rendered.Length);
}


[Fact]
public async void Test_RenderAsMessagesAsync_ShoulderRender()
{
// Arrange
ConversationHistorySection section = new("history");
Mock<ITurnContext> context = new();
MemoryFork memory = new();
GPTTokenizer tokenizer = new();
PromptManager manager = new();

// Act
memory.SetValue("history", new List<ChatMessage>()
{
new(ChatRole.System) { Content = "you are a unit test bot" },
new(ChatRole.User) { Content = "hi" },
new(ChatRole.Assistant) { Content = "hi, how may I assist you?" }
});

// Assert
RenderedPromptSection<List<ChatMessage>> rendered = await section.RenderAsMessagesAsync(context.Object, memory, manager, tokenizer, 50);
Assert.Equal("you are a unit test bot", rendered.Output[2].GetContent<string>());
Assert.Equal("hi", rendered.Output[1].GetContent<string>());
Assert.Equal("hi, how may I assist you?", rendered.Output[0].GetContent<string>());
Assert.Equal(15, rendered.Length);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ public override async Task<RenderedPromptSection<List<ChatMessage>>> RenderAsMes

return await Task.FromResult(this.TruncateMessages(messages, tokenizer, maxTokens));
}

public string GetMessage(ChatMessage message)
{
return this.GetMessageText(message);
}
}

public class PromptSectionTests
Expand Down Expand Up @@ -53,5 +58,29 @@ public async void Test_RenderAsTextAsync_ShouldTruncate()
Assert.Equal("Hello World", rendered.Output);
Assert.Equal(2, rendered.Length);
}

[Fact]
public void Test_GetMessage()
{
ChatMessage message = new(ChatRole.User)
{
Content = new List<MessageContentParts>()
{
new TextContentPart()
{
Text = "Hello",
},

new TextContentPart()
{
Text = "World"
}
}
};

string msg = new TestSection().GetMessage(message);

Assert.Equal("Hello World", msg);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
using Microsoft.Bot.Builder;
using Microsoft.Teams.AI.AI.Models;
using Microsoft.Teams.AI.AI.Prompts.Sections;
using Microsoft.Teams.AI.AI.Prompts;
using Microsoft.Teams.AI.AI.Tokenizers;
using Microsoft.Teams.AI.State;
using Moq;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Microsoft.Teams.AI.Application;
using static System.Net.Mime.MediaTypeNames;

namespace Microsoft.Teams.AI.Tests.AITests.PromptsTests.SectionsTests
{
public class UserInputMessageSectionTest
{
[Fact]
public async void Test_RenderAsMessagesAsync_ShoulderRender()
{
// Arrange
UserInputMessageSection section = new();
Mock<ITurnContext> context = new();
MemoryFork memory = new();
GPTTokenizer tokenizer = new();
PromptManager manager = new();

// Act
memory.SetValue("input", "hi");

memory.SetValue("inputFiles", new List<InputFile>()
{
new(BinaryData.FromString("testData"), "image/png")
});

// Assert
RenderedPromptSection<List<ChatMessage>> rendered = await section.RenderAsMessagesAsync(context.Object, memory, manager, tokenizer, 200);
var messageContentParts = rendered.Output[0].GetContent<List<MessageContentParts>>();

Assert.Equal("hi", ((TextContentPart)messageContentParts[0]).Text);

// the base64 string is an encoding of "hi"
var imageUrl = $"data:image/png;base64,dGVzdERhdGE=";
Assert.Equal(imageUrl, ((ImageContentPart)messageContentParts[1]).ImageUrl);

Assert.Equal(86, rendered.Length);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ public async Task OpenAIModel_CompletePrompt(string input, string expectedAnswer

// Assert
Assert.Equal(PromptResponseStatus.Success, result.Status);
Assert.Contains(expectedAnswer, result.Message!.Content);
Assert.Contains(expectedAnswer, result.Message!.GetContent<string>());
}
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
<Project Sdk="Microsoft.NET.Sdk">
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<TargetFramework>net6.0</TargetFramework>
Expand All @@ -7,10 +7,11 @@

<IsPackable>false</IsPackable>
<IsTestProject>true</IsTestProject>
<PlatformTarget>x64</PlatformTarget>
</PropertyGroup>

<ItemGroup>
<PackageReference Include="Azure.AI.OpenAI" Version="1.0.0-beta.15" />
<PackageReference Include="Azure.AI.OpenAI" Version="1.0.0-beta.17" />
<PackageReference Include="Microsoft.Bcl.AsyncInterfaces" Version="7.0.0" />
<PackageReference Include="Microsoft.Bot.Builder" Version="4.22.3" />
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.9.0" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ await _actions[AIConstants.TooManyStepsActionName]

// Copy the actions output to the input
turnState.Temp!.Input = output;
turnState.Temp.InputFiles = new();
}

// Check for looping
Expand Down
Loading

0 comments on commit a82f18a

Please sign in to comment.