From 9d126747409eeaaf72f505190d1b2d2da4ecdb4b Mon Sep 17 00:00:00 2001 From: NeoSong Date: Sat, 17 Jun 2023 12:49:02 +0800 Subject: [PATCH] Add support for datasets in local json files (#597) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix gitignore * add local dataset dir * add ignore * add local dataset support * add some about local dataset in README.md * fix some * add some in README * remove data dir * add line to gitignore * fix some following good advices * fix some about format * reformat code using yapf --------- Co-authored-by: 宋超 --- .gitignore | 2 + applications/DeepSpeed-Chat/README.md | 7 +++ .../training/utils/data/data_utils.py | 19 ++++++- .../training/utils/data/raw_datasets.py | 57 ++++++++++++++++++- 4 files changed, 81 insertions(+), 4 deletions(-) diff --git a/.gitignore b/.gitignore index b6e47617de..faf4e2aff3 100644 --- a/.gitignore +++ b/.gitignore @@ -7,11 +7,13 @@ __pycache__/ *.so # Distribution / packaging +.idea .Python build/ develop-eggs/ dist/ downloads/ +applications/DeepSpeed-Chat/data eggs/ .eggs/ lib/ diff --git a/applications/DeepSpeed-Chat/README.md b/applications/DeepSpeed-Chat/README.md index f274afce01..22aa1f1578 100644 --- a/applications/DeepSpeed-Chat/README.md +++ b/applications/DeepSpeed-Chat/README.md @@ -239,6 +239,13 @@ If you have downloaded huggingface datasets manually, you can add your local pat One thing to note that some datasets may only have one response instead of two responses. For those datasets, you can only use them in step 1. And in such case, you should add the dataset_name as part of the "--sft_only_data_path" arg instead of the "--data_path" arg. One thing to note is that: If you plan to only do step 1 SFT, adding more single-response datasets is definitely beneficial. However, if you do plan to do steps 2 and 3, then adding too many single-response datasets during SFT could backfire: these data could be different from the data used for steps 2/3, generating different distributions which could cause training instability/worse model quality during step 2/3. That is part of the reason why we focused on trying the datasets with two responses and the preference, and always split a dataset into all 3 steps. +If you have your own dataset in local files, you can also use it by following these rules: +* Pass "local/jsonfile" as the dataset name to the "--data_path" argument. +* Put your train data and evaluation data in applications/DeepSpeed-Chat/data/ with name train.json and eval.json. +* The json data in file should be a single list with each item like ***{"prompt": "Human: I have a question. Assistant:", "chosen": "Good answer.", "rejected": "Bad answer."}***. + +What is more, when you use your own dataset files and modified some data in them, pay attention to the parameter "reload" of ***create_prompt_dataset*** function. You should pass a True value to it or the cache files will not refresh. + ### 🐼 Customizing your own RLHF training pipeline using DeepSpeed-Chat’s RLHF APIs DeepSpeed-Chat allows users to build their very own RLHF training pipeline using our flexible APIs shown below, which users can use to reconstruct their own RLHF training strategy. This enables a general interface and backend for creating a wide range of RLHF algorithms for research exploration. diff --git a/applications/DeepSpeed-Chat/training/utils/data/data_utils.py b/applications/DeepSpeed-Chat/training/utils/data/data_utils.py index a6e4a601a5..4f31f06617 100644 --- a/applications/DeepSpeed-Chat/training/utils/data/data_utils.py +++ b/applications/DeepSpeed-Chat/training/utils/data/data_utils.py @@ -64,6 +64,17 @@ def get_raw_dataset(dataset_name, output_path, seed, local_rank): elif "lmqg/qag_jaquad" in dataset_name: return raw_datasets.LmqgQagjaquadDataset(output_path, seed, local_rank, dataset_name) + elif "local/jsonfile" in dataset_name: + chat_path = os.path.abspath( + os.path.join(os.path.dirname(__file__), os.path.pardir, + os.path.pardir, os.path.pardir)) + if not (os.path.isfile(chat_path + '/data/train.json') + and os.path.isfile(chat_path + '/data/eval.json')): + raise RuntimeError( + f"Please check both the train.json and eval.json files in your applications/DeepSpeed-Chat/data directory." + ) + return raw_datasets.LocalJsonFileDataset(output_path, seed, local_rank, + dataset_name, chat_path) else: raise RuntimeError( f"We do not have configs for dataset {dataset_name}, but you can add it by yourself in raw_datasets.py." @@ -84,7 +95,8 @@ def get_raw_dataset_split_index(local_rank, output_path, dataset_name, seed, split_name, data_split, split_index, data_size): index_file_name = f"{output_path}/{dataset_name}_seed{seed}_{split_name}_{data_split}_{split_index}.npy" - if not os.path.isfile(index_file_name): + # reindex each time when using local jsonfile since it's more likely to get modified + if (not os.path.isfile(index_file_name)) or (dataset_name == 'jsonfile'): splits = [float(s) for s in data_split.split(',')] splits_sum = sum(splits) splits = [split / splits_sum for split in splits] @@ -252,7 +264,8 @@ def create_prompt_dataset(local_rank, tokenizer, max_seq_len, end_of_conversation_token="<|endoftext|>", - sft_only_data_path=[]): + sft_only_data_path=[], + reload=False): """ Creates the prompt dataset """ @@ -271,7 +284,7 @@ def create_prompt_dataset(local_rank, buf_create_cache = torch.ByteTensor([not cache_found]).cuda() torch.distributed.all_reduce(buf_create_cache) - if local_rank <= 0 and buf_create_cache.item() != 0: + if local_rank <= 0 and (buf_create_cache.item() != 0 or reload): if len(data_path) == 1: # Single dataset. train_dataset, eval_dataset = create_dataset( local_rank, data_path[0], data_split, output_path, train_phase, diff --git a/applications/DeepSpeed-Chat/training/utils/data/raw_datasets.py b/applications/DeepSpeed-Chat/training/utils/data/raw_datasets.py index 23666e234e..e8c519a8fc 100644 --- a/applications/DeepSpeed-Chat/training/utils/data/raw_datasets.py +++ b/applications/DeepSpeed-Chat/training/utils/data/raw_datasets.py @@ -15,7 +15,8 @@ def __init__(self, output_path, seed, local_rank, dataset_name): self.output_path = output_path self.seed = seed self.local_rank = local_rank - self.raw_datasets = load_dataset(dataset_name) + if not dataset_name == 'local/jsonfile': + self.raw_datasets = load_dataset(dataset_name) def get_train_data(self): return @@ -342,6 +343,60 @@ def get_prompt_and_rejected(self, sample): return None +class LocalJsonFileDataset(PromptRawDataset): + + def __init__(self, output_path, seed, local_rank, dataset_name, chat_path): + super().__init__(output_path, seed, local_rank, dataset_name) + self.dataset_name = "local/jsonfile" + self.dataset_name_clean = "jsonfile" + self.raw_datasets = load_dataset('json', + data_files={ + "train": + chat_path + '/data/train.json', + "eval": + chat_path + '/data/eval.json' + }) + + def get_train_data(self): + if self.raw_datasets['train'] is not None: + return self.raw_datasets['train'] + return None + + def get_eval_data(self): + if self.raw_datasets['eval'] is not None: + return self.raw_datasets['eval'] + return None + + # The prompt should be in the format of: " Human: " + actual_prompt_sentence + " Assistant:" + def get_prompt(self, sample): + if sample['prompt'] is not None: + return " " + sample['prompt'] + return None + + # The chosen response should be in the format of: " " + actual_response_sentence + def get_chosen(self, sample): + if sample['chosen'] is not None: + return " " + sample['chosen'] + return None + + # The rejected response should be in the format of: " " + actual_response_sentence + # If the dataset does not have rejected response, return None + def get_rejected(self, sample): + if sample['rejected'] is not None: + return " " + sample['rejected'] + return None + + def get_prompt_and_chosen(self, sample): + if sample['prompt'] is not None and sample['chosen'] is not None: + return " " + sample['prompt'] + " " + sample['chosen'] + return None + + def get_prompt_and_rejected(self, sample): + if sample['prompt'] is not None and sample['rejected'] is not None: + return " " + sample['prompt'] + " " + sample['rejected'] + return None + + # Chinese dataset class Wangrui6ZhihuKOLDataset(PromptRawDataset):