Skip to content

Commit

Permalink
initial sft_mixture
Browse files Browse the repository at this point in the history
  • Loading branch information
ahmeda14960 committed Feb 11, 2025
1 parent 7a78c13 commit ff737dd
Show file tree
Hide file tree
Showing 2 changed files with 392 additions and 0 deletions.
92 changes: 92 additions & 0 deletions config/sft_llama3_mixture.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
dataset_type: chat_jsonl

# Config for mixture of supervised datasets
supervised_data:
tulu:
cache_dir: "gs://marin-us-central2/tokenized/tulu_sft_v3_llama3_tokenizer-7b19dc"
train_urls:
- "gs://marin-us-central2/documents/allenai--tulu-v2-sft-mixture-0ba27c/data/**/*.jsonl.gz"
openthoughts:
cache_dir: "gs://marin-us-central2/tokenized/openthoughts_llama3_tokenizer-404700"
train_urls:
- "gs://marin-us-central2/documents/open-thoughts--OpenThoughts-114k-216e29/data/**/*.jsonl.gz"
prime_verified_math:
cache_dir: "gs://marin-us-central2/tokenized/prime_verified_math_llama3_tokenizer-9256ad"
train_urls:
- "gs://marin-us-central2/documents/PrimeIntellect--verifiable-math-problems-ae89bf/data/**/*.jsonl.gz"
acecode:
cache_dir: "gs://marin-us-central2/tokenized/acecode_llama3_tokenizer-9c2672"
train_urls:
- "gs://marin-us-central2/documents/TIGER-Lab--AceCode-89K-73bd59/data/**/*.jsonl.gz"
smoltalk:
cache_dir: "gs://marin-us-central2/tokenized/smoltalk_llama3_tokenizer-ad5792"
train_urls:
- "gs://marin-us-central2/documents/HuggingFaceTB--smoltalk-6190ae/data/**/*.jsonl.gz"


# Mixture weights (example weights - adjust based on your needs)
mixture_weights:
tulu: 0.318
openthoughts: 0.037
prime_verified_math: 0.193
acecode: 0.013
smoltalk: 0.439

mixture_block_size: 2048
stop_strategy: restart

max_seq_len: 4096
tokenizer: "meta-llama/Meta-Llama-3.1-8B"
model: # 8B llama3 class model
type: llama
seq_len: 4096
hidden_dim: 4096
intermediate_dim: 14336
num_layers: 32
num_heads: 32
num_kv_heads: 8
use_flash_attention: True
flash_attention_block_size: 512
use_bias: false
use_layer_norm_weight: true
initializer_range: 0.02
rope:
type: "llama3"

trainer:
seed: 2
tracker:
type: wandb
project: "marin"
tags: ["dolma", "olmo", "llama", "mixture"]
wandb:
project: "marin"
name: "llama3.1_tulu_openthoughts_mixture"

mp: p=f32,c=bfloat16
train_batch_size: 128
# Adjust steps based on combined dataset size and desired epochs
num_train_steps: 12197
steps_per_eval: 1000
tensor_parallel_axes: ["mlp", "heads"]
fsdp_axis: "embed"
batch_axis: "batch"
checkpointer:
base_path: "gs://levanter-checkpoints/marin/llama_3.1_mixture/seed_2/"

optimizer:
learning_rate: 5e-6
weight_decay: 0.0
min_lr_ratio: 0.0
lr_schedule: "linear"
warmup: 0.03

hf_save_steps: 4000
hf_save_path: "gs://levanter-checkpoints/marin/llama_3.1_mixture/hf/seed_2/"

initialize_from_hf: True
model_name_or_path: "meta-llama/Llama-3.1-8B"

messages_field: "messages"
input_role: "user"
output_role: "assistant"
300 changes: 300 additions & 0 deletions src/levanter/main/sft_mixture.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,300 @@
import asyncio
import dataclasses
import logging
import os
from dataclasses import dataclass, field
from enum import Enum
from typing import Dict, Iterator, List, Optional, Union

import jax.random as jrandom
import transformers

import haliax as hax
from haliax import Axis
from haliax.partitioning import round_axis_for_partitioning

import levanter
from levanter import callbacks
from levanter.compat.hf_checkpoints import HFCheckpointConverter, HFCompatConfig, save_hf_checkpoint_callback
from levanter.data import PermutationDataset, batched
from levanter.data.dataset import AsyncDataset
from levanter.data.loader import stack_batches
from levanter.data.mixture import MixtureDataset, StopStrategy
from levanter.data.packing import PromptCompletion, pack_prompt_completions
from levanter.data.text import (
ChatUrlDataSourceConfig,
EpochDataset,
SupervisedSourceConfig,
mk_cached_sft_dataset,
mk_supervised_dataset,
)
from levanter.models.llama import LlamaConfig
from levanter.models.lm_model import LmConfig, LmHeadModel, compute_next_token_loss
from levanter.optim import AdamConfig, OptimizerConfig
from levanter.trainer import Trainer, TrainerConfig
from levanter.utils.background_iterable import BackgroundIterator


logger = logging.getLogger(__name__)

# Define default special tokens
DEFAULT_PAD_TOKEN = "[PAD]"
DEFAULT_EOS_TOKEN = "</s>"
DEFAULT_BOS_TOKEN = "<s>"
DEFAULT_UNK_TOKEN = "<unk>"


class DatasetType(str, Enum):
"""Type of dataset to use"""
HUGGINGFACE = "huggingface" # Use HF dataset
CHAT_JSONL = "chat_jsonl" # Use JSONL files with chat format


@dataclass
class SFTMixtureConfig:
trainer: TrainerConfig = field(default_factory=TrainerConfig)
model: LmConfig = field(default_factory=LlamaConfig)
optimizer: OptimizerConfig = field(default_factory=AdamConfig)

# Config for mixture of supervised datasets
supervised_data: Dict[str, SupervisedSourceConfig] = field(default_factory=dict)
mixture_weights: Dict[str, float] = field(default_factory=dict)
mixture_block_size: int = 2048
stop_strategy: str = StopStrategy.RESTART_STRATEGY

# Config related to continued pretraining
initialize_from_hf: Union[bool, str] = False
hf_save_path: Optional[str] = None
hf_upload: Optional[str] = None
hf_save_steps: int = 0

max_seq_len: int = 2048
model_name_or_path: str = "meta-llama/Llama-2-7b-hf"
tokenizer: str = "meta-llama/Llama-2-7b-hf"

# Add dataset type and chat-specific fields
dataset_type: DatasetType = DatasetType.CHAT_JSONL
messages_field: str = "messages"
input_role: str = "user"
output_role: str = "assistant"

data_seed: Optional[int] = None

def train(config: SFTMixtureConfig):
tokenizer = transformers.AutoTokenizer.from_pretrained(
config.tokenizer,
model_max_length=config.max_seq_len,
padding_side="right",
trust_remote_code=True,
)
logger.info(f"Loaded tokenizer {tokenizer}")

if config.initialize_from_hf:
if config.trainer.initialize_from is not None:
raise ValueError("Cannot use both --initialize_from_hf and --initialize_from")

assert isinstance(config.model, HFCompatConfig)

converter = HFCheckpointConverter.from_hf(config.model_name_or_path, trust_remote_code=True)
if hasattr(tokenizer, "vocab") and tokenizer.vocab != converter.tokenizer.vocab:
logger.warning("The tokenizers appear to be different. You may want to check this.")
if isinstance(config.initialize_from_hf, str):
converter = converter.replaced(reference_checkpoint=config.initialize_from_hf, tokenizer=tokenizer)
else:
converter = converter.replaced(tokenizer=tokenizer)

model_config = converter.default_config
model_config = dataclasses.replace(converter.default_config, seq_len=config.max_seq_len)
elif config.trainer.initialize_from is None:
raise ValueError("Must specify either --initialize_from_hf or --initialize_from")
else:
if config.hf_save_steps:
converter = HFCheckpointConverter.from_hf(config.model_name_or_path, trust_remote_code=True)
converter = converter.replaced(tokenizer=tokenizer)
else:
converter = None
model_config = config.model

config = dataclasses.replace(config, model=model_config)
levanter.initialize(config)

num_new_tokens = add_special_tokens(tokenizer)
logger.info(f"Added {num_new_tokens} new tokens")

seed = config.trainer.seed
data_key, _, model_key, training_key = jrandom.split(jrandom.PRNGKey(seed), 4)

if config.data_seed is not None:
logger.info(f"Overriding data seed with {config.data_seed}")
data_key = jrandom.PRNGKey(config.data_seed)

# Create supervised datasets using generic machinery
logger.info("Creating supervised datasets")

# Create individual datasets
train_datasets = {}
for name, source_config in config.supervised_data.items():
if config.dataset_type == DatasetType.CHAT_JSONL:
train_dataset = mk_cached_sft_dataset(
ChatUrlDataSourceConfig(
cache_dir=source_config.cache_dir,
train_urls=source_config.train_urls,
messages_field=config.messages_field,
input_role=config.input_role,
output_role=config.output_role,
),
tokenizer,
model_config.Pos
)
train_dataset = PermutationDataset(train_dataset, data_key)
else:
train_dataset = mk_supervised_dataset(source_config, "train", tokenizer, model_config.Pos)

train_datasets[name] = train_dataset

# Create mixture dataset
logger.info("Creating mixture dataset")
train_dataset = MixtureDataset(
datasets=train_datasets,
weights=config.mixture_weights,
block_size=config.mixture_block_size,
key=data_key,
stop_strategy=config.stop_strategy
)

logger.info("Creating optimizer")
optimizer = config.optimizer.build(config.trainer.num_train_steps)

with Trainer(config.trainer, optimizer, loss_fn=compute_next_token_loss) as trainer:
parameter_axis_mapping = trainer.parameter_axis_mapping

Pos = config.model.Pos
vocab_size = len(tokenizer)
Vocab = round_axis_for_partitioning(Axis("vocab", vocab_size), parameter_axis_mapping)

if config.initialize_from_hf:
logger.info(f"Loading pretrained model from {converter.reference_checkpoint}")
model: LmHeadModel = converter.load_pretrained(
model_config.model_type, axis_mapping=parameter_axis_mapping, dtype=trainer.mp.param_dtype
)
model = hax.named_jit(lambda m: m.resize_vocab(len(tokenizer)))(model)
state = trainer.initial_state(training_key, model=model)
else:
if vocab_size != Vocab.size:
logger.info(f"Rounding vocab size from {vocab_size} to {Vocab.size} for partitioning")
state = trainer.initial_state(training_key, model_init=lambda: config.model.build(Vocab, key=model_key))

flops_per_token = config.model.flops_per_token(vocab_size)
flops_per_example = 3 * flops_per_token * Pos.size if flops_per_token is not None else None
trainer.add_hook(
callbacks.log_performance_stats(Pos.size, trainer.config.train_batch_size, flops_per_example), every=1
)

current_step = int(state.step)

logger.info("Creating prompt completion iterator")
prompt_completion_iterator = create_prompt_completion_iterator(train_dataset, Pos)

if current_step > 0:
logger.info(f"Resuming training from step {current_step}")
examples_to_skip = current_step * trainer.config.train_batch_size

for _ in range(examples_to_skip):
try:
next(prompt_completion_iterator)
except StopIteration:
logger.warning("Ran out of examples while seeking - restarting from beginning")
prompt_completion_iterator = create_prompt_completion_iterator(train_dataset, Pos)
else:
logger.info("Starting SFT from scratch")

logger.info("Packing prompt completions")
packed_iterator = pack_prompt_completions(
Pos,
prompt_completion_iterator,
max_segments_per_example=4,
pad_token=tokenizer.pad_token_id,
max_buffered_examples=16,
)
logger.info("Stacking batches to train batch")
packed_iterator = stack_batches(example_iterator=packed_iterator, Pos=Pos, Batch=trainer.TrainBatch)
logger.info("Creating data loader")
packed_loader = BackgroundIterator(packed_iterator, max_capacity=1024)

if config.hf_save_path is not None:
if config.trainer.checkpointer.append_run_id_to_base_path:
full_save_path = os.path.join(config.hf_save_path, trainer.run_id)
else:
full_save_path = config.hf_save_path

trainer.add_hook(
save_hf_checkpoint_callback(full_save_path, converter, upload_to_hf=config.hf_upload or False),
every=config.hf_save_steps,
)

trainer.train(state, packed_loader)


def create_prompt_completion_iterator(cached_dataset: AsyncDataset, Pos: hax.Axis) -> Iterator[PromptCompletion]:
"""Creates an iterator that yields PromptCompletion objects from a cached dataset."""
# Instead of getting total length upfront, we'll keep track of processed examples
processed_examples = 0
batch_size = 4096

while True: # Infinite loop for restart strategy
# Get next batch of indices
indices = list(range(processed_examples, processed_examples + batch_size))
try:
examples = asyncio.run(cached_dataset.get_batch(indices))
except (IndexError, ValueError):
# If we hit the end of the dataset, start over
processed_examples = 0
continue

for i in range(len(examples)):
example = examples[i]
sources_len = example["sources_len"].item()
if sources_len > Pos.size - 1:
continue

ids = example["input_ids"].tolist()
if len(ids) > Pos.size:
ids = ids[: Pos.size]

if len(ids) <= sources_len:
continue

try:
yield PromptCompletion(ids=ids, prompt_length=sources_len, segment_id=indices[i])
except ValueError as e:
logger.error(
f"Error creating PromptCompletion (ids length: {len(ids)}, sources_len: {sources_len}, segment id:"
f" {indices[i]}): {e}"
)
continue

processed_examples += batch_size


def add_special_tokens(tokenizer, use_unk_instead_of_adding=False):
special_tokens_dict = dict()
if use_unk_instead_of_adding:
if tokenizer.unk_token is None:
raise ValueError("use_unk_instead_of_add is True but tokenizer doesn't have an unk token")

unk = tokenizer.unk_token if use_unk_instead_of_adding else None

if tokenizer.pad_token is None:
special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN if not use_unk_instead_of_adding else unk
if tokenizer.eos_token is None:
special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN if not use_unk_instead_of_adding else unk
if tokenizer.bos_token is None:
special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN if not use_unk_instead_of_adding else unk
if tokenizer.unk_token is None:
special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN

return tokenizer.add_special_tokens(special_tokens_dict)


if __name__ == "__main__":
levanter.config.main(train)()

0 comments on commit ff737dd

Please sign in to comment.