Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mod-Chatbot #67

Open
wants to merge 40 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
3397a4a
initial chat datamodule
matheper Jul 23, 2024
8a1fea8
Split chat into individual chat turns
matheper Jul 25, 2024
01faf42
Merge branch 'main' into chat_template
matheper Jul 25, 2024
ea0ab8b
black and isort
matheper Jul 25, 2024
5b2e3cd
Added get_clusters from #70
matheper Jul 31, 2024
df2e2aa
lazy (literaly) setup_dataset and get_datamodule
matheper Aug 2, 2024
8b4a99a
Merge remote-tracking branch 'origin/main' into chat_template
Aug 2, 2024
f1e6ac8
add pack sequences
Aug 2, 2024
a944701
update config file
Aug 2, 2024
01a2113
quant config
Aug 3, 2024
d83d8a8
task name can be int
Aug 3, 2024
778ca09
chat datamodule filter task name
Aug 3, 2024
40d974b
prepare kbit on 4bit
Aug 3, 2024
1b60a8a
paged optim on bnb
Aug 3, 2024
8b06f7c
trust
Aug 3, 2024
d724f20
remove unused config options
Aug 5, 2024
c7fd876
remove length printing
Aug 5, 2024
19464d7
remove some comments
Aug 5, 2024
7166a20
mistral for chatbot
Aug 5, 2024
820f653
task_name_field to filter on a sepecific column in the dataset
Aug 5, 2024
cd7dbc6
add support for a few flags in the embedding model
Aug 5, 2024
0510f31
remove breakpoint, oops
Aug 5, 2024
7a47431
script to dump orca in json format + em topic script
Aug 7, 2024
b7a7ed9
some speed-ups
Aug 7, 2024
6b29acb
remove previous instruction
Aug 7, 2024
7c70ac4
resume inference from output file
Aug 8, 2024
003b556
contrastive tagging
Aug 9, 2024
c55fe17
minimal template changes
Aug 9, 2024
6df090a
gpt tagging
Aug 9, 2024
09a2bef
infer/train model
Aug 9, 2024
dfd7a0f
Merge branch 'main' into chat_template
matheper Aug 12, 2024
7ad0e61
Merge branch 'chat_template' of github.com:microsoft/mttl into chat_t…
matheper Aug 12, 2024
6154d08
Merge branch 'main' into em-topic
matheper Aug 12, 2024
1f79e22
Merge remote-tracking branch 'origin/move-files' into chat_template
Aug 13, 2024
6f9ac59
Merge remote-tracking branch 'origin/main' into chat_template
Aug 14, 2024
ee82cf9
Merge remote-tracking branch 'origin/chat_template' into em-topic
Aug 14, 2024
d803dae
move stuff into modular-chatbot
Aug 14, 2024
d68af91
small updates
Aug 16, 2024
8c239bd
load from dataset_type
Aug 16, 2024
5f787ed
Merge pull request #92 from microsoft/em-topic
matheper Aug 16, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions mttl/cli/prettify_orca.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import json
import os

import click
from datasets import load_dataset
from transformers import AutoTokenizer


@click.command()
@click.option("--input_jsonl", help="Path to the input jsonl file")
@click.option("--output_jsonl", help="Model name or path to the model")
def main(input_jsonl, output_jsonl):
num_proc = os.environ.get("MTTL_NUM_PROC_DATASETS", 16)

def prettify(examples):
examples_ = {key: value for key, value in examples.items()}
examples_["messages"] = []
examples_["metadata"] = []
examples_["task_name"] = []
for messages, metadata in zip(examples["messages"], examples["metadata"]):
messages = json.loads(messages)
task_name = json.loads(metadata or "{}").get("task_name", "unknown")
examples_["messages"].append(messages)
examples_["metadata"].append(metadata)
examples_["task_name"].append(task_name)
return examples_

dataset = load_dataset("json", data_files=input_jsonl)
dataset = dataset.map(
prettify,
batched=True, # allows to return more examples than the input
remove_columns=dataset["train"].column_names,
num_proc=num_proc,
)
dataset["train"].to_json(output_jsonl)


if __name__ == "__main__":
main()
82 changes: 61 additions & 21 deletions mttl/datamodule/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ def __getitem__(self, idx):

@dataclass
class DatasetConfig:
"""Generic dataclass for dataset and batching configuration."""

dataset: str = None
data_dir: str = None
model: str = None
Expand All @@ -91,17 +93,19 @@ class DatasetConfig:
subsample_train: int = None
subsample_dev: int = None
subsample_test: int = None
subsample_per_task: bool = False # Changing default to False
subsample_per_task: bool = False
subsample: int = -1
pack_sequences: bool = False # True
pad_to_multiple_of: int = 8
max_seq_per_pack: int = 4
task_id_field: str = "task_id"
task_name_field: str = "task_name"
task_source_field: str = "task_source"


@dataclass
class DefaultCollator:
"""Simple collator

"""
Converts a batch of examples into a batch of inputs and labels for a sequence to sequence task.
If model_family is "gpt", then the inputs and outputs are constructed for a causal language model,
e.g. concatenated in a single string and labels are set to be -100 for all tokens in the input.
Expand All @@ -117,8 +121,15 @@ class DefaultCollator:
model_family: str = "seq2seq"
for_generation: bool = False
train_on_inputs: bool = False
task_to_id: dict = None
add_eos_to_targets: bool = True
task_to_id: dict = None # mapping from task name to task id
add_eos_to_targets: bool = True # add eos token to the end of the target sequence
task_id_field: str = (
"task_id" # where to read task id information from in the batch
)
task_name_field: str = (
"task_name" # where to read task name information from in the batch
)
task_source_field: str = "task_source"

def enforce_eos(self, targets):
# simulate the default behaviour of LLamatokenizer, when adding eos token and truncating: the last token must always be eos
Expand Down Expand Up @@ -408,9 +419,10 @@ def pad_sequence_wrapper(tensor_list, batch_first, padding_value, side="right"):
# Otherwise process as expected
sources = [b["source"] for b in batch]
labels = [b["target"] for b in batch]
task_ids = [b.get("task_id", None) for b in batch]
task_names = [b.get("task_name", None) for b in batch]
task_sources = [b.get("task_source", None) for b in batch]

task_ids = [b.get(self.task_id_field, None) for b in batch]
task_names = [b.get(self.task_name_field, None) for b in batch]
task_sources = [b.get(self.task_source_field, None) for b in batch]

output_batch = (
self.prepare_inputs_for_gpt_family(sources, labels)
Expand All @@ -427,7 +439,7 @@ def pad_sequence_wrapper(tensor_list, batch_first, padding_value, side="right"):
[self.task_to_id[tn] for tn in task_names]
)
elif has_task_ids:
output_batch["task_ids"] = torch.LongTensor(task_ids)
output_batch["task_ids"] = torch.LongTensor(list(map(int, task_ids)))

if has_task_names and not has_task_sources:
task_sources = task_names
Expand Down Expand Up @@ -583,6 +595,9 @@ def collate_fn(self):
train_on_inputs=self.config.train_on_inputs,
add_eos_to_targets=self.config.add_eos_to_targets,
task_to_id=self.task_to_id,
task_name_field=self.config.task_name_field,
task_id_field=self.config.task_id_field,
task_source_field=self.config.task_source_field,
)

def print_infos(self):
Expand Down Expand Up @@ -644,7 +659,6 @@ def subsample_dataset(self, dataset, n_samples, per_task=False):

Raises:
AssertionError: If `per_task` is True and the dataset is not an ArrowDataset.

"""

def get_dst_idxs_sampled(n_samples, total_size):
Expand All @@ -658,7 +672,7 @@ def get_dst_idxs_sampled(n_samples, total_size):
# make this deterministic to always sample the same subset
if isinstance(dataset, ArrowDataset):
if per_task:
task_names = dataset.unique("task_name")
task_names = dataset.unique(self.config.task_name_field)
subsampled_dataset = []
for i, task_name in enumerate(task_names):
logger.info(
Expand All @@ -667,15 +681,22 @@ def get_dst_idxs_sampled(n_samples, total_size):
task_idxs = torch.tensor(
[
index
for index, value in enumerate(dataset["task_name"])
for index, value in enumerate(
dataset[self.config.task_name_field]
)
if value == task_name
]
)
idxs = get_dst_idxs_sampled(n_samples, len(task_idxs))
task_idxs = task_idxs[idxs]
task_dataset = dataset.select(task_idxs)
subsampled_dataset.append(task_dataset)
assert all([t == task_name for t in task_dataset["task_name"]])
assert all(
[
t == task_name
for t in task_dataset[self.config.task_name_field]
]
)
subsampled_dataset = concatenate_datasets(subsampled_dataset)
else:
idxs = get_dst_idxs_sampled(n_samples, total_size)
Expand All @@ -699,7 +720,7 @@ def __init__(
self.for_generation = for_generation
self.tokenizer = get_tokenizer(config, for_generation=for_generation)
self.setup_dataset()
self.post_setup_dataset()
self._post_setup_dataset()

def setup(self, stage=None):
pass
Expand Down Expand Up @@ -737,9 +758,6 @@ def pack_sequences(self, dataset, max_sequences=4, shuffle=True):
if shuffle:
dataset = dataset.shuffle(seed=42)

# TODO: first partition dataset according to `task_name`, and
# pack each task individually to ensure that we don't mix tasks

# Very basic code that will iterate over sequences one by one,
# and merge together until the max_input_length is reached
# This is not optimal, but it's a start
Expand All @@ -762,7 +780,6 @@ def append_to_running_seq(container, example):
else:
raise ValueError(f"Unknown type {type(v)}")

# TODO: THis is SOMEHOW WRONG. CHECK.
container["seq_lens"] += [len(example["input_ids"])]

def add_finished_sequence(container, example):
Expand Down Expand Up @@ -814,10 +831,11 @@ def dict_get_item(ex, i):
)
return dataset

def post_setup_dataset(self):
def _post_setup_dataset(self):
# subsample the splits if needed
for split in ["train", "dev", "test"]:

subsample = getattr(self.config, f"subsample_{split}", None)

if subsample and subsample > 0:
dataset = getattr(self, f"{split}_dataset")
logger.warning(
Expand All @@ -832,6 +850,7 @@ def post_setup_dataset(self):
if self.config.pack_sequences and split == "train":
dataset = getattr(self, f"{split}_dataset")
logger.info(f"Packing sequences for {split} dataset")

dataset = self.tokenize_dataset(dataset)
dataset = self.pack_sequences(
dataset, max_sequences=self.config.max_seq_per_pack
Expand All @@ -856,6 +875,9 @@ def collate_fn(self):
train_on_inputs=self.config.train_on_inputs,
task_to_id=self.task_to_id,
add_eos_to_targets=self.config.add_eos_to_targets,
task_name_field=self.config.task_name_field,
task_id_field=self.config.task_id_field,
task_source_field=self.config.task_source_field,
)


Expand All @@ -880,10 +902,14 @@ def collate_fn(self):
task_to_id=self.task_to_id,
multisource=True,
add_eos_to_targets=self.config.add_eos_to_targets,
task_name_field=self.config.task_name_field,
task_id_field=self.config.task_id_field,
task_source_field=self.config.task_source_field,
)


def get_datamodule(args, for_generation=False, dataset_override=None):
from mttl.config import DataArgs
from mttl.datamodule.arc_data_module import ArcDataConfig, ArcMultiChoiceDataModule
from mttl.datamodule.codex_data_module import CodexDataConfig, CodexDataModule
from mttl.datamodule.hellaswag_data_module import (
Expand Down Expand Up @@ -914,7 +940,15 @@ def get_datamodule(args, for_generation=False, dataset_override=None):
WinograndeMultiChoiceDataModule,
)

# refactor all the common arguments below into a dict common kwargs
# if we have a DataArgs object, we can directly create the datamodule
if isinstance(args, DataArgs) and args.dataset_type is not None:
dataset_config = args.dataset_config

return DataModule.get_class_by_config_class(type(dataset_config))(
dataset_config, for_generation=for_generation
)

# we fall back to previous behavior
dataset = args.dataset if not dataset_override else dataset_override

common_kwargs = {
Expand All @@ -937,6 +971,7 @@ def get_datamodule(args, for_generation=False, dataset_override=None):
"pad_to_multiple_of": args.pad_to_multiple_of,
"padding_side": args.padding_side,
"max_seq_per_pack": args.max_seq_per_pack,
"pack_sequences": args.pack_sequences,
}

if dataset in [
Expand Down Expand Up @@ -985,6 +1020,11 @@ def get_datamodule(args, for_generation=False, dataset_override=None):
assert not for_generation
config = dataset_to_klass_map[dataset][0]
dm = dataset_to_klass_map[dataset][1](config)
elif "clusters" in dataset:
config = ClusterDataConfig(
**common_kwargs,
)
dm = ClusterDataModule(config, for_generation=for_generation)
elif "flan" in dataset:
config = FlanConfig(
**common_kwargs,
Expand Down
44 changes: 32 additions & 12 deletions mttl/datamodule/mt_seq_to_seq_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ def augment_few_shot_task(
max_input_length=None,
seed=42,
modify_task_source=True,
task_source_field="task_source",
task_name_field="task_name",
):
if num_samples is None and few_shots is None:
raise ValueError("Either num_samples or few_shots must be specified.")
Expand Down Expand Up @@ -77,11 +79,11 @@ def map_to_few_shot(_, index):
return {
"source": prompt,
"target": dataset[index]["target"],
"task_name": dataset[index]["task_name"],
"task_source": (
"few_shot_{}".format(dataset[index]["task_source"])
task_name_field: dataset[index][task_name_field],
task_source_field: (
"few_shot_{}".format(dataset[index][task_source_field])
if modify_task_source
else dataset[index]["task_source"]
else dataset[index][task_source_field]
),
"split": (
dataset[index]["split"] if "split" in dataset.column_names else None
Expand All @@ -93,21 +95,31 @@ def map_to_few_shot(_, index):


def augment_few_shot(
dataset, num_samples, tokenizer=None, max_input_length=None, seed=42
dataset,
num_samples,
tokenizer=None,
max_input_length=None,
seed=42,
task_name_field="task_name",
task_source_field="task_source",
):
"""Augment the dataset with few-shot examples."""
import tqdm

augmented_dataset = []
for source in tqdm.tqdm(dataset.unique("task_name")):
for source in tqdm.tqdm(dataset.unique(task_name_field)):
augmented_dataset.append(
Dataset.from_list(
augment_few_shot_task(
dataset.filter(lambda x: x["task_name"] == source),
num_samples,
tokenizer,
max_input_length,
seed,
dataset.filter(lambda x: x[task_name_field] == source),
num_samples=num_samples,
few_shots=None,
tokenizer=tokenizer,
max_input_length=max_input_length,
seed=seed,
modify_task_source=True,
task_name_field=task_name_field,
task_source_field=task_source_field,
)
)
)
Expand Down Expand Up @@ -140,6 +152,9 @@ def setup_dataset(self):
self.dataset = DatasetLibrary.pull_dataset_with_retry(self.config.dataset)
n_proc = int(os.environ.get("MTTL_NUM_PROC_DATASETS", 16))

if "train" not in self.dataset.column_names:
raise ValueError("Flat multi-task datasets must have a 'train' split!")

if "split" not in self.dataset.column_names["train"]:
logger.warning(
"Dataset *should* have a 'split' column, try removing the dataset manually from the cache! Creating a new 'split' column."
Expand All @@ -161,7 +176,10 @@ def create_split(rng, _):
_,
_,
) = maybe_filter_hf_dataset_by_task(
self.dataset, "task_name", self.config.finetune_task_name, n_proc=n_proc
self.dataset,
self.config.task_name_field,
self.config.finetune_task_name,
n_proc=n_proc,
)

train_dataset = apply_source_template(
Expand All @@ -174,6 +192,8 @@ def create_split(rng, _):
self.config.augment_few_shot,
tokenizer=self.tokenizer,
max_input_length=self.config.max_input_length,
task_name_field=self.config.task_name_field,
task_source_field=self.config.task_source_field,
)
train_dataset_aug = train_dataset_aug.shuffle()
train_dataset = train_dataset_aug.select(range(len(train_dataset)))
Expand Down
8 changes: 2 additions & 6 deletions mttl/datamodule/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,8 @@ def maybe_filter_hf_dataset_by_task(
if "test" in dataset:
all_tasks = all_tasks.union(set(dataset["test"][task_field]))

if task_names:
task_names = (
sorted(task_names.split(","))
if isinstance(task_names, str)
else sorted(task_names)
)
if task_names is not None:
task_names = sorted(str(task_names).split(","))
if not set(task_names).issubset(all_tasks):
raise ValueError(
"task_names must be a subset of the available tasks. Got {} and {}".format(
Expand Down
Loading