Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[C#] feat: Migrate AssistantsPlanner to official OpenAI .NET SDK #1895

Merged
merged 7 commits into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,9 @@ dotnet_diagnostic.VSTHRD111.severity = none # Use .ConfigureAwait(bool)

# CA1859: Use concrete types when possible for improved performance
dotnet_diagnostic.CA1859.severity = none

# CS0618: Type or member is obsolete
dotnet_diagnostic.CS0618.severity = none

# IDE0090: Use 'new(...)'
dotnet_diagnostic.IDE0090.severity = none
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
using Moq;
using System.Reflection;
using Microsoft.Teams.AI.AI.Planners;
using Azure.AI.OpenAI.Assistants;
using OpenAI.Assistants;

namespace Microsoft.Teams.AI.Tests.AITests
{
Expand Down Expand Up @@ -99,7 +99,7 @@ public async Task Test_BeginTaskAsync_Assistant_WaitForPreviousRun()
testClient.RemainingMessages.Enqueue("welcome");

AssistantThread thread = await testClient.CreateThreadAsync(new(), CancellationToken.None);
await testClient.CreateRunAsync(thread.Id, AssistantsModelFactory.CreateRunOptions(), CancellationToken.None);
await testClient.CreateRunAsync(thread.Id, "", OpenAIModelFactory.CreateRunOptions(), CancellationToken.None);
turnState.ThreadId = thread.Id;

// Act
Expand Down Expand Up @@ -219,8 +219,7 @@ public async Task Test_ContinueTaskAsync_Assistant_RequiresAction()
var aiOptions = new AIOptions<AssistantsState>(planner);
var ai = new AI<AssistantsState>(aiOptions);

var functionToolCall = AssistantsModelFactory.RequiredFunctionToolCall("test-tool-id", "test-action", "{}");
var requiredAction = AssistantsModelFactory.SubmitToolOutputsAction(new List<RequiredToolCall>{ functionToolCall });
var requiredAction = OpenAIModelFactory.CreateRequiredAction("test-tool-id", "test-action", "{}");

testClient.RemainingActions.Enqueue(requiredAction);
testClient.RemainingRunStatus.Enqueue("requires_action");
Expand Down Expand Up @@ -267,8 +266,7 @@ public async Task Test_ContinueTaskAsync_Assistant_IgnoreRedundantAction()
var aiOptions = new AIOptions<AssistantsState>(planner);
var ai = new AI<AssistantsState>(aiOptions);

var functionToolCall = AssistantsModelFactory.RequiredFunctionToolCall("test-tool-id", "test-action", "{}");
var requiredAction = AssistantsModelFactory.SubmitToolOutputsAction(new List<RequiredToolCall> { functionToolCall });
var requiredAction = OpenAIModelFactory.CreateRequiredAction("test-tool-id", "test-action", "{}");

testClient.RemainingActions.Enqueue(requiredAction);
testClient.RemainingRunStatus.Enqueue("requires_action");
Expand Down Expand Up @@ -316,9 +314,9 @@ public async Task Test_ContinueTaskAsync_Assistant_MultipleMessages()
var ai = new AI<AssistantsState>(aiOptions);

testClient.RemainingRunStatus.Enqueue("completed");
testClient.RemainingMessages.Enqueue("message 2");
testClient.RemainingMessages.Enqueue("message 1");
testClient.RemainingMessages.Enqueue("welcome");
testClient.RemainingMessages.Enqueue("message 1");
testClient.RemainingMessages.Enqueue("message 2");

// Act
var plan = await planner.ContinueTaskAsync(turnContextMock.Object, turnState, ai, CancellationToken.None);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ public async void Test_CreateOAuthCard_WithSSOEnabled()

var turnState = await TurnStateConfig.GetTurnStateWithConversationStateAsync(turnContext);
var app = new TestApplication(new() { Adapter = testAdapter });
var authSettings = new OAuthSettings() {
var authSettings = new OAuthSettings()
{
ConnectionName = "connectionName",
Title = "title",
Text = "text",
Expand Down Expand Up @@ -131,7 +132,7 @@ public async void Test_VerifyStateRouteSelector_ReturnsTrue()
}

[Fact]
public async void Test_VerifyStateRouteSelector_IncorrectActivity_ReturnsFalse ()
public async void Test_VerifyStateRouteSelector_IncorrectActivity_ReturnsFalse()
{
// Arrange
var testAdapter = new SimpleAdapter();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
using Microsoft.Bot.Schema;
using Microsoft.Teams.AI.State;
using Microsoft.Teams.AI.Tests.TestUtils;
using Record = Microsoft.Teams.AI.State.Record;

namespace Microsoft.Teams.AI.Tests.Application
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
using Moq;
using Newtonsoft.Json;
using Newtonsoft.Json.Linq;
using Record = Microsoft.Teams.AI.State.Record;

namespace Microsoft.Teams.AI.Tests.Application
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
using Microsoft.Teams.AI.Tests.TestUtils;
using Moq;
using Newtonsoft.Json.Linq;
using Record = Microsoft.Teams.AI.State.Record;

namespace Microsoft.Teams.AI.Tests.Application
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,21 @@
using Moq;
using System.Reflection;
using Xunit.Abstractions;
using Microsoft.Extensions.Logging;
using Microsoft.Bot.Schema;
using Microsoft.Teams.AI.AI.Tokenizers;
using Microsoft.Teams.AI.AI.Prompts;
using Microsoft.Teams.AI.AI.Prompts.Sections;
using Microsoft.Extensions.Logging;

namespace Microsoft.Teams.AI.Tests.IntegrationTests
{
public sealed class OpenAIModelTests
{
private readonly IConfigurationRoot _configuration;
private readonly RedirectOutput _output;
#pragma warning disable IDE0052 // Remove unread private members
private readonly ILoggerFactory _loggerFactory;
#pragma warning restore IDE0052 // Remove unread private members

public OpenAIModelTests(ITestOutputHelper output)
{
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
using OpenAI.Assistants;
using System.ClientModel;
using System.ClientModel.Primitives;

namespace Microsoft.Teams.AI.Tests.TestUtils
{
internal sealed class OpenAIModelFactory
{
public static RunCreationOptions CreateRunOptions()
{
return new RunCreationOptions();
}

public static RequiredAction CreateRequiredAction(string toolCallId, string functionName, string functionArguments)
{
return new TestRequiredAction(toolCallId, functionName, functionArguments);
}

public static Assistant CreateAssistant()
{
return ModelReaderWriter.Read<Assistant>(BinaryData.FromString(@$"{{
""id"": ""{Guid.NewGuid()}"",
""object"": ""assistant"",
""created_at"": {DateTime.Now.Second}
}}"))!;
}

public static AssistantThread CreateAssistantThread(string guid, DateTimeOffset offset)
{
return ModelReaderWriter.Read<AssistantThread>(BinaryData.FromString(@$"{{
""id"": ""{guid}"",
""created_at"": {offset.Second}
}}"))!;
}

public static ThreadMessage CreateThreadMessage(string threadId, string message)
{
var json = @$"{{
""id"": ""{Guid.NewGuid()}"",
""thread_id"": ""{threadId}"",
""created_at"": {DateTime.Now.Second},
""content"": [
{{
""type"": ""text"",
""text"": {{
""value"": ""{message}"",
""annotations"": []
}}
}}
]
}}";
return ModelReaderWriter.Read<ThreadMessage>(BinaryData.FromString(json))!;
}

public static ThreadRun CreateThreadRun(string threadId, string runStatus, string? runId = null, IList<RequiredAction> requiredActions = null!)
{
var raJson = "{}";
if (requiredActions != null && requiredActions.Count > 0)
{
var toolCalls = requiredActions.Select((requiredAction) =>
{
var ra = (TestRequiredAction)requiredAction;
return $@"{{
""id"": ""{ra.ToolCallId}"",
""type"": ""function"",
""function"": {{
""name"": ""{ra.FunctionName}"",
""arguments"": ""{ra.FunctionArguments}""
}}
}}";
});

raJson = $@"{{
""type"": ""submit_tool_outputs"",
""submit_tool_outputs"": {{
""tool_calls"": [
{string.Join(",", toolCalls)}
]
}}
}}
";
}

return ModelReaderWriter.Read<ThreadRun>(BinaryData.FromString(@$"{{
""id"": ""{runId ?? Guid.NewGuid().ToString()}"",
""thread_id"": ""{threadId}"",
""created_at"": {DateTime.Now.Second},
""status"": ""{runStatus}"",
""required_action"": {raJson}
}}"))!;
}
}

internal sealed class TestRequiredAction : RequiredAction
{
public new string FunctionName;

public new string FunctionArguments;

public new string ToolCallId;

public TestRequiredAction(string toolCallId, string functionName, string functionArguments)
{
this.FunctionName = functionName;
this.FunctionArguments = functionArguments;
this.ToolCallId = toolCallId;
}
}

internal sealed class TestAsyncPageableCollection<T> : AsyncPageableCollection<T> where T : class
{
public List<T> Items;

internal PipelineResponse _pipelineResponse;

public TestAsyncPageableCollection(List<T> items, PipelineResponse response)
{
Items = items;
_pipelineResponse = response;
}

#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously
public override async IAsyncEnumerable<ResultPage<T>> AsPages(string? continuationToken = null, int? pageSizeHint = null)
#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously
{
yield return ResultPage<T>.Create(Items, null, _pipelineResponse);
}
}
}
Loading
Loading