Skip to content

Commit fd52664

Browse files
Update runner for weights sharing
1 parent ea5cd4d commit fd52664

File tree

9 files changed

+2175
-78
lines changed

9 files changed

+2175
-78
lines changed

backends/mediatek/runtime/include/api/NeuronAdapter.h

Lines changed: 2143 additions & 0 deletions
Large diffs are not rendered by default.

examples/mediatek/executor_runner/llama_runner/LlamaModelChunk.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,16 @@ LlamaModelChunk::LlamaModelChunk(
6363
enableSWA(enableSWA),
6464
kCacheTypeSize(llm_helper::getLLMTypeSize(kCacheType)) {}
6565

66+
LlamaModelChunk::~LlamaModelChunk() {}
67+
6668
std::string LlamaModelChunk::SelectMethod(
6769
const std::vector<std::string>& methodNames) const {
6870
const size_t curTokenSize = GetModelId();
6971
for (const auto& methodName : methodNames) {
7072
const auto matches = utils::extract_substr(methodName, "([0-9]+)t[0-9]+c");
73+
if (matches.empty()) {
74+
continue;
75+
}
7176
ET_CHECK_MSG(
7277
matches.size() == 2, "Invalid method name: %s", methodName.c_str());
7378
// Extract the first match group as token size
@@ -88,8 +93,6 @@ std::string LlamaModelChunk::SelectMethod(
8893
return {};
8994
}
9095

91-
LlamaModelChunk::~LlamaModelChunk() {}
92-
9396
void LlamaModelChunk::Initialize() {
9497
LoadModels();
9598
GetModelIoInfo();
@@ -367,8 +370,9 @@ void LlamaModelChunk::UpdatePosEmbAndMask(const size_t numInputToken) {
367370
const auto swaMaskSizeBytes = swaMaskBufferInfo.nbytesUsed;
368371
mMaskBuilder->setMaskBuffer(swaMaskBuffer, swaMaskSizeBytes);
369372
mMaskBuilder->enableSlidingWindow(kWindowSize);
370-
mMaskBuilder->updateMask(
371-
mTokenBatchSize, mCurrentTokenIndex, numInputToken);
373+
// mMaskBuilder->updateMask(mTokenBatchSize, mCurrentTokenIndex,
374+
// numInputToken);
375+
mMaskBuilder->buildMask(mTokenBatchSize, mCurrentTokenIndex);
372376
}
373377
// Pass same isMaskUpdatable to both mask
374378
mMaskBuilder->setIsMaskUpdatable(isMaskUpdatable);

examples/mediatek/executor_runner/llama_runner/LlamaModelChunk.h

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,6 @@ class LlamaModelChunk : public ModelChunk {
8989

9090
void InitMaskBuilder();
9191

92-
void InitSWAMaskBuilder();
93-
9492
void InitCache();
9593

9694
void PrepareCacheIOs();
@@ -134,10 +132,6 @@ class LlamaModelChunk : public ModelChunk {
134132

135133
void CheckIoCount();
136134

137-
size_t GetExpectedInputCount() const;
138-
139-
size_t GetExpectedOutputCount() const;
140-
141135
private:
142136
bool AllowModelsCoexist() const override {
143137
return kIsSharedWeightsUsed;
@@ -150,12 +144,6 @@ class LlamaModelChunk : public ModelChunk {
150144
// Whether shared weights is used
151145
bool kIsSharedWeightsUsed = false;
152146

153-
// Input/Output Indexes
154-
const size_t kMaskInputIndex;
155-
const std::vector<size_t> kRotEmbInputIndexes;
156-
const std::vector<size_t> kCacheInputIndexes;
157-
const std::vector<size_t> kCacheOutputIndexes;
158-
159147
// Cache
160148
TensorShape mCacheShape;
161149
const LLMType kCacheType;

examples/mediatek/executor_runner/llama_runner/llm_helper/mask_builder.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
* except in compliance with the License. See the license file in the root
66
* directory of this source tree for more details.
77
*/
8-
#include <iostream> //TODO: DELETE
98

109
#include "llm_helper/include/mask_builder.h"
1110

examples/mediatek/executor_runner/run_gemma2_sample.sh

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -37,16 +37,9 @@ TOKENIZER_PATH="/data/local/tmp/et_mtk/tokenizer_gemma2.json"
3737
TOKEN_EMBEDDING_PATH="/data/local/tmp/et_mtk/embedding_gemma2_2b_it_fp32.bin"
3838

3939
# Comma-Separated Paths
40-
PROMPT_MODEL_PATHS="\
41-
/data/local/tmp/et_mtk/gemma2_2b_it_A16W4_2_chunks_128t512c/gemma2_2b_it_A16W4_2_chunks_128t512c_0.pte,\
42-
/data/local/tmp/et_mtk/gemma2_2b_it_A16W4_2_chunks_128t512c/gemma2_2b_it_A16W4_2_chunks_128t512c_1.pte,"
43-
44-
# # Comma-Separated Paths
45-
GEN_MODEL_PATHS="\
46-
/data/local/tmp/et_mtk/gemma2_2b_it_A16W4_2_chunks_1t512c/gemma2_2b_it_A16W4_2_chunks_1t512c_0.pte,\
47-
/data/local/tmp/et_mtk/gemma2_2b_it_A16W4_2_chunks_1t512c/gemma2_2b_it_A16W4_2_chunks_1t512c_1.pte,"
48-
49-
40+
WEIGHT_SHARED_MODEL_PACKAGE_PATHS="\
41+
/data/local/tmp/et_mtk/gemma2_2b_it_A16W4_2_chunks/gemma2_2b_it_A16W4_2_chunks_0.pte,\
42+
/data/local/tmp/et_mtk/gemma2_2b_it_A16W4_2_chunks/gemma2_2b_it_A16W4_2_chunks_1.pte,"
5043

5144
PROMPT_FILE=/data/local/tmp/et_mtk/prompt_gemma.txt
5245

@@ -75,6 +68,5 @@ export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$PWD
7568
--tokenizer_type=$TOKENIZER_TYPE \
7669
--tokenizer_path=$TOKENIZER_PATH \
7770
--token_embedding_path=$TOKEN_EMBEDDING_PATH \
78-
--prompt_model_paths=$PROMPT_MODEL_PATHS \
79-
--gen_model_paths=$GEN_MODEL_PATHS \
71+
--model_package_paths=$WEIGHT_SHARED_MODEL_PACKAGE_PATHS \
8072
--prompt_file=$PROMPT_FILE

examples/mediatek/executor_runner/run_phi3_sample.sh

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -36,19 +36,11 @@ TOKENIZER_PATH="/data/local/tmp/et_mtk/tokenizer.bin"
3636
TOKEN_EMBEDDING_PATH="/data/local/tmp/et_mtk/embedding_phi3.5-mini-instruct_fp32.bin"
3737

3838
# Comma-Separated Paths
39-
PROMPT_MODEL_PATHS="\
40-
/data/local/tmp/et_mtk/phi3.5-mini-instruct_A16W4_4_chunks_128t512c/phi3.5-mini-instruct_A16W4_4_chunks_128t512c_0.pte,\
41-
/data/local/tmp/et_mtk/phi3.5-mini-instruct_A16W4_4_chunks_128t512c/phi3.5-mini-instruct_A16W4_4_chunks_128t512c_1.pte,\
42-
/data/local/tmp/et_mtk/phi3.5-mini-instruct_A16W4_4_chunks_128t512c/phi3.5-mini-instruct_A16W4_4_chunks_128t512c_2.pte,\
43-
/data/local/tmp/et_mtk/phi3.5-mini-instruct_A16W4_4_chunks_128t512c/phi3.5-mini-instruct_A16W4_4_chunks_128t512c_3.pte,"
44-
45-
46-
# Comma-Separated Paths
47-
GEN_MODEL_PATHS="\
48-
/data/local/tmp/et_mtk/phi3.5-mini-instruct_A16W4_4_chunks_1t512c/phi3.5-mini-instruct_A16W4_4_chunks_1t512c_0.pte,\
49-
/data/local/tmp/et_mtk/phi3.5-mini-instruct_A16W4_4_chunks_1t512c/phi3.5-mini-instruct_A16W4_4_chunks_1t512c_1.pte,\
50-
/data/local/tmp/et_mtk/phi3.5-mini-instruct_A16W4_4_chunks_1t512c/phi3.5-mini-instruct_A16W4_4_chunks_1t512c_2.pte,\
51-
/data/local/tmp/et_mtk/phi3.5-mini-instruct_A16W4_4_chunks_1t512c/phi3.5-mini-instruct_A16W4_4_chunks_1t512c_3.pte,"
39+
WEIGHT_SHARED_MODEL_PACKAGE_PATHS="\
40+
/data/local/tmp/et_mtk/phi3.5-mini-instruct_A16W4_4_chunks/phi3.5-mini-instruct_A16W4_4_chunks_0.pte,\
41+
/data/local/tmp/et_mtk/phi3.5-mini-instruct_A16W4_4_chunks/phi3.5-mini-instruct_A16W4_4_chunks_1.pte,\
42+
/data/local/tmp/et_mtk/phi3.5-mini-instruct_A16W4_4_chunks/phi3.5-mini-instruct_A16W4_4_chunks_2.pte,\
43+
/data/local/tmp/et_mtk/phi3.5-mini-instruct_A16W4_4_chunks/phi3.5-mini-instruct_A16W4_4_chunks_3.pte,"
5244

5345
PROMPT_FILE=/data/local/tmp/et_mtk/prompt_phi3.txt
5446

@@ -76,6 +68,5 @@ export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$PWD
7668
--tokenizer_type=$TOKENIZER_TYPE \
7769
--tokenizer_path=$TOKENIZER_PATH \
7870
--token_embedding_path=$TOKEN_EMBEDDING_PATH \
79-
--prompt_model_paths=$PROMPT_MODEL_PATHS \
80-
--gen_model_paths=$GEN_MODEL_PATHS \
71+
--model_package_paths=$WEIGHT_SHARED_MODEL_PACKAGE_PATHS \
8172
--prompt_file=$PROMPT_FILE

examples/mediatek/executor_runner/run_qwen2_sample.sh

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -36,20 +36,11 @@ TOKENIZER_PATH="/data/local/tmp/et_mtk/tokenizer_qwen3.json"
3636
TOKEN_EMBEDDING_PATH="/data/local/tmp/et_mtk/embedding_Qwen2-7B-Instruct_fp32.bin"
3737

3838
# Comma-Separated Paths
39-
PROMPT_MODEL_PATHS="\
40-
/data/local/tmp/et_mtk/Qwen2-7B-Instruct_A16W4_4_chunks_128t512c/Qwen2-7B-Instruct_A16W4_4_chunks_128t512c_0.pte,\
41-
/data/local/tmp/et_mtk/Qwen2-7B-Instruct_A16W4_4_chunks_128t512c/Qwen2-7B-Instruct_A16W4_4_chunks_128t512c_1.pte,\
42-
/data/local/tmp/et_mtk/Qwen2-7B-Instruct_A16W4_4_chunks_128t512c/Qwen2-7B-Instruct_A16W4_4_chunks_128t512c_2.pte,\
43-
/data/local/tmp/et_mtk/Qwen2-7B-Instruct_A16W4_4_chunks_128t512c/Qwen2-7B-Instruct_A16W4_4_chunks_128t512c_3.pte,"
44-
45-
# # Comma-Separated Paths
46-
GEN_MODEL_PATHS="\
47-
/data/local/tmp/et_mtk/Qwen2-7B-Instruct_A16W4_4_chunks_1t512c/Qwen2-7B-Instruct_A16W4_4_chunks_1t512c_0.pte,\
48-
/data/local/tmp/et_mtk/Qwen2-7B-Instruct_A16W4_4_chunks_1t512c/Qwen2-7B-Instruct_A16W4_4_chunks_1t512c_1.pte,\
49-
/data/local/tmp/et_mtk/Qwen2-7B-Instruct_A16W4_4_chunks_1t512c/Qwen2-7B-Instruct_A16W4_4_chunks_1t512c_2.pte,\
50-
/data/local/tmp/et_mtk/Qwen2-7B-Instruct_A16W4_4_chunks_1t512c/Qwen2-7B-Instruct_A16W4_4_chunks_1t512c_3.pte,"
51-
52-
39+
WEIGHT_SHARED_MODEL_PACKAGE_PATHS="\
40+
/data/local/tmp/et_mtk/Qwen2-7B-Instruct_A16W4_4_chunks/Qwen2-7B-Instruct_A16W4_4_chunks_0.pte,\
41+
/data/local/tmp/et_mtk/Qwen2-7B-Instruct_A16W4_4_chunks/Qwen2-7B-Instruct_A16W4_4_chunks_1.pte,\
42+
/data/local/tmp/et_mtk/Qwen2-7B-Instruct_A16W4_4_chunks/Qwen2-7B-Instruct_A16W4_4_chunks_2.pte,\
43+
/data/local/tmp/et_mtk/Qwen2-7B-Instruct_A16W4_4_chunks/Qwen2-7B-Instruct_A16W4_4_chunks_3.pte,"
5344

5445
PROMPT_FILE=/data/local/tmp/et_mtk/prompt.txt
5546

@@ -77,6 +68,5 @@ export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$PWD
7768
--tokenizer_type=$TOKENIZER_TYPE \
7869
--tokenizer_path=$TOKENIZER_PATH \
7970
--token_embedding_path=$TOKEN_EMBEDDING_PATH \
80-
--prompt_model_paths=$PROMPT_MODEL_PATHS \
81-
--gen_model_paths=$GEN_MODEL_PATHS \
71+
--model_package_paths=$WEIGHT_SHARED_MODEL_PACKAGE_PATHS \
8272
--prompt_file=$PROMPT_FILE

examples/mediatek/executor_runner/run_qwen3_sample.sh

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -37,22 +37,13 @@ TOKENIZER_PATH="/data/local/tmp/et_mtk/tokenizer_qwen3.json"
3737
TOKEN_EMBEDDING_PATH="/data/local/tmp/et_mtk/embedding_Qwen3-4B_fp32.bin"
3838

3939
# Comma-Separated Paths
40-
PROMPT_MODEL_PATHS="\
41-
/data/local/tmp/et_mtk/Qwen3-4B_A16W4_4_chunks_128t512c/Qwen3-4B_A16W4_4_chunks_128t512c_0.pte,\
42-
/data/local/tmp/et_mtk/Qwen3-4B_A16W4_4_chunks_128t512c/Qwen3-4B_A16W4_4_chunks_128t512c_1.pte,\
43-
/data/local/tmp/et_mtk/Qwen3-4B_A16W4_4_chunks_128t512c/Qwen3-4B_A16W4_4_chunks_128t512c_2.pte,\
44-
/data/local/tmp/et_mtk/Qwen3-4B_A16W4_4_chunks_128t512c/Qwen3-4B_A16W4_4_chunks_128t512c_3.pte,"
40+
WEIGHT_SHARED_MODEL_PACKAGE_PATHS="\
41+
/data/local/tmp/et_mtk/Qwen3-4B_A16W4_4_chunks/Qwen3-4B_A16W4_4_chunks_0.pte,\
42+
/data/local/tmp/et_mtk/Qwen3-4B_A16W4_4_chunks/Qwen3-4B_A16W4_4_chunks_1.pte,\
43+
/data/local/tmp/et_mtk/Qwen3-4B_A16W4_4_chunks/Qwen3-4B_A16W4_4_chunks_2.pte,\
44+
/data/local/tmp/et_mtk/Qwen3-4B_A16W4_4_chunks/Qwen3-4B_A16W4_4_chunks_3.pte,"
4545

46-
# # Comma-Separated Paths
47-
GEN_MODEL_PATHS="\
48-
/data/local/tmp/et_mtk/Qwen3-4B_A16W4_4_chunks_1t512c/Qwen3-4B_A16W4_4_chunks_1t512c_0.pte,\
49-
/data/local/tmp/et_mtk/Qwen3-4B_A16W4_4_chunks_1t512c/Qwen3-4B_A16W4_4_chunks_1t512c_1.pte,\
50-
/data/local/tmp/et_mtk/Qwen3-4B_A16W4_4_chunks_1t512c/Qwen3-4B_A16W4_4_chunks_1t512c_2.pte,\
51-
/data/local/tmp/et_mtk/Qwen3-4B_A16W4_4_chunks_1t512c/Qwen3-4B_A16W4_4_chunks_1t512c_3.pte,"
52-
53-
54-
55-
PROMPT_FILE=/data/local/tmp/et_mtk/prompt.txt
46+
PROMPT_FILE=/data/local/tmp/et_mtk/prompt_qwen3.txt
5647

5748
chmod +x mtk_llama_executor_runner
5849

@@ -79,6 +70,5 @@ export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$PWD
7970
--tokenizer_type=$TOKENIZER_TYPE \
8071
--tokenizer_path=$TOKENIZER_PATH \
8172
--token_embedding_path=$TOKEN_EMBEDDING_PATH \
82-
--prompt_model_paths=$PROMPT_MODEL_PATHS \
83-
--gen_model_paths=$GEN_MODEL_PATHS \
73+
--model_package_paths=$WEIGHT_SHARED_MODEL_PACKAGE_PATHS \
8474
--prompt_file=$PROMPT_FILE

examples/mediatek/model_export_scripts/gemma.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,7 @@ def prepare_model_inputs(
315315
if window_size is not None:
316316
local_mask = generate_mask(
317317
max_cache_size,
318-
0,
318+
seq_length,
319319
input_length,
320320
input_length,
321321
sliding_window=True,

0 commit comments

Comments
 (0)