2
2
from typing import Any , Optional , Sequence , Tuple , Union
3
3
4
4
import torch
5
+ from audio_encoders_pytorch import Bottleneck , Encoder1d
5
6
from einops import rearrange
6
7
from torch import Tensor , nn
7
8
16
17
VKDiffusion ,
17
18
VSampler ,
18
19
)
19
- from .modules import (
20
- STFT ,
21
- Bottleneck ,
22
- MultiEncoder1d ,
23
- SinusoidalEmbedding ,
24
- UNet1d ,
25
- UNetConditional1d ,
20
+ from .modules import STFT , SinusoidalEmbedding , UNet1d , UNetConditional1d
21
+ from .utils import (
22
+ closest_power_2 ,
23
+ default ,
24
+ downsample ,
25
+ exists ,
26
+ groupby ,
27
+ prefix_dict ,
28
+ prod ,
29
+ to_list ,
30
+ upsample ,
26
31
)
27
- from .utils import default , downsample , exists , groupby_kwargs_prefix , to_list , upsample
28
32
29
33
"""
30
34
Diffusion Classes (generic for 1d data)
@@ -36,7 +40,7 @@ def __init__(
36
40
self , diffusion_type : str , use_classifier_free_guidance : bool = False , ** kwargs
37
41
):
38
42
super ().__init__ ()
39
- diffusion_kwargs , kwargs = groupby_kwargs_prefix ("diffusion_" , kwargs )
43
+ diffusion_kwargs , kwargs = groupby ("diffusion_" , kwargs )
40
44
41
45
UNet = UNetConditional1d if use_classifier_free_guidance else UNet1d
42
46
self .unet = UNet (** kwargs )
@@ -149,31 +153,25 @@ def __init__(
149
153
resnet_groups : int ,
150
154
kernel_multiplier_downsample : int ,
151
155
encoder_depth : int ,
152
- encoder_channels : int ,
153
- bottleneck : Optional [Bottleneck ] = None ,
154
156
encoder_num_blocks : Optional [Sequence [int ]] = None ,
155
- encoder_out_layers : int = 0 ,
157
+ bottleneck : Union [Bottleneck , Sequence [Bottleneck ]] = [],
158
+ bottleneck_channels : Optional [int ] = None ,
159
+ use_stft : bool = False ,
156
160
** kwargs ,
157
161
):
158
162
self .in_channels = in_channels
159
163
encoder_num_blocks = default (encoder_num_blocks , num_blocks )
160
164
assert_message = "The number of encoder_num_blocks must match encoder_depth"
161
165
assert len (encoder_num_blocks ) >= encoder_depth , assert_message
166
+ assert patch_blocks == 1 , "patch_blocks != 1 not supported"
167
+ assert not use_stft , "use_stft not supported"
168
+ self .factor = patch_factor * prod (factors [0 :encoder_depth ])
162
169
163
- multiencoder = MultiEncoder1d (
164
- in_channels = in_channels ,
165
- channels = channels ,
166
- patch_blocks = patch_blocks ,
167
- patch_factor = patch_factor ,
168
- num_layers = encoder_depth ,
169
- num_layers_out = encoder_out_layers ,
170
- latent_channels = encoder_channels ,
171
- multipliers = multipliers ,
172
- factors = factors ,
173
- num_blocks = encoder_num_blocks ,
174
- kernel_multiplier_downsample = kernel_multiplier_downsample ,
175
- resnet_groups = resnet_groups ,
176
- )
170
+ context_channels = [0 ] * encoder_depth
171
+ if exists (bottleneck_channels ):
172
+ context_channels += [bottleneck_channels ]
173
+ else :
174
+ context_channels += [channels * multipliers [encoder_depth ]]
177
175
178
176
super ().__init__ (
179
177
in_channels = in_channels ,
@@ -185,89 +183,81 @@ def __init__(
185
183
num_blocks = num_blocks ,
186
184
resnet_groups = resnet_groups ,
187
185
kernel_multiplier_downsample = kernel_multiplier_downsample ,
188
- context_channels = multiencoder . channels_list ,
186
+ context_channels = context_channels ,
189
187
** kwargs ,
190
188
)
191
189
192
- self .bottleneck = bottleneck
193
- self .multiencoder = multiencoder
190
+ self .bottlenecks = nn .ModuleList (to_list (bottleneck ))
191
+ self .encoder = Encoder1d (
192
+ in_channels = in_channels ,
193
+ channels = channels ,
194
+ patch_size = patch_factor ,
195
+ multipliers = multipliers [0 : encoder_depth + 1 ],
196
+ factors = factors [0 :encoder_depth ],
197
+ num_blocks = encoder_num_blocks [0 :encoder_depth ],
198
+ resnet_groups = resnet_groups ,
199
+ out_channels = bottleneck_channels ,
200
+ )
201
+
202
+ def encode (
203
+ self , x : Tensor , with_info : bool = False
204
+ ) -> Union [Tensor , Tuple [Tensor , Any ]]:
205
+ latent , info = self .encoder (x , with_info = True )
206
+ for bottleneck in self .bottlenecks :
207
+ x , info_bottleneck = bottleneck (x , with_info = True )
208
+ info = {** info , ** prefix_dict ("bottleneck_" , info_bottleneck )}
209
+ return (latent , info ) if with_info else latent
194
210
195
211
def forward ( # type: ignore
196
212
self , x : Tensor , with_info : bool = False , ** kwargs
197
213
) -> Union [Tensor , Tuple [Tensor , Any ]]:
198
- if with_info :
199
- latent , info = self .encode (x , with_info = True )
200
- else :
201
- latent = self .encode (x )
202
-
203
- channels_list = self .multiencoder .decode (latent )
204
- loss = self .diffusion (x , channels_list = channels_list , ** kwargs )
214
+ latent , info = self .encode (x , with_info = True )
215
+ loss = self .diffusion (x , channels_list = [latent ], ** kwargs )
205
216
return (loss , info ) if with_info else loss
206
217
207
- def encode (
208
- self , x : Tensor , with_info : bool = False
209
- ) -> Union [Tensor , Tuple [Tensor , Any ]]:
210
- latent = self .multiencoder .encode (x )
211
- latent = torch .tanh (latent )
212
- # Apply bottleneck if provided (e.g. quantization module)
213
- if exists (self .bottleneck ):
214
- latent , info = self .bottleneck (latent )
215
- return (latent , info ) if with_info else latent
216
- return latent
217
-
218
218
def decode (self , latent : Tensor , ** kwargs ) -> Tensor :
219
- b , length = latent .shape [0 ], latent .shape [2 ] * self .multiencoder . factor
219
+ b , length = latent .shape [0 ], latent .shape [2 ] * self .factor
220
220
# Compute noise by inferring shape from latent length
221
221
noise = torch .randn (b , self .in_channels , length ).to (latent )
222
222
# Compute context form latent
223
- channels_list = self .multiencoder .decode (latent )
224
- default_kwargs = dict (channels_list = channels_list )
223
+ default_kwargs = dict (channels_list = [latent ])
225
224
# Decode by sampling while conditioning on latent channels
226
225
return super ().sample (noise , ** {** default_kwargs , ** kwargs }) # type: ignore
227
226
228
227
229
228
class DiffusionVocoder1d (Model1d ):
230
- def __init__ (
231
- self ,
232
- in_channels : int ,
233
- vocoder_num_fft : int ,
234
- ** kwargs ,
235
- ):
236
- self .frequency_channels = vocoder_num_fft // 2 + 1
237
- spectrogram_channels = in_channels * self .frequency_channels
238
-
239
- vocoder_kwargs , kwargs = groupby_kwargs_prefix ("vocoder_" , kwargs )
229
+ def __init__ (self , in_channels : int , stft_num_fft : int , ** kwargs ):
230
+ self .stft_num_fft = stft_num_fft
231
+ spectrogram_channels = stft_num_fft // 2 + 1
240
232
default_kwargs = dict (
241
- in_channels = spectrogram_channels , context_channels = [spectrogram_channels ]
233
+ in_channels = in_channels ,
234
+ use_stft = True ,
235
+ stft_num_fft = stft_num_fft ,
236
+ context_channels = [in_channels * spectrogram_channels ],
242
237
)
243
-
244
238
super ().__init__ (** {** default_kwargs , ** kwargs }) # type: ignore
245
- self .stft = STFT (num_fft = vocoder_num_fft , ** vocoder_kwargs )
246
239
247
240
def forward (self , x : Tensor , ** kwargs ) -> Tensor :
248
- # Get magnitude and phase of true wave
249
- magnitude , phase = self .stft .encode (x )
241
+ # Get magnitude spectrogram from true wave
242
+ magnitude , _ = self . unet .stft .encode (x )
250
243
magnitude = rearrange (magnitude , "b c f t -> b (c f) t" )
251
- phase = rearrange (phase , "b c f t -> b (c f) t" )
252
- # Get diffusion phase loss while conditioning on magnitude (/pi [-1,1] range)
253
- return self .diffusion (phase / pi , channels_list = [magnitude ], ** kwargs )
244
+ # Get diffusion loss while conditioning on magnitude
245
+ return self .diffusion (x , channels_list = [magnitude ], ** kwargs )
254
246
255
247
def sample (self , spectrogram : Tensor , ** kwargs ): # type: ignore
256
- b , c , f , t , device = * spectrogram .shape , spectrogram .device
248
+ b , c , _ , t , device = * spectrogram .shape , spectrogram .device
257
249
magnitude = rearrange (spectrogram , "b c f t -> b (c f) t" )
258
- noise = torch .randn ((b , c * f , t ), device = device )
250
+ timesteps = closest_power_2 (self .unet .stft .hop_length * t )
251
+ noise = torch .randn ((b , c , timesteps ), device = device )
259
252
default_kwargs = dict (channels_list = [magnitude ])
260
- phase = super ().sample (noise , ** {** default_kwargs , ** kwargs }) # type: ignore # noqa
261
- phase = rearrange (phase , "b (c f) t -> b c f t" , c = c )
262
- wave = self .stft .decode (spectrogram , phase * pi )
263
- return wave
253
+ return super ().sample (noise , ** {** default_kwargs , ** kwargs }) # type: ignore # noqa
264
254
265
255
266
256
class DiffusionUpphaser1d (DiffusionUpsampler1d ):
267
257
def __init__ (self , ** kwargs ):
268
- vocoder_kwargs , kwargs = groupby_kwargs_prefix ( "vocoder_ " , kwargs )
258
+ stft_kwargs , kwargs = groupby ( "stft_ " , kwargs )
269
259
super ().__init__ (** kwargs )
270
- self .stft = STFT (** vocoder_kwargs )
260
+ self .stft = STFT (** stft_kwargs )
271
261
272
262
def random_rephase (self , x : Tensor ) -> Tensor :
273
263
magnitude , phase = self .stft .encode (x )
@@ -305,7 +295,6 @@ def get_default_model_kwargs():
305
295
use_nearest_upsample = False ,
306
296
use_skip_scale = True ,
307
297
use_context_time = True ,
308
- use_magnitude_channels = False ,
309
298
diffusion_type = "v" ,
310
299
diffusion_sigma_distribution = UniformDistribution (),
311
300
)
@@ -380,12 +369,13 @@ class AudioDiffusionVocoder(DiffusionVocoder1d):
380
369
def __init__ (self , in_channels : int , ** kwargs ):
381
370
default_kwargs = dict (
382
371
in_channels = in_channels ,
383
- vocoder_num_fft = 1023 ,
384
- channels = 32 ,
372
+ stft_num_fft = 1023 ,
373
+ stft_hop_length = 256 ,
374
+ channels = 64 ,
385
375
patch_blocks = 1 ,
386
376
patch_factor = 1 ,
387
- multipliers = [64 , 32 , 16 , 8 , 4 , 2 , 1 ],
388
- factors = [1 , 1 , 1 , 1 , 1 , 1 ],
377
+ multipliers = [48 , 32 , 16 , 8 , 8 , 8 , 8 ],
378
+ factors = [2 , 2 , 2 , 1 , 1 , 1 ],
389
379
num_blocks = [1 , 1 , 1 , 1 , 1 , 1 ],
390
380
attentions = [0 , 0 , 0 , 1 , 1 , 1 ],
391
381
attention_heads = 8 ,
0 commit comments