Skip to content

Commit a34014f

Browse files
feat: add parameters linear schedule, uniform distribution
1 parent 7517b9f commit a34014f

File tree

2 files changed

+16
-8
lines changed

2 files changed

+16
-8
lines changed

audio_diffusion_pytorch/diffusion.py

+15-7
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,13 @@ def __call__(self, num_samples: int, device: torch.device):
1919

2020

2121
class UniformDistribution(Distribution):
22+
def __init__(self, vmin: float = 0.0, vmax: float = 1.0):
23+
super().__init__()
24+
self.vmin, self.vmax = vmin, vmax
25+
2226
def __call__(self, num_samples: int, device: torch.device = torch.device("cpu")):
23-
return torch.rand(num_samples, device=device)
27+
vmax, vmin = self.vmax, self.vmin
28+
return (vmax - vmin) * torch.rand(num_samples, device=device) + vmin
2429

2530

2631
""" Diffusion Methods """
@@ -132,8 +137,12 @@ def forward(self, num_steps: int, device: torch.device) -> Tensor:
132137

133138

134139
class LinearSchedule(Schedule):
140+
def __init__(self, start: float = 1.0, end: float = 0.0):
141+
super().__init__()
142+
self.start, self.end = start, end
143+
135144
def forward(self, num_steps: int, device: Any) -> Tensor:
136-
return torch.linspace(1.0, 0.0, num_steps, device=device)
145+
return torch.linspace(self.start, self.end, num_steps, device=device)
137146

138147

139148
""" Samplers """
@@ -158,14 +167,13 @@ def get_alpha_beta(self, sigmas: Tensor) -> Tuple[Tensor, Tensor]:
158167
return alpha, beta
159168

160169
def forward( # type: ignore
161-
self, noise: Tensor, num_steps: int, show_progress: bool = False, **kwargs
170+
self, x_noisy: Tensor, num_steps: int, show_progress: bool = False, **kwargs
162171
) -> Tensor:
163-
b = noise.shape[0]
164-
sigmas = self.schedule(num_steps + 1, device=noise.device)
172+
b = x_noisy.shape[0]
173+
sigmas = self.schedule(num_steps + 1, device=x_noisy.device)
165174
sigmas = repeat(sigmas, "i -> i b", b=b)
166-
sigmas_batch = extend_dim(sigmas, dim=noise.ndim + 1)
175+
sigmas_batch = extend_dim(sigmas, dim=x_noisy.ndim + 1)
167176
alphas, betas = self.get_alpha_beta(sigmas_batch)
168-
x_noisy = noise * sigmas_batch[0]
169177
progress_bar = tqdm(range(num_steps), disable=not show_progress)
170178

171179
for i in progress_bar:

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name="audio-diffusion-pytorch",
55
packages=find_packages(exclude=[]),
6-
version="0.1.0",
6+
version="0.1.1",
77
license="MIT",
88
description="Audio Diffusion - PyTorch",
99
long_description_content_type="text/markdown",

0 commit comments

Comments
 (0)