1
+ using Microsoft . Extensions . AI ;
2
+ using System ;
3
+ using System . Collections . Generic ;
4
+ using System . Runtime . CompilerServices ;
5
+ using System . Text ;
6
+ using System . Threading ;
7
+ using System . Threading . Tasks ;
8
+
9
+ namespace Microsoft . ML . OnnxRuntimeGenAI ;
10
+
11
+ /// <summary>Provides an <see cref="IChatClient"/> implementation for interacting with a <see cref="Model"/>.</summary>
12
+ public sealed partial class ChatClient : IChatClient
13
+ {
14
+ /// <summary>The options used to configure the instance.</summary>
15
+ private readonly ChatClientConfiguration _config ;
16
+ /// <summary>The wrapped <see cref="Model"/>.</summary>
17
+ private readonly Model _model ;
18
+ /// <summary>The wrapped <see cref="Tokenizer"/>.</summary>
19
+ private readonly Tokenizer _tokenizer ;
20
+ /// <summary>Whether to dispose of <see cref="_model"/> when this instance is disposed.</summary>
21
+ private readonly bool _ownsModel ;
22
+
23
+ /// <summary>Initializes an instance of the <see cref="ChatClient"/> class.</summary>
24
+ /// <param name="configuration">Options used to configure the client instance.</param>
25
+ /// <param name="modelPath">The file path to the model to load.</param>
26
+ /// <exception cref="ArgumentNullException"><paramref name="modelPath"/> is null.</exception>
27
+ public ChatClient ( ChatClientConfiguration configuration , string modelPath )
28
+ {
29
+ if ( configuration is null )
30
+ {
31
+ throw new ArgumentNullException ( nameof ( configuration ) ) ;
32
+ }
33
+
34
+ if ( modelPath is null )
35
+ {
36
+ throw new ArgumentNullException ( nameof ( modelPath ) ) ;
37
+ }
38
+
39
+ _config = configuration ;
40
+
41
+ _ownsModel = true ;
42
+ _model = new Model ( modelPath ) ;
43
+ _tokenizer = new Tokenizer ( _model ) ;
44
+
45
+ Metadata = new ( "onnxruntime-genai" , new Uri ( $ "file://{ modelPath } ") , modelPath ) ;
46
+ }
47
+
48
+ /// <summary>Initializes an instance of the <see cref="ChatClient"/> class.</summary>
49
+ /// <param name="configuration">Options used to configure the client instance.</param>
50
+ /// <param name="model">The model to employ.</param>
51
+ /// <param name="ownsModel">
52
+ /// <see langword="true"/> if this <see cref="IChatClient"/> owns the <paramref name="model"/> and should
53
+ /// dispose of it when this <see cref="IChatClient"/> is disposed; otherwise, <see langword="false"/>.
54
+ /// The default is <see langword="true"/>.
55
+ /// </param>
56
+ /// <exception cref="ArgumentNullException"><paramref name="model"/> is null.</exception>
57
+ public ChatClient ( ChatClientConfiguration configuration , Model model , bool ownsModel = true )
58
+ {
59
+ if ( configuration is null )
60
+ {
61
+ throw new ArgumentNullException ( nameof ( configuration ) ) ;
62
+ }
63
+
64
+ if ( model is null )
65
+ {
66
+ throw new ArgumentNullException ( nameof ( model ) ) ;
67
+ }
68
+
69
+ _config = configuration ;
70
+
71
+ _ownsModel = ownsModel ;
72
+ _model = model ;
73
+ _tokenizer = new Tokenizer ( _model ) ;
74
+
75
+ Metadata = new ( "onnxruntime-genai" ) ;
76
+ }
77
+
78
+ /// <inheritdoc/>
79
+ public ChatClientMetadata Metadata { get ; }
80
+
81
+ /// <inheritdoc/>
82
+ public void Dispose ( )
83
+ {
84
+ _tokenizer . Dispose ( ) ;
85
+
86
+ if ( _ownsModel )
87
+ {
88
+ _model . Dispose ( ) ;
89
+ }
90
+ }
91
+
92
+ /// <inheritdoc/>
93
+ public async Task < ChatCompletion > CompleteAsync ( IList < ChatMessage > chatMessages , ChatOptions options = null , CancellationToken cancellationToken = default )
94
+ {
95
+ if ( chatMessages is null )
96
+ {
97
+ throw new ArgumentNullException ( nameof ( chatMessages ) ) ;
98
+ }
99
+
100
+ StringBuilder text = new ( ) ;
101
+ await Task . Run ( ( ) =>
102
+ {
103
+ using Sequences tokens = _tokenizer . Encode ( _config . PromptFormatter ( chatMessages ) ) ;
104
+ using GeneratorParams generatorParams = new ( _model ) ;
105
+ UpdateGeneratorParamsFromOptions ( tokens [ 0 ] . Length , generatorParams , options ) ;
106
+
107
+ using Generator generator = new ( _model , generatorParams ) ;
108
+ generator . AppendTokenSequences ( tokens ) ;
109
+
110
+ using var tokenizerStream = _tokenizer . CreateStream ( ) ;
111
+
112
+ var completionId = Guid . NewGuid ( ) . ToString ( ) ;
113
+ while ( ! generator . IsDone ( ) )
114
+ {
115
+ cancellationToken . ThrowIfCancellationRequested ( ) ;
116
+
117
+ generator . GenerateNextToken ( ) ;
118
+
119
+ ReadOnlySpan < int > outputSequence = generator . GetSequence ( 0 ) ;
120
+ string next = tokenizerStream . Decode ( outputSequence [ outputSequence . Length - 1 ] ) ;
121
+
122
+ if ( IsStop ( next , options ) )
123
+ {
124
+ break ;
125
+ }
126
+
127
+ text . Append ( next ) ;
128
+ }
129
+ } , cancellationToken ) ;
130
+
131
+ return new ChatCompletion ( new ChatMessage ( ChatRole . Assistant , text . ToString ( ) ) )
132
+ {
133
+ CompletionId = Guid . NewGuid ( ) . ToString ( ) ,
134
+ CreatedAt = DateTimeOffset . UtcNow ,
135
+ ModelId = Metadata . ModelId ,
136
+ } ;
137
+ }
138
+
139
+ /// <inheritdoc/>
140
+ public async IAsyncEnumerable < StreamingChatCompletionUpdate > CompleteStreamingAsync (
141
+ IList < ChatMessage > chatMessages , ChatOptions options = null , [ EnumeratorCancellation ] CancellationToken cancellationToken = default )
142
+ {
143
+ if ( chatMessages is null )
144
+ {
145
+ throw new ArgumentNullException ( nameof ( chatMessages ) ) ;
146
+ }
147
+
148
+ using Sequences tokens = _tokenizer . Encode ( _config . PromptFormatter ( chatMessages ) ) ;
149
+ using GeneratorParams generatorParams = new ( _model ) ;
150
+ UpdateGeneratorParamsFromOptions ( tokens [ 0 ] . Length , generatorParams , options ) ;
151
+
152
+ using Generator generator = new ( _model , generatorParams ) ;
153
+ generator . AppendTokenSequences ( tokens ) ;
154
+
155
+ using var tokenizerStream = _tokenizer . CreateStream ( ) ;
156
+
157
+ var completionId = Guid . NewGuid ( ) . ToString ( ) ;
158
+ while ( ! generator . IsDone ( ) )
159
+ {
160
+ string next = await Task . Run ( ( ) =>
161
+ {
162
+ generator . GenerateNextToken ( ) ;
163
+
164
+ ReadOnlySpan < int > outputSequence = generator . GetSequence ( 0 ) ;
165
+ return tokenizerStream . Decode ( outputSequence [ outputSequence . Length - 1 ] ) ;
166
+ } , cancellationToken ) ;
167
+
168
+ if ( IsStop ( next , options ) )
169
+ {
170
+ break ;
171
+ }
172
+
173
+ yield return new StreamingChatCompletionUpdate
174
+ {
175
+ CompletionId = completionId ,
176
+ CreatedAt = DateTimeOffset . UtcNow ,
177
+ Role = ChatRole . Assistant ,
178
+ Text = next ,
179
+ } ;
180
+ }
181
+ }
182
+
183
+ /// <inheritdoc/>
184
+ public object GetService ( Type serviceType , object key = null ) =>
185
+ key is not null ? null :
186
+ serviceType == typeof ( Model ) ? _model :
187
+ serviceType == typeof ( Tokenizer ) ? _tokenizer :
188
+ serviceType ? . IsInstanceOfType ( this ) is true ? this :
189
+ null ;
190
+
191
+ /// <summary>Gets whether the specified token is a stop sequence.</summary>
192
+ private bool IsStop ( string token , ChatOptions options ) =>
193
+ options ? . StopSequences ? . Contains ( token ) is true ||
194
+ Array . IndexOf ( _config . StopSequences , token ) >= 0 ;
195
+
196
+ /// <summary>Updates the <paramref name="generatorParams"/> based on the supplied <paramref name="options"/>.</summary>
197
+ private static void UpdateGeneratorParamsFromOptions ( int numInputTokens , GeneratorParams generatorParams , ChatOptions options )
198
+ {
199
+ if ( options is null )
200
+ {
201
+ return ;
202
+ }
203
+
204
+ if ( options . MaxOutputTokens . HasValue )
205
+ {
206
+ generatorParams . SetSearchOption ( "max_length" , numInputTokens + options . MaxOutputTokens . Value ) ;
207
+ }
208
+
209
+ if ( options . Temperature . HasValue )
210
+ {
211
+ generatorParams . SetSearchOption ( "temperature" , options . Temperature . Value ) ;
212
+ }
213
+
214
+ if ( options . TopP . HasValue || options . TopK . HasValue )
215
+ {
216
+ if ( options . TopP . HasValue )
217
+ {
218
+ generatorParams . SetSearchOption ( "top_p" , options . TopP . Value ) ;
219
+ }
220
+
221
+ if ( options . TopK . HasValue )
222
+ {
223
+ generatorParams . SetSearchOption ( "top_k" , options . TopK . Value ) ;
224
+ }
225
+ }
226
+
227
+ if ( options . Seed . HasValue )
228
+ {
229
+ generatorParams . SetSearchOption ( "random_seed" , options . Seed . Value ) ;
230
+ }
231
+
232
+ if ( options . AdditionalProperties is { } props )
233
+ {
234
+ foreach ( var entry in props )
235
+ {
236
+ switch ( entry . Value )
237
+ {
238
+ case int i : generatorParams . SetSearchOption ( entry . Key , i ) ; break ;
239
+ case long l : generatorParams . SetSearchOption ( entry . Key , l ) ; break ;
240
+ case float f : generatorParams . SetSearchOption ( entry . Key , f ) ; break ;
241
+ case double d : generatorParams . SetSearchOption ( entry . Key , d ) ; break ;
242
+ case bool b : generatorParams . SetSearchOption ( entry . Key , b ) ; break ;
243
+ }
244
+ }
245
+ }
246
+ }
247
+ }
0 commit comments