Skip to content

Commit

Permalink
feat: use str type spk_emb for easy recovery
Browse files Browse the repository at this point in the history
  • Loading branch information
fumiama committed Jun 26, 2024
1 parent a5ccaf8 commit d54a80c
Show file tree
Hide file tree
Showing 8 changed files with 42 additions and 7 deletions.
40 changes: 33 additions & 7 deletions ChatTTS/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
from typing import Literal, Optional, List, Callable, Tuple, Dict
from json import load
from pathlib import Path
import lzma

import numpy as np
import torch
import torch.nn.functional as F
from omegaconf import OmegaConf
from vocos import Vocos
from huggingface_hub import snapshot_download
import pybase16384 as b14

from .model import DVAE, GPT, gen_logits
from .utils import (
Expand Down Expand Up @@ -151,10 +153,28 @@ def unload(self):
delattr(self, module)
self.__init__(logger)

def sample_random_speaker(self):
dim = self.gpt.gpt.layers[0].mlp.gate_proj.in_features
std, mean = self.pretrain_models["spk_stat"].chunk(2)
return torch.randn(dim, device=std.device) * std + mean
def sample_random_speaker(self) -> str:
with torch.no_grad():
spk = self._sample_random_speaker()
arr: np.ndarray = spk.cpu().numpy()
s = b14.encode_to_string(
lzma.compress(
arr.tobytes(),
format=lzma.FORMAT_RAW,
filters=[{"id": lzma.FILTER_LZMA2, "preset": 9 | lzma.PRESET_EXTREME}],
),
)
del arr, spk
return s

def _sample_random_speaker(self) -> torch.Tensor:
with torch.no_grad():
dim: int = self.gpt.gpt.layers[0].mlp.gate_proj.in_features
out: torch.Tensor = self.pretrain_models["spk_stat"]
std, mean = out.chunk(2)
spk = torch.randn(dim, device=std.device, dtype=torch.float16).mul_(std).add_(mean)
del out, std, mean
return spk

@dataclass(repr=False, eq=False)
class RefineTextParams:
Expand All @@ -169,7 +189,7 @@ class RefineTextParams:
@dataclass(repr=False, eq=False)
class InferCodeParams:
prompt: str = "[speed_5]"
spk_emb: Optional[torch.Tensor] = None
spk_emb: Optional[str] = None
top_P: float = 0.7
top_K: int = 20
temperature: float = 0.3
Expand Down Expand Up @@ -426,12 +446,18 @@ def _text_to_token(self, text: str, device="cpu") -> Tuple[torch.Tensor, torch.T
def _apply_spk_emb(
self,
emb: torch.Tensor,
spk_emb: torch.Tensor,
spk_emb: str,
input_ids: torch.Tensor,
text_len: int,
):
n = F.normalize(
spk_emb.unsqueeze(0).expand(text_len, -1), p=2.0, dim=1, eps=1e-12
torch.from_numpy(
np.frombuffer(lzma.decompress(
b14.decode_from_string(spk_emb),
format=lzma.FORMAT_RAW,
filters=[{"id": lzma.FILTER_LZMA2, "preset": 9 | lzma.PRESET_EXTREME}],
), dtype=np.float16).copy(),
).unsqueeze(0).expand(text_len, -1), p=2.0, dim=1, eps=1e-12
).to(self.gpt.device_gpt).expand(emb.shape)
cond = input_ids.narrow(-1, 0, 1).eq(self.tokenizer_spk_emb_ids).expand(emb.shape)
torch.where(cond, n, emb, out=emb)
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ torchaudio.save("output1.wav", torch.from_numpy(wavs[0]), 24000)
# Sample a speaker from Gaussian.

rand_spk = chat.sample_random_speaker()
print(rand_spk) # save it for later timbre recovery

params_infer_code = ChatTTS.Chat.InferCodeParams(
spk_emb = rand_spk, # add sampled speaker
Expand Down
1 change: 1 addition & 0 deletions docs/cn/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ torchaudio.save("output1.wav", torch.from_numpy(wavs[0]), 24000)
# Sample a speaker from Gaussian.

rand_spk = chat.sample_random_speaker()
print(rand_spk) # save it for later timbre recovery

params_infer_code = {
'spk_emb': rand_spk, # add sampled speaker
Expand Down
1 change: 1 addition & 0 deletions docs/es/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ torchaudio.save("output1.wav", torch.from_numpy(wavs[0]), 24000)
# Sample a speaker from Gaussian.

rand_spk = chat.sample_random_speaker()
print(rand_spk) # save it for later timbre recovery

params_infer_code = ChatTTS.Chat.InferCodeParams(
spk_emb = rand_spk, # add sampled speaker
Expand Down
1 change: 1 addition & 0 deletions docs/jp/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ torchaudio.save("output1.wav", torch.from_numpy(wavs[0]), 24000)
# ガウス分布から話者をサンプリングします。

rand_spk = chat.sample_random_speaker()
print(rand_spk) # save it for later timbre recovery

params_infer_code = {
'spk_emb': rand_spk, # サンプリングされた話者を追加
Expand Down
1 change: 1 addition & 0 deletions docs/ru/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ torchaudio.save("output1.wav", torch.from_numpy(wavs[0]), 24000)
# Выборка говорящего из Гауссиана.

rand_spk = chat.sample_random_speaker()
print(rand_spk) # save it for later timbre recovery

params_infer_code = {
'spk_emb': rand_spk, # добавить выбранного говорящего
Expand Down
2 changes: 2 additions & 0 deletions examples/ipynb/colab.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,8 @@
"outputs": [],
"source": [
"rand_spk = chat.sample_random_speaker()\n",
"print(rand_spk) # save it for later timbre recovery\n",
"\n",
"params_infer_code = ChatTTS.Chat.InferCodeParams(\n",
" spk_emb=rand_spk,\n",
")\n",
Expand Down
2 changes: 2 additions & 0 deletions examples/ipynb/example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,8 @@
"outputs": [],
"source": [
"rand_spk = chat.sample_random_speaker()\n",
"print(rand_spk) # save it for later timbre recovery\n",
"\n",
"params_infer_code = ChatTTS.Chat.InferCodeParams(\n",
" spk_emb=rand_spk,\n",
")\n",
Expand Down

0 comments on commit d54a80c

Please sign in to comment.