Skip to content

Commit f0babd0

Browse files
authored
feat: support stream mode (#360)
* Update core.py * Update core.py * Update api.py * gpt support streaming
1 parent a63e9c2 commit f0babd0

File tree

3 files changed

+61
-20
lines changed

3 files changed

+61
-20
lines changed

ChatTTS/core.py

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def _load(
145145

146146
self.check_model()
147147

148-
def infer(
148+
def _infer(
149149
self,
150150
text,
151151
skip_refine_text=False,
@@ -155,22 +155,21 @@ def infer(
155155
use_decoder=True,
156156
do_text_normalization=True,
157157
lang=None,
158+
stream=False,
158159
do_homophone_replacement=True
159160
):
160161

161162
assert self.check_model(use_decoder=use_decoder)
162163

163164
if not isinstance(text, list):
164165
text = [text]
165-
166166
if do_text_normalization:
167167
for i, t in enumerate(text):
168168
_lang = detect_language(t) if lang is None else lang
169169
if self.init_normalizer(_lang):
170170
text[i] = self.normalizer[_lang](t)
171171
if _lang == 'zh':
172172
text[i] = apply_half2full_map(text[i])
173-
174173
for i, t in enumerate(text):
175174
invalid_characters = count_invalid_characters(t)
176175
if len(invalid_characters):
@@ -190,18 +189,44 @@ def infer(
190189

191190
text = [params_infer_code.get('prompt', '') + i for i in text]
192191
params_infer_code.pop('prompt', '')
193-
result = infer_code(self.pretrain_models, text, **params_infer_code, return_hidden=use_decoder)
194-
192+
result_gen = infer_code(self.pretrain_models, text, **params_infer_code, return_hidden=use_decoder, stream=stream)
195193
if use_decoder:
196-
mel_spec = [self.pretrain_models['decoder'](i[None].permute(0,2,1)) for i in result['hiddens']]
194+
field = 'hiddens'
195+
docoder_name = 'decoder'
197196
else:
198-
mel_spec = [self.pretrain_models['dvae'](i[None].permute(0,2,1)) for i in result['ids']]
199-
200-
wav = [self.pretrain_models['vocos'].decode(
201-
i.cpu() if torch.backends.mps.is_available() else i
202-
).cpu().numpy() for i in mel_spec]
203-
204-
return wav
197+
field = 'ids'
198+
docoder_name = 'dvae'
199+
vocos_decode = lambda spec: [self.pretrain_models['vocos'].decode(
200+
i.cpu() if torch.backends.mps.is_available() else i
201+
).cpu().numpy() for i in spec]
202+
if stream:
203+
204+
length = 0
205+
for result in result_gen:
206+
chunk_data = result[field][0]
207+
assert len(result[field]) == 1
208+
start_seek = length
209+
length = len(chunk_data)
210+
self.logger.debug(f'{start_seek=} total len: {length}, new len: {length - start_seek = }')
211+
chunk_data = chunk_data[start_seek:]
212+
if not len(chunk_data):
213+
continue
214+
self.logger.debug(f'new hidden {len(chunk_data)=}')
215+
mel_spec = [self.pretrain_models[docoder_name](i[None].permute(0,2,1)) for i in [chunk_data]]
216+
wav = vocos_decode(mel_spec)
217+
self.logger.debug(f'yield wav chunk {len(wav[0])=} {len(wav[0][0])=}')
218+
yield wav
219+
return
220+
mel_spec = [self.pretrain_models[docoder_name](i[None].permute(0,2,1)) for i in next(result_gen)[field]]
221+
yield vocos_decode(mel_spec)
222+
223+
def infer(self, *args, **kwargs):
224+
stream = kwargs.setdefault('stream', False)
225+
res_gen = self._infer(*args, **kwargs)
226+
if stream:
227+
return res_gen
228+
else:
229+
return next(res_gen)
205230

206231
def sample_random_speaker(self, ):
207232

ChatTTS/infer/api.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ def infer_code(
1313
temperature = 0.3,
1414
repetition_penalty = 1.05,
1515
max_new_token = 2048,
16+
stream=False,
1617
**kwargs
1718
):
1819

@@ -66,6 +67,7 @@ def infer_code(
6667
eos_token = num_code,
6768
max_new_token = max_new_token,
6869
infer_text = False,
70+
stream = stream,
6971
**kwargs
7072
)
7173

@@ -122,4 +124,4 @@ def refine_text(
122124
infer_text = True,
123125
**kwargs
124126
)
125-
return result
127+
return result

ChatTTS/model/gpt.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ def generate(
167167
infer_text=False,
168168
return_attn=False,
169169
return_hidden=False,
170+
stream=False,
170171
):
171172

172173
with torch.no_grad():
@@ -264,7 +265,20 @@ def generate(
264265
del idx_next
265266

266267
end_idx += (~finish).int().to(end_idx.device)
267-
268+
if stream:
269+
if end_idx % 24 and not finish.all():
270+
continue
271+
y_inputs_ids = [inputs_ids[idx, start_idx: start_idx+i] for idx, i in enumerate(end_idx.int())]
272+
y_inputs_ids = [i[:, 0] for i in y_inputs_ids] if infer_text else y_inputs_ids
273+
y_hiddens = [[]]
274+
if return_hidden:
275+
y_hiddens = torch.stack(hiddens, 1)
276+
y_hiddens = [y_hiddens[idx, :i] for idx, i in enumerate(end_idx.int())]
277+
yield {
278+
'ids': y_inputs_ids,
279+
'attentions': attentions,
280+
'hiddens':y_hiddens,
281+
}
268282
if finish.all():
269283
pbar.update(max_new_token-i-1)
270284
break
@@ -277,12 +291,12 @@ def generate(
277291
hiddens = [hiddens[idx, :i] for idx, i in enumerate(end_idx.int())]
278292

279293
if not finish.all():
280-
self.logger.warn(f'Incomplete result. hit max_new_token: {max_new_token}')
294+
self.logger.warn(f'Incomplete result. hit max_new_token: {max_new_token}')
295+
296+
del finish
281297

282-
del finish
283-
284-
return {
298+
yield {
285299
'ids': inputs_ids,
286300
'attentions': attentions,
287301
'hiddens':hiddens,
288-
}
302+
}

0 commit comments

Comments
 (0)