Skip to content

Commit feb9aad

Browse files
committed
feat:support async vllm generator
1 parent 651093e commit feb9aad

File tree

13 files changed

+735
-191
lines changed

13 files changed

+735
-191
lines changed

ChatTTS/core.py

Lines changed: 71 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
import logging
33
import tempfile
4+
import uuid
45
from dataclasses import dataclass, asdict
56
from typing import Literal, Optional, List, Tuple, Dict, Union
67
from json import load
@@ -173,15 +174,17 @@ class RefineTextParams:
173174
min_new_token: int = 0
174175
show_tqdm: bool = True
175176
ensure_non_empty: bool = True
176-
manual_seed: Optional[int] = None
177+
manual_seed: Optional[int] = 0
177178

178179
@dataclass(repr=False, eq=False)
179180
class InferCodeParams(RefineTextParams):
180181
prompt: str = "[speed_5]"
181182
spk_emb: Optional[str] = None
182183
spk_smp: Optional[str] = None
183184
txt_smp: Optional[str] = None
184-
temperature: float = 0.3
185+
top_P: float = 1
186+
top_K: int = 1
187+
temperature: float = 0.01
185188
repetition_penalty: float = 1.05
186189
max_new_token: int = 2048
187190
stream_batch: int = 24
@@ -193,16 +196,17 @@ def infer(
193196
text,
194197
stream=False,
195198
lang=None,
196-
skip_refine_text=False,
199+
skip_refine_text=True,
197200
refine_text_only=False,
198201
use_decoder=True,
199202
do_text_normalization=True,
200203
do_homophone_replacement=True,
201-
params_refine_text=RefineTextParams(),
202-
params_infer_code=InferCodeParams(),
204+
params_refine_text=None,
205+
params_infer_code=None,
206+
stream_batch_size=16,
203207
):
204208
self.context.set(False)
205-
res_gen = self._infer(
209+
return self._infer(
206210
text,
207211
stream,
208212
lang,
@@ -213,11 +217,8 @@ def infer(
213217
do_homophone_replacement,
214218
params_refine_text,
215219
params_infer_code,
220+
stream_batch_size,
216221
)
217-
if stream:
218-
return res_gen
219-
else:
220-
return next(res_gen)
221222

222223
def interrupt(self):
223224
self.context.set(True)
@@ -272,7 +273,7 @@ def _load(
272273
vq_config=asdict(self.config.dvae.vq),
273274
dim=self.config.dvae.decoder.idim,
274275
coef=coef,
275-
device=self.device,
276+
device=device,
276277
)
277278
.to(device)
278279
.eval()
@@ -289,8 +290,8 @@ def _load(
289290
self.config.embed.num_text_tokens,
290291
self.config.embed.num_vq,
291292
)
292-
embed.from_pretrained(embed_path, device=self.device)
293-
self.embed = embed.to(self.device)
293+
embed.from_pretrained(embed_path, device=device)
294+
self.embed = embed.to(device)
294295
self.logger.log(logging.INFO, "embed loaded.")
295296

296297
gpt = GPT(
@@ -318,6 +319,7 @@ def _load(
318319
decoder_config=asdict(self.config.decoder),
319320
dim=self.config.decoder.idim,
320321
coef=coef,
322+
device=device,
321323
)
322324
.to(device)
323325
.eval()
@@ -338,18 +340,19 @@ def _load(
338340

339341
return self.has_loaded()
340342

341-
def _infer(
343+
async def _infer(
342344
self,
343345
text,
344-
stream=False,
346+
stream=True,
345347
lang=None,
346-
skip_refine_text=False,
348+
skip_refine_text=True,
347349
refine_text_only=False,
348350
use_decoder=True,
349351
do_text_normalization=True,
350352
do_homophone_replacement=True,
351-
params_refine_text=RefineTextParams(),
352-
params_infer_code=InferCodeParams(),
353+
params_refine_text=None,
354+
params_infer_code=None,
355+
stream_batch_size=16,
353356
):
354357

355358
assert self.has_loaded(use_decoder=use_decoder)
@@ -383,41 +386,38 @@ def _infer(
383386
yield text
384387
return
385388

386-
if stream:
387-
length = 0
388-
pass_batch_count = 0
389-
for result in self._infer_code(
389+
length = 0
390+
async for result in self._infer_code(
390391
text,
391392
stream,
392393
self.device,
393394
use_decoder,
394395
params_infer_code,
396+
stream_batch_size,
395397
):
396398
wavs = self._decode_to_wavs(
397399
result.hiddens if use_decoder else result.ids,
398400
use_decoder,
399401
)
400-
result.destroy()
401-
if stream:
402-
pass_batch_count += 1
403-
if pass_batch_count <= params_infer_code.pass_first_n_batches:
404-
continue
405-
a = length
406-
b = a + params_infer_code.stream_speed
407-
if b > wavs.shape[1]:
408-
b = wavs.shape[1]
409-
new_wavs = wavs[:, a:b]
410-
length = b
411-
yield new_wavs
402+
403+
if result.finished:
404+
yield wavs[:, length:]
412405
else:
413-
yield wavs
414-
if stream:
415-
new_wavs = wavs[:, length:]
416-
# Identify rows with non-zero elements using np.any
417-
# keep_rows = np.any(array != 0, axis=1)
418-
keep_cols = np.sum(new_wavs != 0, axis=0) > 0
419-
# Filter both rows and columns using slicing
420-
yield new_wavs[:][:, keep_cols]
406+
# Hacker:Check if there are any silent segments; if so, take the last segment. Otherwise, try waiting for another loop.
407+
import librosa
408+
409+
silence_intervals = librosa.effects.split(wavs[0][length:], top_db=10)
410+
silence_left = 0
411+
if len(silence_intervals) == 0:
412+
silence_left = len(wavs[0])
413+
else:
414+
for i in range(len(silence_intervals)):
415+
silence_left = silence_intervals[i][0]
416+
if silence_left <= 0:
417+
continue
418+
new_wavs = wavs[:, length : length + silence_left]
419+
length += len(new_wavs[0])
420+
yield new_wavs
421421

422422
@torch.inference_mode()
423423
def _vocos_decode(self, spec: torch.Tensor) -> np.ndarray:
@@ -456,13 +456,14 @@ def _decode_to_wavs(
456456
return wavs
457457

458458
@torch.no_grad()
459-
def _infer_code(
459+
async def _infer_code(
460460
self,
461461
text: Tuple[List[str], str],
462462
stream: bool,
463463
device: torch.device,
464464
return_hidden: bool,
465465
params: InferCodeParams,
466+
stream_batch_size: int,
466467
):
467468

468469
gpt = self.gpt
@@ -503,6 +504,17 @@ def _infer_code(
503504
repetition_penalty=params.repetition_penalty,
504505
)
505506

507+
speaker_embedding_param = gpt(input_ids, text_mask)
508+
509+
if params.spk_emb is not None:
510+
self.speaker.apply(
511+
speaker_embedding_param,
512+
params.spk_emb,
513+
input_ids,
514+
self.tokenizer.spk_emb_ids,
515+
self.gpt.device_gpt,
516+
)
517+
506518
if gpt.is_vllm:
507519
from .model.velocity import SamplingParams
508520

@@ -518,65 +530,23 @@ def _infer_code(
518530
)
519531
input_ids = [i.tolist() for i in input_ids]
520532

521-
result = gpt.llm.generate(
522-
None,
523-
sample_params,
524-
input_ids,
533+
results_generator = gpt.llm.llm_engine.generate(
534+
None, sample_params, uuid.uuid4(), speaker_embedding_param, input_ids[0]
525535
)
526-
527-
token_ids = []
528-
hidden_states = []
529-
for i in result:
530-
token_ids.append(torch.tensor(i.outputs[0].token_ids))
531-
hidden_states.append(
532-
i.outputs[0].hidden_states.to(torch.float32).to(self.device)
533-
)
534-
535-
del text_mask, input_ids
536-
537-
return [
538-
GPT.GenerationOutputs(
539-
ids=token_ids,
540-
hiddens=hidden_states,
541-
attentions=[],
542-
),
543-
]
544-
545-
emb = self.embed(input_ids, text_mask)
546-
547-
del text_mask
548-
549-
if params.spk_emb is not None:
550-
self.speaker.apply(
551-
emb,
552-
params.spk_emb,
553-
input_ids,
554-
self.tokenizer.spk_emb_ids,
555-
self.gpt.device_gpt,
556-
)
557-
558-
result = gpt.generate(
559-
emb,
560-
input_ids,
561-
temperature=torch.tensor(temperature, device=device),
562-
eos_token=num_code,
563-
attention_mask=attention_mask,
564-
max_new_token=params.max_new_token,
565-
min_new_token=params.min_new_token,
566-
logits_processors=(*logits_processors, *logits_warpers),
567-
infer_text=False,
568-
return_hidden=return_hidden,
569-
stream=stream,
570-
show_tqdm=params.show_tqdm,
571-
ensure_non_empty=params.ensure_non_empty,
572-
stream_batch=params.stream_batch,
573-
manual_seed=params.manual_seed,
574-
context=self.context,
575-
)
576-
577-
del emb, input_ids
578-
579-
return result
536+
async for i in results_generator:
537+
token_ids = []
538+
hidden_states = []
539+
if len(i.outputs[0].token_ids) % stream_batch_size == 0 or i.finished:
540+
token_ids.append(torch.tensor(i.outputs[0].token_ids))
541+
hidden_states.append(
542+
i.outputs[0].hidden_states.to(torch.float32).to(self.device)
543+
)
544+
yield GPT.GenerationOutputs(
545+
ids=token_ids,
546+
finished=i.finished,
547+
hiddens=hidden_states,
548+
attentions=[],
549+
)
580550

581551
@torch.no_grad()
582552
def _refine_text(

ChatTTS/model/dvae.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def __init__(
179179
hop_length=256,
180180
n_mels=100,
181181
padding: Literal["center", "same"] = "center",
182-
device: torch.device = torch.device("cuda"),
182+
device: torch.device = torch.device("cpu"),
183183
):
184184
super().__init__()
185185
self.device = device
@@ -213,7 +213,7 @@ def __init__(
213213
vq_config: Optional[dict] = None,
214214
dim=512,
215215
coef: Optional[str] = None,
216-
device: torch.device = torch.device("cuda"),
216+
device: torch.device = torch.device("cpu"),
217217
):
218218
super().__init__()
219219
if coef is None:

ChatTTS/model/gpt.py

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,14 @@ def __init__(
4646
self.is_te_llama = False
4747
self.is_vllm = use_vllm
4848

49-
if self.is_vllm:
50-
return
51-
52-
self.llama_config = self._build_llama_config(gpt_config)
53-
5449
self.emb_code = [ec.__call__ for ec in embed.emb_code]
5550
self.emb_text = embed.emb_text.__call__
5651
self.head_text = embed.head_text.__call__
5752
self.head_code = [hc.__call__ for hc in embed.head_code]
53+
if self.is_vllm:
54+
return
55+
56+
self.llama_config = self._build_llama_config(gpt_config)
5857

5958
def from_pretrained(
6059
self, gpt_folder: str, embed_file_path: str, experimental=False
@@ -68,6 +67,7 @@ def from_pretrained(
6867
num_audio_tokens=self.num_audio_tokens,
6968
num_text_tokens=self.num_text_tokens,
7069
post_model_path=embed_file_path,
70+
dtype="float32",
7171
)
7272
self.logger.info("vLLM model loaded")
7373
return
@@ -138,6 +138,44 @@ def prepare(self, compile=False):
138138
except RuntimeError as e:
139139
self.logger.warning(f"compile failed: {e}. fallback to normal mode.")
140140

141+
def __call__(
142+
self, input_ids: torch.Tensor, text_mask: torch.Tensor
143+
) -> torch.Tensor:
144+
"""
145+
get_emb
146+
"""
147+
return super().__call__(input_ids, text_mask)
148+
149+
def forward(self, input_ids: torch.Tensor, text_mask: torch.Tensor) -> torch.Tensor:
150+
"""
151+
get_emb
152+
"""
153+
input_ids = input_ids.clone()
154+
text_mask = text_mask.clone()
155+
emb_text: torch.Tensor = self.emb_text(
156+
input_ids[text_mask].narrow(1, 0, 1).squeeze_(1).to(self.device_gpt)
157+
)
158+
159+
text_mask_inv = text_mask.logical_not().to(self.device_gpt)
160+
masked_input_ids: torch.Tensor = input_ids[text_mask_inv].to(self.device_gpt)
161+
162+
emb_code = [
163+
self.emb_code[i](masked_input_ids[:, i]) for i in range(self.num_vq)
164+
]
165+
emb_code = torch.stack(emb_code, 2).sum(2)
166+
167+
emb = torch.zeros(
168+
(input_ids.shape[:-1]) + (emb_text.shape[-1],),
169+
device=emb_text.device,
170+
dtype=emb_text.dtype,
171+
)
172+
emb[text_mask] = emb_text
173+
emb[text_mask_inv] = emb_code.to(emb.dtype)
174+
175+
del emb_text, emb_code, text_mask_inv
176+
177+
return emb
178+
141179
@dataclass(repr=False, eq=False)
142180
class _GenerationInputs:
143181
position_ids: torch.Tensor
@@ -273,6 +311,7 @@ class GenerationOutputs:
273311
ids: List[torch.Tensor]
274312
attentions: List[Optional[Tuple[torch.FloatTensor, ...]]]
275313
hiddens: List[torch.Tensor]
314+
finished: bool
276315

277316
def destroy(self):
278317
del_all(self.ids)

0 commit comments

Comments
 (0)