Skip to content

Commit

Permalink
feat(webui): support external spk_emb (#467)
Browse files Browse the repository at this point in the history
  • Loading branch information
fumiama authored Jun 26, 2024
1 parent cc58be2 commit c00ba37
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 30 deletions.
62 changes: 38 additions & 24 deletions examples/web/funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
custom_path: Optional[str] = None

has_interrupted = False
is_in_generate = False

seed_min = 1
seed_max = 4294967295

# 音色选项:用于预置合适的音色
voices = {
Expand All @@ -38,13 +42,18 @@


def generate_seed():
return gr.update(value=random.randint(1, 100000000))
return gr.update(value=random.randint(seed_min, seed_max))


# 返回选择音色对应的seed
def on_voice_change(vocie_selection):
return voices.get(vocie_selection)["seed"]

def on_audio_seed_change(audio_seed_input):
with TorchSeedContext(audio_seed_input):
rand_spk = chat.sample_random_speaker()
return rand_spk


def load_chat(cust_path: Optional[str], coef: Optional[str]) -> bool:
if cust_path == None:
Expand Down Expand Up @@ -79,6 +88,12 @@ def load_chat(cust_path: Optional[str], coef: Optional[str]) -> bool:


def reload_chat(coef: Optional[str]) -> str:
global is_in_generate

if is_in_generate:
gr.Warning("Cannot reload when generating!")
return coef

chat.unload()
gr.Info("Model unloaded.")
if len(coef) != 230:
Expand Down Expand Up @@ -119,37 +134,33 @@ def refine_text(

return text[0] if isinstance(text, list) else text

def generate_audio(text, temperature, top_P, top_K, audio_seed_input, stream):
def generate_audio(text, temperature, top_P, top_K, spk_emb_text: str, stream):
global chat, has_interrupted

if not text or has_interrupted:
if not text or has_interrupted or not spk_emb_text.startswith("蘁淰"):
return None

with TorchSeedContext(audio_seed_input):
rand_spk = chat.sample_random_speaker()

params_infer_code = ChatTTS.Chat.InferCodeParams(
spk_emb=rand_spk,
spk_emb=spk_emb_text,
temperature=temperature,
top_P=top_P,
top_K=top_K,
)

with TorchSeedContext(audio_seed_input):
wav = chat.infer(
text,
skip_refine_text=True,
params_infer_code=params_infer_code,
stream=stream,
)
if stream:
for gen in wav:
audio = gen[0]
if audio is not None and len(audio) > 0:
yield wav_arr_to_mp3_view(audio[0]).tobytes()
del audio
else:
yield wav_arr_to_mp3_view(np.array(wav[0]).flatten()).tobytes()
wav = chat.infer(
text,
skip_refine_text=True,
params_infer_code=params_infer_code,
stream=stream,
)
if stream:
for gen in wav:
audio = gen[0]
if audio is not None and len(audio) > 0:
yield wav_arr_to_mp3_view(audio[0]).tobytes()
del audio
else:
yield wav_arr_to_mp3_view(np.array(wav[0]).flatten()).tobytes()


def interrupt_generate():
Expand All @@ -159,17 +170,20 @@ def interrupt_generate():
chat.interrupt()

def set_buttons_before_generate(generate_button, interrupt_button):
global has_interrupted
global has_interrupted, is_in_generate

has_interrupted = False
is_in_generate = True

return _set_generate_buttons(
generate_button,
interrupt_button,
)

def set_buttons_after_generate(generate_button, interrupt_button, audio_output):
global has_interrupted
global has_interrupted, is_in_generate

is_in_generate = False

return _set_generate_buttons(
generate_button,
Expand Down
29 changes: 23 additions & 6 deletions examples/web/webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,31 @@ def main():
voice_selection = gr.Dropdown(
label="Timbre", choices=voices.keys(), value="Default"
)
audio_seed_input = gr.Number(value=2, label="Audio Seed", interactive=True)
audio_seed_input = gr.Number(
value=2, label="Audio Seed", interactive=True,
minimum=seed_min, maximum=seed_max,
)
generate_audio_seed = gr.Button("\U0001F3B2")
text_seed_input = gr.Number(value=42, label="Text Seed")
text_seed_input = gr.Number(
value=42, label="Text Seed", interactive=True,
minimum=seed_min, maximum=seed_max,
)
generate_text_seed = gr.Button("\U0001F3B2")

with gr.Row():
spk_emb_text = gr.Textbox(
label="Speaker Embedding",
max_lines=3,
show_copy_button=True,
interactive=True,
scale=2,
)
dvae_coef_text = gr.Textbox(
label="DVAE Coefficient",
max_lines=3,
show_copy_button=True,
scale=4,
interactive=True,
scale=2,
)
reload_chat_button = gr.Button("Reload", scale=1)

Expand All @@ -88,9 +102,11 @@ def main():
fn=on_voice_change, inputs=voice_selection, outputs=audio_seed_input
)

generate_audio_seed.click(generate_seed, inputs=[], outputs=audio_seed_input)
generate_audio_seed.click(generate_seed, outputs=audio_seed_input)

generate_text_seed.click(generate_seed, outputs=text_seed_input)

generate_text_seed.click(generate_seed, inputs=[], outputs=text_seed_input)
audio_seed_input.change(on_audio_seed_change, inputs=audio_seed_input, outputs=spk_emb_text)

reload_chat_button.click(
reload_chat, inputs=dvae_coef_text, outputs=dvae_coef_text
Expand Down Expand Up @@ -123,7 +139,7 @@ def make_audio(autoplay, stream):
temperature_slider,
top_p_slider,
top_k_slider,
audio_seed_input,
spk_emb_text,
stream_mode_checkbox,
],
outputs=audio_output,
Expand Down Expand Up @@ -168,6 +184,7 @@ def make_audio(autoplay, stream):
logger.error("Models load failed.")
sys.exit(1)

spk_emb_text.value = on_audio_seed_change(audio_seed_input.value)
dvae_coef_text.value = chat.coef

demo.launch(
Expand Down

0 comments on commit c00ba37

Please sign in to comment.