Skip to content

Commit

Permalink
Added parallel STFT implementation (#113)
Browse files Browse the repository at this point in the history
* Added parallel STFT implementation

* Implemented requested changes
  • Loading branch information
JPery authored Dec 3, 2020
1 parent f41eb4b commit f7c5b13
Show file tree
Hide file tree
Showing 5 changed files with 154 additions and 8 deletions.
70 changes: 70 additions & 0 deletions kapre/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,12 @@
_CH_DEFAULT_STR (str): 'default', a pre-defined string.
"""
import multiprocessing
from tensorflow.keras import backend as K
import tensorflow as tf
from tensorflow.python.ops.signal import shape_ops, fft_ops, window_ops, spectral_ops
from tensorflow.python.framework import ops
from functools import partial
import numpy as np
import librosa

Expand Down Expand Up @@ -248,3 +252,69 @@ def mu_law_decoding(signal_mu, quantization_channels):
tf.math.sign(signal) * (tf.math.exp(tf.math.abs(signal) * tf.math.log1p(mu)) - 1.0) / mu
)
return signal


def parallel_stft(
signals,
frame_length,
frame_step,
fft_length=None,
window_fn=window_ops.hann_window,
pad_end=False,
name=None,
):
"""Workaround for a parallel implementation of tf.signal.stft
See `Wikipedia <https://en.wikipedia.org/wiki/Short-time_Fourier_transform>`_ for more details.
Args:
signals: A `[..., samples]` `float32`/`float64` `Tensor` of real-valued
signals.
frame_length: An integer scalar `Tensor`. The window length in samples.
frame_step: An integer scalar `Tensor`. The number of samples to step.
fft_length: An integer scalar `Tensor`. The size of the FFT to apply.
If not provided, uses the smallest power of 2 enclosing `frame_length`.
window_fn: A callable that takes a window length and a `dtype` keyword
argument and returns a `[window_length]` `Tensor` of samples in the
provided datatype. If set to `None`, no windowing is used.
pad_end: Whether to pad the end of `signals` with zeros when the provided
frame length and step produces a frame that lies partially past its end.
name: An optional name for the operation.
Returns:
A `[..., frames, fft_unique_bins]` `Tensor` of `complex64`/`complex128`
STFT values where `fft_unique_bins` is `fft_length // 2 + 1` (the unique
components of the FFT).
Raises:
ValueError: If `signals` is not at least rank 1, `frame_length` is
not scalar, or `frame_step` is not scalar.
"""
# If GPU available we return the default stft function
if len(tf.config.get_visible_devices('GPU')) > 0:
return tf.signal.stft(
signals, frame_length, frame_step, fft_length, window_fn, pad_end, name
)

# Else we return our implementation using map_fn
with ops.name_scope(name, 'stft', [signals, frame_length, frame_step]):
signals = ops.convert_to_tensor(signals, name='signals')
signals.shape.with_rank_at_least(1)
frame_length = ops.convert_to_tensor(frame_length, name='frame_length')
frame_length.shape.assert_has_rank(0)
frame_step = ops.convert_to_tensor(frame_step, name='frame_step')
frame_step.shape.assert_has_rank(0)
if fft_length is None:
fft_length = spectral_ops._enclosing_power_of_two(frame_length)
else:
fft_length = ops.convert_to_tensor(fft_length, name='fft_length')
framed_signals = shape_ops.frame(signals, frame_length, frame_step, pad_end=pad_end)
if window_fn is not None:
window = window_fn(frame_length, dtype=framed_signals.dtype)
framed_signals *= window
return tf.map_fn(
partial(fft_ops.rfft, fft_length=[fft_length]),
framed_signals,
fn_output_signature=tf.complex64,
parallel_iterations=multiprocessing.cpu_count(), # or how many parallel ops you see fit
)
30 changes: 30 additions & 0 deletions kapre/composed.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def get_stft_magnitude_layer(
input_data_format='default',
output_data_format='default',
name='stft_magnitude',
use_parallel_stft=False,
):
"""A function that returns a stft magnitude layer.
The layer is a `keras.Sequential` model consists of `STFT`, `Magnitude`, and optionally `MagnitudeToDecibel`.
Expand Down Expand Up @@ -72,6 +73,10 @@ def get_stft_magnitude_layer(
`'channels_first'` if you want `(batch, channels, time, frequency)`
Defaults to the setting of your Keras configuration. (tf.keras.backend.image_data_format())
name (str): name of the returned layer
use_parallel_stft (bool): Whether to parallelize stft when running on CPU. If `True`, it uses
`kapre.backend.parallel_stft()` which uses multi-processing. If `False`, it uses Tensorflow
`tf.signal.stft`, which is significant slower than other STFT implementations such as librosa
or scipy on CPU. It does not affect the behavior when running on GPUs
Note:
STFT magnitude represents a linear-frequency spectrum of audio signal and probably the most popular choice
Expand Down Expand Up @@ -121,6 +126,7 @@ def get_stft_magnitude_layer(
pad_end=pad_end,
input_data_format=input_data_format,
output_data_format=output_data_format,
use_parallel_stft=use_parallel_stft,
)

stft_to_stftm = Magnitude()
Expand Down Expand Up @@ -156,6 +162,7 @@ def get_melspectrogram_layer(
input_data_format='default',
output_data_format='default',
name='melspectrogram',
use_parallel_stft=False,
):
"""A function that returns a melspectrogram layer, which is a `keras.Sequential` model consists of
`STFT`, `Magnitude`, `ApplyFilterbank(_mel_filterbank)`, and optionally `MagnitudeToDecibel`.
Expand Down Expand Up @@ -190,6 +197,10 @@ def get_melspectrogram_layer(
`'channels_first'` if you want `(batch, channels, time, frequency)`
Defaults to the setting of your Keras configuration. (tf.keras.backend.image_data_format())
name (str): name of the returned layer
use_parallel_stft (bool): Whether to parallelize stft when running on CPU. If `True`, it uses
`kapre.backend.parallel_stft()` which uses multi-processing. If `False`, it uses Tensorflow
`tf.signal.stft`, which is significant slower than other STFT implementations such as librosa
or scipy on CPU. It does not affect the behavior when running on GPUs
Note:
Melspectrogram is originally developed for speech applications and has been *very* widely used for audio signal
Expand Down Expand Up @@ -234,6 +245,7 @@ def get_melspectrogram_layer(
pad_end=pad_end,
input_data_format=input_data_format,
output_data_format=output_data_format,
use_parallel_stft=use_parallel_stft,
)

stft_to_stftm = Magnitude()
Expand Down Expand Up @@ -281,6 +293,7 @@ def get_log_frequency_spectrogram_layer(
input_data_format='default',
output_data_format='default',
name='log_frequency_spectrogram',
use_parallel_stft=False,
):
"""A function that returns a log-frequency STFT layer, which is a `keras.Sequential` model consists of
`STFT`, `Magnitude`, `ApplyFilterbank(_log_filterbank)`, and optionally `MagnitudeToDecibel`.
Expand Down Expand Up @@ -314,6 +327,10 @@ def get_log_frequency_spectrogram_layer(
`'channels_first'` if you want `(batch, channels, time, frequency)`
Defaults to the setting of your Keras configuration. (tf.keras.backend.image_data_format())
name (str): name of the returned layer
use_parallel_stft (bool): Whether to parallelize stft when running on CPU. If `True`, it uses
`kapre.backend.parallel_stft()` which uses multi-processing. If `False`, it uses Tensorflow
`tf.signal.stft`, which is significant slower than other STFT implementations such as librosa
or scipy on CPU. It does not affect the behavior when running on GPUs
Note:
Log-frequency spectrogram is similar to melspectrogram but its frequency axis is perfectly linear to octave scale.
Expand Down Expand Up @@ -349,6 +366,7 @@ def get_log_frequency_spectrogram_layer(
pad_end=pad_end,
input_data_format=input_data_format,
output_data_format=output_data_format,
use_parallel_stft=use_parallel_stft,
)

stft_to_stftm = Magnitude()
Expand Down Expand Up @@ -396,6 +414,7 @@ def get_perfectly_reconstructing_stft_istft(
stft_data_format='default',
stft_name='stft',
istft_name='istft',
use_parallel_stft=False,
):
"""A function that returns two layers, stft and inverse stft, which would be perfectly reconstructing pair.
Expand All @@ -420,6 +439,10 @@ def get_perfectly_reconstructing_stft_istft(
Defaults to the setting of your Keras configuration. (tf.keras.backend.image_data_format())
stft_name (str): name of the returned STFT layer
istft_name (str): name of the returned ISTFT layer
use_parallel_stft (bool): Whether to parallelize stft when running on CPU. If `True`, it uses
`kapre.backend.parallel_stft()` which uses multi-processing. If `False`, it uses Tensorflow
`tf.signal.stft`, which is significant slower than other STFT implementations such as librosa
or scipy on CPU. It does not affect the behavior when running on GPUs
Note:
Expand Down Expand Up @@ -483,6 +506,7 @@ def get_perfectly_reconstructing_stft_istft(
input_data_format=waveform_data_format,
output_data_format=stft_data_format,
name=stft_name,
use_parallel_stft=use_parallel_stft,
)

stft_to_waveform = InverseSTFT(
Expand Down Expand Up @@ -514,6 +538,7 @@ def get_stft_mag_phase(
input_data_format='default',
output_data_format='default',
name='stft_mag_phase',
use_parallel_stft=False,
):
"""A function that returns magnitude and phase of input audio.
Expand Down Expand Up @@ -542,6 +567,10 @@ def get_stft_mag_phase(
`'channels_first'` if you want `(batch, channels, time, frequency)`
Defaults to the setting of your Keras configuration. (tf.keras.backend.image_data_format())
name (str): name of the returned layer
use_parallel_stft (bool): Whether to parallelize stft when running on CPU. If `True`, it uses
`kapre.backend.parallel_stft()` which uses multi-processing. If `False`, it uses Tensorflow
`tf.signal.stft`, which is significant slower than other STFT implementations such as librosa
or scipy on CPU. It does not affect the behavior when running on GPUs
Example:
::
Expand All @@ -566,6 +595,7 @@ def get_stft_mag_phase(
pad_end=pad_end,
input_data_format=input_data_format,
output_data_format=output_data_format,
use_parallel_stft=use_parallel_stft,
)

stft_to_stftm = Magnitude()
Expand Down
12 changes: 10 additions & 2 deletions kapre/time_frequency.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from tensorflow.keras.layers import Layer, Conv2D
from . import backend
from tensorflow.keras import backend as K
from .backend import _CH_FIRST_STR, _CH_LAST_STR, _CH_DEFAULT_STR
from .backend import _CH_FIRST_STR, _CH_LAST_STR, _CH_DEFAULT_STR, parallel_stft


__all__ = [
Expand Down Expand Up @@ -81,6 +81,10 @@ class STFT(Layer):
`'channels_last'` if you want `(batch, time, frequency, channels)` and
`'channels_first'` if you want `(batch, channels, time, frequency)`
Defaults to the setting of your Keras configuration. (`tf.keras.backend.image_data_format()`)
use_parallel_stft (bool): Whether to parallelize stft when running on CPU. If `True`, it uses
`kapre.backend.parallel_stft()` which uses multi-processing. If `False`, it uses Tensorflow
`tf.signal.stft`, which is significant slower than other STFT implementations such as librosa
or scipy on CPU. It does not affect the behavior when running on GPUs
**kwargs: Keyword args for the parent keras layer (e.g., `name`)
Expand All @@ -105,6 +109,7 @@ def __init__(
pad_end=False,
input_data_format='default',
output_data_format='default',
use_parallel_stft=False,
**kwargs,
):
super(STFT, self).__init__(**kwargs)
Expand All @@ -129,6 +134,8 @@ def __init__(
self.output_data_format = K.image_data_format() if odt == _CH_DEFAULT_STR else odt
self.input_data_format = K.image_data_format() if idt == _CH_DEFAULT_STR else idt

self.use_parallel_stft = use_parallel_stft

def call(self, x):
"""
Compute STFT of the input signal. If the `time` axis is not the last axis of `x`, it should be transposed first.
Expand Down Expand Up @@ -156,8 +163,9 @@ def call(self, x):
waveforms = tf.pad(
waveforms, tf.constant([[0, 0], [0, 0], [int(self.n_fft - self.hop_length), 0]])
)
stft_function = parallel_stft if self.use_parallel_stft else tf.signal.stft

stfts = tf.signal.stft(
stfts = stft_function(
signals=waveforms,
frame_length=self.win_length,
frame_step=self.hop_length,
Expand Down
18 changes: 17 additions & 1 deletion tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import tensorflow as tf
from tensorflow.keras import backend as K
from kapre import backend as KPB
from kapre.backend import magnitude_to_decibel, validate_data_format_str
from kapre.backend import magnitude_to_decibel, validate_data_format_str, parallel_stft

from utils import SRC

Expand Down Expand Up @@ -128,5 +128,21 @@ def test_validate_fail():
_ = validate_data_format_str('weird_string')


@pytest.mark.parametrize('frame_length', [1024, 2048])
@pytest.mark.parametrize('frame_step', [256, 512])
@pytest.mark.parametrize('image_size', [256, 512, 1024])
def test_parallel_stft_correctness(frame_length, frame_step, image_size):
prev_physical_devices_list = tf.config.get_visible_devices('GPU')
tf.config.set_visible_devices([], 'GPU')
test_array = tf.cast(tf.random.normal([image_size, image_size]), dtype=tf.float32)
tf.test.TestCase().assertAllClose(
tf.signal.stft(test_array, frame_length, frame_step),
parallel_stft(test_array, frame_length, frame_step),
rtol=1e-4,
atol=1e-3,
)
tf.config.set_visible_devices(prev_physical_devices_list, 'GPU')


if __name__ == '__main__':
pytest.main([__file__])
32 changes: 27 additions & 5 deletions tests/test_time_frequency.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ def allclose_complex_numbers(a, b, atol=1e-3):
@pytest.mark.parametrize('hop_length', [None, 256])
@pytest.mark.parametrize('n_ch', [1, 2, 6])
@pytest.mark.parametrize('data_format', ['default', 'channels_first', 'channels_last'])
def test_spectrogram_correctness(n_fft, hop_length, n_ch, data_format):
@pytest.mark.parametrize('use_parallel_stft', [True, False])
def test_spectrogram_correctness(n_fft, hop_length, n_ch, data_format, use_parallel_stft):
def _get_stft_model(following_layer=None):
# compute with kapre
stft_model = tensorflow.keras.models.Sequential()
Expand All @@ -72,6 +73,7 @@ def _get_stft_model(following_layer=None):
output_data_format=data_format,
input_shape=input_shape,
name='stft',
use_parallel_stft=use_parallel_stft,
)
)
if following_layer is not None:
Expand Down Expand Up @@ -108,7 +110,8 @@ def _get_stft_model(following_layer=None):

@pytest.mark.parametrize('data_format', ['channels_first', 'channels_last'])
@pytest.mark.parametrize('window_name', [None, 'hann_window', 'hamming_window'])
def test_spectrogram_correctness_more(data_format, window_name):
@pytest.mark.parametrize('use_parallel_stft', [True, False])
def test_spectrogram_correctness_more(data_format, window_name, use_parallel_stft):
def _get_stft_model(following_layer=None):
# compute with kapre
stft_model = tensorflow.keras.models.Sequential()
Expand All @@ -123,6 +126,7 @@ def _get_stft_model(following_layer=None):
output_data_format=data_format,
input_shape=input_shape,
name='stft',
use_parallel_stft=use_parallel_stft,
)
)
if following_layer is not None:
Expand Down Expand Up @@ -176,8 +180,19 @@ def _get_stft_model(following_layer=None):
@pytest.mark.parametrize('n_mels', [40])
@pytest.mark.parametrize('mel_f_min', [0.0])
@pytest.mark.parametrize('mel_f_max', [8000])
@pytest.mark.parametrize('use_parallel_stft', [True, False])
def test_melspectrogram_correctness(
n_fft, sr, hop_length, n_ch, data_format, amin, dynamic_range, n_mels, mel_f_min, mel_f_max
n_fft,
sr,
hop_length,
n_ch,
data_format,
amin,
dynamic_range,
n_mels,
mel_f_min,
mel_f_max,
use_parallel_stft,
):
"""Test the correctness of melspectrogram.
Expand All @@ -201,6 +216,7 @@ def _get_melgram_model(return_decibel, amin, dynamic_range, input_shape=None):
input_shape=input_shape,
db_amin=amin,
db_dynamic_range=dynamic_range,
use_parallel_stft=use_parallel_stft,
)
return melgram_model

Expand Down Expand Up @@ -276,7 +292,8 @@ def test_delta():


@pytest.mark.parametrize('data_format', ['default', 'channels_first', 'channels_last'])
def test_mag_phase(data_format):
@pytest.mark.parametrize('use_parallel_stft', [True, False])
def test_mag_phase(data_format, use_parallel_stft):
n_ch = 1
n_fft, hop_length, win_length = 512, 256, 512

Expand All @@ -289,6 +306,7 @@ def test_mag_phase(data_format):
hop_length=hop_length,
input_data_format=data_format,
output_data_format=data_format,
use_parallel_stft=use_parallel_stft,
)
model = tensorflow.keras.models.Sequential()
model.add(mag_phase_layer)
Expand Down Expand Up @@ -316,7 +334,10 @@ def test_mag_phase(data_format):
@pytest.mark.parametrize('waveform_data_format', ['default', 'channels_first', 'channels_last'])
@pytest.mark.parametrize('stft_data_format', ['default', 'channels_first', 'channels_last'])
@pytest.mark.parametrize('hop_ratio', [0.5, 0.25, 0.125])
def test_perfectly_reconstructing_stft_istft(waveform_data_format, stft_data_format, hop_ratio):
@pytest.mark.parametrize('use_parallel_stft', [True, False])
def test_perfectly_reconstructing_stft_istft(
waveform_data_format, stft_data_format, hop_ratio, use_parallel_stft
):
n_ch = 1
src_mono, batch_src, input_shape = get_audio(data_format=waveform_data_format, n_ch=n_ch)
time_axis = 1 if waveform_data_format == 'channels_first' else 0 # non-batch!
Expand All @@ -332,6 +353,7 @@ def test_perfectly_reconstructing_stft_istft(waveform_data_format, stft_data_for
hop_length=hop_length,
waveform_data_format=waveform_data_format,
stft_data_format=stft_data_format,
use_parallel_stft=use_parallel_stft,
)
# Test - [STFT -> ISTFT]
model = tf.keras.models.Sequential([stft, istft])
Expand Down

0 comments on commit f7c5b13

Please sign in to comment.