2424from speculators .train .trainer import Trainer , TrainerConfig
2525from speculators .train .utils import maybe_destroy_distributed , maybe_setup_distributed
2626
27- # Model
28- DRAFT_VOCAB_SIZE = 32000 # Must match t2d and d2t tensors
29- TOTAL_SEQ_LEN = 8192
30- VERIFIER_MODEL_NAME_OR_PATH = "meta-llama/Llama-3.1-8B-Instruct"
31- HIDDEN_SIZE = 4096 # Must match the verifier model's hidden size
32- VERIFIER_VOCAB_SIZE = 128256 # Must match the verifier model's vocab size
27+ # DRAFTER MODEL HYPARAMETERS
3328NORM_BEFORE_RESIDUAL = True
3429
3530# Dataloader
@@ -58,12 +53,12 @@ def setup_dataloader(
5853
5954 dataset = Eagle3SampleFileDataset (
6055 file_list = file_list ,
61- max_len = TOTAL_SEQ_LEN ,
56+ max_len = args . total_seq_len ,
6257 transform = noise_transform ,
6358 standardize_fn = standardize_fn ,
6459 )
6560 batch_sampler = MultipackDistributedBatchSamplerV2 (
66- batch_max_length = TOTAL_SEQ_LEN ,
61+ batch_max_length = args . total_seq_len ,
6762 lengths = dataset .approx_lengths ,
6863 num_replicas = world_size ,
6964 rank = local_rank ,
@@ -74,7 +69,7 @@ def setup_dataloader(
7469 num_workers = NUM_WORKERS ,
7570 prefetch_factor = PREFETCH_FACTOR ,
7671 pin_memory = True ,
77- collate_fn = create_collate_fn (TOTAL_SEQ_LEN ),
72+ collate_fn = create_collate_fn (args . total_seq_len ),
7873 persistent_workers = True ,
7974 )
8075
@@ -91,15 +86,19 @@ def main(args: argparse.Namespace):
9186 device = torch .device (local_rank )
9287
9388 # Setup speculator config
94- llama_config = LlamaConfig (
95- hidden_size = HIDDEN_SIZE ,
96- vocab_size = VERIFIER_VOCAB_SIZE ,
97- num_hidden_layers = args .num_layers ,
98- )
89+ llama_config = LlamaConfig .from_pretrained (args .verifier_name_or_path )
90+ llama_config .num_hidden_layers = args .num_layers
91+ llama_config .model_type = "llama" # reset to llama (handles non-llama verifiers)
9992 llama_config ._attn_implementation = "simple_flex_attention" # noqa: SLF001
93+
94+ # Load t2d and d2t tensors
95+ d2t = torch .from_numpy (np .load (args .d2t_path )).to (device )
96+ t2d = torch .from_numpy (np .load (args .t2d_path )).to (device )
97+ draft_vocab_size = d2t .shape [0 ]
98+
10099 speculator_config = Eagle3SpeculatorConfig (
101100 transformer_layer_config = llama_config ,
102- draft_vocab_size = DRAFT_VOCAB_SIZE ,
101+ draft_vocab_size = draft_vocab_size ,
103102 norm_before_residual = NORM_BEFORE_RESIDUAL ,
104103 speculators_config = SpeculatorsConfig (
105104 algorithm = "eagle3" ,
@@ -111,15 +110,12 @@ def main(args: argparse.Namespace):
111110 ],
112111 default_proposal_method = "greedy" ,
113112 verifier = VerifierConfig (
114- name_or_path = VERIFIER_MODEL_NAME_OR_PATH ,
113+ name_or_path = args . verifier_name_or_path ,
115114 architectures = ["LlamaForCausalLM" ],
116115 ),
117116 ),
118117 )
119118
120- # Load t2d and d2t tensors
121- d2t = torch .from_numpy (np .load (args .d2t_path )).to (device )
122- t2d = torch .from_numpy (np .load (args .t2d_path )).to (device )
123119
124120 # Setup draft model
125121 draft_model = Eagle3DraftModel (
@@ -165,6 +161,7 @@ def main(args: argparse.Namespace):
165161
166162def parse_args ():
167163 parser = argparse .ArgumentParser ()
164+ parser .add_argument ("--verifier_name_or_path" , type = str , required = True )
168165 parser .add_argument ("--data-path" , type = str , default = "./data" )
169166 parser .add_argument ("--save-path" , type = str , default = "./checkpoints" )
170167 parser .add_argument ("--epochs" , type = int , default = 20 )
@@ -176,6 +173,7 @@ def parse_args():
176173 default = "" ,
177174 help = "One of 'trackio', 'wandb', 'tensorboard' or comma separated list of them" ,
178175 )
176+ parser .add_argument ("--total-seq-len" , type = int , default = 8192 )
179177 parser .add_argument ("--data-format-version" , type = int , default = 1 )
180178 parser .add_argument ("--log-dir" , type = str , default = "./logs" )
181179 parser .add_argument ("--run-name" , type = str , default = None )
0 commit comments