Skip to content

Commit

Permalink
Refactoring work to move to Azure.AI.OpenAI v2.1.0 (#328)
Browse files Browse the repository at this point in the history
  • Loading branch information
daxian-dbw authored Jan 18, 2025
1 parent 80bccd7 commit 33009f1
Show file tree
Hide file tree
Showing 5 changed files with 190 additions and 255 deletions.
7 changes: 4 additions & 3 deletions shell/agents/AIShell.OpenAI.Agent/AIShell.OpenAI.Agent.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@
</PropertyGroup>

<ItemGroup>
<PackageReference Include="Azure.AI.OpenAI" Version="1.0.0-beta.17" />
<PackageReference Include="Azure.Core" Version="1.39.0" />
<PackageReference Include="SharpToken" Version="2.0.3" />
<PackageReference Include="Azure.AI.OpenAI" Version="2.1.0" />
<PackageReference Include="Microsoft.ML.Tokenizers" Version="1.0.1" />
<PackageReference Include="Microsoft.ML.Tokenizers.Data.O200kBase" Version="1.0.1" />
<PackageReference Include="Microsoft.ML.Tokenizers.Data.Cl100kBase" Version="1.0.1" />
</ItemGroup>

<ItemGroup>
Expand Down
36 changes: 22 additions & 14 deletions shell/agents/AIShell.OpenAI.Agent/Agent.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
using System.ClientModel;
using System.Text;
using System.Text.Json;
using Azure.AI.OpenAI;
using AIShell.Abstraction;
using OpenAI.Chat;

namespace AIShell.OpenAI.Agent;

Expand Down Expand Up @@ -106,37 +107,44 @@ public async Task<bool> ChatAsync(string input, IShell shell)
return checkPass;
}

string responseContent = null;
StreamingResponse<StreamingChatCompletionsUpdate> response = await host.RunWithSpinnerAsync(
() => _chatService.GetStreamingChatResponseAsync(input, token)
).ConfigureAwait(false);
IAsyncEnumerator<StreamingChatCompletionUpdate> response = await host
.RunWithSpinnerAsync(
() => _chatService.GetStreamingChatResponseAsync(input, token)
).ConfigureAwait(false);

if (response is not null)
{
StreamingChatCompletionUpdate update = null;
using var streamingRender = host.NewStreamRender(token);

try
{
await foreach (StreamingChatCompletionsUpdate chatUpdate in response)
do
{
if (string.IsNullOrEmpty(chatUpdate.ContentUpdate))
update = response.Current;
if (update.ContentUpdate.Count > 0)
{
continue;
streamingRender.Refresh(update.ContentUpdate[0].Text);
}

streamingRender.Refresh(chatUpdate.ContentUpdate);
}
while (await response.MoveNextAsync().ConfigureAwait(continueOnCapturedContext: false));
}
catch (OperationCanceledException)
{
// Ignore the cancellation exception.
update = null;
}

responseContent = streamingRender.AccumulatedContent;
if (update is null)
{
_chatService.CalibrateChatHistory(usage: null, response: null);
}
else
{
string responseContent = streamingRender.AccumulatedContent;
_chatService.CalibrateChatHistory(update.Usage, new AssistantChatMessage(responseContent));
}
}

_chatService.AddResponseToHistory(responseContent);

return checkPass;
}

Expand Down
75 changes: 14 additions & 61 deletions shell/agents/AIShell.OpenAI.Agent/Helpers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,7 @@
using System.Text.Json;
using System.Text.Json.Serialization;
using System.Text.Json.Serialization.Metadata;

using Azure;
using Azure.Core;
using Azure.Core.Pipeline;
using System.ClientModel.Primitives;

namespace AIShell.OpenAI.Agent;

Expand Down Expand Up @@ -134,69 +131,25 @@ public override JsonTypeInfo GetTypeInfo(Type type, JsonSerializerOptions option
}
}

#nullable enable

/// <summary>
/// Used for setting user key for the Azure.OpenAI.Client.
/// </summary>
internal sealed class UserKeyPolicy : HttpPipelineSynchronousPolicy
{
private readonly string _name;
private readonly AzureKeyCredential _credential;

/// <summary>
/// Initializes a new instance of the <see cref="UserKeyPolicy"/> class.
/// </summary>
/// <param name="credential">The <see cref="AzureKeyCredential"/> used to authenticate requests.</param>
/// <param name="name">The name of the key header used for the credential.</param>
public UserKeyPolicy(AzureKeyCredential credential, string name)
{
ArgumentNullException.ThrowIfNull(credential);
ArgumentException.ThrowIfNullOrEmpty(name);

_credential = credential;
_name = name;
}

/// <inheritdoc/>
public override void OnSendingRequest(HttpMessage message)
{
base.OnSendingRequest(message);
message.Request.Headers.SetValue(_name, _credential.Key);
}
}

/// <summary>
/// Used for configuring the retry policy for Azure.OpenAI.Client.
/// Initializes a new instance of the <see cref="ChatRetryPolicy"/> class.
/// </summary>
internal sealed class ChatRetryPolicy : RetryPolicy
/// <param name="maxRetries">The maximum number of retries to attempt.</param>
/// <param name="delayStrategy">The delay to use for computing the interval between retry attempts.</param>
internal sealed class ChatRetryPolicy(int maxRetries = 2) : ClientRetryPolicy(maxRetries)
{
private const string RetryAfterHeaderName = "Retry-After";
private const string RetryAfterMsHeaderName = "retry-after-ms";
private const string XRetryAfterMsHeaderName = "x-ms-retry-after-ms";

/// <summary>
/// Initializes a new instance of the <see cref="ChatRetryPolicy"/> class.
/// </summary>
/// <param name="maxRetries">The maximum number of retries to attempt.</param>
/// <param name="delayStrategy">The delay to use for computing the interval between retry attempts.</param>
public ChatRetryPolicy(int maxRetries = 2, DelayStrategy? delayStrategy = default) : base(
maxRetries,
delayStrategy ?? DelayStrategy.CreateExponentialDelayStrategy(
initialDelay: TimeSpan.FromSeconds(0.8),
maxDelay: TimeSpan.FromSeconds(5)))
{
// By default, we retry 2 times at most, and use a delay strategy that waits 5 seconds at most between retries.
}

protected override bool ShouldRetry(HttpMessage message, Exception? exception) => ShouldRetryImpl(message, exception);
protected override ValueTask<bool> ShouldRetryAsync(HttpMessage message, Exception? exception) => new(ShouldRetryImpl(message, exception));
protected override bool ShouldRetry(PipelineMessage message, Exception exception) => ShouldRetryImpl(message, exception);
protected override ValueTask<bool> ShouldRetryAsync(PipelineMessage message, Exception exception) => new(ShouldRetryImpl(message, exception));

private bool ShouldRetryImpl(HttpMessage message, Exception? exception)
private bool ShouldRetryImpl(PipelineMessage message, Exception exception)
{
bool result = base.ShouldRetry(message, exception);

if (result && message.HasResponse)
if (result && message.Response is not null)
{
TimeSpan? retryAfter = GetRetryAfterHeaderValue(message.Response.Headers);
if (retryAfter > TimeSpan.FromSeconds(5))
Expand All @@ -209,22 +162,22 @@ private bool ShouldRetryImpl(HttpMessage message, Exception? exception)
return result;
}

private static TimeSpan? GetRetryAfterHeaderValue(ResponseHeaders headers)
private static TimeSpan? GetRetryAfterHeaderValue(PipelineResponseHeaders headers)
{
if (headers.TryGetValue(RetryAfterMsHeaderName, out var retryAfterValue) ||
headers.TryGetValue(XRetryAfterMsHeaderName, out retryAfterValue))
{
if (int.TryParse(retryAfterValue, out var delaySeconds))
if (int.TryParse(retryAfterValue, out var delayInMS))
{
return TimeSpan.FromMilliseconds(delaySeconds);
return TimeSpan.FromMilliseconds(delayInMS);
}
}

if (headers.TryGetValue(RetryAfterHeaderName, out retryAfterValue))
{
if (int.TryParse(retryAfterValue, out var delaySeconds))
if (int.TryParse(retryAfterValue, out var delayInSec))
{
return TimeSpan.FromSeconds(delaySeconds);
return TimeSpan.FromSeconds(delayInSec);
}

if (DateTimeOffset.TryParse(retryAfterValue, out DateTimeOffset delayTime))
Expand Down
24 changes: 12 additions & 12 deletions shell/agents/AIShell.OpenAI.Agent/ModelInfo.cs
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
using SharpToken;
using Microsoft.ML.Tokenizers;

namespace AIShell.OpenAI.Agent;

internal class ModelInfo
{
// Models gpt4, gpt3.5, and the variants of them all use the 'cl100k_base' token encoding.
// But the gpt-4o model uses the 'o200k_base' token encoding. For reference:
// https://github.com/openai/tiktoken/blob/5d970c1100d3210b42497203d6b5c1e30cfda6cb/tiktoken/model.py#L7
// https://github.com/dmitry-brazhenko/SharpToken/blob/main/SharpToken/Lib/Model.cs#L8
// But gpt-4o and o1 models use the 'o200k_base' token encoding. For reference:
// https://github.com/openai/tiktoken/blob/63527649963def8c759b0f91f2eb69a40934e468/tiktoken/model.py
private const string Gpt4oEncoding = "o200k_base";
private const string Gpt34Encoding = "cl100k_base";

private static readonly Dictionary<string, ModelInfo> s_modelMap;
private static readonly Dictionary<string, Task<GptEncoding>> s_encodingMap;
private static readonly Dictionary<string, Task<Tokenizer>> s_encodingMap;

static ModelInfo()
{
Expand All @@ -21,6 +20,7 @@ static ModelInfo()
// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
s_modelMap = new(StringComparer.OrdinalIgnoreCase)
{
["o1"] = new(tokenLimit: 200_000, encoding: Gpt4oEncoding),
["gpt-4o"] = new(tokenLimit: 128_000, encoding: Gpt4oEncoding),
["gpt-4"] = new(tokenLimit: 8_192),
["gpt-4-32k"] = new(tokenLimit: 32_768),
Expand All @@ -35,8 +35,8 @@ static ModelInfo()
// we don't block the startup and the values will be ready when we really need them.
s_encodingMap = new(StringComparer.OrdinalIgnoreCase)
{
[Gpt34Encoding] = Task.Run(() => GptEncoding.GetEncoding(Gpt34Encoding)),
[Gpt4oEncoding] = Task.Run(() => GptEncoding.GetEncoding(Gpt4oEncoding))
[Gpt34Encoding] = Task.Run(() => (Tokenizer)TiktokenTokenizer.CreateForEncoding(Gpt34Encoding)),
[Gpt4oEncoding] = Task.Run(() => (Tokenizer)TiktokenTokenizer.CreateForEncoding(Gpt4oEncoding))
};
}

Expand All @@ -45,24 +45,24 @@ private ModelInfo(int tokenLimit, string encoding = null)
TokenLimit = tokenLimit;
_encodingName = encoding ?? Gpt34Encoding;

// For gpt4 and gpt3.5-turbo, the following 2 properties are the same.
// For gpt4o, gpt4 and gpt3.5-turbo, the following 2 properties are the same.
// See https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
TokensPerMessage = 3;
TokensPerName = 1;
}

private readonly string _encodingName;
private GptEncoding _gptEncoding;
private Tokenizer _gptEncoding;

internal int TokenLimit { get; }
internal int TokensPerMessage { get; }
internal int TokensPerName { get; }
internal GptEncoding Encoding
internal Tokenizer Encoding
{
get {
_gptEncoding ??= s_encodingMap.TryGetValue(_encodingName, out Task<GptEncoding> value)
_gptEncoding ??= s_encodingMap.TryGetValue(_encodingName, out Task<Tokenizer> value)
? value.Result
: GptEncoding.GetEncoding(_encodingName);
: TiktokenTokenizer.CreateForEncoding(_encodingName);
return _gptEncoding;
}
}
Expand Down
Loading

0 comments on commit 33009f1

Please sign in to comment.