8
8
from einops_exts import rearrange_many
9
9
from torch import Tensor , einsum
10
10
11
- from .utils import default , exists , prod
11
+ from .utils import default , exists , prod , wave_norm , wave_unnorm
12
12
13
13
"""
14
14
Utils
@@ -809,6 +809,7 @@ def __init__(
809
809
use_nearest_upsample : bool ,
810
810
use_skip_scale : bool ,
811
811
use_context_time : bool ,
812
+ norm : float = 0.0 ,
812
813
out_channels : Optional [int ] = None ,
813
814
context_features : Optional [int ] = None ,
814
815
context_channels : Optional [Sequence [int ]] = None ,
@@ -822,6 +823,8 @@ def __init__(
822
823
use_context_channels = len (context_channels ) > 0
823
824
context_mapping_features = None
824
825
826
+ self .norm = norm
827
+ self .use_norm = norm > 0.0
825
828
self .num_layers = num_layers
826
829
self .use_context_time = use_context_time
827
830
self .use_context_features = use_context_features
@@ -997,9 +1000,11 @@ def forward(
997
1000
# Concat context channels at layer 0 if provided
998
1001
channels = self .get_channels (channels_list , layer = 0 )
999
1002
x = torch .cat ([x , channels ], dim = 1 ) if exists (channels ) else x
1000
-
1001
1003
mapping = self .get_mapping (time , features )
1002
1004
1005
+ if self .use_norm :
1006
+ x = wave_norm (x , peak = self .norm )
1007
+
1003
1008
x = self .to_in (x , mapping )
1004
1009
skips_list = [x ]
1005
1010
@@ -1019,6 +1024,9 @@ def forward(
1019
1024
x += skips_list .pop ()
1020
1025
x = self .to_out (x , mapping )
1021
1026
1027
+ if self .use_norm :
1028
+ x = wave_unnorm (x , peak = self .norm )
1029
+
1022
1030
return x
1023
1031
1024
1032
@@ -1120,11 +1128,14 @@ def __init__(
1120
1128
num_blocks : Sequence [int ],
1121
1129
use_noisy : bool = False ,
1122
1130
bottleneck : Optional [Bottleneck ] = None ,
1131
+ norm : float = 0.0 ,
1123
1132
):
1124
1133
super ().__init__ ()
1125
1134
num_layers = len (multipliers ) - 1
1126
1135
self .bottleneck = bottleneck
1127
1136
self .use_noisy = use_noisy
1137
+ self .use_norm = norm > 0.0
1138
+ self .norm = norm
1128
1139
1129
1140
assert len (factors ) >= num_layers and len (num_blocks ) >= num_layers
1130
1141
@@ -1174,6 +1185,9 @@ def __init__(
1174
1185
def encode (
1175
1186
self , x : Tensor , with_info : bool = False
1176
1187
) -> Union [Tensor , Tuple [Tensor , Any ]]:
1188
+ if self .use_norm :
1189
+ x = wave_norm (x , peak = self .norm )
1190
+
1177
1191
x = self .to_in (x )
1178
1192
for downsample in self .downsamples :
1179
1193
x = downsample (x )
@@ -1190,7 +1204,12 @@ def decode(self, x: Tensor) -> Tensor:
1190
1204
x = upsample (x )
1191
1205
if self .use_noisy :
1192
1206
x = torch .cat ([x , torch .randn_like (x )], dim = 1 )
1193
- return self .to_out (x )
1207
+ x = self .to_out (x )
1208
+
1209
+ if self .use_norm :
1210
+ x = wave_unnorm (x , peak = self .norm )
1211
+
1212
+ return x
1194
1213
1195
1214
1196
1215
class MultiEncoder1d (nn .Module ):
0 commit comments