1
+ '''
1
2
import math
2
3
import torch
3
4
import torch.nn as nn
4
5
import torch.nn.functional as F
5
6
from torch.distributions import Normal, Uniform, HalfNormal
6
7
7
8
from torchaudio_contrib import STFT, TimeStretch, MelFilterbank, ComplexNorm, ApplyFilterbank
9
+ '''
8
10
9
11
12
+ from torchaudio .transforms import Spectrogram , MelSpectrogram , ComplexNorm
13
+
14
+ def _num_stft_bins (lengths , fft_length , hop_length , pad ):
15
+ return (lengths + 2 * pad - fft_length + hop_length ) // hop_length
16
+
17
+ class MelspectrogramStretch (MelSpectrogram ):
18
+
19
+ def __init__ (self , hop_length = None ,
20
+ sample_rate = 44100 ,
21
+ num_mels = 128 ,
22
+ fft_length = 2048 ,
23
+ norm = 'whiten' ,
24
+ stretch_param = [0.4 , 0.4 ]):
25
+
26
+ super (MelspectrogramStretch , self ).__init__ (sample_rate = sample_rate ,
27
+ n_fft = fft_length ,
28
+ hop_length = hop_length ,
29
+ n_mels = num_mels )
30
+
31
+ self .stft = Spectrogram (n_fft = self .n_fft , win_length = self .win_length ,
32
+ hop_length = self .hop_length , pad = self .pad ,
33
+ power = None , normalized = False )
34
+
35
+ # Augmentation
36
+ self .prob = stretch_param [0 ]
37
+ self .random_stretch = RandomTimeStretch (stretch_param [1 ],
38
+ self .hop_length ,
39
+ self .n_fft // 2 + 1 ,
40
+ fixed_rate = None )
41
+
42
+ # Normalization (pot spec processing)
43
+ self .complex_norm = ComplexNorm (power = 2. )
44
+ self .norm = SpecNormalization (norm )
45
+
46
+ def forward (self , x , lengths = None ):
47
+ x = self .stft (x )
48
+
49
+ if lengths is not None :
50
+ lengths = _num_stft_bins (lengths , self .n_fft , self .hop_length , self .n_fft // 2 )
51
+ lengths = lengths .long ()
52
+
53
+ if torch .rand (1 )[0 ] <= self .prob and self .training :
54
+ # Stretch spectrogram in time using Phase Vocoder
55
+ x , rate = self .random_stretch (x )
56
+ # Modify the rate accordingly
57
+ lengths = (lengths .float ()/ rate ).long ()+ 1
58
+
59
+ x = self .complex_norm (x )
60
+ x = self .mel_scale (x )
61
+
62
+ # Normalize melspectrogram
63
+ x = self .norm (x )
64
+
65
+ if lengths is not None :
66
+ return x , lengths
67
+ return x
68
+
69
+ def __repr__ (self ):
70
+ return self .__class__ .__name__ + '()'
71
+
72
+
73
+ import numpy as np
74
+ import torch
75
+ import torch .nn as nn
76
+
77
+ from torchaudio .transforms import TimeStretch , AmplitudeToDB
78
+ from torch .distributions import Uniform
79
+
80
+ class RandomTimeStretch (TimeStretch ):
81
+
82
+ def __init__ (self , max_perc , hop_length = None , n_freq = 201 , fixed_rate = None ):
83
+
84
+ super (RandomTimeStretch , self ).__init__ (hop_length , n_freq , fixed_rate )
85
+ self ._dist = Uniform (1. - max_perc , 1 + max_perc )
86
+
87
+ def forward (self , x ):
88
+ rate = self ._dist .sample ().item ()
89
+ return super (RandomTimeStretch , self ).forward (x , rate ), rate
90
+
91
+
92
+ class SpecNormalization (nn .Module ):
93
+
94
+ def __init__ (self , norm_type , top_db = 80.0 ):
95
+
96
+ super (SpecNormalization , self ).__init__ ()
97
+
98
+ if 'db' == norm_type :
99
+ self ._norm = AmplitudeToDB (stype = 'power' , top_db = top_db )
100
+ elif 'whiten' == norm_type :
101
+ self ._norm = lambda x : self .z_transform (x )
102
+ else :
103
+ self ._norm = lambda x : x
104
+
105
+
106
+ def z_transform (self , x ):
107
+ # Independent mean, std per batch
108
+ non_batch_inds = [1 , 2 , 3 ]
109
+ mean = x .mean (non_batch_inds , keepdim = True )
110
+ std = x .std (non_batch_inds , keepdim = True )
111
+ x = (x - mean )/ std
112
+ return x
113
+
114
+ def forward (self , x ):
115
+ return self ._norm (x )
116
+
117
+
118
+ '''
10
119
def amplitude_to_db(spec, ref=1.0, amin=1e-10, top_db=80):
11
120
"""
12
121
Amplitude spectrogram to the db scale
13
122
"""
14
123
power = spec**2
15
124
return power_to_db(power, ref, amin, top_db)
16
125
17
-
18
126
def power_to_db(spec, ref=1.0, amin=1e-10, top_db=80.0):
19
127
"""
20
128
Power spectrogram to the db scale
@@ -41,7 +149,6 @@ def power_to_db(spec, ref=1.0, amin=1e-10, top_db=80.0):
41
149
#log_spec /= log_spec.max()
42
150
return log_spec
43
151
44
-
45
152
def spec_whiten(spec, eps=1):
46
153
47
154
along_dim = lambda f, x: f(x, dim=-1).view(-1,1,1,1)
@@ -58,10 +165,6 @@ def spec_whiten(spec, eps=1):
58
165
return resu
59
166
60
167
61
- def _num_stft_bins (lengths , fft_length , hop_length , pad ):
62
- return (lengths + 2 * pad - fft_length + hop_length ) // hop_length
63
-
64
-
65
168
class MelspectrogramStretch(nn.Module):
66
169
67
170
def __init__(self, hop_length=None, num_mels=128, fft_length=2048, norm='whiten', stretch_param=[0.4, 0.4]):
@@ -89,12 +192,15 @@ def __init__(self, hop_length=None, num_mels=128, fft_length=2048, norm='whiten'
89
192
90
193
self.counter = 0
91
194
195
+
196
+
92
197
def forward(self, x, lengths=None):
93
198
x = self.stft(x)
94
199
95
200
if lengths is not None:
96
201
lengths = _num_stft_bins(lengths, self.fft_length, self.hop_length, self.fft_length//2)
97
-
202
+ lengths = lengths.long()
203
+
98
204
if torch.rand(1)[0] <= self.prob and self.training:
99
205
rate = 1 - self.dist.sample()
100
206
x = self.pv(x, rate)
@@ -114,3 +220,4 @@ def __repr__(self):
114
220
param_str = '(num_mels={}, fft_length={}, norm={}, stretch_param={})'.format(
115
221
self.num_mels, self.fft_length, self.norm.__name__, self.stretch_param)
116
222
return self.__class__.__name__ + param_str
223
+ '''
0 commit comments