-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy patheval_segmentation.py
91 lines (70 loc) · 3.78 KB
/
eval_segmentation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import os
import pathlib
import random
import sys
from argparse import ArgumentParser
import numpy as np
import sentencepiece
from simi import dataset, utils
from simi.segmentation import segment_sentencepiece, segment_viterbi
def parseArgs():
parser = ArgumentParser()
parser.add_argument('dataset', type=pathlib.Path,
help='Path to the quantized dataset, which is to be segmented')
parser.add_argument('sentencepiece_prefix', type=pathlib.Path,
help='Prefix for sentencepiece model. It must point to existing folder containing sentencepiece model & vocab.')
parser.add_argument('output', type=pathlib.Path,
help='Output folder')
parser.add_argument('--clusterings', type=str,
help='Path to the clusterings of the data, must match the dataset. Required if using Viterbi segmentation')
parser.add_argument('--seed', type=int, default=290956,
help='Random seed')
parser.add_argument('--viterbi', action='store_true',
help='Do Viterbi segmentation instead of sentencepiece\'s default')
parser.add_argument('--alpha', type=float, default=1.0,
help='Temperature for sharpening/smoothening clustering distribution. More than 1: sharpening, less than 1: smoothening. Deafult: 1.0')
parser.add_argument('--output_format', type=str, default='txt,csv',
help='Output format of the transformed dataset. Comma separated list of: \'txt\' (arrays of strings) or \'csv\' (similar to LibriSpeech alignments)')
return parser.parse_args()
def save_segmentation(formatted, dataset, path, args):
assert args.output_format and args.output_format.split(','), \
f'Output format should be a comma sparated list'
if 'txt' in args.output_format.split(','):
with open(args.output / 'segmented_outputs.txt', 'w') as output:
for sentence, fname in zip(formatted, dataset.filenames):
output.write(f'{fname} {" ".join(sentence)}\n')
if 'csv' in args.output_format.split(','):
if not os.path.exists(path):
os.makedirs(path)
for sentence, fname in zip(formatted, dataset.filenames):
with open(path / (fname+'.csv'), 'w') as output:
i = 0
for word in sentence:
output.write(f'{i/100},{(i+len(word))/100},{word},phones\n')
i += len(word)
def run(args):
random.seed(args.seed)
np.random.seed(args.seed)
sentencepiece.set_random_generator_seed(args.seed)
if not os.path.exists(args.output):
os.makedirs(args.output)
if 'txt' in args.output_format.split(',') and os.path.exists(args.output / 'segmented_outputs.txt'):
print(f'Segmentation already found at {args.output}, skipping')
sys.exit(0)
assert utils.ensure_path(f'{args.sentencepiece_prefix}.model'), \
f'Sentencepiece model not found at {args.sentencepiece_prefix}'
print('Loading dataset...')
devset = dataset.Data(args.dataset)
if args.viterbi:
print('Running Viterbi segmentation...')
assert args.clusterings is not None, "If viterbi is used you have to specify path to the clusterings!"
devset.load_clusterings(args.clusterings, args.alpha)
vit_formatted, vit_segmentation = segment_viterbi(devset.data, devset.clusterings, args.sentencepiece_prefix)
save_segmentation(vit_formatted, devset, args.output, args)
else:
print('Running SentencePiece segmentation...')
sp_formatted, sp_segmentation = segment_sentencepiece(devset.data, args.sentencepiece_prefix)
save_segmentation(sp_formatted, devset, args.output, args)
if __name__ == "__main__":
args = parseArgs()
run(args)