Skip to content

Commit

Permalink
WIP buggy, MFU stuck at 1
Browse files Browse the repository at this point in the history
  • Loading branch information
ahmeda14960 committed Jan 27, 2025
1 parent 0d7acc7 commit efbcc60
Showing 1 changed file with 25 additions and 9 deletions.
34 changes: 25 additions & 9 deletions src/levanter/main/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,6 @@ def train(config: SFTConfig):

# some axes we need
Pos = config.model.Pos

# to do partitioning, our dimensions have to be divisible by the size of the physical axes they're mapped to
# For most things, we just insist you specify the config right, but tokenizers often have strange numbers of
# tokens: gpt-2 has 50257, for example. So we round up.
Expand All @@ -210,12 +209,12 @@ def train(config: SFTConfig):
callbacks.log_performance_stats(Pos.size, trainer.config.train_batch_size, flops_per_example), every=1
)

# reshuffle the examples before packing!
# TODO: reshuffle the examples before packing!
# Get current step from trainer state
current_step = int(state.step)

# to implement seeking
# check the step number in the trainer state if it's not zero
# then next the iterator until we get there, then continue training.
# batch size will be backed in from config



# change iterate tokenized requests to take a dict rather than a list
# of where the first element is prompt ands econd is response
Expand All @@ -227,6 +226,22 @@ def train(config: SFTConfig):
logger.info("Creating prompt completion iterator")
prompt_completion_iterator = create_prompt_completion_iterator(train_dataset, Pos)

if current_step > 0:
logger.info(f"Resuming training from step {current_step}")
# Calculate how many examples to skip based on batch size
examples_to_skip = current_step * trainer.config.train_batch_size

# Skip through the iterator until we reach the right position
for _ in range(examples_to_skip):
try:
next(prompt_completion_iterator)
except StopIteration:
logger.warning("Ran out of examples while seeking - restarting from beginning")
# Recreate iterator and continue skipping
prompt_completion_iterator = create_prompt_completion_iterator(train_dataset, Pos)
else:
logger.info("Starting SFT from scratch")

logger.info("Packing prompt completions")
packed_iterator = _pack_requests(prompt_completion_iterator, tokenizer, Pos, max_pack_size=4)
logger.info("Stacking batches to train batch")
Expand Down Expand Up @@ -279,20 +294,21 @@ def create_prompt_completion_iterator(cached_dataset: AsyncDataset, Pos: hax.Axi
for i in range(length):
example = asyncio.run(cached_dataset.getitem_async(i))

if int(example["sources_len"]) > Pos.size - 1:
sources_len = example["sources_len"].item()
if sources_len > Pos.size - 1:
continue

ids = example["input_ids"].tolist()
if len(ids) > Pos.size:
ids = ids[:Pos.size]

if len(ids) <= example["sources_len"]:
if len(ids) <= sources_len:
continue

try:
yield PromptCompletion(
ids=ids,
prompt_length=int(example["sources_len"]),
prompt_length=sources_len,
segment_id=i
)
except ValueError:
Expand Down

0 comments on commit efbcc60

Please sign in to comment.