Skip to content

Commit ab8048e

Browse files
committed
Clean up
1 parent 3c12aae commit ab8048e

File tree

1 file changed

+18
-31
lines changed
  • onnxscript/function_libs/torch_lib/ops

1 file changed

+18
-31
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 18 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -8118,29 +8118,6 @@ def aten_std_mean_correction(
81188118
return op.Sqrt(var), mean
81198119

81208120

8121-
def _center_window_around_zeros_if_needed(
8122-
window: TFloat, n_fft: int
8123-
) -> TFloat:
8124-
# first dimension
8125-
n_win = op.Shape(window, start=0, end=1)
8126-
8127-
left = op.Div(op.Sub(n_fft, n_win), op.Constant(value_ints=[2]))
8128-
8129-
right = op.Sub(op.Sub(n_fft, left), n_win)
8130-
left = op.Reshape(left, op.Constant(value_ints=[1]))
8131-
right = op.Reshape(right, op.Constant(value_ints=[1]))
8132-
8133-
left_win = op.Expand(op.Constant(value_ints=[0]), left)
8134-
right_win = op.Expand(op.Constant(value_ints=[0]), right)
8135-
right_win = op.CastLike(right_win, window)
8136-
left_win = op.CastLike(left_win, window)
8137-
window_padded = op.Concat(left_win, window, right_win, axis=0)
8138-
8139-
# Center window around zeros if needed (required by ONNX's STFT)
8140-
window = op.Where(op.Less(op.Squeeze(n_win), n_fft), window_padded, window)
8141-
return window
8142-
8143-
81448121
def _create_window_from_win_length(win_length: int, n_fft: int) -> TFloat:
81458122
left = op.Div(op.Sub(n_fft, win_length), op.Constant(value_ints=[2]))
81468123

@@ -8161,9 +8138,7 @@ def _create_window_from_n_fft(n_fft: int) -> TFloat:
81618138
return window
81628139

81638140

8164-
def _normalize_fft_result(
8165-
signal: TFloat, result: TFloat, n_fft: int
8166-
) -> TFloat:
8141+
def _normalize_fft_result(signal: TFloat, result: TFloat, n_fft: int) -> TFloat:
81678142
n_fft_tensor = op.Reshape(n_fft, op.Constant(value_ints=[1]))
81688143
sqrt_nfft = op.Sqrt(op.CastLike(n_fft_tensor, signal))
81698144
result = op.Div(result, sqrt_nfft)
@@ -8208,14 +8183,28 @@ def aten_stft(
82088183
frame_length_const = op.Reshape(n_fft, op.Constant(value_ints=[1]))
82098184

82108185
# Pre-process input if needed
8211-
is_signal_rank1 = self.shape is not None and len(self.shape) == 1
8186+
is_signal_rank1 = len(self.shape) == 1
82128187
if is_signal_rank1:
82138188
# Add a batch dimension
82148189
self = op.Identity(op.Unsqueeze(self, op.Constant(value_ints=[0])))
82158190

82168191
# Get window and make sure it's the same size as `win_length` or `n_fft`
82178192
if window is not None and window.shape[0] is not None:
8218-
window = _center_window_around_zeros_if_needed(window, n_fft)
8193+
# first dimension
8194+
n_win = op.Shape(window, start=0, end=1)
8195+
# Center window around zeros if needed (required by ONNX's STFT)
8196+
if n_win < n_fft:
8197+
left = op.Div(op.Sub(n_fft, n_win), op.Constant(value_ints=[2]))
8198+
8199+
right = op.Sub(op.Sub(n_fft, left), n_win)
8200+
left = op.Reshape(left, op.Constant(value_ints=[1]))
8201+
right = op.Reshape(right, op.Constant(value_ints=[1]))
8202+
8203+
left_win = op.Expand(op.Constant(value_ints=[0]), left)
8204+
right_win = op.Expand(op.Constant(value_ints=[0]), right)
8205+
right_win = op.CastLike(right_win, window)
8206+
left_win = op.CastLike(left_win, window)
8207+
window = op.Concat(left_win, window, right_win, axis=0)
82198208
elif window is None:
82208209
if win_length is not None:
82218210
window = _create_window_from_win_length(win_length, n_fft)
@@ -8226,9 +8215,7 @@ def aten_stft(
82268215
onesided = 1
82278216
else:
82288217
onesided = 0
8229-
result = _aten_stft_onnx(
8230-
self, frame_step_const, window, frame_length_const, onesided
8231-
)
8218+
result = _aten_stft_onnx(self, frame_step_const, window, frame_length_const, onesided)
82328219
# Remove batch dimension, if needed
82338220
if is_signal_rank1:
82348221
result = op.Squeeze(result, op.Constant(value_ints=[0]))

0 commit comments

Comments
 (0)