Skip to content

Commit

Permalink
optimize: specify infer parameters (#422)
Browse files Browse the repository at this point in the history
and move infer into core
  • Loading branch information
fumiama authored Jun 24, 2024
1 parent a9cb840 commit 8235a46
Show file tree
Hide file tree
Showing 5 changed files with 270 additions and 240 deletions.
321 changes: 247 additions & 74 deletions ChatTTS/core.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,23 @@
import os
import logging
import tempfile
from typing import Literal, Optional, List, Callable
from dataclasses import dataclass
from typing import Literal, Optional, List, Callable, Tuple
from functools import lru_cache

import numpy as np
import torch
import torch.nn.functional as F
from omegaconf import OmegaConf
from vocos import Vocos
from huggingface_hub import snapshot_download
from transformers.generation import TopKLogitsWarper, TopPLogitsWarper

from .model.dvae import DVAE
from .model.gpt import GPT
from .utils.gpu import select_device
from .model.processors import CustomRepetitionPenaltyLogitsProcessorRepeat
from .utils.io import get_latest_modified_file, del_all
from .infer.api import refine_text, infer_code
from .utils.dl import check_all_assets, download_all_assets
from .utils.log import logger as utils_logger

Expand Down Expand Up @@ -103,6 +107,74 @@ def load_models(
device=device, compile=compile, coef=coef,
**{k: os.path.join(download_path, v) for k, v in OmegaConf.load(os.path.join(download_path, 'config', 'path.yaml')).items()},
)

def unload(self):
logger = self.logger
del_all(self.pretrain_models)
self.normalizer.destroy()
del self.normalizer
self._gen_logits.cache_clear()
del_list = ["vocos", "_vocos_decode", 'gpt', 'decoder', 'dvae']
for module in del_list:
if hasattr(self, module):
delattr(self, module)
self.__init__(logger)

def sample_random_speaker(self):
dim = self.gpt.gpt.layers[0].mlp.gate_proj.in_features
std, mean = self.pretrain_models['spk_stat'].chunk(2)
return torch.randn(dim, device=std.device) * std + mean

@dataclass(repr=False, eq=False)
class RefineTextParams():
prompt: str = ''
top_P: float = 0.7
top_K: int = 20
temperature: float = 0.7
repetition_penalty: float = 1.0
max_new_token: int = 384
min_new_token: int = 0

@dataclass(repr=False, eq=False)
class InferCodeParams():
prompt: str = '[speed_5]'
spk_emb: Optional[torch.Tensor] = None
top_P: float = 0.7
top_K: int = 20
temperature: float = 0.3
repetition_penalty: float = 1.05
max_new_token: int = 2048
min_new_token: int = 0

def infer(
self,
text,
stream=False,
lang=None,
skip_refine_text=False,
refine_text_only=False,
use_decoder=True,
do_text_normalization=True,
do_homophone_replacement=True,
params_refine_text = RefineTextParams(),
params_infer_code = InferCodeParams(),
):
res_gen = self._infer(
text,
stream,
lang,
skip_refine_text,
refine_text_only,
use_decoder,
do_text_normalization,
do_homophone_replacement,
params_refine_text,
params_infer_code,
)
if stream:
return res_gen
else:
return next(res_gen)

def _load(
self,
Expand Down Expand Up @@ -185,34 +257,23 @@ def _load(
self.coef = coef

return self.has_loaded()

def unload(self):
logger = self.logger
del_all(self.pretrain_models)
self.normalizer.destroy()
del self.normalizer
del_list = ["vocos", "_vocos_decode", 'gpt', 'decoder', 'dvae']
for module in del_list:
if hasattr(self, module):
delattr(self, module)
self.__init__(logger)

def _infer(
self,
text,
text,
stream=False,
lang=None,
skip_refine_text=False,
refine_text_only=False,
params_refine_text={},
params_infer_code={'prompt':'[speed_5]'},
refine_text_only=False,
use_decoder=True,
do_text_normalization=True,
lang=None,
stream=False,
do_homophone_replacement=True
do_homophone_replacement=True,
params_refine_text = RefineTextParams(),
params_infer_code = InferCodeParams(),
):

assert self.has_loaded(use_decoder=use_decoder)

if not isinstance(text, list):
text = [text]

Expand All @@ -221,11 +282,8 @@ def _infer(
) for t in text]

if not skip_refine_text:
refined = refine_text(
self.gpt, self.pretrain_models['tokenizer'],
text,
device=self.device,
**params_refine_text,
refined = self._refine_text(
text, self.device, params_refine_text,
)
text_tokens = refined.ids
text_tokens = [i[i < self.pretrain_models['tokenizer'].convert_tokens_to_ids('[break_0]')] for i in text_tokens]
Expand All @@ -235,57 +293,14 @@ def _infer(
yield text
return

text = [params_infer_code.get('prompt', '') + i for i in text]
params_infer_code.pop('prompt', '')

length = [0 for _ in range(len(text))]
for result in infer_code(
self.gpt, self.pretrain_models['tokenizer'],
text,
device=self.device,
**params_infer_code,
return_hidden=use_decoder,
stream=stream,
for result in self._infer_code(
text, stream, self.device, use_decoder, params_infer_code,
):
wav = self.decode_to_wavs(result, length, use_decoder)
wav = self._decode_to_wavs(result, length, use_decoder)
yield wav

def infer(
self,
text,
skip_refine_text=False,
refine_text_only=False,
params_refine_text={},
params_infer_code={'prompt':'[speed_5]'},
use_decoder=True,
do_text_normalization=True,
lang=None,
stream=False,
do_homophone_replacement=True,
):
res_gen = self._infer(
text,
skip_refine_text,
refine_text_only,
params_refine_text,
params_infer_code,
use_decoder,
do_text_normalization,
lang,
stream,
do_homophone_replacement,
)
if stream:
return res_gen
else:
return next(res_gen)

def sample_random_speaker(self):
dim = self.gpt.gpt.layers[0].mlp.gate_proj.in_features
std, mean = self.pretrain_models['spk_stat'].chunk(2)
return torch.randn(dim, device=std.device) * std + mean

def decode_to_wavs(self, result: GPT.GenerationOutputs, start_seeks: List[int], use_decoder: bool):
def _decode_to_wavs(self, result: GPT.GenerationOutputs, start_seeks: List[int], use_decoder: bool):
x = result.hiddens if use_decoder else result.ids
wavs: List[np.ndarray] = []
for i, chunk_data in enumerate(x):
Expand All @@ -304,3 +319,161 @@ def decode_to_wavs(self, result: GPT.GenerationOutputs, start_seeks: List[int],
result.destroy()
del_all(x)
return wavs

def _gen_gpt_inputs(self, text: str, device="cpu"):

gpt = self.gpt
tokenizer = self.pretrain_models['tokenizer']

text_token_tmp = tokenizer(text, return_tensors='pt', add_special_tokens=False, padding=True)
text_token = text_token_tmp.to(device)
del text_token_tmp
input_ids = text_token['input_ids'][...,None].expand(-1, -1, gpt.num_vq)
text_mask = torch.ones(text_token['input_ids'].shape, dtype=bool, device=device)

return input_ids, text_token, text_mask

@lru_cache
def _gen_logits(
self,
num_code: int,
top_P = 0.7,
top_K = 20,
repetition_penalty = 1.0,
):
logits_warpers = []
if top_P is not None:
logits_warpers.append(TopPLogitsWarper(top_P, min_tokens_to_keep=3))
if top_K is not None:
logits_warpers.append(TopKLogitsWarper(top_K, min_tokens_to_keep=3))

logits_processors = []
if repetition_penalty is not None and repetition_penalty != 1:
logits_processors.append(CustomRepetitionPenaltyLogitsProcessorRepeat(\
repetition_penalty, num_code, 16))

return logits_warpers, logits_processors

def _apply_spk_emb(
self,
emb: torch.Tensor,
spk_emb: torch.Tensor,
input_ids: torch.Tensor,
text_len: int,
):

tokenizer = self.pretrain_models['tokenizer']

n = F.normalize(spk_emb.to(emb.dtype)[None].expand(text_len, -1), p=2.0, dim=1, eps=1e-12).to(self.gpt.device_gpt)
emb[input_ids[..., 0] == tokenizer.convert_tokens_to_ids('[spk_emb]')] = n
del n

def _infer_code(
self,
text: Tuple[List[str], str],
stream: bool,
device: torch.device,
return_hidden: bool,
params: InferCodeParams,
):

gpt = self.gpt

if not isinstance(text, list):
text = [text]

assert len(text), 'text should not be empty'

if not isinstance(params.temperature, list):
temperature = [params.temperature] * gpt.num_vq
else:
temperature = params.temperature

if params.prompt:
text = [params.prompt + i for i in text]

if params.spk_emb is not None:
text = [f'[Stts][spk_emb]{i}[Ptts]' for i in text]
else:
text = [f'[Stts][empty_spk]{i}[Ptts]' for i in text]

input_ids, text_token, text_mask = self._gen_gpt_inputs(text, gpt.device_gpt)

emb = gpt(input_ids, text_mask)
del text_mask

if params.spk_emb is not None:
self._apply_spk_emb(emb, params.spk_emb, input_ids, len(text))

num_code = int(gpt.emb_code[0].num_embeddings - 1)

logits_warpers, logits_processors = self._gen_logits(
num_code=num_code,
top_P=params.top_P,
top_K=params.top_K,
repetition_penalty=params.repetition_penalty,
)

result = gpt.generate(
emb, input_ids,
temperature = torch.tensor(temperature, device=device),
eos_token = num_code,
attention_mask = text_token['attention_mask'],
max_new_token = params.max_new_token,
min_new_token = params.min_new_token,
logits_warpers = logits_warpers,
logits_processors = logits_processors,
infer_text = False,
return_hidden=return_hidden,
stream = stream,
)

del_all(text_token)
del emb, text_token, input_ids

return result

def _refine_text(
self,
text: str,
device: torch.device,
params: RefineTextParams,
):

gpt = self.gpt
tokenizer = self.pretrain_models['tokenizer']

if not isinstance(text, list):
text = [text]

text = [f"[Sbreak]{i}[Pbreak]{params.prompt}" for i in text]

input_ids, text_token, text_mask = self._gen_gpt_inputs(text, gpt.device_gpt)

logits_warpers, logits_processors = self._gen_logits(
num_code=len(tokenizer),
top_P=params.top_P,
top_K=params.top_K,
repetition_penalty=params.repetition_penalty,
)

emb = gpt(input_ids, text_mask)
del text_mask

result = gpt.generate(
emb, input_ids,
temperature = torch.tensor([params.temperature], device=device),
eos_token = torch.tensor(tokenizer.convert_tokens_to_ids('[Ebreak]'), device=gpt.device_gpt)[None],
attention_mask = text_token['attention_mask'],
max_new_token = params.max_new_token,
min_new_token=params.min_new_token,
logits_warpers = logits_warpers,
logits_processors = logits_processors,
infer_text = True,
stream = False,
)

del_all(text_token)
del emb, text_token, input_ids

return next(result)
Loading

0 comments on commit 8235a46

Please sign in to comment.