Skip to content

Commit f3cf2b7

Browse files
authored
Add Gemma 2b benchmark; fix a typo. (#81)
1 parent 776c1c4 commit f3cf2b7

File tree

3 files changed

+23
-2
lines changed

3 files changed

+23
-2
lines changed

README.md

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,17 @@ NOTE: the `--platform=tpu=8` need to specify number of tpu devices (which is 4 f
109109
```bash
110110
python run_server.py --param_size=7b --model_name=$model_name --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --platform=tpu=8 --model=$model_name
111111
```
112-
Now you can fire gRPC to it
112+
113+
Now you can fire gRPC to it.
114+
115+
Optional flags:
116+
* `--shard_on_batch=1` This makes the model to shard on
117+
the batch dimension. I.e. this runs in data parallel mode instead of model
118+
parallel. This will ignore the sharding config. This is recommended for Gemma 2B
119+
model, because Gemma 2B is small enough to fit on a single TPU chip.
120+
121+
* `--sharding_config=<path>` This makes use of alternative sharding config instead of
122+
the ones in default_shardings directory.
113123

114124
# Run benchmark
115125
go to the deps/JetStream folder (downloaded during `install_everything.sh`)

benchmarks/summary.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,17 @@ Date | Device | dtype | batch size | cache length |max input length |max output
1616
2024-05-10 | TPU v5e-8 | bfloat16 | 96 | 2048 | 1024 | 1024 | 3236
1717
2024-05-10 | TPU v5e-8 | int8 | 128 | 2048 | 1024 | 1024 | 4695
1818

19+
## Gemma - 2B
20+
21+
Date | Device | dtype | batch size | cache length |max input length |max output length| throughput (token/s)
22+
----| ------- | ------ |---------- | -------------|-----------------|------------------|----------------------
23+
2024-05-14 | TPU v5e-8 | bfloat16 | 512 | 2048 | 1024 | 1024 | 8700
24+
2024-05-14 | TPU v5e-8 | int8 | 1024 | 2048 | 1024 | 1024 | 8746
25+
26+
** NOTE: ** Gemma 2B uses `--shard_on_batch` flag so it's data parallel instead
27+
of model parallel.
28+
29+
1930
## Llama 2 - 7B
2031

2132
Date | Device | dtype | batch size | cache length |max input length |max output length| throughput (token/s)

jetstream_pt/environment.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def make_caches_generate(self):
176176
def sharding_by_name(self, name):
177177
"""Create sharding specified in the config."""
178178
if self.shard_on_batch:
179-
return self.shading_by_axis(0) # batch dimension
179+
return self.sharding_by_axis(0) # batch dimension
180180

181181
if name in self._sharding_config:
182182
return self.sharding_by_axis(self._sharding_config[name])

0 commit comments

Comments
 (0)