Skip to content

Commit 24ff00f

Browse files
feat: encoder1d, decoder1d, autoencoder bottleneck channels
1 parent d48ac1a commit 24ff00f

File tree

3 files changed

+127
-32
lines changed

3 files changed

+127
-32
lines changed

audio_diffusion_pytorch/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
)
3131
from .modules import (
3232
AutoEncoder1d,
33+
Decoder1d,
34+
Encoder1d,
3335
MultiEncoder1d,
3436
Noiser,
3537
T5Embedder,

audio_diffusion_pytorch/modules.py

Lines changed: 124 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1364,7 +1364,7 @@ def forward(
13641364
return (x, info) if with_info else x
13651365

13661366

1367-
class AutoEncoder1d(nn.Module):
1367+
class Encoder1d(nn.Module):
13681368
def __init__(
13691369
self,
13701370
in_channels: int,
@@ -1375,16 +1375,10 @@ def __init__(
13751375
multipliers: Sequence[int],
13761376
factors: Sequence[int],
13771377
num_blocks: Sequence[int],
1378-
use_noisy: bool = False,
1379-
bottleneck: Union[Bottleneck, List[Bottleneck]] = [],
1380-
use_magnitude_channels: bool = False,
1378+
out_channels: Optional[int] = None,
13811379
):
13821380
super().__init__()
13831381
num_layers = len(multipliers) - 1
1384-
self.bottlenecks = nn.ModuleList(to_list(bottleneck))
1385-
self.use_noisy = use_noisy
1386-
self.use_magnitude_channels = use_magnitude_channels
1387-
13881382
assert len(factors) >= num_layers and len(num_blocks) >= num_layers
13891383

13901384
self.to_in = Patcher(
@@ -1408,10 +1402,66 @@ def __init__(
14081402
]
14091403
)
14101404

1405+
self.to_out = (
1406+
nn.Conv1d(
1407+
in_channels=channels * multipliers[-1],
1408+
out_channels=out_channels,
1409+
kernel_size=1,
1410+
)
1411+
if exists(out_channels)
1412+
else nn.Identity()
1413+
)
1414+
1415+
def forward(
1416+
self, x: Tensor, with_info: bool = False
1417+
) -> Union[Tensor, Tuple[Tensor, Any]]:
1418+
xs = []
1419+
x = self.to_in(x)
1420+
1421+
for downsample in self.downsamples:
1422+
x = downsample(x)
1423+
xs += [x]
1424+
1425+
x = self.to_out(x)
1426+
1427+
info = dict(xs=xs)
1428+
return (x, info) if with_info else x
1429+
1430+
1431+
class Decoder1d(nn.Module):
1432+
def __init__(
1433+
self,
1434+
out_channels: int,
1435+
channels: int,
1436+
patch_blocks: int,
1437+
patch_factor: int,
1438+
resnet_groups: int,
1439+
multipliers: Sequence[int],
1440+
factors: Sequence[int],
1441+
num_blocks: Sequence[int],
1442+
use_magnitude_channels: bool = False,
1443+
in_channels: Optional[int] = None,
1444+
):
1445+
super().__init__()
1446+
num_layers = len(multipliers) - 1
1447+
self.use_magnitude_channels = use_magnitude_channels
1448+
1449+
assert len(factors) >= num_layers and len(num_blocks) >= num_layers
1450+
1451+
self.to_in = (
1452+
Conv1d(
1453+
in_channels=in_channels,
1454+
out_channels=channels * multipliers[-1],
1455+
kernel_size=1,
1456+
)
1457+
if exists(in_channels)
1458+
else nn.Identity()
1459+
)
1460+
14111461
self.upsamples = nn.ModuleList(
14121462
[
14131463
UpsampleBlock1d(
1414-
in_channels=channels * multipliers[i + 1] * (use_noisy + 1),
1464+
in_channels=channels * multipliers[i + 1],
14151465
out_channels=channels * multipliers[i],
14161466
factor=factors[i],
14171467
num_groups=resnet_groups,
@@ -1424,12 +1474,73 @@ def __init__(
14241474
)
14251475

14261476
self.to_out = Unpatcher(
1427-
in_channels=channels * (use_noisy + 1),
1428-
out_channels=in_channels * (2 if use_magnitude_channels else 1),
1477+
in_channels=channels,
1478+
out_channels=out_channels * (2 if use_magnitude_channels else 1),
14291479
blocks=patch_blocks,
14301480
factor=patch_factor,
14311481
)
14321482

1483+
def forward(self, x: Tensor) -> Union[Tensor, Tuple[Tensor, Any]]:
1484+
x = self.to_in(x)
1485+
1486+
for upsample in self.upsamples:
1487+
x = upsample(x)
1488+
1489+
x = self.to_out(x)
1490+
1491+
if self.use_magnitude_channels:
1492+
x = merge_magnitude_channels(x)
1493+
1494+
return x
1495+
1496+
1497+
class AutoEncoder1d(nn.Module):
1498+
def __init__(
1499+
self,
1500+
in_channels: int,
1501+
channels: int,
1502+
patch_blocks: int,
1503+
patch_factor: int,
1504+
resnet_groups: int,
1505+
multipliers: Sequence[int],
1506+
factors: Sequence[int],
1507+
num_blocks: Sequence[int],
1508+
use_noisy: bool = False,
1509+
bottleneck: Union[Bottleneck, List[Bottleneck]] = [],
1510+
bottleneck_channels: Optional[int] = None,
1511+
use_magnitude_channels: bool = False,
1512+
):
1513+
super().__init__()
1514+
num_layers = len(multipliers) - 1
1515+
self.bottlenecks = nn.ModuleList(to_list(bottleneck))
1516+
1517+
assert len(factors) >= num_layers and len(num_blocks) >= num_layers
1518+
1519+
self.encoder = Encoder1d(
1520+
in_channels=in_channels,
1521+
channels=channels,
1522+
patch_blocks=patch_blocks,
1523+
patch_factor=patch_factor,
1524+
resnet_groups=resnet_groups,
1525+
multipliers=multipliers,
1526+
factors=factors,
1527+
num_blocks=num_blocks,
1528+
out_channels=bottleneck_channels,
1529+
)
1530+
1531+
self.decoder = Decoder1d(
1532+
in_channels=bottleneck_channels,
1533+
out_channels=in_channels,
1534+
channels=channels,
1535+
patch_blocks=patch_blocks,
1536+
patch_factor=patch_factor,
1537+
resnet_groups=resnet_groups,
1538+
multipliers=multipliers,
1539+
factors=factors,
1540+
num_blocks=num_blocks,
1541+
use_magnitude_channels=use_magnitude_channels,
1542+
)
1543+
14331544
def forward(
14341545
self, x: Tensor, with_info: bool = False
14351546
) -> Union[Tensor, Tuple[Tensor, Any]]:
@@ -1440,12 +1551,7 @@ def forward(
14401551
def encode(
14411552
self, x: Tensor, with_info: bool = False
14421553
) -> Union[Tensor, Tuple[Tensor, Any]]:
1443-
xs = []
1444-
x = self.to_in(x)
1445-
for downsample in self.downsamples:
1446-
x = downsample(x)
1447-
xs += [x]
1448-
info = dict(xs=xs)
1554+
x, info = self.encoder(x, with_info=True)
14491555

14501556
for bottleneck in self.bottlenecks:
14511557
x, info_bottleneck = bottleneck(x, with_info=True)
@@ -1454,20 +1560,7 @@ def encode(
14541560
return (x, info) if with_info else x
14551561

14561562
def decode(self, x: Tensor) -> Tensor:
1457-
for upsample in self.upsamples:
1458-
if self.use_noisy:
1459-
x = torch.cat([x, torch.randn_like(x)], dim=1)
1460-
x = upsample(x)
1461-
1462-
if self.use_noisy:
1463-
x = torch.cat([x, torch.randn_like(x)], dim=1)
1464-
1465-
x = self.to_out(x)
1466-
1467-
if self.use_magnitude_channels:
1468-
x = merge_magnitude_channels(x)
1469-
1470-
return x
1563+
return self.decoder(x)
14711564

14721565

14731566
class MultiEncoder1d(nn.Module):

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name="audio-diffusion-pytorch",
55
packages=find_packages(exclude=[]),
6-
version="0.0.73",
6+
version="0.0.74",
77
license="MIT",
88
description="Audio Diffusion - PyTorch",
99
long_description_content_type="text/markdown",

0 commit comments

Comments
 (0)