-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathSTFT.py
68 lines (56 loc) · 3.03 KB
/
STFT.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.autograd import Variable
class STFT(torch.nn.Module):
def __init__(self, filter_length=1024, hop_length=512):
super(STFT, self).__init__()
self.filter_length = filter_length
self.hop_length = hop_length
self.forward_transform = None
scale = self.filter_length / self.hop_length
fourier_basis = np.fft.fft(np.eye(self.filter_length))
cutoff = int((self.filter_length / 2 + 1))
fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]),
np.imag(fourier_basis[:cutoff, :])])
forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
inverse_basis = torch.FloatTensor(np.linalg.pinv(scale * fourier_basis).T[:, None, :])
self.register_buffer('forward_basis', forward_basis.float())
self.register_buffer('inverse_basis', inverse_basis.float())
self.num_samples = 219904
def transform(self, input_data):
num_batches = input_data.size(0)
num_samples = input_data.size(1)
self.num_samples = num_samples
input_data = input_data.view(num_batches, 1, num_samples)
forward_transform = F.conv1d(input_data,
Variable(self.forward_basis, requires_grad=False),
stride = self.hop_length,
padding = self.filter_length)
cutoff = int((self.filter_length / 2) + 1)
real_part = forward_transform[:, :cutoff, :]
imag_part = forward_transform[:, cutoff:, :]
magnitude = torch.sqrt(real_part**2 + imag_part**2)
phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data))
return magnitude, phase
def inverse(self, magnitude, phase):
# print("magnitude",magnitude[0,0:2,0:10])
# print("phase",phase[0,0:2,0:10])
recombine_magnitude_phase = torch.cat([magnitude*torch.cos(phase),
magnitude*torch.sin(phase)], dim=1)
# print("recombine_magnitude_phase",recombine_magnitude_phase.size())
# print("recombine_magnitude_phase",recombine_magnitude_phase[0,0:2,0:10])
inverse_transform = F.conv_transpose1d(recombine_magnitude_phase,
Variable(self.inverse_basis, requires_grad=False),
stride=self.hop_length,
padding=0)
inverse_transform = inverse_transform[:, :, self.filter_length:]
inverse_transform = inverse_transform[:, :, :self.num_samples]
# print("inverse_transform",inverse_transform[0,0:2,0:10])
# print("inverse_transform",inverse_transform.size())
return inverse_transform
def forward(self, input_data):
self.magnitude, self.phase = self.transform(input_data)
reconstruction = self.inverse(self.magnitude, self.phase)
return reconstruction