@@ -19,8 +19,13 @@ def __call__(self, num_samples: int, device: torch.device):
19
19
20
20
21
21
class UniformDistribution (Distribution ):
22
+ def __init__ (self , vmin : float = 0.0 , vmax : float = 1.0 ):
23
+ super ().__init__ ()
24
+ self .vmin , self .vmax = vmin , vmax
25
+
22
26
def __call__ (self , num_samples : int , device : torch .device = torch .device ("cpu" )):
23
- return torch .rand (num_samples , device = device )
27
+ vmax , vmin = self .vmax , self .vmin
28
+ return (vmax - vmin ) * torch .rand (num_samples , device = device ) + vmin
24
29
25
30
26
31
""" Diffusion Methods """
@@ -132,8 +137,12 @@ def forward(self, num_steps: int, device: torch.device) -> Tensor:
132
137
133
138
134
139
class LinearSchedule (Schedule ):
140
+ def __init__ (self , start : float = 1.0 , end : float = 0.0 ):
141
+ super ().__init__ ()
142
+ self .start , self .end = start , end
143
+
135
144
def forward (self , num_steps : int , device : Any ) -> Tensor :
136
- return torch .linspace (1.0 , 0.0 , num_steps , device = device )
145
+ return torch .linspace (self . start , self . end , num_steps , device = device )
137
146
138
147
139
148
""" Samplers """
@@ -158,14 +167,13 @@ def get_alpha_beta(self, sigmas: Tensor) -> Tuple[Tensor, Tensor]:
158
167
return alpha , beta
159
168
160
169
def forward ( # type: ignore
161
- self , noise : Tensor , num_steps : int , show_progress : bool = False , ** kwargs
170
+ self , x_noisy : Tensor , num_steps : int , show_progress : bool = False , ** kwargs
162
171
) -> Tensor :
163
- b = noise .shape [0 ]
164
- sigmas = self .schedule (num_steps + 1 , device = noise .device )
172
+ b = x_noisy .shape [0 ]
173
+ sigmas = self .schedule (num_steps + 1 , device = x_noisy .device )
165
174
sigmas = repeat (sigmas , "i -> i b" , b = b )
166
- sigmas_batch = extend_dim (sigmas , dim = noise .ndim + 1 )
175
+ sigmas_batch = extend_dim (sigmas , dim = x_noisy .ndim + 1 )
167
176
alphas , betas = self .get_alpha_beta (sigmas_batch )
168
- x_noisy = noise * sigmas_batch [0 ]
169
177
progress_bar = tqdm (range (num_steps ), disable = not show_progress )
170
178
171
179
for i in progress_bar :
0 commit comments