Skip to content

Commit d2e8e62

Browse files
committed
2D working without TP self attention
1 parent a4cc35f commit d2e8e62

File tree

1 file changed

+78
-0
lines changed

1 file changed

+78
-0
lines changed

examples/llama/2d_llama.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# $ torchrun --nproc-per-node 8 2d_llama.py
2+
import os
3+
import torch
4+
from transformers import AutoModelForCausalLM, AutoTokenizer
5+
from pippy import Pipe, PipeSplitWrapper, annotate_split_points, PipelineStage
6+
from torch.distributed._tensor import init_device_mesh
7+
8+
# Grab the model
9+
llama = AutoModelForCausalLM.from_pretrained(
10+
"meta-llama/Llama-2-7b-chat-hf", low_cpu_mem_usage=True
11+
)
12+
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
13+
14+
prompts = (
15+
"How do you", "I like to", "Can I help", "You need to",
16+
"The weather is", "I found a", "What is your", "You are so",
17+
) # bs = 8
18+
tokenizer.pad_token = tokenizer.eos_token
19+
20+
rank = int(os.environ["RANK"])
21+
world_size = int(os.environ["WORLD_SIZE"])
22+
device = torch.device(f"cuda:{rank % torch.cuda.device_count()}")
23+
24+
pp_group_size = 2
25+
tp_group_size = 4
26+
mesh_2d = init_device_mesh("cuda", (pp_group_size, tp_group_size), mesh_dim_names=("pp", "tp"))
27+
pp_group = mesh_2d["pp"].get_group()
28+
29+
llama.to(device).eval()
30+
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(device)
31+
32+
# Cut model by equal number of layers per rank
33+
layers_per_stage = llama.config.num_hidden_layers // pp_group_size
34+
for i in range(1, pp_group_size):
35+
annotate_split_points(llama,
36+
{f"model.layers.{i * layers_per_stage}": PipeSplitWrapper.SplitPoint.BEGINNING})
37+
38+
# Create a pipeline representation from the model
39+
llama_pipe = Pipe.from_tracing(llama, pp_group_size, example_args=(inputs["input_ids"],))
40+
41+
# Create pipeline stage for each rank
42+
stage_idx = rank // tp_group_size
43+
stage = PipelineStage(llama_pipe, stage_idx, device=device, group=pp_group)
44+
45+
# Tensor parallel
46+
from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel, RowwiseParallel
47+
starting_layer = stage_idx * layers_per_stage
48+
plan = {}
49+
for i in range(layers_per_stage):
50+
# HACK: the right fix is to remove the ".mod" added by PipeSplitWrapper
51+
extra = "_mod" if starting_layer > 0 and i == 0 else ""
52+
layer_name = f"L__self___model_layers_{starting_layer + i}{extra}"
53+
plan.update({
54+
# Parallel self attention not working yet due to the dimension mismatch
55+
# after TP in view operation
56+
#f"{layer_name}_self_attn_q_proj": ColwiseParallel(),
57+
#f"{layer_name}_self_attn_k_proj": ColwiseParallel(),
58+
#f"{layer_name}_self_attn_v_proj": ColwiseParallel(),
59+
#f"{layer_name}_self_attn_o_proj": RowwiseParallel(),
60+
f"{layer_name}_mlp_gate_proj": ColwiseParallel(),
61+
f"{layer_name}_mlp_up_proj": ColwiseParallel(),
62+
f"{layer_name}_mlp_down_proj": RowwiseParallel(),
63+
})
64+
tp_mesh = mesh_2d["tp"]
65+
parallelize_module(stage.submod, tp_mesh, plan)
66+
67+
# Run
68+
if stage_idx == 0:
69+
args = inputs["input_ids"]
70+
else:
71+
args = None
72+
output = stage(args)
73+
74+
# Decode
75+
if output is not None:
76+
next_token_logits = output[0][:, -1, :]
77+
next_token = torch.argmax(next_token_logits, dim=-1)
78+
print(tokenizer.batch_decode(next_token))

0 commit comments

Comments
 (0)