-
Notifications
You must be signed in to change notification settings - Fork 530
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add aten.stft.center
and decomposition
#3880
base: main
Are you sure you want to change the base?
Conversation
This reverts commit ae145aa.
CI errors (undefined symbols referenced by |
Might need to add tosa xfails if you haven't done so already. IIRC Tosa got added to the CI in your last sync. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A general comment before reviewing further:
Is it absolutely necessary to use a loop? The only times I would ever consider using a loop is if there is a necessary loop-carried dependency, but I don't think that is the case here.
Even if there is a nice algorithm for rfft, I don't think decomposing to many rffts in a loop would be more efficient than converting this to, say, a convolution with something like window*exp
(if that is possible with the configurations you are trying to support).
// init_freq_tensor = aten.empty.memory_format([batch_dim?, n_freqs, | ||
// n_frames], | ||
// self.dtype, None, None, None, None) | ||
// final_freq_tensor = prim.loop | ||
// n_frames, %true, init(init_freq_tensor) | ||
// { | ||
// ^bb0(frame, freq_tensor): | ||
// begin = frame * hop_length | ||
// end = begin + n_fft | ||
// narrow_length = min(end, signal_len) - begin | ||
// missing = n_fft - narrow_length | ||
// sliced = torch.narrow(self, axis_signal, begin, narrow_length) : | ||
// !torch.vtensor<[batch_dim?,?],f32> | ||
// padded_sliced = aten.pad(sliced, [0, missing], "constant", 0.0) : | ||
// !torch.vtensor<[batch_dim?,?],f32> | ||
// padded_sliced = tensor_static_info_cast(padded_sliced) : | ||
// !torch.vtensor<[batch_dim?,n_fft],f32> | ||
// weighted = aten.mul.Tensor(padded_sliced, window) : | ||
// !torch.vtensor<[batch_dim?,n_fft],f32> | ||
// f = onesidedBool ? aten.fft_rfft : aten.fft_fft | ||
// freq_slice_sq = f(weighted, None, axis_signal) : | ||
// !torch.vtensor<[batch_dim?,n_freqs],f32> | ||
// freq_slice = aten.unsqueeze(freq_slice_sq, axis_frames) : | ||
// !torch.vtensor<[batch_dim?,n_freqs, 1],f32> | ||
// new_freq_tensor = aten.slice_scatter( | ||
// freq_tensor, freq_slice, | ||
// dim=axis_frames, start=frame, | ||
// end=None, step=1 | ||
// ) | ||
// torch.prim.Loop.condition %true, iter(%new_freq_tensor) | ||
// } | ||
// return final_freq_tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pseudo-IR isn't very helpful as a comment, since you include lit tests for this decomposition.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed in last commit.
if (isa<Torch::NoneType>(hopLength.getType())) { | ||
hopLength = rewriter.create<AtenFloordivIntOp>( | ||
loc, n_fft, | ||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(4))); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
There is a builder for Torch::ConstantIntOp
which allows passing an int directly, which is a bit easier to read.
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(4))); | |
rewriter.create<ConstantIntOp>(loc, 4)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed in last commit.
Value center = op.getCenter(); | ||
bool centerBool; | ||
// TODO: add support for non-constant center and center=True | ||
if (!matchPattern(center, m_TorchConstantBool(¢erBool))) | ||
return rewriter.notifyMatchFailure(op, | ||
"Unsupported: non-constant center"); | ||
if (centerBool) | ||
return rewriter.notifyMatchFailure(op, "Unsupported: center=True"); | ||
|
||
Value normalized = op.getNormalized(); | ||
bool normalizedBool; | ||
// TODO: add support for non-constant normalized and normalized=True | ||
if (!matchPattern(normalized, m_TorchConstantBool(&normalizedBool))) | ||
return rewriter.notifyMatchFailure( | ||
op, "Unsupported: non-constant normalized"); | ||
if (normalizedBool) | ||
return rewriter.notifyMatchFailure(op, "Unsupported: normalized=True"); | ||
|
||
bool onesidedBool; | ||
// Default: True for real input and window, False otherwise. | ||
// TODO: add support for non-constant onesided | ||
if (isa<Torch::NoneType>(op.getOnesided().getType())) { | ||
Type dtype = selfType.getDtype(); | ||
onesidedBool = !isa<mlir::ComplexType>(dtype); | ||
} else if (!matchPattern(op.getOnesided(), | ||
m_TorchConstantBool(&onesidedBool))) | ||
return rewriter.notifyMatchFailure(op, | ||
"Unsupported: non-constant onesided"); | ||
|
||
Value returnComplex = op.getReturnComplex(); | ||
bool returnComplexBool; | ||
// TODO: add support for non-constant return_complex and return_complex=True | ||
if (!matchPattern(returnComplex, m_TorchConstantBool(&returnComplexBool))) | ||
return rewriter.notifyMatchFailure( | ||
op, "Unsupported: non-constant return_complex"); | ||
if (!returnComplex) | ||
return rewriter.notifyMatchFailure(op, | ||
"Unsupported: return_complex=False"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move all of these match failures before the generation of runtime assert ops in the if (hasWindow)
block.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed in last commit.
The two approaches have different time complexities: assuming N is both the number of frequency samples and the window length (as per default) and T is the number of frames, STFT via Conv would be O(T N^2), while STFT via many FFTs (Cooley-Tukey algorithm) would be O(T N log2(N)).
For a single core execution the many FFTs would give a ~56x speedup. In a parallel execution setting the Conv-based algorithm would have 1 more axis for parallelization (dimension N=512), so which algorithm wins would be highly dependent on the number of cores available and the signal length (and also the batches dimension). Regarding the use of a loop, the many FFTs requires it as the signal length dimension could be unknown. If I remember correctly the loop gets unrolled completely in case the signal length dimension is known. |
The choice to work with
aten.stft.center
instead ofaten.stft
is because the latter doesn't match the signature that gets exposed (see https://pytorch.org/docs/stable/generated/torch.stft.html).