Skip to content

Commit 87d55a8

Browse files
committed
Add an IChatClient implementation to OnnxRuntimeGenAI
1 parent 44a8f22 commit 87d55a8

4 files changed

+355
-1
lines changed

src/csharp/ChatClient.cs

+247
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,247 @@
1+
using Microsoft.Extensions.AI;
2+
using System;
3+
using System.Collections.Generic;
4+
using System.Runtime.CompilerServices;
5+
using System.Text;
6+
using System.Threading;
7+
using System.Threading.Tasks;
8+
9+
namespace Microsoft.ML.OnnxRuntimeGenAI;
10+
11+
/// <summary>Provides an <see cref="IChatClient"/> implementation for interacting with a <see cref="Model"/>.</summary>
12+
public sealed partial class ChatClient : IChatClient
13+
{
14+
/// <summary>The options used to configure the instance.</summary>
15+
private readonly ChatClientConfiguration _config;
16+
/// <summary>The wrapped <see cref="Model"/>.</summary>
17+
private readonly Model _model;
18+
/// <summary>The wrapped <see cref="Tokenizer"/>.</summary>
19+
private readonly Tokenizer _tokenizer;
20+
/// <summary>Whether to dispose of <see cref="_model"/> when this instance is disposed.</summary>
21+
private readonly bool _ownsModel;
22+
23+
/// <summary>Initializes an instance of the <see cref="ChatClient"/> class.</summary>
24+
/// <param name="configuration">Options used to configure the client instance.</param>
25+
/// <param name="modelPath">The file path to the model to load.</param>
26+
/// <exception cref="ArgumentNullException"><paramref name="modelPath"/> is null.</exception>
27+
public ChatClient(ChatClientConfiguration configuration, string modelPath)
28+
{
29+
if (configuration is null)
30+
{
31+
throw new ArgumentNullException(nameof(configuration));
32+
}
33+
34+
if (modelPath is null)
35+
{
36+
throw new ArgumentNullException(nameof(modelPath));
37+
}
38+
39+
_config = configuration;
40+
41+
_ownsModel = true;
42+
_model = new Model(modelPath);
43+
_tokenizer = new Tokenizer(_model);
44+
45+
Metadata = new("onnxruntime-genai", new Uri($"file://{modelPath}"), modelPath);
46+
}
47+
48+
/// <summary>Initializes an instance of the <see cref="ChatClient"/> class.</summary>
49+
/// <param name="configuration">Options used to configure the client instance.</param>
50+
/// <param name="model">The model to employ.</param>
51+
/// <param name="ownsModel">
52+
/// <see langword="true"/> if this <see cref="IChatClient"/> owns the <paramref name="model"/> and should
53+
/// dispose of it when this <see cref="IChatClient"/> is disposed; otherwise, <see langword="false"/>.
54+
/// The default is <see langword="true"/>.
55+
/// </param>
56+
/// <exception cref="ArgumentNullException"><paramref name="model"/> is null.</exception>
57+
public ChatClient(ChatClientConfiguration configuration, Model model, bool ownsModel = true)
58+
{
59+
if (configuration is null)
60+
{
61+
throw new ArgumentNullException(nameof(configuration));
62+
}
63+
64+
if (model is null)
65+
{
66+
throw new ArgumentNullException(nameof(model));
67+
}
68+
69+
_config = configuration;
70+
71+
_ownsModel = ownsModel;
72+
_model = model;
73+
_tokenizer = new Tokenizer(_model);
74+
75+
Metadata = new("onnxruntime-genai");
76+
}
77+
78+
/// <inheritdoc/>
79+
public ChatClientMetadata Metadata { get; }
80+
81+
/// <inheritdoc/>
82+
public void Dispose()
83+
{
84+
_tokenizer.Dispose();
85+
86+
if (_ownsModel)
87+
{
88+
_model.Dispose();
89+
}
90+
}
91+
92+
/// <inheritdoc/>
93+
public async Task<ChatCompletion> CompleteAsync(IList<ChatMessage> chatMessages, ChatOptions options = null, CancellationToken cancellationToken = default)
94+
{
95+
if (chatMessages is null)
96+
{
97+
throw new ArgumentNullException(nameof(chatMessages));
98+
}
99+
100+
StringBuilder text = new();
101+
await Task.Run(() =>
102+
{
103+
using Sequences tokens = _tokenizer.Encode(_config.PromptFormatter(chatMessages));
104+
using GeneratorParams generatorParams = new(_model);
105+
UpdateGeneratorParamsFromOptions(tokens[0].Length, generatorParams, options);
106+
107+
using Generator generator = new(_model, generatorParams);
108+
generator.AppendTokenSequences(tokens);
109+
110+
using var tokenizerStream = _tokenizer.CreateStream();
111+
112+
var completionId = Guid.NewGuid().ToString();
113+
while (!generator.IsDone())
114+
{
115+
cancellationToken.ThrowIfCancellationRequested();
116+
117+
generator.GenerateNextToken();
118+
119+
ReadOnlySpan<int> outputSequence = generator.GetSequence(0);
120+
string next = tokenizerStream.Decode(outputSequence[outputSequence.Length - 1]);
121+
122+
if (IsStop(next, options))
123+
{
124+
break;
125+
}
126+
127+
text.Append(next);
128+
}
129+
}, cancellationToken);
130+
131+
return new ChatCompletion(new ChatMessage(ChatRole.Assistant, text.ToString()))
132+
{
133+
CompletionId = Guid.NewGuid().ToString(),
134+
CreatedAt = DateTimeOffset.UtcNow,
135+
ModelId = Metadata.ModelId,
136+
};
137+
}
138+
139+
/// <inheritdoc/>
140+
public async IAsyncEnumerable<StreamingChatCompletionUpdate> CompleteStreamingAsync(
141+
IList<ChatMessage> chatMessages, ChatOptions options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
142+
{
143+
if (chatMessages is null)
144+
{
145+
throw new ArgumentNullException(nameof(chatMessages));
146+
}
147+
148+
using Sequences tokens = _tokenizer.Encode(_config.PromptFormatter(chatMessages));
149+
using GeneratorParams generatorParams = new(_model);
150+
UpdateGeneratorParamsFromOptions(tokens[0].Length, generatorParams, options);
151+
152+
using Generator generator = new(_model, generatorParams);
153+
generator.AppendTokenSequences(tokens);
154+
155+
using var tokenizerStream = _tokenizer.CreateStream();
156+
157+
var completionId = Guid.NewGuid().ToString();
158+
while (!generator.IsDone())
159+
{
160+
string next = await Task.Run(() =>
161+
{
162+
generator.GenerateNextToken();
163+
164+
ReadOnlySpan<int> outputSequence = generator.GetSequence(0);
165+
return tokenizerStream.Decode(outputSequence[outputSequence.Length - 1]);
166+
}, cancellationToken);
167+
168+
if (IsStop(next, options))
169+
{
170+
break;
171+
}
172+
173+
yield return new StreamingChatCompletionUpdate
174+
{
175+
CompletionId = completionId,
176+
CreatedAt = DateTimeOffset.UtcNow,
177+
Role = ChatRole.Assistant,
178+
Text = next,
179+
};
180+
}
181+
}
182+
183+
/// <inheritdoc/>
184+
public object GetService(Type serviceType, object key = null) =>
185+
key is not null ? null :
186+
serviceType == typeof(Model) ? _model :
187+
serviceType == typeof(Tokenizer) ? _tokenizer :
188+
serviceType?.IsInstanceOfType(this) is true ? this :
189+
null;
190+
191+
/// <summary>Gets whether the specified token is a stop sequence.</summary>
192+
private bool IsStop(string token, ChatOptions options) =>
193+
options?.StopSequences?.Contains(token) is true ||
194+
Array.IndexOf(_config.StopSequences, token) >= 0;
195+
196+
/// <summary>Updates the <paramref name="generatorParams"/> based on the supplied <paramref name="options"/>.</summary>
197+
private static void UpdateGeneratorParamsFromOptions(int numInputTokens, GeneratorParams generatorParams, ChatOptions options)
198+
{
199+
if (options is null)
200+
{
201+
return;
202+
}
203+
204+
if (options.MaxOutputTokens.HasValue)
205+
{
206+
generatorParams.SetSearchOption("max_length", numInputTokens + options.MaxOutputTokens.Value);
207+
}
208+
209+
if (options.Temperature.HasValue)
210+
{
211+
generatorParams.SetSearchOption("temperature", options.Temperature.Value);
212+
}
213+
214+
if (options.TopP.HasValue || options.TopK.HasValue)
215+
{
216+
if (options.TopP.HasValue)
217+
{
218+
generatorParams.SetSearchOption("top_p", options.TopP.Value);
219+
}
220+
221+
if (options.TopK.HasValue)
222+
{
223+
generatorParams.SetSearchOption("top_k", options.TopK.Value);
224+
}
225+
}
226+
227+
if (options.Seed.HasValue)
228+
{
229+
generatorParams.SetSearchOption("random_seed", options.Seed.Value);
230+
}
231+
232+
if (options.AdditionalProperties is { } props)
233+
{
234+
foreach (var entry in props)
235+
{
236+
switch (entry.Value)
237+
{
238+
case int i: generatorParams.SetSearchOption(entry.Key, i); break;
239+
case long l: generatorParams.SetSearchOption(entry.Key, l); break;
240+
case float f: generatorParams.SetSearchOption(entry.Key, f); break;
241+
case double d: generatorParams.SetSearchOption(entry.Key, d); break;
242+
case bool b: generatorParams.SetSearchOption(entry.Key, b); break;
243+
}
244+
}
245+
}
246+
}
247+
}

src/csharp/ChatClientConfiguration.cs

+73
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
using Microsoft.Extensions.AI;
2+
using System;
3+
using System.Collections.Generic;
4+
5+
namespace Microsoft.ML.OnnxRuntimeGenAI;
6+
7+
/// <summary>Provides configuration options used when constructing a <see cref="ChatClient"/>.</summary>
8+
/// <remarks>
9+
/// Every model has different requirements for stop sequences and prompt formatting. For best results,
10+
/// the configuration should be tailored to the exact nature of the model being used. For example,
11+
/// when using a Phi3 model, a configuration like the following may be used:
12+
/// <code>
13+
/// static ChatClientConfiguration CreateForPhi3() =&gt;
14+
/// new(["&lt;|system|&gt;", "&lt;|user|&gt;", "&lt;|assistant|&gt;", "&lt;|end|&gt;"],
15+
/// (IEnumerable&lt;ChatMessage&gt; messages) =&gt;
16+
/// {
17+
/// StringBuilder prompt = new();
18+
///
19+
/// foreach (var message in messages)
20+
/// foreach (var content in message.Contents.OfType&lt;TextContent&gt;())
21+
/// prompt.Append("&lt;|").Append(message.Role.Value).Append("|&gt;\n").Append(tc.Text).Append("&lt;|end|&gt;\n");
22+
///
23+
/// return prompt.Append("&lt;|assistant|&gt;\n").ToString();
24+
/// });
25+
/// </code>
26+
/// </remarks>
27+
public sealed class ChatClientConfiguration
28+
{
29+
private string[] _stopSequences;
30+
private Func<IEnumerable<ChatMessage>, string> _promptFormatter;
31+
32+
/// <summary>Initializes a new instance of the <see cref="ChatClientConfiguration"/> class.</summary>
33+
/// <param name="stopSequences">The stop sequences used by the model.</param>
34+
/// <param name="promptFormatter">The function to use to format a list of messages for input into the model.</param>
35+
/// <exception cref="ArgumentNullException"><paramref name="stopSequences"/> is null.</exception>
36+
/// <exception cref="ArgumentNullException"><paramref name="promptFormatter"/> is null.</exception>
37+
public ChatClientConfiguration(
38+
string[] stopSequences,
39+
Func<IEnumerable<ChatMessage>, string> promptFormatter)
40+
{
41+
if (stopSequences is null)
42+
{
43+
throw new ArgumentNullException(nameof(stopSequences));
44+
}
45+
46+
if (promptFormatter is null)
47+
{
48+
throw new ArgumentNullException(nameof(promptFormatter));
49+
}
50+
51+
StopSequences = stopSequences;
52+
PromptFormatter = promptFormatter;
53+
}
54+
55+
/// <summary>
56+
/// Gets or sets stop sequences to use during generation.
57+
/// </summary>
58+
/// <remarks>
59+
/// These will apply in addition to any stop sequences that are a part of the <see cref="ChatOptions.StopSequences"/>.
60+
/// </remarks>
61+
public string[] StopSequences
62+
{
63+
get => _stopSequences;
64+
set => _stopSequences = value ?? throw new ArgumentNullException(nameof(value));
65+
}
66+
67+
/// <summary>Gets the function that creates a prompt string from the chat history.</summary>
68+
public Func<IEnumerable<ChatMessage>, string> PromptFormatter
69+
{
70+
get => _promptFormatter;
71+
set => _promptFormatter = value ?? throw new ArgumentNullException(nameof(value));
72+
}
73+
}

src/csharp/Microsoft.ML.OnnxRuntimeGenAI.csproj

+4
Original file line numberDiff line numberDiff line change
@@ -121,4 +121,8 @@
121121
<PackageReference Include="System.Memory" Version="4.5.5" />
122122
</ItemGroup>
123123

124+
<ItemGroup>
125+
<PackageReference Include="Microsoft.Extensions.AI.Abstractions" Version="9.0.1-preview.1.24570.5" />
126+
</ItemGroup>
127+
124128
</Project>

test/csharp/TestOnnxRuntimeGenAIAPI.cs

+31-1
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,15 @@
22
// Licensed under the MIT License.
33

44
using System;
5+
using System.Collections.Generic;
56
using System.IO;
67
using System.Linq;
78
using System.Runtime.InteropServices;
8-
using System.Runtime.CompilerServices;
9+
using System.Text;
10+
using System.Threading.Tasks;
911
using Xunit;
1012
using Xunit.Abstractions;
13+
using Microsoft.Extensions.AI;
1114

1215
namespace Microsoft.ML.OnnxRuntimeGenAI.Tests
1316
{
@@ -349,6 +352,33 @@ public void TestTopKTopPSearch()
349352
}
350353
}
351354

355+
[IgnoreOnModelAbsenceFact(DisplayName = "TestChatClient")]
356+
public async Task TestChatClient()
357+
{
358+
ChatClientConfiguration config = new(
359+
["<|system|>", "<|user|>", "<|assistant|>", "<|end|>"],
360+
(IEnumerable<ChatMessage> messages) =>
361+
{
362+
StringBuilder prompt = new();
363+
364+
foreach (var message in messages)
365+
foreach (var content in message.Contents.OfType<TextContent>())
366+
prompt.Append("<|").Append(message.Role.Value).Append("|>\n").Append(content.Text).Append("<|end|>\n");
367+
368+
return prompt.Append("<|assistant|>\n").ToString();
369+
});
370+
371+
using var client = new ChatClient(config, _phi2Path);
372+
373+
var completion = await client.CompleteAsync("What is 2 + 3?", new()
374+
{
375+
MaxOutputTokens = 20,
376+
Temperature = 0f,
377+
});
378+
379+
Assert.Contains("5", completion.ToString());
380+
}
381+
352382
[IgnoreOnModelAbsenceFact(DisplayName = "TestTokenizerBatchEncodeDecode")]
353383
public void TestTokenizerBatchEncodeDecode()
354384
{

0 commit comments

Comments
 (0)