diff --git a/ChatTTS/core.py b/ChatTTS/core.py index d5d11244f..02d12e8c5 100644 --- a/ChatTTS/core.py +++ b/ChatTTS/core.py @@ -5,6 +5,7 @@ from typing import Literal, Optional, List, Callable, Tuple, Dict from json import load from pathlib import Path +import lzma import numpy as np import torch @@ -12,6 +13,7 @@ from omegaconf import OmegaConf from vocos import Vocos from huggingface_hub import snapshot_download +import pybase16384 as b14 from .model import DVAE, GPT, gen_logits from .utils import ( @@ -151,10 +153,28 @@ def unload(self): delattr(self, module) self.__init__(logger) - def sample_random_speaker(self): - dim = self.gpt.gpt.layers[0].mlp.gate_proj.in_features - std, mean = self.pretrain_models["spk_stat"].chunk(2) - return torch.randn(dim, device=std.device) * std + mean + def sample_random_speaker(self) -> str: + with torch.no_grad(): + spk = self._sample_random_speaker() + arr: np.ndarray = spk.cpu().numpy() + s = b14.encode_to_string( + lzma.compress( + arr.tobytes(), + format=lzma.FORMAT_RAW, + filters=[{"id": lzma.FILTER_LZMA2, "preset": 9 | lzma.PRESET_EXTREME}], + ), + ) + del arr, spk + return s + + def _sample_random_speaker(self) -> torch.Tensor: + with torch.no_grad(): + dim: int = self.gpt.gpt.layers[0].mlp.gate_proj.in_features + out: torch.Tensor = self.pretrain_models["spk_stat"] + std, mean = out.chunk(2) + spk = torch.randn(dim, device=std.device, dtype=torch.float16).mul_(std).add_(mean) + del out, std, mean + return spk @dataclass(repr=False, eq=False) class RefineTextParams: @@ -169,7 +189,7 @@ class RefineTextParams: @dataclass(repr=False, eq=False) class InferCodeParams: prompt: str = "[speed_5]" - spk_emb: Optional[torch.Tensor] = None + spk_emb: Optional[str] = None top_P: float = 0.7 top_K: int = 20 temperature: float = 0.3 @@ -426,12 +446,18 @@ def _text_to_token(self, text: str, device="cpu") -> Tuple[torch.Tensor, torch.T def _apply_spk_emb( self, emb: torch.Tensor, - spk_emb: torch.Tensor, + spk_emb: str, input_ids: torch.Tensor, text_len: int, ): n = F.normalize( - spk_emb.unsqueeze(0).expand(text_len, -1), p=2.0, dim=1, eps=1e-12 + torch.from_numpy( + np.frombuffer(lzma.decompress( + b14.decode_from_string(spk_emb), + format=lzma.FORMAT_RAW, + filters=[{"id": lzma.FILTER_LZMA2, "preset": 9 | lzma.PRESET_EXTREME}], + ), dtype=np.float16).copy(), + ).unsqueeze(0).expand(text_len, -1), p=2.0, dim=1, eps=1e-12 ).to(self.gpt.device_gpt).expand(emb.shape) cond = input_ids.narrow(-1, 0, 1).eq(self.tokenizer_spk_emb_ids).expand(emb.shape) torch.where(cond, n, emb, out=emb) diff --git a/README.md b/README.md index 8b261f398..f2158e2c0 100644 --- a/README.md +++ b/README.md @@ -131,6 +131,7 @@ torchaudio.save("output1.wav", torch.from_numpy(wavs[0]), 24000) # Sample a speaker from Gaussian. rand_spk = chat.sample_random_speaker() +print(rand_spk) # save it for later timbre recovery params_infer_code = ChatTTS.Chat.InferCodeParams( spk_emb = rand_spk, # add sampled speaker diff --git a/docs/cn/README.md b/docs/cn/README.md index 6f1d6ae57..fb45ed2f0 100644 --- a/docs/cn/README.md +++ b/docs/cn/README.md @@ -142,6 +142,7 @@ torchaudio.save("output1.wav", torch.from_numpy(wavs[0]), 24000) # Sample a speaker from Gaussian. rand_spk = chat.sample_random_speaker() +print(rand_spk) # save it for later timbre recovery params_infer_code = { 'spk_emb': rand_spk, # add sampled speaker diff --git a/docs/es/README.md b/docs/es/README.md index daf2da526..ec479be3c 100644 --- a/docs/es/README.md +++ b/docs/es/README.md @@ -139,6 +139,7 @@ torchaudio.save("output1.wav", torch.from_numpy(wavs[0]), 24000) # Sample a speaker from Gaussian. rand_spk = chat.sample_random_speaker() +print(rand_spk) # save it for later timbre recovery params_infer_code = ChatTTS.Chat.InferCodeParams( spk_emb = rand_spk, # add sampled speaker diff --git a/docs/jp/README.md b/docs/jp/README.md index 0f40c3376..51c6888d4 100644 --- a/docs/jp/README.md +++ b/docs/jp/README.md @@ -53,6 +53,7 @@ torchaudio.save("output1.wav", torch.from_numpy(wavs[0]), 24000) # ガウス分布から話者をサンプリングします。 rand_spk = chat.sample_random_speaker() +print(rand_spk) # save it for later timbre recovery params_infer_code = { 'spk_emb': rand_spk, # サンプリングされた話者を追加 diff --git a/docs/ru/README.md b/docs/ru/README.md index d0d62b61d..40ded2752 100644 --- a/docs/ru/README.md +++ b/docs/ru/README.md @@ -53,6 +53,7 @@ torchaudio.save("output1.wav", torch.from_numpy(wavs[0]), 24000) # Выборка говорящего из Гауссиана. rand_spk = chat.sample_random_speaker() +print(rand_spk) # save it for later timbre recovery params_infer_code = { 'spk_emb': rand_spk, # добавить выбранного говорящего diff --git a/examples/ipynb/colab.ipynb b/examples/ipynb/colab.ipynb index ab11fec72..78899c800 100644 --- a/examples/ipynb/colab.ipynb +++ b/examples/ipynb/colab.ipynb @@ -304,6 +304,8 @@ "outputs": [], "source": [ "rand_spk = chat.sample_random_speaker()\n", + "print(rand_spk) # save it for later timbre recovery\n", + "\n", "params_infer_code = ChatTTS.Chat.InferCodeParams(\n", " spk_emb=rand_spk,\n", ")\n", diff --git a/examples/ipynb/example.ipynb b/examples/ipynb/example.ipynb index 20e4e316a..c31844e71 100644 --- a/examples/ipynb/example.ipynb +++ b/examples/ipynb/example.ipynb @@ -247,6 +247,8 @@ "outputs": [], "source": [ "rand_spk = chat.sample_random_speaker()\n", + "print(rand_spk) # save it for later timbre recovery\n", + "\n", "params_infer_code = ChatTTS.Chat.InferCodeParams(\n", " spk_emb=rand_spk,\n", ")\n",