-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathCNN-visualiser.py
62 lines (55 loc) · 2.36 KB
/
CNN-visualiser.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from keras.utils import plot_model
from keras.models import load_model
from vis.visualization.saliency import visualize_saliency
from vis.visualization.activation_maximization import visualize_activation
from matplotlib import pyplot
def generate_model_image(model_path, filename):
"""
Generates a flow chart of the model.
:param model_path: Path of the Keras model to be loaded. Expects string input.
:param filename: Path of file to be generated and downloaded. Expects string input.
:return: None
"""
print("Loading model...")
model = load_model(model_path)
print("Model loaded. Plotting model...")
plot_model(model, to_file=filename, show_shapes=True)
print("Plotting complete. File is ready at model.png")
def visualize_saliency_map(model_path, layer, image, sensitivity):
"""
Visualises saliency maps for the specified image.
:param model_path: Path of the Keras model to be loaded. Expects string input.
:param layer: Which layer of the neural network is to be visualised. Expects integer input.
:param image: Image to be analyzed by the neural network to create saliency map. Expects 2D matrix input.
:param sensitivity: Saliency sensitivity. Goes from 0 to 3. Expects integer input.
:return: None
"""
print("Loading model...")
model = load_model(model_path)
print("Model loaded.")
print("Visualising saliency map...")
layers = visualize_saliency(model=model, layer_idx=layer, filter_indices=None, seed_input=image)
print(layers.shape)
print("Filters visualised.")
pyplot.figure(dpi=300)
pyplot.axes().set_aspect('equal', 'datalim')
pyplot.pcolormesh(layers[:, :, sensitivity])
pyplot.show()
def visualize_feature_maps(model_path, layer):
"""
Visualises the feature maps within a Keras model.
:param model_path: Path of the Keras model to be loaded. Expects string input.
:param layer: Which layer of the model to extract a feature map of. Expects integer input.
:return: None
"""
print("Loading model...")
model = load_model(model_path)
print("Model loaded.")
print("Visualising filters...")
layers = visualize_activation(model, layer_idx=layer)
print(layers.shape)
print("Filters visualised.")
pyplot.figure(dpi=300)
pyplot.axes().set_aspect('equal', 'datalim')
pyplot.pcolormesh(layers[:, :, 0])
pyplot.show()