Skip to content

Commit 7a3d117

Browse files
committed
Switch to using pytorch/audio
1 parent f658ede commit 7a3d117

File tree

7 files changed

+120
-15
lines changed

7 files changed

+120
-15
lines changed

README.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ Example results:
2424

2525
- [soundfile](https://pypi.org/project/SoundFile/): audio loading
2626
- [torchparse](https://github.com/ksanjeevan/torchparse): .cfg easy model definition
27-
- [torchaudio_contrib](https://github.com/keunwoochoi/torchaudio-contrib): Audio transforms on GPU
27+
- [pytorch/audio](https://github.com/pytorch/audio): Audio transforms
2828

2929

3030
#### Features
@@ -160,6 +160,7 @@ Per fold metrics CRNN(Bidirectional, Dropout):
160160
- [x] CRNN entirely defined in .cfg
161161
- [x] Some bug in 'infer'
162162
- [x] Run 10-fold Cross Validation
163+
- [x] Switch over to pytorch/audio since the merge
163164
- [ ] Comment things
164165

165166

data/data_manager.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11

2-
import os, cv2
2+
import os
33
import pandas as pd
44
import numpy as np
55

data/data_sets.py

-2
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@
88
import numpy as np
99
import soundfile as sf
1010
import torch.utils.data as data
11-
import cv2
12-
1311

1412
class FolderDataset(data.Dataset):
1513

net/audio.py

+114-7
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,128 @@
1+
'''
12
import math
23
import torch
34
import torch.nn as nn
45
import torch.nn.functional as F
56
from torch.distributions import Normal, Uniform, HalfNormal
67
78
from torchaudio_contrib import STFT, TimeStretch, MelFilterbank, ComplexNorm, ApplyFilterbank
9+
'''
810

911

12+
from torchaudio.transforms import Spectrogram, MelSpectrogram , ComplexNorm
13+
14+
def _num_stft_bins(lengths, fft_length, hop_length, pad):
15+
return (lengths + 2 * pad - fft_length + hop_length) // hop_length
16+
17+
class MelspectrogramStretch(MelSpectrogram):
18+
19+
def __init__(self, hop_length=None,
20+
sample_rate=44100,
21+
num_mels=128,
22+
fft_length=2048,
23+
norm='whiten',
24+
stretch_param=[0.4, 0.4]):
25+
26+
super(MelspectrogramStretch, self).__init__(sample_rate=sample_rate,
27+
n_fft=fft_length,
28+
hop_length=hop_length,
29+
n_mels=num_mels)
30+
31+
self.stft = Spectrogram(n_fft=self.n_fft, win_length=self.win_length,
32+
hop_length=self.hop_length, pad=self.pad,
33+
power=None, normalized=False)
34+
35+
# Augmentation
36+
self.prob = stretch_param[0]
37+
self.random_stretch = RandomTimeStretch(stretch_param[1],
38+
self.hop_length,
39+
self.n_fft//2+1,
40+
fixed_rate=None)
41+
42+
# Normalization (pot spec processing)
43+
self.complex_norm = ComplexNorm(power=2.)
44+
self.norm = SpecNormalization(norm)
45+
46+
def forward(self, x, lengths=None):
47+
x = self.stft(x)
48+
49+
if lengths is not None:
50+
lengths = _num_stft_bins(lengths, self.n_fft, self.hop_length, self.n_fft//2)
51+
lengths = lengths.long()
52+
53+
if torch.rand(1)[0] <= self.prob and self.training:
54+
# Stretch spectrogram in time using Phase Vocoder
55+
x, rate = self.random_stretch(x)
56+
# Modify the rate accordingly
57+
lengths = (lengths.float()/rate).long()+1
58+
59+
x = self.complex_norm(x)
60+
x = self.mel_scale(x)
61+
62+
# Normalize melspectrogram
63+
x = self.norm(x)
64+
65+
if lengths is not None:
66+
return x, lengths
67+
return x
68+
69+
def __repr__(self):
70+
return self.__class__.__name__ + '()'
71+
72+
73+
import numpy as np
74+
import torch
75+
import torch.nn as nn
76+
77+
from torchaudio.transforms import TimeStretch, AmplitudeToDB
78+
from torch.distributions import Uniform
79+
80+
class RandomTimeStretch(TimeStretch):
81+
82+
def __init__(self, max_perc, hop_length=None, n_freq=201, fixed_rate=None):
83+
84+
super(RandomTimeStretch, self).__init__(hop_length, n_freq, fixed_rate)
85+
self._dist = Uniform(1.-max_perc, 1+max_perc)
86+
87+
def forward(self, x):
88+
rate = self._dist.sample().item()
89+
return super(RandomTimeStretch, self).forward(x, rate), rate
90+
91+
92+
class SpecNormalization(nn.Module):
93+
94+
def __init__(self, norm_type, top_db=80.0):
95+
96+
super(SpecNormalization, self).__init__()
97+
98+
if 'db' == norm_type:
99+
self._norm = AmplitudeToDB(stype='power', top_db=top_db)
100+
elif 'whiten' == norm_type:
101+
self._norm = lambda x: self.z_transform(x)
102+
else:
103+
self._norm = lambda x: x
104+
105+
106+
def z_transform(self, x):
107+
# Independent mean, std per batch
108+
non_batch_inds = [1, 2, 3]
109+
mean = x.mean(non_batch_inds, keepdim=True)
110+
std = x.std(non_batch_inds, keepdim=True)
111+
x = (x - mean)/std
112+
return x
113+
114+
def forward(self, x):
115+
return self._norm(x)
116+
117+
118+
'''
10119
def amplitude_to_db(spec, ref=1.0, amin=1e-10, top_db=80):
11120
"""
12121
Amplitude spectrogram to the db scale
13122
"""
14123
power = spec**2
15124
return power_to_db(power, ref, amin, top_db)
16125
17-
18126
def power_to_db(spec, ref=1.0, amin=1e-10, top_db=80.0):
19127
"""
20128
Power spectrogram to the db scale
@@ -41,7 +149,6 @@ def power_to_db(spec, ref=1.0, amin=1e-10, top_db=80.0):
41149
#log_spec /= log_spec.max()
42150
return log_spec
43151
44-
45152
def spec_whiten(spec, eps=1):
46153
47154
along_dim = lambda f, x: f(x, dim=-1).view(-1,1,1,1)
@@ -58,10 +165,6 @@ def spec_whiten(spec, eps=1):
58165
return resu
59166
60167
61-
def _num_stft_bins(lengths, fft_length, hop_length, pad):
62-
return (lengths + 2 * pad - fft_length + hop_length) // hop_length
63-
64-
65168
class MelspectrogramStretch(nn.Module):
66169
67170
def __init__(self, hop_length=None, num_mels=128, fft_length=2048, norm='whiten', stretch_param=[0.4, 0.4]):
@@ -89,12 +192,15 @@ def __init__(self, hop_length=None, num_mels=128, fft_length=2048, norm='whiten'
89192
90193
self.counter = 0
91194
195+
196+
92197
def forward(self, x, lengths=None):
93198
x = self.stft(x)
94199
95200
if lengths is not None:
96201
lengths = _num_stft_bins(lengths, self.fft_length, self.hop_length, self.fft_length//2)
97-
202+
lengths = lengths.long()
203+
98204
if torch.rand(1)[0] <= self.prob and self.training:
99205
rate = 1 - self.dist.sample()
100206
x = self.pv(x, rate)
@@ -114,3 +220,4 @@ def __repr__(self):
114220
param_str = '(num_mels={}, fft_length={}, norm={}, stretch_param={})'.format(
115221
self.num_mels, self.fft_length, self.norm.__name__, self.stretch_param)
116222
return self.__class__.__name__ + param_str
223+
'''

net/model.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def __init__(self, classes, config={}, state_dict=None):
2626
stretch_param=[0.4, 0.4])
2727

2828
# shape -> (channel, freq, token_time)
29-
self.net = parse_cfg(config['cfg'], in_shape=[in_chan, self.spec.num_mels, 400])
29+
self.net = parse_cfg(config['cfg'], in_shape=[in_chan, self.spec.n_mels, 400])
3030

3131
def _many_to_one(self, t, lengths):
3232
return t[torch.arange(t.size(0)), lengths - 1]
@@ -39,14 +39,13 @@ def safe_param(elem):
3939
#if name.startswith(('conv2d','maxpool2d')):
4040
if isinstance(layer, (nn.Conv2d, nn.MaxPool2d)):
4141
p, k, s = map(safe_param, [layer.padding, layer.kernel_size,layer.stride])
42-
lengths = (lengths + 2*p - k)//s + 1
42+
lengths = ((lengths + 2*p - k)//s + 1).long()
4343

4444
return torch.where(lengths > 0, lengths, torch.tensor(1, device=lengths.device))
4545

4646
def forward(self, batch):
4747
# x-> (batch, time, channel)
4848
x, lengths, _ = batch # unpacking seqs, lengths and srs
49-
5049
# x-> (batch, channel, time)
5150
xt = x.float().transpose(1,2)
5251
# xt -> (batch, channel, freq, time)

run.py

-1
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,6 @@ def train_main(config, resume):
8585
model = getattr(net_module, m_name)(classes, config=config)
8686
num_classes = len(classes)
8787

88-
print(model)
8988

9089
loss = getattr(net_module, config['train']['loss'])
9190
metrics = getattr(net_module, config['metrics'])(num_classes)

train/base_trainer.py

+1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def __init__(self, model, loss, metrics, optimizer, resume, config, train_logger
1818
self.logger = logging.getLogger(self.__class__.__name__)
1919

2020
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21+
2122
self.model = model.to(self.device)
2223

2324
self.loss = loss

0 commit comments

Comments
 (0)