File tree Expand file tree Collapse file tree 3 files changed +23
-2
lines changed Expand file tree Collapse file tree 3 files changed +23
-2
lines changed Original file line number Diff line number Diff line change @@ -109,7 +109,17 @@ NOTE: the `--platform=tpu=8` need to specify number of tpu devices (which is 4 f
109
109
``` bash
110
110
python run_server.py --param_size=7b --model_name=$model_name --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --platform=tpu=8 --model=$model_name
111
111
```
112
- Now you can fire gRPC to it
112
+
113
+ Now you can fire gRPC to it.
114
+
115
+ Optional flags:
116
+ * ` --shard_on_batch=1 ` This makes the model to shard on
117
+ the batch dimension. I.e. this runs in data parallel mode instead of model
118
+ parallel. This will ignore the sharding config. This is recommended for Gemma 2B
119
+ model, because Gemma 2B is small enough to fit on a single TPU chip.
120
+
121
+ * ` --sharding_config=<path> ` This makes use of alternative sharding config instead of
122
+ the ones in default_shardings directory.
113
123
114
124
# Run benchmark
115
125
go to the deps/JetStream folder (downloaded during ` install_everything.sh ` )
Original file line number Diff line number Diff line change @@ -16,6 +16,17 @@ Date | Device | dtype | batch size | cache length |max input length |max output
16
16
2024-05-10 | TPU v5e-8 | bfloat16 | 96 | 2048 | 1024 | 1024 | 3236
17
17
2024-05-10 | TPU v5e-8 | int8 | 128 | 2048 | 1024 | 1024 | 4695
18
18
19
+ ## Gemma - 2B
20
+
21
+ Date | Device | dtype | batch size | cache length |max input length |max output length| throughput (token/s)
22
+ ----| ------- | ------ |---------- | -------------|-----------------|------------------|----------------------
23
+ 2024-05-14 | TPU v5e-8 | bfloat16 | 512 | 2048 | 1024 | 1024 | 8700
24
+ 2024-05-14 | TPU v5e-8 | int8 | 1024 | 2048 | 1024 | 1024 | 8746
25
+
26
+ ** NOTE: ** Gemma 2B uses ` --shard_on_batch ` flag so it's data parallel instead
27
+ of model parallel.
28
+
29
+
19
30
## Llama 2 - 7B
20
31
21
32
Date | Device | dtype | batch size | cache length |max input length |max output length| throughput (token/s)
Original file line number Diff line number Diff line change @@ -176,7 +176,7 @@ def make_caches_generate(self):
176
176
def sharding_by_name (self , name ):
177
177
"""Create sharding specified in the config."""
178
178
if self .shard_on_batch :
179
- return self .shading_by_axis (0 ) # batch dimension
179
+ return self .sharding_by_axis (0 ) # batch dimension
180
180
181
181
if name in self ._sharding_config :
182
182
return self .sharding_by_axis (self ._sharding_config [name ])
You can’t perform that action at this time.
0 commit comments