diff --git a/ChatTTS/core.py b/ChatTTS/core.py index 7e246fa1e..0d96905d9 100644 --- a/ChatTTS/core.py +++ b/ChatTTS/core.py @@ -407,7 +407,7 @@ def _infer( ): wav = self._decode_to_wavs(result, length, use_decoder) yield wav - + def _vocos_decode(self, spec: torch.Tensor) -> np.ndarray: if "mps" in str(self.device): return self.vocos.decode(spec.cpu()).cpu().numpy() @@ -460,13 +460,27 @@ def _text_to_token( attn_sz = attention_mask_lst[-1].size(0) if attn_sz > max_attention_mask_len: max_attention_mask_len = attn_sz - input_ids = torch.zeros(len(input_ids_lst), max_input_ids_len, device=device, dtype=input_ids_lst[0].dtype) + input_ids = torch.zeros( + len(input_ids_lst), + max_input_ids_len, + device=device, + dtype=input_ids_lst[0].dtype, + ) for i in range(len(input_ids_lst)): - input_ids.narrow(0, i, 1).narrow(1, 0, input_ids_lst[i].size(0)).copy_(input_ids_lst[i]) + input_ids.narrow(0, i, 1).narrow(1, 0, input_ids_lst[i].size(0)).copy_( + input_ids_lst[i] + ) del_all(input_ids_lst) - attention_mask = torch.zeros(len(attention_mask_lst), max_attention_mask_len, device=device, dtype=attention_mask_lst[0].dtype) + attention_mask = torch.zeros( + len(attention_mask_lst), + max_attention_mask_len, + device=device, + dtype=attention_mask_lst[0].dtype, + ) for i in range(len(attention_mask_lst)): - attention_mask.narrow(0, i, 1).narrow(1, 0, attention_mask_lst[i].size(0)).copy_(attention_mask_lst[i]) + attention_mask.narrow(0, i, 1).narrow( + 1, 0, attention_mask_lst[i].size(0) + ).copy_(attention_mask_lst[i]) del_all(attention_mask_lst) text_mask = torch.ones(input_ids.shape, dtype=bool, device=device) diff --git a/examples/cmd/run.py b/examples/cmd/run.py index 9206b3dcb..7acb2796a 100644 --- a/examples/cmd/run.py +++ b/examples/cmd/run.py @@ -35,16 +35,19 @@ def main(texts: List[str], spk: Optional[str] = None): else: logger.error("Models load failed.") sys.exit(1) - + if spk is None: spk = chat.sample_random_speaker() logger.info("Use speaker:") print(spk) logger.info("Start inference.") - wavs = chat.infer(texts, params_infer_code=ChatTTS.Chat.InferCodeParams( - spk_emb=spk, - )) + wavs = chat.infer( + texts, + params_infer_code=ChatTTS.Chat.InferCodeParams( + spk_emb=spk, + ), + ) logger.info("Inference completed.") # Save each generated wav file to a local file for index, wav in enumerate(wavs): @@ -55,9 +58,15 @@ def main(texts: List[str], spk: Optional[str] = None): if __name__ == "__main__": logger.info("Starting ChatTTS commandline demo...") parser = argparse.ArgumentParser( - description="ChatTTS Command", usage='[--spk xxx] "Your text 1." " Your text 2."' + description="ChatTTS Command", + usage='[--spk xxx] "Your text 1." " Your text 2."', + ) + parser.add_argument( + "--spk", + help="Speaker (empty to sample a random one)", + type=Optional[str], + default=None, ) - parser.add_argument("--spk", help="Speaker (empty to sample a random one)", type=Optional[str], default=None) parser.add_argument( "texts", help="Original text", default="YOUR TEXT HERE", nargs="*" )