Skip to content

Commit a3f4407

Browse files
committed
Clean unnecessary comments
1 parent e959e62 commit a3f4407

File tree

1 file changed

+23
-28
lines changed

1 file changed

+23
-28
lines changed

data.py

+23-28
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,12 @@ def load_dataset(filename):
1616
with open(filename, 'rb') as f:
1717
return pickle.load(f)
1818

19-
def padded_batch_input(input, indices=None, dtype=K.floatx()):
19+
def padded_batch_input(input, indices=None, dtype=K.floatx(), maxlen=None):
2020
if indices is None:
2121
indices = np.arange(len(input))
2222

2323
batch_input = [input[i] for i in indices]
24-
return sequence.pad_sequences(batch_input, dtype=dtype, padding='post')
24+
return sequence.pad_sequences(batch_input, maxlen, dtype, padding='post')
2525

2626
def categorical_batch_target(target, classes, indices=None, dtype=K.floatx()):
2727
if indices is None:
@@ -30,23 +30,24 @@ def categorical_batch_target(target, classes, indices=None, dtype=K.floatx()):
3030
batch_target = [min(target[i], classes-1) for i in indices]
3131
return np_utils.to_categorical(batch_target, classes).astype(dtype)
3232

33-
# @np.vectorize
3433
def lengthGroup(length):
35-
if length < 240:
36-
return 0
37-
if length < 380:
38-
return 1
39-
if length < 520:
40-
return 2
41-
if length < 660:
42-
return 3
34+
if length < 150:
35+
return 0
36+
if length < 240:
37+
return 1
38+
if length < 380:
39+
return 2
40+
if length < 520:
41+
return 3
42+
if length < 660:
4343
return 4
44+
return 5
4445

4546
class BatchGen(object):
4647
def __init__(self, inputs, targets=None, batch_size=None, stop=False,
4748
shuffle=True, balance=False, dtype=K.floatx(),
48-
flatten_targets=True, sort_by_length=False,
49-
groupby=False):
49+
flatten_targets=False, sort_by_length=False,
50+
group=False, maxlen=None):
5051
assert len(set([len(i) for i in inputs])) == 1
5152
assert(not shuffle or not sort_by_length)
5253
self.inputs = inputs
@@ -61,6 +62,10 @@ def __init__(self, inputs, targets=None, batch_size=None, stop=False,
6162
self.balance = balance
6263
self.targets = targets
6364
self.flatten_targets = flatten_targets
65+
if isinstance(maxlen, (list, tuple)):
66+
self.maxlen = maxlen
67+
else:
68+
self.maxlen = [maxlen] * len(inputs)
6469

6570
self.sort_by_length = None
6671
if sort_by_length:
@@ -72,11 +77,9 @@ def __init__(self, inputs, targets=None, batch_size=None, stop=False,
7277
self.generator = self._generator()
7378
self._steps = -(-self.nb_samples // self.batch_size) # round up
7479

75-
self.group_ids = [0, 1, 2, 3, 4] if groupby else None
76-
if groupby is not False:
80+
self.groups = None
81+
if group is not False:
7782
indices = np.arange(self.nb_samples)
78-
# import ipdb
79-
# ipdb.set_trace()
8083

8184
ff = lambda i: lengthGroup(len(inputs[0][i]))
8285

@@ -85,8 +88,6 @@ def __init__(self, inputs, targets=None, batch_size=None, stop=False,
8588
self.groups = itertools.groupby(indices, ff)
8689

8790
self.groups = {k: np.array(list(v)) for k, v in self.groups}
88-
# lengthGroups(np.array([len(_) for _ in inputs[0]]))
89-
# print(4)
9091

9192
def _generator(self):
9293
while True:
@@ -108,22 +109,17 @@ def _generator(self):
108109
np.random.shuffle(v)
109110

110111
tmp = np.concatenate(self.groups.values())
111-
112112
batches = np.array_split(tmp, self._steps)
113113

114114
remainder = []
115115
if len(batches[-1]) < self._steps:
116116
remainder = batches[-1:]
117117
batches = batches[:-1]
118118

119-
print('------', len(tmp), self.batch_size, len(batches))
120119
shuffle(batches)
121-
122120
batches += remainder
123-
124121
permutation = np.concatenate(batches)
125122

126-
127123
else:
128124
permutation = np.arange(self.nb_samples)
129125

@@ -142,11 +138,10 @@ def _generator(self):
142138
# for i in range(0, self.nb_samples, self.batch_size):
143139
# indices = permutation[i : i + self.batch_size]
144140

145-
batch_X = [padded_batch_input(input, indices, self.dtype)
146-
for input in self.inputs]
141+
batch_X = [padded_batch_input(x, indices, self.dtype, maxlen)
142+
for x, maxlen in zip(self.inputs, self.maxlen)]
147143

148144
P = batch_X[0].shape[1]
149-
print("[[[ {} {} {} ]]]]".format(bs, P, lengthGroup(P)))
150145

151146
if not self.targets:
152147
yield batch_X
@@ -182,7 +177,7 @@ def __next__(self):
182177
def steps(self):
183178
if self.sort_by_length is None:
184179
return self._steps
185-
180+
186181
print("Steps was called")
187182
if self.shuffle:
188183
permutation = np.random.permutation(self.nb_samples)

0 commit comments

Comments
 (0)