1
+ using Microsoft . Extensions . AI ;
2
+ using System ;
3
+ using System . Collections . Generic ;
4
+ using System . Reflection ;
5
+ using System . Runtime . CompilerServices ;
6
+ using System . Text ;
7
+ using System . Threading ;
8
+ using System . Threading . Tasks ;
9
+
10
+ namespace Microsoft . ML . OnnxRuntimeGenAI ;
11
+
12
+ /// <summary>An <see cref="IChatClient"/> implementation based on ONNX Runtime GenAI.</summary>
13
+ public sealed class ChatClient : IChatClient , IDisposable
14
+ {
15
+ /// <summary>The wrapped <see cref="Model"/>.</summary>
16
+ private readonly Model _model ;
17
+ /// <summary>The wrapped <see cref="Tokenizer"/>.</summary>
18
+ private readonly Tokenizer _tokenizer ;
19
+ /// <summary>Whether to dispose of <see cref="_model"/> when this instance is disposed.</summary>
20
+ private readonly bool _ownsModel ;
21
+
22
+ /// <summary>Initializes an instance of the <see cref="ChatClient"/> class.</summary>
23
+ /// <param name="modelPath">The file path to the model to load.</param>
24
+ /// <exception cref="ArgumentNullException"><paramref name="modelPath"/> is null.</exception>
25
+ public ChatClient ( string modelPath )
26
+ {
27
+ if ( modelPath is null )
28
+ {
29
+ throw new ArgumentNullException ( nameof ( modelPath ) ) ;
30
+ }
31
+
32
+ _ownsModel = true ;
33
+ _model = new Model ( modelPath ) ;
34
+ _tokenizer = new Tokenizer ( _model ) ;
35
+
36
+ Metadata = new ( typeof ( ChatClient ) . Namespace , new Uri ( $ "file://{ modelPath } ") , modelPath ) ;
37
+ }
38
+
39
+ /// <summary>Initializes an instance of the <see cref="ChatClient"/> class.</summary>
40
+ /// <param name="model">The model to employ.</param>
41
+ /// <param name="ownsModel">
42
+ /// <see langword="true"/> if this <see cref="IChatClient"/> owns the <paramref name="model"/> and should
43
+ /// dispose of it when this <see cref="IChatClient"/> is disposed; otherwise, <see langword="false"/>.
44
+ /// The default is <see langword="true"/>.
45
+ /// </param>
46
+ /// <exception cref="ArgumentNullException"><paramref name="model"/> is null.</exception>
47
+ public ChatClient ( Model model , bool ownsModel = true )
48
+ {
49
+ if ( model is null )
50
+ {
51
+ throw new ArgumentNullException ( nameof ( model ) ) ;
52
+ }
53
+
54
+ _ownsModel = ownsModel ;
55
+ _model = model ;
56
+ _tokenizer = new Tokenizer ( _model ) ;
57
+
58
+ Metadata = new ( "Microsoft.ML.OnnxRuntimeGenAI" ) ;
59
+ }
60
+
61
+ /// <inheritdoc/>
62
+ public ChatClientMetadata Metadata { get ; }
63
+
64
+ /// <inheritdoc/>
65
+ public void Dispose ( )
66
+ {
67
+ _tokenizer . Dispose ( ) ;
68
+
69
+ if ( _ownsModel )
70
+ {
71
+ _model . Dispose ( ) ;
72
+ }
73
+ }
74
+
75
+ /// <inheritdoc/>
76
+ public Task < ChatCompletion > CompleteAsync ( IList < ChatMessage > chatMessages , ChatOptions options = null , CancellationToken cancellationToken = default )
77
+ {
78
+ if ( chatMessages is null )
79
+ {
80
+ throw new ArgumentNullException ( nameof ( chatMessages ) ) ;
81
+ }
82
+
83
+ return Task . Run ( ( ) =>
84
+ {
85
+ using Sequences tokens = _tokenizer . Encode ( CreatePrompt ( chatMessages ) ) ;
86
+ using GeneratorParams generatorParams = new ( _model ) ;
87
+ UpdateGeneratorParamsFromOptions ( tokens [ 0 ] . Length , generatorParams , options ) ;
88
+ generatorParams . SetInputSequences ( tokens ) ;
89
+
90
+ using Generator generator = new ( _model , generatorParams ) ;
91
+ using Sequences outputSequences = _model . Generate ( generatorParams ) ;
92
+
93
+ return new ChatCompletion ( new ChatMessage ( ChatRole . Assistant , _tokenizer . Decode ( outputSequences [ 0 ] ) ) )
94
+ {
95
+ CompletionId = Guid . NewGuid ( ) . ToString ( ) ,
96
+ CreatedAt = DateTimeOffset . UtcNow ,
97
+ ModelId = Metadata . ModelId ,
98
+ } ;
99
+ } , cancellationToken ) ;
100
+ }
101
+
102
+ /// <inheritdoc/>
103
+ public async IAsyncEnumerable < StreamingChatCompletionUpdate > CompleteStreamingAsync (
104
+ IList < ChatMessage > chatMessages , ChatOptions options = null , [ EnumeratorCancellation ] CancellationToken cancellationToken = default )
105
+ {
106
+ if ( chatMessages is null )
107
+ {
108
+ throw new ArgumentNullException ( nameof ( chatMessages ) ) ;
109
+ }
110
+
111
+ using Sequences tokens = _tokenizer . Encode ( CreatePrompt ( chatMessages ) ) ;
112
+ using GeneratorParams generatorParams = new ( _model ) ;
113
+ UpdateGeneratorParamsFromOptions ( tokens [ 0 ] . Length , generatorParams , options ) ;
114
+ generatorParams . SetInputSequences ( tokens ) ;
115
+
116
+ using Generator generator = new ( _model , generatorParams ) ;
117
+ using var tokenizerStream = _tokenizer . CreateStream ( ) ;
118
+
119
+ var completionId = Guid . NewGuid ( ) . ToString ( ) ;
120
+ while ( ! generator . IsDone ( ) )
121
+ {
122
+ string next = await Task . Run ( ( ) =>
123
+ {
124
+ generator . ComputeLogits ( ) ;
125
+ generator . GenerateNextToken ( ) ;
126
+
127
+ ReadOnlySpan < int > outputSequence = generator . GetSequence ( 0 ) ;
128
+ return tokenizerStream . Decode ( outputSequence [ outputSequence . Length - 1 ] ) ;
129
+ } , cancellationToken ) ;
130
+
131
+ yield return new StreamingChatCompletionUpdate
132
+ {
133
+ CompletionId = completionId ,
134
+ CreatedAt = DateTimeOffset . UtcNow ,
135
+ Role = ChatRole . Assistant ,
136
+ Text = next ,
137
+ } ;
138
+ }
139
+ }
140
+
141
+ /// <inheritdoc/>
142
+ public TService GetService < TService > ( object key = null ) where TService : class =>
143
+ typeof ( TService ) == typeof ( Model ) ? ( TService ) ( object ) _model :
144
+ typeof ( TService ) == typeof ( Tokenizer ) ? ( TService ) ( object ) _tokenizer :
145
+ this as TService ;
146
+
147
+ /// <summary>Creates a prompt string from the supplied chat history.</summary>
148
+ private string CreatePrompt ( IEnumerable < ChatMessage > messages )
149
+ {
150
+ StringBuilder prompt = new ( ) ;
151
+
152
+ foreach ( var message in messages )
153
+ {
154
+ foreach ( var content in message . Contents )
155
+ {
156
+ switch ( content )
157
+ {
158
+ case TextContent tc when ! string . IsNullOrWhiteSpace ( tc . Text ) :
159
+ prompt . Append ( "<|" ) . Append ( message . Role . Value ) . Append ( "|>\n " ) . Append ( tc . Text ) ;
160
+ break ;
161
+ }
162
+ }
163
+ }
164
+
165
+ return prompt . Append ( "<|end|>\n <|assistant|>" ) . ToString ( ) ;
166
+ }
167
+
168
+ /// <summary>Updates the <paramref name="generatorParams"/> based on the supplied <paramref name="options"/>.</summary>
169
+ private static void UpdateGeneratorParamsFromOptions ( int numInputTokens , GeneratorParams generatorParams , ChatOptions options )
170
+ {
171
+ if ( options is null )
172
+ {
173
+ return ;
174
+ }
175
+
176
+ if ( options . Temperature . HasValue )
177
+ {
178
+ generatorParams . SetSearchOption ( "temperature" , options . Temperature . Value ) ;
179
+ }
180
+
181
+ if ( options . TopP . HasValue )
182
+ {
183
+ generatorParams . SetSearchOption ( "top_p" , options . TopP . Value ) ;
184
+ }
185
+
186
+ if ( options . MaxOutputTokens . HasValue )
187
+ {
188
+ generatorParams . SetSearchOption ( "max_length" , numInputTokens + options . MaxOutputTokens . Value ) ;
189
+ }
190
+
191
+ if ( options . AdditionalProperties is { } props )
192
+ {
193
+ foreach ( var entry in props )
194
+ {
195
+ switch ( entry . Value )
196
+ {
197
+ case int i : generatorParams . SetSearchOption ( entry . Key , i ) ; break ;
198
+ case long l : generatorParams . SetSearchOption ( entry . Key , l ) ; break ;
199
+ case float f : generatorParams . SetSearchOption ( entry . Key , f ) ; break ;
200
+ case double d : generatorParams . SetSearchOption ( entry . Key , d ) ; break ;
201
+ case bool b : generatorParams . SetSearchOption ( entry . Key , b ) ; break ;
202
+ }
203
+ }
204
+ }
205
+ }
206
+ }
0 commit comments