diff --git a/src/levanter/main/sft.py b/src/levanter/main/sft.py index 4bcae2ce0..f63b8a4e9 100644 --- a/src/levanter/main/sft.py +++ b/src/levanter/main/sft.py @@ -186,7 +186,6 @@ def train(config: SFTConfig): # some axes we need Pos = config.model.Pos - # to do partitioning, our dimensions have to be divisible by the size of the physical axes they're mapped to # For most things, we just insist you specify the config right, but tokenizers often have strange numbers of # tokens: gpt-2 has 50257, for example. So we round up. @@ -210,12 +209,12 @@ def train(config: SFTConfig): callbacks.log_performance_stats(Pos.size, trainer.config.train_batch_size, flops_per_example), every=1 ) - # reshuffle the examples before packing! + # TODO: reshuffle the examples before packing! + # Get current step from trainer state + current_step = int(state.step) - # to implement seeking - # check the step number in the trainer state if it's not zero - # then next the iterator until we get there, then continue training. - # batch size will be backed in from config + + # change iterate tokenized requests to take a dict rather than a list # of where the first element is prompt ands econd is response @@ -227,6 +226,22 @@ def train(config: SFTConfig): logger.info("Creating prompt completion iterator") prompt_completion_iterator = create_prompt_completion_iterator(train_dataset, Pos) + if current_step > 0: + logger.info(f"Resuming training from step {current_step}") + # Calculate how many examples to skip based on batch size + examples_to_skip = current_step * trainer.config.train_batch_size + + # Skip through the iterator until we reach the right position + for _ in range(examples_to_skip): + try: + next(prompt_completion_iterator) + except StopIteration: + logger.warning("Ran out of examples while seeking - restarting from beginning") + # Recreate iterator and continue skipping + prompt_completion_iterator = create_prompt_completion_iterator(train_dataset, Pos) + else: + logger.info("Starting SFT from scratch") + logger.info("Packing prompt completions") packed_iterator = _pack_requests(prompt_completion_iterator, tokenizer, Pos, max_pack_size=4) logger.info("Stacking batches to train batch") @@ -279,20 +294,21 @@ def create_prompt_completion_iterator(cached_dataset: AsyncDataset, Pos: hax.Axi for i in range(length): example = asyncio.run(cached_dataset.getitem_async(i)) - if int(example["sources_len"]) > Pos.size - 1: + sources_len = example["sources_len"].item() + if sources_len > Pos.size - 1: continue ids = example["input_ids"].tolist() if len(ids) > Pos.size: ids = ids[:Pos.size] - if len(ids) <= example["sources_len"]: + if len(ids) <= sources_len: continue try: yield PromptCompletion( ids=ids, - prompt_length=int(example["sources_len"]), + prompt_length=sources_len, segment_id=i ) except ValueError: