|
| 1 | +# $ torchrun --nproc-per-node 8 2d_llama.py |
| 2 | +import os |
| 3 | +import torch |
| 4 | +from transformers import AutoModelForCausalLM, AutoTokenizer |
| 5 | +from pippy import Pipe, PipeSplitWrapper, annotate_split_points, PipelineStage |
| 6 | +from torch.distributed._tensor import init_device_mesh |
| 7 | + |
| 8 | +# Grab the model |
| 9 | +llama = AutoModelForCausalLM.from_pretrained( |
| 10 | + "meta-llama/Llama-2-7b-chat-hf", low_cpu_mem_usage=True |
| 11 | +) |
| 12 | +tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") |
| 13 | + |
| 14 | +prompts = ( |
| 15 | + "How do you", "I like to", "Can I help", "You need to", |
| 16 | + "The weather is", "I found a", "What is your", "You are so", |
| 17 | +) # bs = 8 |
| 18 | +tokenizer.pad_token = tokenizer.eos_token |
| 19 | + |
| 20 | +rank = int(os.environ["RANK"]) |
| 21 | +world_size = int(os.environ["WORLD_SIZE"]) |
| 22 | +device = torch.device(f"cuda:{rank % torch.cuda.device_count()}") |
| 23 | + |
| 24 | +pp_group_size = 2 |
| 25 | +tp_group_size = 4 |
| 26 | +mesh_2d = init_device_mesh("cuda", (pp_group_size, tp_group_size), mesh_dim_names=("pp", "tp")) |
| 27 | +pp_group = mesh_2d["pp"].get_group() |
| 28 | + |
| 29 | +llama.to(device).eval() |
| 30 | +inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(device) |
| 31 | + |
| 32 | +# Cut model by equal number of layers per rank |
| 33 | +layers_per_stage = llama.config.num_hidden_layers // pp_group_size |
| 34 | +for i in range(1, pp_group_size): |
| 35 | + annotate_split_points(llama, |
| 36 | + {f"model.layers.{i * layers_per_stage}": PipeSplitWrapper.SplitPoint.BEGINNING}) |
| 37 | + |
| 38 | +# Create a pipeline representation from the model |
| 39 | +llama_pipe = Pipe.from_tracing(llama, pp_group_size, example_args=(inputs["input_ids"],)) |
| 40 | + |
| 41 | +# Create pipeline stage for each rank |
| 42 | +stage_idx = rank // tp_group_size |
| 43 | +stage = PipelineStage(llama_pipe, stage_idx, device=device, group=pp_group) |
| 44 | + |
| 45 | +# Tensor parallel |
| 46 | +from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel, RowwiseParallel |
| 47 | +starting_layer = stage_idx * layers_per_stage |
| 48 | +plan = {} |
| 49 | +for i in range(layers_per_stage): |
| 50 | + # HACK: the right fix is to remove the ".mod" added by PipeSplitWrapper |
| 51 | + extra = "_mod" if starting_layer > 0 and i == 0 else "" |
| 52 | + layer_name = f"L__self___model_layers_{starting_layer + i}{extra}" |
| 53 | + plan.update({ |
| 54 | + # Parallel self attention not working yet due to the dimension mismatch |
| 55 | + # after TP in view operation |
| 56 | + #f"{layer_name}_self_attn_q_proj": ColwiseParallel(), |
| 57 | + #f"{layer_name}_self_attn_k_proj": ColwiseParallel(), |
| 58 | + #f"{layer_name}_self_attn_v_proj": ColwiseParallel(), |
| 59 | + #f"{layer_name}_self_attn_o_proj": RowwiseParallel(), |
| 60 | + f"{layer_name}_mlp_gate_proj": ColwiseParallel(), |
| 61 | + f"{layer_name}_mlp_up_proj": ColwiseParallel(), |
| 62 | + f"{layer_name}_mlp_down_proj": RowwiseParallel(), |
| 63 | + }) |
| 64 | +tp_mesh = mesh_2d["tp"] |
| 65 | +parallelize_module(stage.submod, tp_mesh, plan) |
| 66 | + |
| 67 | +# Run |
| 68 | +if stage_idx == 0: |
| 69 | + args = inputs["input_ids"] |
| 70 | +else: |
| 71 | + args = None |
| 72 | +output = stage(args) |
| 73 | + |
| 74 | +# Decode |
| 75 | +if output is not None: |
| 76 | + next_token_logits = output[0][:, -1, :] |
| 77 | + next_token = torch.argmax(next_token_logits, dim=-1) |
| 78 | + print(tokenizer.batch_decode(next_token)) |
0 commit comments