@@ -145,7 +145,7 @@ def _load(
145
145
146
146
self .check_model ()
147
147
148
- def infer (
148
+ def _infer (
149
149
self ,
150
150
text ,
151
151
skip_refine_text = False ,
@@ -155,22 +155,21 @@ def infer(
155
155
use_decoder = True ,
156
156
do_text_normalization = True ,
157
157
lang = None ,
158
+ stream = False ,
158
159
do_homophone_replacement = True
159
160
):
160
161
161
162
assert self .check_model (use_decoder = use_decoder )
162
163
163
164
if not isinstance (text , list ):
164
165
text = [text ]
165
-
166
166
if do_text_normalization :
167
167
for i , t in enumerate (text ):
168
168
_lang = detect_language (t ) if lang is None else lang
169
169
if self .init_normalizer (_lang ):
170
170
text [i ] = self .normalizer [_lang ](t )
171
171
if _lang == 'zh' :
172
172
text [i ] = apply_half2full_map (text [i ])
173
-
174
173
for i , t in enumerate (text ):
175
174
invalid_characters = count_invalid_characters (t )
176
175
if len (invalid_characters ):
@@ -190,18 +189,44 @@ def infer(
190
189
191
190
text = [params_infer_code .get ('prompt' , '' ) + i for i in text ]
192
191
params_infer_code .pop ('prompt' , '' )
193
- result = infer_code (self .pretrain_models , text , ** params_infer_code , return_hidden = use_decoder )
194
-
192
+ result_gen = infer_code (self .pretrain_models , text , ** params_infer_code , return_hidden = use_decoder , stream = stream )
195
193
if use_decoder :
196
- mel_spec = [self .pretrain_models ['decoder' ](i [None ].permute (0 ,2 ,1 )) for i in result ['hiddens' ]]
194
+ field = 'hiddens'
195
+ docoder_name = 'decoder'
197
196
else :
198
- mel_spec = [self .pretrain_models ['dvae' ](i [None ].permute (0 ,2 ,1 )) for i in result ['ids' ]]
199
-
200
- wav = [self .pretrain_models ['vocos' ].decode (
201
- i .cpu () if torch .backends .mps .is_available () else i
202
- ).cpu ().numpy () for i in mel_spec ]
203
-
204
- return wav
197
+ field = 'ids'
198
+ docoder_name = 'dvae'
199
+ vocos_decode = lambda spec : [self .pretrain_models ['vocos' ].decode (
200
+ i .cpu () if torch .backends .mps .is_available () else i
201
+ ).cpu ().numpy () for i in spec ]
202
+ if stream :
203
+
204
+ length = 0
205
+ for result in result_gen :
206
+ chunk_data = result [field ][0 ]
207
+ assert len (result [field ]) == 1
208
+ start_seek = length
209
+ length = len (chunk_data )
210
+ self .logger .debug (f'{ start_seek = } total len: { length } , new len: { length - start_seek = } ' )
211
+ chunk_data = chunk_data [start_seek :]
212
+ if not len (chunk_data ):
213
+ continue
214
+ self .logger .debug (f'new hidden { len (chunk_data )= } ' )
215
+ mel_spec = [self .pretrain_models [docoder_name ](i [None ].permute (0 ,2 ,1 )) for i in [chunk_data ]]
216
+ wav = vocos_decode (mel_spec )
217
+ self .logger .debug (f'yield wav chunk { len (wav [0 ])= } { len (wav [0 ][0 ])= } ' )
218
+ yield wav
219
+ return
220
+ mel_spec = [self .pretrain_models [docoder_name ](i [None ].permute (0 ,2 ,1 )) for i in next (result_gen )[field ]]
221
+ yield vocos_decode (mel_spec )
222
+
223
+ def infer (self , * args , ** kwargs ):
224
+ stream = kwargs .setdefault ('stream' , False )
225
+ res_gen = self ._infer (* args , ** kwargs )
226
+ if stream :
227
+ return res_gen
228
+ else :
229
+ return next (res_gen )
205
230
206
231
def sample_random_speaker (self , ):
207
232
0 commit comments