-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathdata_loader.py
105 lines (93 loc) · 3.73 KB
/
data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import random
import torch
import numpy as np
from torch.utils.data import Dataset
from datasets import load_dataset
class RaceQuestionAnswerGeneration(Dataset):
def __init__(self, tokenizer, data_split, separator='<sep>'):
"""
task:
- input: article (i.e. context)
- output: question <sep> answer
args:
tokenizer: tokenizer
data_split: train, validation, test
"""
data = load_dataset("race", "all", split=data_split)
self.data = data
self.tokenizer = tokenizer
self.separator = separator
self.label_mapping = {label: i for i, label in enumerate(["A", "B", "C", "D"])}
print("RaceQuestionAnswerGeneration Initialized")
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
example = self.data[idx]
# example_id = example["example_id"]
question = example["question"]
context = example["article"]
options = example["options"]
label_example = example["answer"]
answer = options[self.label_mapping[label_example]]
# input & output
input = context
output = question + ' ' + self.separator + ' ' + answer
return {'input': input, 'output': output}
class RaceDistractorGeneration(Dataset):
def __init__(self, tokenizer, data_split, shuffle_distractors=False, separator='<sep>'):
"""
task:
- input: question <sep> answer <sep> article
- output: distractor1 <sep> distractor2 <sep> distractor3
args:
tokenizer: tokenizer
data_split: train, validation, test
"""
data = load_dataset("race", "all", split=data_split)
self.data = data
self.tokenizer = tokenizer
self.separator = separator
self.label_mapping = {label: i for i, label in enumerate(["A", "B", "C", "D"])}
self.all_labels = [0, 1, 2, 3]
self.shuffle_distractors = shuffle_distractors
print("RaceQuestionAnswerGeneration Initialized")
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
example = self.data[idx]
# example_id = example["example_id"]
question = example["question"]
context = example["article"]
options = example["options"]
label_example = example["answer"]
answer_i = self.label_mapping[label_example]
answer = options[answer_i]
distractor_ids = [i for i in self.all_labels if i != answer_i]
if self.shuffle_distractors:
random.shuffle(distractor_ids)
distractors = [options[i] for i in distractor_ids]
# input & output
input = question + ' ' + self.separator + ' ' + answer + ' ' + self.separator + ' ' + context
output = distractors[0] + ' ' + self.separator + ' ' + distractors[1] + ' ' + self.separator + ' ' + distractors[2]
return {'input': input, 'output': output}
class RaceAnsweringModel(Dataset):
def __init__(self,
data_split,
):
"""
"""
data = load_dataset("race", "all", split=data_split)
self.data = data
self.label_mapping = {label: i for i, label in enumerate(["A", "B", "C", "D"])}
self.all_labels = [0, 1, 2, 3]
print("RaceAnsweringModel Initialized")
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
example = self.data[idx]
question = example["question"]
context = example["article"]
options = example["options"]
label_example = example["answer"]
answer_i = self.label_mapping[label_example]
return {'context': context, 'question': question, 'options': options, 'answer_i': answer_i}