-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathnsynth.py
110 lines (101 loc) · 4.33 KB
/
nsynth.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
"""
File: nsynth.py
Author: Kwon-Young Choi
Email: [email protected]
Date: 2018-11-13
Description: Load NSynth dataset using pytorch Dataset.
If you want to modify the output of the dataset, use the transform
and target_transform callbacks as ususal.
"""
import os
import json
import glob
import numpy as np
import scipy.io.wavfile
import torch
import torch.utils.data as data
import torchvision.transforms as transforms
from sklearn.preprocessing import LabelEncoder
class NSynth(data.Dataset):
"""Pytorch dataset for NSynth dataset
args:
root: root dir containing examples.json and audio directory with
wav files.
transform (callable, optional): A function/transform that takes in
a sample and returns a transformed version.
target_transform (callable, optional): A function/transform that takes
in the target and transforms it.
blacklist_pattern: list of string used to blacklist dataset element.
If one of the string is present in the audio filename, this sample
together with its metadata is removed from the dataset.
categorical_field_list: list of string. Each string is a key like
instrument_family that will be used as a classification target.
Each field value will be encoding as an integer using sklearn
LabelEncoder.
"""
def __init__(self, root, transform=None, target_transform=None,
blacklist_pattern=[],
categorical_field_list=["instrument_family"]):
"""Constructor"""
assert(isinstance(root, str))
assert(isinstance(blacklist_pattern, list))
assert(isinstance(categorical_field_list, list))
self.root = root
self.filenames = glob.glob(os.path.join(root, "audio/*.wav"))
with open(os.path.join(root, "examples.json"), "r") as f:
self.json_data = json.load(f)
for pattern in blacklist_pattern:
self.filenames, self.json_data = self.blacklist(
self.filenames, self.json_data, pattern)
self.categorical_field_list = categorical_field_list
self.le = []
for i, field in enumerate(self.categorical_field_list):
self.le.append(LabelEncoder())
field_values = [value[field] for value in self.json_data.values()]
self.le[i].fit(field_values)
self.transform = transform
self.target_transform = target_transform
def blacklist(self, filenames, json_data, pattern):
filenames = [filename for filename in filenames
if pattern not in filename]
json_data = {
key: value for key, value in json_data.items()
if pattern not in key
}
return filenames, json_data
def __len__(self):
return len(self.filenames)
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (audio sample, *categorical targets, json_data)
"""
name = self.filenames[index]
_, sample = scipy.io.wavfile.read(name)
target = self.json_data[os.path.splitext(os.path.basename(name))[0]]
categorical_target = [
le.transform([target[field]])[0]
for field, le in zip(self.categorical_field_list, self.le)]
if self.transform is not None:
sample = self.transform(sample)
if self.target_transform is not None:
target = self.target_transform(target)
return [sample, *categorical_target, target]
if __name__ == "__main__":
# audio samples are loaded as an int16 numpy array
# rescale intensity range as float [-1, 1]
toFloat = transforms.Lambda(lambda x: x / np.iinfo(np.int16).max)
# use instrument_family and instrument_source as classification targets
dataset = NSynth(
"../nsynth-test",
transform=toFloat,
blacklist_pattern=["string"], # blacklist string instrument
categorical_field_list=["instrument_family", "instrument_source"])
loader = data.DataLoader(dataset, batch_size=32, shuffle=True)
for samples, instrument_family_target, instrument_source_target, targets \
in loader:
print(samples.shape, instrument_family_target.shape,
instrument_source_target.shape)
print(torch.min(samples), torch.max(samples))