-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
52 lines (44 loc) · 1.4 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
import model
tf.logging.set_verbosity(tf.logging.INFO)
def main(unused_argv):
# Load training and eval data
print('Reading labels')
with open('source/train-labels.idx1-ubyte', 'rb') as fd:
train_labels = np.asarray(model.parse_idx(fd))
print('Reading images')
with open('source/train-images.idx3-ubyte', 'rb') as fd:
train_data_raw = model.parse_idx(fd)
vectors = []
for image in train_data_raw:
vector = []
for row in image:
for value in row:
vector.append(float(value)/255.0)
vectors.append(vector)
train_data = np.asarray(vectors, dtype=np.float32)
mnist_classifier = tf.estimator.Estimator(model_fn=model.cnn_model_fn, model_dir="net")
tensors_to_log = {"probabilities": "softmax_tensor"}
logging_hook = tf.train.LoggingTensorHook(
tensors=tensors_to_log,
every_n_iter=50
)
# Train the model
train_input_fn = tf.estimator.inputs.numpy_input_fn(
x={"x": train_data},
y=train_labels,
batch_size=100,
num_epochs=None,
shuffle=True
)
mnist_classifier.train(
input_fn=train_input_fn,
steps=1000,
hooks=[logging_hook]
)
if __name__ == "__main__":
tf.app.run()