Skip to content

Commit ec2662f

Browse files
authored
[LLMChat] Make llm_chat compatible with PagedKVCache (mlc-ai#293)
PagedKVCache is introduced in MLC-LLM a while back to unite the interface for KVCache. This PR makes WebLLM compatible with the new PagedKVCache interface, encapsulating it with the goal that WebLLM users will not notice any difference. This PR is equivalent to the changes to `llm_chat.cc` in mlc-ai/mlc-llm#1651, and should address issues like mlc-ai/mlc-llm#1628. There are still existing model compilation issues regarding `workgroup_size` (since WebGPU, unlike most other backends, can only support 256 number of threads). We will address this issue more elegantly soon; for now, compiling llama-based models require manually changing kernel sizes as shown in [this branch](https://github.com/CharlieFRuan/mlc-llm/tree/local-workgroupSize-webLLM-kvCache). This PR is also largely dependent on apache/tvm#16554.
1 parent 3319d1c commit ec2662f

File tree

2 files changed

+116
-32
lines changed

2 files changed

+116
-32
lines changed

src/chat_module.ts

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,6 @@ export class ChatModule implements ChatInterface {
139139

140140
this.pipeline = new LLMChatPipeline(tvm, tokenizer, config, this.logitProcessor);
141141
await this.pipeline?.asyncLoadWebGPUPipelines();
142-
143142
const tend = performance.now();
144143

145144
if (this.initProgressCallback !== undefined) {

src/llm_chat.ts

Lines changed: 116 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
/* eslint-disable @typescript-eslint/no-non-null-assertion */
2+
/* eslint-disable no-prototype-builtins */
13
import * as tvmjs from "tvmjs";
24
import { Tokenizer } from "@mlc-ai/web-tokenizers";
35
import { ChatConfig } from "./config";
@@ -16,6 +18,12 @@ export class LLMChatPipeline {
1618
private prefill: tvmjs.PackedFunc;
1719
private decoding: tvmjs.PackedFunc;
1820
private fclearKVCaches: tvmjs.PackedFunc;
21+
// Functions for PagedKVCache only
22+
private embed?: tvmjs.PackedFunc = undefined;
23+
private fKVCacheAddSequence?: tvmjs.PackedFunc = undefined;
24+
private fKVCacheRemoveSequence?: tvmjs.PackedFunc = undefined;
25+
private fKVCacheBeginForward?: tvmjs.PackedFunc = undefined;
26+
private fKVCacheEndForward?: tvmjs.PackedFunc = undefined;
1927

2028
// parameter states
2129
private params: tvmjs.TVMObject;
@@ -41,10 +49,11 @@ export class LLMChatPipeline {
4149
private appearedTokens = new Set<number>();
4250
private conversation: Conversation;
4351
// Whether sink is in action
44-
private sinkTriggered: boolean = false;
52+
private sinkTriggered = false;
4553
// sliding window cache offset (Next position to be overridden on the rolling kv cache.)
46-
private slidingWindowCacheOffset: number = 0;
47-
// Total amount of seq len prefilled so far
54+
private slidingWindowCacheOffset = 0;
55+
// Whether we are using PagedKVCache (eventually this will become default)
56+
private usePagedKVCache = false;
4857

4958
// stats
5059
private decodingTotalTime = 0;
@@ -59,6 +68,7 @@ export class LLMChatPipeline {
5968
private logitProcessor?: LogitProcessor = undefined;
6069

6170
constructor(tvm: tvmjs.Instance, tokenizer: Tokenizer, config: ChatConfig, logitProcessor?: LogitProcessor) {
71+
// 0. Setting attributes
6272
this.tvm = tvm;
6373
this.tokenizer = tokenizer;
6474
this.config = config;
@@ -69,19 +79,28 @@ export class LLMChatPipeline {
6979
this.stopTokens = this.conversation.getStopTokens();
7080

7181
this.device = this.tvm.webgpu();
72-
tvm.beginScope();
7382

83+
// 1. Create VM and get the core functions
84+
tvm.beginScope();
7485
this.vm = this.tvm.detachFromCurrentScope(
7586
this.tvm.createVirtualMachine(this.device)
7687
);
7788
this.prefill = this.tvm.detachFromCurrentScope(
7889
this.vm.getFunction("prefill")
7990
);
91+
try {
92+
// We expect to find `embed` if we usePagedKVCache
93+
this.embed = this.tvm.detachFromCurrentScope(
94+
this.vm.getFunction("embed")
95+
);
96+
} catch {
97+
// Do nothing
98+
}
8099
this.decoding = this.tvm.detachFromCurrentScope(
81100
this.vm.getFunction("decode")
82101
);
83102

84-
// Get json stored in the vm's metadata function
103+
// 2. Get json stored in the vm's metadata function
85104
let fgetMetadata;
86105
let useSLIM = false; // SLIM is the new workflow
87106
try {
@@ -94,7 +113,7 @@ export class LLMChatPipeline {
94113
const metadataStr = this.tvm.detachFromCurrentScope(ret_value).toString();
95114
const metadata = JSON.parse(metadataStr);
96115

97-
// Load parameters
116+
// 3. Load parameters
98117
if (useSLIM) {
99118
// Under SLIM workflow, we load parameters by name
100119
const paramNames: string[] = [];
@@ -109,14 +128,16 @@ export class LLMChatPipeline {
109128
);
110129
}
111130

112-
if (metadata.hasOwnProperty("prefill_chunk_size") && metadata.prefill_chunk_size != -1) {
131+
// 4. Read in compilation configurations from metadata
132+
if (metadata.hasOwnProperty("prefill_chunk_size")) {
113133
this.prefillChunkSize = metadata.prefill_chunk_size;
114134
this.logger("Using prefillChunkSize: ", this.prefillChunkSize);
115135
if (this.prefillChunkSize <= 0) {
116136
throw Error("Prefill chunk size needs to be positive.");
117137
}
138+
} else {
139+
throw Error("Cannot find `prefill_chunk_size` in metadta; make sure the wasm is up to date.");
118140
}
119-
120141
// Only use one of slidingWindowSize and maxWindowLength
121142
if (metadata.hasOwnProperty("sliding_window_size") && metadata.sliding_window_size != -1) {
122143
this.slidingWindowSize = metadata.sliding_window_size;
@@ -149,20 +170,60 @@ export class LLMChatPipeline {
149170
}
150171
}
151172

173+
// 5. Create cache
174+
// Use `fcreateCache` to determine whether we are using the new KVCache implementation
152175
let fcreateCache;
153-
if (useSLIM) {
154-
fcreateCache = this.vm.getFunction("_initialize_effect");
155-
} else {
156-
fcreateCache = this.vm.getFunction("create_kv_cache");
176+
try {
177+
if (useSLIM) {
178+
fcreateCache = this.vm.getFunction("_initialize_effect");
179+
} else {
180+
fcreateCache = this.vm.getFunction("create_kv_cache");
181+
}
182+
} catch (err) {
183+
// If we cannot find function above, it means that we are using the new PagedKVCache
184+
this.usePagedKVCache = true;
185+
fcreateCache = this.vm.getFunction("create_tir_paged_kv_cache");
186+
console.log("Using Paged KVCache");
187+
if (this.embed === undefined) {
188+
throw Error("If using paged KVCache, method `embed()` needs to be defined.");
189+
}
157190
}
158191

159-
this.fclearKVCaches = this.tvm.detachFromCurrentScope(
160-
this.tvm.getGlobalFunc("vm.builtin.attention_kv_cache_array_clear")
161-
);
192+
// Load cache functions and instantiate KVCache
193+
if (this.usePagedKVCache) {
194+
this.fclearKVCaches = this.tvm.detachFromCurrentScope(
195+
this.tvm.getGlobalFunc("vm.builtin.paged_attention_kv_cache_clear")
196+
);
197+
this.fKVCacheAddSequence = this.tvm.detachFromCurrentScope(
198+
this.tvm.getGlobalFunc("vm.builtin.paged_attention_kv_cache_add_sequence")
199+
);
200+
this.fKVCacheRemoveSequence = this.tvm.detachFromCurrentScope(
201+
this.tvm.getGlobalFunc("vm.builtin.paged_attention_kv_cache_remove_sequence")
202+
);
203+
this.fKVCacheBeginForward = this.tvm.detachFromCurrentScope(
204+
this.tvm.getGlobalFunc("vm.builtin.paged_attention_kv_cache_begin_forward")
205+
);
206+
this.fKVCacheEndForward = this.tvm.detachFromCurrentScope(
207+
this.tvm.getGlobalFunc("vm.builtin.paged_attention_kv_cache_end_forward")
208+
);
162209

163-
// use extern config for now
164-
this.kvCache = this.tvm.detachFromCurrentScope(fcreateCache());
210+
// Create PagedKVCache; we do not expose KVCache config for now
211+
const defaultPageSize = 16;
212+
const defaultMaxNumSequence = 1;
213+
this.kvCache = this.tvm.detachFromCurrentScope(fcreateCache(
214+
this.tvm.makeShapeTuple([defaultMaxNumSequence]), // max_num_sequence
215+
this.tvm.makeShapeTuple([this.maxWindowLength]), // max_total_sequence_length
216+
this.tvm.makeShapeTuple([this.prefillChunkSize]), // prefill_chunk_size
217+
this.tvm.makeShapeTuple([defaultPageSize]), // page_size, hard coded for now
218+
));
219+
} else {
220+
this.fclearKVCaches = this.tvm.detachFromCurrentScope(
221+
this.tvm.getGlobalFunc("vm.builtin.attention_kv_cache_array_clear")
222+
);
223+
this.kvCache = this.tvm.detachFromCurrentScope(fcreateCache());
224+
}
165225
this.filledKVCacheLength = 0;
226+
this.resetChat(); // especially needed for PagedKVCache as we need to call fKVCacheAddSequence
166227
tvm.endScope();
167228
}
168229

@@ -198,18 +259,28 @@ export class LLMChatPipeline {
198259
/**
199260
* Reset the chat history
200261
*/
201-
resetChat(keepStats: boolean = false) {
262+
resetChat(keepStats = false) {
202263
this.conversation.reset();
203264
if (!keepStats) {
204265
this.resetRuntimeStats();
205266
}
206-
this.fclearKVCaches(this.kvCache);
267+
this.resetKVCache();
207268
this.filledKVCacheLength = 0;
208269
this.sinkTriggered = false;
209270
this.slidingWindowCacheOffset = 0;
210271
this.logitProcessor?.resetState();
211272
}
212273

274+
/**
275+
* Reset KV Cache
276+
*/
277+
resetKVCache() {
278+
this.fclearKVCaches(this.kvCache);
279+
if (this.usePagedKVCache) {
280+
this.fKVCacheAddSequence!(this.kvCache, new tvmjs.Scalar(0, "int64"));
281+
}
282+
}
283+
213284
/**
214285
* @returns Whether stop is triggered.
215286
*/
@@ -397,16 +468,23 @@ export class LLMChatPipeline {
397468
private forward(inputs: tvmjs.NDArray, curPos: number): tvmjs.NDArray {
398469
this.tvm.beginScope();
399470
let retValue;
471+
const seqLen = inputs.shape[1]; // Num input tokens
400472
const seqLenShape = this.tvm.makeShapeTuple([curPos]);
401-
if (inputs.shape[1] > 1) {
473+
if (seqLen > 1) {
402474
// Prefill
403475
if (this.slidingWindowSize == -1) {
404-
retValue = this.prefill(
405-
inputs, seqLenShape, this.kvCache, this.params
406-
);
476+
if (this.usePagedKVCache) {
477+
const seqIdsTuple = this.tvm.makeShapeTuple([0]);
478+
const inputLenShape = this.tvm.makeShapeTuple([seqLen]);
479+
this.fKVCacheBeginForward!(this.kvCache, seqIdsTuple, inputLenShape);
480+
const embed = this.embed!(inputs, this.params);
481+
retValue = this.prefill(embed, this.kvCache, this.params);
482+
this.fKVCacheEndForward!(this.kvCache);
483+
} else {
484+
retValue = this.prefill(inputs, seqLenShape, this.kvCache, this.params);
485+
}
407486
} else {
408487
// Sliding window attention needs extra shape parameters
409-
const seqLen = inputs.shape[1]; // Num input tokens
410488
const cacheLen = Math.min(this.slidingWindowSize, curPos - seqLen); // Num elements in the cache
411489
const cacheLenShape = this.tvm.makeShapeTuple([cacheLen]);
412490
const kvSeqLenShape = this.tvm.makeShapeTuple([cacheLen + seqLen]);
@@ -419,9 +497,16 @@ export class LLMChatPipeline {
419497
} else {
420498
// Decode
421499
if (this.slidingWindowSize == -1) {
422-
retValue = this.decoding(
423-
inputs, seqLenShape, this.kvCache, this.params
424-
);
500+
if (this.usePagedKVCache) {
501+
const seqIdsTuple = this.tvm.makeShapeTuple([0]);
502+
const appendLength = this.tvm.makeShapeTuple([1]);
503+
this.fKVCacheBeginForward!(this.kvCache, seqIdsTuple, appendLength);
504+
const embed = this.embed!(inputs, this.params);
505+
retValue = this.decoding(embed, this.kvCache, this.params);
506+
this.fKVCacheEndForward!(this.kvCache);
507+
} else {
508+
retValue = this.decoding(inputs, seqLenShape, this.kvCache, this.params);
509+
}
425510
} else {
426511
// Same logic as above; keeping this if-else structure to match mlc-llm's llm_chat.cc
427512
const seqLen = inputs.shape[1]; // Num input tokens
@@ -463,7 +548,7 @@ export class LLMChatPipeline {
463548
) {
464549
// 1. Move logits to CPU
465550
this.tvm.beginScope();
466-
let logitsOnCPU = this.updateLogitsOnCPU(logitsOnGPU);
551+
const logitsOnCPU = this.updateLogitsOnCPU(logitsOnGPU);
467552
this.tvm.endScope();
468553
await this.device.sync();
469554

@@ -544,7 +629,7 @@ export class LLMChatPipeline {
544629
// need shift window and re-encode
545630
this.logger("need shift window")
546631
this.filledKVCacheLength = 0;
547-
this.fclearKVCaches(this.kvCache);
632+
this.resetKVCache();
548633

549634
// abandon all tokens we collected
550635
if (this.conversation.config.add_bos) {
@@ -585,7 +670,7 @@ export class LLMChatPipeline {
585670
inputData.copyFrom(inputIds);
586671

587672
// 2. Forward tokens and get logits
588-
let logitsOnGPU: tvmjs.NDArray = this.forward(inputData, curPos);
673+
const logitsOnGPU: tvmjs.NDArray = this.forward(inputData, curPos);
589674
const nextToken = await this.sampleTokenFromLogits(
590675
logitsOnGPU, this.config.temperature, this.config.top_p);
591676
this.tvm.endScope();
@@ -605,7 +690,7 @@ export class LLMChatPipeline {
605690

606691
async evaluate() {
607692
// run a canonical evaluation of the flow
608-
this.fclearKVCaches(this.kvCache);
693+
this.resetKVCache();
609694
this.filledKVCacheLength = 0;
610695

611696
const testPrompt = "The capital of Canada is";

0 commit comments

Comments
 (0)