-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathsplit-data.py
44 lines (40 loc) · 1.57 KB
/
split-data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
"""
Script to split data into train and test CSV files
Example usage:
$python3 split_data.py dataset.csv
$python3 split_data.py dataset.csv <split percentage float>
"""
import sys
from preprocess import write_status
def split(filename, split_percentage=0.1):
print("Splitting data in %s%% train and %s%% test" % (1 - split_percentage,
split_percentage))
save_train = open("%s-train.csv" % filename, "w", encoding="utf-8")
save_test = open("%s-test.csv" % filename, "w", encoding="utf-8")
with open("%s.csv" % filename, "r", encoding="utf-8") as csv:
lines = csv.readlines()
total = len(lines)
split_index = int(total - split_percentage * total)
#train, test = lines[:split_index], lines[split_index:]
for i, line in enumerate(lines):
if i < split_index + 1:
save_train.write(line)
else:
save_test.write(line)
write_status(i + 1, total)
save_test.close()
save_train.close()
print("\nData successfully split and saved in %s-train.csv and "
"%s-test.csv" % (filename, filename))
if __name__ == '__main__':
if len(sys.argv) < 2:
print('Usage: python3 split-data.py <CSV>')
exit()
csv_file_name = sys.argv[1].split(".")[0]
print(csv_file_name)
if len(sys.argv) == 3:
if sys.argv[2]:
split_percentage = float(sys.argv[2])
split(csv_file_name, split_percentage)
else:
split(csv_file_name)