Skip to content

Commit 3da590d

Browse files
authored
Update scripts (#69)
* [WIP] Add an all_in_one shell script to simplify experiments. * Fix typo. * Fix typo * [Fix] Fix kernel copy command in all_in_one. * Update README script. * [script] Remove limitation of inference_type. * [fix] Fix a random compile error. * Add --disable-t-mac for convenience.
1 parent cdd878b commit 3da590d

File tree

5 files changed

+227
-20
lines changed

5 files changed

+227
-20
lines changed

README.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -313,14 +313,14 @@ We have provided an **all-in-one script**. Invoke it with:
313313
```bash
314314
pip install 3rdparty/llama.cpp/gguf-py
315315
huggingface-cli download 1bitLLM/bitnet_b1_58-3B --local-dir ${model_dir}
316-
python tools/run_pipeline.py -o ${model_dir}
316+
python tools/run_pipeline.py -o ${model_dir} -q int_n
317317
```
318318

319319
We have also supported models in GTPQ format from [GPTQModel](https://github.com/ModelCloud/GPTQModel)/[EfficientQAT](https://github.com/OpenGVLab/EfficientQAT). Try it out with officially released EfficientQAT (of GPTQ format) [Llama-3-8b-instruct-w2-g128](https://huggingface.co/ChenMnZ/Llama-3-8b-instruct-EfficientQAT-w2g128-GPTQ):
320320

321321
```bash
322322
huggingface-cli download ChenMnZ/Llama-3-8b-instruct-EfficientQAT-w2g128-GPTQ --local-dir ${model_dir}
323-
python tools/run_pipeline.py -o ${model_dir} -m llama-3-8b-2bit
323+
python tools/run_pipeline.py -o ${model_dir} -m llama-3-8b-2bit -q int_n
324324
```
325325

326326
> - Use `-p` or `-s` argument to select the steps you want to run.

python/t_mac/intrins/tbl.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -41,18 +41,18 @@ def tbl(
4141

4242
if m_groups == -1:
4343
if zero_point:
44-
scales_shape = (1, m // bits * 2)
44+
scales_shape = (kfactor * g // act_group_size, m // bits * 2)
4545
def _get_scale(m, k):
46-
return Scales[0, m // bits * 2] - Scales[0, m // bits * 2 + 1]
46+
return Scales[k * g // act_group_size, m // bits * 2] - Scales[k * g // act_group_size, m // bits * 2 + 1]
4747
else:
48-
scales_shape = (1, m // bits)
48+
scales_shape = (kfactor * g // act_group_size, m // bits)
4949
def _get_scale(m, k):
50-
return Scales[0, m // bits]
50+
return Scales[k * g // act_group_size, m // bits]
5151
scale_buffer_strides = [te.var("ss"), 1]
5252
else:
53-
scales_shape = (1,)
53+
scales_shape = (kfactor * g // act_group_size,)
5454
def _get_scale(m, k):
55-
return Scales[0]
55+
return Scales[k * g // act_group_size]
5656
scale_buffer_strides = [1]
5757

5858
alpha = te.const(get_bits_alphas(bits)[0], dtype=out_dtype)

tools/all_in_one.sh

+103
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
set -e
2+
3+
if [ "$#" -lt 3 ]; then
4+
echo "Usage: $0 <model_path> <kernel_name> <model_type> [--rechunk] [--convert-model] [--run-only] [--disable-t-mac]"
5+
echo " model_path: path to the model directory"
6+
echo " kernel_name: name of the kernel for compiler, e.g., llama-2-7b-4bit, hf-bitnet-3b, hf-bitnet-large-intn, hf-bitnet-large-tq, trilm-3.9b"
7+
echo " model_type: type of the model, e.g., f16, int_n, tq1_0, tq2_0, q4_0"
8+
echo " --rechunk: optional. Rechunk the model if set."
9+
echo " --convert-model: optional. Convert the model to gguf format if set."
10+
echo " --run-only: optional. Skip the compilation and only run the inference and benchmark if set."
11+
echo " --disable-t-mac: optional. Disable T-MAC if set."
12+
exit 1
13+
fi
14+
15+
16+
if [[ "$3" == "q4_0" ]]; then
17+
export EXTRA_COMPILE_ARGS=("-gs=32" "-ags=32")
18+
elif [[ "$3" == "tq1_0" || "$3" == "tq2_0" ]]; then
19+
export EXTRA_COMPILE_ARGS=("-gs=256" "-ags=64")
20+
else
21+
export EXTRA_COMPILE_ARGS=()
22+
fi
23+
24+
25+
RECHUNK=false
26+
for arg in "$@"; do
27+
case $arg in
28+
--rechunk)
29+
RECHUNK=true
30+
;;
31+
*)
32+
;;
33+
esac
34+
done
35+
36+
37+
CONVERT_MODEL=false
38+
for arg in "$@"; do
39+
case $arg in
40+
--convert-model)
41+
CONVERT_MODEL=true
42+
;;
43+
*)
44+
;;
45+
esac
46+
done
47+
48+
RUN_ONLY=false
49+
for arg in "$@"; do
50+
case $arg in
51+
--run-only)
52+
RUN_ONLY=true
53+
;;
54+
*)
55+
;;
56+
esac
57+
done
58+
59+
DISABLE_T_MAC=false
60+
for arg in "$@"; do
61+
case $arg in
62+
--disable-t-mac)
63+
DISABLE_T_MAC=true
64+
;;
65+
*)
66+
;;
67+
esac
68+
done
69+
70+
export MODEL_DIR=$(readlink -f "$1")
71+
export KERNEL_NAME=$2
72+
export MODEL_DTYPE=$3
73+
74+
echo "MODEL_DIR: $MODEL_DIR"
75+
echo "KERNEL_NAME: $KERNEL_NAME"
76+
echo "MODEL_DTYPE: $MODEL_DTYPE"
77+
echo "RECHUNK: $RECHUNK"
78+
echo "CONVERT_MODEL: $CONVERT_MODEL"
79+
echo "RUN_ONLY: $RUN_ONLY"
80+
echo "DISABLE_T_MAC: $DISABLE_T_MAC"
81+
82+
83+
if [ "$RUN_ONLY" != true ]; then
84+
if [ "$DISABLE_T_MAC" == true ]; then
85+
echo "=== python tools/run_pipeline.py -o $MODEL_DIR -m $KERNEL_NAME -nt 4 -s 4,5 "${EXTRA_COMPILE_ARGS[@]}" --disable-t-mac ==="
86+
python tools/run_pipeline.py -o $MODEL_DIR -m $KERNEL_NAME -nt 4 -s 4,5 ${EXTRA_COMPILE_ARGS[@]} --disable-t-mac
87+
else
88+
echo "=== python tools/run_pipeline.py -o $MODEL_DIR -m $KERNEL_NAME -nt 4 -s 0,1,2,4,5 "${EXTRA_COMPILE_ARGS[@]}" -q $MODEL_DTYPE ==="
89+
python tools/run_pipeline.py -o $MODEL_DIR -m $KERNEL_NAME -nt 4 -s 0,1,2,4,5 ${EXTRA_COMPILE_ARGS[@]} -q $MODEL_DTYPE
90+
if $CONVERT_MODEL; then
91+
echo "=== python tools/run_pipeline.py -o $MODEL_DIR -m $KERNEL_NAME -nt 4 -s 3 "${EXTRA_COMPILE_ARGS[@]}" -q $MODEL_DTYPE ==="
92+
python tools/run_pipeline.py -o $MODEL_DIR -m $KERNEL_NAME -nt 4 -s 3 ${EXTRA_COMPILE_ARGS[@]} -q $MODEL_DTYPE
93+
fi
94+
fi
95+
fi
96+
97+
echo "=== python tools/run_pipeline.py -o "$MODEL_DIR" -it "$MODEL_DTYPE" -s 6 ==="
98+
python tools/run_pipeline.py -o "$MODEL_DIR" -it $MODEL_DTYPE -s 6
99+
for threads in $(seq 1 4); do
100+
echo "=== Running with $threads threads, 1 batch ==="
101+
python tools/run_pipeline.py -o "$MODEL_DIR" -it $MODEL_DTYPE -nt $threads -s 7
102+
done
103+

tools/run_pipeline.py

+115-11
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,14 @@ def run_command(command, pwd, ignore_errors=False):
1616
print(f" Running command in {pwd}:")
1717
print(f" {' '.join(command)}")
1818
os.makedirs(FLAGS.logs_dir, exist_ok=True)
19-
log_file = os.path.join(FLAGS.logs_dir, datetime.now().strftime("%Y-%m-%d-%H-%M-%S.log"))
19+
command_name = command[0].split(os.path.sep)[-1]
20+
log_file = os.path.join(FLAGS.logs_dir, f"{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}_{command_name}.log")
2021
with open(log_file, "w") as fp:
2122
try:
22-
subprocess.check_call(command, cwd=pwd, stdout=fp, stderr=fp)
23+
if "llama-bench" in command_name:
24+
subprocess.check_call(command, cwd=pwd)
25+
else:
26+
subprocess.check_call(command, cwd=pwd, stdout=fp, stderr=fp)
2327
except subprocess.CalledProcessError as err:
2428
if not ignore_errors:
2529
print(RED + f"Please check {log_file} for what's wrong" + RESET)
@@ -48,6 +52,7 @@ def get_llamacpp_build_dir():
4852

4953

5054
def compile_kernels():
55+
model_name = f"{FLAGS.model}_{str(FLAGS.quant_type).upper()}"
5156
deploy_dir = os.path.join(ROOT_DIR, "deploy")
5257
tuned_dir = os.path.join(deploy_dir, "tuned")
5358
prebuilt_dir = os.path.join(tuned_dir, f"{get_arch(FLAGS.device)}-{FLAGS.model}")
@@ -56,10 +61,18 @@ def compile_kernels():
5661
shutil.copytree(prebuilt_dir, tuned_dir, dirs_exist_ok=True)
5762
return
5863

64+
# Clear previous tune.log
65+
command = [
66+
'rm',
67+
os.path.join("tuned", "preprocessor", "tune.log"),
68+
os.path.join("tuned", "qgemm_lut", "tune.log"),
69+
]
70+
run_command(command, deploy_dir, ignore_errors=True)
71+
5972
qargs = get_quant_args()
6073
command = [
6174
'python', 'compile.py',
62-
'-o', 'tuned',
75+
'-o', f'{os.path.join("tuned", model_name)}',
6376
'-da',
6477
'-nt', f'{FLAGS.num_threads}',
6578
'-tb',
@@ -82,6 +95,11 @@ def compile_kernels():
8295
command.append('-v')
8396
run_command(command, deploy_dir)
8497

98+
# Move to pre-install directory
99+
kernel_dir = os.path.join(tuned_dir, model_name)
100+
print(f" Copy built kernels from {kernel_dir} to {tuned_dir}")
101+
shutil.copytree(kernel_dir, tuned_dir, dirs_exist_ok=True)
102+
85103

86104
def _clean_cmake(build_dir):
87105
command = ['cmake', '--build', '.', '--target', 'clean']
@@ -123,31 +141,51 @@ def convert_models():
123141
model_dir = FLAGS.model_dir
124142
if not os.path.exists(model_dir):
125143
raise FileNotFoundError(model_dir)
126-
out_path = os.path.join(model_dir, f"ggml-model.{FLAGS.quant_type}.gguf")
144+
145+
out_type = FLAGS.quant_type
146+
if FLAGS.quant_type == "q4_0":
147+
out_type = "f16"
148+
149+
model_name = f"{os.path.split(model_dir)[-1]}.{str(out_type).upper()}.gguf"
150+
out_path = os.path.join(model_dir, model_name)
127151
kcfg_path = os.path.join(ROOT_DIR, "install", "lib", "kcfg.ini")
128152
llamacpp_dir = os.path.join(ROOT_DIR, "3rdparty", "llama.cpp")
129153
command = [
130154
'python',
131155
'convert_hf_to_gguf.py',
132156
f'{model_dir}',
133-
'--outtype', f'{FLAGS.quant_type}',
157+
'--outtype', f'{out_type}',
134158
'--outfile', f'{out_path}',
135159
'--kcfg', f'{kcfg_path}',
136160
'--enable-t-mac',
137161
'--verbose',
138162
]
139163
run_command(command, llamacpp_dir)
140164

165+
if FLAGS.quant_type == "q4_0":
166+
quantized_model_name = f"{os.path.split(model_dir)[-1]}.Q4_0.gguf"
167+
quantized_out_path = os.path.join(model_dir, quantized_model_name)
168+
command = [
169+
'./build/bin/llama-quantize',
170+
'--token-embedding-type', 'f16',
171+
'--output-tensor-type', 'f16',
172+
f'{out_path}',
173+
f'{quantized_out_path}',
174+
'q4_0',
175+
]
176+
run_command(command, llamacpp_dir)
177+
141178

142179
def cmake_llamacpp():
143180
build_dir = get_llamacpp_build_dir()
144181
cmake_prefix_path = os.path.join(ROOT_DIR, "install", "lib", "cmake", "t-mac")
145182
command = [
146183
'cmake', '..',
147-
'-DGGML_TMAC=ON',
184+
f'-DGGML_TMAC={"OFF" if FLAGS.disable_t_mac else "ON"}',
148185
f'-DCMAKE_PREFIX_PATH={cmake_prefix_path}',
149186
'-DCMAKE_BUILD_TYPE=Release',
150187
'-DGGML_OPENMP=OFF',
188+
f'-DGGML_TMAC_RECHUNK={"ON" if FLAGS.rechunk else "OFF"}',
151189
]
152190
if FLAGS.device == "android":
153191
try:
@@ -178,13 +216,14 @@ def cmake_llamacpp():
178216

179217
def build_llamacpp():
180218
build_dir = get_llamacpp_build_dir()
181-
command = ['cmake', '--build', '.', '--target', 'llama-cli', 'llama-bench', 'llama-quantize', '--config', 'Release']
219+
command = ['cmake', '--build', '.', '--target', 'llama-cli', 'llama-bench', 'llama-quantize', 'llama-perplexity', '--config', 'Release']
182220
run_command(command, build_dir)
183221

184222

185223
def run_inference():
186224
build_dir = get_llamacpp_build_dir()
187-
out_path = os.path.join(FLAGS.model_dir, f"ggml-model.{FLAGS.quant_type}.gguf")
225+
model_name = f"{os.path.split(FLAGS.model_dir)[-1]}.{str(FLAGS.inference_type).upper()}.gguf"
226+
out_path = os.path.join(FLAGS.model_dir, model_name)
188227
if is_win():
189228
main_path = os.path.join(build_dir, "bin", "Release", "llama-cli.exe")
190229
if not os.path.exists(main_path):
@@ -229,14 +268,67 @@ def run_inference():
229268
'-m', f'{out_path}',
230269
'-n', '128',
231270
'-t', f'{FLAGS.num_threads}',
232-
'-p', prompt,
271+
'-p', f'{prompt}',
233272
'-ngl', '0',
234273
'-c', '2048'
235274
]
236275
log_file = run_command(command, build_dir)
237276
print(GREEN + f"Check {log_file} for inference output" + RESET)
238277

239278

279+
def run_llama_bench():
280+
build_dir = get_llamacpp_build_dir()
281+
model_name = f"{os.path.split(FLAGS.model_dir)[-1]}.{str(FLAGS.inference_type).upper()}.gguf"
282+
out_path = os.path.join(FLAGS.model_dir, model_name)
283+
if is_win():
284+
main_path = os.path.join(build_dir, "bin", "Release", "llama-bench.exe")
285+
if not os.path.exists(main_path):
286+
main_path = os.path.join(build_dir, "bin", "llama-bench")
287+
else:
288+
main_path = os.path.join(build_dir, "bin", "llama-bench")
289+
prompt = 256
290+
# TODO: verify in Android
291+
if FLAGS.device == "android":
292+
remote_bin_path = os.path.join(FLAGS.remote_dir, "bin")
293+
command = ['push', os.path.join(build_dir, "bin"), FLAGS.remote_dir]
294+
run_adb_command(command, build_dir)
295+
remote_main_path = os.path.join(remote_bin_path, "llama-bench")
296+
command = ['shell', 'chmod', '-R', '+x', remote_bin_path]
297+
run_adb_command(command, build_dir)
298+
remote_out_path = os.path.join(
299+
FLAGS.remote_dir,
300+
f"{os.path.basename(FLAGS.model_dir)}-{os.path.basename(out_path)}",
301+
)
302+
if not FLAGS.skip_push_model:
303+
command = ['push', out_path, remote_out_path]
304+
run_adb_command(command, build_dir)
305+
kcfg_path = os.path.join(ROOT_DIR, "install", "lib", "kcfg.ini")
306+
remote_kcfg_path = os.path.join(FLAGS.remote_dir, "kcfg.ini")
307+
command = ['push', kcfg_path, remote_kcfg_path]
308+
run_adb_command(command, build_dir)
309+
command = [
310+
'shell',
311+
f'TMAC_KCFG_FILE={remote_kcfg_path}',
312+
f'{remote_main_path}',
313+
'-m', f'{remote_out_path}',
314+
'-n', '128',
315+
'-t', f'{FLAGS.num_threads}',
316+
'-p', f'{prompt}',
317+
'-ngl', '0',
318+
]
319+
log_file = run_adb_command(command, build_dir)
320+
else:
321+
command = [
322+
f'{main_path}',
323+
'-m', f'{out_path}',
324+
'-n', '128',
325+
'-t', f'{FLAGS.num_threads}',
326+
'-p', f'{prompt}',
327+
'-ngl', '0',
328+
]
329+
log_file = run_command(command, build_dir)
330+
print(GREEN + f"Check {log_file} for llama-bench output" + RESET)
331+
240332
STEPS = [
241333
("Compile kernels", compile_kernels),
242334
("Build T-MAC C++ CMakeFiles", cmake_t_mac),
@@ -245,6 +337,7 @@ def run_inference():
245337
("Build llama.cpp CMakeFiles", cmake_llamacpp),
246338
("Build llama.cpp", build_llamacpp),
247339
("Run inference", run_inference),
340+
("Run llama-bench", run_llama_bench)
248341
]
249342

250343

@@ -278,7 +371,10 @@ def parse_args():
278371
parser.add_argument("-gs", "--group_size", type=int, default=None, help="Don't set this argument if you don't know its meaning.")
279372
parser.add_argument("-ags", "--act_group_size", type=int, default=None, help="Don't set this argument if you don't know its meaning.")
280373
parser.add_argument("-ld", "--logs_dir", type=str, default="logs")
281-
parser.add_argument("-q", "--quant_type", type=str, choices=["int_n", "f16", "f32"], default="int_n")
374+
parser.add_argument("-q", "--quant_type", type=str, choices=["int_n", "f16", "f32", "tq1_0", "tq2_0", "q4_0"], default=None,
375+
help="Quantization model type. This will override inference_type.")
376+
parser.add_argument("-it", "--inference_type", type=str, default="int_n",
377+
help="Inference model type. This will be overridden by quant_type if quant_type is set.")
282378
parser.add_argument("-zp", "--zero_point", action="store_true", help="Enforce enable zero_point. Required by EfficientQAT models.")
283379
parser.add_argument("-nzp", "--no_zero_point", action="store_false", help="Enforce disable zero_point. Don't set this argument if you don't know its meaning.")
284380

@@ -293,8 +389,16 @@ def parse_args():
293389
parser.add_argument("-ndk", "--ndk_home", type=str, default="", help="NDK home")
294390
parser.add_argument("-spm", "--skip_push_model", action="store_true", help="Suppose the model is unchanged to skip pushing the model file")
295391

392+
parser.add_argument("-rc", "--rechunk", action="store_true", help="Set this argument if you want to use rechunk in computation.")
393+
parser.add_argument("--disable-t-mac", action="store_true", help="Set this argument if you want to disable T-MAC.")
394+
296395
parser.set_defaults(zero_point=None)
297-
return parser.parse_args()
396+
args = parser.parse_args()
397+
398+
if args.quant_type is not None:
399+
args.inference_type = args.quant_type
400+
401+
return args
298402

299403

300404
def get_quant_args():

0 commit comments

Comments
 (0)