Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ZS snapshot [conflict with ABX_script PR as ABX_script one has been rebased] #14

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
Open
11 changes: 10 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
this repo is based on `CPC_audio` (<https://github.com/facebookresearch/CPC_audio>) repo, but it also contains:
- part of the code used for University of Wroclaw ZeroSpeech2021 submission (our modifications and also merged files from: <https://github.com/facebookresearch/CPC_audio/tree/zerospeech> - `criterion/clustering/`, <https://github.com/tuanh208/CPC_audio/tree/zerospeech> - `feature_loder.py -> buildFeature_batch`, <https://github.com/bootphon/zerospeech2021_baseline> - `scripts/`)
- code used for CPC-CTC paper

Below original README updated with some of our modifications; part of our code is also described in `cpc/README.md`.

--------------------------------------------------------------------------------------------------


# CPC_audio

This code implements the Contrast Predictive Coding algorithm on audio data, as described in the paper [Unsupervised Pretraining Transfers well Across Languages](https://arxiv.org/abs/2002.02848). This is an unsupervised method to train audio features directly from the raw waveform.
Expand Down Expand Up @@ -123,7 +132,7 @@ Will evaluate the speaker separability of the concatenation of the features from

`--gru_level` controls from which layer of autoregressive part of CPC to extract the features. By default it's the last one.

Nullspaces:
### Nullspaces:

To conduct the nullspace experiment, first classify speakers using two factorized matrices `A` (`DIM_EMBEDDING` x `DIM_INBETWEEN`) and `B` (`DIM_INBETWEEN` x `SPEAKERS`). You'll want to extract `A'`, the nullspace of matrix `A` (of size `DIM_EMBEDDING` x (`DIM_EMBEDDING` - `DIM_INBETWEEN`)), to make the embeddings less sensitive to speakers.
```bash
Expand Down
16 changes: 0 additions & 16 deletions centerpush_nonullspace_phoneme_classification.sh

This file was deleted.

18 changes: 0 additions & 18 deletions centerpush_nullspace_phoneme_classification.sh

This file was deleted.

6 changes: 4 additions & 2 deletions cpc/eval/linear_separability.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ def train_step(feature_maker, criterion, data_loader, optimizer, label_key="spea
if centerpushSettings:
centers, pushDeg = centerpushSettings
c_feature = utils.pushToClosestForBatch(c_feature, centers, deg=pushDeg)
encoded_data = utils.pushToClosestForBatch(encoded_data, centers, deg=pushDeg)
# [!] ONLY c_features are projected into nullspace, so encoded_data is of no use with nullspace currently
#encoded_data = utils.pushToClosestForBatch(encoded_data, centers, deg=pushDeg)
all_losses, all_acc = criterion(c_feature, encoded_data, label)

totLoss = all_losses.sum()
Expand Down Expand Up @@ -70,7 +71,8 @@ def val_step(feature_maker, criterion, data_loader, label_key="speaker", centerp
if centerpushSettings:
centers, pushDeg = centerpushSettings
c_feature = utils.pushToClosestForBatch(c_feature, centers, deg=pushDeg)
encoded_data = utils.pushToClosestForBatch(encoded_data, centers, deg=pushDeg)
# [!] ONLY c_features are projected into nullspace, so encoded_data is of no use with nullspace currently
#encoded_data = utils.pushToClosestForBatch(encoded_data, centers, deg=pushDeg)
all_losses, all_acc = criterion(c_feature, encoded_data, label)

logs["locLoss_val"] += np.asarray([all_losses.mean().item()])
Expand Down
1 change: 0 additions & 1 deletion cpc/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,6 @@ def __init__(self,
def forward(self, batchData, label):
cFeature, encodedData, label = self.cpc(batchData, label)
cFeature = self.nullspace(cFeature)
encodedData = self.nullspace(encodedData)
return cFeature, encodedData, label


Expand Down
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ dependencies:
- tqdm
- nose
- cython
- pysoundfile
- pip:
- soundfile
- progressbar2
- matplotlib
- torchaudio
74 changes: 67 additions & 7 deletions finetune_nullspace.sh
Original file line number Diff line number Diff line change
@@ -1,30 +1,90 @@
SAVE_DIR="/pio/scratch/1/i273233/linear_separability/cpc/gru_level2/cpc_official"
SPEAKERS="speakers_factorized"
PHONEMES="phonemes_nullspace"
SPEAKERS_NULLSPACE="speakers_nullspace"

DIM_INTER=$1
DATASET_PATH=false
TRAIN_SET=false
VALIDATION_SET=false
CHECKPOINT_PATH=false
OUTPUT_DIR=false
DIM_INBETWEEN=false
FROM_STEP=$SPEAKERS
if [[ $# -ge 2 ]]; then
FROM_STEP=$2
PHONES_PATH=false
AUDIO_FORMAT=flac

print_usage() {
echo -e "Usage: ./finetune_nullspace.sh"
echo -e "\t-d DATASET_PATH (E.g. LIBRISPEECH_DATASET_PATH/train-clean-100)"
echo -e "\t-t TRAIN_SET (E.g. LIBRISPEECH_TRAIN_CLEAN_100_TRAIN_SPLIT_FILE_PATH)"
echo -e "\t-v VALIDATION_SET (E.g. LIBRISPEECH_TRAIN_CLEAN_100_TEST_SPLIT_FILE_PATH)"
echo -e "\t-c CHECKPOINT_PATH"
echo -e "\t-o OUTPUT_DIR"
echo -e "\t-n DIM_INBETWEEN (Dimension of nullspace will be DIM_EMBEDDING - DIM_INBETWEEN)"
echo -e "\t-p PHONES_PATH (Path to the file containing phonemes for the entire dataset)"
echo -e "OPTIONAL FLAGS:"
echo -e "\t-s FROM_STEP (From which step do you want to start. Order: $SPEAKERS [default] -> $PHONEMES -> $SPEAKERS_NULLSPACE)"
echo -e "\t-f audio files format in -d dataset (without a dot)"
}

while getopts 'd:t:v:c:o:n:s:p:f:' flag; do
case "${flag}" in
d) DATASET_PATH="${OPTARG}" ;;
t) TRAIN_SET="${OPTARG}" ;;
v) VALIDATION_SET="${OPTARG}" ;;
c) CHECKPOINT_PATH="${OPTARG}" ;;
o) OUTPUT_DIR="${OPTARG}" ;;
n) DIM_INBETWEEN="${OPTARG}" ;;
s) FROM_STEP="${OPTARG}" ;;
p) PHONES_PATH="${OPTARG}" ;;
f) AUDIO_FORMAT=${OPTARG} ;;
*) print_usage
exit 1 ;;
esac
done

echo $DATASET_PATH $TRAIN_SET $VALIDATION_SET $CHECKPOINT_PATH $OUTPUT_DIR $DIM_INBETWEEN $FROM_STEP $PHONES_PATH

if [[ $DATASET_PATH == false || $TRAIN_SET == false || $VALIDATION_SET == false || $CHECKPOINT_PATH == false || $OUTPUT_DIR == false || $DIM_INBETWEEN == false || ( $PHONES_PATH == false && $FROM_STEP != $SPEAKERS ) ]]
then
echo "Either DATASET_PATH, TRAIN_SET, VALIDATION_SET, CHECKPOINT_PATH, OUTPUT_DIR or DIM_INBETWEEN is not set or there are invalid PHONES_PATH and FROM_STEP."
print_usage
exit 1
fi

mkdir -p $OUTPUT_DIR

case $FROM_STEP in
$SPEAKERS)
echo $SPEAKERS
mkdir -p ${SAVE_DIR}_${SPEAKERS}_${DIM_INTER} && python cpc/eval/linear_separability.py $zd/LibriSpeech/train-clean-100/ $zd/LibriSpeech/labels_split/train_split_100.txt $zd/LibriSpeech/labels_split/test_split_100.txt $zd/checkpoints/CPC-big-kmeans50/cpc_ll6k/checkpoint_32.pt --pathCheckpoint ${SAVE_DIR}_${SPEAKERS}_${DIM_INTER} --mode $SPEAKERS --max_size_loaded 40000000 --n_process_loader 2 --model cpc --dim_inter $DIM_INTER --gru_level 2 | tee ${SAVE_DIR}_${SPEAKERS}_${DIM_INTER}/log.txt
mkdir -p ${OUTPUT_DIR}/${SPEAKERS}_${DIM_INBETWEEN}
python cpc/eval/linear_separability.py $DATASET_PATH $TRAIN_SET $VALIDATION_SET $CHECKPOINT_PATH \
--pathCheckpoint ${OUTPUT_DIR}/${SPEAKERS}_${DIM_INBETWEEN} --mode $SPEAKERS \
--max_size_loaded 40000000 --n_process_loader 2 --model cpc --dim_inter $DIM_INBETWEEN --gru_level 2 --file_extension .$AUDIO_FORMAT
;&
$PHONEMES)
echo $PHONEMES
mkdir -p ${SAVE_DIR}_${PHONEMES}_${DIM_INTER} && python cpc/eval/linear_separability.py $zd/LibriSpeech/train-clean-100/ $zd/LibriSpeech/labels_split/train_split_100.txt $zd/LibriSpeech/labels_split/test_split_100.txt $zd/checkpoints/CPC-big-kmeans50/cpc_ll6k/checkpoint_32.pt --pathCheckpoint ${SAVE_DIR}_${PHONEMES}_${DIM_INTER} --mode $PHONEMES --max_size_loaded 40000000 --n_process_loader 2 --model cpc --pathPhone $zd/LibriSpeech/alignments2/converted_aligned_phones.txt --path_speakers_factorized ${SAVE_DIR}_${SPEAKERS}_${DIM_INTER}/checkpoint_9.pt --dim_inter $DIM_INTER --gru_level 2 | tee ${SAVE_DIR}_${PHONEMES}_${DIM_INTER}/log.txt
mkdir -p ${OUTPUT_DIR}/${PHONEMES}_${DIM_INBETWEEN}
python cpc/eval/linear_separability.py $DATASET_PATH $TRAIN_SET $VALIDATION_SET $CHECKPOINT_PATH \
--pathCheckpoint ${OUTPUT_DIR}/${PHONEMES}_${DIM_INBETWEEN} --mode $PHONEMES \
--max_size_loaded 40000000 --n_process_loader 2 --model cpc --pathPhone $PHONES_PATH \
--path_speakers_factorized ${OUTPUT_DIR}/${SPEAKERS}_${DIM_INBETWEEN}/checkpoint_9.pt \
--dim_inter $DIM_INBETWEEN --gru_level 2 --file_extension .$AUDIO_FORMAT
;&
$SPEAKERS_NULLSPACE)
echo $SPEAKERS_NULLSPACE
mkdir -p ${SAVE_DIR}_${SPEAKERS_NULLSPACE}_${DIM_INTER} && python cpc/eval/linear_separability.py $zd/LibriSpeech/train-clean-100/ $zd/LibriSpeech/labels_split/train_split_100.txt $zd/LibriSpeech/labels_split/test_split_100.txt $zd/checkpoints/CPC-big-kmeans50/cpc_ll6k/checkpoint_32.pt --pathCheckpoint ${SAVE_DIR}_${SPEAKERS_NULLSPACE}_${DIM_INTER} --mode $SPEAKERS_NULLSPACE --max_size_loaded 40000000 --n_process_loader 2 --model cpc --path_speakers_factorized ${SAVE_DIR}_${SPEAKERS}_${DIM_INTER}/checkpoint_9.pt --dim_inter $DIM_INTER --gru_level 2 | tee ${SAVE_DIR}_${SPEAKERS_NULLSPACE}_${DIM_INTER}/log.txt
mkdir -p ${OUTPUT_DIR}/${SPEAKERS_NULLSPACE}_${DIM_INBETWEEN}
python cpc/eval/linear_separability.py $DATASET_PATH $TRAIN_SET $VALIDATION_SET $CHECKPOINT_PATH \
--pathCheckpoint ${OUTPUT_DIR}/${SPEAKERS_NULLSPACE}_${DIM_INBETWEEN} --mode $SPEAKERS_NULLSPACE \
--max_size_loaded 40000000 --n_process_loader 2 --model cpc \
--path_speakers_factorized ${OUTPUT_DIR}/${SPEAKERS}_${DIM_INBETWEEN}/checkpoint_9.pt \
--dim_inter $DIM_INBETWEEN --gru_level 2 --file_extension .$AUDIO_FORMAT
;;
*)
echo "Invalid from step: ${FROM_STEP} while it should be either ${SPEAKERS}, ${PHONEMES} or ${SPEAKERS_NULLSPACE}"
;;
esac

echo "Checkpoint with nullspace is located in ${OUTPUT_DIR}/${PHONEMES}_${DIM_INBETWEEN}/checkpoint_9.pt"
echo "The results of all the experiments are located in ${OUTPUT_DIR}/DIRECTORY/checkpoint_logs.json"

exit 0
55 changes: 55 additions & 0 deletions scripts/create_ls_dataset_for_abx_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import os
import sys
import shutil
import argparse
from pathlib import Path
import numpy as np
import soundfile as sf

def parse_args():
# Run parameters
parser = argparse.ArgumentParser()
parser.add_argument("librispeech_path", type=str,
help="Path to the root directory of LibriSpeech.")
parser.add_argument("zerospeech_dataset_path", type=str,
help="Path to the ZeroSpeech dataset.")
parser.add_argument("target_path", type=str,
help="Path to the output directory.")
parser.add_argument("--file_extension", type=str, default="flac",
help="Extension of the audio files in the dataset (default: flac).")
return parser.parse_args()

def main():
# Parse and print args
args = parse_args()
#logger.info(args)

phonetic = "phonetic"
datasets = ["dev-clean", "dev-other", "test-clean", "test-other"]

for dataset in datasets:
print("> {}".format(dataset))
target_dirname = os.path.join(args.target_path, phonetic, dataset)
Path(target_dirname).mkdir(parents=True, exist_ok=True)

librispeech_dirname = os.path.join(args.librispeech_path, dataset)
files = [(filename, dirname) for dirname, _, files in os.walk(librispeech_dirname, followlinks=True) for filename in files if filename.endswith(args.file_extension)]
for i, (filename, dirname) in enumerate(files):
print("Progress {:2.1%}".format(i / len(files)), end="\r")
input_path = os.path.join(dirname, filename)
output_path = os.path.join(target_dirname, os.path.splitext(filename)[0] + "." + args.file_extension)
data, sample_rate = sf.read(input_path)
sf.write(output_path, data, sample_rate)

if dataset.startswith("dev"):
source_item_path = os.path.join(args.zerospeech_dataset_path, phonetic, dataset, dataset + ".item")
target_item_path = os.path.join(target_dirname, dataset + ".item")
shutil.copy(source_item_path, target_item_path)


if __name__ == "__main__":
#import ptvsd
#ptvsd.enable_attach(('0.0.0.0', 7310))
#print("Attach debugger now")
#ptvsd.wait_for_attach()
main()
133 changes: 133 additions & 0 deletions scripts/embeddings_abx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
#!/usr/bin/env python3 -u

import logging
import os
import sys
import argparse
from itertools import chain
from pathlib import Path
import time
import copy
import numpy as np
import soundfile as sf

from cpc.feature_loader import loadModel, FeatureModule

import torch
import torch.nn as nn
import torch.nn.functional as F

logging.basicConfig(
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
level=os.environ.get("LOGLEVEL", "INFO").upper(),
stream=sys.stdout,
)
logger = logging.getLogger("zerospeech2021 abx")

def parse_args():
# Run parameters
parser = argparse.ArgumentParser()
parser.add_argument("path_checkpoint", type=str,
help="Path to the trained fairseq wav2vec2.0 model.")
parser.add_argument("path_data", type=str,
help="Path to the dataset that we want to compute ABX for.")
parser.add_argument("path_output_dir", type=str,
help="Path to the output directory.")
parser.add_argument("--debug", action="store_true",
help="Load only a very small amount of files for "
"debugging purposes.")
parser.add_argument("--cpu", action="store_true",
help="Run on a cpu machine.")
parser.add_argument("--file_extension", type=str, default="wav",
help="Extension of the audio files in the dataset (default: wav).")
parser.add_argument("--no_test", action="store_true",
help="Don't compute embeddings for test-* parts of dataset")
parser.add_argument('--gru_level', type=int, default=-1,
help='Hidden level of the LSTM autoregressive model to be taken'
'(default: -1, last layer).')
parser.add_argument('--nullspace', action='store_true',
help="Additionally load nullspace")
return parser.parse_args()

def main():
# Parse and print args
args = parse_args()
logger.info(args)

# Load the model
print("")
print(f"Loading model from {args.path_checkpoint}")

if args.gru_level is not None and args.gru_level > 0:
updateConfig = argparse.Namespace(nLevelsGRU=args.gru_level)
else:
updateConfig = None

model = loadModel([args.path_checkpoint], load_nullspace=args.nullspace, updateConfig=updateConfig)[0]

if args.gru_level is not None and args.gru_level > 0:
# Keep hidden units at LSTM layers on sequential batches
if args.nullspace:
model.cpc.gAR.keepHidden = True
else:
model.gAR.keepHidden = True

device = "cuda" if torch.cuda.is_available() and not args.cpu else "cpu"

# Register the hooks
layer_outputs = {}
def get_layer_output(name):
def hook(model, input, output):
if type(output) is tuple:
layer_outputs[name] = output[0].detach().squeeze(1).cpu().numpy()
elif type(output) is dict:
layer_outputs[name] = output["x"].detach().squeeze(0).cpu().numpy()
else:
layer_outputs[name] = output.detach().squeeze(0).cpu().numpy()
return hook

layer_names = []
layer_name = os.path.basename(os.path.dirname(args.path_checkpoint))
layer_names.append(layer_name)
if not args.nullspace:
model.gAR.register_forward_hook(get_layer_output(layer_name))
else:
model.nullspace.register_forward_hook(get_layer_output(layer_name))

model = model.eval().to(device)
print("Model loaded!")
print(model)

# Extract values from chosen layers and save them to files
phonetic = "phonetic"
datasets_path = os.path.join(args.path_data, phonetic)
datasets = os.listdir(datasets_path)
datasets = [dataset for dataset in datasets if not args.no_test or not dataset.startswith("test")]
print(datasets)

with torch.no_grad():
for dataset in datasets:
print("> {}".format(dataset))
dataset_path = os.path.join(datasets_path, dataset)
files = [f for f in os.listdir(dataset_path) if f.endswith(args.file_extension)]
for i, f in enumerate(files):
print("Progress {:2.1%}".format(i / len(files)), end="\r")
input_f = os.path.join(dataset_path, f)
x, sample_rate = sf.read(input_f)
x = torch.tensor(x).float().reshape(1,1,-1).to(device)
output = model(x, None)[0]

for layer_name, value in layer_outputs.items():
output_dir = os.path.join(args.path_output_dir, layer_name, phonetic, dataset)
Path(output_dir).mkdir(parents=True, exist_ok=True)
out_f = os.path.join(output_dir, os.path.splitext(f)[0] + ".txt")
np.savetxt(out_f, value)

if __name__ == "__main__":
#import ptvsd
#ptvsd.enable_attach(('0.0.0.0', 7310))
#print("Attach debugger now")
#ptvsd.wait_for_attach()
main()

Loading