10
10
class MnistGenerator :
11
11
def __init__ (self , im = None , om = None ):
12
12
self ._im , self ._om = im , om
13
+ self ._cursor = self ._indices = None
13
14
self ._x , self ._y = DataUtil .get_dataset ("mnist" , "../../_Data/mnist.txt" , quantized = True , one_hot = True )
14
15
self ._x = self ._x .reshape (- 1 , 28 , 28 )
15
16
self ._x_train , self ._x_test = self ._x [:1800 ], self ._x [1800 :]
16
17
self ._y_train , self ._y_test = self ._y [:1800 ], self ._y [1800 :]
17
18
19
+ def refresh (self ):
20
+ self ._cursor = 0
21
+ self ._indices = np .random .permutation (len (self ._x_train ))
22
+
18
23
def gen (self , batch , test = False ):
19
24
if batch == 0 :
20
25
if test :
21
26
return self ._x_test , self ._y_test
22
27
return self ._x_train , self ._y_train
23
- batch = np .random .choice (len (self ._x_train ), batch )
24
- return self ._x_train [batch ], self ._y_train [batch ]
28
+ end = min (self ._cursor + batch , len (self ._x_train ))
29
+ start , self ._cursor = self ._cursor , end
30
+ if start == end :
31
+ self .refresh ()
32
+ end = batch
33
+ start = self ._cursor = 0
34
+ indices = self ._indices [start :end ]
35
+ return self ._x_train [indices ], self ._y_train [indices ]
25
36
26
37
if __name__ == '__main__' :
27
38
generator = MnistGenerator ()
@@ -46,8 +57,8 @@ def gen(self, batch, test=False):
46
57
net = tflearn .input_data (shape = [None , 28 , 28 ])
47
58
net = tf .concat (tflearn .lstm (net , 128 , return_seq = True )[- 3 :], axis = 1 )
48
59
net = tflearn .fully_connected (net , 10 , activation = 'softmax' )
49
- net = tflearn .regression (net , optimizer = 'adam' , batch_size = 64 ,
60
+ net = tflearn .regression (net , optimizer = 'adam' , batch_size = 64 , learning_rate = 0.001 ,
50
61
loss = 'categorical_crossentropy' )
51
62
model = tflearn .DNN (net , tensorboard_verbose = 0 )
52
- model .fit (* generator .gen (0 ), n_epoch = 10 , validation_set = generator .gen (0 , True ), show_metric = True )
63
+ model .fit (* generator .gen (0 ), validation_set = generator .gen (0 , True ), show_metric = True )
53
64
print ("Time Cost: {}" .format (time .time () - t ))
0 commit comments