Skip to content

Commit d6c4921

Browse files
committed
Add an IChatClient implementation to OnnxRuntimeGenAI
1 parent 47132b6 commit d6c4921

File tree

2 files changed

+210
-0
lines changed

2 files changed

+210
-0
lines changed

src/csharp/ChatClient.cs

+206
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
using Microsoft.Extensions.AI;
2+
using System;
3+
using System.Collections.Generic;
4+
using System.Reflection;
5+
using System.Runtime.CompilerServices;
6+
using System.Text;
7+
using System.Threading;
8+
using System.Threading.Tasks;
9+
10+
namespace Microsoft.ML.OnnxRuntimeGenAI;
11+
12+
/// <summary>An <see cref="IChatClient"/> implementation based on ONNX Runtime GenAI.</summary>
13+
public sealed class ChatClient : IChatClient, IDisposable
14+
{
15+
/// <summary>The wrapped <see cref="Model"/>.</summary>
16+
private readonly Model _model;
17+
/// <summary>The wrapped <see cref="Tokenizer"/>.</summary>
18+
private readonly Tokenizer _tokenizer;
19+
/// <summary>Whether to dispose of <see cref="_model"/> when this instance is disposed.</summary>
20+
private readonly bool _ownsModel;
21+
22+
/// <summary>Initializes an instance of the <see cref="ChatClient"/> class.</summary>
23+
/// <param name="modelPath">The file path to the model to load.</param>
24+
/// <exception cref="ArgumentNullException"><paramref name="modelPath"/> is null.</exception>
25+
public ChatClient(string modelPath)
26+
{
27+
if (modelPath is null)
28+
{
29+
throw new ArgumentNullException(nameof(modelPath));
30+
}
31+
32+
_ownsModel = true;
33+
_model = new Model(modelPath);
34+
_tokenizer = new Tokenizer(_model);
35+
36+
Metadata = new(typeof(ChatClient).Namespace, new Uri($"file://{modelPath}"), modelPath);
37+
}
38+
39+
/// <summary>Initializes an instance of the <see cref="ChatClient"/> class.</summary>
40+
/// <param name="model">The model to employ.</param>
41+
/// <param name="ownsModel">
42+
/// <see langword="true"/> if this <see cref="IChatClient"/> owns the <paramref name="model"/> and should
43+
/// dispose of it when this <see cref="IChatClient"/> is disposed; otherwise, <see langword="false"/>.
44+
/// The default is <see langword="true"/>.
45+
/// </param>
46+
/// <exception cref="ArgumentNullException"><paramref name="model"/> is null.</exception>
47+
public ChatClient(Model model, bool ownsModel = true)
48+
{
49+
if (model is null)
50+
{
51+
throw new ArgumentNullException(nameof(model));
52+
}
53+
54+
_ownsModel = ownsModel;
55+
_model = model;
56+
_tokenizer = new Tokenizer(_model);
57+
58+
Metadata = new("Microsoft.ML.OnnxRuntimeGenAI");
59+
}
60+
61+
/// <inheritdoc/>
62+
public ChatClientMetadata Metadata { get; }
63+
64+
/// <inheritdoc/>
65+
public void Dispose()
66+
{
67+
_tokenizer.Dispose();
68+
69+
if (_ownsModel)
70+
{
71+
_model.Dispose();
72+
}
73+
}
74+
75+
/// <inheritdoc/>
76+
public Task<ChatCompletion> CompleteAsync(IList<ChatMessage> chatMessages, ChatOptions options = null, CancellationToken cancellationToken = default)
77+
{
78+
if (chatMessages is null)
79+
{
80+
throw new ArgumentNullException(nameof(chatMessages));
81+
}
82+
83+
return Task.Run(() =>
84+
{
85+
using Sequences tokens = _tokenizer.Encode(CreatePrompt(chatMessages));
86+
using GeneratorParams generatorParams = new(_model);
87+
UpdateGeneratorParamsFromOptions(tokens[0].Length, generatorParams, options);
88+
generatorParams.SetInputSequences(tokens);
89+
90+
using Generator generator = new(_model, generatorParams);
91+
using Sequences outputSequences = _model.Generate(generatorParams);
92+
93+
return new ChatCompletion(new ChatMessage(ChatRole.Assistant, _tokenizer.Decode(outputSequences[0])))
94+
{
95+
CompletionId = Guid.NewGuid().ToString(),
96+
CreatedAt = DateTimeOffset.UtcNow,
97+
ModelId = Metadata.ModelId,
98+
};
99+
}, cancellationToken);
100+
}
101+
102+
/// <inheritdoc/>
103+
public async IAsyncEnumerable<StreamingChatCompletionUpdate> CompleteStreamingAsync(
104+
IList<ChatMessage> chatMessages, ChatOptions options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
105+
{
106+
if (chatMessages is null)
107+
{
108+
throw new ArgumentNullException(nameof(chatMessages));
109+
}
110+
111+
using Sequences tokens = _tokenizer.Encode(CreatePrompt(chatMessages));
112+
using GeneratorParams generatorParams = new(_model);
113+
UpdateGeneratorParamsFromOptions(tokens[0].Length, generatorParams, options);
114+
generatorParams.SetInputSequences(tokens);
115+
116+
using Generator generator = new(_model, generatorParams);
117+
using var tokenizerStream = _tokenizer.CreateStream();
118+
119+
var completionId = Guid.NewGuid().ToString();
120+
while (!generator.IsDone())
121+
{
122+
string next = await Task.Run(() =>
123+
{
124+
generator.ComputeLogits();
125+
generator.GenerateNextToken();
126+
127+
ReadOnlySpan<int> outputSequence = generator.GetSequence(0);
128+
return tokenizerStream.Decode(outputSequence[outputSequence.Length - 1]);
129+
}, cancellationToken);
130+
131+
yield return new StreamingChatCompletionUpdate
132+
{
133+
CompletionId = completionId,
134+
CreatedAt = DateTimeOffset.UtcNow,
135+
Role = ChatRole.Assistant,
136+
Text = next,
137+
};
138+
}
139+
}
140+
141+
/// <inheritdoc/>
142+
public TService GetService<TService>(object key = null) where TService : class =>
143+
typeof(TService) == typeof(Model) ? (TService)(object)_model :
144+
typeof(TService) == typeof(Tokenizer) ? (TService)(object)_tokenizer :
145+
this as TService;
146+
147+
/// <summary>Creates a prompt string from the supplied chat history.</summary>
148+
private string CreatePrompt(IEnumerable<ChatMessage> messages)
149+
{
150+
StringBuilder prompt = new();
151+
152+
foreach (var message in messages)
153+
{
154+
foreach (var content in message.Contents)
155+
{
156+
switch (content)
157+
{
158+
case TextContent tc when !string.IsNullOrWhiteSpace(tc.Text):
159+
prompt.Append("<|").Append(message.Role.Value).Append("|>\n").Append(tc.Text);
160+
break;
161+
}
162+
}
163+
}
164+
165+
return prompt.Append("<|end|>\n<|assistant|>").ToString();
166+
}
167+
168+
/// <summary>Updates the <paramref name="generatorParams"/> based on the supplied <paramref name="options"/>.</summary>
169+
private static void UpdateGeneratorParamsFromOptions(int numInputTokens, GeneratorParams generatorParams, ChatOptions options)
170+
{
171+
if (options is null)
172+
{
173+
return;
174+
}
175+
176+
if (options.Temperature.HasValue)
177+
{
178+
generatorParams.SetSearchOption("temperature", options.Temperature.Value);
179+
}
180+
181+
if (options.TopP.HasValue)
182+
{
183+
generatorParams.SetSearchOption("top_p", options.TopP.Value);
184+
}
185+
186+
if (options.MaxOutputTokens.HasValue)
187+
{
188+
generatorParams.SetSearchOption("max_length", numInputTokens + options.MaxOutputTokens.Value);
189+
}
190+
191+
if (options.AdditionalProperties is { } props)
192+
{
193+
foreach (var entry in props)
194+
{
195+
switch (entry.Value)
196+
{
197+
case int i: generatorParams.SetSearchOption(entry.Key, i); break;
198+
case long l: generatorParams.SetSearchOption(entry.Key, l); break;
199+
case float f: generatorParams.SetSearchOption(entry.Key, f); break;
200+
case double d: generatorParams.SetSearchOption(entry.Key, d); break;
201+
case bool b: generatorParams.SetSearchOption(entry.Key, b); break;
202+
}
203+
}
204+
}
205+
}
206+
}

src/csharp/Microsoft.ML.OnnxRuntimeGenAI.csproj

+4
Original file line numberDiff line numberDiff line change
@@ -121,4 +121,8 @@
121121
<PackageReference Include="System.Memory" Version="4.5.5" />
122122
</ItemGroup>
123123

124+
<ItemGroup>
125+
<PackageReference Include="Microsoft.Extensions.AI.Abstractions" Version="9.0.0-preview.9.24507.7" />
126+
</ItemGroup>
127+
124128
</Project>

0 commit comments

Comments
 (0)