Skip to content

Commit 1ca4897

Browse files
committed
init repo
0 parents  commit 1ca4897

12 files changed

+1019
-0
lines changed

.gitignore

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
*.wav
2+
*.mat
3+
egs.py
4+
__pycache__/
5+
data/
6+
.vscode/

compute_cmvn.py

+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
#!/usr/bin/env python
2+
# coding=utf-8
3+
4+
# wujian@2018
5+
6+
import argparse
7+
import pickle
8+
import tqdm
9+
import numpy as np
10+
11+
from dataset import SpectrogramReader
12+
from utils import parse_yaml
13+
14+
def run(args):
15+
num_bins, conf_dict = parse_yaml(args.train_conf)
16+
reader = SpectrogramReader(args.wave_scp, **conf_dict["spectrogram_reader"])
17+
mean = np.zeros(num_bins)
18+
std = np.zeros(num_bins)
19+
num_frames = 0
20+
# D(X) = E(X^2) - E(X)^2
21+
for _, spectrogram in tqdm.tqdm(reader):
22+
num_frames += spectrogram.shape[0]
23+
mean += np.sum(spectrogram, 0)
24+
std += np.sum(spectrogram**2, 0)
25+
mean = mean / num_frames
26+
std = np.sqrt(std / num_frames - mean**2)
27+
with open(args.cmvn_dst, "wb") as f:
28+
cmvn_dict = {"mean": mean, "std": std}
29+
pickle.dump(cmvn_dict, f)
30+
print("Totally processed {} frames".format(num_frames))
31+
print("Global mean: {}".format(mean))
32+
print("Global std: {}".format(std))
33+
34+
35+
if __name__ == '__main__':
36+
parser = argparse.ArgumentParser(
37+
description="Command to compute global cmvn stats")
38+
parser.add_argument(
39+
"wave_scp", type=str, help="Location of mixture wave scripts")
40+
parser.add_argument(
41+
"train_conf", type=str, help="Location of training configure files")
42+
parser.add_argument(
43+
"cmvn_dst", type=str, help="Location to dump cmvn stats")
44+
args = parser.parse_args()
45+
run(args)

dataset.py

+236
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
#!/usr/bin/env python
2+
# coding=utf-8
3+
# wujian@2018
4+
5+
import os
6+
import random
7+
import logging
8+
import pickle
9+
10+
import numpy as np
11+
import torch as th
12+
13+
from torch.nn.utils.rnn import pack_sequence, pad_sequence
14+
15+
from utils import parse_scps, stft, compute_vad_mask, apply_cmvn
16+
17+
logger = logging.getLogger(__name__)
18+
logger.setLevel(logging.INFO)
19+
handler = logging.StreamHandler()
20+
handler.setLevel(logging.INFO)
21+
formatter = logging.Formatter(
22+
"%(asctime)s [%(pathname)s:%(lineno)s - %(levelname)s ] %(message)s")
23+
handler.setFormatter(formatter)
24+
logger.addHandler(handler)
25+
26+
27+
class SpectrogramReader(object):
28+
"""
29+
Wrapper for short-time fourier transform of dataset
30+
"""
31+
32+
def __init__(self, wave_scp, **kwargs):
33+
if not os.path.exists(wave_scp):
34+
raise FileNotFoundError("Could not find file {}".format(wave_scp))
35+
self.stft_kwargs = kwargs
36+
self.wave_dict = parse_scps(wave_scp)
37+
self.wave_keys = [key for key in self.wave_dict.keys()]
38+
logger.info(
39+
"Create SpectrogramReader for {} with {} utterances".format(
40+
wave_scp, len(self.wave_dict)))
41+
42+
def __len__(self):
43+
return len(self.wave_dict)
44+
45+
def __contains__(self, key):
46+
return key in self.wave_dict
47+
48+
# stft
49+
def _load(self, key):
50+
return stft(self.wave_dict[key], **self.stft_kwargs)
51+
52+
# sequential index
53+
def __iter__(self):
54+
for key in self.wave_dict:
55+
yield key, self._load(key)
56+
57+
# random index
58+
def __getitem__(self, key):
59+
if key not in self.wave_dict:
60+
raise KeyError("Could not find utterance {}".format(key))
61+
return self._load(key)
62+
63+
64+
class Dataset(object):
65+
def __init__(self, mixture_reader, targets_reader_list):
66+
self.mixture_reader = mixture_reader
67+
self.keys_list = mixture_reader.wave_keys
68+
self.targets_reader_list = targets_reader_list
69+
70+
def __len__(self):
71+
return len(self.keys_list)
72+
73+
def _has_target(self, key):
74+
for targets_reader in self.targets_reader_list:
75+
if key not in targets_reader:
76+
return False
77+
return True
78+
79+
def _index_by_key(self, key):
80+
"""
81+
Return a tuple like (matrix, [matrix, ...])
82+
"""
83+
if key not in self.mixture_reader or not self._has_target(key):
84+
raise KeyError("Missing targets or mixture")
85+
target_list = [reader[key] for reader in self.targets_reader_list]
86+
return (self.mixture_reader[key], target_list)
87+
88+
def _index_by_num(self, num):
89+
"""
90+
Return a tuple like (matrix, [matrix, ...])
91+
"""
92+
if num >= len(self.keys_list):
93+
raise IndexError("Index out of dataset, {} vs {}".format(
94+
num, len(self.keys_list)))
95+
key = self.keys_list[num]
96+
return self._index_by_key(key)
97+
98+
def _index_by_list(self, list_idx):
99+
"""
100+
Returns a list of tuple like [
101+
(matrix, [matrix, ...]),
102+
(matrix, [matrix, ...]),
103+
...
104+
]
105+
"""
106+
if max(list_idx) >= len(self.keys_list):
107+
raise IndexError("Index list contains index out of dataset")
108+
return [self._index_by_num(index) for index in list_idx]
109+
110+
def __getitem__(self, index):
111+
if type(index) == int:
112+
return self._index_by_num(index)
113+
elif type(index) == str:
114+
return self._index_by_key(index)
115+
elif type(index) == list:
116+
return self._index_by_list(index)
117+
else:
118+
raise KeyError("Unsupported index type(int/str/list)")
119+
120+
121+
class BatchSampler(object):
122+
def __init__(self,
123+
sampler_size,
124+
batch_size=16,
125+
shuffle=True,
126+
drop_last=False):
127+
if batch_size <= 0:
128+
raise ValueError(
129+
"Illegal batch_size(= {}) detected".format(batch_size))
130+
self.batch_size = batch_size
131+
self.drop_last = drop_last
132+
self.sampler_index = list(range(sampler_size))
133+
self.sampler_size = sampler_size
134+
if shuffle:
135+
random.shuffle(self.sampler_index)
136+
137+
def __len__(self):
138+
return self.sampler_size
139+
140+
def __iter__(self):
141+
base = 0
142+
step = self.batch_size
143+
while True:
144+
if base + step > self.sampler_size:
145+
break
146+
yield (self.sampler_index[base:base + step]
147+
if step != 1 else self.sampler_index[base])
148+
base += step
149+
if not self.drop_last and base < self.sampler_size:
150+
yield self.sampler_index[base:]
151+
152+
153+
class DataLoader(object):
154+
"""
155+
Multi/Per utterance loader for DCNet training
156+
"""
157+
158+
def __init__(self,
159+
dataset,
160+
shuffle=True,
161+
batch_size=16,
162+
drop_last=False,
163+
vad_threshold=40,
164+
mvn_dict=None):
165+
self.dataset = dataset
166+
self.vad_threshold = vad_threshold
167+
self.mvn_dict = mvn_dict
168+
self.batch_size = batch_size
169+
self.drop_last = drop_last
170+
self.shuffle = shuffle
171+
if mvn_dict:
172+
logger.info("Using cmvn dictionary from {}".format(mvn_dict))
173+
with open(mvn_dict, "rb") as f:
174+
self.mvn_dict = pickle.load(f)
175+
176+
def __len__(self):
177+
remain = len(self.dataset) % self.batch_size
178+
if self.drop_last or not remain:
179+
return len(self.dataset) // self.batch_size
180+
else:
181+
return len(self.dataset) // self.batch_size + 1
182+
183+
def _transform(self, mixture_specs, targets_specs_list):
184+
"""
185+
Transform from numpy/list to torch types
186+
"""
187+
# compute vad mask before cmvn
188+
vad_mask = compute_vad_mask(
189+
mixture_specs, self.vad_threshold, apply_exp=True)
190+
# apply cmvn
191+
if self.mvn_dict:
192+
mixture_specs = apply_cmvn(mixture_specs, self.mvn_dict)
193+
# compute target embedding index
194+
target_attr = np.argmax(np.array(targets_specs_list), 0)
195+
return {
196+
"num_frames": mixture_specs.shape[0],
197+
"spectrogram": th.tensor(mixture_specs, dtype=th.float32),
198+
"target_attr": th.tensor(target_attr, dtype=th.int64),
199+
"silent_mask": th.tensor(vad_mask, dtype=th.float32)
200+
}
201+
202+
def _process(self, index):
203+
if type(index) is list:
204+
dict_list = sorted(
205+
[self._transform(s, t) for s, t in self.dataset[index]],
206+
key=lambda x: x["num_frames"],
207+
reverse=True)
208+
spectrogram = pack_sequence([d["spectrogram"] for d in dict_list])
209+
target_attr = pad_sequence(
210+
[d["target_attr"] for d in dict_list], batch_first=True)
211+
silent_mask = pad_sequence(
212+
[d["silent_mask"] for d in dict_list], batch_first=True)
213+
return spectrogram, target_attr, silent_mask
214+
elif type(index) is int:
215+
s, t = self.dataset[index]
216+
data_dict = self._transform(s, t)
217+
return data_dict["spectrogram"], \
218+
data_dict["target_attr"], \
219+
data_dict["silent_mask"]
220+
else:
221+
raise ValueError("Unsupported index type({})".format(type(index)))
222+
223+
def __iter__(self):
224+
sampler = BatchSampler(
225+
len(self.dataset),
226+
batch_size=self.batch_size,
227+
shuffle=self.shuffle,
228+
drop_last=self.drop_last)
229+
num_utts = 0
230+
for e, index in enumerate(sampler):
231+
num_utts += (len(index) if type(index) is list else 1)
232+
if not (e + 1) % 100:
233+
logger.info("Processed {} batches, {} utterances".format(
234+
e + 1, num_utts))
235+
yield self._process(index)
236+
logger.info("Processed {} utterances in total".format(num_utts))

dcnet.py

+76
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
#!/usr/bin/env python
2+
# coding=utf-8
3+
# wujian@2018
4+
5+
import torch as th
6+
from torch.nn.utils.rnn import PackedSequence, pad_packed_sequence
7+
8+
9+
def l2_loss(x):
10+
norm = th.norm(x, 2)
11+
return norm**2
12+
13+
14+
def l2_normalize(x, dim=0, eps=1e-12):
15+
assert (dim < x.dim())
16+
norm = th.norm(x, 2, dim, keepdim=True)
17+
return x / (norm + eps)
18+
19+
20+
class DCNet(th.nn.Module):
21+
def __init__(self,
22+
num_bins,
23+
rnn="lstm",
24+
embedding_dim=20,
25+
num_layers=2,
26+
hidden_size=600,
27+
dropout=0.0,
28+
non_linear="tanh",
29+
bidirectional=True):
30+
super(DCNet, self).__init__()
31+
if non_linear not in ['tanh', 'sigmoid']:
32+
raise ValueError(
33+
"Unsupported non-linear type: {}".format(non_linear))
34+
rnn = rnn.upper()
35+
if rnn not in ['RNN', 'LSTM', 'GRU']:
36+
raise ValueError("Unsupported rnn type: {}".format(rnn))
37+
self.rnn = getattr(th.nn, rnn)(
38+
num_bins,
39+
hidden_size,
40+
num_layers,
41+
batch_first=True,
42+
dropout=dropout,
43+
bidirectional=bidirectional)
44+
self.drops = th.nn.Dropout(p=dropout)
45+
self.embed = th.nn.Linear(
46+
hidden_size * 2
47+
if bidirectional else hidden_size, num_bins * embedding_dim)
48+
self.non_linear = {
49+
"tanh": th.nn.functional.tanh,
50+
"sigmoid": th.nn.functional.sigmoid
51+
}[non_linear]
52+
self.embedding_dim = embedding_dim
53+
54+
def forward(self, x, train=True):
55+
is_packed = isinstance(x, PackedSequence)
56+
if not is_packed and x.dim() != 3:
57+
x = th.unsqueeze(x, 0)
58+
x, _ = self.rnn(x)
59+
if is_packed:
60+
x, _ = pad_packed_sequence(x, batch_first=True)
61+
N = x.size(0)
62+
# N x T x H
63+
x = self.drops(x)
64+
# N x T x FD
65+
x = self.embed(x)
66+
x = self.non_linear(x)
67+
68+
if train:
69+
# N x T x FD => N x TF x D
70+
x = x.view(N, -1, self.embedding_dim)
71+
else:
72+
# for inference
73+
# N x T x FD => NTF x D
74+
x = x.view(-1, self.embedding_dim)
75+
x = l2_normalize(x, -1)
76+
return x

requirements.txt

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
numpy==1.13.3
2+
torch==0.4.0
3+
scipy==1.0.0
4+
librosa==0.5.1
5+
tqdm==4.19.4
6+
config==0.3.9
7+
scikit_learn==0.19.1
8+
PyYAML==3.12

run_demo.sh

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#!/usr/bin/env bash
2+
3+
mix_scp=./data/tune/mix.scp
4+
mdl_dir=./tune/2spk_dcnet_a
5+
6+
set -eu
7+
8+
[ -d ./cache ] && rm -rf cache
9+
10+
mkdir cache
11+
12+
shuf $mix_scp | head -n30 > test.scp
13+
14+
./separate.py --dump-pca --num-spks 2 $mdl_dir/train.yaml $mdl_dir/final.pkl test.scp
15+
16+
rm -f test.scp

0 commit comments

Comments
 (0)