1
+ /* eslint-disable @typescript-eslint/no-non-null-assertion */
2
+ /* eslint-disable no-prototype-builtins */
1
3
import * as tvmjs from "tvmjs" ;
2
4
import { Tokenizer } from "@mlc-ai/web-tokenizers" ;
3
5
import { ChatConfig } from "./config" ;
@@ -16,6 +18,12 @@ export class LLMChatPipeline {
16
18
private prefill : tvmjs . PackedFunc ;
17
19
private decoding : tvmjs . PackedFunc ;
18
20
private fclearKVCaches : tvmjs . PackedFunc ;
21
+ // Functions for PagedKVCache only
22
+ private embed ?: tvmjs . PackedFunc = undefined ;
23
+ private fKVCacheAddSequence ?: tvmjs . PackedFunc = undefined ;
24
+ private fKVCacheRemoveSequence ?: tvmjs . PackedFunc = undefined ;
25
+ private fKVCacheBeginForward ?: tvmjs . PackedFunc = undefined ;
26
+ private fKVCacheEndForward ?: tvmjs . PackedFunc = undefined ;
19
27
20
28
// parameter states
21
29
private params : tvmjs . TVMObject ;
@@ -41,10 +49,11 @@ export class LLMChatPipeline {
41
49
private appearedTokens = new Set < number > ( ) ;
42
50
private conversation : Conversation ;
43
51
// Whether sink is in action
44
- private sinkTriggered : boolean = false ;
52
+ private sinkTriggered = false ;
45
53
// sliding window cache offset (Next position to be overridden on the rolling kv cache.)
46
- private slidingWindowCacheOffset : number = 0 ;
47
- // Total amount of seq len prefilled so far
54
+ private slidingWindowCacheOffset = 0 ;
55
+ // Whether we are using PagedKVCache (eventually this will become default)
56
+ private usePagedKVCache = false ;
48
57
49
58
// stats
50
59
private decodingTotalTime = 0 ;
@@ -59,6 +68,7 @@ export class LLMChatPipeline {
59
68
private logitProcessor ?: LogitProcessor = undefined ;
60
69
61
70
constructor ( tvm : tvmjs . Instance , tokenizer : Tokenizer , config : ChatConfig , logitProcessor ?: LogitProcessor ) {
71
+ // 0. Setting attributes
62
72
this . tvm = tvm ;
63
73
this . tokenizer = tokenizer ;
64
74
this . config = config ;
@@ -69,19 +79,28 @@ export class LLMChatPipeline {
69
79
this . stopTokens = this . conversation . getStopTokens ( ) ;
70
80
71
81
this . device = this . tvm . webgpu ( ) ;
72
- tvm . beginScope ( ) ;
73
82
83
+ // 1. Create VM and get the core functions
84
+ tvm . beginScope ( ) ;
74
85
this . vm = this . tvm . detachFromCurrentScope (
75
86
this . tvm . createVirtualMachine ( this . device )
76
87
) ;
77
88
this . prefill = this . tvm . detachFromCurrentScope (
78
89
this . vm . getFunction ( "prefill" )
79
90
) ;
91
+ try {
92
+ // We expect to find `embed` if we usePagedKVCache
93
+ this . embed = this . tvm . detachFromCurrentScope (
94
+ this . vm . getFunction ( "embed" )
95
+ ) ;
96
+ } catch {
97
+ // Do nothing
98
+ }
80
99
this . decoding = this . tvm . detachFromCurrentScope (
81
100
this . vm . getFunction ( "decode" )
82
101
) ;
83
102
84
- // Get json stored in the vm's metadata function
103
+ // 2. Get json stored in the vm's metadata function
85
104
let fgetMetadata ;
86
105
let useSLIM = false ; // SLIM is the new workflow
87
106
try {
@@ -94,7 +113,7 @@ export class LLMChatPipeline {
94
113
const metadataStr = this . tvm . detachFromCurrentScope ( ret_value ) . toString ( ) ;
95
114
const metadata = JSON . parse ( metadataStr ) ;
96
115
97
- // Load parameters
116
+ // 3. Load parameters
98
117
if ( useSLIM ) {
99
118
// Under SLIM workflow, we load parameters by name
100
119
const paramNames : string [ ] = [ ] ;
@@ -109,14 +128,16 @@ export class LLMChatPipeline {
109
128
) ;
110
129
}
111
130
112
- if ( metadata . hasOwnProperty ( "prefill_chunk_size" ) && metadata . prefill_chunk_size != - 1 ) {
131
+ // 4. Read in compilation configurations from metadata
132
+ if ( metadata . hasOwnProperty ( "prefill_chunk_size" ) ) {
113
133
this . prefillChunkSize = metadata . prefill_chunk_size ;
114
134
this . logger ( "Using prefillChunkSize: " , this . prefillChunkSize ) ;
115
135
if ( this . prefillChunkSize <= 0 ) {
116
136
throw Error ( "Prefill chunk size needs to be positive." ) ;
117
137
}
138
+ } else {
139
+ throw Error ( "Cannot find `prefill_chunk_size` in metadta; make sure the wasm is up to date." ) ;
118
140
}
119
-
120
141
// Only use one of slidingWindowSize and maxWindowLength
121
142
if ( metadata . hasOwnProperty ( "sliding_window_size" ) && metadata . sliding_window_size != - 1 ) {
122
143
this . slidingWindowSize = metadata . sliding_window_size ;
@@ -149,20 +170,60 @@ export class LLMChatPipeline {
149
170
}
150
171
}
151
172
173
+ // 5. Create cache
174
+ // Use `fcreateCache` to determine whether we are using the new KVCache implementation
152
175
let fcreateCache ;
153
- if ( useSLIM ) {
154
- fcreateCache = this . vm . getFunction ( "_initialize_effect" ) ;
155
- } else {
156
- fcreateCache = this . vm . getFunction ( "create_kv_cache" ) ;
176
+ try {
177
+ if ( useSLIM ) {
178
+ fcreateCache = this . vm . getFunction ( "_initialize_effect" ) ;
179
+ } else {
180
+ fcreateCache = this . vm . getFunction ( "create_kv_cache" ) ;
181
+ }
182
+ } catch ( err ) {
183
+ // If we cannot find function above, it means that we are using the new PagedKVCache
184
+ this . usePagedKVCache = true ;
185
+ fcreateCache = this . vm . getFunction ( "create_tir_paged_kv_cache" ) ;
186
+ console . log ( "Using Paged KVCache" ) ;
187
+ if ( this . embed === undefined ) {
188
+ throw Error ( "If using paged KVCache, method `embed()` needs to be defined." ) ;
189
+ }
157
190
}
158
191
159
- this . fclearKVCaches = this . tvm . detachFromCurrentScope (
160
- this . tvm . getGlobalFunc ( "vm.builtin.attention_kv_cache_array_clear" )
161
- ) ;
192
+ // Load cache functions and instantiate KVCache
193
+ if ( this . usePagedKVCache ) {
194
+ this . fclearKVCaches = this . tvm . detachFromCurrentScope (
195
+ this . tvm . getGlobalFunc ( "vm.builtin.paged_attention_kv_cache_clear" )
196
+ ) ;
197
+ this . fKVCacheAddSequence = this . tvm . detachFromCurrentScope (
198
+ this . tvm . getGlobalFunc ( "vm.builtin.paged_attention_kv_cache_add_sequence" )
199
+ ) ;
200
+ this . fKVCacheRemoveSequence = this . tvm . detachFromCurrentScope (
201
+ this . tvm . getGlobalFunc ( "vm.builtin.paged_attention_kv_cache_remove_sequence" )
202
+ ) ;
203
+ this . fKVCacheBeginForward = this . tvm . detachFromCurrentScope (
204
+ this . tvm . getGlobalFunc ( "vm.builtin.paged_attention_kv_cache_begin_forward" )
205
+ ) ;
206
+ this . fKVCacheEndForward = this . tvm . detachFromCurrentScope (
207
+ this . tvm . getGlobalFunc ( "vm.builtin.paged_attention_kv_cache_end_forward" )
208
+ ) ;
162
209
163
- // use extern config for now
164
- this . kvCache = this . tvm . detachFromCurrentScope ( fcreateCache ( ) ) ;
210
+ // Create PagedKVCache; we do not expose KVCache config for now
211
+ const defaultPageSize = 16 ;
212
+ const defaultMaxNumSequence = 1 ;
213
+ this . kvCache = this . tvm . detachFromCurrentScope ( fcreateCache (
214
+ this . tvm . makeShapeTuple ( [ defaultMaxNumSequence ] ) , // max_num_sequence
215
+ this . tvm . makeShapeTuple ( [ this . maxWindowLength ] ) , // max_total_sequence_length
216
+ this . tvm . makeShapeTuple ( [ this . prefillChunkSize ] ) , // prefill_chunk_size
217
+ this . tvm . makeShapeTuple ( [ defaultPageSize ] ) , // page_size, hard coded for now
218
+ ) ) ;
219
+ } else {
220
+ this . fclearKVCaches = this . tvm . detachFromCurrentScope (
221
+ this . tvm . getGlobalFunc ( "vm.builtin.attention_kv_cache_array_clear" )
222
+ ) ;
223
+ this . kvCache = this . tvm . detachFromCurrentScope ( fcreateCache ( ) ) ;
224
+ }
165
225
this . filledKVCacheLength = 0 ;
226
+ this . resetChat ( ) ; // especially needed for PagedKVCache as we need to call fKVCacheAddSequence
166
227
tvm . endScope ( ) ;
167
228
}
168
229
@@ -198,18 +259,28 @@ export class LLMChatPipeline {
198
259
/**
199
260
* Reset the chat history
200
261
*/
201
- resetChat ( keepStats : boolean = false ) {
262
+ resetChat ( keepStats = false ) {
202
263
this . conversation . reset ( ) ;
203
264
if ( ! keepStats ) {
204
265
this . resetRuntimeStats ( ) ;
205
266
}
206
- this . fclearKVCaches ( this . kvCache ) ;
267
+ this . resetKVCache ( ) ;
207
268
this . filledKVCacheLength = 0 ;
208
269
this . sinkTriggered = false ;
209
270
this . slidingWindowCacheOffset = 0 ;
210
271
this . logitProcessor ?. resetState ( ) ;
211
272
}
212
273
274
+ /**
275
+ * Reset KV Cache
276
+ */
277
+ resetKVCache ( ) {
278
+ this . fclearKVCaches ( this . kvCache ) ;
279
+ if ( this . usePagedKVCache ) {
280
+ this . fKVCacheAddSequence ! ( this . kvCache , new tvmjs . Scalar ( 0 , "int64" ) ) ;
281
+ }
282
+ }
283
+
213
284
/**
214
285
* @returns Whether stop is triggered.
215
286
*/
@@ -397,16 +468,23 @@ export class LLMChatPipeline {
397
468
private forward ( inputs : tvmjs . NDArray , curPos : number ) : tvmjs . NDArray {
398
469
this . tvm . beginScope ( ) ;
399
470
let retValue ;
471
+ const seqLen = inputs . shape [ 1 ] ; // Num input tokens
400
472
const seqLenShape = this . tvm . makeShapeTuple ( [ curPos ] ) ;
401
- if ( inputs . shape [ 1 ] > 1 ) {
473
+ if ( seqLen > 1 ) {
402
474
// Prefill
403
475
if ( this . slidingWindowSize == - 1 ) {
404
- retValue = this . prefill (
405
- inputs , seqLenShape , this . kvCache , this . params
406
- ) ;
476
+ if ( this . usePagedKVCache ) {
477
+ const seqIdsTuple = this . tvm . makeShapeTuple ( [ 0 ] ) ;
478
+ const inputLenShape = this . tvm . makeShapeTuple ( [ seqLen ] ) ;
479
+ this . fKVCacheBeginForward ! ( this . kvCache , seqIdsTuple , inputLenShape ) ;
480
+ const embed = this . embed ! ( inputs , this . params ) ;
481
+ retValue = this . prefill ( embed , this . kvCache , this . params ) ;
482
+ this . fKVCacheEndForward ! ( this . kvCache ) ;
483
+ } else {
484
+ retValue = this . prefill ( inputs , seqLenShape , this . kvCache , this . params ) ;
485
+ }
407
486
} else {
408
487
// Sliding window attention needs extra shape parameters
409
- const seqLen = inputs . shape [ 1 ] ; // Num input tokens
410
488
const cacheLen = Math . min ( this . slidingWindowSize , curPos - seqLen ) ; // Num elements in the cache
411
489
const cacheLenShape = this . tvm . makeShapeTuple ( [ cacheLen ] ) ;
412
490
const kvSeqLenShape = this . tvm . makeShapeTuple ( [ cacheLen + seqLen ] ) ;
@@ -419,9 +497,16 @@ export class LLMChatPipeline {
419
497
} else {
420
498
// Decode
421
499
if ( this . slidingWindowSize == - 1 ) {
422
- retValue = this . decoding (
423
- inputs , seqLenShape , this . kvCache , this . params
424
- ) ;
500
+ if ( this . usePagedKVCache ) {
501
+ const seqIdsTuple = this . tvm . makeShapeTuple ( [ 0 ] ) ;
502
+ const appendLength = this . tvm . makeShapeTuple ( [ 1 ] ) ;
503
+ this . fKVCacheBeginForward ! ( this . kvCache , seqIdsTuple , appendLength ) ;
504
+ const embed = this . embed ! ( inputs , this . params ) ;
505
+ retValue = this . decoding ( embed , this . kvCache , this . params ) ;
506
+ this . fKVCacheEndForward ! ( this . kvCache ) ;
507
+ } else {
508
+ retValue = this . decoding ( inputs , seqLenShape , this . kvCache , this . params ) ;
509
+ }
425
510
} else {
426
511
// Same logic as above; keeping this if-else structure to match mlc-llm's llm_chat.cc
427
512
const seqLen = inputs . shape [ 1 ] ; // Num input tokens
@@ -463,7 +548,7 @@ export class LLMChatPipeline {
463
548
) {
464
549
// 1. Move logits to CPU
465
550
this . tvm . beginScope ( ) ;
466
- let logitsOnCPU = this . updateLogitsOnCPU ( logitsOnGPU ) ;
551
+ const logitsOnCPU = this . updateLogitsOnCPU ( logitsOnGPU ) ;
467
552
this . tvm . endScope ( ) ;
468
553
await this . device . sync ( ) ;
469
554
@@ -544,7 +629,7 @@ export class LLMChatPipeline {
544
629
// need shift window and re-encode
545
630
this . logger ( "need shift window" )
546
631
this . filledKVCacheLength = 0 ;
547
- this . fclearKVCaches ( this . kvCache ) ;
632
+ this . resetKVCache ( ) ;
548
633
549
634
// abandon all tokens we collected
550
635
if ( this . conversation . config . add_bos ) {
@@ -585,7 +670,7 @@ export class LLMChatPipeline {
585
670
inputData . copyFrom ( inputIds ) ;
586
671
587
672
// 2. Forward tokens and get logits
588
- let logitsOnGPU : tvmjs . NDArray = this . forward ( inputData , curPos ) ;
673
+ const logitsOnGPU : tvmjs . NDArray = this . forward ( inputData , curPos ) ;
589
674
const nextToken = await this . sampleTokenFromLogits (
590
675
logitsOnGPU , this . config . temperature , this . config . top_p ) ;
591
676
this . tvm . endScope ( ) ;
@@ -605,7 +690,7 @@ export class LLMChatPipeline {
605
690
606
691
async evaluate ( ) {
607
692
// run a canonical evaluation of the flow
608
- this . fclearKVCaches ( this . kvCache ) ;
693
+ this . resetKVCache ( ) ;
609
694
this . filledKVCacheLength = 0 ;
610
695
611
696
const testPrompt = "The capital of Canada is" ;
0 commit comments