Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

update clustering_script.py, clustering_quantization.py, feature_loader.py #14

Open
wants to merge 3 commits into
base: zerospeech
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
209 changes: 119 additions & 90 deletions cpc/criterion/clustering/clustering_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch
from cpc.dataset import findAllSeqs
from cpc.feature_loader import buildFeature, FeatureModule, loadModel, buildFeature_batch
from cpc.criterion.research.clustering import kMeanCluster
from cpc.criterion.clustering import kMeanCluster

def readArgs(pathArgs):
print(f"Loading args from {pathArgs}")
Expand All @@ -18,31 +18,57 @@ def readArgs(pathArgs):

return args

def writeArgs(pathArgs, args):
with open(pathArgs, 'w') as file:
json.dump(vars(args), file, indent=2)

def loadClusterModule(pathCheckpoint):
print(f"Loading ClusterModule at {pathCheckpoint}")
state_dict = torch.load(pathCheckpoint)
if "state_dict" in state_dict: #kmeans
clusterModule = kMeanCluster(torch.zeros(1, state_dict["n_clusters"], state_dict["dim"]))
clusterModule.load_state_dict(state_dict["state_dict"])
else: #dpmeans
clusterModule = kMeanCluster(state_dict["mu"])
clusterModule = clusterModule.cuda()
state_dict = torch.load(pathCheckpoint, map_location=torch.device('cpu'))
clusterModule = kMeanCluster(torch.zeros(1, state_dict["n_clusters"], state_dict["dim"]))
clusterModule.load_state_dict(state_dict["state_dict"])
return clusterModule

def quantize_file(file_path, cpc_feature_function, clusterModule):
# Get CPC features
cFeatures = cpc_feature_function(file_path)
if clusterModule.Ck.is_cuda:
cFeatures = cFeatures.cuda()

nGroups = cFeatures.size(-1)//clusterModule.Ck.size(-1) # groups information

# Quantize the output of clustering on the CPC features
cFeatures = cFeatures.view(1, -1, clusterModule.Ck.size(-1))
if cFeatures.size(1) > 50000: # Librilight, to avoid GPU OOM, decrease when still OOM
clusterModule = clusterModule.cpu()
cFeatures = cFeatures.cpu()
qFeatures = torch.argmin(clusterModule(cFeatures), dim=-1)
if not args.cpu:
clusterModule = clusterModule.cuda()
else:
qFeatures = torch.argmin(clusterModule(cFeatures), dim=-1)
qFeatures = qFeatures[0].detach().cpu().numpy()

# Transform to quantized line
quantLine = ",".join(["-".join([str(i) for i in item]) for item in qFeatures.reshape(-1, nGroups)])

return quantLine

def parseArgs(argv):
# Run parameters
parser = argparse.ArgumentParser(description='Quantize audio files using CPC Clustering Module.')
parser.add_argument('pathCheckpoint', type=str,
parser.add_argument('pathClusteringCheckpoint', type=str,
help='Path to the clustering checkpoint.')
parser.add_argument('pathDB', type=str,
help='Path to the dataset that we want to quantize.')
parser.add_argument('pathOutput', type=str,
parser.add_argument('pathOutputDir', type=str,
help='Path to the output directory.')
parser.add_argument('--pathSeq', type=str,
help='Path to the sequences (file names) to be included used.')
parser.add_argument('--pathSeq', type=str,
help='Path to the sequences (file names) to be included used '
'(if not speficied, included all files found in pathDB).')
parser.add_argument('--split', type=str, default=None,
help="If you want to divide the dataset in small splits, specify it "
"with idxSplit-numSplits (idxSplit > 0), eg. --split 1-20.")
help='If you want to divide the dataset in small splits, specify it '
'with idxSplit-numSplits (idxSplit > 0), eg. --split 1-20.')
parser.add_argument('--file_extension', type=str, default=".flac",
help="Extension of the audio files in the dataset (default: .flac).")
parser.add_argument('--max_size_seq', type=int, default=10240,
Expand All @@ -62,11 +88,10 @@ def parseArgs(argv):
"NOTE: This can have better quantized units as we can set "
"model.gAR.keepHidden = True (line 162), but the quantization"
"will be a bit longer.")
parser.add_argument('--recursionLevel', type=int, default=1,
help='Speaker level in pathDB (defaut: 1). This is only helpful'
'when --separate-speaker is activated.')
parser.add_argument('--separate-speaker', action='store_true',
help="Separate each speaker with a different output file.")
parser.add_argument('--cpu', action='store_true',
help="Run on a cpu machine.")
parser.add_argument('--resume', action='store_true',
help="Continue to quantize if an output file already exists.")
return parser.parse_args(argv)

def main(argv):
Expand All @@ -77,12 +102,6 @@ def main(argv):
print(f"Quantizing data from {args.pathDB}")
print("=============================================================")

# Check if directory exists
if not os.path.exists(args.pathOutput):
print("")
print(f"Creating the output directory at {args.pathOutput}")
Path(args.pathOutput).mkdir(parents=True, exist_ok=True)

# Get splits
if args.split:
assert len(args.split.split("-"))==2 and int(args.split.split("-")[1]) >= int(args.split.split("-")[0]) >= 1, \
Expand All @@ -93,40 +112,45 @@ def main(argv):

# Find all sequences
print("")
print(f"Looking for all {args.file_extension} files in {args.pathDB} with speakerLevel {args.recursionLevel}")
seqNames, speakers = findAllSeqs(args.pathDB,
speaker_level=args.recursionLevel,
print(f"Looking for all {args.file_extension} files in {args.pathDB}")
seqNames, _ = findAllSeqs(args.pathDB,
speaker_level=1,
extension=args.file_extension,
loadCache=True)
if len(seqNames) == 0 or not os.path.splitext(seqNames[0][1])[1].endswith(args.file_extension):
print(f"Seems like the _seq_cache.txt does not contain the correct extension, reload the file list")
seqNames, _ = findAllSeqs(args.pathDB,
speaker_level=1,
extension=args.file_extension,
loadCache=False)
print(f"Done! Found {len(seqNames)} files!")

# Filter specific sequences
if args.pathSeq:
with open(args.pathSeq, 'r') as f:
seqs = set([x.strip() for x in f])

filtered = []
for s in seqNames:
if s[1].split('/')[-1].split('.')[0] in seqs:
filtered.append(s)
print("")
print(f"Filtering seqs in {args.pathSeq}")
with open(args.pathSeq, 'r') as f:
seqs = set([x.strip() for x in f])
filtered = []
for s in seqNames:
if os.path.splitext(s[1].split('/')[-1])[0] in seqs:
filtered.append(s)
seqNames = filtered
print(f"Done! {len(seqNames)} files filtered!")

print(f"Done! Found {len(seqNames)} files and {len(speakers)} speakers!")
if args.separate_speaker:
seqNames_by_speaker = {}
for seq in seqNames:
speaker = seq[1].split("/")[args.recursionLevel-1]
if speaker not in seqNames_by_speaker:
seqNames_by_speaker[speaker] = []
seqNames_by_speaker[speaker].append(seq)
# Check if directory exists
if not os.path.exists(args.pathOutputDir):
print("")
print(f"Creating the output directory at {args.pathOutputDir}")
Path(args.pathOutputDir).mkdir(parents=True, exist_ok=True)
writeArgs(os.path.join(args.pathOutputDir, "_info_args.json"), args)

# Check if output file exists
if not args.split:
nameOutput = "quantized_outputs.txt"
else:
nameOutput = f"quantized_outputs_split_{idx_split}-{num_splits}.txt"
if args.separate_speaker is False:
outputFile = os.path.join(args.pathOutput, nameOutput)
assert not os.path.exists(outputFile), \
f"Output file {outputFile} already exists !!!"
outputFile = os.path.join(args.pathOutputDir, nameOutput)

# Get splits
if args.split:
Expand All @@ -147,27 +171,50 @@ def main(argv):
# shuffle(seqNames)
seqNames = seqNames[:nsamples]

# Continue
addEndLine = False # to add end line (\n) to first line or not
if args.resume:
if os.path.exists(outputFile):
with open(outputFile, 'r') as f:
lines = [line for line in f]
existing_files = set([x.split()[0] for x in lines if x.split()])
seqNames = [s for s in seqNames if os.path.splitext(s[1].split('/')[-1])[0] not in existing_files]
print(f"Found existing output file, continue to quantize {len(seqNames)} audio files left!")
if len(lines) > 0 and not lines[-1].endswith("\n"):
addEndLine = True
else:
assert not os.path.exists(outputFile), \
f"Output file {outputFile} already exists !!! If you want to continue quantizing audio files, please check the --resume option."

assert len(seqNames) > 0, \
"No file to be quantized!"

# Load Clustering args
assert args.pathCheckpoint[-3:] == ".pt"
if os.path.exists(args.pathCheckpoint[:-3] + "_args.json"):
pathConfig = args.pathCheckpoint[:-3] + "_args.json"
elif os.path.exists(os.path.join(os.path.dirname(args.pathCheckpoint), "checkpoint_args.json")):
pathConfig = os.path.join(os.path.dirname(args.pathCheckpoint), "checkpoint_args.json")
assert args.pathClusteringCheckpoint[-3:] == ".pt"
if os.path.exists(args.pathClusteringCheckpoint[:-3] + "_args.json"):
pathConfig = args.pathClusteringCheckpoint[:-3] + "_args.json"
elif os.path.exists(os.path.join(os.path.dirname(args.pathClusteringCheckpoint), "checkpoint_args.json")):
pathConfig = os.path.join(os.path.dirname(args.pathClusteringCheckpoint), "checkpoint_args.json")
else:
assert False, \
f"Args file not found in the directory {os.path.dirname(args.pathCheckpoint)}"
f"Args file not found in the directory {os.path.dirname(args.pathClusteringCheckpoint)}"
clustering_args = readArgs(pathConfig)
print("")
print(f"Clutering args:\n{json.dumps(vars(clustering_args), indent=4, sort_keys=True)}")
print('-' * 50)

# Load CluterModule
clusterModule = loadClusterModule(args.pathCheckpoint)
clusterModule.cuda()
clusterModule = loadClusterModule(args.pathClusteringCheckpoint)
if not args.cpu:
clusterModule.cuda()

# Load FeatureMaker
print("")
print("Loading CPC FeatureMaker")
if not os.path.isabs(clustering_args.pathCheckpoint): # Maybe it's relative path
clustering_args.pathCheckpoint = os.path.join(os.path.dirname(os.path.abspath(args.pathClusteringCheckpoint)), clustering_args.pathCheckpoint)
assert os.path.exists(clustering_args.pathCheckpoint), \
f"CPC path at {clustering_args.pathCheckpoint} does not exist!!"
if 'level_gru' in vars(clustering_args) and clustering_args.level_gru is not None:
updateConfig = argparse.Namespace(nLevelsGRU=clustering_args.level_gru)
else:
Expand All @@ -183,8 +230,9 @@ def main(argv):
featureMaker = torch.nn.Sequential(featureMaker, dimRed)
if not clustering_args.train_mode:
featureMaker.eval()
featureMaker.cuda()
def feature_function(x):
if not args.cpu:
featureMaker.cuda()
def cpc_feature_function(x):
if args.nobatch is False:
return buildFeature_batch(featureMaker, x,
seqNorm=False,
Expand All @@ -196,11 +244,11 @@ def feature_function(x):
seqNorm=False,
strict=args.strict)
print("CPC FeatureMaker loaded!")

# Quantization of files
print("")
print(f"Quantizing audio files...")
seqQuantLines = []
print(f"Quantizing audio files and saving outputs to {outputFile}...")
f = open(outputFile, "a")
bar = progressbar.ProgressBar(maxval=len(seqNames))
bar.start()
start_time = time()
Expand All @@ -210,39 +258,20 @@ def feature_function(x):
file_path = vals[1]
file_path = os.path.join(args.pathDB, file_path)

# Get features & quantizing
cFeatures = feature_function(file_path).cuda()

nGroups = cFeatures.size(-1)//clusterModule.Ck.size(-1)

cFeatures = cFeatures.view(1, -1, clusterModule.Ck.size(-1))
# Quantizing
quantLine = quantize_file(file_path, cpc_feature_function, clusterModule)

if len(vals) > 2 and int(vals[-1]) > 9400000: # Librilight, to avoid OOM
clusterModule = clusterModule.cpu()
cFeatures = cFeatures.cpu()
qFeatures = torch.argmin(clusterModule(cFeatures), dim=-1)
clusterModule = clusterModule.cuda()
# Save the outputs
file_name = os.path.splitext(os.path.basename(file_path))[0]
outLine = "\t".join([file_name, quantLine])
if addEndLine:
f.write("\n"+outLine)
else:
qFeatures = torch.argmin(clusterModule(cFeatures), dim=-1)
qFeatures = qFeatures[0].detach().cpu().numpy()

# Transform to quantized line
quantLine = ",".join(["-".join([str(i) for i in item]) for item in qFeatures.reshape(-1, nGroups)])
seqQuantLines.append(quantLine)

f.write(outLine)
addEndLine = True
bar.finish()
print(f"...done {len(seqQuantLines)} files in {time()-start_time} seconds.")

# Saving outputs
print("")
print(f"Saving outputs to {outputFile}")
outLines = []
for vals, quantln in zip(seqNames, seqQuantLines):
file_path = vals[1]
file_name = os.path.splitext(os.path.basename(file_path))[0]
outLines.append("\t".join([file_name, quantln]))
with open(outputFile, "w") as f:
f.write("\n".join(outLines))
print(f"...done {len(seqNames)} files in {time()-start_time} seconds.")
f.close()

if __name__ == "__main__":
args = sys.argv[1:]
Expand Down
10 changes: 5 additions & 5 deletions cpc/criterion/clustering/clustering_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,12 @@ def parseArgs(argv):
print(f"Length of dataLoader: {len(trainLoader)}")
print("")

#if args.level_gru is None:
# updateConfig = None
#else:
# updateConfig = argparse.Namespace(nLevelsGRU=args.level_gru)
if args.level_gru is None:
updateConfig = None
else:
updateConfig = argparse.Namespace(nLevelsGRU=args.level_gru)

model = loadModel([args.pathCheckpoint])[0]#, updateConfig=updateConfig)[0]
model = loadModel([args.pathCheckpoint][0], updateConfig=updateConfig)[0]
featureMaker = FeatureModule(model, args.encoder_layer)
print("Checkpoint loaded!")
print("")
Expand Down
Loading