Skip to content

Commit

Permalink
chore(format): run black on main (#475)
Browse files Browse the repository at this point in the history
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
  • Loading branch information
github-actions[bot] and github-actions[bot] authored Jun 27, 2024
1 parent bb9d3b3 commit 52d75d5
Show file tree
Hide file tree
Showing 9 changed files with 123 additions and 57 deletions.
116 changes: 76 additions & 40 deletions ChatTTS/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,15 @@ def download_models(
) -> Optional[str]:
if source == "local":
download_path = os.getcwd()
if not check_all_assets(Path(download_path), self.sha256_map, update=True) or force_redownload:
if (
not check_all_assets(Path(download_path), self.sha256_map, update=True)
or force_redownload
):
with tempfile.TemporaryDirectory() as tmp:
download_all_assets(tmpdir=tmp)
if not check_all_assets(Path(download_path), self.sha256_map, update=False):
if not check_all_assets(
Path(download_path), self.sha256_map, update=False
):
self.logger.error(
"download to local path %s failed.", download_path
)
Expand Down Expand Up @@ -109,9 +114,7 @@ def download_models(
elif source == "custom":
self.logger.log(logging.INFO, f"try to load from local: {custom_path}")
if not check_all_assets(Path(custom_path), self.sha256_map, update=False):
self.logger.error(
"check models in custom path %s failed.", custom_path
)
self.logger.error("check models in custom path %s failed.", custom_path)
return None
download_path = custom_path

Expand Down Expand Up @@ -164,7 +167,9 @@ def _encode_spk_emb(spk_emb: torch.Tensor) -> str:
lzma.compress(
arr.tobytes(),
format=lzma.FORMAT_RAW,
filters=[{"id": lzma.FILTER_LZMA2, "preset": 9 | lzma.PRESET_EXTREME}],
filters=[
{"id": lzma.FILTER_LZMA2, "preset": 9 | lzma.PRESET_EXTREME}
],
),
)
del arr
Expand All @@ -175,7 +180,11 @@ def _sample_random_speaker(self) -> torch.Tensor:
dim: int = self.gpt.gpt.layers[0].mlp.gate_proj.in_features
out: torch.Tensor = self.pretrain_models["spk_stat"]
std, mean = out.chunk(2)
spk = torch.randn(dim, device=std.device, dtype=torch.float16).mul_(std).add_(mean)
spk = (
torch.randn(dim, device=std.device, dtype=torch.float16)
.mul_(std)
.add_(mean)
)
del out, std, mean
return spk

Expand Down Expand Up @@ -331,8 +340,12 @@ def _load(
tokenizer.padding_side = "left"
self.pretrain_models["tokenizer"] = tokenizer
self.tokenizer_len = len(tokenizer)
self.tokenizer_spk_emb_ids: torch.Tensor = tokenizer.convert_tokens_to_ids("[spk_emb]")
self.tokenizer_break_0_ids: torch.Tensor = tokenizer.convert_tokens_to_ids("[break_0]")
self.tokenizer_spk_emb_ids: torch.Tensor = tokenizer.convert_tokens_to_ids(
"[spk_emb]"
)
self.tokenizer_break_0_ids: torch.Tensor = tokenizer.convert_tokens_to_ids(
"[break_0]"
)
self.tokenizer_eos_token: torch.Tensor = torch.tensor(
tokenizer.convert_tokens_to_ids("[Ebreak]"), device=gpt.device_gpt
).unsqueeze_(0)
Expand Down Expand Up @@ -381,8 +394,7 @@ def _infer(
)
text_tokens = refined.ids
text_tokens = [
i[i.less(self.tokenizer_break_0_ids)]
for i in text_tokens
i[i.less(self.tokenizer_break_0_ids)] for i in text_tokens
]
text = self.pretrain_models["tokenizer"].batch_decode(text_tokens)
refined.destroy()
Expand Down Expand Up @@ -423,7 +435,9 @@ def _decode_to_wavs(
del_all(x)
return wavs

def _text_to_token(self, text: str, device="cpu") -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
def _text_to_token(
self, text: str, device="cpu"
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:

gpt = self.gpt
tokenizer = self.pretrain_models["tokenizer"]
Expand All @@ -441,14 +455,17 @@ def _text_to_token(self, text: str, device="cpu") -> Tuple[torch.Tensor, torch.T
del_all(text_token)

return input_ids, attention_mask, text_mask

@staticmethod
def _decode_spk_emb(spk_emb: str) -> np.ndarray:
return np.frombuffer(lzma.decompress(
b14.decode_from_string(spk_emb),
format=lzma.FORMAT_RAW,
filters=[{"id": lzma.FILTER_LZMA2, "preset": 9 | lzma.PRESET_EXTREME}],
), dtype=np.float16).copy()
return np.frombuffer(
lzma.decompress(
b14.decode_from_string(spk_emb),
format=lzma.FORMAT_RAW,
filters=[{"id": lzma.FILTER_LZMA2, "preset": 9 | lzma.PRESET_EXTREME}],
),
dtype=np.float16,
).copy()

def _apply_spk_emb(
self,
Expand All @@ -457,12 +474,24 @@ def _apply_spk_emb(
input_ids: torch.Tensor,
text_len: int,
):
n = F.normalize(
torch.from_numpy(
self._decode_spk_emb(spk_emb),
).unsqueeze(0).expand(text_len, -1), p=2.0, dim=1, eps=1e-12
).to(self.gpt.device_gpt).unsqueeze_(1).expand(emb.shape)
cond = input_ids.narrow(-1, 0, 1).eq(self.tokenizer_spk_emb_ids).expand(emb.shape)
n = (
F.normalize(
torch.from_numpy(
self._decode_spk_emb(spk_emb),
)
.unsqueeze(0)
.expand(text_len, -1),
p=2.0,
dim=1,
eps=1e-12,
)
.to(self.gpt.device_gpt)
.unsqueeze_(1)
.expand(emb.shape)
)
cond = (
input_ids.narrow(-1, 0, 1).eq(self.tokenizer_spk_emb_ids).expand(emb.shape)
)
torch.where(cond, n, emb, out=emb)
del cond, n

Expand All @@ -486,9 +515,14 @@ def _infer_code(
temperature = [params.temperature] * gpt.num_vq
else:
temperature = params.temperature

for i, t in enumerate(text):
text[i] = t.replace('[Stts]', '').replace('[spk_emb]', '').replace('[empty_spk]', '').strip()
text[i] = (
t.replace("[Stts]", "")
.replace("[spk_emb]", "")
.replace("[empty_spk]", "")
.strip()
)
"""
see https://github.com/2noise/ChatTTS/issues/459
"""
Expand Down Expand Up @@ -566,20 +600,22 @@ def _refine_text(
emb = gpt(input_ids, text_mask)
del text_mask

result = next(gpt.generate(
emb,
input_ids,
temperature=torch.tensor([params.temperature], device=device),
eos_token=self.tokenizer_eos_token,
attention_mask=attention_mask,
max_new_token=params.max_new_token,
min_new_token=params.min_new_token,
logits_warpers=logits_warpers,
logits_processors=logits_processors,
infer_text=True,
stream=False,
context=self.context,
))
result = next(
gpt.generate(
emb,
input_ids,
temperature=torch.tensor([params.temperature], device=device),
eos_token=self.tokenizer_eos_token,
attention_mask=attention_mask,
max_new_token=params.max_new_token,
min_new_token=params.min_new_token,
logits_warpers=logits_warpers,
logits_processors=logits_processors,
infer_text=True,
stream=False,
context=self.context,
)
)

del emb, input_ids
del_all(logits_warpers)
Expand Down
18 changes: 13 additions & 5 deletions ChatTTS/model/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,9 @@ def _prepare_generation_inputs(
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
# input_ids based on the past_length.
elif past_length < input_ids.shape[1]:
input_ids = input_ids.narrow(1, past_length, input_ids.size(1)-past_length)
input_ids = input_ids.narrow(
1, past_length, input_ids.size(1) - past_length
)
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.

# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
Expand All @@ -235,14 +237,18 @@ def _prepare_generation_inputs(
and attention_mask is not None
and cache_length + input_ids.shape[1] > max_cache_length
):
attention_mask = attention_mask.narrow(1, -max_cache_length, max_cache_length)
attention_mask = attention_mask.narrow(
1, -max_cache_length, max_cache_length
)

if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask.eq(0), 1)
if past_key_values:
position_ids = position_ids.narrow(1, -input_ids.shape[1], input_ids.shape[1])
position_ids = position_ids.narrow(
1, -input_ids.shape[1], input_ids.shape[1]
)

input_length = (
position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
Expand Down Expand Up @@ -321,7 +327,7 @@ def generate(
inputs_ids: torch.Tensor,
temperature: torch.Tensor,
eos_token: Union[int, torch.Tensor],
attention_mask: Optional[torch.Tensor]=None,
attention_mask: Optional[torch.Tensor] = None,
max_new_token=2048,
min_new_token=0,
logits_warpers: List[LogitsWarper] = [],
Expand Down Expand Up @@ -360,7 +366,9 @@ def generate(
device=inputs_ids.device,
)
if attention_mask is not None:
attention_mask_cache.narrow(1, 0, attention_mask.shape[1]).copy_(attention_mask)
attention_mask_cache.narrow(1, 0, attention_mask.shape[1]).copy_(
attention_mask
)

with tqdm(
total=max_new_token,
Expand Down
5 changes: 4 additions & 1 deletion ChatTTS/model/processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ def __call__(
input_ids = input_ids.narrow(1, -self.past_window, self.past_window)
freq = F.one_hot(input_ids, scores.size(1)).sum(1)
if freq.size(0) > self.max_input_ids:
freq.narrow(0, self.max_input_ids, freq.size(0)-self.max_input_ids).zero_()
freq.narrow(
0, self.max_input_ids, freq.size(0) - self.max_input_ids
).zero_()
alpha = torch.pow(self.penalty, freq)
scores = scores.contiguous()
inp = scores.multiply(alpha)
Expand All @@ -32,6 +34,7 @@ def __call__(
del inp, oth, scores, con, alpha
return out


def gen_logits(
num_code: int,
top_P=0.7,
Expand Down
2 changes: 1 addition & 1 deletion examples/ipynb/colab.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@
"outputs": [],
"source": [
"rand_spk = chat.sample_random_speaker()\n",
"print(rand_spk) # save it for later timbre recovery\n",
"print(rand_spk) # save it for later timbre recovery\n",
"\n",
"params_infer_code = ChatTTS.Chat.InferCodeParams(\n",
" spk_emb=rand_spk,\n",
Expand Down
2 changes: 1 addition & 1 deletion examples/ipynb/example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@
"outputs": [],
"source": [
"rand_spk = chat.sample_random_speaker()\n",
"print(rand_spk) # save it for later timbre recovery\n",
"print(rand_spk) # save it for later timbre recovery\n",
"\n",
"params_infer_code = ChatTTS.Chat.InferCodeParams(\n",
" spk_emb=rand_spk,\n",
Expand Down
2 changes: 1 addition & 1 deletion examples/web/ex.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
ex=[
ex = [
[
"四川美食确实以辣闻名,但也有不辣的选择。比如甜水面、赖汤圆、蛋烘糕、叶儿粑等,这些小吃口味温和,甜而不腻,也很受欢迎。",
0.3,
Expand Down
10 changes: 8 additions & 2 deletions examples/web/funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def generate_seed():
def on_voice_change(vocie_selection):
return voices.get(vocie_selection)["seed"]


def on_audio_seed_change(audio_seed_input):
with TorchSeedContext(audio_seed_input):
rand_spk = chat.sample_random_speaker()
Expand Down Expand Up @@ -117,12 +118,14 @@ def _set_generate_buttons(generate_button, interrupt_button, is_reset=False):


def refine_text(
text, text_seed_input, refine_text_flag,
text,
text_seed_input,
refine_text_flag,
):
global chat

if not refine_text_flag:
sleep(1) # to skip fast answer of loading mark
sleep(1) # to skip fast answer of loading mark
return text

with TorchSeedContext(text_seed_input):
Expand All @@ -134,6 +137,7 @@ def refine_text(

return text[0] if isinstance(text, list) else text


def generate_audio(text, temperature, top_P, top_K, spk_emb_text: str, stream):
global chat, has_interrupted

Expand Down Expand Up @@ -169,6 +173,7 @@ def interrupt_generate():
has_interrupted = True
chat.interrupt()


def set_buttons_before_generate(generate_button, interrupt_button):
global has_interrupted, is_in_generate

Expand All @@ -180,6 +185,7 @@ def set_buttons_before_generate(generate_button, interrupt_button):
interrupt_button,
)


def set_buttons_after_generate(generate_button, interrupt_button, audio_output):
global has_interrupted, is_in_generate

Expand Down
24 changes: 18 additions & 6 deletions examples/web/webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,19 @@ def main():
label="Timbre", choices=voices.keys(), value="Default"
)
audio_seed_input = gr.Number(
value=2, label="Audio Seed", interactive=True,
minimum=seed_min, maximum=seed_max,
value=2,
label="Audio Seed",
interactive=True,
minimum=seed_min,
maximum=seed_max,
)
generate_audio_seed = gr.Button("\U0001F3B2")
text_seed_input = gr.Number(
value=42, label="Text Seed", interactive=True,
minimum=seed_min, maximum=seed_max,
value=42,
label="Text Seed",
interactive=True,
minimum=seed_min,
maximum=seed_max,
)
generate_text_seed = gr.Button("\U0001F3B2")

Expand Down Expand Up @@ -107,7 +113,9 @@ def main():

generate_text_seed.click(generate_seed, outputs=text_seed_input)

audio_seed_input.change(on_audio_seed_change, inputs=audio_seed_input, outputs=spk_emb_text)
audio_seed_input.change(
on_audio_seed_change, inputs=audio_seed_input, outputs=spk_emb_text
)

reload_chat_button.click(
reload_chat, inputs=dvae_coef_text, outputs=dvae_coef_text
Expand All @@ -125,7 +133,11 @@ def make_audio(autoplay, stream):
interactive=False,
show_label=True,
)
generate_button.click(fn=set_buttons_before_generate, inputs=[generate_button, interrupt_button], outputs=[generate_button, interrupt_button]).then(
generate_button.click(
fn=set_buttons_before_generate,
inputs=[generate_button, interrupt_button],
outputs=[generate_button, interrupt_button],
).then(
refine_text,
inputs=[
text_input,
Expand Down
1 change: 1 addition & 0 deletions tools/audio/mp3.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .np import unsafe_float_to_int16
from .av import wav2


def wav_arr_to_mp3_view(wav: np.ndarray):
buf = BytesIO()
with wave.open(buf, "wb") as wf:
Expand Down

0 comments on commit 52d75d5

Please sign in to comment.