|
3 | 3 | from typing import Any, Optional, Sequence, Tuple, Union
|
4 | 4 |
|
5 | 5 | import torch
|
6 |
| -from audio_encoders_pytorch import Bottleneck, Encoder1d |
| 6 | +from audio_encoders_pytorch import Encoder1d |
7 | 7 | from einops import rearrange
|
8 | 8 | from torch import Tensor, nn
|
9 | 9 | from tqdm import tqdm
|
|
16 | 16 | downsample,
|
17 | 17 | exists,
|
18 | 18 | groupby,
|
19 |
| - prefix_dict, |
20 |
| - prod, |
21 | 19 | to_list,
|
22 | 20 | upsample,
|
23 | 21 | )
|
@@ -104,194 +102,40 @@ def sample( # type: ignore
|
104 | 102 | return super().sample(noise, **{**default_kwargs, **kwargs}) # type: ignore
|
105 | 103 |
|
106 | 104 |
|
107 |
| -class DiffusionAutoencoder1d(nn.Module): |
108 |
| - def __init__( |
109 |
| - self, |
110 |
| - in_channels: int, |
111 |
| - encoder_inject_depth: int, |
112 |
| - encoder_channels: int, |
113 |
| - encoder_factors: Sequence[int], |
114 |
| - encoder_multipliers: Sequence[int], |
115 |
| - encoder_patch_size: int = 1, |
116 |
| - bottleneck: Union[Bottleneck, Sequence[Bottleneck]] = [], |
117 |
| - bottleneck_channels: Optional[int] = None, |
118 |
| - unet_type: str = "base", |
119 |
| - **kwargs, |
120 |
| - ): |
121 |
| - super().__init__() |
122 |
| - self.in_channels = in_channels |
123 |
| - |
124 |
| - encoder_kwargs, kwargs = groupby("encoder_", kwargs) |
125 |
| - diffusion_kwargs, kwargs = groupby("diffusion_", kwargs) |
126 |
| - |
127 |
| - # Compute context channels |
128 |
| - context_channels = [0] * encoder_inject_depth |
129 |
| - if exists(bottleneck_channels): |
130 |
| - context_channels += [bottleneck_channels] |
131 |
| - else: |
132 |
| - context_channels += [encoder_channels * encoder_multipliers[-1]] |
133 |
| - |
134 |
| - self.unet = XUNet1d( |
135 |
| - type=unet_type, |
136 |
| - in_channels=in_channels, |
137 |
| - context_channels=context_channels, |
138 |
| - **kwargs, |
139 |
| - ) |
140 |
| - |
141 |
| - self.diffusion = XDiffusion(net=self.unet, **diffusion_kwargs) |
142 |
| - |
143 |
| - self.encoder = Encoder1d( |
144 |
| - in_channels=in_channels, |
145 |
| - channels=encoder_channels, |
146 |
| - patch_size=encoder_patch_size, |
147 |
| - factors=encoder_factors, |
148 |
| - multipliers=encoder_multipliers, |
149 |
| - out_channels=bottleneck_channels, |
150 |
| - **encoder_kwargs, |
151 |
| - ) |
152 |
| - |
153 |
| - self.encoder_downsample_factor = encoder_patch_size * prod(encoder_factors) |
154 |
| - self.bottleneck_channels = bottleneck_channels |
155 |
| - self.bottlenecks = nn.ModuleList(to_list(bottleneck)) |
| 105 | +class DiffusionAE1d(Model1d): |
| 106 | + """Diffusion Auto Encoder""" |
156 | 107 |
|
157 |
| - def encode( |
158 |
| - self, x: Tensor, with_info: bool = False |
159 |
| - ) -> Union[Tensor, Tuple[Tensor, Any]]: |
160 |
| - latent, info = self.encoder(x, with_info=True) |
161 |
| - # Apply bottlenecks if present |
162 |
| - for bottleneck in self.bottlenecks: |
163 |
| - latent, info_bottleneck = bottleneck(latent, with_info=True) |
164 |
| - info = {**info, **prefix_dict("bottleneck_", info_bottleneck)} |
165 |
| - return (latent, info) if with_info else latent |
166 |
| - |
167 |
| - def forward( # type: ignore |
168 |
| - self, x: Tensor, with_info: bool = False, **kwargs |
169 |
| - ) -> Union[Tensor, Tuple[Tensor, Any]]: |
170 |
| - latent, info = self.encode(x, with_info=True) |
171 |
| - loss = self.diffusion(x, channels_list=[latent], **kwargs) |
172 |
| - return (loss, info) if with_info else loss |
173 |
| - |
174 |
| - def decode(self, latent: Tensor, **kwargs) -> Tensor: |
175 |
| - b = latent.shape[0] |
176 |
| - length = latent.shape[2] * self.encoder_downsample_factor |
177 |
| - # Compute noise by inferring shape from latent length |
178 |
| - noise = torch.randn(b, self.in_channels, length, device=latent.device) |
179 |
| - # Compute context form latent |
180 |
| - default_kwargs = dict(channels_list=[latent]) |
181 |
| - # Decode by sampling while conditioning on latent channels |
182 |
| - return self.sample(noise, **{**default_kwargs, **kwargs}) |
183 |
| - |
184 |
| - def sample(self, *args, **kwargs) -> Tensor: |
185 |
| - return self.diffusion.sample(*args, **kwargs) |
186 |
| - |
187 |
| - |
188 |
| -class DiffusionMAE1d(nn.Module): |
189 | 108 | def __init__(
|
190 |
| - self, |
191 |
| - in_channels: int, |
192 |
| - encoder_inject_depth: int, |
193 |
| - encoder_channels: int, |
194 |
| - encoder_factors: Sequence[int], |
195 |
| - encoder_multipliers: Sequence[int], |
196 |
| - diffusion_type: str, |
197 |
| - stft_num_fft: int, |
198 |
| - stft_hop_length: int, |
199 |
| - stft_use_complex: bool, |
200 |
| - stft_window_length: Optional[int] = None, |
201 |
| - encoder_patch_size: int = 1, |
202 |
| - bottleneck: Union[Bottleneck, Sequence[Bottleneck]] = [], |
203 |
| - bottleneck_channels: Optional[int] = None, |
204 |
| - unet_type: str = "base", |
205 |
| - **kwargs, |
| 109 | + self, in_channels: int, encoder: Encoder1d, encoder_inject_depth: int, **kwargs |
206 | 110 | ):
|
207 |
| - super().__init__() |
208 |
| - self.in_channels = in_channels |
209 |
| - |
210 |
| - encoder_kwargs, kwargs = groupby("encoder_", kwargs) |
211 |
| - diffusion_kwargs, kwargs = groupby("diffusion_", kwargs) |
212 |
| - stft_kwargs, kwargs = groupby("stft_", kwargs) |
213 |
| - |
214 |
| - # Compute context channels |
215 |
| - context_channels = [0] * encoder_inject_depth |
216 |
| - if exists(bottleneck_channels): |
217 |
| - context_channels += [bottleneck_channels] |
218 |
| - else: |
219 |
| - context_channels += [encoder_channels * encoder_multipliers[-1]] |
220 |
| - |
221 |
| - self.spectrogram_channels = stft_num_fft // 2 + 1 |
222 |
| - self.stft_hop_length = stft_hop_length |
223 |
| - |
224 |
| - self.encoder_stft = STFT( |
225 |
| - num_fft=stft_num_fft, |
226 |
| - hop_length=stft_hop_length, |
227 |
| - window_length=stft_window_length, |
228 |
| - use_complex=False, # Magnitude encoding |
229 |
| - ) |
230 |
| - |
231 |
| - self.unet = XUNet1d( |
232 |
| - type=unet_type, |
| 111 | + super().__init__( |
233 | 112 | in_channels=in_channels,
|
234 |
| - context_channels=context_channels, |
235 |
| - use_stft=True, |
236 |
| - stft_use_complex=stft_use_complex, |
237 |
| - stft_num_fft=stft_num_fft, |
238 |
| - stft_hop_length=stft_hop_length, |
239 |
| - stft_window_length=stft_window_length, |
| 113 | + context_channels=[0] * encoder_inject_depth + [encoder.out_channels], |
240 | 114 | **kwargs,
|
241 | 115 | )
|
242 |
| - |
243 |
| - self.diffusion = XDiffusion( |
244 |
| - type=diffusion_type, net=self.unet, **diffusion_kwargs |
245 |
| - ) |
246 |
| - |
247 |
| - self.encoder = Encoder1d( |
248 |
| - in_channels=in_channels * self.spectrogram_channels, |
249 |
| - channels=encoder_channels, |
250 |
| - patch_size=encoder_patch_size, |
251 |
| - factors=encoder_factors, |
252 |
| - multipliers=encoder_multipliers, |
253 |
| - out_channels=bottleneck_channels, |
254 |
| - **encoder_kwargs, |
255 |
| - ) |
256 |
| - |
257 |
| - self.encoder_downsample_factor = encoder_patch_size * prod(encoder_factors) |
258 |
| - self.bottleneck_channels = bottleneck_channels |
259 |
| - self.bottlenecks = nn.ModuleList(to_list(bottleneck)) |
260 |
| - |
261 |
| - def encode( |
262 |
| - self, x: Tensor, with_info: bool = False |
263 |
| - ) -> Union[Tensor, Tuple[Tensor, Any]]: |
264 |
| - # Extract magnitude and encode |
265 |
| - magnitude, _ = self.encoder_stft.encode(x) |
266 |
| - magnitude_flat = rearrange(magnitude, "b c f t -> b (c f) t") |
267 |
| - latent, info = self.encoder(magnitude_flat, with_info=True) |
268 |
| - # Apply bottlenecks if present |
269 |
| - for bottleneck in self.bottlenecks: |
270 |
| - latent, info_bottleneck = bottleneck(latent, with_info=True) |
271 |
| - info = {**info, **prefix_dict("bottleneck_", info_bottleneck)} |
272 |
| - return (latent, info) if with_info else latent |
| 116 | + self.in_channels = in_channels |
| 117 | + self.encoder = encoder |
273 | 118 |
|
274 | 119 | def forward( # type: ignore
|
275 | 120 | self, x: Tensor, with_info: bool = False, **kwargs
|
276 | 121 | ) -> Union[Tensor, Tuple[Tensor, Any]]:
|
277 | 122 | latent, info = self.encode(x, with_info=True)
|
278 |
| - loss = self.diffusion(x, channels_list=[latent], **kwargs) |
| 123 | + print(latent.shape) |
| 124 | + loss = super().forward(x, channels_list=[latent], **kwargs) |
279 | 125 | return (loss, info) if with_info else loss
|
280 | 126 |
|
| 127 | + def encode(self, *args, **kwargs): |
| 128 | + return self.encoder(*args, **kwargs) |
| 129 | + |
281 | 130 | def decode(self, latent: Tensor, **kwargs) -> Tensor:
|
282 | 131 | b = latent.shape[0]
|
283 |
| - length = closest_power_2( |
284 |
| - self.stft_hop_length * latent.shape[2] * self.encoder_downsample_factor |
285 |
| - ) |
| 132 | + length = closest_power_2(latent.shape[2] * self.encoder.downsample_factor) |
286 | 133 | # Compute noise by inferring shape from latent length
|
287 | 134 | noise = torch.randn(b, self.in_channels, length, device=latent.device)
|
288 | 135 | # Compute context form latent
|
289 | 136 | default_kwargs = dict(channels_list=[latent])
|
290 | 137 | # Decode by sampling while conditioning on latent channels
|
291 |
| - return self.sample(noise, **{**default_kwargs, **kwargs}) # type: ignore |
292 |
| - |
293 |
| - def sample(self, *args, **kwargs) -> Tensor: |
294 |
| - return self.diffusion.sample(*args, **kwargs) |
| 138 | + return super().sample(noise, **{**default_kwargs, **kwargs}) |
295 | 139 |
|
296 | 140 |
|
297 | 141 | class DiffusionVocoder1d(Model1d):
|
@@ -499,31 +343,21 @@ def sample(self, *args, **kwargs):
|
499 | 343 | return super().sample(*args, **{**get_default_sampling_kwargs(), **kwargs})
|
500 | 344 |
|
501 | 345 |
|
502 |
| -class AudioDiffusionAutoencoder(DiffusionAutoencoder1d): |
503 |
| - def __init__(self, *args, **kwargs): |
| 346 | +class AudioDiffusionAE(DiffusionAE1d): |
| 347 | + def __init__(self, in_channels: int, *args, **kwargs): |
504 | 348 | default_kwargs = dict(
|
505 | 349 | **get_default_model_kwargs(),
|
| 350 | + in_channels=in_channels, |
| 351 | + encoder=Encoder1d( |
| 352 | + in_channels=in_channels, |
| 353 | + patch_size=16, |
| 354 | + channels=16, |
| 355 | + multipliers=[1, 2, 4, 4, 4, 4, 4], |
| 356 | + factors=[4, 4, 4, 2, 2, 2], |
| 357 | + num_blocks=[2, 2, 2, 2, 2, 2], |
| 358 | + out_channels=64, |
| 359 | + ), |
506 | 360 | encoder_inject_depth=6,
|
507 |
| - encoder_channels=16, |
508 |
| - encoder_patch_size=16, |
509 |
| - encoder_multipliers=[1, 2, 4, 4, 4, 4, 4], |
510 |
| - encoder_factors=[4, 4, 4, 2, 2, 2], |
511 |
| - encoder_num_blocks=[2, 2, 2, 2, 2, 2], |
512 |
| - bottleneck_channels=64, |
513 |
| - ) |
514 |
| - super().__init__(*args, **{**default_kwargs, **kwargs}) # type: ignore |
515 |
| - |
516 |
| - def decode(self, *args, **kwargs): |
517 |
| - return super().decode(*args, **{**get_default_sampling_kwargs(), **kwargs}) |
518 |
| - |
519 |
| - |
520 |
| -class AudioDiffusionMAE(DiffusionMAE1d): |
521 |
| - def __init__(self, *args, **kwargs): |
522 |
| - default_kwargs = dict( |
523 |
| - diffusion_type="v", |
524 |
| - diffusion_sigma_distribution=UniformDistribution(), |
525 |
| - stft_num_fft=1023, |
526 |
| - stft_hop_length=256, |
527 | 361 | )
|
528 | 362 | super().__init__(*args, **{**default_kwargs, **kwargs}) # type: ignore
|
529 | 363 |
|
|
0 commit comments