1
+ using Microsoft . Extensions . AI ;
2
+ using System ;
3
+ using System . Collections . Generic ;
4
+ using System . Runtime . CompilerServices ;
5
+ using System . Text ;
6
+ using System . Threading ;
7
+ using System . Threading . Tasks ;
8
+
9
+ namespace Microsoft . ML . OnnxRuntimeGenAI ;
10
+
11
+ /// <summary>Provides an <see cref="IChatClient"/> implementation for interacting with a <see cref="Model"/>.</summary>
12
+ public sealed partial class ChatClient : IChatClient
13
+ {
14
+ /// <summary>The options used to configure the instance.</summary>
15
+ private readonly ChatClientConfiguration _config ;
16
+ /// <summary>The wrapped <see cref="Model"/>.</summary>
17
+ private readonly Model _model ;
18
+ /// <summary>The wrapped <see cref="Tokenizer"/>.</summary>
19
+ private readonly Tokenizer _tokenizer ;
20
+ /// <summary>Whether to dispose of <see cref="_model"/> when this instance is disposed.</summary>
21
+ private readonly bool _ownsModel ;
22
+
23
+ /// <summary>Initializes an instance of the <see cref="ChatClient"/> class.</summary>
24
+ /// <param name="configuration">Options used to configure the client instance.</param>
25
+ /// <param name="modelPath">The file path to the model to load.</param>
26
+ /// <exception cref="ArgumentNullException"><paramref name="modelPath"/> is null.</exception>
27
+ public ChatClient ( ChatClientConfiguration configuration , string modelPath )
28
+ {
29
+ if ( configuration is null )
30
+ {
31
+ throw new ArgumentNullException ( nameof ( configuration ) ) ;
32
+ }
33
+
34
+ if ( modelPath is null )
35
+ {
36
+ throw new ArgumentNullException ( nameof ( modelPath ) ) ;
37
+ }
38
+
39
+ _config = configuration ;
40
+
41
+ _ownsModel = true ;
42
+ _model = new Model ( modelPath ) ;
43
+ _tokenizer = new Tokenizer ( _model ) ;
44
+
45
+ Metadata = new ( "onnxruntime-genai" , new Uri ( $ "file://{ modelPath } ") , modelPath ) ;
46
+ }
47
+
48
+ /// <summary>Initializes an instance of the <see cref="ChatClient"/> class.</summary>
49
+ /// <param name="configuration">Options used to configure the client instance.</param>
50
+ /// <param name="model">The model to employ.</param>
51
+ /// <param name="ownsModel">
52
+ /// <see langword="true"/> if this <see cref="IChatClient"/> owns the <paramref name="model"/> and should
53
+ /// dispose of it when this <see cref="IChatClient"/> is disposed; otherwise, <see langword="false"/>.
54
+ /// The default is <see langword="true"/>.
55
+ /// </param>
56
+ /// <exception cref="ArgumentNullException"><paramref name="model"/> is null.</exception>
57
+ public ChatClient ( ChatClientConfiguration configuration , Model model , bool ownsModel = true )
58
+ {
59
+ if ( configuration is null )
60
+ {
61
+ throw new ArgumentNullException ( nameof ( configuration ) ) ;
62
+ }
63
+
64
+ if ( model is null )
65
+ {
66
+ throw new ArgumentNullException ( nameof ( model ) ) ;
67
+ }
68
+
69
+ _config = configuration ;
70
+
71
+ _ownsModel = ownsModel ;
72
+ _model = model ;
73
+ _tokenizer = new Tokenizer ( _model ) ;
74
+
75
+ Metadata = new ( "onnx" ) ;
76
+ }
77
+
78
+ /// <inheritdoc/>
79
+ public ChatClientMetadata Metadata { get ; }
80
+
81
+ /// <inheritdoc/>
82
+ public void Dispose ( )
83
+ {
84
+ _tokenizer . Dispose ( ) ;
85
+
86
+ if ( _ownsModel )
87
+ {
88
+ _model . Dispose ( ) ;
89
+ }
90
+ }
91
+
92
+ /// <inheritdoc/>
93
+ public async Task < ChatCompletion > CompleteAsync ( IList < ChatMessage > chatMessages , ChatOptions options = null , CancellationToken cancellationToken = default )
94
+ {
95
+ if ( chatMessages is null )
96
+ {
97
+ throw new ArgumentNullException ( nameof ( chatMessages ) ) ;
98
+ }
99
+
100
+ int inputTokens = 0 , outputTokens = 0 ;
101
+ StringBuilder text = new ( ) ;
102
+ await Task . Run ( ( ) =>
103
+ {
104
+ using Sequences tokens = _tokenizer . Encode ( _config . PromptFormatter ( chatMessages ) ) ;
105
+ using GeneratorParams generatorParams = new ( _model ) ;
106
+ UpdateGeneratorParamsFromOptions ( tokens [ 0 ] . Length , generatorParams , options ) ;
107
+
108
+ inputTokens = tokens [ 0 ] . Length ;
109
+
110
+ using Generator generator = new ( _model , generatorParams ) ;
111
+ generator . AppendTokenSequences ( tokens ) ;
112
+
113
+ using var tokenizerStream = _tokenizer . CreateStream ( ) ;
114
+
115
+ while ( ! generator . IsDone ( ) )
116
+ {
117
+ cancellationToken . ThrowIfCancellationRequested ( ) ;
118
+
119
+ generator . GenerateNextToken ( ) ;
120
+
121
+ ReadOnlySpan < int > outputSequence = generator . GetSequence ( 0 ) ;
122
+ string next = tokenizerStream . Decode ( outputSequence [ outputSequence . Length - 1 ] ) ;
123
+
124
+ if ( IsStop ( next , options ) )
125
+ {
126
+ break ;
127
+ }
128
+
129
+ outputTokens ++ ;
130
+ text . Append ( next ) ;
131
+ }
132
+ } , cancellationToken ) ;
133
+
134
+ return new ChatCompletion ( new ChatMessage ( ChatRole . Assistant , text . ToString ( ) ) )
135
+ {
136
+ CompletionId = Guid . NewGuid ( ) . ToString ( ) ,
137
+ CreatedAt = DateTimeOffset . UtcNow ,
138
+ ModelId = Metadata . ModelId ,
139
+ Usage = new ( )
140
+ {
141
+ InputTokenCount = inputTokens ,
142
+ OutputTokenCount = outputTokens ,
143
+ TotalTokenCount = inputTokens + outputTokens ,
144
+ } ,
145
+ } ;
146
+ }
147
+
148
+ /// <inheritdoc/>
149
+ public async IAsyncEnumerable < StreamingChatCompletionUpdate > CompleteStreamingAsync (
150
+ IList < ChatMessage > chatMessages , ChatOptions options = null , [ EnumeratorCancellation ] CancellationToken cancellationToken = default )
151
+ {
152
+ if ( chatMessages is null )
153
+ {
154
+ throw new ArgumentNullException ( nameof ( chatMessages ) ) ;
155
+ }
156
+
157
+ using Sequences tokens = _tokenizer . Encode ( _config . PromptFormatter ( chatMessages ) ) ;
158
+ using GeneratorParams generatorParams = new ( _model ) ;
159
+ UpdateGeneratorParamsFromOptions ( tokens [ 0 ] . Length , generatorParams , options ) ;
160
+
161
+ using Generator generator = new ( _model , generatorParams ) ;
162
+ generator . AppendTokenSequences ( tokens ) ;
163
+
164
+ using var tokenizerStream = _tokenizer . CreateStream ( ) ;
165
+
166
+ int inputTokens = tokens [ 0 ] . Length , outputTokens = 0 ;
167
+ var completionId = Guid . NewGuid ( ) . ToString ( ) ;
168
+ while ( ! generator . IsDone ( ) )
169
+ {
170
+ string next = await Task . Run ( ( ) =>
171
+ {
172
+ generator . GenerateNextToken ( ) ;
173
+
174
+ ReadOnlySpan < int > outputSequence = generator . GetSequence ( 0 ) ;
175
+ return tokenizerStream . Decode ( outputSequence [ outputSequence . Length - 1 ] ) ;
176
+ } , cancellationToken ) ;
177
+
178
+ if ( IsStop ( next , options ) )
179
+ {
180
+ break ;
181
+ }
182
+
183
+ outputTokens ++ ;
184
+ yield return new ( )
185
+ {
186
+ CompletionId = completionId ,
187
+ CreatedAt = DateTimeOffset . UtcNow ,
188
+ Role = ChatRole . Assistant ,
189
+ Text = next ,
190
+ } ;
191
+ }
192
+
193
+ yield return new ( )
194
+ {
195
+ CompletionId = completionId ,
196
+ Contents = [ new UsageContent ( new ( )
197
+ {
198
+ InputTokenCount = inputTokens ,
199
+ OutputTokenCount = outputTokens ,
200
+ TotalTokenCount = inputTokens + outputTokens ,
201
+ } ) ] ,
202
+ CreatedAt = DateTimeOffset . UtcNow ,
203
+ Role = ChatRole . Assistant ,
204
+ } ;
205
+ }
206
+
207
+ /// <inheritdoc/>
208
+ public object GetService ( Type serviceType , object key = null ) =>
209
+ key is not null ? null :
210
+ serviceType == typeof ( Model ) ? _model :
211
+ serviceType == typeof ( Tokenizer ) ? _tokenizer :
212
+ serviceType ? . IsInstanceOfType ( this ) is true ? this :
213
+ null ;
214
+
215
+ /// <summary>Gets whether the specified token is a stop sequence.</summary>
216
+ private bool IsStop ( string token , ChatOptions options ) =>
217
+ options ? . StopSequences ? . Contains ( token ) is true ||
218
+ Array . IndexOf ( _config . StopSequences , token ) >= 0 ;
219
+
220
+ /// <summary>Updates the <paramref name="generatorParams"/> based on the supplied <paramref name="options"/>.</summary>
221
+ private static void UpdateGeneratorParamsFromOptions ( int numInputTokens , GeneratorParams generatorParams , ChatOptions options )
222
+ {
223
+ if ( options is null )
224
+ {
225
+ return ;
226
+ }
227
+
228
+ if ( options . MaxOutputTokens . HasValue )
229
+ {
230
+ generatorParams . SetSearchOption ( "max_length" , numInputTokens + options . MaxOutputTokens . Value ) ;
231
+ }
232
+
233
+ if ( options . Temperature . HasValue )
234
+ {
235
+ generatorParams . SetSearchOption ( "temperature" , options . Temperature . Value ) ;
236
+ }
237
+
238
+ if ( options . PresencePenalty . HasValue )
239
+ {
240
+ generatorParams . SetSearchOption ( "repetition_penalty" , options . PresencePenalty . Value ) ;
241
+ }
242
+
243
+ if ( options . TopP . HasValue || options . TopK . HasValue )
244
+ {
245
+ if ( options . TopP . HasValue )
246
+ {
247
+ generatorParams . SetSearchOption ( "top_p" , options . TopP . Value ) ;
248
+ }
249
+
250
+ if ( options . TopK . HasValue )
251
+ {
252
+ generatorParams . SetSearchOption ( "top_k" , options . TopK . Value ) ;
253
+ }
254
+ }
255
+
256
+ if ( options . Seed . HasValue )
257
+ {
258
+ generatorParams . SetSearchOption ( "random_seed" , options . Seed . Value ) ;
259
+ }
260
+
261
+ if ( options . AdditionalProperties is { } props )
262
+ {
263
+ foreach ( var entry in props )
264
+ {
265
+ if ( entry . Value is bool b )
266
+ {
267
+ generatorParams . SetSearchOption ( entry . Key , b ) ;
268
+ }
269
+ else if ( entry . Value is not null )
270
+ {
271
+ try
272
+ {
273
+ double d = Convert . ToDouble ( entry . Value ) ;
274
+ generatorParams . SetSearchOption ( entry . Key , d ) ;
275
+ }
276
+ catch
277
+ {
278
+ // Ignore values we can't convert
279
+ }
280
+ }
281
+ }
282
+ }
283
+ }
284
+ }
0 commit comments