Skip to content

Commit 2243cf3

Browse files
authored
[ChatModule] Add GenerationConfig and set up unit tests (mlc-ai#298)
This PR adds `GenerationConfig`, which allows per-generation configs. See `get-started.ts` for its example usage: ```typescript let genConfig: webllm.GenerationConfig = { presence_penalty: 0.5, frequency_penalty: 0.5, max_gen_len: 20, // stop: ["is", "Canada"] // for demonstration purpose } const prompt0 = "What is the capital of Canada?"; const reply0 = await chat.generate(prompt0, generateProgressCallback, 1, genConfig); ``` In addition to the existing fields in `mlc-chat-config.json`, we also support OpenAI-like fields `frequency_penalty`, `presence_penalty`, and `stop` to prepare for the incoming OpenAI-like APIs. This PR also sets up unit tests; use `npm test` to run tests. However, some work needs to be done to support end-to-end testing (e.g. accessing WebGPU in a test environment). All prebuilt WASMs are updated correspondingly: mlc-ai/binary-mlc-llm-libs#90 as we introduced a new API in tvmjs's `runtime.ts` via apache/tvm#16504. Note that the update of Llama WASMs is breaking in the sense that users will have to update their WebLLM npm.
1 parent 3178ec1 commit 2243cf3

File tree

12 files changed

+5293
-2951
lines changed

12 files changed

+5293
-2951
lines changed

examples/get-started-rest/package.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
{
2-
"name": "get-started",
2+
"name": "get-started-rest",
33
"version": "0.1.0",
44
"private": true,
55
"scripts": {

examples/get-started/src/get_started.ts

+13-3
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ async function main() {
3333
]
3434
}
3535
const selectedModel = "Llama-2-7b-chat-hf-q4f32_1"
36-
// const selectedModel = "Mistral-7B-Instruct-v0.1-q4f16_1"
36+
// const selectedModel = "Mistral-7B-Instruct-v0.2-q4f16_1"
3737
await chat.reload(selectedModel, undefined, myAppConfig);
3838

3939
// Option 2: If we do not specify appConfig, we use `prebuiltAppConfig` defined in `config.ts`
@@ -43,14 +43,24 @@ async function main() {
4343
setLabel("generate-label", message);
4444
};
4545

46+
// Per-generation configuration
47+
let genConfig: webllm.GenerationConfig = {
48+
presence_penalty: 0.5,
49+
frequency_penalty: 0.5,
50+
// stop: ["is", "Canada"] // for demonstration purpose
51+
}
52+
4653
const prompt0 = "What is the capital of Canada?";
4754
setLabel("prompt-label", prompt0);
48-
const reply0 = await chat.generate(prompt0, generateProgressCallback);
55+
const reply0 = await chat.generate(prompt0, generateProgressCallback, 1, genConfig);
4956
console.log(reply0);
5057

58+
genConfig = {
59+
max_gen_len: 20,
60+
}
5161
const prompt1 = "Can you write a poem about it?";
5262
setLabel("prompt-label", prompt1);
53-
const reply1 = await chat.generate(prompt1, generateProgressCallback);
63+
const reply1 = await chat.generate(prompt1, generateProgressCallback, 1, genConfig);
5464
console.log(reply1);
5565

5666
console.log(await chat.runtimeStatsText());

jest.config.cjs

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
module.exports = {
2+
preset: "ts-jest",
3+
testEnvironment: "node",
4+
};

0 commit comments

Comments
 (0)