@@ -16,12 +16,12 @@ def load_dataset(filename):
16
16
with open (filename , 'rb' ) as f :
17
17
return pickle .load (f )
18
18
19
- def padded_batch_input (input , indices = None , dtype = K .floatx ()):
19
+ def padded_batch_input (input , indices = None , dtype = K .floatx (), maxlen = None ):
20
20
if indices is None :
21
21
indices = np .arange (len (input ))
22
22
23
23
batch_input = [input [i ] for i in indices ]
24
- return sequence .pad_sequences (batch_input , dtype = dtype , padding = 'post' )
24
+ return sequence .pad_sequences (batch_input , maxlen , dtype , padding = 'post' )
25
25
26
26
def categorical_batch_target (target , classes , indices = None , dtype = K .floatx ()):
27
27
if indices is None :
@@ -30,23 +30,24 @@ def categorical_batch_target(target, classes, indices=None, dtype=K.floatx()):
30
30
batch_target = [min (target [i ], classes - 1 ) for i in indices ]
31
31
return np_utils .to_categorical (batch_target , classes ).astype (dtype )
32
32
33
- # @np.vectorize
34
33
def lengthGroup (length ):
35
- if length < 240 :
36
- return 0
37
- if length < 380 :
38
- return 1
39
- if length < 520 :
40
- return 2
41
- if length < 660 :
42
- return 3
34
+ if length < 150 :
35
+ return 0
36
+ if length < 240 :
37
+ return 1
38
+ if length < 380 :
39
+ return 2
40
+ if length < 520 :
41
+ return 3
42
+ if length < 660 :
43
43
return 4
44
+ return 5
44
45
45
46
class BatchGen (object ):
46
47
def __init__ (self , inputs , targets = None , batch_size = None , stop = False ,
47
48
shuffle = True , balance = False , dtype = K .floatx (),
48
- flatten_targets = True , sort_by_length = False ,
49
- groupby = False ):
49
+ flatten_targets = False , sort_by_length = False ,
50
+ group = False , maxlen = None ):
50
51
assert len (set ([len (i ) for i in inputs ])) == 1
51
52
assert (not shuffle or not sort_by_length )
52
53
self .inputs = inputs
@@ -61,6 +62,10 @@ def __init__(self, inputs, targets=None, batch_size=None, stop=False,
61
62
self .balance = balance
62
63
self .targets = targets
63
64
self .flatten_targets = flatten_targets
65
+ if isinstance (maxlen , (list , tuple )):
66
+ self .maxlen = maxlen
67
+ else :
68
+ self .maxlen = [maxlen ] * len (inputs )
64
69
65
70
self .sort_by_length = None
66
71
if sort_by_length :
@@ -72,11 +77,9 @@ def __init__(self, inputs, targets=None, batch_size=None, stop=False,
72
77
self .generator = self ._generator ()
73
78
self ._steps = - (- self .nb_samples // self .batch_size ) # round up
74
79
75
- self .group_ids = [ 0 , 1 , 2 , 3 , 4 ] if groupby else None
76
- if groupby is not False :
80
+ self .groups = None
81
+ if group is not False :
77
82
indices = np .arange (self .nb_samples )
78
- # import ipdb
79
- # ipdb.set_trace()
80
83
81
84
ff = lambda i : lengthGroup (len (inputs [0 ][i ]))
82
85
@@ -85,8 +88,6 @@ def __init__(self, inputs, targets=None, batch_size=None, stop=False,
85
88
self .groups = itertools .groupby (indices , ff )
86
89
87
90
self .groups = {k : np .array (list (v )) for k , v in self .groups }
88
- # lengthGroups(np.array([len(_) for _ in inputs[0]]))
89
- # print(4)
90
91
91
92
def _generator (self ):
92
93
while True :
@@ -108,22 +109,17 @@ def _generator(self):
108
109
np .random .shuffle (v )
109
110
110
111
tmp = np .concatenate (self .groups .values ())
111
-
112
112
batches = np .array_split (tmp , self ._steps )
113
113
114
114
remainder = []
115
115
if len (batches [- 1 ]) < self ._steps :
116
116
remainder = batches [- 1 :]
117
117
batches = batches [:- 1 ]
118
118
119
- print ('------' , len (tmp ), self .batch_size , len (batches ))
120
119
shuffle (batches )
121
-
122
120
batches += remainder
123
-
124
121
permutation = np .concatenate (batches )
125
122
126
-
127
123
else :
128
124
permutation = np .arange (self .nb_samples )
129
125
@@ -142,11 +138,10 @@ def _generator(self):
142
138
# for i in range(0, self.nb_samples, self.batch_size):
143
139
# indices = permutation[i : i + self.batch_size]
144
140
145
- batch_X = [padded_batch_input (input , indices , self .dtype )
146
- for input in self .inputs ]
141
+ batch_X = [padded_batch_input (x , indices , self .dtype , maxlen )
142
+ for x , maxlen in zip ( self .inputs , self . maxlen ) ]
147
143
148
144
P = batch_X [0 ].shape [1 ]
149
- print ("[[[ {} {} {} ]]]]" .format (bs , P , lengthGroup (P )))
150
145
151
146
if not self .targets :
152
147
yield batch_X
@@ -182,7 +177,7 @@ def __next__(self):
182
177
def steps (self ):
183
178
if self .sort_by_length is None :
184
179
return self ._steps
185
-
180
+
186
181
print ("Steps was called" )
187
182
if self .shuffle :
188
183
permutation = np .random .permutation (self .nb_samples )
0 commit comments