Skip to content

Commit b187dd4

Browse files
committed
Add an IChatClient implementation to OnnxRuntimeGenAI
1 parent 20f907e commit b187dd4

4 files changed

+444
-1
lines changed

src/csharp/Microsoft.ML.OnnxRuntimeGenAI.csproj

+4
Original file line numberDiff line numberDiff line change
@@ -121,4 +121,8 @@
121121
<PackageReference Include="System.Memory" Version="4.5.5" />
122122
</ItemGroup>
123123

124+
<ItemGroup>
125+
<PackageReference Include="Microsoft.Extensions.AI.Abstractions" Version="9.3.0-preview.1.25114.11" />
126+
</ItemGroup>
127+
124128
</Project>
+313
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,313 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
using System;
5+
using System.Buffers;
6+
using System.Collections.Generic;
7+
using System.Runtime.CompilerServices;
8+
using System.Threading;
9+
using System.Threading.Tasks;
10+
using Microsoft.Extensions.AI;
11+
12+
#nullable enable
13+
14+
namespace Microsoft.ML.OnnxRuntimeGenAI;
15+
16+
/// <summary>Provides an <see cref="IChatClient"/> implementation for interacting with an ONNX Runtime GenAI <see cref="Model"/>.</summary>
17+
public sealed class OnnxRuntimeGenAIChatClient : IChatClient
18+
{
19+
/// <summary>The options used to configure the instance.</summary>
20+
private readonly OnnxRuntimeGenAIChatClientOptions _options;
21+
/// <summary>The wrapped <see cref="Model"/>.</summary>
22+
private readonly Model _model;
23+
/// <summary>The wrapped <see cref="Tokenizer"/>.</summary>
24+
private readonly Tokenizer _tokenizer;
25+
/// <summary>Whether to dispose of <see cref="_model"/> when this instance is disposed.</summary>
26+
private readonly bool _ownsModel;
27+
/// <summary>Metadata for the chat client.</summary>
28+
private readonly ChatClientMetadata _metadata;
29+
30+
/// <summary>Cached information about the last generation to speed up a subsequent generation.</summary>
31+
/// <remarks>Only one is cached. Interlocked operations are used to take and return an instance from this cache.</remarks>
32+
private CachedGenerator? _cachedGenerator;
33+
34+
/// <summary>Initializes an instance of the <see cref="OnnxRuntimeGenAIChatClient"/> class.</summary>
35+
/// <param name="options">Options used to configure the client instance.</param>
36+
/// <param name="modelPath">The file path to the model to load.</param>
37+
/// <exception cref="ArgumentNullException"><paramref name="options"/> is <see langword="null"/>.</exception>
38+
/// <exception cref="ArgumentNullException"><paramref name="modelPath"/> is <see langword="null"/>.</exception>
39+
public OnnxRuntimeGenAIChatClient(OnnxRuntimeGenAIChatClientOptions options, string modelPath)
40+
{
41+
if (options is null)
42+
{
43+
throw new ArgumentNullException(nameof(options));
44+
}
45+
46+
if (modelPath is null)
47+
{
48+
throw new ArgumentNullException(nameof(modelPath));
49+
}
50+
51+
_options = options;
52+
53+
_ownsModel = true;
54+
_model = new Model(modelPath);
55+
_tokenizer = new Tokenizer(_model);
56+
57+
_metadata = new("onnx", new Uri($"file://{modelPath}"), modelPath);
58+
}
59+
60+
/// <summary>Initializes an instance of the <see cref="OnnxRuntimeGenAIChatClient"/> class.</summary>
61+
/// <param name="options">Options used to configure the client instance.</param>
62+
/// <param name="model">The model to employ.</param>
63+
/// <param name="ownsModel">
64+
/// <see langword="true"/> if this <see cref="IChatClient"/> owns the <paramref name="model"/> and should
65+
/// dispose of it when this <see cref="IChatClient"/> is disposed; otherwise, <see langword="false"/>.
66+
/// The default is <see langword="true"/>.
67+
/// </param>
68+
/// <exception cref="ArgumentNullException"><paramref name="options"/> is <see langword="null"/>.</exception>
69+
/// <exception cref="ArgumentNullException"><paramref name="model"/> is <see langword="null"/>.</exception>
70+
public OnnxRuntimeGenAIChatClient(OnnxRuntimeGenAIChatClientOptions options, Model model, bool ownsModel = true)
71+
{
72+
if (options is null)
73+
{
74+
throw new ArgumentNullException(nameof(options));
75+
}
76+
77+
if (model is null)
78+
{
79+
throw new ArgumentNullException(nameof(model));
80+
}
81+
82+
_options = options;
83+
84+
_ownsModel = ownsModel;
85+
_model = model;
86+
_tokenizer = new Tokenizer(_model);
87+
88+
_metadata = new("onnx");
89+
}
90+
91+
/// <inheritdoc/>
92+
public void Dispose()
93+
{
94+
if (Interlocked.Exchange(ref _cachedGenerator, null) is CachedGenerator cachedGenerator)
95+
{
96+
cachedGenerator.Dispose();
97+
}
98+
99+
_tokenizer.Dispose();
100+
101+
if (_ownsModel)
102+
{
103+
_model.Dispose();
104+
}
105+
}
106+
107+
/// <inheritdoc/>
108+
public Task<ChatResponse> GetResponseAsync(
109+
IList<ChatMessage> chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) =>
110+
GetStreamingResponseAsync(chatMessages, options, cancellationToken).ToChatResponseAsync(cancellationToken: cancellationToken);
111+
112+
/// <inheritdoc/>
113+
public async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(
114+
IList<ChatMessage> chatMessages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
115+
{
116+
if (chatMessages is null)
117+
{
118+
throw new ArgumentNullException(nameof(chatMessages));
119+
}
120+
121+
// Check to see whether there's a cached generator. If there is, and if its id matches what we got from the client,
122+
// we can use it; otherwise, we need to create a new one.
123+
CachedGenerator? generator = Interlocked.Exchange(ref _cachedGenerator, null);
124+
if (generator is null ||
125+
generator.ChatThreadId is null ||
126+
generator.ChatThreadId != options?.ChatThreadId)
127+
{
128+
generator?.Dispose();
129+
130+
using GeneratorParams p = new(_model); // we can dispose of this after we create the generator
131+
UpdateGeneratorParamsFromOptions(p, options);
132+
generator = new(new Generator(_model, p));
133+
}
134+
135+
// If caching is enabled, generate a new ID to represent the state of the generator when we finish this response.
136+
generator.ChatThreadId = _options.EnableCaching ? Guid.NewGuid().ToString("N") : null;
137+
138+
// Format and tokenize the message.
139+
using Sequences tokens = _tokenizer.Encode(_options.PromptFormatter(chatMessages));
140+
try
141+
{
142+
generator.Generator.AppendTokenSequences(tokens);
143+
int inputTokens = tokens[0].Length, outputTokens = 0;
144+
145+
// Loop while we still want to produce more tokens.
146+
using var tokenizerStream = _tokenizer.CreateStream();
147+
while (!generator.Generator.IsDone())
148+
{
149+
// If we've reached a max output token limit, stop.
150+
if (options?.MaxOutputTokens is int maxOutputTokens &&
151+
outputTokens >= maxOutputTokens)
152+
{
153+
break;
154+
}
155+
156+
// Avoid blocking calling thread with expensive compute
157+
await YieldAwaiter.Instance;
158+
159+
// Generate the next token.
160+
generator.Generator.GenerateNextToken();
161+
string next = tokenizerStream.Decode(GetLastToken(generator.Generator.GetSequence(0)));
162+
163+
// workaround until C# 13 is adopted and ref locals are usable in async methods
164+
static int GetLastToken(ReadOnlySpan<int> span) => span[span.Length - 1];
165+
166+
// If this token is a stop token, bail.
167+
if (IsStop(next, options))
168+
{
169+
break;
170+
}
171+
172+
// Yield the next token in the stream.
173+
outputTokens++;
174+
yield return new()
175+
{
176+
CreatedAt = DateTimeOffset.UtcNow,
177+
Role = ChatRole.Assistant,
178+
Text = next,
179+
};
180+
}
181+
182+
// Yield a final update containing metadata.
183+
yield return new()
184+
{
185+
ChatThreadId = generator.ChatThreadId,
186+
Contents = [new UsageContent(new()
187+
{
188+
InputTokenCount = inputTokens,
189+
OutputTokenCount = outputTokens,
190+
TotalTokenCount = inputTokens + outputTokens,
191+
})],
192+
CreatedAt = DateTimeOffset.UtcNow,
193+
FinishReason = options is not null && options.MaxOutputTokens <= outputTokens ? ChatFinishReason.Length : ChatFinishReason.Stop,
194+
ModelId = _metadata.ModelId,
195+
ResponseId = Guid.NewGuid().ToString(),
196+
Role = ChatRole.Assistant,
197+
};
198+
}
199+
finally
200+
{
201+
// Cache the generator for subsequent use if it's cachable and there isn't already a generator cached.
202+
if (generator.ChatThreadId is null ||
203+
Interlocked.CompareExchange(ref _cachedGenerator, generator, null) != null)
204+
{
205+
generator.Dispose();
206+
}
207+
}
208+
}
209+
210+
/// <inheritdoc/>
211+
object? IChatClient.GetService(Type serviceType, object? serviceKey)
212+
{
213+
if (serviceType is null)
214+
{
215+
throw new ArgumentNullException(nameof(serviceType));
216+
}
217+
218+
return
219+
serviceKey is not null ? null :
220+
serviceType == typeof(ChatClientMetadata) ? _metadata :
221+
serviceType == typeof(Model) ? _model :
222+
serviceType == typeof(Tokenizer) ? _tokenizer :
223+
serviceType?.IsInstanceOfType(this) is true ? this :
224+
null;
225+
}
226+
227+
/// <summary>Gets whether the specified token is a stop sequence.</summary>
228+
private bool IsStop(string token, ChatOptions? options) =>
229+
options?.StopSequences?.Contains(token) is true ||
230+
_options.StopSequences.Contains(token);
231+
232+
/// <summary>Updates the <paramref name="generatorParams"/> based on the supplied <paramref name="options"/>.</summary>
233+
private static void UpdateGeneratorParamsFromOptions(GeneratorParams generatorParams, ChatOptions? options)
234+
{
235+
if (options is null)
236+
{
237+
return;
238+
}
239+
240+
if (options.Temperature.HasValue)
241+
{
242+
generatorParams.SetSearchOption("temperature", options.Temperature.Value);
243+
}
244+
245+
if (options.PresencePenalty.HasValue)
246+
{
247+
generatorParams.SetSearchOption("repetition_penalty", options.PresencePenalty.Value);
248+
}
249+
250+
if (options.TopP.HasValue || options.TopK.HasValue)
251+
{
252+
if (options.TopP.HasValue)
253+
{
254+
generatorParams.SetSearchOption("top_p", options.TopP.Value);
255+
}
256+
257+
if (options.TopK.HasValue)
258+
{
259+
generatorParams.SetSearchOption("top_k", options.TopK.Value);
260+
}
261+
262+
generatorParams.SetSearchOption("do_sample", true);
263+
}
264+
265+
if (options.Seed.HasValue)
266+
{
267+
generatorParams.SetSearchOption("random_seed", options.Seed.Value);
268+
}
269+
270+
if (options.AdditionalProperties is { } props)
271+
{
272+
foreach (var entry in props)
273+
{
274+
if (entry.Value is bool b)
275+
{
276+
generatorParams.SetSearchOption(entry.Key, b);
277+
}
278+
else if (entry.Value is not null)
279+
{
280+
try
281+
{
282+
double d = Convert.ToDouble(entry.Value);
283+
generatorParams.SetSearchOption(entry.Key, d);
284+
}
285+
catch
286+
{
287+
// Ignore values we can't convert
288+
}
289+
}
290+
}
291+
}
292+
}
293+
294+
private sealed class CachedGenerator(Generator generator) : IDisposable
295+
{
296+
public Generator Generator { get; } = generator;
297+
298+
public string? ChatThreadId { get; set; }
299+
300+
public void Dispose() => Generator?.Dispose();
301+
}
302+
303+
/// <summary>Polyfill for Task.CompletedTask.ConfigureAwait(ConfigureAwaitOptions.ForceYielding);</summary>
304+
private sealed class YieldAwaiter : INotifyCompletion
305+
{
306+
public static YieldAwaiter Instance { get; } = new();
307+
public YieldAwaiter GetAwaiter() => this;
308+
public bool IsCompleted => false;
309+
public void OnCompleted(Action continuation) => Task.Run(continuation);
310+
public void UnsafeOnCompleted(Action continuation) => Task.Run(continuation);
311+
public void GetResult() { }
312+
}
313+
}

0 commit comments

Comments
 (0)