-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdemo_mnist_tf.py
61 lines (45 loc) · 1.84 KB
/
demo_mnist_tf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import pickle
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
np.random.seed(0)
with np.load("data/mnist/mnist.npz", allow_pickle=True) as f:
x_train, y_train = f["x_train"], f["y_train"] # 60000
x_test, y_test = f["x_test"], f["y_test"] # 10000
print(x_train.shape, y_train.shape, x_test.shape, y_test.shape)
x_train = (x_train - 127.5) / 127.5
x_test = (x_test - 127.5) / 127.5
x_train = x_train.reshape(-1, 28*28)
x_test = x_test.reshape(-1, 28*28)
y_train = tf.keras.utils.to_categorical(y_train)
y_test = tf.keras.utils.to_categorical(y_test)
model = tf.keras.Sequential([
tf.keras.layers.Dense(784, activation="relu", input_shape=(x_train.shape[1],)),
tf.keras.layers.Dense(500, activation="relu"),
tf.keras.layers.Dense(200, activation="relu"),
tf.keras.layers.Dense(10, activation="softmax"),
])
model.compile(loss="categorical_crossentropy", metrics=["accuracy"])
print(model.summary())
model.fit(x_train, y_train, epochs=10, validation_data=(x_test, y_test))
score = model.evaluate(x_test, y_test)
print(score)
model.save("data/mnist/mnist.h5") # Netron 13h45
predicted = model.predict(x_test)
print(predicted)
# Gestion des erreurs
# on récupère les données mal prédites
predicted = predicted.argmax(axis=1)
misclass = (y_test.argmax(axis=1) != predicted)
x_test = x_test.reshape((-1, 28, 28))
misclass_images = x_test[misclass,:,:]
misclass_predicted = predicted[misclass]
# on sélectionne un échantillon de ces images
select = np.random.randint(misclass_images.shape[0], size=12)
# on affiche les images et les prédictions (erronées) associées à ces images
for index, value in enumerate(select):
plt.subplot(3,4,index+1)
plt.axis('off')
plt.imshow(misclass_images[value],cmap=plt.cm.gray_r,interpolation="nearest")
plt.title('Predicted: %i' % misclass_predicted[value])
plt.show()