Skip to content

Commit d0b206a

Browse files
feat: update diffusion ae, remove aes, stft unet1d, new vocoder
1 parent 21014f9 commit d0b206a

File tree

6 files changed

+129
-592
lines changed

6 files changed

+129
-592
lines changed

README.md

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,7 @@ from audio_diffusion_pytorch import AudioDiffusionAutoencoder
6666

6767
autoencoder = AudioDiffusionAutoencoder(
6868
in_channels=1,
69-
encoder_depth=4,
70-
encoder_channels=32
69+
encoder_depth=4
7170
)
7271

7372
# Train on audio samples
@@ -157,8 +156,7 @@ unet = UNet1d(
157156
kernel_multiplier_downsample=2,
158157
use_nearest_upsample=False,
159158
use_skip_scale=True,
160-
use_context_time=True,
161-
use_magnitude_channels=False
159+
use_context_time=True
162160
)
163161

164162
x = torch.randn(3, 1, 2 ** 16)

audio_diffusion_pytorch/__init__.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -32,16 +32,4 @@
3232
DiffusionVocoder1d,
3333
Model1d,
3434
)
35-
from .modules import (
36-
AutoEncoder1d,
37-
Decoder1d,
38-
Encoder1d,
39-
MultiEncoder1d,
40-
Noiser,
41-
STFTAutoEncoder1d,
42-
T5Embedder,
43-
Tanh,
44-
UNet1d,
45-
UNetConditional1d,
46-
Variational,
47-
)
35+
from .modules import T5Embedder, UNet1d, UNetConditional1d

audio_diffusion_pytorch/model.py

Lines changed: 71 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import Any, Optional, Sequence, Tuple, Union
33

44
import torch
5+
from audio_encoders_pytorch import Bottleneck, Encoder1d
56
from einops import rearrange
67
from torch import Tensor, nn
78

@@ -16,15 +17,18 @@
1617
VKDiffusion,
1718
VSampler,
1819
)
19-
from .modules import (
20-
STFT,
21-
Bottleneck,
22-
MultiEncoder1d,
23-
SinusoidalEmbedding,
24-
UNet1d,
25-
UNetConditional1d,
20+
from .modules import STFT, SinusoidalEmbedding, UNet1d, UNetConditional1d
21+
from .utils import (
22+
closest_power_2,
23+
default,
24+
downsample,
25+
exists,
26+
groupby,
27+
prefix_dict,
28+
prod,
29+
to_list,
30+
upsample,
2631
)
27-
from .utils import default, downsample, exists, groupby_kwargs_prefix, to_list, upsample
2832

2933
"""
3034
Diffusion Classes (generic for 1d data)
@@ -36,7 +40,7 @@ def __init__(
3640
self, diffusion_type: str, use_classifier_free_guidance: bool = False, **kwargs
3741
):
3842
super().__init__()
39-
diffusion_kwargs, kwargs = groupby_kwargs_prefix("diffusion_", kwargs)
43+
diffusion_kwargs, kwargs = groupby("diffusion_", kwargs)
4044

4145
UNet = UNetConditional1d if use_classifier_free_guidance else UNet1d
4246
self.unet = UNet(**kwargs)
@@ -149,31 +153,25 @@ def __init__(
149153
resnet_groups: int,
150154
kernel_multiplier_downsample: int,
151155
encoder_depth: int,
152-
encoder_channels: int,
153-
bottleneck: Optional[Bottleneck] = None,
154156
encoder_num_blocks: Optional[Sequence[int]] = None,
155-
encoder_out_layers: int = 0,
157+
bottleneck: Union[Bottleneck, Sequence[Bottleneck]] = [],
158+
bottleneck_channels: Optional[int] = None,
159+
use_stft: bool = False,
156160
**kwargs,
157161
):
158162
self.in_channels = in_channels
159163
encoder_num_blocks = default(encoder_num_blocks, num_blocks)
160164
assert_message = "The number of encoder_num_blocks must match encoder_depth"
161165
assert len(encoder_num_blocks) >= encoder_depth, assert_message
166+
assert patch_blocks == 1, "patch_blocks != 1 not supported"
167+
assert not use_stft, "use_stft not supported"
168+
self.factor = patch_factor * prod(factors[0:encoder_depth])
162169

163-
multiencoder = MultiEncoder1d(
164-
in_channels=in_channels,
165-
channels=channels,
166-
patch_blocks=patch_blocks,
167-
patch_factor=patch_factor,
168-
num_layers=encoder_depth,
169-
num_layers_out=encoder_out_layers,
170-
latent_channels=encoder_channels,
171-
multipliers=multipliers,
172-
factors=factors,
173-
num_blocks=encoder_num_blocks,
174-
kernel_multiplier_downsample=kernel_multiplier_downsample,
175-
resnet_groups=resnet_groups,
176-
)
170+
context_channels = [0] * encoder_depth
171+
if exists(bottleneck_channels):
172+
context_channels += [bottleneck_channels]
173+
else:
174+
context_channels += [channels * multipliers[encoder_depth]]
177175

178176
super().__init__(
179177
in_channels=in_channels,
@@ -185,89 +183,81 @@ def __init__(
185183
num_blocks=num_blocks,
186184
resnet_groups=resnet_groups,
187185
kernel_multiplier_downsample=kernel_multiplier_downsample,
188-
context_channels=multiencoder.channels_list,
186+
context_channels=context_channels,
189187
**kwargs,
190188
)
191189

192-
self.bottleneck = bottleneck
193-
self.multiencoder = multiencoder
190+
self.bottlenecks = nn.ModuleList(to_list(bottleneck))
191+
self.encoder = Encoder1d(
192+
in_channels=in_channels,
193+
channels=channels,
194+
patch_size=patch_factor,
195+
multipliers=multipliers[0 : encoder_depth + 1],
196+
factors=factors[0:encoder_depth],
197+
num_blocks=encoder_num_blocks[0:encoder_depth],
198+
resnet_groups=resnet_groups,
199+
out_channels=bottleneck_channels,
200+
)
201+
202+
def encode(
203+
self, x: Tensor, with_info: bool = False
204+
) -> Union[Tensor, Tuple[Tensor, Any]]:
205+
latent, info = self.encoder(x, with_info=True)
206+
for bottleneck in self.bottlenecks:
207+
x, info_bottleneck = bottleneck(x, with_info=True)
208+
info = {**info, **prefix_dict("bottleneck_", info_bottleneck)}
209+
return (latent, info) if with_info else latent
194210

195211
def forward( # type: ignore
196212
self, x: Tensor, with_info: bool = False, **kwargs
197213
) -> Union[Tensor, Tuple[Tensor, Any]]:
198-
if with_info:
199-
latent, info = self.encode(x, with_info=True)
200-
else:
201-
latent = self.encode(x)
202-
203-
channels_list = self.multiencoder.decode(latent)
204-
loss = self.diffusion(x, channels_list=channels_list, **kwargs)
214+
latent, info = self.encode(x, with_info=True)
215+
loss = self.diffusion(x, channels_list=[latent], **kwargs)
205216
return (loss, info) if with_info else loss
206217

207-
def encode(
208-
self, x: Tensor, with_info: bool = False
209-
) -> Union[Tensor, Tuple[Tensor, Any]]:
210-
latent = self.multiencoder.encode(x)
211-
latent = torch.tanh(latent)
212-
# Apply bottleneck if provided (e.g. quantization module)
213-
if exists(self.bottleneck):
214-
latent, info = self.bottleneck(latent)
215-
return (latent, info) if with_info else latent
216-
return latent
217-
218218
def decode(self, latent: Tensor, **kwargs) -> Tensor:
219-
b, length = latent.shape[0], latent.shape[2] * self.multiencoder.factor
219+
b, length = latent.shape[0], latent.shape[2] * self.factor
220220
# Compute noise by inferring shape from latent length
221221
noise = torch.randn(b, self.in_channels, length).to(latent)
222222
# Compute context form latent
223-
channels_list = self.multiencoder.decode(latent)
224-
default_kwargs = dict(channels_list=channels_list)
223+
default_kwargs = dict(channels_list=[latent])
225224
# Decode by sampling while conditioning on latent channels
226225
return super().sample(noise, **{**default_kwargs, **kwargs}) # type: ignore
227226

228227

229228
class DiffusionVocoder1d(Model1d):
230-
def __init__(
231-
self,
232-
in_channels: int,
233-
vocoder_num_fft: int,
234-
**kwargs,
235-
):
236-
self.frequency_channels = vocoder_num_fft // 2 + 1
237-
spectrogram_channels = in_channels * self.frequency_channels
238-
239-
vocoder_kwargs, kwargs = groupby_kwargs_prefix("vocoder_", kwargs)
229+
def __init__(self, in_channels: int, stft_num_fft: int, **kwargs):
230+
self.stft_num_fft = stft_num_fft
231+
spectrogram_channels = stft_num_fft // 2 + 1
240232
default_kwargs = dict(
241-
in_channels=spectrogram_channels, context_channels=[spectrogram_channels]
233+
in_channels=in_channels,
234+
use_stft=True,
235+
stft_num_fft=stft_num_fft,
236+
context_channels=[in_channels * spectrogram_channels],
242237
)
243-
244238
super().__init__(**{**default_kwargs, **kwargs}) # type: ignore
245-
self.stft = STFT(num_fft=vocoder_num_fft, **vocoder_kwargs)
246239

247240
def forward(self, x: Tensor, **kwargs) -> Tensor:
248-
# Get magnitude and phase of true wave
249-
magnitude, phase = self.stft.encode(x)
241+
# Get magnitude spectrogram from true wave
242+
magnitude, _ = self.unet.stft.encode(x)
250243
magnitude = rearrange(magnitude, "b c f t -> b (c f) t")
251-
phase = rearrange(phase, "b c f t -> b (c f) t")
252-
# Get diffusion phase loss while conditioning on magnitude (/pi [-1,1] range)
253-
return self.diffusion(phase / pi, channels_list=[magnitude], **kwargs)
244+
# Get diffusion loss while conditioning on magnitude
245+
return self.diffusion(x, channels_list=[magnitude], **kwargs)
254246

255247
def sample(self, spectrogram: Tensor, **kwargs): # type: ignore
256-
b, c, f, t, device = *spectrogram.shape, spectrogram.device
248+
b, c, _, t, device = *spectrogram.shape, spectrogram.device
257249
magnitude = rearrange(spectrogram, "b c f t -> b (c f) t")
258-
noise = torch.randn((b, c * f, t), device=device)
250+
timesteps = closest_power_2(self.unet.stft.hop_length * t)
251+
noise = torch.randn((b, c, timesteps), device=device)
259252
default_kwargs = dict(channels_list=[magnitude])
260-
phase = super().sample(noise, **{**default_kwargs, **kwargs}) # type: ignore # noqa
261-
phase = rearrange(phase, "b (c f) t -> b c f t", c=c)
262-
wave = self.stft.decode(spectrogram, phase * pi)
263-
return wave
253+
return super().sample(noise, **{**default_kwargs, **kwargs}) # type: ignore # noqa
264254

265255

266256
class DiffusionUpphaser1d(DiffusionUpsampler1d):
267257
def __init__(self, **kwargs):
268-
vocoder_kwargs, kwargs = groupby_kwargs_prefix("vocoder_", kwargs)
258+
stft_kwargs, kwargs = groupby("stft_", kwargs)
269259
super().__init__(**kwargs)
270-
self.stft = STFT(**vocoder_kwargs)
260+
self.stft = STFT(**stft_kwargs)
271261

272262
def random_rephase(self, x: Tensor) -> Tensor:
273263
magnitude, phase = self.stft.encode(x)
@@ -305,7 +295,6 @@ def get_default_model_kwargs():
305295
use_nearest_upsample=False,
306296
use_skip_scale=True,
307297
use_context_time=True,
308-
use_magnitude_channels=False,
309298
diffusion_type="v",
310299
diffusion_sigma_distribution=UniformDistribution(),
311300
)
@@ -380,12 +369,13 @@ class AudioDiffusionVocoder(DiffusionVocoder1d):
380369
def __init__(self, in_channels: int, **kwargs):
381370
default_kwargs = dict(
382371
in_channels=in_channels,
383-
vocoder_num_fft=1023,
384-
channels=32,
372+
stft_num_fft=1023,
373+
stft_hop_length=256,
374+
channels=64,
385375
patch_blocks=1,
386376
patch_factor=1,
387-
multipliers=[64, 32, 16, 8, 4, 2, 1],
388-
factors=[1, 1, 1, 1, 1, 1],
377+
multipliers=[48, 32, 16, 8, 8, 8, 8],
378+
factors=[2, 2, 2, 1, 1, 1],
389379
num_blocks=[1, 1, 1, 1, 1, 1],
390380
attentions=[0, 0, 0, 1, 1, 1],
391381
attention_heads=8,

0 commit comments

Comments
 (0)