1
+ // Copyright (c) Microsoft Corporation. All rights reserved.
2
+ // Licensed under the MIT License.
3
+
4
+ using System ;
5
+ using System . Buffers ;
6
+ using System . Collections . Generic ;
7
+ using System . Runtime . CompilerServices ;
8
+ using System . Threading ;
9
+ using System . Threading . Tasks ;
10
+ using Microsoft . Extensions . AI ;
11
+
12
+ #nullable enable
13
+
14
+ namespace Microsoft . ML . OnnxRuntimeGenAI ;
15
+
16
+ /// <summary>Provides an <see cref="IChatClient"/> implementation for interacting with an ONNX Runtime GenAI <see cref="Model"/>.</summary>
17
+ public sealed class OnnxRuntimeGenAIChatClient : IChatClient
18
+ {
19
+ /// <summary>The options used to configure the instance.</summary>
20
+ private readonly OnnxRuntimeGenAIChatClientOptions _options ;
21
+ /// <summary>The wrapped <see cref="Model"/>.</summary>
22
+ private readonly Model _model ;
23
+ /// <summary>The wrapped <see cref="Tokenizer"/>.</summary>
24
+ private readonly Tokenizer _tokenizer ;
25
+ /// <summary>Whether to dispose of <see cref="_model"/> when this instance is disposed.</summary>
26
+ private readonly bool _ownsModel ;
27
+ /// <summary>Metadata for the chat client.</summary>
28
+ private readonly ChatClientMetadata _metadata ;
29
+
30
+ /// <summary>Cached information about the last generation to speed up a subsequent generation.</summary>
31
+ /// <remarks>Only one is cached. Interlocked operations are used to take and return an instance from this cache.</remarks>
32
+ private CachedGenerator ? _cachedGenerator ;
33
+
34
+ /// <summary>Initializes an instance of the <see cref="OnnxRuntimeGenAIChatClient"/> class.</summary>
35
+ /// <param name="options">Options used to configure the client instance.</param>
36
+ /// <param name="modelPath">The file path to the model to load.</param>
37
+ /// <exception cref="ArgumentNullException"><paramref name="options"/> is <see langword="null"/>.</exception>
38
+ /// <exception cref="ArgumentNullException"><paramref name="modelPath"/> is <see langword="null"/>.</exception>
39
+ public OnnxRuntimeGenAIChatClient ( OnnxRuntimeGenAIChatClientOptions options , string modelPath )
40
+ {
41
+ if ( options is null )
42
+ {
43
+ throw new ArgumentNullException ( nameof ( options ) ) ;
44
+ }
45
+
46
+ if ( modelPath is null )
47
+ {
48
+ throw new ArgumentNullException ( nameof ( modelPath ) ) ;
49
+ }
50
+
51
+ _options = options ;
52
+
53
+ _ownsModel = true ;
54
+ _model = new Model ( modelPath ) ;
55
+ _tokenizer = new Tokenizer ( _model ) ;
56
+
57
+ _metadata = new ( "onnx" , new Uri ( $ "file://{ modelPath } ") , modelPath ) ;
58
+ }
59
+
60
+ /// <summary>Initializes an instance of the <see cref="OnnxRuntimeGenAIChatClient"/> class.</summary>
61
+ /// <param name="options">Options used to configure the client instance.</param>
62
+ /// <param name="model">The model to employ.</param>
63
+ /// <param name="ownsModel">
64
+ /// <see langword="true"/> if this <see cref="IChatClient"/> owns the <paramref name="model"/> and should
65
+ /// dispose of it when this <see cref="IChatClient"/> is disposed; otherwise, <see langword="false"/>.
66
+ /// The default is <see langword="true"/>.
67
+ /// </param>
68
+ /// <exception cref="ArgumentNullException"><paramref name="options"/> is <see langword="null"/>.</exception>
69
+ /// <exception cref="ArgumentNullException"><paramref name="model"/> is <see langword="null"/>.</exception>
70
+ public OnnxRuntimeGenAIChatClient ( OnnxRuntimeGenAIChatClientOptions options , Model model , bool ownsModel = true )
71
+ {
72
+ if ( options is null )
73
+ {
74
+ throw new ArgumentNullException ( nameof ( options ) ) ;
75
+ }
76
+
77
+ if ( model is null )
78
+ {
79
+ throw new ArgumentNullException ( nameof ( model ) ) ;
80
+ }
81
+
82
+ _options = options ;
83
+
84
+ _ownsModel = ownsModel ;
85
+ _model = model ;
86
+ _tokenizer = new Tokenizer ( _model ) ;
87
+
88
+ _metadata = new ( "onnx" ) ;
89
+ }
90
+
91
+ /// <inheritdoc/>
92
+ public void Dispose ( )
93
+ {
94
+ if ( Interlocked . Exchange ( ref _cachedGenerator , null ) is CachedGenerator cachedGenerator )
95
+ {
96
+ cachedGenerator . Dispose ( ) ;
97
+ }
98
+
99
+ _tokenizer . Dispose ( ) ;
100
+
101
+ if ( _ownsModel )
102
+ {
103
+ _model . Dispose ( ) ;
104
+ }
105
+ }
106
+
107
+ /// <inheritdoc/>
108
+ public Task < ChatResponse > GetResponseAsync (
109
+ IList < ChatMessage > chatMessages , ChatOptions ? options = null , CancellationToken cancellationToken = default ) =>
110
+ GetStreamingResponseAsync ( chatMessages , options , cancellationToken ) . ToChatResponseAsync ( cancellationToken : cancellationToken ) ;
111
+
112
+ /// <inheritdoc/>
113
+ public async IAsyncEnumerable < ChatResponseUpdate > GetStreamingResponseAsync (
114
+ IList < ChatMessage > chatMessages , ChatOptions ? options = null , [ EnumeratorCancellation ] CancellationToken cancellationToken = default )
115
+ {
116
+ if ( chatMessages is null )
117
+ {
118
+ throw new ArgumentNullException ( nameof ( chatMessages ) ) ;
119
+ }
120
+
121
+ // Check to see whether there's a cached generator. If there is, and if its id matches what we got from the client,
122
+ // we can use it; otherwise, we need to create a new one.
123
+ CachedGenerator ? generator = Interlocked . Exchange ( ref _cachedGenerator , null ) ;
124
+ if ( generator is null ||
125
+ generator . ChatThreadId is null ||
126
+ generator . ChatThreadId != options ? . ChatThreadId )
127
+ {
128
+ generator ? . Dispose ( ) ;
129
+
130
+ using GeneratorParams p = new ( _model ) ; // we can dispose of this after we create the generator
131
+ UpdateGeneratorParamsFromOptions ( p , options ) ;
132
+ generator = new ( new Generator ( _model , p ) ) ;
133
+ }
134
+
135
+ // If caching is enabled, generate a new ID to represent the state of the generator when we finish this response.
136
+ generator . ChatThreadId = _options . EnableCaching ? Guid . NewGuid ( ) . ToString ( "N" ) : null ;
137
+
138
+ // Format and tokenize the message.
139
+ using Sequences tokens = _tokenizer . Encode ( _options . PromptFormatter ( chatMessages ) ) ;
140
+ try
141
+ {
142
+ generator . Generator . AppendTokenSequences ( tokens ) ;
143
+ int inputTokens = tokens [ 0 ] . Length , outputTokens = 0 ;
144
+
145
+ // Loop while we still want to produce more tokens.
146
+ using var tokenizerStream = _tokenizer . CreateStream ( ) ;
147
+ while ( ! generator . Generator . IsDone ( ) )
148
+ {
149
+ // If we've reached a max output token limit, stop.
150
+ if ( options ? . MaxOutputTokens is int maxOutputTokens &&
151
+ outputTokens >= maxOutputTokens )
152
+ {
153
+ break ;
154
+ }
155
+
156
+ // Avoid blocking calling thread with expensive compute
157
+ await YieldAwaiter . Instance ;
158
+
159
+ // Generate the next token.
160
+ generator . Generator . GenerateNextToken ( ) ;
161
+ string next = tokenizerStream . Decode ( GetLastToken ( generator . Generator . GetSequence ( 0 ) ) ) ;
162
+
163
+ // workaround until C# 13 is adopted and ref locals are usable in async methods
164
+ static int GetLastToken ( ReadOnlySpan < int > span ) => span [ span . Length - 1 ] ;
165
+
166
+ // If this token is a stop token, bail.
167
+ if ( IsStop ( next , options ) )
168
+ {
169
+ break ;
170
+ }
171
+
172
+ // Yield the next token in the stream.
173
+ outputTokens ++ ;
174
+ yield return new ( )
175
+ {
176
+ CreatedAt = DateTimeOffset . UtcNow ,
177
+ Role = ChatRole . Assistant ,
178
+ Text = next ,
179
+ } ;
180
+ }
181
+
182
+ // Yield a final update containing metadata.
183
+ yield return new ( )
184
+ {
185
+ ChatThreadId = generator . ChatThreadId ,
186
+ Contents = [ new UsageContent ( new ( )
187
+ {
188
+ InputTokenCount = inputTokens ,
189
+ OutputTokenCount = outputTokens ,
190
+ TotalTokenCount = inputTokens + outputTokens ,
191
+ } ) ] ,
192
+ CreatedAt = DateTimeOffset . UtcNow ,
193
+ FinishReason = options is not null && options . MaxOutputTokens <= outputTokens ? ChatFinishReason . Length : ChatFinishReason . Stop ,
194
+ ModelId = _metadata . ModelId ,
195
+ ResponseId = Guid . NewGuid ( ) . ToString ( ) ,
196
+ Role = ChatRole . Assistant ,
197
+ } ;
198
+ }
199
+ finally
200
+ {
201
+ // Cache the generator for subsequent use if it's cachable and there isn't already a generator cached.
202
+ if ( generator . ChatThreadId is null ||
203
+ Interlocked . CompareExchange ( ref _cachedGenerator , generator , null ) != null )
204
+ {
205
+ generator . Dispose ( ) ;
206
+ }
207
+ }
208
+ }
209
+
210
+ /// <inheritdoc/>
211
+ object ? IChatClient . GetService ( Type serviceType , object ? serviceKey )
212
+ {
213
+ if ( serviceType is null )
214
+ {
215
+ throw new ArgumentNullException ( nameof ( serviceType ) ) ;
216
+ }
217
+
218
+ return
219
+ serviceKey is not null ? null :
220
+ serviceType == typeof ( ChatClientMetadata ) ? _metadata :
221
+ serviceType == typeof ( Model ) ? _model :
222
+ serviceType == typeof ( Tokenizer ) ? _tokenizer :
223
+ serviceType ? . IsInstanceOfType ( this ) is true ? this :
224
+ null ;
225
+ }
226
+
227
+ /// <summary>Gets whether the specified token is a stop sequence.</summary>
228
+ private bool IsStop ( string token , ChatOptions ? options ) =>
229
+ options ? . StopSequences ? . Contains ( token ) is true ||
230
+ _options . StopSequences . Contains ( token ) ;
231
+
232
+ /// <summary>Updates the <paramref name="generatorParams"/> based on the supplied <paramref name="options"/>.</summary>
233
+ private static void UpdateGeneratorParamsFromOptions ( GeneratorParams generatorParams , ChatOptions ? options )
234
+ {
235
+ if ( options is null )
236
+ {
237
+ return ;
238
+ }
239
+
240
+ if ( options . Temperature . HasValue )
241
+ {
242
+ generatorParams . SetSearchOption ( "temperature" , options . Temperature . Value ) ;
243
+ }
244
+
245
+ if ( options . PresencePenalty . HasValue )
246
+ {
247
+ generatorParams . SetSearchOption ( "repetition_penalty" , options . PresencePenalty . Value ) ;
248
+ }
249
+
250
+ if ( options . TopP . HasValue || options . TopK . HasValue )
251
+ {
252
+ if ( options . TopP . HasValue )
253
+ {
254
+ generatorParams . SetSearchOption ( "top_p" , options . TopP . Value ) ;
255
+ }
256
+
257
+ if ( options . TopK . HasValue )
258
+ {
259
+ generatorParams . SetSearchOption ( "top_k" , options . TopK . Value ) ;
260
+ }
261
+
262
+ generatorParams . SetSearchOption ( "do_sample" , true ) ;
263
+ }
264
+
265
+ if ( options . Seed . HasValue )
266
+ {
267
+ generatorParams . SetSearchOption ( "random_seed" , options . Seed . Value ) ;
268
+ }
269
+
270
+ if ( options . AdditionalProperties is { } props )
271
+ {
272
+ foreach ( var entry in props )
273
+ {
274
+ if ( entry . Value is bool b )
275
+ {
276
+ generatorParams . SetSearchOption ( entry . Key , b ) ;
277
+ }
278
+ else if ( entry . Value is not null )
279
+ {
280
+ try
281
+ {
282
+ double d = Convert . ToDouble ( entry . Value ) ;
283
+ generatorParams . SetSearchOption ( entry . Key , d ) ;
284
+ }
285
+ catch
286
+ {
287
+ // Ignore values we can't convert
288
+ }
289
+ }
290
+ }
291
+ }
292
+ }
293
+
294
+ private sealed class CachedGenerator ( Generator generator ) : IDisposable
295
+ {
296
+ public Generator Generator { get ; } = generator ;
297
+
298
+ public string ? ChatThreadId { get ; set ; }
299
+
300
+ public void Dispose ( ) => Generator ? . Dispose ( ) ;
301
+ }
302
+
303
+ /// <summary>Polyfill for Task.CompletedTask.ConfigureAwait(ConfigureAwaitOptions.ForceYielding);</summary>
304
+ private sealed class YieldAwaiter : INotifyCompletion
305
+ {
306
+ public static YieldAwaiter Instance { get ; } = new ( ) ;
307
+ public YieldAwaiter GetAwaiter ( ) => this ;
308
+ public bool IsCompleted => false ;
309
+ public void OnCompleted ( Action continuation ) => Task . Run ( continuation ) ;
310
+ public void UnsafeOnCompleted ( Action continuation ) => Task . Run ( continuation ) ;
311
+ public void GetResult ( ) { }
312
+ }
313
+ }
0 commit comments