1
1
import os
2
2
import logging
3
3
import tempfile
4
+ import uuid
4
5
from dataclasses import dataclass , asdict
5
6
from typing import Literal , Optional , List , Tuple , Dict , Union
6
7
from json import load
@@ -173,15 +174,17 @@ class RefineTextParams:
173
174
min_new_token : int = 0
174
175
show_tqdm : bool = True
175
176
ensure_non_empty : bool = True
176
- manual_seed : Optional [int ] = None
177
+ manual_seed : Optional [int ] = 0
177
178
178
179
@dataclass (repr = False , eq = False )
179
180
class InferCodeParams (RefineTextParams ):
180
181
prompt : str = "[speed_5]"
181
182
spk_emb : Optional [str ] = None
182
183
spk_smp : Optional [str ] = None
183
184
txt_smp : Optional [str ] = None
184
- temperature : float = 0.3
185
+ top_P : float = 1
186
+ top_K : int = 1
187
+ temperature : float = 0.01
185
188
repetition_penalty : float = 1.05
186
189
max_new_token : int = 2048
187
190
stream_batch : int = 24
@@ -193,16 +196,17 @@ def infer(
193
196
text ,
194
197
stream = False ,
195
198
lang = None ,
196
- skip_refine_text = False ,
199
+ skip_refine_text = True ,
197
200
refine_text_only = False ,
198
201
use_decoder = True ,
199
202
do_text_normalization = True ,
200
203
do_homophone_replacement = True ,
201
- params_refine_text = RefineTextParams (),
202
- params_infer_code = InferCodeParams (),
204
+ params_refine_text = None ,
205
+ params_infer_code = None ,
206
+ stream_batch_size = 16 ,
203
207
):
204
208
self .context .set (False )
205
- res_gen = self ._infer (
209
+ return self ._infer (
206
210
text ,
207
211
stream ,
208
212
lang ,
@@ -213,11 +217,8 @@ def infer(
213
217
do_homophone_replacement ,
214
218
params_refine_text ,
215
219
params_infer_code ,
220
+ stream_batch_size ,
216
221
)
217
- if stream :
218
- return res_gen
219
- else :
220
- return next (res_gen )
221
222
222
223
def interrupt (self ):
223
224
self .context .set (True )
@@ -272,7 +273,7 @@ def _load(
272
273
vq_config = asdict (self .config .dvae .vq ),
273
274
dim = self .config .dvae .decoder .idim ,
274
275
coef = coef ,
275
- device = self . device ,
276
+ device = device ,
276
277
)
277
278
.to (device )
278
279
.eval ()
@@ -289,8 +290,8 @@ def _load(
289
290
self .config .embed .num_text_tokens ,
290
291
self .config .embed .num_vq ,
291
292
)
292
- embed .from_pretrained (embed_path , device = self . device )
293
- self .embed = embed .to (self . device )
293
+ embed .from_pretrained (embed_path , device = device )
294
+ self .embed = embed .to (device )
294
295
self .logger .log (logging .INFO , "embed loaded." )
295
296
296
297
gpt = GPT (
@@ -318,6 +319,7 @@ def _load(
318
319
decoder_config = asdict (self .config .decoder ),
319
320
dim = self .config .decoder .idim ,
320
321
coef = coef ,
322
+ device = device ,
321
323
)
322
324
.to (device )
323
325
.eval ()
@@ -338,18 +340,19 @@ def _load(
338
340
339
341
return self .has_loaded ()
340
342
341
- def _infer (
343
+ async def _infer (
342
344
self ,
343
345
text ,
344
- stream = False ,
346
+ stream = True ,
345
347
lang = None ,
346
- skip_refine_text = False ,
348
+ skip_refine_text = True ,
347
349
refine_text_only = False ,
348
350
use_decoder = True ,
349
351
do_text_normalization = True ,
350
352
do_homophone_replacement = True ,
351
- params_refine_text = RefineTextParams (),
352
- params_infer_code = InferCodeParams (),
353
+ params_refine_text = None ,
354
+ params_infer_code = None ,
355
+ stream_batch_size = 16 ,
353
356
):
354
357
355
358
assert self .has_loaded (use_decoder = use_decoder )
@@ -383,41 +386,38 @@ def _infer(
383
386
yield text
384
387
return
385
388
386
- if stream :
387
- length = 0
388
- pass_batch_count = 0
389
- for result in self ._infer_code (
389
+ length = 0
390
+ async for result in self ._infer_code (
390
391
text ,
391
392
stream ,
392
393
self .device ,
393
394
use_decoder ,
394
395
params_infer_code ,
396
+ stream_batch_size ,
395
397
):
396
398
wavs = self ._decode_to_wavs (
397
399
result .hiddens if use_decoder else result .ids ,
398
400
use_decoder ,
399
401
)
400
- result .destroy ()
401
- if stream :
402
- pass_batch_count += 1
403
- if pass_batch_count <= params_infer_code .pass_first_n_batches :
404
- continue
405
- a = length
406
- b = a + params_infer_code .stream_speed
407
- if b > wavs .shape [1 ]:
408
- b = wavs .shape [1 ]
409
- new_wavs = wavs [:, a :b ]
410
- length = b
411
- yield new_wavs
402
+
403
+ if result .finished :
404
+ yield wavs [:, length :]
412
405
else :
413
- yield wavs
414
- if stream :
415
- new_wavs = wavs [:, length :]
416
- # Identify rows with non-zero elements using np.any
417
- # keep_rows = np.any(array != 0, axis=1)
418
- keep_cols = np .sum (new_wavs != 0 , axis = 0 ) > 0
419
- # Filter both rows and columns using slicing
420
- yield new_wavs [:][:, keep_cols ]
406
+ # Hacker:Check if there are any silent segments; if so, take the last segment. Otherwise, try waiting for another loop.
407
+ import librosa
408
+
409
+ silence_intervals = librosa .effects .split (wavs [0 ][length :], top_db = 10 )
410
+ silence_left = 0
411
+ if len (silence_intervals ) == 0 :
412
+ silence_left = len (wavs [0 ])
413
+ else :
414
+ for i in range (len (silence_intervals )):
415
+ silence_left = silence_intervals [i ][0 ]
416
+ if silence_left <= 0 :
417
+ continue
418
+ new_wavs = wavs [:, length : length + silence_left ]
419
+ length += len (new_wavs [0 ])
420
+ yield new_wavs
421
421
422
422
@torch .inference_mode ()
423
423
def _vocos_decode (self , spec : torch .Tensor ) -> np .ndarray :
@@ -456,13 +456,14 @@ def _decode_to_wavs(
456
456
return wavs
457
457
458
458
@torch .no_grad ()
459
- def _infer_code (
459
+ async def _infer_code (
460
460
self ,
461
461
text : Tuple [List [str ], str ],
462
462
stream : bool ,
463
463
device : torch .device ,
464
464
return_hidden : bool ,
465
465
params : InferCodeParams ,
466
+ stream_batch_size : int ,
466
467
):
467
468
468
469
gpt = self .gpt
@@ -503,6 +504,17 @@ def _infer_code(
503
504
repetition_penalty = params .repetition_penalty ,
504
505
)
505
506
507
+ speaker_embedding_param = gpt (input_ids , text_mask )
508
+
509
+ if params .spk_emb is not None :
510
+ self .speaker .apply (
511
+ speaker_embedding_param ,
512
+ params .spk_emb ,
513
+ input_ids ,
514
+ self .tokenizer .spk_emb_ids ,
515
+ self .gpt .device_gpt ,
516
+ )
517
+
506
518
if gpt .is_vllm :
507
519
from .model .velocity import SamplingParams
508
520
@@ -518,65 +530,23 @@ def _infer_code(
518
530
)
519
531
input_ids = [i .tolist () for i in input_ids ]
520
532
521
- result = gpt .llm .generate (
522
- None ,
523
- sample_params ,
524
- input_ids ,
533
+ results_generator = gpt .llm .llm_engine .generate (
534
+ None , sample_params , uuid .uuid4 (), speaker_embedding_param , input_ids [0 ]
525
535
)
526
-
527
- token_ids = []
528
- hidden_states = []
529
- for i in result :
530
- token_ids .append (torch .tensor (i .outputs [0 ].token_ids ))
531
- hidden_states .append (
532
- i .outputs [0 ].hidden_states .to (torch .float32 ).to (self .device )
533
- )
534
-
535
- del text_mask , input_ids
536
-
537
- return [
538
- GPT .GenerationOutputs (
539
- ids = token_ids ,
540
- hiddens = hidden_states ,
541
- attentions = [],
542
- ),
543
- ]
544
-
545
- emb = self .embed (input_ids , text_mask )
546
-
547
- del text_mask
548
-
549
- if params .spk_emb is not None :
550
- self .speaker .apply (
551
- emb ,
552
- params .spk_emb ,
553
- input_ids ,
554
- self .tokenizer .spk_emb_ids ,
555
- self .gpt .device_gpt ,
556
- )
557
-
558
- result = gpt .generate (
559
- emb ,
560
- input_ids ,
561
- temperature = torch .tensor (temperature , device = device ),
562
- eos_token = num_code ,
563
- attention_mask = attention_mask ,
564
- max_new_token = params .max_new_token ,
565
- min_new_token = params .min_new_token ,
566
- logits_processors = (* logits_processors , * logits_warpers ),
567
- infer_text = False ,
568
- return_hidden = return_hidden ,
569
- stream = stream ,
570
- show_tqdm = params .show_tqdm ,
571
- ensure_non_empty = params .ensure_non_empty ,
572
- stream_batch = params .stream_batch ,
573
- manual_seed = params .manual_seed ,
574
- context = self .context ,
575
- )
576
-
577
- del emb , input_ids
578
-
579
- return result
536
+ async for i in results_generator :
537
+ token_ids = []
538
+ hidden_states = []
539
+ if len (i .outputs [0 ].token_ids ) % stream_batch_size == 0 or i .finished :
540
+ token_ids .append (torch .tensor (i .outputs [0 ].token_ids ))
541
+ hidden_states .append (
542
+ i .outputs [0 ].hidden_states .to (torch .float32 ).to (self .device )
543
+ )
544
+ yield GPT .GenerationOutputs (
545
+ ids = token_ids ,
546
+ finished = i .finished ,
547
+ hiddens = hidden_states ,
548
+ attentions = [],
549
+ )
580
550
581
551
@torch .no_grad ()
582
552
def _refine_text (
0 commit comments