@@ -2974,10 +2974,7 @@ whisper_span<const float> whisper_mel_calc::hann_window() {
2974
2974
// naive Discrete Fourier Transform
2975
2975
// input is real-valued
2976
2976
// output is complex-valued
2977
- static void dft (const std::vector<float > & in, std::vector<float > & out) {
2978
- int N = in.size ();
2979
-
2980
- out.resize (N*2 );
2977
+ static void dft (const float * in, int N, float * out) {
2981
2978
const int sin_cos_step = SIN_COS_N_COUNT / N;
2982
2979
2983
2980
for (int k = 0 ; k < N; k++) {
@@ -2999,44 +2996,35 @@ static void dft(const std::vector<float> & in, std::vector<float> & out) {
2999
2996
// poor man's implementation - use something better
3000
2997
// input is real-valued
3001
2998
// output is complex-valued
3002
- static void fft (const std::vector<float > & in, std::vector<float > & out) {
3003
- out.resize (in.size ()*2 );
3004
-
3005
- int N = in.size ();
3006
-
2999
+ static void fft (float * in, int N, float * out) {
3007
3000
if (N == 1 ) {
3008
3001
out[0 ] = in[0 ];
3009
3002
out[1 ] = 0 ;
3010
3003
return ;
3011
3004
}
3012
3005
3013
- if (N%2 == 1 ) {
3014
- dft (in, out);
3006
+ const int half_N = N / 2 ;
3007
+ if (N - half_N*2 == 1 ) {
3008
+ dft (in, N, out);
3015
3009
return ;
3016
3010
}
3017
3011
3018
- std::vector<float > even;
3019
- std::vector<float > odd;
3020
-
3021
- even.reserve (N/2 );
3022
- odd.reserve (N/2 );
3023
-
3024
- for (int i = 0 ; i < N; i++) {
3025
- if (i % 2 == 0 ) {
3026
- even.push_back (in[i]);
3027
- } else {
3028
- odd.push_back (in[i]);
3029
- }
3012
+ float * even = in + N;
3013
+ for (int i = 0 ; i < half_N; ++i) {
3014
+ even[i]= in[2 *i];
3030
3015
}
3016
+ float * even_fft = out + 2 * N;
3017
+ fft (even, half_N, even_fft);
3031
3018
3032
- std::vector<float > even_fft;
3033
- std::vector<float > odd_fft;
3034
-
3035
- fft (even, even_fft);
3036
- fft (odd, odd_fft);
3019
+ float * odd = even;
3020
+ for (int i = 0 ; i < half_N; ++i) {
3021
+ odd[i] = in[2 *i + 1 ];
3022
+ }
3023
+ float * odd_fft = even_fft + N;
3024
+ fft (odd, half_N, odd_fft);
3037
3025
3038
3026
const int sin_cos_step = SIN_COS_N_COUNT / N;
3039
- for (int k = 0 ; k < N/ 2 ; k++) {
3027
+ for (int k = 0 ; k < half_N ; k++) {
3040
3028
int idx = k * sin_cos_step; // t = 2*M_PI*k/N
3041
3029
float re = global_cache.cos_vals [idx]; // cos(t)
3042
3030
float im = -global_cache.sin_vals [idx]; // sin(t)
@@ -3047,8 +3035,8 @@ static void fft(const std::vector<float> & in, std::vector<float> & out) {
3047
3035
out[2 *k + 0 ] = even_fft[2 *k + 0 ] + re*re_odd - im*im_odd;
3048
3036
out[2 *k + 1 ] = even_fft[2 *k + 1 ] + re*im_odd + im*re_odd;
3049
3037
3050
- out[2 *(k + N/ 2 ) + 0 ] = even_fft[2 *k + 0 ] - re*re_odd + im*im_odd;
3051
- out[2 *(k + N/ 2 ) + 1 ] = even_fft[2 *k + 1 ] - re*im_odd - im*re_odd;
3038
+ out[2 *(k + half_N ) + 0 ] = even_fft[2 *k + 0 ] - re*re_odd + im*im_odd;
3039
+ out[2 *(k + half_N ) + 1 ] = even_fft[2 *k + 1 ] - re*im_odd - im*re_odd;
3052
3040
}
3053
3041
}
3054
3042
@@ -3066,8 +3054,8 @@ void log_mel_spectrogram_worker_thread(int ith, const float * hann, const std::v
3066
3054
const whisper_filters & filters, whisper_mel_data & mel) {
3067
3055
const auto frame_size = WHISPER_N_FFT;
3068
3056
const auto frame_step = WHISPER_HOP_LENGTH;
3069
- std::vector<float > fft_in (frame_size, 0.0 );
3070
- std::vector<float > fft_out (2 * frame_size );
3057
+ std::vector<float > fft_in (frame_size * 2 , 0.0 );
3058
+ std::vector<float > fft_out (frame_size * 2 * 2 * 2 );
3071
3059
int n_fft = filters.n_fft ;
3072
3060
int i = ith;
3073
3061
@@ -3088,7 +3076,7 @@ void log_mel_spectrogram_worker_thread(int ith, const float * hann, const std::v
3088
3076
}
3089
3077
3090
3078
// FFT
3091
- fft (fft_in, fft_out);
3079
+ fft (fft_in. data (), frame_size, fft_out. data () );
3092
3080
3093
3081
// Calculate modulus^2 of complex numbers
3094
3082
// Use pow(fft_out[2 * j + 0], 2) + pow(fft_out[2 * j + 1], 2) causes inference quality problem? Interesting.
0 commit comments