-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy patheval2.py
42 lines (33 loc) · 1.33 KB
/
eval2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import model_q2
from keras.optimizers import Adam
from keras.preprocessing.image import ImageDataGenerator
from keras.callbacks import ModelCheckpoint
import numpy as np
img_width, img_height = 224, 224
# valid_data_dir = '/data/datasets/rbonatti/data_processed/2/valid'
valid_data_dir = '/data/datasets/rbonatti/data_processed/data_processed_test/test2'
batch_size = 1
val_samples=10000
if __name__ == "__main__":
network = model_q2.VGG_16('/data/datasets/rbonatti/vgg16_weights_with_name.h5')
network.load_weights('/data/datasets/rbonatti/ml_weights2/weights.25-2.47.hdf5')
adam = Adam(lr=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0.0)
network.compile(optimizer=adam, loss='categorical_crossentropy', metrics=['accuracy'])
# prepare data augmentation configuration
datagen = ImageDataGenerator(rescale=1. / 255)
valid_generator = datagen.flow_from_directory(
valid_data_dir,
target_size=(img_height, img_width),
batch_size=batch_size,
class_mode=None,
color_mode='grayscale',
shuffle=False)
predictions=network.predict_generator(
generator=valid_generator,
val_samples=val_samples
)
predictions=np.argmax(predictions,axis=1)
predictions.astype(int)
a=np.array([1])
predictions=predictions+a
np.savetxt('/data/datasets/rbonatti/ml_prediction_q2_test2.out', predictions, delimiter=',')