Skip to content

Commit ffb234e

Browse files
committed
script to run mlperf llama2 70b on gpu
1 parent deb14a3 commit ffb234e

File tree

2 files changed

+126
-1
lines changed

2 files changed

+126
-1
lines changed
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
#!/usr/bin/env bash
2+
3+
# Run command:
4+
# bash benchmarks_llama2-70b-h100_8.sh [-b benchmark_type]
5+
# benchmark_type can be: performance, audit, accuracy, or all (default)
6+
7+
run_name="h100_llama2-70b"
8+
dry_run=false
9+
enable_profiler=false
10+
test_mode=false
11+
benchmark_type="performance"
12+
13+
helpFunction()
14+
{
15+
echo ""
16+
echo "Usage: $0 [-n] [-p] [-t] [-s] [-x] [-r run_name] [-m token_multiplier] [-b benchmark_type]"
17+
echo -e "\t-n Dry run mode"
18+
echo -e "\t-p Enable profiler"
19+
echo -e "\t-t Test mode"
20+
echo -e "\t-r Specify run name"
21+
echo -e "\t-b Specify benchmark type (performance|audit|accuracy|all)"
22+
exit 1
23+
}
24+
25+
26+
for arg in "$@"; do
27+
case $arg in
28+
-n) dry_run=true ;;
29+
-p) enable_profiler=true ;;
30+
-t) test_mode=true ;;
31+
-r=*|--run=*) run_name="${arg#*=}" ;;
32+
-r|--run) shift; run_name="$1" ;;
33+
-b=*|--benchmark=*) benchmark_type="${arg#*=}" ;;
34+
-b|--benchmark) shift; benchmark_type="$1" ;;
35+
-h|--help) helpFunction ;;
36+
esac
37+
shift
38+
done
39+
40+
# Validate benchmark type
41+
case "$benchmark_type" in
42+
performance|audit|accuracy|all) ;;
43+
*) echo "Invalid benchmark type. Must be: performance, audit, accuracy, or all"; exit 1 ;;
44+
esac
45+
46+
47+
cmd=''
48+
RUN_OPTIONS=" -c " # Enable prefill packing by default
49+
if "$dry_run"; then
50+
RUN_OPTIONS="${RUN_OPTIONS} -n "
51+
fi
52+
53+
if "$enable_profiler"; then
54+
RUN_OPTIONS="${RUN_OPTIONS} -p "
55+
fi
56+
57+
58+
if "$test_mode"; then
59+
RUN_OPTIONS="${RUN_OPTIONS} -t "
60+
fi
61+
62+
export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_command_buffer=FUSION --xla_disable_hlo_passes=rematerialization"
63+
echo XLA_FLAGS: $XLA_FLAGS
64+
65+
# if [[ -z ${QUANTIZATION} ]] ; then
66+
# export QUANTIZATION="aqt_fp8"
67+
# fi
68+
69+
if [[ -z ${KV_QUANT_DTYPE} ]] ; then
70+
export KV_QUANT_DTYPE="fp8"
71+
export QUANTIZE_KVCACHE=True
72+
fi
73+
74+
if [[ -z ${CHECKPOINT} ]] ; then
75+
export CHECKPOINT="gs://jwyang/maxtext/direct_generate_param_only_checkpoint_llama2_70b_chat/checkpoints/0/items"
76+
fi
77+
78+
if [[ -z ${TOKENIZER_PATH} ]] ; then
79+
export TOKENIZER_PATH="/opt//maxtext/assets/tokenizer.llama2"
80+
fi
81+
82+
if [ -z "$PREFILL_LENS_AND_PER_DEVICE_BATCH_SIZES" ];
83+
then
84+
PREFILL_LEN="1024"
85+
BATCH_SIZE_PER_DEVICE="160"
86+
export PREFILL_LENS_AND_PER_DEVICE_BATCH_SIZES="${PREFILL_LEN},${BATCH_SIZE_PER_DEVICE}"
87+
fi
88+
89+
90+
BASE_CFG="model_name=llama2-70b tokenizer_path=${TOKENIZER_PATH} load_parameters_path=${CHECKPOINT} scan_layers=false hardware=gpu async_checkpointing=False ici_tensor_parallelism=-1 weight_dtype=bfloat16"
91+
KV_QUANT_CFG="quantize_kvcache=${QUANTIZE_KVCACHE} kv_quant_dtype=${KV_QUANT_DTYPE}"
92+
export MAXENGINE_ARGS="${BASE_CFG} ${KV_QUANT_CFG} optimize_mesh_for_tpu_v6e=false"
93+
echo
94+
echo $MAXENGINE_ARGS
95+
echo
96+
RUN_DESC=${run_name}_${PREFILL_LEN}_${BATCH_SIZE_PER_DEVICE}_quant_${QUANTIZATION}_${QUANT_MP}_kv_${KV_QUANT_DTYPE}_opt
97+
export BASEDIR=/opt/maxtext/Maxtext/inference_mlperf/
98+
99+
$cmd cd ..
100+
101+
run_benchmark() {
102+
local type=$1
103+
case "$type" in
104+
"performance")
105+
$cmd bash llama_offline_run.sh ${RUN_OPTIONS} -r -benchmarks_performance_${RUN_DESC}
106+
;;
107+
"audit")
108+
$cmd bash llama_offline_run.sh ${RUN_OPTIONS} -r -benchmarks_audit_${RUN_DESC} -d
109+
;;
110+
"accuracy")
111+
export HF_CKPT="meta-llama/Llama-2-70b-chat-hf"
112+
$cmd bash llama_offline_run.sh ${RUN_OPTIONS} -r benchmarks_accuracy_${RUN_DESC} -a
113+
;;
114+
esac
115+
}
116+
117+
if [ "$benchmark_type" = "all" ]; then
118+
run_benchmark "performance"
119+
run_benchmark "audit"
120+
run_benchmark "accuracy"
121+
else
122+
run_benchmark "$benchmark_type"
123+
fi
124+

MaxText/maxengine.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1230,7 +1230,8 @@ def set_engine_vars_from_base_engine(
12301230
"""Set internal vars from base_engine, which has already loaded the checkpoint and has sharding,
12311231
mesh, and kv cache related vars set.
12321232
"""
1233-
engine.model.quant.quant_mode = base_engine.model.quant.quant_mode
1233+
if base_engine.model.quant:
1234+
engine.model.quant.quant_mode = base_engine.model.quant.quant_mode
12341235
engine.state_mesh_annotations = base_engine.state_mesh_annotations
12351236
engine.abstract_params = base_engine.abstract_params
12361237
engine.kv_cache_annotations = max_utils.get_kv_cache_annotations(engine.model, engine.config, rng, engine.mesh) # pylint: disable=protected-access

0 commit comments

Comments
 (0)