Skip to content

Commit

Permalink
cleanup, proper readme
Browse files Browse the repository at this point in the history
  • Loading branch information
gcie committed Jul 12, 2021
1 parent e8db677 commit 1deb2c3
Show file tree
Hide file tree
Showing 17 changed files with 307 additions and 191 deletions.
28 changes: 28 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1 +1,29 @@
# SIMI experiments

Repository containing current experiments related to the simi (semantic) task of the ZeroSpeech 2021 Challenge.

## Scoring

**Example command**: `./scoring/run.sh`

We compute the PER (Phonetic Error Rate) of a clusterization by mapping every sentence piece to the phone (from ground-truth), that occured most often at it's time. Then we do a greedy alignment, and compute the mismatch error rate.

## Clusterization

This is a script for running the CPC_kmeans checkpoint, but without the latest argmax (we leave each frame as distribution on centroids). This is required for the viterbi segmentation.

## Segmentation

**Example command**: `./simi/segmentation/run.sh`

We run sentencepiece on given `trainset`, then using learnt language model we try to predict the best segmentation. There are two ways of doing so:
1. use sentencepiece's default segmentation (by default), but it's bad because our language is not very "exact" - some pseudophones (output of the CPC+clustering) may be mismatched
2. use viterbi segmentation (flag `--viterbi`). It takes errors into account, and it usually gives lower PER.

To get the description of parameters run command `python segmentation.py --help`.

## Grouping

**Example command**: `./simi/grouping/run.sh`

The idea is based on the fact that segmentation using sentencepiece requires large vocab to work, but it results in multiple different sentence pieces mapping to the same phoneme. In order to reduce number of sentence pieces, we run vord2vec on them and then kmeans, grouping and merging multiple sentencepieces together. It never increases PER, but (greatly) reduces vocab size.
50 changes: 34 additions & 16 deletions simi/quantization/quantization.py → clusterization.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,38 @@
import json
import os
import pathlib
from argparse import ArgumentParser
from pathlib import Path
from random import shuffle
from time import time
from urllib import parse

import numpy as np
import progressbar
import torch
from cpc.dataset import findAllSeqs
from cpc.feature_loader import buildFeature, buildFeature_batch
from cpc.feature_loader import buildFeature_batch

from utils_functions import (loadClusterModule, loadCPCFeatureMaker, readArgs,
writeArgs)
from simi.clusterization import loadClusterModule, loadCPCFeatureMaker, readArgs

CUDA = False

def quantize_file(file_path, cpc_feature_function, clusterModule):
def parseArgs():
parser = ArgumentParser()

parser.add_argument('clustering_checkpoint', type=pathlib.Path,
help='Path to the clustering checkpoint')
parser.add_argument('dataset', type=pathlib.Path,
help='Path to the dataset, which is to be quantized')
parser.add_argument('output', type=pathlib.Path,
help='Output path')
parser.add_argument('--file_extension', type=str, default='vaw',
help='File extension of the audio files')
parser.add_argument('--CUDA', action='store_true',
help='Use CUDA')
return parser.parse_args()



def quantize_file(file_path, cpc_feature_function, clusterModule, args):
# Get CPC features
cFeatures = cpc_feature_function(file_path)
if clusterModule.Ck.is_cuda:
Expand All @@ -27,17 +44,17 @@ def quantize_file(file_path, cpc_feature_function, clusterModule):
cFeatures = cFeatures.view(1, -1, clusterModule.Ck.size(-1))

clustered = clusterModule(cFeatures)
if CUDA:
if args.CUDA:
clusterModule = clusterModule.cuda()

return clustered.detach().cpu().numpy().reshape(-1, clusterModule.k)

def main():
pathClusteringCheckpoint = '/pio/data/zerospeech2021/checkpoints/CPC-big-kmeans50/clustering_kmeans50/clustering_CPC_big_kmeans50.pt'
pathDB = '/pio/data/zerospeech2021/LibriSpeech/test-clean'
pathOutputDir = '/pio/scratch/1/i290956/zs2021/clusterings/LibriSpeech/test-clean'
def main(args):
pathClusteringCheckpoint = str(args.clustering_checkpoint) # '/pio/data/zerospeech2021/checkpoints/CPC-big-kmeans50/clustering_kmeans50/clustering_CPC_big_kmeans50.pt'
pathDB = str(args.dataset) # '/pio/data/zerospeech2021/LibriSpeech/test-clean'
pathOutputDir = str(args.output) # '/pio/scratch/1/i290956/zs2021/clusterings/LibriSpeech/test-clean'

file_extension = 'vaw'
file_extension = args.file_extension
seqNames, _ = findAllSeqs(pathDB, speaker_level=1, extension=file_extension, loadCache=True)

if not os.path.exists(pathOutputDir):
Expand Down Expand Up @@ -68,7 +85,7 @@ def main():
print("")
print(f"Loading ClusterModule at {pathClusteringCheckpoint}")
clusterModule = loadClusterModule(pathClusteringCheckpoint)
if CUDA:
if args.CUDA:
clusterModule.cuda()
print("ClusterModule loaded!")

Expand All @@ -88,7 +105,7 @@ def main():
featureMaker = torch.nn.Sequential(featureMaker, dimRed)
if not clustering_args.train_mode:
featureMaker.eval()
if CUDA:
if args.CUDA:
featureMaker.cuda()
def cpc_feature_function(x):
return buildFeature_batch(featureMaker, x,seqNorm=False, strict=True,
Expand All @@ -112,12 +129,13 @@ def cpc_feature_function(x):
# Quantization
f = open(outputPath, 'wb')
f.close()
clustered_file = quantize_file(file_path, cpc_feature_function, clusterModule)
clustered_file = quantize_file(file_path, cpc_feature_function, clusterModule, args)
np.save(outputPath, clustered_file)

bar.finish()
print(f"...done {len(seqNames)} files in {time()-start_time} seconds.")


if __name__ == "__main__":
main()
args = parseArgs()
main(args)
191 changes: 191 additions & 0 deletions environment.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
name: semantic
channels:
- pytorch
- conda-forge
- defaults
dependencies:
- _libgcc_mutex=0.1=conda_forge
- _openmp_mutex=4.5=1_gnu
- aiohttp=3.7.4=py38h27cfd23_1
- async-timeout=3.0.1=py38h06a4308_0
- attrs=20.3.0=pyhd3eb1b0_0
- backcall=0.2.0=pyhd3eb1b0_0
- biopython=1.78=py38h7b6447c_0
- blas=1.0=mkl
- boto3=1.17.46=pyhd3eb1b0_0
- botocore=1.20.50=pyhd3eb1b0_1
- brotlipy=0.7.0=py38h27cfd23_1003
- bzip2=1.0.8=h7b6447c_0
- ca-certificates=2021.5.25=h06a4308_1
- cachetools=4.2.1=pyhd3eb1b0_0
- certifi=2021.5.30=py38h06a4308_0
- cffi=1.14.5=py38h261ae71_0
- chardet=3.0.4=py38h06a4308_1003
- cryptography=3.4.7=py38hd23ed53_0
- cudatoolkit=11.0.3=h15472ef_8
- cycler=0.10.0=py38_0
- cython=0.29.23=py38h2531618_0
- dbus=1.13.18=hb2f20db_0
- decorator=5.0.6=pyhd3eb1b0_0
- editdistance=0.5.3=py38h2531618_0
- expat=2.3.0=h2531618_2
- ffmpeg=4.3=hf484d3e_0
- fontconfig=2.13.1=h6c09931_0
- freetype=2.10.4=h5ab3b9f_0
- gensim=3.8.3=py38h2531618_2
- gettext=0.19.8.1=h0b5b191_1005
- glib=2.68.1=h36276a3_0
- gmp=6.2.1=h2531618_2
- gnutls=3.6.15=he1e5248_0
- google-api-core=1.25.1=pyhd3eb1b0_0
- google-auth=1.29.0=pyhd3eb1b0_0
- google-cloud-core=1.6.0=pyhd3eb1b0_0
- google-cloud-storage=1.37.1=pyhd3eb1b0_0
- google-crc32c=1.1.2=py38h27cfd23_0
- google-resumable-media=1.2.0=pyhd3eb1b0_1
- googleapis-common-protos=1.53.0=py38h06a4308_0
- gst-plugins-base=1.14.0=h8213a91_2
- gstreamer=1.14.0=h28cd5cc_2
- icu=58.2=he6710b0_3
- idna=2.10=pyhd3eb1b0_0
- iniconfig=1.1.1=pyhd3eb1b0_0
- intel-openmp=2021.2.0=h06a4308_610
- intervaltree=3.1.0=py_0
- ipykernel=5.3.4=py38h5ca1d4c_0
- ipython=7.22.0=py38hb070fc8_0
- ipython_genutils=0.2.0=pyhd3eb1b0_1
- jedi=0.17.0=py38_0
- jmespath=0.10.0=py_0
- joblib=1.0.1=pyhd3eb1b0_0
- jpeg=9b=h024ee3a_2
- jupyter_client=6.1.12=pyhd3eb1b0_0
- jupyter_core=4.7.1=py38h06a4308_0
- kiwisolver=1.3.1=py38h2531618_0
- lame=3.100=h7b6447c_0
- lcms2=2.12=h3be6417_0
- ld_impl_linux-64=2.33.1=h53a641e_7
- libcrc32c=1.1.1=he6710b0_2
- libffi=3.3=he6710b0_2
- libflac=1.3.3=h9c3ff4c_1
- libgcc-ng=9.3.0=h2828fa1_19
- libgfortran-ng=7.3.0=hdf63c60_0
- libgomp=9.3.0=h2828fa1_19
- libiconv=1.15=h63c8f33_5
- libidn2=2.3.1=h27cfd23_0
- libllvm10=10.0.1=hbcb73fb_5
- libogg=1.3.4=h7f98852_1
- libopus=1.3.1=h7f98852_1
- libpng=1.6.37=hbc83047_0
- libprotobuf=3.14.0=h8c45485_0
- libsndfile=1.0.31=h9c3ff4c_1
- libsodium=1.0.18=h7b6447c_0
- libstdcxx-ng=9.3.0=h6de172a_19
- libtasn1=4.16.0=h27cfd23_0
- libtiff=4.1.0=h2733197_1
- libunistring=0.9.10=h27cfd23_0
- libuuid=1.0.3=h1bed415_2
- libuv=1.40.0=h7b6447c_0
- libvorbis=1.3.7=h9c3ff4c_0
- libxcb=1.14=h7b6447c_0
- libxml2=2.9.10=hb55368b_3
- llvmlite=0.36.0=py38h612dafd_4
- lz4-c=1.9.3=h2531618_0
- matplotlib=3.3.4=py38h06a4308_0
- matplotlib-base=3.3.4=py38h62a2d02_0
- mkl=2021.2.0=h06a4308_296
- mkl-service=2.3.0=py38h27cfd23_1
- mkl_fft=1.3.0=py38h42c9631_2
- mkl_random=1.2.1=py38ha9443f7_2
- more-itertools=8.7.0=pyhd3eb1b0_0
- multidict=5.1.0=py38h27cfd23_2
- ncurses=6.2=he6710b0_1
- nettle=3.7.2=hbbd107a_1
- ninja=1.10.2=hff7bd54_1
- numba=0.53.1=py38ha9443f7_0
- numpy=1.20.1=py38h93e21f0_0
- numpy-base=1.20.1=py38h7d8b39e_0
- olefile=0.46=py_0
- openh264=2.1.0=hd408876_0
- openssl=1.1.1k=h27cfd23_0
- packaging=20.9=pyhd3eb1b0_0
- pandas=1.2.4=py38h2531618_0
- parso=0.8.2=pyhd3eb1b0_0
- pcre=8.44=he6710b0_0
- pexpect=4.8.0=pyhd3eb1b0_3
- pickleshare=0.7.5=pyhd3eb1b0_1003
- pillow=8.2.0=py38he98fc37_0
- pip=21.0.1=py38h06a4308_0
- pluggy=0.13.1=py38h06a4308_0
- progressbar2=3.37.1=py38h06a4308_0
- prompt-toolkit=3.0.17=pyh06a4308_0
- protobuf=3.14.0=py38h2531618_1
- ptyprocess=0.7.0=pyhd3eb1b0_2
- py=1.10.0=pyhd3eb1b0_0
- pyasn1=0.4.8=py_0
- pyasn1-modules=0.2.8=py_0
- pycparser=2.20=py_2
- pygments=2.8.1=pyhd3eb1b0_0
- pyopenssl=20.0.1=pyhd3eb1b0_1
- pyparsing=2.4.7=pyhd3eb1b0_0
- pyqt=5.9.2=py38h05f1152_4
- pysocks=1.7.1=py38h06a4308_0
- pysoundfile=0.10.3.post1=pyhd3deb0d_0
- pytest=6.2.3=py38h06a4308_2
- pytest-runner=5.3.0=pyhd3eb1b0_0
- python=3.8.5=h7579374_1
- python-dateutil=2.8.1=pyhd3eb1b0_0
- python-utils=2.5.6=py38h06a4308_0
- python_abi=3.8=1_cp38
- pytorch=1.7.1=py3.8_cuda11.0.221_cudnn8.0.5_0
- pytz=2021.1=pyhd3eb1b0_0
- pyyaml=5.4.1=py38h27cfd23_1
- pyzmq=20.0.0=py38h2531618_1
- qt=5.9.7=h5867ecd_1
- readline=8.1=h27cfd23_0
- requests=2.25.1=pyhd3eb1b0_0
- rsa=4.7.2=pyhd3eb1b0_1
- s3transfer=0.3.6=pyhd3eb1b0_0
- scikit-learn=0.24.1=py38ha9443f7_0
- scipy=1.6.2=py38had2a1c9_1
- setuptools=52.0.0=py38h06a4308_0
- sip=4.19.13=py38he6710b0_0
- six=1.15.0=py38h06a4308_0
- smart_open=5.0.0=pyhd3eb1b0_0
- sortedcontainers=2.3.0=pyhd3eb1b0_0
- sqlite=3.35.4=hdfb4753_0
- tbb=2020.3=hfd86e86_0
- threadpoolctl=2.1.0=pyh5ca1d4c_0
- tk=8.6.10=hbc83047_0
- toml=0.10.2=pyhd3eb1b0_0
- torchaudio=0.7.2=py38
- torchvision=0.8.2=py38_cu110
- tornado=6.1=py38h27cfd23_0
- tqdm=4.59.0=pyhd3eb1b0_1
- traitlets=5.0.5=pyhd3eb1b0_0
- typing-extensions=3.7.4.3=hd3eb1b0_0
- typing_extensions=3.7.4.3=pyh06a4308_0
- urllib3=1.26.4=pyhd3eb1b0_0
- wcwidth=0.2.5=py_0
- wheel=0.36.2=pyhd3eb1b0_0
- xz=5.2.5=h7b6447c_0
- yaml=0.2.5=h7b6447c_0
- yarl=1.5.1=py38h7b6447c_0
- zeromq=4.3.4=h2531618_0
- zlib=1.2.11=h7b6447c_3
- zstd=1.4.9=haebb681_0
- pip:
- antlr4-python3-runtime==4.8
- cpc-audio==1.0
- fairseq==1.0.0a0+19793a7
- hydra-core==1.0.6
- importlib-resources==5.1.4
- jiwer==2.2.0
- omegaconf==2.0.6
- portalocker==2.0.0
- ptvsd==4.3.2
- python-levenshtein==0.12.2
- regex==2021.4.4
- sacrebleu==1.5.1
- sentencepiece==0.1.95
- zipp==3.4.1
prefix: /pio/scratch/1/i290956/miniconda3/envs/semantic
12 changes: 3 additions & 9 deletions cluster.py → grouping.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,9 @@
import os
import pathlib
import random
import tqdm
from argparse import ArgumentError, ArgumentParser
import pickle
from argparse import ArgumentParser

from collections import defaultdict
import numpy as np
import sentencepiece
from more_itertools import grouper

from simi import dataset
from simi import utils
from simi.vectorization import vectorize
from simi.clusterization import cluster_kmeans

Expand Down Expand Up @@ -67,6 +59,7 @@ def save(self, path):
for x in sample:
output.write(','.join(map(str, x)) + '\n')


def run(args):
print(f'Loading train segmentation...')
segmentation = LibriSpeechSegmentation(args.segmentation)
Expand All @@ -87,6 +80,7 @@ def run(args):
segmentation.save(args.output)
print('Done!')


class StubArgs(object):
def __init__(self):
self.seed = 290956
Expand Down
Loading

0 comments on commit 1deb2c3

Please sign in to comment.