From 3377f1bf6d6e00d2486120d68c8aaaa4cecfb298 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Wed, 26 Jun 2024 17:29:34 +0900 Subject: [PATCH] fix: redundant `[spk_emb]`s in refine_text (fix #459) --- ChatTTS/core.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/ChatTTS/core.py b/ChatTTS/core.py index 02d12e8c5..bcb2bdf49 100644 --- a/ChatTTS/core.py +++ b/ChatTTS/core.py @@ -329,6 +329,7 @@ def _load( self.pretrain_models["tokenizer"] = tokenizer self.tokenizer_len = len(tokenizer) self.tokenizer_spk_emb_ids: torch.Tensor = tokenizer.convert_tokens_to_ids("[spk_emb]") + self.tokenizer_break_0_ids: torch.Tensor = tokenizer.convert_tokens_to_ids("[break_0]") self.tokenizer_eos_token: torch.Tensor = torch.tensor( tokenizer.convert_tokens_to_ids("[Ebreak]"), device=gpt.device_gpt ).unsqueeze_(0) @@ -377,12 +378,7 @@ def _infer( ) text_tokens = refined.ids text_tokens = [ - i[ - i - < self.pretrain_models["tokenizer"].convert_tokens_to_ids( - "[break_0]" - ) - ] + i[i.less(self.tokenizer_break_0_ids)] for i in text_tokens ] text = self.pretrain_models["tokenizer"].batch_decode(text_tokens) @@ -458,7 +454,7 @@ def _apply_spk_emb( filters=[{"id": lzma.FILTER_LZMA2, "preset": 9 | lzma.PRESET_EXTREME}], ), dtype=np.float16).copy(), ).unsqueeze(0).expand(text_len, -1), p=2.0, dim=1, eps=1e-12 - ).to(self.gpt.device_gpt).expand(emb.shape) + ).to(self.gpt.device_gpt).unsqueeze_(1).expand(emb.shape) cond = input_ids.narrow(-1, 0, 1).eq(self.tokenizer_spk_emb_ids).expand(emb.shape) torch.where(cond, n, emb, out=emb) del cond, n @@ -483,6 +479,12 @@ def _infer_code( temperature = [params.temperature] * gpt.num_vq else: temperature = params.temperature + + for i, t in enumerate(text): + text[i] = t.replace('[Stts]', '').replace('[spk_emb]', '').replace('[empty_spk]', '').strip() + """ + see https://github.com/2noise/ChatTTS/issues/459 + """ if params.prompt: text = [params.prompt + i for i in text] @@ -557,7 +559,7 @@ def _refine_text( emb = gpt(input_ids, text_mask) del text_mask - result = gpt.generate( + result = next(gpt.generate( emb, input_ids, temperature=torch.tensor([params.temperature], device=device), @@ -570,10 +572,10 @@ def _refine_text( infer_text=True, stream=False, context=self.context, - ) + )) del emb, input_ids del_all(logits_warpers) del_all(logits_processors) - return next(result) + return result