-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathCNN-single-demo.py
89 lines (73 loc) · 3.02 KB
/
CNN-single-demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from __future__ import print_function
from aura.aura_loader import read_file
from keras.models import load_model
from aura.decode import decode
from aura.decode import preprocess
from aura.decode import view_image as view
from aura.aura_loader import parse_aura_dimensions
from scipy.ndimage.interpolation import rotate
from sys import stderr
from time import sleep
import numpy as np
print("Loading model...")
model = load_model("Model-12.hf")
print("Model loaded.")
# Prepare paths
root = "../Aura_Data/Dataset/GVRSF-Demoset/"
cancer_path = root + "{256x256x10861}Chunk0.aura"
healthy_path = root + "{136x136x5493}HealthyTestset.aura"
btp_path = root + "{256x256x631}BTPTestset.aura"
cl, cw, cn = parse_aura_dimensions(cancer_path)
hl, hw, hn = parse_aura_dimensions(healthy_path)
bl, bw, bn = parse_aura_dimensions(btp_path)
def query_user(question, n, min=0):
"""
Queries a user from the console, and returns the user's
:param question: Type string that is asked to the user.
:param n: Upper bound
:param min: Lower bound
:return: Integer type
"""
user_question = question + " (" + str(min) + "-" + str(n) + "): "
image_index = input(user_question)
while not image_index.isdigit() or int(image_index) > n or int(image_index) < 0:
stderr.write("\nPlease enter a number between " + str(min) + " and " + str(n) + "\n")
sleep(0.01)
image_index = input(user_question)
return int(image_index)
def get_most_confident_prediction(prediction):
highest_label, highest_confidence = "", 0
for item in prediction:
if item[1] > highest_confidence:
highest_label = item[0]
highest_confidence = item[1]
return highest_label, highest_confidence
# Query users for input
# cancer_image_index = query_user("Choose image from cancerous test set", cn - 1)
# healthy_image_index = query_user("Choose image from healthy test set", hn - 1)
# btp_image_index = query_user("Choose image from another cancerous test set", bn - 1)
healthy_image_index = 3061
cancer_image_index = 1974
btp_image_index = 295
imageHealthy = read_file(healthy_path).T[healthy_image_index]
imageCancer = read_file(cancer_path).T[cancer_image_index]
imageBTP = read_file(btp_path).T[btp_image_index]
print("Processing images...")
# Compile images into one array
all_images = [imageHealthy, imageCancer, imageBTP]
all_predictions = []
# Preprocess all images and plot them.
for index, image in enumerate(all_images):
view(image)
image = rotate(image.astype(np.float32), 270)
all_images[index] = preprocess(image)
print("Images processed.")
print("Analysing images...")
# Use model to predict all images, and compile into all_predictions
for index, image in enumerate(all_images):
all_predictions.append(decode(model.predict(image)))
print("Images analysed. Processing results...")
print("\n---------------------RESULTS---------------------")
# Print out results.
for i, prediction in enumerate(all_predictions):
print("Patient " + str(i) + " is/has " + get_most_confident_prediction(prediction)[0])