-
Notifications
You must be signed in to change notification settings - Fork 45
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Miquel Tubau Pires
committed
Feb 2, 2020
1 parent
d11a66d
commit 1b3f3ba
Showing
2 changed files
with
168 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
import os | ||
import argparse | ||
import string, random | ||
from shutil import copyfile | ||
import scipy.io.wavfile as wavfile | ||
from wav_functions import readwav | ||
|
||
printable = set(string.printable) | ||
|
||
|
||
def augment_data(wav_data, factor, audio_seconds): | ||
# get different audios until we reach self.data_augmentation | ||
audios = [] | ||
centers = [] | ||
for id in range(0,factor): | ||
# selecting a center randomly but taking into consideration there needs to be data enough "in the left" and "in the right" | ||
samples_to_retain = audio_seconds * 16000 | ||
valid_centers = range(0,len(wav_data))[samples_to_retain/2:-samples_to_retain/2] | ||
center = random.choice(valid_centers) | ||
|
||
# if the center is already taken, let's find another one. | ||
while center in centers: center = random.choice(valid_centers) | ||
|
||
centers.append(center) | ||
audios.append(wav_data[center-samples_to_retain/2:center+samples_to_retain/2]) | ||
|
||
return audios | ||
|
||
if __name__=="__main__": | ||
|
||
parser = argparse.ArgumentParser() | ||
|
||
parser.add_argument("--folder_path", required = True, help='path of the folder where to take the data and augment') | ||
parser.add_argument("--factor", required = True, help = 'data augmentation factor') | ||
parser.add_argument('--output_folder', required = True, help = 'folder where to store the augmented data') | ||
parser.add_argument('--audio_seconds', required = True, help='audio of the resulting augmented .wav files') | ||
|
||
args = parser.parse_args() | ||
|
||
youtubers = [youtuber for youtuber in os.listdir(args.folder_path) if not youtuber.startswith('.')] | ||
|
||
id_number = 0 | ||
|
||
for youtuber in youtubers: | ||
|
||
working_path = os.path.join(args.folder_path,youtuber) | ||
|
||
# creating if necessary a folder with the youtuber name in the output_folder | ||
if not os.path.exists(os.path.join(args.output_folder,youtuber)): | ||
os.makedirs(os.path.join(args.output_folder,youtuber)) | ||
|
||
audios = [audio for audio in os.listdir(working_path) if audio.endswith('.wav')] | ||
|
||
for audio in audios: | ||
|
||
number = audio.split('_')[-1][:-4] | ||
|
||
# getting the paths of the original audio and face | ||
audio_path = os.path.join(working_path,audio) | ||
corresponding_face = os.path.join(working_path,'cropped_face_frame_'+number+'.png') | ||
|
||
# reading the audio file and converting it into an array | ||
fm, _, wav_data = readwav(audio_path) | ||
|
||
if fm != 16000: | ||
raise ValueError('Sampling rate is expected to be 16 KHz!') | ||
|
||
if len(wav_data) < 16000 * int(args.audio_seconds): | ||
raise ValueError('The original audio is shorter than the desired output') | ||
|
||
# obtaining the audios as a result of the data augmentation | ||
wav_vectors = augment_data(wav_data, int(args.factor), int(args.audio_seconds)) | ||
|
||
for wav_vec in wav_vectors: | ||
|
||
new_path = os.path.join(args.output_folder,youtuber) | ||
|
||
# converting wav_vec into a .wav file and storing it | ||
wav_file = wavfile.write(filename=os.path.join(new_path,'preprocessed_frame_'+str(id_number)+'.wav'), rate=16000, data=wav_vec) | ||
new_face_path = os.path.join(new_path,'cropped_face_frame_'+str(id_number)+'.png') | ||
|
||
|
||
# copying the faces from the older path to the newer | ||
|
||
copyfile(corresponding_face, new_face_path) | ||
|
||
# updating id_number to avoid overwritting files | ||
id_number += 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
# wavio.py | ||
# Author: Warren Weckesser | ||
# License: BSD 3-Clause (http://opensource.org/licenses/BSD-3-Clause) | ||
|
||
import wave | ||
import numpy as np | ||
|
||
|
||
def _wav2array(nchannels, sampwidth, data): | ||
"""data must be the string containing the bytes from the wav file.""" | ||
num_samples, remainder = divmod(len(data), sampwidth * nchannels) | ||
if remainder > 0: | ||
raise ValueError('The length of data is not a multiple of ' | ||
'sampwidth * num_channels.') | ||
if sampwidth > 4: | ||
raise ValueError("sampwidth must not be greater than 4.") | ||
|
||
if sampwidth == 3: | ||
a = np.empty((num_samples, nchannels, 4), dtype=np.uint8) | ||
raw_bytes = np.fromstring(data, dtype=np.uint8) | ||
a[:, :, :sampwidth] = raw_bytes.reshape(-1, nchannels, sampwidth) | ||
a[:, :, sampwidth:] = (a[:, :, sampwidth - 1:sampwidth] >> 7) * 255 | ||
result = a.view('<i4').reshape(a.shape[:-1]) | ||
else: | ||
# 8 bit samples are stored as unsigned ints; others as signed ints. | ||
dt_char = 'u' if sampwidth == 1 else 'i' | ||
a = np.fromstring(data, dtype='<%s%d' % (dt_char, sampwidth)) | ||
result = a.reshape(-1, nchannels) | ||
return result | ||
|
||
|
||
def readwav(file): | ||
""" | ||
Read a wav file. | ||
Returns the frame rate, sample width (in bytes) and a numpy array | ||
containing the data. | ||
This function does not read compressed wav files. | ||
""" | ||
wav = wave.open(file) | ||
rate = wav.getframerate() | ||
nchannels = wav.getnchannels() | ||
sampwidth = wav.getsampwidth() | ||
nframes = wav.getnframes() | ||
data = wav.readframes(nframes) | ||
wav.close() | ||
array = _wav2array(nchannels, sampwidth, data) | ||
return rate, sampwidth, array | ||
|
||
|
||
def writewav24(filename, rate, data): | ||
"""Create a 24 bit wav file. | ||
data must be "array-like", either 1- or 2-dimensional. If it is 2-d, | ||
the rows are the frames (i.e. samples) and the columns are the channels. | ||
The data is assumed to be signed, and the values are assumed to be | ||
within the range of a 24 bit integer. Floating point values are | ||
converted to integers. The data is not rescaled or normalized before | ||
writing it to the file. | ||
Example: Create a 3 second 440 Hz sine wave. | ||
>>> rate = 22050 # samples per second | ||
>>> T = 3 # sample duration (seconds) | ||
>>> f = 440.0 # sound frequency (Hz) | ||
>>> t = np.linspace(0, T, T*rate, endpoint=False) | ||
>>> x = (2**23 - 1) * np.sin(2 * np.pi * f * t) | ||
>>> writewav24("sine24.wav", rate, x) | ||
""" | ||
a32 = np.asarray(data, dtype=np.int32) | ||
if a32.ndim == 1: | ||
# Convert to a 2D array with a single column. | ||
a32.shape = a32.shape + (1,) | ||
# By shifting first 0 bits, then 8, then 16, the resulting output | ||
# is 24 bit little-endian. | ||
a8 = (a32.reshape(a32.shape + (1,)) >> np.array([0, 8, 16])) & 255 | ||
wavdata = a8.astype(np.uint8).tostring() | ||
|
||
w = wave.open(filename, 'wb') | ||
w.setnchannels(a32.shape[1]) | ||
w.setsampwidth(3) | ||
w.setframerate(rate) | ||
w.writeframes(wavdata) | ||
w.close() |