-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathfrom_tfrecord.py
48 lines (43 loc) · 1.49 KB
/
from_tfrecord.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
"""
File: from_tfrecord.py
Author: Kwon-Young Choi
Email: [email protected]
Date: 2018-11-12
Description: read nsynth dataset from tfrecord file
"""
import tensorflow as tf
import autodebug
def parser(serialized_example):
features = tf.parse_single_example(
serialized_example,
features={
'note': tf.FixedLenFeature([], tf.int64),
'instrument': tf.FixedLenFeature([], tf.int64),
'pitch': tf.FixedLenFeature([], tf.int64),
'velocity': tf.FixedLenFeature([], tf.int64),
'sample_rate': tf.FixedLenFeature([], tf.int64),
'audio': tf.FixedLenSequenceFeature(
shape=[], dtype=tf.float32, allow_missing=True),
'qualities': tf.FixedLenSequenceFeature(
shape=[], dtype=tf.int64, allow_missing=True),
'instrument_family': tf.FixedLenFeature([], tf.int64),
'instrument_source': tf.FixedLenFeature([], tf.int64),
})
return features
data_path = ['data/nsynth-test.tfrecord']
dataset = tf.data.TFRecordDataset(data_path)
dataset = dataset.map(parser)
dataset = dataset.batch(32)
iterator = dataset.make_one_shot_iterator()
batch_notes = iterator.get_next()
with tf.Session() as sess:
cpt = 0
while True:
print(cpt)
try:
out = sess.run(batch_notes)
for key, value in out.items():
print(key, value.dtype, value.shape)
except tf.errors.OutOfRangeError:
break
cpt += 1