@@ -8118,29 +8118,6 @@ def aten_std_mean_correction(
81188118 return op .Sqrt (var ), mean
81198119
81208120
8121- def _center_window_around_zeros_if_needed (
8122- window : TFloat , n_fft : int
8123- ) -> TFloat :
8124- # first dimension
8125- n_win = op .Shape (window , start = 0 , end = 1 )
8126-
8127- left = op .Div (op .Sub (n_fft , n_win ), op .Constant (value_ints = [2 ]))
8128-
8129- right = op .Sub (op .Sub (n_fft , left ), n_win )
8130- left = op .Reshape (left , op .Constant (value_ints = [1 ]))
8131- right = op .Reshape (right , op .Constant (value_ints = [1 ]))
8132-
8133- left_win = op .Expand (op .Constant (value_ints = [0 ]), left )
8134- right_win = op .Expand (op .Constant (value_ints = [0 ]), right )
8135- right_win = op .CastLike (right_win , window )
8136- left_win = op .CastLike (left_win , window )
8137- window_padded = op .Concat (left_win , window , right_win , axis = 0 )
8138-
8139- # Center window around zeros if needed (required by ONNX's STFT)
8140- window = op .Where (op .Less (op .Squeeze (n_win ), n_fft ), window_padded , window )
8141- return window
8142-
8143-
81448121def _create_window_from_win_length (win_length : int , n_fft : int ) -> TFloat :
81458122 left = op .Div (op .Sub (n_fft , win_length ), op .Constant (value_ints = [2 ]))
81468123
@@ -8161,9 +8138,7 @@ def _create_window_from_n_fft(n_fft: int) -> TFloat:
81618138 return window
81628139
81638140
8164- def _normalize_fft_result (
8165- signal : TFloat , result : TFloat , n_fft : int
8166- ) -> TFloat :
8141+ def _normalize_fft_result (signal : TFloat , result : TFloat , n_fft : int ) -> TFloat :
81678142 n_fft_tensor = op .Reshape (n_fft , op .Constant (value_ints = [1 ]))
81688143 sqrt_nfft = op .Sqrt (op .CastLike (n_fft_tensor , signal ))
81698144 result = op .Div (result , sqrt_nfft )
@@ -8208,14 +8183,28 @@ def aten_stft(
82088183 frame_length_const = op .Reshape (n_fft , op .Constant (value_ints = [1 ]))
82098184
82108185 # Pre-process input if needed
8211- is_signal_rank1 = self . shape is not None and len (self .shape ) == 1
8186+ is_signal_rank1 = len (self .shape ) == 1
82128187 if is_signal_rank1 :
82138188 # Add a batch dimension
82148189 self = op .Identity (op .Unsqueeze (self , op .Constant (value_ints = [0 ])))
82158190
82168191 # Get window and make sure it's the same size as `win_length` or `n_fft`
82178192 if window is not None and window .shape [0 ] is not None :
8218- window = _center_window_around_zeros_if_needed (window , n_fft )
8193+ # first dimension
8194+ n_win = op .Shape (window , start = 0 , end = 1 )
8195+ # Center window around zeros if needed (required by ONNX's STFT)
8196+ if n_win < n_fft :
8197+ left = op .Div (op .Sub (n_fft , n_win ), op .Constant (value_ints = [2 ]))
8198+
8199+ right = op .Sub (op .Sub (n_fft , left ), n_win )
8200+ left = op .Reshape (left , op .Constant (value_ints = [1 ]))
8201+ right = op .Reshape (right , op .Constant (value_ints = [1 ]))
8202+
8203+ left_win = op .Expand (op .Constant (value_ints = [0 ]), left )
8204+ right_win = op .Expand (op .Constant (value_ints = [0 ]), right )
8205+ right_win = op .CastLike (right_win , window )
8206+ left_win = op .CastLike (left_win , window )
8207+ window = op .Concat (left_win , window , right_win , axis = 0 )
82198208 elif window is None :
82208209 if win_length is not None :
82218210 window = _create_window_from_win_length (win_length , n_fft )
@@ -8226,9 +8215,7 @@ def aten_stft(
82268215 onesided = 1
82278216 else :
82288217 onesided = 0
8229- result = _aten_stft_onnx (
8230- self , frame_step_const , window , frame_length_const , onesided
8231- )
8218+ result = _aten_stft_onnx (self , frame_step_const , window , frame_length_const , onesided )
82328219 # Remove batch dimension, if needed
82338220 if is_signal_rank1 :
82348221 result = op .Squeeze (result , op .Constant (value_ints = [0 ]))
0 commit comments