11from abc import ABC , abstractmethod
2- from typing import Union , Literal
2+ from typing import Literal
33
44from keras import ops
55
@@ -33,7 +33,7 @@ def __init__(
3333 weighting : Literal ["sigmoid" , "likelihood_weighting" ] = None ,
3434 ):
3535 """
36- Initialize the noise schedule.
36+ Initialize the noise schedule with given variance and weighting strategy .
3737
3838 Parameters
3939 ----------
@@ -54,21 +54,23 @@ def __init__(
5454 self ._weighting = weighting
5555
5656 @abstractmethod
57- def get_log_snr (self , t : Union [ float , Tensor ] , training : bool ) -> Tensor :
57+ def get_log_snr (self , t : float | Tensor , training : bool ) -> Tensor :
5858 """Get the log signal-to-noise ratio (lambda) for a given diffusion time."""
5959 pass
6060
6161 @abstractmethod
62- def get_t_from_log_snr (self , log_snr_t : Union [ float , Tensor ] , training : bool ) -> Tensor :
62+ def get_t_from_log_snr (self , log_snr_t : float | Tensor , training : bool ) -> Tensor :
6363 """Get the diffusion time (t) from the log signal-to-noise ratio (lambda)."""
6464 pass
6565
6666 @abstractmethod
67- def derivative_log_snr (self , log_snr_t : Union [ float , Tensor ] , training : bool ) -> Tensor :
67+ def derivative_log_snr (self , log_snr_t : float | Tensor , training : bool ) -> Tensor :
6868 r"""Compute \beta(t) = d/dt log(1 + e^(-snr(t))). This is usually used for the reverse SDE."""
6969 pass
7070
71- def get_drift_diffusion (self , log_snr_t : Tensor , x : Tensor = None , training : bool = False ) -> tuple [Tensor , Tensor ]:
71+ def get_drift_diffusion (
72+ self , log_snr_t : Tensor , x : Tensor = None , training : bool = False
73+ ) -> Tensor | tuple [Tensor , Tensor ]:
7274 r"""Compute the drift and optionally the squared diffusion term for the reverse SDE.
7375 It can be derived from the derivative of the schedule:
7476
@@ -97,10 +99,10 @@ def get_drift_diffusion(self, log_snr_t: Tensor, x: Tensor = None, training: boo
9799 raise ValueError (f"Unknown variance type: { self ._variance_type } " )
98100 return f , beta
99101
100- def get_alpha_sigma (self , log_snr_t : Tensor , training : bool ) -> tuple [Tensor , Tensor ]:
102+ def get_alpha_sigma (self , log_snr_t : Tensor ) -> tuple [Tensor , Tensor ]:
101103 """Get alpha and sigma for a given log signal-to-noise ratio (lambda).
102104
103- Default is a variance preserving schedule::
105+ Default is a variance preserving schedule:
104106
105107 alpha(t) = sqrt(sigmoid(log_snr_t))
106108 sigma(t) = sqrt(sigmoid(-log_snr_t))
@@ -120,9 +122,32 @@ def get_alpha_sigma(self, log_snr_t: Tensor, training: bool) -> tuple[Tensor, Te
120122 return alpha_t , sigma_t
121123
122124 def get_weights_for_snr (self , log_snr_t : Tensor ) -> Tensor :
123- """Get weights for the signal-to-noise ratio (snr) for a given log signal-to-noise ratio (lambda).
124- Default weighting is None, which means only ones are returned.
125- Generally, weighting functions should be defined for a noise prediction loss.
125+ """
126+ Compute loss weights based on log signal-to-noise ratio (log-SNR).
127+
128+ This method returns a tensor of weights used for loss re-weighting in diffusion models,
129+ depending on the selected strategy. If no weighting is specified, uniform weights (ones)
130+ are returned.
131+
132+ Supported weighting strategies:
133+ - "sigmoid": Based on Kingma et al. (2023), uses a sigmoid of shifted log-SNR.
134+ - "likelihood_weighting": Based on Song et al. (2021), uses ratio of diffusion drift
135+ to squared noise scale.
136+
137+ Parameters
138+ ----------
139+ log_snr_t : Tensor
140+ A tensor containing the log signal-to-noise ratio values.
141+
142+ Returns
143+ -------
144+ Tensor
145+ A tensor of weights corresponding to each log-SNR value.
146+
147+ Raises
148+ ------
149+ TypeError
150+ If the weighting strategy specified in `self._weighting` is unknown.
126151 """
127152 if self ._weighting is None :
128153 return ops .ones_like (log_snr_t )
@@ -131,33 +156,37 @@ def get_weights_for_snr(self, log_snr_t: Tensor) -> Tensor:
131156 return ops .sigmoid (- log_snr_t + 2 )
132157 elif self ._weighting == "likelihood_weighting" :
133158 # likelihood weighting based on Song et al. (2021)
134- g_squared = self .get_drift_diffusion (log_snr_t = log_snr_t )
135- sigma_t = self .get_alpha_sigma (log_snr_t = log_snr_t , training = True )[ 1 ]
159+ g_squared = self .get_drift_diffusion (log_snr_t )
160+ _ , sigma_t = self .get_alpha_sigma (log_snr_t )
136161 return g_squared / ops .square (sigma_t )
137162 else :
138163 raise TypeError (f"Unknown weighting type: { self ._weighting } " )
139164
140165 def get_config (self ):
141- return dict ( name = self .name , variance_type = self ._variance_type , weighting = self ._weighting )
166+ return { " name" : self .name , " variance_type" : self ._variance_type , " weighting" : self ._weighting }
142167
143168 @classmethod
144169 def from_config (cls , config , custom_objects = None ):
145170 return cls (** deserialize (config , custom_objects = custom_objects ))
146171
147172 def validate (self ):
148173 """Validate the noise schedule."""
174+
149175 if self .log_snr_min >= self .log_snr_max :
150176 raise ValueError ("min_log_snr must be less than max_log_snr." )
151- for training in [True , False ]:
177+
178+ # Validate log SNR values and corresponding time mappings for both training and inference
179+ for training in (True , False ):
152180 if not ops .isfinite (self .get_log_snr (0.0 , training = training )):
153- raise ValueError (f"log_snr(0) must be finite with training={ training } . " )
181+ raise ValueError (f"log_snr(0.0 ) must be finite ( training={ training } ) " )
154182 if not ops .isfinite (self .get_log_snr (1.0 , training = training )):
155- raise ValueError (f"log_snr(1) must be finite with training={ training } . " )
183+ raise ValueError (f"log_snr(1.0 ) must be finite ( training={ training } ) " )
156184 if not ops .isfinite (self .get_t_from_log_snr (self .log_snr_max , training = training )):
157- raise ValueError (f"t(0 ) must be finite with training={ training } . " )
185+ raise ValueError (f"t(log_snr_max ) must be finite ( training={ training } ) " )
158186 if not ops .isfinite (self .get_t_from_log_snr (self .log_snr_min , training = training )):
159- raise ValueError (f"t(1) must be finite with training={ training } ." )
160- if not ops .isfinite (self .derivative_log_snr (self .log_snr_max , training = False )):
161- raise ValueError ("dt/t log_snr(0) must be finite." )
162- if not ops .isfinite (self .derivative_log_snr (self .log_snr_min , training = False )):
163- raise ValueError ("dt/t log_snr(1) must be finite." )
187+ raise ValueError (f"t(log_snr_min) must be finite (training={ training } )" )
188+
189+ # Validate log SNR derivatives at the boundaries
190+ for boundary , name in [(self .log_snr_max , "log_snr_max (t=0)" ), (self .log_snr_min , "log_snr_min (t=1)" )]:
191+ if not ops .isfinite (self .derivative_log_snr (boundary , training = False )):
192+ raise ValueError (f"derivative_log_snr at { name } must be finite." )
0 commit comments