Skip to content

Commit 17754c0

Browse files
committed
Add an IChatClient implementation to OnnxRuntimeGenAI
1 parent 8288683 commit 17754c0

4 files changed

+393
-1
lines changed

src/csharp/ChatClient.cs

+284
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,284 @@
1+
using Microsoft.Extensions.AI;
2+
using System;
3+
using System.Collections.Generic;
4+
using System.Runtime.CompilerServices;
5+
using System.Text;
6+
using System.Threading;
7+
using System.Threading.Tasks;
8+
9+
namespace Microsoft.ML.OnnxRuntimeGenAI;
10+
11+
/// <summary>Provides an <see cref="IChatClient"/> implementation for interacting with a <see cref="Model"/>.</summary>
12+
public sealed partial class ChatClient : IChatClient
13+
{
14+
/// <summary>The options used to configure the instance.</summary>
15+
private readonly ChatClientConfiguration _config;
16+
/// <summary>The wrapped <see cref="Model"/>.</summary>
17+
private readonly Model _model;
18+
/// <summary>The wrapped <see cref="Tokenizer"/>.</summary>
19+
private readonly Tokenizer _tokenizer;
20+
/// <summary>Whether to dispose of <see cref="_model"/> when this instance is disposed.</summary>
21+
private readonly bool _ownsModel;
22+
23+
/// <summary>Initializes an instance of the <see cref="ChatClient"/> class.</summary>
24+
/// <param name="configuration">Options used to configure the client instance.</param>
25+
/// <param name="modelPath">The file path to the model to load.</param>
26+
/// <exception cref="ArgumentNullException"><paramref name="modelPath"/> is null.</exception>
27+
public ChatClient(ChatClientConfiguration configuration, string modelPath)
28+
{
29+
if (configuration is null)
30+
{
31+
throw new ArgumentNullException(nameof(configuration));
32+
}
33+
34+
if (modelPath is null)
35+
{
36+
throw new ArgumentNullException(nameof(modelPath));
37+
}
38+
39+
_config = configuration;
40+
41+
_ownsModel = true;
42+
_model = new Model(modelPath);
43+
_tokenizer = new Tokenizer(_model);
44+
45+
Metadata = new("onnxruntime-genai", new Uri($"file://{modelPath}"), modelPath);
46+
}
47+
48+
/// <summary>Initializes an instance of the <see cref="ChatClient"/> class.</summary>
49+
/// <param name="configuration">Options used to configure the client instance.</param>
50+
/// <param name="model">The model to employ.</param>
51+
/// <param name="ownsModel">
52+
/// <see langword="true"/> if this <see cref="IChatClient"/> owns the <paramref name="model"/> and should
53+
/// dispose of it when this <see cref="IChatClient"/> is disposed; otherwise, <see langword="false"/>.
54+
/// The default is <see langword="true"/>.
55+
/// </param>
56+
/// <exception cref="ArgumentNullException"><paramref name="model"/> is null.</exception>
57+
public ChatClient(ChatClientConfiguration configuration, Model model, bool ownsModel = true)
58+
{
59+
if (configuration is null)
60+
{
61+
throw new ArgumentNullException(nameof(configuration));
62+
}
63+
64+
if (model is null)
65+
{
66+
throw new ArgumentNullException(nameof(model));
67+
}
68+
69+
_config = configuration;
70+
71+
_ownsModel = ownsModel;
72+
_model = model;
73+
_tokenizer = new Tokenizer(_model);
74+
75+
Metadata = new("onnx");
76+
}
77+
78+
/// <inheritdoc/>
79+
public ChatClientMetadata Metadata { get; }
80+
81+
/// <inheritdoc/>
82+
public void Dispose()
83+
{
84+
_tokenizer.Dispose();
85+
86+
if (_ownsModel)
87+
{
88+
_model.Dispose();
89+
}
90+
}
91+
92+
/// <inheritdoc/>
93+
public async Task<ChatCompletion> CompleteAsync(IList<ChatMessage> chatMessages, ChatOptions options = null, CancellationToken cancellationToken = default)
94+
{
95+
if (chatMessages is null)
96+
{
97+
throw new ArgumentNullException(nameof(chatMessages));
98+
}
99+
100+
int inputTokens = 0, outputTokens = 0;
101+
StringBuilder text = new();
102+
await Task.Run(() =>
103+
{
104+
using Sequences tokens = _tokenizer.Encode(_config.PromptFormatter(chatMessages));
105+
using GeneratorParams generatorParams = new(_model);
106+
UpdateGeneratorParamsFromOptions(tokens[0].Length, generatorParams, options);
107+
108+
inputTokens = tokens[0].Length;
109+
110+
using Generator generator = new(_model, generatorParams);
111+
generator.AppendTokenSequences(tokens);
112+
113+
using var tokenizerStream = _tokenizer.CreateStream();
114+
115+
while (!generator.IsDone())
116+
{
117+
cancellationToken.ThrowIfCancellationRequested();
118+
119+
generator.GenerateNextToken();
120+
121+
ReadOnlySpan<int> outputSequence = generator.GetSequence(0);
122+
string next = tokenizerStream.Decode(outputSequence[outputSequence.Length - 1]);
123+
124+
if (IsStop(next, options))
125+
{
126+
break;
127+
}
128+
129+
outputTokens++;
130+
text.Append(next);
131+
}
132+
}, cancellationToken);
133+
134+
return new ChatCompletion(new ChatMessage(ChatRole.Assistant, text.ToString()))
135+
{
136+
CompletionId = Guid.NewGuid().ToString(),
137+
CreatedAt = DateTimeOffset.UtcNow,
138+
ModelId = Metadata.ModelId,
139+
Usage = new()
140+
{
141+
InputTokenCount = inputTokens,
142+
OutputTokenCount = outputTokens,
143+
TotalTokenCount = inputTokens + outputTokens,
144+
},
145+
};
146+
}
147+
148+
/// <inheritdoc/>
149+
public async IAsyncEnumerable<StreamingChatCompletionUpdate> CompleteStreamingAsync(
150+
IList<ChatMessage> chatMessages, ChatOptions options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
151+
{
152+
if (chatMessages is null)
153+
{
154+
throw new ArgumentNullException(nameof(chatMessages));
155+
}
156+
157+
using Sequences tokens = _tokenizer.Encode(_config.PromptFormatter(chatMessages));
158+
using GeneratorParams generatorParams = new(_model);
159+
UpdateGeneratorParamsFromOptions(tokens[0].Length, generatorParams, options);
160+
161+
using Generator generator = new(_model, generatorParams);
162+
generator.AppendTokenSequences(tokens);
163+
164+
using var tokenizerStream = _tokenizer.CreateStream();
165+
166+
int inputTokens = tokens[0].Length, outputTokens = 0;
167+
var completionId = Guid.NewGuid().ToString();
168+
while (!generator.IsDone())
169+
{
170+
string next = await Task.Run(() =>
171+
{
172+
generator.GenerateNextToken();
173+
174+
ReadOnlySpan<int> outputSequence = generator.GetSequence(0);
175+
return tokenizerStream.Decode(outputSequence[outputSequence.Length - 1]);
176+
}, cancellationToken);
177+
178+
if (IsStop(next, options))
179+
{
180+
break;
181+
}
182+
183+
outputTokens++;
184+
yield return new()
185+
{
186+
CompletionId = completionId,
187+
CreatedAt = DateTimeOffset.UtcNow,
188+
Role = ChatRole.Assistant,
189+
Text = next,
190+
};
191+
}
192+
193+
yield return new()
194+
{
195+
CompletionId = completionId,
196+
Contents = [new UsageContent(new()
197+
{
198+
InputTokenCount = inputTokens,
199+
OutputTokenCount = outputTokens,
200+
TotalTokenCount = inputTokens + outputTokens,
201+
})],
202+
CreatedAt = DateTimeOffset.UtcNow,
203+
Role = ChatRole.Assistant,
204+
};
205+
}
206+
207+
/// <inheritdoc/>
208+
public object GetService(Type serviceType, object key = null) =>
209+
key is not null ? null :
210+
serviceType == typeof(Model) ? _model :
211+
serviceType == typeof(Tokenizer) ? _tokenizer :
212+
serviceType?.IsInstanceOfType(this) is true ? this :
213+
null;
214+
215+
/// <summary>Gets whether the specified token is a stop sequence.</summary>
216+
private bool IsStop(string token, ChatOptions options) =>
217+
options?.StopSequences?.Contains(token) is true ||
218+
Array.IndexOf(_config.StopSequences, token) >= 0;
219+
220+
/// <summary>Updates the <paramref name="generatorParams"/> based on the supplied <paramref name="options"/>.</summary>
221+
private static void UpdateGeneratorParamsFromOptions(int numInputTokens, GeneratorParams generatorParams, ChatOptions options)
222+
{
223+
if (options is null)
224+
{
225+
return;
226+
}
227+
228+
if (options.MaxOutputTokens.HasValue)
229+
{
230+
generatorParams.SetSearchOption("max_length", numInputTokens + options.MaxOutputTokens.Value);
231+
}
232+
233+
if (options.Temperature.HasValue)
234+
{
235+
generatorParams.SetSearchOption("temperature", options.Temperature.Value);
236+
}
237+
238+
if (options.PresencePenalty.HasValue)
239+
{
240+
generatorParams.SetSearchOption("repetition_penalty", options.PresencePenalty.Value);
241+
}
242+
243+
if (options.TopP.HasValue || options.TopK.HasValue)
244+
{
245+
if (options.TopP.HasValue)
246+
{
247+
generatorParams.SetSearchOption("top_p", options.TopP.Value);
248+
}
249+
250+
if (options.TopK.HasValue)
251+
{
252+
generatorParams.SetSearchOption("top_k", options.TopK.Value);
253+
}
254+
}
255+
256+
if (options.Seed.HasValue)
257+
{
258+
generatorParams.SetSearchOption("random_seed", options.Seed.Value);
259+
}
260+
261+
if (options.AdditionalProperties is { } props)
262+
{
263+
foreach (var entry in props)
264+
{
265+
if (entry.Value is bool b)
266+
{
267+
generatorParams.SetSearchOption(entry.Key, b);
268+
}
269+
else if (entry.Value is not null)
270+
{
271+
try
272+
{
273+
double d = Convert.ToDouble(entry.Value);
274+
generatorParams.SetSearchOption(entry.Key, d);
275+
}
276+
catch
277+
{
278+
// Ignore values we can't convert
279+
}
280+
}
281+
}
282+
}
283+
}
284+
}

src/csharp/ChatClientConfiguration.cs

+73
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
using Microsoft.Extensions.AI;
2+
using System;
3+
using System.Collections.Generic;
4+
5+
namespace Microsoft.ML.OnnxRuntimeGenAI;
6+
7+
/// <summary>Provides configuration options used when constructing a <see cref="ChatClient"/>.</summary>
8+
/// <remarks>
9+
/// Every model has different requirements for stop sequences and prompt formatting. For best results,
10+
/// the configuration should be tailored to the exact nature of the model being used. For example,
11+
/// when using a Phi3 model, a configuration like the following may be used:
12+
/// <code>
13+
/// static ChatClientConfiguration CreateForPhi3() =&gt;
14+
/// new(["&lt;|system|&gt;", "&lt;|user|&gt;", "&lt;|assistant|&gt;", "&lt;|end|&gt;"],
15+
/// (IEnumerable&lt;ChatMessage&gt; messages) =&gt;
16+
/// {
17+
/// StringBuilder prompt = new();
18+
///
19+
/// foreach (var message in messages)
20+
/// foreach (var content in message.Contents.OfType&lt;TextContent&gt;())
21+
/// prompt.Append("&lt;|").Append(message.Role.Value).Append("|&gt;\n").Append(tc.Text).Append("&lt;|end|&gt;\n");
22+
///
23+
/// return prompt.Append("&lt;|assistant|&gt;\n").ToString();
24+
/// });
25+
/// </code>
26+
/// </remarks>
27+
public sealed class ChatClientConfiguration
28+
{
29+
private string[] _stopSequences;
30+
private Func<IEnumerable<ChatMessage>, string> _promptFormatter;
31+
32+
/// <summary>Initializes a new instance of the <see cref="ChatClientConfiguration"/> class.</summary>
33+
/// <param name="stopSequences">The stop sequences used by the model.</param>
34+
/// <param name="promptFormatter">The function to use to format a list of messages for input into the model.</param>
35+
/// <exception cref="ArgumentNullException"><paramref name="stopSequences"/> is null.</exception>
36+
/// <exception cref="ArgumentNullException"><paramref name="promptFormatter"/> is null.</exception>
37+
public ChatClientConfiguration(
38+
string[] stopSequences,
39+
Func<IEnumerable<ChatMessage>, string> promptFormatter)
40+
{
41+
if (stopSequences is null)
42+
{
43+
throw new ArgumentNullException(nameof(stopSequences));
44+
}
45+
46+
if (promptFormatter is null)
47+
{
48+
throw new ArgumentNullException(nameof(promptFormatter));
49+
}
50+
51+
StopSequences = stopSequences;
52+
PromptFormatter = promptFormatter;
53+
}
54+
55+
/// <summary>
56+
/// Gets or sets stop sequences to use during generation.
57+
/// </summary>
58+
/// <remarks>
59+
/// These will apply in addition to any stop sequences that are a part of the <see cref="ChatOptions.StopSequences"/>.
60+
/// </remarks>
61+
public string[] StopSequences
62+
{
63+
get => _stopSequences;
64+
set => _stopSequences = value ?? throw new ArgumentNullException(nameof(value));
65+
}
66+
67+
/// <summary>Gets the function that creates a prompt string from the chat history.</summary>
68+
public Func<IEnumerable<ChatMessage>, string> PromptFormatter
69+
{
70+
get => _promptFormatter;
71+
set => _promptFormatter = value ?? throw new ArgumentNullException(nameof(value));
72+
}
73+
}

src/csharp/Microsoft.ML.OnnxRuntimeGenAI.csproj

+4
Original file line numberDiff line numberDiff line change
@@ -121,4 +121,8 @@
121121
<PackageReference Include="System.Memory" Version="4.5.5" />
122122
</ItemGroup>
123123

124+
<ItemGroup>
125+
<PackageReference Include="Microsoft.Extensions.AI.Abstractions" Version="9.0.1-preview.1.24570.5" />
126+
</ItemGroup>
127+
124128
</Project>

0 commit comments

Comments
 (0)