forked from Dakewe-DS1000/Inception-Camelyon17
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpredict.py
executable file
·94 lines (74 loc) · 2.63 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import sys
import argparse
import numpy as np
from PIL import Image
from io import BytesIO
import matplotlib.pyplot as plt
from keras.preprocessing import image
from keras.models import load_model
from keras.applications.inception_v3 import preprocess_input
# InceptionV3框架中固定的图像尺寸
target_size = (299, 299)
# 规定类别的名称
labels = ("background", "normal", "tumer")
# 模型文件路径
model_dir = "D:\\Inception-Camelyon17\\modules\\fine_tune_model.h5"
# 测试文件标签文本
test_label_dir = "F:\\ai_data\\camelyon17\\research_data\\test_label.txt"
# 测试文件的数量
test_file_number = 260
# 预测函数
def predict(model, img):
"""predict function
Arguments:
model {[keras model]} -- [model loaded from training]
img {[image file]} -- [image data]
"""
if img.size != target_size:
img = img.resize(target_size)
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)
preds = model.predict(x)
return preds[0]
def getFileNames(label_dir):
# 获取测试图像标签文本
label_file = open(label_dir, "r")
label_text = label_file.readlines()
file_name = []
class_idx = []
for idx in range(0, test_file_number):
context = label_text[idx]
r = context.split(" ")
file_name.append(r[0])
class_idx.append(r[1].split("\n")[0])
return file_name, class_idx
def main():
# 载入模型
model = load_model(model_dir)
# 获取图像文件名称以及对应标签
file_names, classes_idx = getFileNames(test_label_dir)
positive_num = 0
negative_num = 0
for idx in range(test_file_number):
img = Image.open(file_names[idx])
preds = predict(model, img)
class_name = classes_idx[idx]
img_file_name = file_names[idx]
class_number = len(preds)
pred_max = 0
pred_pos = 0
for i in range(class_number):
if preds[i] > pred_max:
pred_max = preds[i]
pred_pos = i
if int(pred_pos) != int(class_name):
negative_num += 1
print("Error : {0} : {1} ( {2} ) :: {3:3.2f}%".format(img_file_name, pred_pos, class_name, pred_max))
else:
positive_num += 1
accuracy_positive = float(positive_num) / float(test_file_number) * 100.0
accuracy_negative = float(negative_num) / float(test_file_number) * 100.0
print("accuracy : {0:3.2f}[%]".format(accuracy_positive))
if __name__ == '__main__':
main()