Skip to content

Commit 648bf48

Browse files
authored
Add llama-3 instructions to readme (#79)
* Add llama-3 instructions to readme * Use the same sharding for llama-2 and llama-3
1 parent 51647b3 commit 648bf48

File tree

3 files changed

+17
-12
lines changed

3 files changed

+17
-12
lines changed

README.md

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,9 @@ NOTE: the above script will export PYTHONPATH, so sourcing will make it to take
4646
## LLaMA
4747
### Get official llama weights from meta-llama
4848

49-
Following instructions here: https://github.com/meta-llama/llama#download
49+
Following instructions here:
50+
* Llama-2: https://github.com/meta-llama/llama#download
51+
* Llama-3: https://github.com/meta-llama/llama3/#download
5052

5153
After you have downloaded the weights, it will also download a `tokenizer.model` file that is
5254
the tokenizer that we will use.
@@ -68,7 +70,7 @@ Need to manually modify the `config.json` in the checkpoint folder to make it a
6870
export input_ckpt_dir=Original llama weights directory
6971
export output_ckpt_dir=The output directory
7072
export quantize=True #whether to quantize
71-
export model_name="llama-2" # or "gemma"
73+
export model_name="llama-3" # or "llama-2", "gemma"
7274
python -m convert_checkpoints --model_name=$model_name --input_checkpoint_dir=$input_ckpt_dir --output_checkpoint_dir=$output_ckpt_dir --quantize=$quantize
7375
```
7476

@@ -80,16 +82,20 @@ Set tokenizer path
8082
export tokenizer_path=tokenizer model file path
8183
```
8284

83-
## Llama 7b
85+
## Llama-2 7b
8486
```bash
85-
python run_interactive.py --size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/$model_name.yaml
87+
python run_interactive.py --size=7b --model_name=$model_name --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/llama.yaml
8688
```
8789

88-
## Llama 13b
90+
## Llama-2 13b
8991
```bash
90-
python run_interactive.py --size=13b --batch_size=64 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/$model_name.yaml
92+
python run_interactive.py --size=13b --model_name=$model_name --batch_size=64 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/llama.yaml
9193
```
9294

95+
## Llama-3 8b
96+
```bash
97+
python run_interactive.py --size=8b --model_name=$model_name --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/llama.yaml
98+
```
9399

94100
## Gemma 7b
95101
```bash
@@ -101,7 +107,7 @@ python run_interactive.py --model_name=$model_name --size=7b --batch_size=64 --m
101107
NOTE: the `--platform=tpu=8` need to specify number of tpu devices (which is 4 for v4-8 and 8 for v5light-8`)
102108

103109
```bash
104-
python run_server.py --param_size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --platform=tpu=8 --model=$model_name
110+
python run_server.py --param_size=7b --model_name=$model_name --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --platform=tpu=8 --model=$model_name
105111
```
106112
Now you can fire gRPC to it
107113

default_shardings/llama-2.yaml renamed to default_shardings/llama.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66

77
freqs_cis : -1 # torch.complex64 (2048, 64)
8-
tok_embeddings.weight : 1 # torch.float32 (32000, 4096)
8+
tok_embeddings.weight : 1 # torch.float32 (vocab_size, 4096)
99
tok_embeddings.weight_scaler : 0 # torch.bfloat16 (4096,)
1010
layers.*.attention.wo.weight : 1 # torch.int8 (4096, 4096)
1111
layers.*.attention.wo.weight_scaler : 0 # torch.bfloat16 (4096,)
@@ -24,5 +24,5 @@ layers.*.feed_forward.w3.weight_scaler : 0 # torch.bfloat16 (4096,)
2424
layers.*.attention_norm.weight : -1 # torch.float32 (4096,)
2525
layers.*.ffn_norm.weight : -1 # torch.float32 (4096,)
2626
norm.weight : -1 # torch.float32 (4096,)
27-
output.weight : 0 # torch.float32 (32000, 4096)
27+
output.weight : 0 # torch.float32 (vocab_size, 4096)
2828
output.weight_scaler : 0 # torch.float32 (4096,)

jetstream_pt/engine.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -701,7 +701,6 @@ def create_pytorch_engine(
701701
checkpoint_format = "safetensors"
702702
checkpoint_path = paths[0]
703703

704-
tokenizer = token_utils.load_vocab(tokenizer_path)
705704
pt_model = None
706705

707706
if not sharding_config:
@@ -734,7 +733,7 @@ def create_pytorch_engine(
734733
max_cache_length,
735734
args.dim // args.n_heads,
736735
)
737-
env_data.model_type = "llama-2-" + param_size
736+
env_data.model_type = model_name + "-" + param_size
738737
env_data.num_layers = args.n_layers
739738
env = JetEngineEnvironment(env_data)
740739
pt_model = model_exportable.Transformer(args, env)
@@ -746,7 +745,7 @@ def create_pytorch_engine(
746745
max_cache_length,
747746
args.head_dim,
748747
)
749-
env_data.model_type = "gemma-" + param_size
748+
env_data.model_type = model_name + "-" + param_size
750749
env_data.num_layers = args.num_hidden_layers
751750
env = JetEngineEnvironment(env_data)
752751
pt_model = gemma_model.GemmaModel(args, env)

0 commit comments

Comments
 (0)