@@ -1364,7 +1364,7 @@ def forward(
1364
1364
return (x , info ) if with_info else x
1365
1365
1366
1366
1367
- class AutoEncoder1d (nn .Module ):
1367
+ class Encoder1d (nn .Module ):
1368
1368
def __init__ (
1369
1369
self ,
1370
1370
in_channels : int ,
@@ -1375,16 +1375,10 @@ def __init__(
1375
1375
multipliers : Sequence [int ],
1376
1376
factors : Sequence [int ],
1377
1377
num_blocks : Sequence [int ],
1378
- use_noisy : bool = False ,
1379
- bottleneck : Union [Bottleneck , List [Bottleneck ]] = [],
1380
- use_magnitude_channels : bool = False ,
1378
+ out_channels : Optional [int ] = None ,
1381
1379
):
1382
1380
super ().__init__ ()
1383
1381
num_layers = len (multipliers ) - 1
1384
- self .bottlenecks = nn .ModuleList (to_list (bottleneck ))
1385
- self .use_noisy = use_noisy
1386
- self .use_magnitude_channels = use_magnitude_channels
1387
-
1388
1382
assert len (factors ) >= num_layers and len (num_blocks ) >= num_layers
1389
1383
1390
1384
self .to_in = Patcher (
@@ -1408,10 +1402,66 @@ def __init__(
1408
1402
]
1409
1403
)
1410
1404
1405
+ self .to_out = (
1406
+ nn .Conv1d (
1407
+ in_channels = channels * multipliers [- 1 ],
1408
+ out_channels = out_channels ,
1409
+ kernel_size = 1 ,
1410
+ )
1411
+ if exists (out_channels )
1412
+ else nn .Identity ()
1413
+ )
1414
+
1415
+ def forward (
1416
+ self , x : Tensor , with_info : bool = False
1417
+ ) -> Union [Tensor , Tuple [Tensor , Any ]]:
1418
+ xs = []
1419
+ x = self .to_in (x )
1420
+
1421
+ for downsample in self .downsamples :
1422
+ x = downsample (x )
1423
+ xs += [x ]
1424
+
1425
+ x = self .to_out (x )
1426
+
1427
+ info = dict (xs = xs )
1428
+ return (x , info ) if with_info else x
1429
+
1430
+
1431
+ class Decoder1d (nn .Module ):
1432
+ def __init__ (
1433
+ self ,
1434
+ out_channels : int ,
1435
+ channels : int ,
1436
+ patch_blocks : int ,
1437
+ patch_factor : int ,
1438
+ resnet_groups : int ,
1439
+ multipliers : Sequence [int ],
1440
+ factors : Sequence [int ],
1441
+ num_blocks : Sequence [int ],
1442
+ use_magnitude_channels : bool = False ,
1443
+ in_channels : Optional [int ] = None ,
1444
+ ):
1445
+ super ().__init__ ()
1446
+ num_layers = len (multipliers ) - 1
1447
+ self .use_magnitude_channels = use_magnitude_channels
1448
+
1449
+ assert len (factors ) >= num_layers and len (num_blocks ) >= num_layers
1450
+
1451
+ self .to_in = (
1452
+ Conv1d (
1453
+ in_channels = in_channels ,
1454
+ out_channels = channels * multipliers [- 1 ],
1455
+ kernel_size = 1 ,
1456
+ )
1457
+ if exists (in_channels )
1458
+ else nn .Identity ()
1459
+ )
1460
+
1411
1461
self .upsamples = nn .ModuleList (
1412
1462
[
1413
1463
UpsampleBlock1d (
1414
- in_channels = channels * multipliers [i + 1 ] * ( use_noisy + 1 ) ,
1464
+ in_channels = channels * multipliers [i + 1 ],
1415
1465
out_channels = channels * multipliers [i ],
1416
1466
factor = factors [i ],
1417
1467
num_groups = resnet_groups ,
@@ -1424,12 +1474,73 @@ def __init__(
1424
1474
)
1425
1475
1426
1476
self .to_out = Unpatcher (
1427
- in_channels = channels * ( use_noisy + 1 ) ,
1428
- out_channels = in_channels * (2 if use_magnitude_channels else 1 ),
1477
+ in_channels = channels ,
1478
+ out_channels = out_channels * (2 if use_magnitude_channels else 1 ),
1429
1479
blocks = patch_blocks ,
1430
1480
factor = patch_factor ,
1431
1481
)
1432
1482
1483
+ def forward (self , x : Tensor ) -> Union [Tensor , Tuple [Tensor , Any ]]:
1484
+ x = self .to_in (x )
1485
+
1486
+ for upsample in self .upsamples :
1487
+ x = upsample (x )
1488
+
1489
+ x = self .to_out (x )
1490
+
1491
+ if self .use_magnitude_channels :
1492
+ x = merge_magnitude_channels (x )
1493
+
1494
+ return x
1495
+
1496
+
1497
+ class AutoEncoder1d (nn .Module ):
1498
+ def __init__ (
1499
+ self ,
1500
+ in_channels : int ,
1501
+ channels : int ,
1502
+ patch_blocks : int ,
1503
+ patch_factor : int ,
1504
+ resnet_groups : int ,
1505
+ multipliers : Sequence [int ],
1506
+ factors : Sequence [int ],
1507
+ num_blocks : Sequence [int ],
1508
+ use_noisy : bool = False ,
1509
+ bottleneck : Union [Bottleneck , List [Bottleneck ]] = [],
1510
+ bottleneck_channels : Optional [int ] = None ,
1511
+ use_magnitude_channels : bool = False ,
1512
+ ):
1513
+ super ().__init__ ()
1514
+ num_layers = len (multipliers ) - 1
1515
+ self .bottlenecks = nn .ModuleList (to_list (bottleneck ))
1516
+
1517
+ assert len (factors ) >= num_layers and len (num_blocks ) >= num_layers
1518
+
1519
+ self .encoder = Encoder1d (
1520
+ in_channels = in_channels ,
1521
+ channels = channels ,
1522
+ patch_blocks = patch_blocks ,
1523
+ patch_factor = patch_factor ,
1524
+ resnet_groups = resnet_groups ,
1525
+ multipliers = multipliers ,
1526
+ factors = factors ,
1527
+ num_blocks = num_blocks ,
1528
+ out_channels = bottleneck_channels ,
1529
+ )
1530
+
1531
+ self .decoder = Decoder1d (
1532
+ in_channels = bottleneck_channels ,
1533
+ out_channels = in_channels ,
1534
+ channels = channels ,
1535
+ patch_blocks = patch_blocks ,
1536
+ patch_factor = patch_factor ,
1537
+ resnet_groups = resnet_groups ,
1538
+ multipliers = multipliers ,
1539
+ factors = factors ,
1540
+ num_blocks = num_blocks ,
1541
+ use_magnitude_channels = use_magnitude_channels ,
1542
+ )
1543
+
1433
1544
def forward (
1434
1545
self , x : Tensor , with_info : bool = False
1435
1546
) -> Union [Tensor , Tuple [Tensor , Any ]]:
@@ -1440,12 +1551,7 @@ def forward(
1440
1551
def encode (
1441
1552
self , x : Tensor , with_info : bool = False
1442
1553
) -> Union [Tensor , Tuple [Tensor , Any ]]:
1443
- xs = []
1444
- x = self .to_in (x )
1445
- for downsample in self .downsamples :
1446
- x = downsample (x )
1447
- xs += [x ]
1448
- info = dict (xs = xs )
1554
+ x , info = self .encoder (x , with_info = True )
1449
1555
1450
1556
for bottleneck in self .bottlenecks :
1451
1557
x , info_bottleneck = bottleneck (x , with_info = True )
@@ -1454,20 +1560,7 @@ def encode(
1454
1560
return (x , info ) if with_info else x
1455
1561
1456
1562
def decode (self , x : Tensor ) -> Tensor :
1457
- for upsample in self .upsamples :
1458
- if self .use_noisy :
1459
- x = torch .cat ([x , torch .randn_like (x )], dim = 1 )
1460
- x = upsample (x )
1461
-
1462
- if self .use_noisy :
1463
- x = torch .cat ([x , torch .randn_like (x )], dim = 1 )
1464
-
1465
- x = self .to_out (x )
1466
-
1467
- if self .use_magnitude_channels :
1468
- x = merge_magnitude_channels (x )
1469
-
1470
- return x
1563
+ return self .decoder (x )
1471
1564
1472
1565
1473
1566
class MultiEncoder1d (nn .Module ):
0 commit comments