-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathreport.py
54 lines (47 loc) · 1.88 KB
/
report.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
from sklearn import metrics
# def get_report(result_file: str, label_file: str):
# with open(label_file, "r", encoding="utf-8") as fr:
# labels = [line.strip() for line in fr]
# with open(result_file, "r", encoding="utf-8") as fr:
# fr = [line.strip().rsplit("\t", 1) for line in fr]
# y_predict = []
# y_label = []
# for line in fr:
# raw_data, y = line
# label = raw_data.split("_!_")[2]
# y_predict.append(labels.index(y))
# y_label.append(labels.index(label))
# print(metrics.classification_report(y_label, y_predict, target_names=labels))
from sklearn import metrics
def get_report(result_file: str, label_file: str):
"""
打印评测结果
如格式不同, 调整代码并将标注和预测结果传入metrics.classification_report即可
:param result_file: 训练后得到的文本分类结果
:param label_file: 模型保存路径下的labels.txt路径
:return:
"""
with open(label_file, "r", encoding="utf-8") as fr:
labels = [line.strip() for line in fr]
with open(result_file, "r", encoding="utf-8") as fr:
fr = [line.strip().rsplit("\t", 1) for line in fr]
y_predict = []
y_label = []
for line in fr:
raw_data, y = line
# y is label is come from result
# print(y)
y = y.strip("\t").strip()
# print(y)
label = raw_data.split("_!_")[2]
label = label.strip("\t").strip()
# label = raw_data.split("_!_")[2].strip()
# print(label)
y_predict.append(labels.index(y))
# print(y_predict)
y_label.append(labels.index(label))
# print(y_label)
print(metrics.classification_report(y_label, y_predict, target_names=labels))
if __name__ == "__main__":
get_report(result_file="resources/cropus/data_result.txt",
label_file="resources/model/new_model/labels.txt")