Skip to content

Commit 208b853

Browse files
committed
Sturcturing code
1 parent 9cf1c13 commit 208b853

22 files changed

+541
-288
lines changed

.gitignore

+5
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
*.pth
2+
*.wav
3+
speech_commands/
4+
wandb/
5+
16
# Byte-compiled / optimized / DLL files
27
__pycache__/
38
*.py[cod]

README.md

+4
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,6 @@
11
# kws-pytorch
22
A KWS model trained on SpeechCommands dataset, written in PyTorch.
3+
4+
W&B logs:
5+
6+
https://wandb.ai/raccooncoder/kws-dlaudio

augmentations.py

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import torch
2+
from torch import nn
3+
import torchaudio
4+
5+
youtube_noise, _ = torchaudio.load('Cafe sounds ~ Ambient noise-i9a6ReFTHiw.wav')
6+
youtube_noise = youtube_noise.sum(dim=0)
7+
8+
class GaussianNoise(nn.Module):
9+
def __init__(self, mean=0, std=0.05):
10+
super(GaussianNoise, self).__init__()
11+
12+
self.noiser = torch.distributions.Normal(mean, std)
13+
14+
def forward(self, wav):
15+
wav = wav + self.noiser.sample(wav.size())
16+
wav = wav.clamp(-1, 1)
17+
18+
return wav
19+
20+
class YoutubeNoise(nn.Module):
21+
def __init__(self, alpha=0.05):
22+
super(YoutubeNoise, self).__init__()
23+
24+
self.alpha = alpha
25+
self.noise_wav = youtube_noise
26+
27+
def forward(self, wav):
28+
wav = wav + self.alpha * self.noise_wav[:wav.shape[-1]]
29+
wav = wav.clamp(-1, 1)
30+
31+
return wav
File renamed without changes.

configs/config.json

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
{
2+
"target_class": "marvin",
3+
"num_epochs": 20,
4+
"batch_size": 256,
5+
"random_seed": 13,
6+
"img_padding_length": 130,
7+
"enc_hidden_size": 128,
8+
"window_size": 100,
9+
"conv_out_channels": 16,
10+
"conv_kernel_size": 51,
11+
"learning_rate": 0.001,
12+
"dataloader_num_workers": 8,
13+
"weight_decay": 0.001,
14+
"lr_scheduler_step_size": 10,
15+
"lr_scheduler_gamma": 0.1,
16+
"melspec_sample_rate": 16000,
17+
"melspec_n_mels": 64,
18+
"melspec_n_fft": 512,
19+
"melspec_hop_length": 128,
20+
"melspec_f_max": 4000,
21+
"specaug_freq_mask_param": 5,
22+
"specaug_time_mask_param": 5,
23+
"confidence_threshold": 0.9,
24+
"teacher_alpha": 0.6
25+
}

configs/config_gen.py

+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import json
2+
3+
config = dict(
4+
target_class = 'marvin',
5+
num_epochs = 20,
6+
batch_size = 256,
7+
random_seed = 13,
8+
img_padding_length = 130,
9+
enc_hidden_size = 128,
10+
window_size = 100,
11+
conv_out_channels = 16,
12+
conv_kernel_size = 51,
13+
learning_rate = 1e-3,
14+
dataloader_num_workers = 8,
15+
weight_decay = 1e-3,
16+
lr_scheduler_step_size = 10,
17+
lr_scheduler_gamma = 0.1,
18+
melspec_sample_rate = 16000,
19+
melspec_n_mels = 64,
20+
melspec_n_fft = 512,
21+
melspec_hop_length = 128,
22+
melspec_f_max = 4000,
23+
specaug_freq_mask_param = 5,
24+
specaug_time_mask_param = 5,
25+
confidence_threshold = 0.9,
26+
teacher_alpha = 0.6
27+
)
28+
29+
with open('config.json', 'w') as f:
30+
json.dump(config, f, indent=4)

dataset.py

+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import torch
2+
from torch import nn
3+
import torchaudio
4+
from torch.utils.data import Dataset
5+
from augmentations import *
6+
7+
class SpeechCommands(Dataset):
8+
def __init__(self, config, X, y, train=True):
9+
self.paths = X
10+
self.labels = y
11+
self.train = train
12+
self.config = config
13+
14+
def __len__(self):
15+
return len(self.paths)
16+
17+
def __getitem__(self, idx):
18+
img = torch.zeros(1, self.config.melspec_n_mels, self.config.img_padding_length)
19+
wav, sr = torchaudio.load(self.paths[idx])
20+
21+
if self.train:
22+
wav_proc = nn.Sequential(#GaussianNoise(0, 0.01),
23+
YoutubeNoise(0.1),
24+
torchaudio.transforms.MelSpectrogram(sample_rate=self.config.melspec_sample_rate,
25+
n_mels=self.config.melspec_n_mels,
26+
n_fft=self.config.melspec_n_fft,
27+
hop_length=self.config.melspec_hop_length,
28+
f_max=self.config.melspec_f_max),
29+
torchaudio.transforms.FrequencyMasking(freq_mask_param=self.config.specaug_freq_mask_param),
30+
torchaudio.transforms.TimeMasking(time_mask_param=self.config.specaug_time_mask_param)
31+
)
32+
else:
33+
wav_proc = nn.Sequential(torchaudio.transforms.MelSpectrogram(sample_rate=self.config.melspec_sample_rate,
34+
n_mels=self.config.melspec_n_mels,
35+
n_fft=self.config.melspec_n_fft,
36+
hop_length=self.config.melspec_hop_length,
37+
f_max=self.config.melspec_f_max),
38+
)
39+
40+
mel_spectrogram = torch.log(wav_proc(wav) + 1e-9)
41+
img[0, :, :mel_spectrogram.size(2)] = mel_spectrogram
42+
43+
return img.reshape(self.config.melspec_n_mels, self.config.img_padding_length), self.labels[idx]

distill.py

+89
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import wandb
2+
import torch
3+
from torch import nn
4+
from torch.utils.data import DataLoader
5+
from sklearn.model_selection import train_test_split
6+
7+
import json
8+
import glob
9+
import pandas as pd
10+
import numpy as np
11+
12+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
13+
torch.backends.cudnn.benchmark = True
14+
15+
with open('configs/config.json', 'r') as f:
16+
config = json.load(f)
17+
18+
wandb.init(config=config, project="kws-dlaudio")
19+
config = wandb.config
20+
21+
from utils import set_seed
22+
from train import train_distill, evaluate
23+
from inference import inference
24+
from dataset import SpeechCommands
25+
from model import KWSNet
26+
27+
set_seed(config.random_seed)
28+
29+
print(device)
30+
31+
paths = []
32+
labels = []
33+
34+
for path in glob.glob('speech_commands/*/*.wav'):
35+
_, label, _ = path.split('/')
36+
paths.append(path)
37+
labels.append(int(label == config.target_class))
38+
39+
df = pd.DataFrame({'path': paths, 'label': labels})
40+
41+
X_train, X_test, y_train, y_test = train_test_split(np.array(df['path']),
42+
np.array(df['label']),
43+
test_size=0.1,
44+
stratify=np.array(df['label']),
45+
random_state=config.random_seed)
46+
47+
train_dataset = SpeechCommands(config, X_train, y_train)
48+
test_dataset = SpeechCommands(config, X_test, y_test)
49+
50+
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=config.dataloader_num_workers, pin_memory=True)
51+
val_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False, num_workers=config.dataloader_num_workers, pin_memory=True)
52+
53+
student_model = KWSNet(config.enc_hidden_size // 2, config.conv_out_channels, config.conv_kernel_size)
54+
student_model = student_model.to(device)
55+
56+
error = nn.CrossEntropyLoss()
57+
optimizer = torch.optim.Adam(student_model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)
58+
59+
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=config.lr_scheduler_step_size, gamma=config.lr_scheduler_gamma)
60+
61+
teacher_model = KWSNet(config.enc_hidden_size, config.conv_out_channels, config.conv_kernel_size)
62+
teacher_model.load_state_dict(torch.load('checkpoints/teacher_model.pth'))
63+
teacher_model = teacher_model.to(device)
64+
65+
alpha = config.teacher_alpha
66+
67+
for epoch in range(config.num_epochs):
68+
train_distill(epoch, teacher_model, student_model, alpha, optimizer, error, train_loader, device)
69+
evaluate(student_model, optimizer, error, val_loader, device)
70+
lr_scheduler.step()
71+
72+
negative_val = []
73+
positive_val = []
74+
75+
for path, label in zip(X_test, y_test):
76+
if label == 1:
77+
positive_val.append(path)
78+
else:
79+
negative_val.append(path)
80+
81+
82+
path = positive_val[1]
83+
inference('results/student_positive_example.png', student_model, path, noise=True, device=device)
84+
85+
path = negative_val[1]
86+
inference('results/student_negative_example.png', student_model, path, noise=True, device=device)
87+
88+
torch.save(student_model.state_dict(), 'checkpoints/student_model.pth')
89+
wandb.save('checkpoints/student_model.pth')

docker/Dockerfile

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
FROM nvidia/cuda:10.1-cudnn7-devel-ubuntu18.04
2+
3+
ENV TZ=Europe/Moscow
4+
ENV TERM xterm-256color
5+
6+
RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone
7+
8+
RUN apt-get install -y \
9+
python3-pip \
10+
python3-tk \
11+
libboost-all-dev
12+
13+
RUN apt-get -y install git
14+
RUN python3 -m pip install --upgrade pip
15+
COPY requirements.txt .
16+
RUN pip3 install -r requirements.txt
17+
18+
RUN wget http://download.tensorflow.org/data/speech_commands_v0.01.tar.gz -O speech_commands_v0.01.tar.gz
19+
RUN mkdir speech_commands && tar -C speech_commands -xvzf speech_commands_v0.01.tar.gz 1> log
20+
COPY mv.sh .
21+
ENTRYPOINT mv.sh

docker/build.sh

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
docker build -t kws_pytorch .

docker/mv.sh

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
#!/bin/bash
2+
3+
mv speech_commands /home/kws-pytorch
4+
cd /home/kws-pytorch

requirements.txt renamed to docker/requirements.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,5 @@ torch == 1.6.0
44
torchaudio == 0.6.0
55
wandb == 0.10.7
66
librosa
7-
tqdm
7+
tqdm
8+
plotly

docker/run.sh

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
docker run -v $(pwd)/:/home/asr-pytorch -it kws_pytorch

inference.py

+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import wandb
2+
import torch
3+
import torchaudio
4+
from torch import nn
5+
import torch.nn.functional as F
6+
import matplotlib.pyplot as plt
7+
8+
import time
9+
10+
config = wandb.config
11+
12+
def inference(fname, model, path, noise, device):
13+
model.eval()
14+
model.inference_mode()
15+
16+
noise_wav1, _ = torchaudio.load('LJ001-0001.wav')
17+
noise_wav2, _ = torchaudio.load('LJ001-0014.wav')
18+
19+
wav, sr = torchaudio.load(path)
20+
wav_proc = nn.Sequential(torchaudio.transforms.MelSpectrogram(sample_rate=config.melspec_sample_rate,
21+
n_mels=config.melspec_n_mels,
22+
n_fft=config.melspec_n_fft,
23+
hop_length=config.melspec_hop_length,
24+
f_max=config.melspec_f_max))
25+
mel_spectrogram = torch.log(wav_proc(wav) + 1e-9)
26+
27+
if noise:
28+
noise1_melspec = torch.log(wav_proc(noise_wav1) + 1e-9)
29+
noise2_melspec = torch.log(wav_proc(noise_wav2) + 1e-9)
30+
31+
img = torch.cat((noise1_melspec, mel_spectrogram, noise2_melspec), -1)
32+
else:
33+
img = mel_spectrogram
34+
35+
img = img.to(device)
36+
37+
start = time.time()
38+
with torch.no_grad():
39+
outputs = F.softmax(model(img).squeeze(1), dim=-1).detach().cpu().numpy()[:, 0, 1]
40+
41+
finish = time.time()
42+
wandb.log({'Inference time': finish - start})
43+
44+
plt.figure(figsize=(20,10))
45+
plt.plot(range(len(outputs)), outputs)
46+
plt.axhline(y=config.confidence_threshold, color='r', linestyle='-')
47+
wandb.log({fname: wandb.Image(plt)})
48+
plt.savefig(fname, dpi=500)

0 commit comments

Comments
 (0)