-
Notifications
You must be signed in to change notification settings - Fork 410
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
f7ba918
commit 59a3624
Showing
4 changed files
with
172 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
import random | ||
import requests | ||
import os, glob | ||
|
||
# english literature | ||
books = [ | ||
'https://www.gutenberg.org/cache/epub/1513/pg1513.txt', | ||
'https://www.gutenberg.org/files/2701/2701-0.txt', | ||
'https://www.gutenberg.org/cache/epub/84/pg84.txt', | ||
'https://www.gutenberg.org/cache/epub/2641/pg2641.txt', | ||
'https://www.gutenberg.org/cache/epub/1342/pg1342.txt', | ||
'https://www.gutenberg.org/cache/epub/100/pg100.txt' | ||
] | ||
|
||
#default english | ||
# allowed_chars = ' abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!@#$%^&*()-_+=\"\':;[]{}/<>,.`~\n\\' | ||
|
||
#german | ||
allowed_chars = ' aäbcdefghijklmnoöpqrsßtuüvwxyzABCDEFGHIJKLMNOÖPQRSTUÜVWXYZ0123456789!@#$%^&*()-_+=\"\':;[]{}/<>,.`~\n\\' | ||
|
||
|
||
def download_book(book): | ||
return requests.get(book).content.decode('utf-8') | ||
|
||
|
||
def filter_data(data): | ||
print('Filtering data') | ||
return ''.join([char for char in data if char in allowed_chars]) | ||
|
||
|
||
def load_books(fromfolder=False): | ||
text_data = [] | ||
if fromfolder: | ||
current_working_directory = os.getcwd() | ||
print(current_working_directory) | ||
path = 'text' | ||
for filename in glob.glob(os.path.join(path, '*.txt')): | ||
with open(os.path.join(os.getcwd(), filename), 'r') as f: # open in readonly mode | ||
print(f'Loading {filename}') | ||
text_data.append(filter_data(str(f.read()))) | ||
else: | ||
print(f'Loading {len(books)} books into ram') | ||
for book in books: | ||
text_data.append(filter_data(str(download_book(book)))) | ||
print('Loaded books') | ||
return ' '.join(text_data) | ||
|
||
|
||
def random_split_chunk(data, size=14): | ||
data = data.split(' ') | ||
index = random.randrange(0, len(data)) | ||
return ' '.join(data[index:index+size]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import os | ||
import fnmatch | ||
import shutil | ||
|
||
import numpy | ||
import torchaudio | ||
import gradio | ||
|
||
from bark.hubert.pre_kmeans_hubert import CustomHubert | ||
from bark.hubert.customtokenizer import auto_train | ||
from tqdm.auto import tqdm | ||
|
||
|
||
def training_prepare_files(path, model,progress=gradio.Progress(track_tqdm=True)): | ||
|
||
semanticsfolder = "./training/data/output" | ||
wavfolder = "./training/data/output_wav" | ||
ready = os.path.join(path, 'ready') | ||
|
||
testfiles = fnmatch.filter(os.listdir(ready), '*.npy') | ||
if(len(testfiles) < 1): | ||
# prepare and copy for training | ||
hubert_model = CustomHubert(checkpoint_path=model) | ||
|
||
wavfiles = fnmatch.filter(os.listdir(wavfolder), '*.wav') | ||
for i, f in tqdm(enumerate(wavfiles), total=len(wavfiles)): | ||
semaname = '.'.join(f.split('.')[:-1]) # Cut off the extension | ||
semaname = f'{semaname}.npy' | ||
semafilename = os.path.join(semanticsfolder, semaname) | ||
if not os.path.isfile(semafilename): | ||
print(f'Skipping {f} no semantics pair found!') | ||
continue | ||
|
||
print('Processing', f) | ||
wav, sr = torchaudio.load(os.path.join(wavfolder, f)) | ||
if wav.shape[0] == 2: # Stereo to mono if needed | ||
wav = wav.mean(0, keepdim=True) | ||
output = hubert_model.forward(wav, input_sample_hz=sr) | ||
out_array = output.cpu().numpy() | ||
fname = f'{i}_semantic_features.npy' | ||
numpy.save(os.path.join(ready, fname), out_array) | ||
fname = f'{i}_semantic.npy' | ||
shutil.copy(semafilename, os.path.join(ready, fname)) | ||
|
||
def train(path, save_every, max_epochs): | ||
auto_train(path, save_epochs=save_every) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
import random | ||
import uuid | ||
import numpy | ||
import os | ||
import random | ||
import fnmatch | ||
|
||
from tqdm.auto import tqdm | ||
from scipy.io import wavfile | ||
|
||
from bark.generation import load_model, SAMPLE_RATE | ||
from bark.api import semantic_to_waveform | ||
|
||
from bark import text_to_semantic | ||
from bark.generation import load_model | ||
|
||
from training.data import load_books, random_split_chunk | ||
|
||
output = 'training/data/output' | ||
output_wav = 'training/data/output_wav' | ||
|
||
|
||
def prepare_semantics_from_text(num_generations): | ||
loaded_data = load_books(True) | ||
|
||
print('Loading semantics model') | ||
load_model(use_gpu=True, use_small=False, force_reload=False, model_type='text') | ||
|
||
if not os.path.isdir(output): | ||
os.mkdir(output) | ||
|
||
loop = 1 | ||
while 1: | ||
filename = uuid.uuid4().hex + '.npy' | ||
file_name = os.path.join(output, filename) | ||
text = '' | ||
while not len(text) > 0: | ||
text = random_split_chunk(loaded_data) # Obtain a short chunk of text | ||
text = text.strip() | ||
print(f'{loop} Generating semantics for text:', text) | ||
loop+=1 | ||
semantics = text_to_semantic(text, temp=round(random.uniform(0.6, 0.8), ndigits=2)) | ||
numpy.save(file_name, semantics) | ||
|
||
|
||
def prepare_wavs_from_semantics(): | ||
if not os.path.isdir(output): | ||
raise Exception('No \'output\' folder, make sure you run create_data.py first!') | ||
if not os.path.isdir(output_wav): | ||
os.mkdir(output_wav) | ||
|
||
print('Loading coarse model') | ||
load_model(use_gpu=True, use_small=False, force_reload=False, model_type='coarse') | ||
print('Loading fine model') | ||
load_model(use_gpu=True, use_small=False, force_reload=False, model_type='fine') | ||
|
||
files = fnmatch.filter(os.listdir(output), '*.npy') | ||
current = 1 | ||
total = len(files) | ||
|
||
for i, f in tqdm(enumerate(files), total=len(files)): | ||
real_name = '.'.join(f.split('.')[:-1]) # Cut off the extension | ||
file_name = os.path.join(output, f) | ||
out_file = os.path.join(output_wav, f'{real_name}.wav') | ||
if not os.path.isfile(out_file) and os.path.isfile(file_name): # Don't process files that have already been processed, to be able to continue previous generations | ||
print(f'Processing ({i+1}/{total}) -> {f}') | ||
wav = semantic_to_waveform(numpy.load(file_name), temp=round(random.uniform(0.6, 0.8), ndigits=2)) | ||
# Change to PCM16 | ||
# wav = (wav * 32767).astype(np.int16) | ||
wavfile.write(out_file, SAMPLE_RATE, wav) | ||
|
||
print('Done!') | ||
|