Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optimize: spk_emb & generate & webui #460

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions .github/workflows/unitest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@ jobs:
with:
python-version: ${{ matrix.python-version }}

- name: Install dependencies
run: |
pip install -r requirements.txt
- name: Test Install
run: pip install .

- name: Install Dependencies
run: pip install -r requirements.txt

- name: Run Test
run: |
Expand Down
110 changes: 57 additions & 53 deletions ChatTTS/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,11 @@ def _load(
tokenizer = torch.load(tokenizer_path, map_location=device, mmap=True)
tokenizer.padding_side = "left"
self.pretrain_models["tokenizer"] = tokenizer
self.tokenizer_len = len(tokenizer)
self.tokenizer_spk_emb_ids: torch.Tensor = tokenizer.convert_tokens_to_ids("[spk_emb]")
self.tokenizer_eos_token: torch.Tensor = torch.tensor(
tokenizer.convert_tokens_to_ids("[Ebreak]"), device=gpt.device_gpt
).unsqueeze_(0)
self.logger.log(logging.INFO, "tokenizer loaded.")

self.coef = coef
Expand Down Expand Up @@ -342,38 +347,40 @@ def _infer(
for t in text
]

if not skip_refine_text:
refined = self._refine_text(
with torch.no_grad():

if not skip_refine_text:
refined = self._refine_text(
text,
self.device,
params_refine_text,
)
text_tokens = refined.ids
text_tokens = [
i[
i
< self.pretrain_models["tokenizer"].convert_tokens_to_ids(
"[break_0]"
)
]
for i in text_tokens
]
text = self.pretrain_models["tokenizer"].batch_decode(text_tokens)
refined.destroy()
if refine_text_only:
yield text
return

length = [0 for _ in range(len(text))]
for result in self._infer_code(
text,
stream,
self.device,
params_refine_text,
)
text_tokens = refined.ids
text_tokens = [
i[
i
< self.pretrain_models["tokenizer"].convert_tokens_to_ids(
"[break_0]"
)
]
for i in text_tokens
]
text = self.pretrain_models["tokenizer"].batch_decode(text_tokens)
refined.destroy()
if refine_text_only:
yield text
return

length = [0 for _ in range(len(text))]
for result in self._infer_code(
text,
stream,
self.device,
use_decoder,
params_infer_code,
):
wav = self._decode_to_wavs(result, length, use_decoder)
yield wav
use_decoder,
params_infer_code,
):
wav = self._decode_to_wavs(result, length, use_decoder)
yield wav

def _decode_to_wavs(
self, result: GPT.GenerationOutputs, start_seeks: List[int], use_decoder: bool
Expand All @@ -397,7 +404,7 @@ def _decode_to_wavs(
del_all(x)
return wavs

def _gen_gpt_inputs(self, text: str, device="cpu"):
def _text_to_token(self, text: str, device="cpu") -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:

gpt = self.gpt
tokenizer = self.pretrain_models["tokenizer"]
Expand All @@ -407,10 +414,14 @@ def _gen_gpt_inputs(self, text: str, device="cpu"):
)
text_token = text_token_tmp.to(device)
del text_token_tmp
input_ids = text_token["input_ids"][..., None].expand(-1, -1, gpt.num_vq)

input_ids = text_token["input_ids"].unsqueeze(-1).expand(-1, -1, gpt.num_vq)
text_mask = torch.ones(text_token["input_ids"].shape, dtype=bool, device=device)
attention_mask = text_token["attention_mask"]

del_all(text_token)

return input_ids, text_token, text_mask
return input_ids, attention_mask, text_mask

def _apply_spk_emb(
self,
Expand All @@ -419,14 +430,12 @@ def _apply_spk_emb(
input_ids: torch.Tensor,
text_len: int,
):

tokenizer = self.pretrain_models["tokenizer"]

n = F.normalize(
spk_emb.to(emb.dtype)[None].expand(text_len, -1), p=2.0, dim=1, eps=1e-12
).to(self.gpt.device_gpt)
emb[input_ids[..., 0] == tokenizer.convert_tokens_to_ids("[spk_emb]")] = n
del n
spk_emb.unsqueeze(0).expand(text_len, -1), p=2.0, dim=1, eps=1e-12
).to(self.gpt.device_gpt).expand(emb.shape)
cond = input_ids.narrow(-1, 0, 1).eq(self.tokenizer_spk_emb_ids).expand(emb.shape)
torch.where(cond, n, emb, out=emb)
del cond, n

def _infer_code(
self,
Expand Down Expand Up @@ -457,7 +466,7 @@ def _infer_code(
else:
text = [f"[Stts][empty_spk]{i}[Ptts]" for i in text]

input_ids, text_token, text_mask = self._gen_gpt_inputs(text, gpt.device_gpt)
input_ids, attention_mask, text_mask = self._text_to_token(text, gpt.device_gpt)

emb = gpt(input_ids, text_mask)
del text_mask
Expand All @@ -479,7 +488,7 @@ def _infer_code(
input_ids,
temperature=torch.tensor(temperature, device=device),
eos_token=num_code,
attention_mask=text_token["attention_mask"],
attention_mask=attention_mask,
max_new_token=params.max_new_token,
min_new_token=params.min_new_token,
logits_warpers=logits_warpers,
Expand All @@ -490,8 +499,7 @@ def _infer_code(
context=self.context,
)

del_all(text_token)
del emb, text_token, input_ids
del emb, input_ids
del_all(logits_warpers)
del_all(logits_processors)

Expand All @@ -505,17 +513,16 @@ def _refine_text(
):

gpt = self.gpt
tokenizer = self.pretrain_models["tokenizer"]

if not isinstance(text, list):
text = [text]

text = [f"[Sbreak]{i}[Pbreak]{params.prompt}" for i in text]

input_ids, text_token, text_mask = self._gen_gpt_inputs(text, gpt.device_gpt)
input_ids, attention_mask, text_mask = self._text_to_token(text, gpt.device_gpt)

logits_warpers, logits_processors = gen_logits(
num_code=len(tokenizer),
num_code=self.tokenizer_len,
top_P=params.top_P,
top_K=params.top_K,
repetition_penalty=params.repetition_penalty,
Expand All @@ -528,10 +535,8 @@ def _refine_text(
emb,
input_ids,
temperature=torch.tensor([params.temperature], device=device),
eos_token=torch.tensor(
tokenizer.convert_tokens_to_ids("[Ebreak]"), device=gpt.device_gpt
)[None],
attention_mask=text_token["attention_mask"],
eos_token=self.tokenizer_eos_token,
attention_mask=attention_mask,
max_new_token=params.max_new_token,
min_new_token=params.min_new_token,
logits_warpers=logits_warpers,
Expand All @@ -541,8 +546,7 @@ def _refine_text(
context=self.context,
)

del_all(text_token)
del emb, text_token, input_ids
del emb, input_ids
del_all(logits_warpers)
del_all(logits_processors)

Expand Down
10 changes: 5 additions & 5 deletions ChatTTS/model/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def _prepare_generation_inputs(
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
position_ids.masked_fill_(attention_mask.eq(0), 1)
if past_key_values:
position_ids = position_ids.narrow(1, -input_ids.shape[1], input_ids.shape[1])

Expand Down Expand Up @@ -321,7 +321,7 @@ def generate(
inputs_ids: torch.Tensor,
temperature: torch.Tensor,
eos_token: Union[int, torch.Tensor],
attention_mask=None,
attention_mask: Optional[torch.Tensor]=None,
max_new_token=2048,
min_new_token=0,
logits_warpers: List[LogitsWarper] = [],
Expand Down Expand Up @@ -469,14 +469,14 @@ def generate(
if not infer_text:
# idx_next = rearrange(idx_next, "(b n) 1 -> b n", n=self.num_vq)
idx_next = idx_next.view(-1, self.num_vq)
finish_or = (idx_next == eos_token).any(1)
finish_or = idx_next.eq(eos_token).any(1)
finish.logical_or_(finish_or)
del finish_or
inputs_ids_tmp = torch.cat(
[inputs_ids, idx_next.unsqueeze_(1)], 1
)
else:
finish_or = (idx_next == eos_token).any(1)
finish_or = idx_next.eq(eos_token).any(1)
finish.logical_or_(finish_or)
del finish_or
inputs_ids_tmp = torch.cat(
Expand All @@ -497,7 +497,7 @@ def generate(
if stream:
if (
end_idx.all()
and (end_idx % 24 == 0).any()
and end_idx.fmod(24).eq(0).any()
and minus_prev_end_index.add_(end_idx).any()
):
self.logger.debug("yield stream result, end: %d", end_idx)
Expand Down
1 change: 0 additions & 1 deletion examples/ipynb/colab.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,6 @@
"\n",
"wav = chat.infer(\n",
" \"四川美食确实以辣闻名,但也有不辣的选择。比如甜水面、赖汤圆、蛋烘糕、叶儿粑等,这些小吃口味温和,甜而不腻,也很受欢迎。\",\n",
" params_refine_text=params_refine_text,\n",
" params_infer_code=params_infer_code,\n",
")"
]
Expand Down
1 change: 0 additions & 1 deletion examples/ipynb/example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,6 @@
"\n",
"wav = chat.infer(\n",
" \"四川美食确实以辣闻名,但也有不辣的选择。比如甜水面、赖汤圆、蛋烘糕、叶儿粑等,这些小吃口味温和,甜而不腻,也很受欢迎。\",\n",
" params_refine_text=params_refine_text,\n",
" params_infer_code=params_infer_code,\n",
")"
]
Expand Down
37 changes: 18 additions & 19 deletions examples/web/funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,43 +95,34 @@ def reload_chat(coef: Optional[str]) -> str:
return chat.coef


def set_generate_buttons(generate_button, interrupt_button, is_reset=False):
def _set_generate_buttons(generate_button, interrupt_button, is_reset=False):
return gr.update(
value=generate_button, visible=is_reset, interactive=is_reset
), gr.update(value=interrupt_button, visible=not is_reset, interactive=not is_reset)


def refine_text(
text, text_seed_input, refine_text_flag, generate_button, interrupt_button
text, text_seed_input, refine_text_flag,
):
global chat, has_interrupted
has_interrupted = False
global chat

if not refine_text_flag:
sleep(1) # to skip fast answer of loading mark
return text, *set_generate_buttons(
generate_button, interrupt_button, is_reset=True
)
return text

with TorchSeedContext(text_seed_input):
text = chat.infer(
text,
skip_refine_text=False,
refine_text_only=True,
)
return text[0] if isinstance(text, list) else text, *set_generate_buttons(
generate_button, interrupt_button, is_reset=True
)


def text_output_listener(generate_button, interrupt_button):
return set_generate_buttons(generate_button, interrupt_button)

return text[0] if isinstance(text, list) else text

def generate_audio(text, temperature, top_P, top_K, audio_seed_input, stream):
global chat, has_interrupted

if not text or text == "𝕃𝕠𝕒𝕕𝕚𝕟𝕘..." or has_interrupted:
if not text or has_interrupted:
return None

with TorchSeedContext(audio_seed_input):
Expand All @@ -157,9 +148,8 @@ def generate_audio(text, temperature, top_P, top_K, audio_seed_input, stream):
if audio is not None and len(audio) > 0:
yield 24000, unsafe_float_to_int16(audio[0])
del audio
return

yield 24000, unsafe_float_to_int16(np.array(wav[0]).flatten())
else:
yield 24000, unsafe_float_to_int16(np.array(wav[0]).flatten())


def interrupt_generate():
Expand All @@ -168,11 +158,20 @@ def interrupt_generate():
has_interrupted = True
chat.interrupt()

def set_buttons_before_generate(generate_button, interrupt_button):
global has_interrupted

has_interrupted = False

return _set_generate_buttons(
generate_button,
interrupt_button,
)

def set_buttons_after_generate(generate_button, interrupt_button, audio_output):
global has_interrupted

return set_generate_buttons(
return _set_generate_buttons(
generate_button,
interrupt_button,
audio_output is not None or has_interrupted,
Expand Down
Loading
Loading