Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/support_more_data_format'
Browse files Browse the repository at this point in the history
  • Loading branch information
SLR722 committed Jan 10, 2025
2 parents 96735e9 + 1e915d8 commit 1095bb9
Show file tree
Hide file tree
Showing 8 changed files with 155 additions and 77 deletions.
7 changes: 7 additions & 0 deletions llama_stack/apis/post_training/post_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,18 @@ class OptimizerType(Enum):
sgd = "sgd"


@json_schema_type
class DatasetFormat(Enum):
instruct = "instruct"
dialog = "dialog"


@json_schema_type
class DataConfig(BaseModel):
dataset_id: str
batch_size: int
shuffle: bool
data_format: DatasetFormat
validation_dataset_id: Optional[str] = None
packed: Optional[bool] = False
train_on_input: Optional[bool] = False
Expand Down
48 changes: 48 additions & 0 deletions llama_stack/providers/inline/post_training/common/validator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

# Copyright (c) Meta Platforms, IAny, nc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import StringType
from llama_stack.apis.datasets import Datasets
from llama_stack.providers.utils.common.data_schema_validator import (
ColumnName,
validate_dataset_schema,
)

EXPECTED_DATASET_SCHEMA = {
"instruct": [
{
ColumnName.chat_completion_input.value: StringType(),
ColumnName.expected_answer.value: StringType(),
}
],
"dialog": [
{
ColumnName.dialog.value: StringType(),
}
],
}


async def validate_input_dataset_schema(
datasets_api: Datasets,
dataset_id: str,
dataset_type: str,
) -> None:
dataset_def = await datasets_api.get_dataset(dataset_id=dataset_id)
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
raise ValueError(f"Dataset {dataset_id} does not have a schema defined.")

if dataset_type not in EXPECTED_DATASET_SCHEMA:
raise ValueError(f"Dataset type {dataset_type} is not supported.")

validate_dataset_schema(
dataset_def.dataset_schema, EXPECTED_DATASET_SCHEMA[dataset_type]
)
Original file line number Diff line number Diff line change
Expand Up @@ -10,28 +10,22 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from enum import Enum
from typing import Any, Callable, Dict, List
from typing import Any, Callable, Dict

import torch
from llama_models.datatypes import Model
from llama_models.sku_list import resolve_model
from llama_stack.apis.common.type_system import ParamType, StringType
from llama_stack.apis.datasets import Datasets

from llama_stack.apis.post_training import DatasetFormat

from pydantic import BaseModel
from torchtune.data._messages import InputOutputToMessages, ShareGPTToMessages

from torchtune.models.llama3 import llama3_tokenizer
from torchtune.models.llama3._tokenizer import Llama3Tokenizer
from torchtune.models.llama3_1 import lora_llama3_1_8b
from torchtune.models.llama3_2 import lora_llama3_2_3b


class ColumnName(Enum):
instruction = "instruction"
input = "input"
output = "output"
text = "text"
from torchtune.modules.transforms import Transform


class ModelConfig(BaseModel):
Expand All @@ -40,10 +34,6 @@ class ModelConfig(BaseModel):
checkpoint_type: str


class DatasetSchema(BaseModel):
alpaca: List[Dict[str, ParamType]]


MODEL_CONFIGS: Dict[str, ModelConfig] = {
"Llama3.2-3B-Instruct": ModelConfig(
model_definition=lora_llama3_2_3b,
Expand All @@ -57,26 +47,11 @@ class DatasetSchema(BaseModel):
),
}

DATA_FORMATS: Dict[str, Transform] = {
"instruct": InputOutputToMessages,
"dialog": ShareGPTToMessages,
}

EXPECTED_DATASET_SCHEMA = DatasetSchema(
alpaca=[
{
ColumnName.instruction.value: StringType(),
ColumnName.input.value: StringType(),
ColumnName.output.value: StringType(),
ColumnName.text.value: StringType(),
},
{
ColumnName.instruction.value: StringType(),
ColumnName.input.value: StringType(),
ColumnName.output.value: StringType(),
},
{
ColumnName.instruction.value: StringType(),
ColumnName.output.value: StringType(),
},
]
)

BuildLoraModelCallable = Callable[..., torch.nn.Module]
BuildTokenizerCallable = Callable[..., Llama3Tokenizer]
Expand Down Expand Up @@ -123,19 +98,5 @@ async def get_checkpointer_model_type(
return model_config.checkpoint_type


async def validate_input_dataset_schema(
datasets_api: Datasets,
dataset_id: str,
dataset_type: str,
) -> None:
dataset_def = await datasets_api.get_dataset(dataset_id=dataset_id)
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
raise ValueError(f"Dataset {dataset_id} does not have a schema defined.")

if not hasattr(EXPECTED_DATASET_SCHEMA, dataset_type):
raise ValueError(f"Dataset type {dataset_type} is not supported.")

if dataset_def.dataset_schema not in getattr(EXPECTED_DATASET_SCHEMA, dataset_type):
raise ValueError(
f"Dataset {dataset_id} does not have a correct input schema in {getattr(EXPECTED_DATASET_SCHEMA, dataset_type)}"
)
async def get_data_transform(data_format: DatasetFormat) -> Transform:
return DATA_FORMATS[data_format.value]
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Mapping

from llama_stack.providers.utils.common.data_schema_validator import ColumnName


def llama_stack_instruct_to_torchtune_instruct(
sample: Mapping[str, Any]
) -> Mapping[str, Any]:
assert (
ColumnName.chat_completion_input.value in sample
and ColumnName.expected_answer.value in sample
), "Invalid input row"
input_messages = eval(str(sample[ColumnName.chat_completion_input.value]))

assert (
len(input_messages) == 1
), "llama stack intruct dataset format only supports 1 user message"
input_message = input_messages[0]

assert "content" in input_message, "content not found in input message"
input = input_message["content"]
output = sample[ColumnName.expected_answer.value]

return {
"input": input,
"output": output,
}


def llama_stack_chat_to_torchtune_chat(sample: Mapping[str, Any]) -> Mapping[str, Any]:
assert ColumnName.dialog.value in sample, "Invalid input row"
role_map = {"user": "human", "assistant": "gpt"}
dialog = eval(str(sample[ColumnName.dialog.value]))

assert len(dialog) > 1, "dialog must have at least 2 messagse"
roles = []
conversations = []
for message in dialog:
assert (
"role" in message and "content" in message
), "role and content must in message"
roles.append(message["role"])
conversations.append(
{"from": role_map[message["role"]], "value": message["content"]}
)

assert roles[0] == "user", "first message must be from user"
assert "assistant" in roles, "at least 1 message should be from assistant"

return {"conversations": conversations}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
from typing import Any, Dict, List, Mapping

import numpy as np
from llama_stack.providers.inline.post_training.torchtune.datasets.format_adapter import (
llama_stack_chat_to_torchtune_chat,
llama_stack_instruct_to_torchtune_instruct,
)

from torch.utils.data import Dataset
from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX
Expand All @@ -26,10 +30,12 @@ def __init__(
rows: List[Dict[str, Any]],
message_transform: Transform,
model_transform: Transform,
dataset_type: str,
) -> None:
self._rows = rows
self._message_transform = message_transform
self._model_transform = model_transform
self._dataset_type = dataset_type

def __len__(self):
return len(self._rows)
Expand All @@ -39,6 +45,12 @@ def __getitem__(self, index: int) -> Dict[str, Any]:
return self._prepare_sample(sample)

def _prepare_sample(self, sample: Mapping[str, Any]) -> Dict[str, Any]:
if self._dataset_type == "instruct":
sample = llama_stack_instruct_to_torchtune_instruct(sample)
elif self._dataset_type == "dialog":
sample = llama_stack_chat_to_torchtune_chat(sample)
else:
raise ValueError(f"Invalid dataset type: {self._dataset_type}")
transformed_sample = self._message_transform(sample)
if "messages" in transformed_sample:
validate_messages(transformed_sample["messages"])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR

from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.providers.inline.post_training.common.validator import (
validate_input_dataset_schema,
)

from llama_stack.providers.inline.post_training.torchtune.common import utils
from llama_stack.providers.inline.post_training.torchtune.common.checkpointer import (
Expand All @@ -42,7 +45,7 @@
from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler
from torchtune import modules, training, utils as torchtune_utils
from torchtune.data import AlpacaToMessages, padded_collate_sft
from torchtune.data import padded_collate_sft

from torchtune.modules.loss import CEWithChunkedOutputLoss
from torchtune.modules.peft import (
Expand Down Expand Up @@ -129,8 +132,10 @@ def model_checkpoint_dir(model) -> str:
self.seed = training.set_seed(seed=config.torch_seed)
self.epochs_run = 0
self.total_epochs = training_config.n_epochs
self._data_format = training_config.data_config.data_format
self._shuffle = training_config.data_config.shuffle
self._batch_size = training_config.data_config.batch_size
self._train_on_input = training_config.data_config.train_on_input

# this is important for debugging purpose
self.max_steps_per_epoch = training_config.max_steps_per_epoch
Expand Down Expand Up @@ -354,18 +359,17 @@ async def fetch_rows(dataset_id: str):
all_rows = await fetch_rows(dataset_id)
rows = all_rows.rows

# Curretly only support alpaca instruct dataset
# TODO @SLR722 make the message_transform swappable and support more dataset types
# TODO @SLR722 make the input dataset schema more flexible by exposing column_map
await utils.validate_input_dataset_schema(
await validate_input_dataset_schema(
datasets_api=self.datasets_api,
dataset_id=dataset_id,
dataset_type="alpaca",
dataset_type=self._data_format.value,
)
data_transform = await utils.get_data_transform(self._data_format)
ds = SFTDataset(
rows,
message_transform=AlpacaToMessages(train_on_input=False),
message_transform=data_transform(train_on_input=self._train_on_input),
model_transform=tokenizer,
dataset_type=self._data_format.value,
)

sampler = DistributedSampler(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class ColumnName(Enum):
completion_input = "completion_input"
generated_answer = "generated_answer"
context = "context"
dialog = "dialog"


VALID_SCHEMAS_FOR_SCORING = [
Expand Down
23 changes: 3 additions & 20 deletions llama_stack/templates/experimental-post-training/run.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ providers:
provider_type: inline::basic
config: {}
datasetio:
- provider_id: huggingface-0
provider_type: remote::huggingface
- provider_id: localfs
provider_type: inline::localfs
config: {}
telemetry:
- provider_id: meta-reference
Expand Down Expand Up @@ -68,23 +68,6 @@ metadata_store:
models: []
shields: []
memory_banks: []
datasets:
- dataset_id: alpaca
provider_id: huggingface-0
url:
uri: https://huggingface.co/datasets/tatsu-lab/alpaca
metadata:
path: tatsu-lab/alpaca
name:
split: train
dataset_schema:
instruction:
type: string
input:
type: string
output:
type: string
text:
type: string
datasets: []
scoring_fns: []
eval_tasks: []

0 comments on commit 1095bb9

Please sign in to comment.