Skip to content

Commit bf4cb4a

Browse files
mkycodermike-fzy
andauthored
whisper : optimize fft() function (#2242)
Co-authored-by: Mike Fan <[email protected]>
1 parent e293f17 commit bf4cb4a

File tree

1 file changed

+22
-34
lines changed

1 file changed

+22
-34
lines changed

whisper.cpp

+22-34
Original file line numberDiff line numberDiff line change
@@ -2974,10 +2974,7 @@ whisper_span<const float> whisper_mel_calc::hann_window() {
29742974
// naive Discrete Fourier Transform
29752975
// input is real-valued
29762976
// output is complex-valued
2977-
static void dft(const std::vector<float> & in, std::vector<float> & out) {
2978-
int N = in.size();
2979-
2980-
out.resize(N*2);
2977+
static void dft(const float* in, int N, float* out) {
29812978
const int sin_cos_step = SIN_COS_N_COUNT / N;
29822979

29832980
for (int k = 0; k < N; k++) {
@@ -2999,44 +2996,35 @@ static void dft(const std::vector<float> & in, std::vector<float> & out) {
29992996
// poor man's implementation - use something better
30002997
// input is real-valued
30012998
// output is complex-valued
3002-
static void fft(const std::vector<float> & in, std::vector<float> & out) {
3003-
out.resize(in.size()*2);
3004-
3005-
int N = in.size();
3006-
2999+
static void fft(float* in, int N, float* out) {
30073000
if (N == 1) {
30083001
out[0] = in[0];
30093002
out[1] = 0;
30103003
return;
30113004
}
30123005

3013-
if (N%2 == 1) {
3014-
dft(in, out);
3006+
const int half_N = N / 2;
3007+
if (N - half_N*2 == 1) {
3008+
dft(in, N, out);
30153009
return;
30163010
}
30173011

3018-
std::vector<float> even;
3019-
std::vector<float> odd;
3020-
3021-
even.reserve(N/2);
3022-
odd.reserve(N/2);
3023-
3024-
for (int i = 0; i < N; i++) {
3025-
if (i % 2 == 0) {
3026-
even.push_back(in[i]);
3027-
} else {
3028-
odd.push_back(in[i]);
3029-
}
3012+
float* even = in + N;
3013+
for (int i = 0; i < half_N; ++i) {
3014+
even[i]= in[2*i];
30303015
}
3016+
float* even_fft = out + 2 * N;
3017+
fft(even, half_N, even_fft);
30313018

3032-
std::vector<float> even_fft;
3033-
std::vector<float> odd_fft;
3034-
3035-
fft(even, even_fft);
3036-
fft(odd, odd_fft);
3019+
float* odd = even;
3020+
for (int i = 0; i < half_N; ++i) {
3021+
odd[i] = in[2*i + 1];
3022+
}
3023+
float* odd_fft = even_fft + N;
3024+
fft(odd, half_N, odd_fft);
30373025

30383026
const int sin_cos_step = SIN_COS_N_COUNT / N;
3039-
for (int k = 0; k < N/2; k++) {
3027+
for (int k = 0; k < half_N; k++) {
30403028
int idx = k * sin_cos_step; // t = 2*M_PI*k/N
30413029
float re = global_cache.cos_vals[idx]; // cos(t)
30423030
float im = -global_cache.sin_vals[idx]; // sin(t)
@@ -3047,8 +3035,8 @@ static void fft(const std::vector<float> & in, std::vector<float> & out) {
30473035
out[2*k + 0] = even_fft[2*k + 0] + re*re_odd - im*im_odd;
30483036
out[2*k + 1] = even_fft[2*k + 1] + re*im_odd + im*re_odd;
30493037

3050-
out[2*(k + N/2) + 0] = even_fft[2*k + 0] - re*re_odd + im*im_odd;
3051-
out[2*(k + N/2) + 1] = even_fft[2*k + 1] - re*im_odd - im*re_odd;
3038+
out[2*(k + half_N) + 0] = even_fft[2*k + 0] - re*re_odd + im*im_odd;
3039+
out[2*(k + half_N) + 1] = even_fft[2*k + 1] - re*im_odd - im*re_odd;
30523040
}
30533041
}
30543042

@@ -3066,8 +3054,8 @@ void log_mel_spectrogram_worker_thread(int ith, const float * hann, const std::v
30663054
const whisper_filters & filters, whisper_mel_data & mel) {
30673055
const auto frame_size = WHISPER_N_FFT;
30683056
const auto frame_step = WHISPER_HOP_LENGTH;
3069-
std::vector<float> fft_in(frame_size, 0.0);
3070-
std::vector<float> fft_out(2 * frame_size);
3057+
std::vector<float> fft_in(frame_size * 2, 0.0);
3058+
std::vector<float> fft_out(frame_size * 2 * 2 * 2);
30713059
int n_fft = filters.n_fft;
30723060
int i = ith;
30733061

@@ -3088,7 +3076,7 @@ void log_mel_spectrogram_worker_thread(int ith, const float * hann, const std::v
30883076
}
30893077

30903078
// FFT
3091-
fft(fft_in, fft_out);
3079+
fft(fft_in.data(), frame_size, fft_out.data());
30923080

30933081
// Calculate modulus^2 of complex numbers
30943082
// Use pow(fft_out[2 * j + 0], 2) + pow(fft_out[2 * j + 1], 2) causes inference quality problem? Interesting.

0 commit comments

Comments
 (0)