Skip to content

Commit bce7598

Browse files
FanhaiLu1jwyang-google
authored andcommitted
Refactor readme (#41)
* Reafactor readme * update README * Update README.md
1 parent 735570e commit bce7598

File tree

2 files changed

+301
-293
lines changed

2 files changed

+301
-293
lines changed

README.md

+12-293
Original file line numberDiff line numberDiff line change
@@ -9,306 +9,25 @@
99

1010
JetStream is a throughput and memory optimized engine for LLM inference on XLA devices, starting with TPUs (and GPUs in future -- PRs welcome).
1111

12-
## Documentation
13-
14-
- [Online Inference with MaxText on v5e Cloud TPU VM](https://cloud.google.com/tpu/docs/tutorials/LLM/jetstream) [[README](#jetstream-maxtext-inference-on-v5e-cloud-tpu-vm-user-guide)]
15-
- [Online Inference with Pytorch on v5e Cloud TPU VM](https://cloud.google.com/tpu/docs/tutorials/LLM/jetstream-pytorch) [[README](https://github.com/google/jetstream-pytorch/tree/main?tab=readme-ov-file#jetstream-pytorch)]
16-
- [Serve Gemma using TPUs on GKE with JetStream](https://cloud.google.com/kubernetes-engine/docs/tutorials/serve-gemma-tpu-jetstream)
17-
- [JetStream Standalone Local Setup](#jetstream-standalone-local-setup)
18-
19-
---
20-
21-
# JetStream MaxText Inference on v5e Cloud TPU VM User Guide
22-
23-
## Outline
24-
25-
26-
1. Prerequisites: Prepare your GCP project and connect to Cloud TPU VM
27-
2. Download the JetStream and MaxText github repository
28-
3. Setup your MaxText JetStream environment
29-
4. Convert Model Checkpoints
30-
5. Run the JetStream MaxText server
31-
6. Send a test request to the JetStream MaxText server
32-
7. Run benchmarks with the JetStream MaxText server
33-
8. Clean up
34-
35-
36-
## Prerequisites: Prepare your GCP project and connect to Cloud TPU VM
37-
38-
Follow the steps in [Manage TPU resources | Google Cloud](https://cloud.google.com/tpu/docs/managing-tpus-tpu-vm) to create a Cloud TPU VM (Recommend TPU type: `v5litepod-8`) and connect to the Cloud TPU VM.
39-
40-
41-
## Step 1: Download JetStream and the MaxText github repository
42-
43-
```bash
44-
git clone -b jetstream-v0.2.0 https://github.com/google/maxtext.git
45-
git clone -b v0.2.0 https://github.com/google/JetStream.git
46-
```
47-
48-
## Step 2: Setup MaxText
49-
50-
```bash
51-
# Create a python virtual environment for the demo.
52-
sudo apt install python3.10-venv
53-
python -m venv .env
54-
source .env/bin/activate
55-
56-
# Setup MaxText.
57-
cd maxtext/
58-
bash setup.sh
59-
```
60-
61-
## Step 3: Convert Model Checkpoints
62-
63-
You can run the JetStream MaxText Server with Gemma and Llama2 models. This section describes how to run the JetStream MaxText server with various sizes of these models.
64-
65-
### Use a Gemma model checkpoint
66-
67-
* You can download a [Gemma checkpoint from Kaggle](https://www.kaggle.com/models/google/gemma/frameworks/maxText/variations/7b).
68-
* After downloading checkpoints, copy them to your GCS bucket at `$CHKPT_BUCKET`.
69-
* `gsutil -m cp -r ${YOUR_CKPT_PATH} ${CHKPT_BUCKET}`
70-
* Please refer to the [conversion script](https://github.com/google/JetStream/blob/main/jetstream/tools/maxtext/model_ckpt_conversion.sh) for an example of `$CHKPT_BUCKET`.
71-
* Then, using the following command to convert the Gemma checkpoint into a MaxText compatible unscanned checkpoint.
72-
73-
```bash
74-
# bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh ${MODEL} ${MODEL_VARIATION} ${CHKPT_BUCKET}
75-
76-
# For gemma-7b
77-
bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh gemma 7b ${CHKPT_BUCKET}
78-
```
79-
80-
Note: For more information about the Gemma model and checkpoints, see [About Gemma](https://github.com/google/maxtext/blob/main/end_to_end/gemma/Run_Gemma.md).
81-
82-
83-
### Use a Llama2 model checkpoint
84-
85-
* You can use a Llama2 checkpoint you have generated or one from [the open source community](https://llama.meta.com/llama-downloads/).
86-
* After downloading checkpoints, copy them to your GCS bucket at `$CHKPT_BUCKET`.
87-
* `gsutil -m cp -r ${YOUR_CKPT_PATH} ${CHKPT_BUCKET}`
88-
* Please refer to the [conversion script](https://github.com/google/JetStream/blob/main/jetstream/tools/maxtext/model_ckpt_conversion.sh) for an example of `$CHKPT_BUCKET`.
89-
* Then, using the following command to convert the Llama2 checkpoint into a MaxText compatible unscanned checkpoint.
90-
91-
```bash
92-
# bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh ${MODEL} ${MODEL_VARIATION} ${CHKPT_BUCKET}
93-
94-
# For llama2-7b
95-
bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh llama2 7b ${CHKPT_BUCKET}
96-
97-
# For llama2-13b
98-
bash ../JetStream/jetstream/tools/maxtext/model_ckpt_conversion.sh llama2 13b ${CHKPT_BUCKET}
99-
```
100-
101-
Note: For more information about the Llama2 model and checkpoints, see [About Llama2](https://github.com/google/maxtext/blob/main/getting_started/Run_Llama2.md).
102-
103-
104-
## Step4: Run the JetStream MaxText server
105-
106-
107-
### Create model config environment variables for server flags
108-
109-
You can export the following environment variables based on the model you used.
110-
111-
* You can copy and export the `UNSCANNED_CKPT_PATH` from the model\_ckpt\_conversion.sh output.
112-
113-
114-
#### Create Gemma-7b environment variables for server flags
115-
12+
## JetStream Engine Implementation
11613

14+
Currently, there are two reference engine implementations available -- one for Jax models and another for Pytorch models.
11715

118-
* Configure the [flags](#jetstream-maxtext-server-flag-descriptions) passing into the JetStream MaxText server
16+
### Jax
11917

120-
```bash
121-
export TOKENIZER_PATH=assets/tokenizer.gemma
122-
export LOAD_PARAMETERS_PATH=${UNSCANNED_CKPT_PATH}
123-
export MAX_PREFILL_PREDICT_LENGTH=1024
124-
export MAX_TARGET_LENGTH=2048
125-
export MODEL_NAME=gemma-7b
126-
export ICI_FSDP_PARALLELISM=1
127-
export ICI_AUTOREGRESSIVE_PARALLELISM=-1
128-
export ICI_TENSOR_PARALLELISM=1
129-
export SCAN_LAYERS=false
130-
export WEIGHT_DTYPE=bfloat16
131-
export PER_DEVICE_BATCH_SIZE=4
132-
```
133-
134-
#### Create Llama2-7b environment variables for server flags
135-
136-
* Configure the [flags](#jetstream-maxtext-server-flag-descriptions) passing into the JetStream MaxText server
137-
138-
```bash
139-
export TOKENIZER_PATH=assets/tokenizer.llama2
140-
export LOAD_PARAMETERS_PATH=${UNSCANNED_CKPT_PATH}
141-
export MAX_PREFILL_PREDICT_LENGTH=1024
142-
export MAX_TARGET_LENGTH=2048
143-
export MODEL_NAME=llama2-7b
144-
export ICI_FSDP_PARALLELISM=1
145-
export ICI_AUTOREGRESSIVE_PARALLELISM=-1
146-
export ICI_TENSOR_PARALLELISM=1
147-
export SCAN_LAYERS=false
148-
export WEIGHT_DTYPE=bfloat16
149-
export PER_DEVICE_BATCH_SIZE=6
150-
```
151-
152-
#### Create Llama2-13b environment variables for server flags
153-
154-
155-
156-
* Configure the [flags](#jetstream-maxtext-server-flag-descriptions) passing into the JetStream MaxText server
157-
158-
```bash
159-
export TOKENIZER_PATH=assets/tokenizer.llama2
160-
export LOAD_PARAMETERS_PATH=${UNSCANNED_CKPT_PATH}
161-
export MAX_PREFILL_PREDICT_LENGTH=1024
162-
export MAX_TARGET_LENGTH=2048
163-
export MODEL_NAME=llama2-13b
164-
export ICI_FSDP_PARALLELISM=1
165-
export ICI_AUTOREGRESSIVE_PARALLELISM=-1
166-
export ICI_TENSOR_PARALLELISM=1
167-
export SCAN_LAYERS=false
168-
export WEIGHT_DTYPE=bfloat16
169-
export PER_DEVICE_BATCH_SIZE=2
170-
```
171-
172-
### Run the following command to start the JetStream MaxText server
173-
174-
```bash
175-
cd ~/maxtext
176-
python MaxText/maxengine_server.py \
177-
MaxText/configs/base.yml \
178-
tokenizer_path=${TOKENIZER_PATH} \
179-
load_parameters_path=${LOAD_PARAMETERS_PATH} \
180-
max_prefill_predict_length=${MAX_PREFILL_PREDICT_LENGTH} \
181-
max_target_length=${MAX_TARGET_LENGTH} \
182-
model_name=${MODEL_NAME} \
183-
ici_fsdp_parallelism=${ICI_FSDP_PARALLELISM} \
184-
ici_autoregressive_parallelism=${ICI_AUTOREGRESSIVE_PARALLELISM} \
185-
ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} \
186-
scan_layers=${SCAN_LAYERS} \
187-
weight_dtype=${WEIGHT_DTYPE} \
188-
per_device_batch_size=${PER_DEVICE_BATCH_SIZE}
189-
```
190-
191-
### JetStream MaxText Server flag descriptions:
192-
193-
194-
195-
* tokenizer\_path: file path to a tokenizer (should match your model)
196-
* load\_parameters\_path: Loads the parameters (no optimizer states) from a specific directory
197-
* per\_device\_batch\_size: decoding batch size per device (1 TPU chip = 1 device)
198-
* max\_prefill\_predict\_length: Maximum length for the prefill when doing autoregression
199-
* max\_target\_length: Maximum sequence length
200-
* model\_name: Model name
201-
* ici\_fsdp\_parallelism: The number of shards for FSDP parallelism
202-
* ici\_autoregressive\_parallelism: The number of shards for autoregressive parallelism
203-
* ici\_tensor\_parallelism: The number of shards for tensor parallelism
204-
* weight\_dtype: Weight data type (e.g. bfloat16)
205-
* scan\_layers: Scan layers boolean flag
206-
207-
Note: these flags are from [MaxText config](https://github.com/google/maxtext/blob/f9e04cdc1eec74a0e648411857c09403c3358461/MaxText/configs/base.yml)
208-
209-
210-
## Step 5: Send test request to JetStream MaxText server
211-
212-
```bash
213-
cd ~
214-
python JetStream/jetstream/tools/requester.py
215-
```
18+
- Git: https://github.com/google/maxtext
19+
- README: https://github.com/google/jetstream/blob/main/max_text/README.md
21620

217-
The output will be similar to the following:
21+
### Pytorch
21822

219-
```bash
220-
Sending request to: 0.0.0.0:9000
221-
Prompt: Today is a good day
222-
Response: to be a fan
223-
```
224-
225-
## Step 6: Run benchmarks with JetStream MaxText server
226-
227-
Note: The JetStream MaxText Server is not running with quantization optimization in Step 3. To get best benchmark results, we need to enable quantization (Please use AQT trained or fine tuned checkpoints to ensure accuracy) for both weights and KV cache, please add the quantization flags and restart the server as following:
228-
229-
```bash
230-
# Enable int8 quantization for both weights and KV cache
231-
export QUANTIZATION=int8
232-
export QUANTIZE_KVCACHE=true
233-
234-
# For Gemma 7b model, change per_device_batch_size to 12 to optimize performance.
235-
export PER_DEVICE_BATCH_SIZE=12
236-
237-
cd ~/maxtext
238-
python MaxText/maxengine_server.py \
239-
MaxText/configs/base.yml \
240-
tokenizer_path=${TOKENIZER_PATH} \
241-
load_parameters_path=${LOAD_PARAMETERS_PATH} \
242-
max_prefill_predict_length=${MAX_PREFILL_PREDICT_LENGTH} \
243-
max_target_length=${MAX_TARGET_LENGTH} \
244-
model_name=${MODEL_NAME} \
245-
ici_fsdp_parallelism=${ICI_FSDP_PARALLELISM} \
246-
ici_autoregressive_parallelism=${ICI_AUTOREGRESSIVE_PARALLELISM} \
247-
ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} \
248-
scan_layers=${SCAN_LAYERS} \
249-
weight_dtype=${WEIGHT_DTYPE} \
250-
per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \
251-
quantization=${QUANTIZATION} \
252-
quantize_kvcache=${QUANTIZE_KVCACHE}
253-
```
254-
255-
### Benchmarking Gemma-7b
256-
257-
Instructions
258-
- Download the ShareGPT dataset
259-
- Make sure to use the Gemma tokenizer (tokenizer.gemma) when running Gemma 7b.
260-
- Add `--warmup-first` flag for your 1st run to warmup the server
23+
- Git: https://github.com/google/jetstream-pytorch
24+
- README: https://github.com/google/jetstream-pytorch/blob/main/README.md
26125

262-
```bash
263-
# Activate the python virtual environment we created in Step 2.
264-
cd ~
265-
source .env/bin/activate
266-
267-
# download dataset
268-
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
269-
270-
# run benchmark with the downloaded dataset and the tokenizer in maxtext
271-
# You can control the qps by setting `--request-rate`, the default value is inf.
272-
python JetStream/benchmarks/benchmark_serving.py \
273-
--tokenizer /home/$USER/maxtext/assets/tokenizer.gemma \
274-
--num-prompts 1000 \
275-
--dataset sharegpt \
276-
--dataset-path ~/ShareGPT_V3_unfiltered_cleaned_split.json \
277-
--max-output-length 1024 \
278-
--request-rate 5 \
279-
--warmup-first true
280-
```
281-
282-
### Benchmarking Llama2-\*b
283-
284-
```bash
285-
# Same as Gemma-7b except for the tokenizer (must use a tokenizer that matches your model, which should now be tokenizer.llama2).
286-
287-
python JetStream/benchmarks/benchmark_serving.py \
288-
--tokenizer maxtext/assets/tokenizer.llama2 \
289-
--num-prompts 1000 \
290-
--dataset sharegpt \
291-
--dataset-path ~/ShareGPT_V3_unfiltered_cleaned_split.json \
292-
--max-output-length 1024 \
293-
--request-rate 5 \
294-
--warmup-first true
295-
```
296-
297-
## Clean Up
298-
299-
```bash
300-
# Clean up gcs buckets.
301-
gcloud storage buckets delete ${MODEL_BUCKET}
302-
gcloud storage buckets delete ${BASE_OUTPUT_DIRECTORY}
303-
gcloud storage buckets delete ${DATASET_PATH}
304-
# Clean up repositories.
305-
rm -rf maxtext
306-
rm -rf JetStream
307-
# Clean up python virtual environment
308-
rm -rf .env
309-
```
26+
- [Online Inference with MaxText on v5e Cloud TPU VM](https://cloud.google.com/tpu/docs/tutorials/LLM/jetstream) [[README](#jetstream-maxtext-inference-on-v5e-cloud-tpu-vm-user-guide)]
27+
- [Online Inference with Pytorch on v5e Cloud TPU VM](https://cloud.google.com/tpu/docs/tutorials/LLM/jetstream-pytorch) [[README](https://github.com/google/jetstream-pytorch/tree/main?tab=readme-ov-file#jetstream-pytorch)]
28+
- [Serve Gemma using TPUs on GKE with JetStream](https://cloud.google.com/kubernetes-engine/docs/tutorials/serve-gemma-tpu-jetstream)
29+
- [JetStream Standalone Local Setup](#jetstream-standalone-local-setup)
31030

311-
---
31231

31332
# JetStream Standalone Local Setup
31433

0 commit comments

Comments
 (0)