Skip to content

Commit 38186ec

Browse files
Align diffusion model with other inference networks and remove deprecation warnings (#489)
* Align dm implementation with other networks * Remove deprecation warning for using subnet_kwargs * Fix tests * Remove redundant training arg in get_alpha_sigma and some redundant comments * Fix configs creation - do not get base config due to fixed call of super().__init__() * Remove redundant training arg from tests * Fix dispatch tests for dms * Improve docs and mark option for x prediction in literal * Fix start/stop time * minor cleanup of refactory --------- Co-authored-by: Valentin Pratz <[email protected]>
1 parent 361fa45 commit 38186ec

File tree

13 files changed

+239
-191
lines changed

13 files changed

+239
-191
lines changed

bayesflow/experimental/diffusion_model/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from .diffusion_model import DiffusionModel
2-
from .noise_schedule import NoiseSchedule
3-
from .cosine_noise_schedule import CosineNoiseSchedule
4-
from .edm_noise_schedule import EDMNoiseSchedule
2+
from bayesflow.experimental.diffusion_model.schedules import CosineNoiseSchedule
3+
from bayesflow.experimental.diffusion_model.schedules import EDMNoiseSchedule
4+
from bayesflow.experimental.diffusion_model.schedules import NoiseSchedule
55
from .dispatch import find_noise_schedule
66

77
from ...utils._docs import _add_imports_to_all

bayesflow/experimental/diffusion_model/diffusion_model.py

Lines changed: 139 additions & 76 deletions
Large diffs are not rendered by default.

bayesflow/experimental/diffusion_model/dispatch.py

Lines changed: 6 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from functools import singledispatch
2-
from .noise_schedule import NoiseSchedule
2+
3+
from .schedules.noise_schedule import NoiseSchedule
34

45

56
@singledispatch
@@ -16,34 +17,17 @@ def _(noise_schedule: NoiseSchedule):
1617
def _(name: str, *args, **kwargs):
1718
match name.lower():
1819
case "cosine":
19-
from .cosine_noise_schedule import CosineNoiseSchedule
20+
from .schedules import CosineNoiseSchedule
2021

21-
return CosineNoiseSchedule()
22+
return CosineNoiseSchedule(*args, **kwargs)
2223
case "edm":
23-
from .edm_noise_schedule import EDMNoiseSchedule
24+
from .schedules import EDMNoiseSchedule
2425

25-
return EDMNoiseSchedule()
26+
return EDMNoiseSchedule(*args, **kwargs)
2627
case other:
2728
raise ValueError(f"Unsupported noise schedule name: '{other}'.")
2829

2930

30-
@find_noise_schedule.register
31-
def _(config: dict, *args, **kwargs):
32-
name = config.get("name", "").lower()
33-
params = {k: v for k, v in config.items() if k != "name"}
34-
match name:
35-
case "cosine":
36-
from .cosine_noise_schedule import CosineNoiseSchedule
37-
38-
return CosineNoiseSchedule(**params)
39-
case "edm":
40-
from .edm_noise_schedule import EDMNoiseSchedule
41-
42-
return EDMNoiseSchedule(**params)
43-
case other:
44-
raise ValueError(f"Unsupported noise schedule config: '{other}'.")
45-
46-
4731
@find_noise_schedule.register
4832
def _(cls: type, *args, **kwargs):
4933
if issubclass(cls, NoiseSchedule):
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .noise_schedule import NoiseSchedule
2+
from .cosine_noise_schedule import CosineNoiseSchedule
3+
from .edm_noise_schedule import EDMNoiseSchedule

bayesflow/experimental/diffusion_model/cosine_noise_schedule.py renamed to bayesflow/experimental/diffusion_model/schedules/cosine_noise_schedule.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import math
2-
from typing import Union, Literal
2+
from typing import Literal
33

44
from keras import ops
55

@@ -14,7 +14,14 @@
1414
class CosineNoiseSchedule(NoiseSchedule):
1515
"""Cosine noise schedule for diffusion models. This schedule is based on the cosine schedule from [1].
1616
17-
[1] Diffusion Models Beat GANs on Image Synthesis: Dhariwal and Nichol (2022)
17+
A cosine schedule is a popular technique for controlling how the variance (noise level) or
18+
learning rate evolves during the training of diffusion models. It was proposed as an improvement
19+
over the original linear beta schedule in [2]
20+
21+
[1] Dhariwal, P., & Nichol, A. (2021). Diffusion models beat gans on image synthesis.
22+
Advances in Neural Information Processing Systems, 34, 8780-8794.
23+
[2] Ho, J., Jain, A., & Abbeel, P. (2020). Denoising diffusion probabilistic models.
24+
Advances in Neural Information Processing Systems, 33, 6840-6851.
1825
"""
1926

2027
def __init__(
@@ -51,12 +58,12 @@ def __init__(
5158
def _truncated_t(self, t: Tensor) -> Tensor:
5259
return self._t_min + (self._t_max - self._t_min) * t
5360

54-
def get_log_snr(self, t: Union[float, Tensor], training: bool) -> Tensor:
61+
def get_log_snr(self, t: Tensor | float, training: bool) -> Tensor:
5562
"""Get the log signal-to-noise ratio (lambda) for a given diffusion time."""
5663
t_trunc = self._truncated_t(t)
5764
return -2 * ops.log(ops.tan(math.pi * t_trunc * 0.5)) + 2 * self._shift
5865

59-
def get_t_from_log_snr(self, log_snr_t: Union[Tensor, float], training: bool) -> Tensor:
66+
def get_t_from_log_snr(self, log_snr_t: Tensor | float, training: bool) -> Tensor:
6067
"""Get the diffusion time (t) from the log signal-to-noise ratio (lambda)."""
6168
# SNR = -2 * log(tan(pi*t/2)) => t = 2/pi * arctan(exp(-snr/2))
6269
return 2 / math.pi * ops.arctan(ops.exp((2 * self._shift - log_snr_t) * 0.5))
@@ -76,9 +83,13 @@ def derivative_log_snr(self, log_snr_t: Tensor, training: bool) -> Tensor:
7683
return -factor * dsnr_dt
7784

7885
def get_config(self):
79-
return dict(
80-
min_log_snr=self.log_snr_min, max_log_snr=self.log_snr_max, shift=self._shift, weighting=self._weighting
81-
)
86+
config = {
87+
"min_log_snr": self.log_snr_min,
88+
"max_log_snr": self.log_snr_max,
89+
"shift": self._shift,
90+
"weighting": self._weighting,
91+
}
92+
return config
8293

8394
@classmethod
8495
def from_config(cls, config, custom_objects=None):

bayesflow/experimental/diffusion_model/edm_noise_schedule.py renamed to bayesflow/experimental/diffusion_model/schedules/edm_noise_schedule.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import math
2-
from typing import Union
32

43
from keras import ops
54

@@ -15,7 +14,8 @@ class EDMNoiseSchedule(NoiseSchedule):
1514
"""EDM noise schedule for diffusion models. This schedule is based on the EDM paper [1].
1615
This should be used with the F-prediction type in the diffusion model.
1716
18-
[1] Elucidating the Design Space of Diffusion-Based Generative Models: Karras et al. (2022)
17+
[1] Karras, T., Aittala, M., Aila, T., & Laine, S. (2022). Elucidating the design space of diffusion-based
18+
generative models. Advances in Neural Information Processing Systems, 35, 26565-26577.
1919
"""
2020

2121
def __init__(self, sigma_data: float = 1.0, sigma_min: float = 1e-4, sigma_max: float = 80.0):
@@ -26,7 +26,7 @@ def __init__(self, sigma_data: float = 1.0, sigma_min: float = 1e-4, sigma_max:
2626
----------
2727
sigma_data : float, optional
2828
The standard deviation of the output distribution. Input of the network is scaled by this factor and
29-
the weighting function is scaled by this factor as well.
29+
the weighting function is scaled by this factor as well. Default is 1.0.
3030
sigma_min : float, optional
3131
The minimum noise level. Only relevant for sampling. Default is 1e-4.
3232
sigma_max : float, optional
@@ -50,21 +50,21 @@ def __init__(self, sigma_data: float = 1.0, sigma_min: float = 1e-4, sigma_max:
5050
self._log_snr_min_training = self.log_snr_min - 1 # one is never sampler during training
5151
self._log_snr_max_training = self.log_snr_max + 1 # 0 is almost surely never sampled during training
5252

53-
def get_log_snr(self, t: Union[float, Tensor], training: bool) -> Tensor:
53+
def get_log_snr(self, t: float | Tensor, training: bool) -> Tensor:
5454
"""Get the log signal-to-noise ratio (lambda) for a given diffusion time."""
5555
if training:
56-
# SNR = -dist.icdf(t_trunc) # negative seems to be wrong in the paper in the Kingma paper
56+
# SNR = -dist.icdf(t_trunc) # negative seems to be wrong in the Kingma paper
5757
loc = -2 * self.p_mean
5858
scale = 2 * self.p_std
5959
snr = loc + scale * ops.erfinv(2 * t - 1) * math.sqrt(2)
6060
snr = ops.clip(snr, x_min=self._log_snr_min_training, x_max=self._log_snr_max_training)
61-
else: # sampling
61+
else:
6262
sigma_min_rho = self.sigma_min ** (1 / self.rho)
6363
sigma_max_rho = self.sigma_max ** (1 / self.rho)
6464
snr = -2 * self.rho * ops.log(sigma_max_rho + (1 - t) * (sigma_min_rho - sigma_max_rho))
6565
return snr
6666

67-
def get_t_from_log_snr(self, log_snr_t: Union[float, Tensor], training: bool) -> Tensor:
67+
def get_t_from_log_snr(self, log_snr_t: float | Tensor, training: bool) -> Tensor:
6868
"""Get the diffusion time (t) from the log signal-to-noise ratio (lambda)."""
6969
if training:
7070
# SNR = -dist.icdf(t_trunc) => t = dist.cdf(-snr) # negative seems to be wrong in the Kingma paper
@@ -80,7 +80,7 @@ def get_t_from_log_snr(self, log_snr_t: Union[float, Tensor], training: bool) ->
8080
t = 1 - ((ops.exp(-log_snr_t / (2 * self.rho)) - sigma_max_rho) / (sigma_min_rho - sigma_max_rho))
8181
return t
8282

83-
def derivative_log_snr(self, log_snr_t: Tensor, training: bool) -> Tensor:
83+
def derivative_log_snr(self, log_snr_t: Tensor, training: bool = False) -> Tensor:
8484
"""Compute d/dt log(1 + e^(-snr(t))), which is used for the reverse SDE."""
8585
if training:
8686
raise NotImplementedError("Derivative of log SNR is not implemented for training mode.")
@@ -101,11 +101,12 @@ def derivative_log_snr(self, log_snr_t: Tensor, training: bool) -> Tensor:
101101

102102
def get_weights_for_snr(self, log_snr_t: Tensor) -> Tensor:
103103
"""Get weights for the signal-to-noise ratio (snr) for a given log signal-to-noise ratio (lambda)."""
104-
# for F-prediction: w = (ops.exp(-log_snr_t) + sigma_data^2) / (ops.exp(-log_snr_t)*sigma_data^2)
105-
return ops.exp(-log_snr_t) / ops.square(self.sigma_data) + 1
104+
# for F-loss: w = (ops.exp(-log_snr_t) + sigma_data^2) / (ops.exp(-log_snr_t)*sigma_data^2)
105+
return 1 + ops.exp(-log_snr_t) / ops.square(self.sigma_data)
106106

107107
def get_config(self):
108-
return dict(sigma_data=self.sigma_data, sigma_min=self.sigma_min, sigma_max=self.sigma_max)
108+
config = {"sigma_data": self.sigma_data, "sigma_min": self.sigma_min, "sigma_max": self.sigma_max}
109+
return config
109110

110111
@classmethod
111112
def from_config(cls, config, custom_objects=None):

bayesflow/experimental/diffusion_model/noise_schedule.py renamed to bayesflow/experimental/diffusion_model/schedules/noise_schedule.py

Lines changed: 52 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from abc import ABC, abstractmethod
2-
from typing import Union, Literal
2+
from typing import Literal
33

44
from keras import ops
55

@@ -33,7 +33,7 @@ def __init__(
3333
weighting: Literal["sigmoid", "likelihood_weighting"] = None,
3434
):
3535
"""
36-
Initialize the noise schedule.
36+
Initialize the noise schedule with given variance and weighting strategy.
3737
3838
Parameters
3939
----------
@@ -54,21 +54,23 @@ def __init__(
5454
self._weighting = weighting
5555

5656
@abstractmethod
57-
def get_log_snr(self, t: Union[float, Tensor], training: bool) -> Tensor:
57+
def get_log_snr(self, t: float | Tensor, training: bool) -> Tensor:
5858
"""Get the log signal-to-noise ratio (lambda) for a given diffusion time."""
5959
pass
6060

6161
@abstractmethod
62-
def get_t_from_log_snr(self, log_snr_t: Union[float, Tensor], training: bool) -> Tensor:
62+
def get_t_from_log_snr(self, log_snr_t: float | Tensor, training: bool) -> Tensor:
6363
"""Get the diffusion time (t) from the log signal-to-noise ratio (lambda)."""
6464
pass
6565

6666
@abstractmethod
67-
def derivative_log_snr(self, log_snr_t: Union[float, Tensor], training: bool) -> Tensor:
67+
def derivative_log_snr(self, log_snr_t: float | Tensor, training: bool) -> Tensor:
6868
r"""Compute \beta(t) = d/dt log(1 + e^(-snr(t))). This is usually used for the reverse SDE."""
6969
pass
7070

71-
def get_drift_diffusion(self, log_snr_t: Tensor, x: Tensor = None, training: bool = False) -> tuple[Tensor, Tensor]:
71+
def get_drift_diffusion(
72+
self, log_snr_t: Tensor, x: Tensor = None, training: bool = False
73+
) -> Tensor | tuple[Tensor, Tensor]:
7274
r"""Compute the drift and optionally the squared diffusion term for the reverse SDE.
7375
It can be derived from the derivative of the schedule:
7476
@@ -97,10 +99,10 @@ def get_drift_diffusion(self, log_snr_t: Tensor, x: Tensor = None, training: boo
9799
raise ValueError(f"Unknown variance type: {self._variance_type}")
98100
return f, beta
99101

100-
def get_alpha_sigma(self, log_snr_t: Tensor, training: bool) -> tuple[Tensor, Tensor]:
102+
def get_alpha_sigma(self, log_snr_t: Tensor) -> tuple[Tensor, Tensor]:
101103
"""Get alpha and sigma for a given log signal-to-noise ratio (lambda).
102104
103-
Default is a variance preserving schedule::
105+
Default is a variance preserving schedule:
104106
105107
alpha(t) = sqrt(sigmoid(log_snr_t))
106108
sigma(t) = sqrt(sigmoid(-log_snr_t))
@@ -120,9 +122,32 @@ def get_alpha_sigma(self, log_snr_t: Tensor, training: bool) -> tuple[Tensor, Te
120122
return alpha_t, sigma_t
121123

122124
def get_weights_for_snr(self, log_snr_t: Tensor) -> Tensor:
123-
"""Get weights for the signal-to-noise ratio (snr) for a given log signal-to-noise ratio (lambda).
124-
Default weighting is None, which means only ones are returned.
125-
Generally, weighting functions should be defined for a noise prediction loss.
125+
"""
126+
Compute loss weights based on log signal-to-noise ratio (log-SNR).
127+
128+
This method returns a tensor of weights used for loss re-weighting in diffusion models,
129+
depending on the selected strategy. If no weighting is specified, uniform weights (ones)
130+
are returned.
131+
132+
Supported weighting strategies:
133+
- "sigmoid": Based on Kingma et al. (2023), uses a sigmoid of shifted log-SNR.
134+
- "likelihood_weighting": Based on Song et al. (2021), uses ratio of diffusion drift
135+
to squared noise scale.
136+
137+
Parameters
138+
----------
139+
log_snr_t : Tensor
140+
A tensor containing the log signal-to-noise ratio values.
141+
142+
Returns
143+
-------
144+
Tensor
145+
A tensor of weights corresponding to each log-SNR value.
146+
147+
Raises
148+
------
149+
TypeError
150+
If the weighting strategy specified in `self._weighting` is unknown.
126151
"""
127152
if self._weighting is None:
128153
return ops.ones_like(log_snr_t)
@@ -131,33 +156,37 @@ def get_weights_for_snr(self, log_snr_t: Tensor) -> Tensor:
131156
return ops.sigmoid(-log_snr_t + 2)
132157
elif self._weighting == "likelihood_weighting":
133158
# likelihood weighting based on Song et al. (2021)
134-
g_squared = self.get_drift_diffusion(log_snr_t=log_snr_t)
135-
sigma_t = self.get_alpha_sigma(log_snr_t=log_snr_t, training=True)[1]
159+
g_squared = self.get_drift_diffusion(log_snr_t)
160+
_, sigma_t = self.get_alpha_sigma(log_snr_t)
136161
return g_squared / ops.square(sigma_t)
137162
else:
138163
raise TypeError(f"Unknown weighting type: {self._weighting}")
139164

140165
def get_config(self):
141-
return dict(name=self.name, variance_type=self._variance_type, weighting=self._weighting)
166+
return {"name": self.name, "variance_type": self._variance_type, "weighting": self._weighting}
142167

143168
@classmethod
144169
def from_config(cls, config, custom_objects=None):
145170
return cls(**deserialize(config, custom_objects=custom_objects))
146171

147172
def validate(self):
148173
"""Validate the noise schedule."""
174+
149175
if self.log_snr_min >= self.log_snr_max:
150176
raise ValueError("min_log_snr must be less than max_log_snr.")
151-
for training in [True, False]:
177+
178+
# Validate log SNR values and corresponding time mappings for both training and inference
179+
for training in (True, False):
152180
if not ops.isfinite(self.get_log_snr(0.0, training=training)):
153-
raise ValueError(f"log_snr(0) must be finite with training={training}.")
181+
raise ValueError(f"log_snr(0.0) must be finite (training={training})")
154182
if not ops.isfinite(self.get_log_snr(1.0, training=training)):
155-
raise ValueError(f"log_snr(1) must be finite with training={training}.")
183+
raise ValueError(f"log_snr(1.0) must be finite (training={training})")
156184
if not ops.isfinite(self.get_t_from_log_snr(self.log_snr_max, training=training)):
157-
raise ValueError(f"t(0) must be finite with training={training}.")
185+
raise ValueError(f"t(log_snr_max) must be finite (training={training})")
158186
if not ops.isfinite(self.get_t_from_log_snr(self.log_snr_min, training=training)):
159-
raise ValueError(f"t(1) must be finite with training={training}.")
160-
if not ops.isfinite(self.derivative_log_snr(self.log_snr_max, training=False)):
161-
raise ValueError("dt/t log_snr(0) must be finite.")
162-
if not ops.isfinite(self.derivative_log_snr(self.log_snr_min, training=False)):
163-
raise ValueError("dt/t log_snr(1) must be finite.")
187+
raise ValueError(f"t(log_snr_min) must be finite (training={training})")
188+
189+
# Validate log SNR derivatives at the boundaries
190+
for boundary, name in [(self.log_snr_max, "log_snr_max (t=0)"), (self.log_snr_min, "log_snr_min (t=1)")]:
191+
if not ops.isfinite(self.derivative_log_snr(boundary, training=False)):
192+
raise ValueError(f"derivative_log_snr at {name} must be finite.")

bayesflow/experimental/free_form_flow/free_form_flow.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
import keras
22
from keras import ops
33

4-
import warnings
5-
64
from bayesflow.distributions import Distribution
75
from bayesflow.types import Tensor
86
from bayesflow.utils import (
@@ -86,13 +84,6 @@ def __init__(
8684
"""
8785
super().__init__(base_distribution, **kwargs)
8886

89-
if encoder_subnet_kwargs or decoder_subnet_kwargs:
90-
warnings.warn(
91-
"Using `subnet_kwargs` is deprecated."
92-
"Instead, instantiate the network yourself and pass the arguments directly.",
93-
DeprecationWarning,
94-
)
95-
9687
encoder_subnet_kwargs = encoder_subnet_kwargs or {}
9788
decoder_subnet_kwargs = decoder_subnet_kwargs or {}
9889

bayesflow/networks/consistency_models/consistency_model.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33

44
import numpy as np
55

6-
import warnings
7-
86
from bayesflow.types import Tensor
97
from bayesflow.utils import find_network, layer_kwargs, weighted_mean
108
from bayesflow.utils.serialization import deserialize, serializable, serialize
@@ -76,13 +74,6 @@ def __init__(
7674

7775
self.total_steps = float(total_steps)
7876

79-
if subnet_kwargs:
80-
warnings.warn(
81-
"Using `subnet_kwargs` is deprecated."
82-
"Instead, instantiate the network yourself and pass the arguments directly.",
83-
DeprecationWarning,
84-
)
85-
8677
subnet_kwargs = subnet_kwargs or {}
8778
if subnet == "mlp":
8879
subnet_kwargs = ConsistencyModel.MLP_DEFAULT_CONFIG | subnet_kwargs

0 commit comments

Comments
 (0)