Skip to content

Commit

Permalink
Add support for datasets in local json files (NVIDIA#597)
Browse files Browse the repository at this point in the history
* fix gitignore

* add local dataset dir

* add ignore

* add local dataset support

* add some about local dataset in README.md

* fix some

* add some in README

* remove data dir

* add line to gitignore

* fix some following good advices

* fix some about format

* reformat code using yapf

---------

Co-authored-by: 宋超 <[email protected]>
  • Loading branch information
LeetJoe and 宋超 authored Jun 17, 2023
1 parent 353603d commit 9d12674
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 4 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@ __pycache__/
*.so

# Distribution / packaging
.idea
.Python
build/
develop-eggs/
dist/
downloads/
applications/DeepSpeed-Chat/data
eggs/
.eggs/
lib/
Expand Down
7 changes: 7 additions & 0 deletions applications/DeepSpeed-Chat/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,13 @@ If you have downloaded huggingface datasets manually, you can add your local pat

One thing to note that some datasets may only have one response instead of two responses. For those datasets, you can only use them in step 1. And in such case, you should add the dataset_name as part of the "--sft_only_data_path" arg instead of the "--data_path" arg. One thing to note is that: If you plan to only do step 1 SFT, adding more single-response datasets is definitely beneficial. However, if you do plan to do steps 2 and 3, then adding too many single-response datasets during SFT could backfire: these data could be different from the data used for steps 2/3, generating different distributions which could cause training instability/worse model quality during step 2/3. That is part of the reason why we focused on trying the datasets with two responses and the preference, and always split a dataset into all 3 steps.

If you have your own dataset in local files, you can also use it by following these rules:
* Pass "local/jsonfile" as the dataset name to the "--data_path" argument.
* Put your train data and evaluation data in applications/DeepSpeed-Chat/data/ with name train.json and eval.json.
* The json data in file should be a single list with each item like ***{"prompt": "Human: I have a question. Assistant:", "chosen": "Good answer.", "rejected": "Bad answer."}***.

What is more, when you use your own dataset files and modified some data in them, pay attention to the parameter "reload" of ***create_prompt_dataset*** function. You should pass a True value to it or the cache files will not refresh.

### 🐼 Customizing your own RLHF training pipeline using DeepSpeed-Chat’s RLHF APIs

DeepSpeed-Chat allows users to build their very own RLHF training pipeline using our flexible APIs shown below, which users can use to reconstruct their own RLHF training strategy. This enables a general interface and backend for creating a wide range of RLHF algorithms for research exploration.
Expand Down
19 changes: 16 additions & 3 deletions applications/DeepSpeed-Chat/training/utils/data/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,17 @@ def get_raw_dataset(dataset_name, output_path, seed, local_rank):
elif "lmqg/qag_jaquad" in dataset_name:
return raw_datasets.LmqgQagjaquadDataset(output_path, seed, local_rank,
dataset_name)
elif "local/jsonfile" in dataset_name:
chat_path = os.path.abspath(
os.path.join(os.path.dirname(__file__), os.path.pardir,
os.path.pardir, os.path.pardir))
if not (os.path.isfile(chat_path + '/data/train.json')
and os.path.isfile(chat_path + '/data/eval.json')):
raise RuntimeError(
f"Please check both the train.json and eval.json files in your applications/DeepSpeed-Chat/data directory."
)
return raw_datasets.LocalJsonFileDataset(output_path, seed, local_rank,
dataset_name, chat_path)
else:
raise RuntimeError(
f"We do not have configs for dataset {dataset_name}, but you can add it by yourself in raw_datasets.py."
Expand All @@ -84,7 +95,8 @@ def get_raw_dataset_split_index(local_rank, output_path, dataset_name, seed,
split_name, data_split, split_index,
data_size):
index_file_name = f"{output_path}/{dataset_name}_seed{seed}_{split_name}_{data_split}_{split_index}.npy"
if not os.path.isfile(index_file_name):
# reindex each time when using local jsonfile since it's more likely to get modified
if (not os.path.isfile(index_file_name)) or (dataset_name == 'jsonfile'):
splits = [float(s) for s in data_split.split(',')]
splits_sum = sum(splits)
splits = [split / splits_sum for split in splits]
Expand Down Expand Up @@ -252,7 +264,8 @@ def create_prompt_dataset(local_rank,
tokenizer,
max_seq_len,
end_of_conversation_token="<|endoftext|>",
sft_only_data_path=[]):
sft_only_data_path=[],
reload=False):
"""
Creates the prompt dataset
"""
Expand All @@ -271,7 +284,7 @@ def create_prompt_dataset(local_rank,
buf_create_cache = torch.ByteTensor([not cache_found]).cuda()
torch.distributed.all_reduce(buf_create_cache)

if local_rank <= 0 and buf_create_cache.item() != 0:
if local_rank <= 0 and (buf_create_cache.item() != 0 or reload):
if len(data_path) == 1: # Single dataset.
train_dataset, eval_dataset = create_dataset(
local_rank, data_path[0], data_split, output_path, train_phase,
Expand Down
57 changes: 56 additions & 1 deletion applications/DeepSpeed-Chat/training/utils/data/raw_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ def __init__(self, output_path, seed, local_rank, dataset_name):
self.output_path = output_path
self.seed = seed
self.local_rank = local_rank
self.raw_datasets = load_dataset(dataset_name)
if not dataset_name == 'local/jsonfile':
self.raw_datasets = load_dataset(dataset_name)

def get_train_data(self):
return
Expand Down Expand Up @@ -342,6 +343,60 @@ def get_prompt_and_rejected(self, sample):
return None


class LocalJsonFileDataset(PromptRawDataset):

def __init__(self, output_path, seed, local_rank, dataset_name, chat_path):
super().__init__(output_path, seed, local_rank, dataset_name)
self.dataset_name = "local/jsonfile"
self.dataset_name_clean = "jsonfile"
self.raw_datasets = load_dataset('json',
data_files={
"train":
chat_path + '/data/train.json',
"eval":
chat_path + '/data/eval.json'
})

def get_train_data(self):
if self.raw_datasets['train'] is not None:
return self.raw_datasets['train']
return None

def get_eval_data(self):
if self.raw_datasets['eval'] is not None:
return self.raw_datasets['eval']
return None

# The prompt should be in the format of: " Human: " + actual_prompt_sentence + " Assistant:"
def get_prompt(self, sample):
if sample['prompt'] is not None:
return " " + sample['prompt']
return None

# The chosen response should be in the format of: " " + actual_response_sentence
def get_chosen(self, sample):
if sample['chosen'] is not None:
return " " + sample['chosen']
return None

# The rejected response should be in the format of: " " + actual_response_sentence
# If the dataset does not have rejected response, return None
def get_rejected(self, sample):
if sample['rejected'] is not None:
return " " + sample['rejected']
return None

def get_prompt_and_chosen(self, sample):
if sample['prompt'] is not None and sample['chosen'] is not None:
return " " + sample['prompt'] + " " + sample['chosen']
return None

def get_prompt_and_rejected(self, sample):
if sample['prompt'] is not None and sample['rejected'] is not None:
return " " + sample['prompt'] + " " + sample['rejected']
return None


# Chinese dataset
class Wangrui6ZhihuKOLDataset(PromptRawDataset):

Expand Down

0 comments on commit 9d12674

Please sign in to comment.