|
| 1 | +#!/usr/bin/env python |
| 2 | +# coding=utf-8 |
| 3 | +# wujian@2018 |
| 4 | + |
| 5 | +import os |
| 6 | +import random |
| 7 | +import logging |
| 8 | +import pickle |
| 9 | + |
| 10 | +import numpy as np |
| 11 | +import torch as th |
| 12 | + |
| 13 | +from torch.nn.utils.rnn import pack_sequence, pad_sequence |
| 14 | + |
| 15 | +from utils import parse_scps, stft, compute_vad_mask, apply_cmvn |
| 16 | + |
| 17 | +logger = logging.getLogger(__name__) |
| 18 | +logger.setLevel(logging.INFO) |
| 19 | +handler = logging.StreamHandler() |
| 20 | +handler.setLevel(logging.INFO) |
| 21 | +formatter = logging.Formatter( |
| 22 | + "%(asctime)s [%(pathname)s:%(lineno)s - %(levelname)s ] %(message)s") |
| 23 | +handler.setFormatter(formatter) |
| 24 | +logger.addHandler(handler) |
| 25 | + |
| 26 | + |
| 27 | +class SpectrogramReader(object): |
| 28 | + """ |
| 29 | + Wrapper for short-time fourier transform of dataset |
| 30 | + """ |
| 31 | + |
| 32 | + def __init__(self, wave_scp, **kwargs): |
| 33 | + if not os.path.exists(wave_scp): |
| 34 | + raise FileNotFoundError("Could not find file {}".format(wave_scp)) |
| 35 | + self.stft_kwargs = kwargs |
| 36 | + self.wave_dict = parse_scps(wave_scp) |
| 37 | + self.wave_keys = [key for key in self.wave_dict.keys()] |
| 38 | + logger.info( |
| 39 | + "Create SpectrogramReader for {} with {} utterances".format( |
| 40 | + wave_scp, len(self.wave_dict))) |
| 41 | + |
| 42 | + def __len__(self): |
| 43 | + return len(self.wave_dict) |
| 44 | + |
| 45 | + def __contains__(self, key): |
| 46 | + return key in self.wave_dict |
| 47 | + |
| 48 | + # stft |
| 49 | + def _load(self, key): |
| 50 | + return stft(self.wave_dict[key], **self.stft_kwargs) |
| 51 | + |
| 52 | + # sequential index |
| 53 | + def __iter__(self): |
| 54 | + for key in self.wave_dict: |
| 55 | + yield key, self._load(key) |
| 56 | + |
| 57 | + # random index |
| 58 | + def __getitem__(self, key): |
| 59 | + if key not in self.wave_dict: |
| 60 | + raise KeyError("Could not find utterance {}".format(key)) |
| 61 | + return self._load(key) |
| 62 | + |
| 63 | + |
| 64 | +class Dataset(object): |
| 65 | + def __init__(self, mixture_reader, targets_reader_list): |
| 66 | + self.mixture_reader = mixture_reader |
| 67 | + self.keys_list = mixture_reader.wave_keys |
| 68 | + self.targets_reader_list = targets_reader_list |
| 69 | + |
| 70 | + def __len__(self): |
| 71 | + return len(self.keys_list) |
| 72 | + |
| 73 | + def _has_target(self, key): |
| 74 | + for targets_reader in self.targets_reader_list: |
| 75 | + if key not in targets_reader: |
| 76 | + return False |
| 77 | + return True |
| 78 | + |
| 79 | + def _index_by_key(self, key): |
| 80 | + """ |
| 81 | + Return a tuple like (matrix, [matrix, ...]) |
| 82 | + """ |
| 83 | + if key not in self.mixture_reader or not self._has_target(key): |
| 84 | + raise KeyError("Missing targets or mixture") |
| 85 | + target_list = [reader[key] for reader in self.targets_reader_list] |
| 86 | + return (self.mixture_reader[key], target_list) |
| 87 | + |
| 88 | + def _index_by_num(self, num): |
| 89 | + """ |
| 90 | + Return a tuple like (matrix, [matrix, ...]) |
| 91 | + """ |
| 92 | + if num >= len(self.keys_list): |
| 93 | + raise IndexError("Index out of dataset, {} vs {}".format( |
| 94 | + num, len(self.keys_list))) |
| 95 | + key = self.keys_list[num] |
| 96 | + return self._index_by_key(key) |
| 97 | + |
| 98 | + def _index_by_list(self, list_idx): |
| 99 | + """ |
| 100 | + Returns a list of tuple like [ |
| 101 | + (matrix, [matrix, ...]), |
| 102 | + (matrix, [matrix, ...]), |
| 103 | + ... |
| 104 | + ] |
| 105 | + """ |
| 106 | + if max(list_idx) >= len(self.keys_list): |
| 107 | + raise IndexError("Index list contains index out of dataset") |
| 108 | + return [self._index_by_num(index) for index in list_idx] |
| 109 | + |
| 110 | + def __getitem__(self, index): |
| 111 | + if type(index) == int: |
| 112 | + return self._index_by_num(index) |
| 113 | + elif type(index) == str: |
| 114 | + return self._index_by_key(index) |
| 115 | + elif type(index) == list: |
| 116 | + return self._index_by_list(index) |
| 117 | + else: |
| 118 | + raise KeyError("Unsupported index type(int/str/list)") |
| 119 | + |
| 120 | + |
| 121 | +class BatchSampler(object): |
| 122 | + def __init__(self, |
| 123 | + sampler_size, |
| 124 | + batch_size=16, |
| 125 | + shuffle=True, |
| 126 | + drop_last=False): |
| 127 | + if batch_size <= 0: |
| 128 | + raise ValueError( |
| 129 | + "Illegal batch_size(= {}) detected".format(batch_size)) |
| 130 | + self.batch_size = batch_size |
| 131 | + self.drop_last = drop_last |
| 132 | + self.sampler_index = list(range(sampler_size)) |
| 133 | + self.sampler_size = sampler_size |
| 134 | + if shuffle: |
| 135 | + random.shuffle(self.sampler_index) |
| 136 | + |
| 137 | + def __len__(self): |
| 138 | + return self.sampler_size |
| 139 | + |
| 140 | + def __iter__(self): |
| 141 | + base = 0 |
| 142 | + step = self.batch_size |
| 143 | + while True: |
| 144 | + if base + step > self.sampler_size: |
| 145 | + break |
| 146 | + yield (self.sampler_index[base:base + step] |
| 147 | + if step != 1 else self.sampler_index[base]) |
| 148 | + base += step |
| 149 | + if not self.drop_last and base < self.sampler_size: |
| 150 | + yield self.sampler_index[base:] |
| 151 | + |
| 152 | + |
| 153 | +class DataLoader(object): |
| 154 | + """ |
| 155 | + Multi/Per utterance loader for DCNet training |
| 156 | + """ |
| 157 | + |
| 158 | + def __init__(self, |
| 159 | + dataset, |
| 160 | + shuffle=True, |
| 161 | + batch_size=16, |
| 162 | + drop_last=False, |
| 163 | + vad_threshold=40, |
| 164 | + mvn_dict=None): |
| 165 | + self.dataset = dataset |
| 166 | + self.vad_threshold = vad_threshold |
| 167 | + self.mvn_dict = mvn_dict |
| 168 | + self.batch_size = batch_size |
| 169 | + self.drop_last = drop_last |
| 170 | + self.shuffle = shuffle |
| 171 | + if mvn_dict: |
| 172 | + logger.info("Using cmvn dictionary from {}".format(mvn_dict)) |
| 173 | + with open(mvn_dict, "rb") as f: |
| 174 | + self.mvn_dict = pickle.load(f) |
| 175 | + |
| 176 | + def __len__(self): |
| 177 | + remain = len(self.dataset) % self.batch_size |
| 178 | + if self.drop_last or not remain: |
| 179 | + return len(self.dataset) // self.batch_size |
| 180 | + else: |
| 181 | + return len(self.dataset) // self.batch_size + 1 |
| 182 | + |
| 183 | + def _transform(self, mixture_specs, targets_specs_list): |
| 184 | + """ |
| 185 | + Transform from numpy/list to torch types |
| 186 | + """ |
| 187 | + # compute vad mask before cmvn |
| 188 | + vad_mask = compute_vad_mask( |
| 189 | + mixture_specs, self.vad_threshold, apply_exp=True) |
| 190 | + # apply cmvn |
| 191 | + if self.mvn_dict: |
| 192 | + mixture_specs = apply_cmvn(mixture_specs, self.mvn_dict) |
| 193 | + # compute target embedding index |
| 194 | + target_attr = np.argmax(np.array(targets_specs_list), 0) |
| 195 | + return { |
| 196 | + "num_frames": mixture_specs.shape[0], |
| 197 | + "spectrogram": th.tensor(mixture_specs, dtype=th.float32), |
| 198 | + "target_attr": th.tensor(target_attr, dtype=th.int64), |
| 199 | + "silent_mask": th.tensor(vad_mask, dtype=th.float32) |
| 200 | + } |
| 201 | + |
| 202 | + def _process(self, index): |
| 203 | + if type(index) is list: |
| 204 | + dict_list = sorted( |
| 205 | + [self._transform(s, t) for s, t in self.dataset[index]], |
| 206 | + key=lambda x: x["num_frames"], |
| 207 | + reverse=True) |
| 208 | + spectrogram = pack_sequence([d["spectrogram"] for d in dict_list]) |
| 209 | + target_attr = pad_sequence( |
| 210 | + [d["target_attr"] for d in dict_list], batch_first=True) |
| 211 | + silent_mask = pad_sequence( |
| 212 | + [d["silent_mask"] for d in dict_list], batch_first=True) |
| 213 | + return spectrogram, target_attr, silent_mask |
| 214 | + elif type(index) is int: |
| 215 | + s, t = self.dataset[index] |
| 216 | + data_dict = self._transform(s, t) |
| 217 | + return data_dict["spectrogram"], \ |
| 218 | + data_dict["target_attr"], \ |
| 219 | + data_dict["silent_mask"] |
| 220 | + else: |
| 221 | + raise ValueError("Unsupported index type({})".format(type(index))) |
| 222 | + |
| 223 | + def __iter__(self): |
| 224 | + sampler = BatchSampler( |
| 225 | + len(self.dataset), |
| 226 | + batch_size=self.batch_size, |
| 227 | + shuffle=self.shuffle, |
| 228 | + drop_last=self.drop_last) |
| 229 | + num_utts = 0 |
| 230 | + for e, index in enumerate(sampler): |
| 231 | + num_utts += (len(index) if type(index) is list else 1) |
| 232 | + if not (e + 1) % 100: |
| 233 | + logger.info("Processed {} batches, {} utterances".format( |
| 234 | + e + 1, num_utts)) |
| 235 | + yield self._process(index) |
| 236 | + logger.info("Processed {} utterances in total".format(num_utts)) |
0 commit comments