Skip to content

Replace CPU Function Calls with GPU Kernel Invocations #697

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 143 additions & 55 deletions src/llm_chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ export class LLMChatPipeline {
private image_embed: tvmjs.PackedFunc | undefined;
private embed: tvmjs.PackedFunc;
private fapplyBitmask: tvmjs.PackedFunc;
private fapplyPenalty: tvmjs.PackedFunc;
private fapplyLogitBias: tvmjs.PackedFunc;
private fsoftmaxWithTemperature: tvmjs.PackedFunc;

// Functions related to PagedKVCache
private fclearKVCaches: tvmjs.PackedFunc;
private fKVCacheAddSequence: tvmjs.PackedFunc;
Expand Down Expand Up @@ -190,6 +194,15 @@ export class LLMChatPipeline {
this.fapplyBitmask = this.tvm.detachFromCurrentScope(
this.vm.getFunction("apply_bitmask_inplace"),
);
this.fapplyPenalty = this.tvm.detachFromCurrentScope(
this.vm.getFunction("apply_penalty_inplace"),
);
this.fapplyLogitBias = this.tvm.detachFromCurrentScope(
this.vm.getFunction("apply_logit_bias_inplace"),
);
this.fsoftmaxWithTemperature = this.tvm.detachFromCurrentScope(
this.vm.getFunction("softmax_with_temperature"),
);
try {
this.image_embed = this.tvm.detachFromCurrentScope(
this.vm.getFunction("image_embed"),
Expand Down Expand Up @@ -1091,80 +1104,155 @@ export class LLMChatPipeline {
if (this.logitProcessor !== undefined) {
logitsOnCPUArray = this.logitProcessor.processLogits(logitsOnCPUArray);
}

if (_hasValue(logit_bias)) {
for (const tokenID in logit_bias) {
const curBias = logit_bias[tokenID];
const curTokenID = parseInt(tokenID);
if (curTokenID > vocab_size) {
throw Error(
"Token " +
curTokenID +
" in logit_bias exceeds vocab_size " +
vocab_size,
);
}
logitsOnCPUArray[curTokenID] += curBias;
this.tvm.beginScope();
const numTokens = Object.keys(logit_bias ?? {}).length;
const pos2seq_id = new Int32Array(numTokens).fill(0);
const tokenIds = new Int32Array(numTokens);
const tokenLogitBias = new Float32Array(numTokens);

const logitBiasKeys = Object.keys(logit_bias ?? {});
for (let index = 0; index < numTokens; index++) {
const tokenId = parseInt(logitBiasKeys[index]);
tokenIds[index] = tokenId;
tokenLogitBias[index] = logit_bias![tokenId];
}

const pos2seqIdsArray = this.tvm
.empty([numTokens], "int32", this.device)
.copyFrom(pos2seq_id);

const tokenIdsArray = this.tvm
.empty([numTokens], "int32", this.device)
.copyFrom(tokenIds);

const tokenLogitBiasArray = this.tvm
.empty([numTokens], "float32", this.device)
.copyFrom(tokenLogitBias);

const logitsOnGPU = this.tvm
.empty([1, this.fullVocabSize], "float32", this.device)
.copyFrom(logitsOnCPUArray);

this.fapplyLogitBias(
logitsOnGPU,
pos2seqIdsArray,
tokenIdsArray,
tokenLogitBiasArray,
);
this.updateLogitsOnCPU(logitsOnGPU);
this.tvm.endScope();
}
this.logitsOnCPU.copyFrom(logitsOnCPUArray);
await this.device.sync();
}

// 3. Apply penalties to logits
if (_hasValue(frequency_penalty) && _hasValue(presence_penalty)) {
// 3.1. Use frequency and presence penalty
if (
frequency_penalty != 0.0 ||
presence_penalty != 0.0 ||
repetition_penalty != 1.0
) {
this.tvm.beginScope();
// Both `keys()` and `values()` are in insertion order.
const appearedTokens = [...this.appearedTokensFreq.keys()];
const appearedTokensFreqs = [...this.appearedTokensFreq.values()];
const appeared_tokens_ndarray = this.tvm.empty(
[1, appearedTokens.length],
"int32",
this.tvm.cpu(),
);
const appeared_tokens_freqs_ndarray = this.tvm.empty(
[1, appearedTokensFreqs.length],
"int32",
this.tvm.cpu(),
);
appeared_tokens_ndarray.copyFrom(appearedTokens);
appeared_tokens_freqs_ndarray.copyFrom(appearedTokensFreqs);
this.tvm.applyPresenceAndFrequencyPenalty(
this.logitsOnCPU,
appeared_tokens_ndarray,
appeared_tokens_freqs_ndarray,
presence_penalty!,
frequency_penalty!,
);
this.tvm.endScope();
} else if (repetition_penalty != 1.0) {
// 3.2. Use repetition penalty
this.tvm.beginScope();
const appearedTokens = [...this.appearedTokensFreq.keys()];
const appeared_tokens_ndarray = this.tvm.empty(
[1, appearedTokens.length],
"int32",
this.tvm.cpu(),
);
appeared_tokens_ndarray.copyFrom(appearedTokens);
this.tvm.applyRepetitionPenalty(
this.logitsOnCPU,
appeared_tokens_ndarray,

const numTokens = appearedTokens.length;

const seqIdsArray = this.tvm
.empty([1], "int32", this.device)
.copyFrom([0]);

const pos2seq_id = new Int32Array(numTokens).fill(0);
const tokenIds = new Int32Array(numTokens).fill(0);
const tokenCnt = new Int32Array(numTokens).fill(0);
const penalties = new Float32Array([
presence_penalty,
frequency_penalty,
repetition_penalty,
);
]);
const paddedPenalties = new Float32Array(3);
paddedPenalties.set(penalties);

tokenIds.set(appearedTokens);
tokenCnt.set(appearedTokensFreqs);

const pos2seqIdsArray = this.tvm
.empty([numTokens], "int32", this.device)
.copyFrom(pos2seq_id);

const tokenIdsArray = this.tvm
.empty([numTokens], "int32", this.device)
.copyFrom(tokenIds);

const tokenCntArray = this.tvm
.empty([numTokens], "int32", this.device)
.copyFrom(tokenCnt);

const penaltiesArray = this.tvm
.empty([1, 3], "float32", this.device)
.copyFrom(paddedPenalties);

const logitsOnGPU = this.tvm
.empty([1, this.fullVocabSize], "float32", this.device)
.copyFrom(this.logitsOnCPU.toArray());

if (numTokens > 0) {
this.fapplyPenalty(
logitsOnGPU,
seqIdsArray,
pos2seqIdsArray,
tokenIdsArray,
tokenCntArray,
penaltiesArray,
);
}
this.updateLogitsOnCPU(logitsOnGPU);
this.tvm.endScope();
}
await this.device.sync();

// 4. Sample token from logits
// If logprobs, need the actual distribution via softmax, otherwise directly sample from logits
let sampledToken: number;
if (logprobs) {
// Inplace transform logitsOnCPU to a distribution
temperature = Math.max(1e-6, temperature); // to prevent division by zero
this.tvm.applySoftmaxWithTemperature(this.logitsOnCPU, temperature);
sampledToken = this.tvm.sampleTopPFromProb(this.logitsOnCPU, top_p);
this.tokenLogprobArray.push(
this.getTokenLogprob(sampledToken, top_logprobs!),
);

const numSeqs = 1;
const numTokens = this.appearedTokensFreq.size;

if (numTokens > 0) {
const temperatures = new Float32Array([temperature]);

this.tvm.beginScope();
const temperaturesArray = this.tvm
.empty([numSeqs], "float32", this.device)
.copyFrom(temperatures);

const logitsOnGPU = this.tvm
.empty([numSeqs, 1, this.fullVocabSize], "float32", this.device)
.copyFrom(this.logitsOnCPU.toArray());

const probs = this.fsoftmaxWithTemperature(
logitsOnGPU,
temperaturesArray,
);
this.updateLogitsOnCPU(probs);
this.tvm.endScope();
await this.device.sync();

sampledToken = this.tvm.sampleTopPFromProb(this.logitsOnCPU, top_p);
this.tokenLogprobArray.push(
this.getTokenLogprob(sampledToken, top_logprobs!),
);
} else {
this.tvm.applySoftmaxWithTemperature(this.logitsOnCPU, temperature);
sampledToken = this.tvm.sampleTopPFromProb(this.logitsOnCPU, top_p);
this.tokenLogprobArray.push(
this.getTokenLogprob(sampledToken, top_logprobs!),
);
}
} else {
// temperature being 0 is allowed here, equivalent to argmax
sampledToken = this.tvm.sampleTopPFromLogits(
Expand Down