Skip to content

Commit 2638cff

Browse files
committed
Generalize train script to support other verifier types
Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
1 parent 475f664 commit 2638cff

File tree

1 file changed

+17
-19
lines changed

1 file changed

+17
-19
lines changed

scripts/train_llama3_8b_drafter.py renamed to scripts/train.py

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,7 @@
2424
from speculators.train.trainer import Trainer, TrainerConfig
2525
from speculators.train.utils import maybe_destroy_distributed, maybe_setup_distributed
2626

27-
# Model
28-
DRAFT_VOCAB_SIZE = 32000 # Must match t2d and d2t tensors
29-
TOTAL_SEQ_LEN = 8192
30-
VERIFIER_MODEL_NAME_OR_PATH = "meta-llama/Llama-3.1-8B-Instruct"
31-
HIDDEN_SIZE = 4096 # Must match the verifier model's hidden size
32-
VERIFIER_VOCAB_SIZE = 128256 # Must match the verifier model's vocab size
27+
# DRAFTER MODEL HYPARAMETERS
3328
NORM_BEFORE_RESIDUAL = True
3429

3530
# Dataloader
@@ -58,12 +53,12 @@ def setup_dataloader(
5853

5954
dataset = Eagle3SampleFileDataset(
6055
file_list=file_list,
61-
max_len=TOTAL_SEQ_LEN,
56+
max_len=args.total_seq_len,
6257
transform=noise_transform,
6358
standardize_fn=standardize_fn,
6459
)
6560
batch_sampler = MultipackDistributedBatchSamplerV2(
66-
batch_max_length=TOTAL_SEQ_LEN,
61+
batch_max_length=args.total_seq_len,
6762
lengths=dataset.approx_lengths,
6863
num_replicas=world_size,
6964
rank=local_rank,
@@ -74,7 +69,7 @@ def setup_dataloader(
7469
num_workers=NUM_WORKERS,
7570
prefetch_factor=PREFETCH_FACTOR,
7671
pin_memory=True,
77-
collate_fn=create_collate_fn(TOTAL_SEQ_LEN),
72+
collate_fn=create_collate_fn(args.total_seq_len),
7873
persistent_workers=True,
7974
)
8075

@@ -91,15 +86,19 @@ def main(args: argparse.Namespace):
9186
device = torch.device(local_rank)
9287

9388
# Setup speculator config
94-
llama_config = LlamaConfig(
95-
hidden_size=HIDDEN_SIZE,
96-
vocab_size=VERIFIER_VOCAB_SIZE,
97-
num_hidden_layers=args.num_layers,
98-
)
89+
llama_config = LlamaConfig.from_pretrained(args.verifier_name_or_path)
90+
llama_config.num_hidden_layers = args.num_layers
91+
llama_config.model_type = "llama" # reset to llama (handles non-llama verifiers)
9992
llama_config._attn_implementation = "simple_flex_attention" # noqa: SLF001
93+
94+
# Load t2d and d2t tensors
95+
d2t = torch.from_numpy(np.load(args.d2t_path)).to(device)
96+
t2d = torch.from_numpy(np.load(args.t2d_path)).to(device)
97+
draft_vocab_size = d2t.shape[0]
98+
10099
speculator_config = Eagle3SpeculatorConfig(
101100
transformer_layer_config=llama_config,
102-
draft_vocab_size=DRAFT_VOCAB_SIZE,
101+
draft_vocab_size=draft_vocab_size,
103102
norm_before_residual=NORM_BEFORE_RESIDUAL,
104103
speculators_config=SpeculatorsConfig(
105104
algorithm="eagle3",
@@ -111,15 +110,12 @@ def main(args: argparse.Namespace):
111110
],
112111
default_proposal_method="greedy",
113112
verifier=VerifierConfig(
114-
name_or_path=VERIFIER_MODEL_NAME_OR_PATH,
113+
name_or_path=args.verifier_name_or_path,
115114
architectures=["LlamaForCausalLM"],
116115
),
117116
),
118117
)
119118

120-
# Load t2d and d2t tensors
121-
d2t = torch.from_numpy(np.load(args.d2t_path)).to(device)
122-
t2d = torch.from_numpy(np.load(args.t2d_path)).to(device)
123119

124120
# Setup draft model
125121
draft_model = Eagle3DraftModel(
@@ -165,6 +161,7 @@ def main(args: argparse.Namespace):
165161

166162
def parse_args():
167163
parser = argparse.ArgumentParser()
164+
parser.add_argument("--verifier_name_or_path", type=str, required=True)
168165
parser.add_argument("--data-path", type=str, default="./data")
169166
parser.add_argument("--save-path", type=str, default="./checkpoints")
170167
parser.add_argument("--epochs", type=int, default=20)
@@ -176,6 +173,7 @@ def parse_args():
176173
default="",
177174
help="One of 'trackio', 'wandb', 'tensorboard' or comma separated list of them",
178175
)
176+
parser.add_argument("--total-seq-len", type=int, default=8192)
179177
parser.add_argument("--data-format-version", type=int, default=1)
180178
parser.add_argument("--log-dir", type=str, default="./logs")
181179
parser.add_argument("--run-name", type=str, default=None)

0 commit comments

Comments
 (0)