Skip to content

Commit 37903b2

Browse files
feat: option to norm inputs with mu-law
1 parent ed9d77c commit 37903b2

File tree

3 files changed

+36
-4
lines changed

3 files changed

+36
-4
lines changed

Diff for: audio_diffusion_pytorch/modules.py

+22-3
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from einops_exts import rearrange_many
99
from torch import Tensor, einsum
1010

11-
from .utils import default, exists, prod
11+
from .utils import default, exists, prod, wave_norm, wave_unnorm
1212

1313
"""
1414
Utils
@@ -809,6 +809,7 @@ def __init__(
809809
use_nearest_upsample: bool,
810810
use_skip_scale: bool,
811811
use_context_time: bool,
812+
norm: float = 0.0,
812813
out_channels: Optional[int] = None,
813814
context_features: Optional[int] = None,
814815
context_channels: Optional[Sequence[int]] = None,
@@ -822,6 +823,8 @@ def __init__(
822823
use_context_channels = len(context_channels) > 0
823824
context_mapping_features = None
824825

826+
self.norm = norm
827+
self.use_norm = norm > 0.0
825828
self.num_layers = num_layers
826829
self.use_context_time = use_context_time
827830
self.use_context_features = use_context_features
@@ -997,9 +1000,11 @@ def forward(
9971000
# Concat context channels at layer 0 if provided
9981001
channels = self.get_channels(channels_list, layer=0)
9991002
x = torch.cat([x, channels], dim=1) if exists(channels) else x
1000-
10011003
mapping = self.get_mapping(time, features)
10021004

1005+
if self.use_norm:
1006+
x = wave_norm(x, peak=self.norm)
1007+
10031008
x = self.to_in(x, mapping)
10041009
skips_list = [x]
10051010

@@ -1019,6 +1024,9 @@ def forward(
10191024
x += skips_list.pop()
10201025
x = self.to_out(x, mapping)
10211026

1027+
if self.use_norm:
1028+
x = wave_unnorm(x, peak=self.norm)
1029+
10221030
return x
10231031

10241032

@@ -1120,11 +1128,14 @@ def __init__(
11201128
num_blocks: Sequence[int],
11211129
use_noisy: bool = False,
11221130
bottleneck: Optional[Bottleneck] = None,
1131+
norm: float = 0.0,
11231132
):
11241133
super().__init__()
11251134
num_layers = len(multipliers) - 1
11261135
self.bottleneck = bottleneck
11271136
self.use_noisy = use_noisy
1137+
self.use_norm = norm > 0.0
1138+
self.norm = norm
11281139

11291140
assert len(factors) >= num_layers and len(num_blocks) >= num_layers
11301141

@@ -1174,6 +1185,9 @@ def __init__(
11741185
def encode(
11751186
self, x: Tensor, with_info: bool = False
11761187
) -> Union[Tensor, Tuple[Tensor, Any]]:
1188+
if self.use_norm:
1189+
x = wave_norm(x, peak=self.norm)
1190+
11771191
x = self.to_in(x)
11781192
for downsample in self.downsamples:
11791193
x = downsample(x)
@@ -1190,7 +1204,12 @@ def decode(self, x: Tensor) -> Tensor:
11901204
x = upsample(x)
11911205
if self.use_noisy:
11921206
x = torch.cat([x, torch.randn_like(x)], dim=1)
1193-
return self.to_out(x)
1207+
x = self.to_out(x)
1208+
1209+
if self.use_norm:
1210+
x = wave_unnorm(x, peak=self.norm)
1211+
1212+
return x
11941213

11951214

11961215
class MultiEncoder1d(nn.Module):

Diff for: audio_diffusion_pytorch/utils.py

+13
Original file line numberDiff line numberDiff line change
@@ -83,3 +83,16 @@ def downsample(waveforms: Tensor, factor: int, **kwargs) -> Tensor:
8383

8484
def upsample(waveforms: Tensor, factor: int, **kwargs) -> Tensor:
8585
return resample(waveforms, factor_in=1, factor_out=factor, **kwargs)
86+
87+
88+
def wave_norm(x: Tensor, bits: int = 24, peak: float = 0.5) -> Tensor:
89+
mu = 2 ** bits
90+
x = torch.sign(x) * torch.log1p(mu * torch.abs(x)) / math.log1p(mu)
91+
return x * peak
92+
93+
94+
def wave_unnorm(x: Tensor, bits: int = 24, peak: float = 0.5) -> Tensor:
95+
x = x / peak
96+
mu = 2 ** bits
97+
x = torch.sign(x) * (torch.exp(torch.abs(x) * math.log1p(mu)) - 1) / mu
98+
return x

Diff for: setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name="audio-diffusion-pytorch",
55
packages=find_packages(exclude=[]),
6-
version="0.0.52",
6+
version="0.0.53",
77
license="MIT",
88
description="Audio Diffusion - PyTorch",
99
long_description_content_type="text/markdown",

0 commit comments

Comments
 (0)