-
Notifications
You must be signed in to change notification settings - Fork 3.3k
/
Copy pathclean_dataset.py
130 lines (109 loc) · 3.92 KB
/
clean_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
"""
Example usage:
python clean_dataset.py /
"2023-11-05_oasst_all.jsonl" /
"2023-11-05_oasst_all.clean.jsonl" /
--instructions "instructions.xlsx"
"""
import argparse
from collections import OrderedDict
import pandas
from oasst_data.reader import read_message_trees
from oasst_data.schemas import ExportMessageNode, ExportMessageTree
from oasst_data.traversal import visit_messages_depth_first
from oasst_data.writer import write_message_trees
def parse_args():
parser = argparse.ArgumentParser(description="filter_dataset")
parser.add_argument(
"input_file_name",
type=str,
help="path to input .jsonl or .jsonl.gz input file",
)
parser.add_argument(
"output_file_name",
type=str,
help="path to output .jsonl or .jsonl.gz file",
)
parser.add_argument("--instructions", type=str, help="xlsx file with instructions")
parser.add_argument("--exclude-nulls", action="store_true", default=False)
args = parser.parse_args()
return args
def main():
args = parse_args()
instructions_df = pandas.read_excel(args.instructions, na_filter=False)
# load dataset and index messages by id
tree_by_id: dict[str, ExportMessageTree] = OrderedDict()
message_by_id: dict[str, ExportMessageNode] = {}
print(f"Reading: {args.input_file_name}")
for message_tree in read_message_trees(args.input_file_name):
tree_by_id[message_tree.message_tree_id] = message_tree
def index_message(msg: ExportMessageNode):
message_by_id[msg.message_id] = msg
visit_messages_depth_first(message_tree.prompt, index_message)
print(f"Loaded {len(tree_by_id)} trees with {len(message_by_id)} messages.")
def count_descendants(msg: ExportMessageNode):
i = 1
if msg.replies:
for r in msg.replies:
i += count_descendants(r)
return i
def delete_message(msg: ExportMessageNode):
if msg.parent_id is None:
tree_by_id.pop(msg.message_id)
print(f"Tree deleted: {msg.message_id}")
else:
parent_msg = message_by_id[msg.parent_id]
try:
parent_msg.replies.remove(msg)
print(f"Branch deleted: {msg.message_id} ({count_descendants(msg)} messages)")
except ValueError:
print(f"Message not found: {msg.message_id}")
# cleaning
print("Cleaning...")
for index, row in instructions_df.iterrows():
id = row["UUID"]
print(f"Cleaning id={id}")
msg = message_by_id.get(id)
if msg is None:
print(f"Not found: {id}")
print(f"Skipping instructions for : {id}")
continue
action = row["Action"]
print(f"Action={action}")
# Delete
if action == "Delete":
print(f"deleting: {id}")
delete_message(msg)
# Replace
elif action == "Replace":
print(f"replace: {id}")
replace = row["Replace"]
msg.text = replace
# Edit
elif action == "Edit":
print(f"edit: {id}")
if row["Category"] == "Copy Code":
find = "\nCopy code\n"
replace = "\n\n"
else:
find = row["Find"]
replace = row["Replace"]
try:
msg.text.index(find) # make sure text is present
msg.text = msg.text.replace(find, replace)
except ValueError as e:
print(e)
# print(f"find not found: {find}")
continue
else:
print(f"Unsupported action {action}")
print("Done")
# write cleaned dataset to output file
print(f"Writing: {args.output_file_name}")
write_message_trees(
args.output_file_name,
tree_by_id.values(),
exclude_none=args.exclude_nulls,
)
if __name__ == "__main__":
main()