Skip to content

Commit cf74348

Browse files
adding deferred init as Ke advised
1 parent f3daad1 commit cf74348

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

examples/llama/2d_llama.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ def modify_view(
2727

2828
# Grab the model
2929
llama = AutoModelForCausalLM.from_pretrained(
30-
"meta-llama/Llama-2-7b-chat-hf", low_cpu_mem_usage=True
30+
"meta-llama/Llama-2-7b-chat-hf", low_cpu_mem_usage=True,
31+
torch_dtype=torch.float16
3132
)
3233
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
3334

@@ -46,8 +47,8 @@ def modify_view(
4647
mesh_2d = init_device_mesh("cuda", (pp_group_size, tp_group_size), mesh_dim_names=("pp", "tp"))
4748
pp_group = mesh_2d["pp"].get_group()
4849

49-
llama.to(device).eval()
50-
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(device)
50+
llama.eval()
51+
inputs = tokenizer(prompts, return_tensors="pt", padding=True)
5152

5253
# Cut model by equal number of layers per rank
5354
layers_per_stage = llama.config.num_hidden_layers // pp_group_size
@@ -90,7 +91,7 @@ def modify_view(
9091
parallelize_module(
9192
stage.submod, tp_mesh, {**attn_plan, **mlp_plan}
9293
)
93-
94+
inputs = inputs.to(device)
9495
# Run
9596
if stage_idx == 0:
9697
args = inputs["input_ids"]

0 commit comments

Comments
 (0)