Skip to content

Commit

Permalink
data augmentation files uploaded
Browse files Browse the repository at this point in the history
  • Loading branch information
Miquel Tubau Pires committed Feb 2, 2020
1 parent d11a66d commit 1b3f3ba
Show file tree
Hide file tree
Showing 2 changed files with 168 additions and 0 deletions.
88 changes: 88 additions & 0 deletions scripts/augmentate_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import os
import argparse
import string, random
from shutil import copyfile
import scipy.io.wavfile as wavfile
from wav_functions import readwav

printable = set(string.printable)


def augment_data(wav_data, factor, audio_seconds):
# get different audios until we reach self.data_augmentation
audios = []
centers = []
for id in range(0,factor):
# selecting a center randomly but taking into consideration there needs to be data enough "in the left" and "in the right"
samples_to_retain = audio_seconds * 16000
valid_centers = range(0,len(wav_data))[samples_to_retain/2:-samples_to_retain/2]
center = random.choice(valid_centers)

# if the center is already taken, let's find another one.
while center in centers: center = random.choice(valid_centers)

centers.append(center)
audios.append(wav_data[center-samples_to_retain/2:center+samples_to_retain/2])

return audios

if __name__=="__main__":

parser = argparse.ArgumentParser()

parser.add_argument("--folder_path", required = True, help='path of the folder where to take the data and augment')
parser.add_argument("--factor", required = True, help = 'data augmentation factor')
parser.add_argument('--output_folder', required = True, help = 'folder where to store the augmented data')
parser.add_argument('--audio_seconds', required = True, help='audio of the resulting augmented .wav files')

args = parser.parse_args()

youtubers = [youtuber for youtuber in os.listdir(args.folder_path) if not youtuber.startswith('.')]

id_number = 0

for youtuber in youtubers:

working_path = os.path.join(args.folder_path,youtuber)

# creating if necessary a folder with the youtuber name in the output_folder
if not os.path.exists(os.path.join(args.output_folder,youtuber)):
os.makedirs(os.path.join(args.output_folder,youtuber))

audios = [audio for audio in os.listdir(working_path) if audio.endswith('.wav')]

for audio in audios:

number = audio.split('_')[-1][:-4]

# getting the paths of the original audio and face
audio_path = os.path.join(working_path,audio)
corresponding_face = os.path.join(working_path,'cropped_face_frame_'+number+'.png')

# reading the audio file and converting it into an array
fm, _, wav_data = readwav(audio_path)

if fm != 16000:
raise ValueError('Sampling rate is expected to be 16 KHz!')

if len(wav_data) < 16000 * int(args.audio_seconds):
raise ValueError('The original audio is shorter than the desired output')

# obtaining the audios as a result of the data augmentation
wav_vectors = augment_data(wav_data, int(args.factor), int(args.audio_seconds))

for wav_vec in wav_vectors:

new_path = os.path.join(args.output_folder,youtuber)

# converting wav_vec into a .wav file and storing it
wav_file = wavfile.write(filename=os.path.join(new_path,'preprocessed_frame_'+str(id_number)+'.wav'), rate=16000, data=wav_vec)
new_face_path = os.path.join(new_path,'cropped_face_frame_'+str(id_number)+'.png')


# copying the faces from the older path to the newer

copyfile(corresponding_face, new_face_path)

# updating id_number to avoid overwritting files
id_number += 1
80 changes: 80 additions & 0 deletions scripts/wav_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# wavio.py
# Author: Warren Weckesser
# License: BSD 3-Clause (http://opensource.org/licenses/BSD-3-Clause)

import wave
import numpy as np


def _wav2array(nchannels, sampwidth, data):
"""data must be the string containing the bytes from the wav file."""
num_samples, remainder = divmod(len(data), sampwidth * nchannels)
if remainder > 0:
raise ValueError('The length of data is not a multiple of '
'sampwidth * num_channels.')
if sampwidth > 4:
raise ValueError("sampwidth must not be greater than 4.")

if sampwidth == 3:
a = np.empty((num_samples, nchannels, 4), dtype=np.uint8)
raw_bytes = np.fromstring(data, dtype=np.uint8)
a[:, :, :sampwidth] = raw_bytes.reshape(-1, nchannels, sampwidth)
a[:, :, sampwidth:] = (a[:, :, sampwidth - 1:sampwidth] >> 7) * 255
result = a.view('<i4').reshape(a.shape[:-1])
else:
# 8 bit samples are stored as unsigned ints; others as signed ints.
dt_char = 'u' if sampwidth == 1 else 'i'
a = np.fromstring(data, dtype='<%s%d' % (dt_char, sampwidth))
result = a.reshape(-1, nchannels)
return result


def readwav(file):
"""
Read a wav file.
Returns the frame rate, sample width (in bytes) and a numpy array
containing the data.
This function does not read compressed wav files.
"""
wav = wave.open(file)
rate = wav.getframerate()
nchannels = wav.getnchannels()
sampwidth = wav.getsampwidth()
nframes = wav.getnframes()
data = wav.readframes(nframes)
wav.close()
array = _wav2array(nchannels, sampwidth, data)
return rate, sampwidth, array


def writewav24(filename, rate, data):
"""Create a 24 bit wav file.
data must be "array-like", either 1- or 2-dimensional. If it is 2-d,
the rows are the frames (i.e. samples) and the columns are the channels.
The data is assumed to be signed, and the values are assumed to be
within the range of a 24 bit integer. Floating point values are
converted to integers. The data is not rescaled or normalized before
writing it to the file.
Example: Create a 3 second 440 Hz sine wave.
>>> rate = 22050 # samples per second
>>> T = 3 # sample duration (seconds)
>>> f = 440.0 # sound frequency (Hz)
>>> t = np.linspace(0, T, T*rate, endpoint=False)
>>> x = (2**23 - 1) * np.sin(2 * np.pi * f * t)
>>> writewav24("sine24.wav", rate, x)
"""
a32 = np.asarray(data, dtype=np.int32)
if a32.ndim == 1:
# Convert to a 2D array with a single column.
a32.shape = a32.shape + (1,)
# By shifting first 0 bits, then 8, then 16, the resulting output
# is 24 bit little-endian.
a8 = (a32.reshape(a32.shape + (1,)) >> np.array([0, 8, 16])) & 255
wavdata = a8.astype(np.uint8).tostring()

w = wave.open(filename, 'wb')
w.setnchannels(a32.shape[1])
w.setsampwidth(3)
w.setframerate(rate)
w.writeframes(wavdata)
w.close()

0 comments on commit 1b3f3ba

Please sign in to comment.