Advancing Spatio-Temporal Processing in Spiking Neural Networks through Adaptation
To set up the required environment, run:
conda env create -f environment.yml
To start an experiment, use:
python run.py experiment=<experiment_name> ++logdir=path/to/my/logdir ++datadir=path/to/my/datadir
Notes:
datadir
is mandatory and should contain the datasets.- For SHD and SSC, data is downloaded automatically if not found at
datadir/SHDWrapper
. - BSD and oscillation toy task, datasets are created on the fly, so
datadir
can point to an empty directory. - Results are stored in a local
results
folder unlessresultdir
is specified. <experiment_name>
refers to configurations in./config/experiment/
.
We use Hydra for configuration management. To override parameters, use the ++
syntax. For example, to change the number of training epochs:
python run.py experiment=SHD_SE_adLIF_small ++logdir=path/to/my/logdir ++datadir=path/to/my/datadir ++n_epochs=10
For the BSD task with a different number of classes (Figure 6b):
python run.py experiment=BSD_SE_adLIF ++logdir=path/to/my/logdir ++datadir=path/to/my/datadir ++dataset.num_classes=10
To start an audio compression experiment, use:
python run_compress.py experiment=<experiment_name> ++logdir=path/to/my/logdir ++datadir=path/to/my/datadir
- SE-adLIF:
compress_libri_SE_adLIF
- EF-adLIF:
compress_libri_EF_adLIF
- LIF:
compress_libri_LIF
Model checkpoints for each configuration are available at checkpoints.
To generate wave files from a checkpoint:
generate_waves.py ckpt_path=/path/to/ckpt/example.ckpt source_wave_path=/path/to/libritts/location/ pred_wave_path=/path/to/prediction/ encoder_only=$encoder_flag
$encoder_flag
:true
orfalse
.source_wave_path
can be a single.wav
file or a directory containing.wav
files.- If no valid
.wav
files exist, the clean test-set from LibriTTS (~9h of audio) is used.
Use evaluate_metrics.py
to compute SI-SNR or Visqol:
evaluate_metrics.py metric=$metric source_wave_path=path/to/source/waves pred_wave_path=path/to/model/predictions
$metric
can besi_snr
orvisqol
.- Note: Visqol must be compiled manually following these instructions. Additionally, the project requires either
gcc-9
/g++-9
orgcc-10
/g++-10
. Set the compiler using:
export CC=gcc-9 CXX=g++-9
Furthermore, Visqol relies on Bazel but references an outdated HTTP resource (Armadillo) in its WORKSPACE file. The ressource has been moved here. You should modify the WORKSPACE file to reference your local copy as instructed here.
Global parameters (e.g., device: 'cpu'
, cuda:0
) can be set in config/main.yaml
. These settings are used by PyTorch Lightning’s SingleDeviceStrategy
.
For variable-length sequences (e.g., SHD, SSC), a custom masking procedure is used:
- Data vector: Contains actual data, padded with zeros.
- Block index (
block_idx
): Indicates valid data (1s) and padding (0s). - Target vector: Maps indices to corresponding labels.
data vector: |1011010100101001010000000000000|
|-----data---------|--padding---|
---> time
block_idx: |1111111111111111111100000000000|
target: [-1, 3]
Explanation:
- Block
0
has target-1
(ignored). - Block
1
has target class3
.
data vector: |1 0 1 1 0 0 1 0 0 0 0 0 0|
|-----data---|--padding---|
---> time
block_idx: |1 2 3 4 5 6 7 0 0 0 0 0 0|
target: [-1, 4, 3, 1, 3, 4, 6, 3]
Explanation:
- Multiple blocks (
1-7
) have corresponding target labels. - Padding (
0s
) is ignored during loss computation.
Using this method, per-block predictions can be efficiently gathered using torch.scatter_reduce
, ignoring padded time steps.
This work is licensed under a Creative Commons Attribution-ShareAlike 4.0 International License.