Skip to content

Commit 7fa0ad2

Browse files
committedNov 26, 2022
feat: decouple ae from diffae/diffmae, update ncca
1 parent d730653 commit 7fa0ad2

File tree

5 files changed

+60
-225
lines changed

5 files changed

+60
-225
lines changed
 

‎README.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,9 @@ upsampled = upsampler.sample(
6262

6363
### Autoencoding
6464
```py
65-
from audio_diffusion_pytorch import AudioDiffusionAutoencoder
65+
from audio_diffusion_pytorch import AudioDiffusionAE
6666

67-
autoencoder = AudioDiffusionAutoencoder(in_channels=1)
67+
autoencoder = AudioDiffusionAE(in_channels=1)
6868

6969
# Train on audio samples
7070
x = torch.randn(2, 1, 2 ** 18)

‎audio_diffusion_pytorch/__init__.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from audio_encoders_pytorch import Encoder1d, ME1d
2+
13
from .diffusion import (
24
ADPM2Sampler,
35
AEulerSampler,
@@ -21,15 +23,14 @@
2123
XDiffusion,
2224
)
2325
from .model import (
24-
AudioDiffusionAutoencoder,
26+
AudioDiffusionAE,
2527
AudioDiffusionConditional,
2628
AudioDiffusionModel,
2729
AudioDiffusionUpphaser,
2830
AudioDiffusionUpsampler,
2931
AudioDiffusionVocoder,
32+
DiffusionAE1d,
3033
DiffusionAR1d,
31-
DiffusionAutoencoder1d,
32-
DiffusionMAE1d,
3334
DiffusionUpphaser1d,
3435
DiffusionUpsampler1d,
3536
DiffusionVocoder1d,

‎audio_diffusion_pytorch/model.py

+27-193
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import Any, Optional, Sequence, Tuple, Union
44

55
import torch
6-
from audio_encoders_pytorch import Bottleneck, Encoder1d
6+
from audio_encoders_pytorch import Encoder1d
77
from einops import rearrange
88
from torch import Tensor, nn
99
from tqdm import tqdm
@@ -16,8 +16,6 @@
1616
downsample,
1717
exists,
1818
groupby,
19-
prefix_dict,
20-
prod,
2119
to_list,
2220
upsample,
2321
)
@@ -104,194 +102,40 @@ def sample( # type: ignore
104102
return super().sample(noise, **{**default_kwargs, **kwargs}) # type: ignore
105103

106104

107-
class DiffusionAutoencoder1d(nn.Module):
108-
def __init__(
109-
self,
110-
in_channels: int,
111-
encoder_inject_depth: int,
112-
encoder_channels: int,
113-
encoder_factors: Sequence[int],
114-
encoder_multipliers: Sequence[int],
115-
encoder_patch_size: int = 1,
116-
bottleneck: Union[Bottleneck, Sequence[Bottleneck]] = [],
117-
bottleneck_channels: Optional[int] = None,
118-
unet_type: str = "base",
119-
**kwargs,
120-
):
121-
super().__init__()
122-
self.in_channels = in_channels
123-
124-
encoder_kwargs, kwargs = groupby("encoder_", kwargs)
125-
diffusion_kwargs, kwargs = groupby("diffusion_", kwargs)
126-
127-
# Compute context channels
128-
context_channels = [0] * encoder_inject_depth
129-
if exists(bottleneck_channels):
130-
context_channels += [bottleneck_channels]
131-
else:
132-
context_channels += [encoder_channels * encoder_multipliers[-1]]
133-
134-
self.unet = XUNet1d(
135-
type=unet_type,
136-
in_channels=in_channels,
137-
context_channels=context_channels,
138-
**kwargs,
139-
)
140-
141-
self.diffusion = XDiffusion(net=self.unet, **diffusion_kwargs)
142-
143-
self.encoder = Encoder1d(
144-
in_channels=in_channels,
145-
channels=encoder_channels,
146-
patch_size=encoder_patch_size,
147-
factors=encoder_factors,
148-
multipliers=encoder_multipliers,
149-
out_channels=bottleneck_channels,
150-
**encoder_kwargs,
151-
)
152-
153-
self.encoder_downsample_factor = encoder_patch_size * prod(encoder_factors)
154-
self.bottleneck_channels = bottleneck_channels
155-
self.bottlenecks = nn.ModuleList(to_list(bottleneck))
105+
class DiffusionAE1d(Model1d):
106+
"""Diffusion Auto Encoder"""
156107

157-
def encode(
158-
self, x: Tensor, with_info: bool = False
159-
) -> Union[Tensor, Tuple[Tensor, Any]]:
160-
latent, info = self.encoder(x, with_info=True)
161-
# Apply bottlenecks if present
162-
for bottleneck in self.bottlenecks:
163-
latent, info_bottleneck = bottleneck(latent, with_info=True)
164-
info = {**info, **prefix_dict("bottleneck_", info_bottleneck)}
165-
return (latent, info) if with_info else latent
166-
167-
def forward( # type: ignore
168-
self, x: Tensor, with_info: bool = False, **kwargs
169-
) -> Union[Tensor, Tuple[Tensor, Any]]:
170-
latent, info = self.encode(x, with_info=True)
171-
loss = self.diffusion(x, channels_list=[latent], **kwargs)
172-
return (loss, info) if with_info else loss
173-
174-
def decode(self, latent: Tensor, **kwargs) -> Tensor:
175-
b = latent.shape[0]
176-
length = latent.shape[2] * self.encoder_downsample_factor
177-
# Compute noise by inferring shape from latent length
178-
noise = torch.randn(b, self.in_channels, length, device=latent.device)
179-
# Compute context form latent
180-
default_kwargs = dict(channels_list=[latent])
181-
# Decode by sampling while conditioning on latent channels
182-
return self.sample(noise, **{**default_kwargs, **kwargs})
183-
184-
def sample(self, *args, **kwargs) -> Tensor:
185-
return self.diffusion.sample(*args, **kwargs)
186-
187-
188-
class DiffusionMAE1d(nn.Module):
189108
def __init__(
190-
self,
191-
in_channels: int,
192-
encoder_inject_depth: int,
193-
encoder_channels: int,
194-
encoder_factors: Sequence[int],
195-
encoder_multipliers: Sequence[int],
196-
diffusion_type: str,
197-
stft_num_fft: int,
198-
stft_hop_length: int,
199-
stft_use_complex: bool,
200-
stft_window_length: Optional[int] = None,
201-
encoder_patch_size: int = 1,
202-
bottleneck: Union[Bottleneck, Sequence[Bottleneck]] = [],
203-
bottleneck_channels: Optional[int] = None,
204-
unet_type: str = "base",
205-
**kwargs,
109+
self, in_channels: int, encoder: Encoder1d, encoder_inject_depth: int, **kwargs
206110
):
207-
super().__init__()
208-
self.in_channels = in_channels
209-
210-
encoder_kwargs, kwargs = groupby("encoder_", kwargs)
211-
diffusion_kwargs, kwargs = groupby("diffusion_", kwargs)
212-
stft_kwargs, kwargs = groupby("stft_", kwargs)
213-
214-
# Compute context channels
215-
context_channels = [0] * encoder_inject_depth
216-
if exists(bottleneck_channels):
217-
context_channels += [bottleneck_channels]
218-
else:
219-
context_channels += [encoder_channels * encoder_multipliers[-1]]
220-
221-
self.spectrogram_channels = stft_num_fft // 2 + 1
222-
self.stft_hop_length = stft_hop_length
223-
224-
self.encoder_stft = STFT(
225-
num_fft=stft_num_fft,
226-
hop_length=stft_hop_length,
227-
window_length=stft_window_length,
228-
use_complex=False, # Magnitude encoding
229-
)
230-
231-
self.unet = XUNet1d(
232-
type=unet_type,
111+
super().__init__(
233112
in_channels=in_channels,
234-
context_channels=context_channels,
235-
use_stft=True,
236-
stft_use_complex=stft_use_complex,
237-
stft_num_fft=stft_num_fft,
238-
stft_hop_length=stft_hop_length,
239-
stft_window_length=stft_window_length,
113+
context_channels=[0] * encoder_inject_depth + [encoder.out_channels],
240114
**kwargs,
241115
)
242-
243-
self.diffusion = XDiffusion(
244-
type=diffusion_type, net=self.unet, **diffusion_kwargs
245-
)
246-
247-
self.encoder = Encoder1d(
248-
in_channels=in_channels * self.spectrogram_channels,
249-
channels=encoder_channels,
250-
patch_size=encoder_patch_size,
251-
factors=encoder_factors,
252-
multipliers=encoder_multipliers,
253-
out_channels=bottleneck_channels,
254-
**encoder_kwargs,
255-
)
256-
257-
self.encoder_downsample_factor = encoder_patch_size * prod(encoder_factors)
258-
self.bottleneck_channels = bottleneck_channels
259-
self.bottlenecks = nn.ModuleList(to_list(bottleneck))
260-
261-
def encode(
262-
self, x: Tensor, with_info: bool = False
263-
) -> Union[Tensor, Tuple[Tensor, Any]]:
264-
# Extract magnitude and encode
265-
magnitude, _ = self.encoder_stft.encode(x)
266-
magnitude_flat = rearrange(magnitude, "b c f t -> b (c f) t")
267-
latent, info = self.encoder(magnitude_flat, with_info=True)
268-
# Apply bottlenecks if present
269-
for bottleneck in self.bottlenecks:
270-
latent, info_bottleneck = bottleneck(latent, with_info=True)
271-
info = {**info, **prefix_dict("bottleneck_", info_bottleneck)}
272-
return (latent, info) if with_info else latent
116+
self.in_channels = in_channels
117+
self.encoder = encoder
273118

274119
def forward( # type: ignore
275120
self, x: Tensor, with_info: bool = False, **kwargs
276121
) -> Union[Tensor, Tuple[Tensor, Any]]:
277122
latent, info = self.encode(x, with_info=True)
278-
loss = self.diffusion(x, channels_list=[latent], **kwargs)
123+
print(latent.shape)
124+
loss = super().forward(x, channels_list=[latent], **kwargs)
279125
return (loss, info) if with_info else loss
280126

127+
def encode(self, *args, **kwargs):
128+
return self.encoder(*args, **kwargs)
129+
281130
def decode(self, latent: Tensor, **kwargs) -> Tensor:
282131
b = latent.shape[0]
283-
length = closest_power_2(
284-
self.stft_hop_length * latent.shape[2] * self.encoder_downsample_factor
285-
)
132+
length = closest_power_2(latent.shape[2] * self.encoder.downsample_factor)
286133
# Compute noise by inferring shape from latent length
287134
noise = torch.randn(b, self.in_channels, length, device=latent.device)
288135
# Compute context form latent
289136
default_kwargs = dict(channels_list=[latent])
290137
# Decode by sampling while conditioning on latent channels
291-
return self.sample(noise, **{**default_kwargs, **kwargs}) # type: ignore
292-
293-
def sample(self, *args, **kwargs) -> Tensor:
294-
return self.diffusion.sample(*args, **kwargs)
138+
return super().sample(noise, **{**default_kwargs, **kwargs})
295139

296140

297141
class DiffusionVocoder1d(Model1d):
@@ -499,31 +343,21 @@ def sample(self, *args, **kwargs):
499343
return super().sample(*args, **{**get_default_sampling_kwargs(), **kwargs})
500344

501345

502-
class AudioDiffusionAutoencoder(DiffusionAutoencoder1d):
503-
def __init__(self, *args, **kwargs):
346+
class AudioDiffusionAE(DiffusionAE1d):
347+
def __init__(self, in_channels: int, *args, **kwargs):
504348
default_kwargs = dict(
505349
**get_default_model_kwargs(),
350+
in_channels=in_channels,
351+
encoder=Encoder1d(
352+
in_channels=in_channels,
353+
patch_size=16,
354+
channels=16,
355+
multipliers=[1, 2, 4, 4, 4, 4, 4],
356+
factors=[4, 4, 4, 2, 2, 2],
357+
num_blocks=[2, 2, 2, 2, 2, 2],
358+
out_channels=64,
359+
),
506360
encoder_inject_depth=6,
507-
encoder_channels=16,
508-
encoder_patch_size=16,
509-
encoder_multipliers=[1, 2, 4, 4, 4, 4, 4],
510-
encoder_factors=[4, 4, 4, 2, 2, 2],
511-
encoder_num_blocks=[2, 2, 2, 2, 2, 2],
512-
bottleneck_channels=64,
513-
)
514-
super().__init__(*args, **{**default_kwargs, **kwargs}) # type: ignore
515-
516-
def decode(self, *args, **kwargs):
517-
return super().decode(*args, **{**get_default_sampling_kwargs(), **kwargs})
518-
519-
520-
class AudioDiffusionMAE(DiffusionMAE1d):
521-
def __init__(self, *args, **kwargs):
522-
default_kwargs = dict(
523-
diffusion_type="v",
524-
diffusion_sigma_distribution=UniformDistribution(),
525-
stft_num_fft=1023,
526-
stft_hop_length=256,
527361
)
528362
super().__init__(*args, **{**default_kwargs, **kwargs}) # type: ignore
529363

‎audio_diffusion_pytorch/modules.py

+26-26
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from einops_exts import rearrange_many
99
from torch import Tensor, einsum
1010

11-
from .utils import closest_power_2, default, exists, groupby, is_sequence
11+
from .utils import closest_power_2, default, exists, groupby
1212

1313
"""
1414
Utils
@@ -1197,44 +1197,44 @@ def __init__(self, context_features: int, **kwargs):
11971197
super().__init__(context_features=context_features, **kwargs)
11981198
self.embedder = NumberEmbedder(features=context_features)
11991199

1200+
def expand(self, x: Any, shape: Tuple[int, ...]) -> Tensor:
1201+
x = x if torch.is_tensor(x) else torch.tensor(x)
1202+
return x.expand(shape)
1203+
12001204
def forward( # type: ignore
12011205
self,
12021206
x: Tensor,
12031207
time: Tensor,
12041208
*,
12051209
channels_list: Sequence[Tensor],
1206-
channels_augmentation: bool = False,
1207-
channels_scale: Union[int, Sequence[int]] = 0,
1210+
channels_augmentation: Union[
1211+
bool, Sequence[bool], Sequence[Sequence[bool]], Tensor
1212+
] = False,
1213+
channels_scale: Union[
1214+
float, Sequence[float], Sequence[Sequence[float]], Tensor
1215+
] = 0,
12081216
**kwargs,
12091217
) -> Tensor:
1210-
b, num_items = x.shape[0], len(channels_list)
1211-
1212-
if channels_augmentation:
1213-
# Random noise augmentation for each item
1214-
channels_scale = torch.rand(num_items, b).to(x) # type: ignore
1215-
for i in range(num_items):
1216-
item = channels_list[i]
1217-
scale = rearrange(channels_scale[i], "b -> b 1 1") # type: ignore
1218-
channels_list[i] = torch.randn_like(item) * scale + item * (1 - scale) # type: ignore # noqa
1219-
else:
1220-
# Expand same scale to each batch element
1221-
if is_sequence(channels_scale):
1222-
assert_message = "len(channels_scale) must match len(channels_list)"
1223-
assert len(channels_scale) == num_items, assert_message
1224-
else:
1225-
channels_scale = num_items * [channels_scale] # type: ignore
1226-
channels_scale = torch.tensor(channels_scale).to(x) # type: ignore
1227-
channels_scale = repeat(channels_scale, "n -> n b", b=b)
1228-
1229-
# Compute scale feature embedding
1230-
scale_embedding = self.embedder(channels_scale)
1231-
scale_embedding = reduce(scale_embedding, "n b d -> b d", "sum")
1218+
b, n = x.shape[0], len(channels_list)
1219+
channels_augmentation = self.expand(channels_augmentation, shape=(b, n)).to(x)
1220+
channels_scale = self.expand(channels_scale, shape=(b, n)).to(x)
1221+
1222+
# Augmentation (for each channel list item)
1223+
for i in range(n):
1224+
scale = channels_scale[:, i] * channels_augmentation[:, i]
1225+
scale = rearrange(scale, "b -> b 1 1")
1226+
item = channels_list[i]
1227+
channels_list[i] = torch.randn_like(item) * scale + item * (1 - scale) # type: ignore # noqa
1228+
1229+
# Scale embedding (sum reduction if more than one channel list item)
1230+
channels_scale_emb = self.embedder(channels_scale)
1231+
channels_scale_emb = reduce(channels_scale_emb, "b n d -> b d", "sum")
12321232

12331233
return super().forward(
12341234
x=x,
12351235
time=time,
12361236
channels_list=channels_list,
1237-
features=scale_embedding,
1237+
features=channels_scale_emb,
12381238
**kwargs,
12391239
)
12401240

‎setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name="audio-diffusion-pytorch",
55
packages=find_packages(exclude=[]),
6-
version="0.0.95",
6+
version="0.0.96",
77
license="MIT",
88
description="Audio Diffusion - PyTorch",
99
long_description_content_type="text/markdown",

0 commit comments

Comments
 (0)
Please sign in to comment.