Skip to content

Make diffusion model conditioning more flexible #521

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 41 additions & 12 deletions bayesflow/experimental/diffusion_model/diffusion_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@
if subnet == "mlp":
subnet_kwargs = DiffusionModel.MLP_DEFAULT_CONFIG | subnet_kwargs
self.subnet = find_network(subnet, **subnet_kwargs)
self._concatenate_subnet_input = kwargs.get("concatenate_subnet_input", True)

self.output_projector = keras.layers.Dense(units=None, bias_initializer="zeros", name="output_projector")

Expand Down Expand Up @@ -149,6 +150,8 @@
"prediction_type": self._prediction_type,
"loss_type": self._loss_type,
"integrate_kwargs": self.integrate_kwargs,
"concatenate_subnet_input": self._concatenate_subnet_input,
# we do not need to store subnet_kwargs
}
return base_config | serialize(config)

Expand Down Expand Up @@ -197,6 +200,35 @@
return (z + sigma_t**2 * pred) / alpha_t
raise ValueError(f"Unknown prediction type {self._prediction_type}.")

def _subnet_input(
self, xz: Tensor, log_snr: Tensor, conditions: Tensor = None, training: bool = False
) -> Tensor | tuple[Tensor, Tensor, Tensor]:
"""
Prepares the input for the subnet either by concatenating the latent variable `xz`,
the signal-to-noise ratio `log_snr`, and optional conditions or by returning them separately.

Parameters
----------
xz : Tensor
The noisy input tensor for the diffusion model, typically of shape (..., D), but can vary.
log_snr : Tensor
The log signal-to-noise ratio tensor, typically of shape (..., 1).
conditions : Tensor, optional
The optional conditioning tensor (e.g. parameters).
training : bool, optional
The training mode flag, which can be used to control behavior during training.

Returns
-------
Tensor
The concatenated input tensor for the subnet or a tuple of tensors if concatenation is disabled.
"""
if self._concatenate_subnet_input:
xtc = tensor_utils.concatenate_valid([xz, log_snr, conditions], axis=-1)
return self.subnet(xtc, training=training)
else:
return self.subnet(xz, log_snr, conditions, training=training)

Check warning on line 230 in bayesflow/experimental/diffusion_model/diffusion_model.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/experimental/diffusion_model/diffusion_model.py#L230

Added line #L230 was not covered by tests

def velocity(
self,
xz: Tensor,
Expand All @@ -221,7 +253,7 @@
If True, computes the velocity for the stochastic formulation (SDE).
If False, uses the deterministic formulation (ODE).
conditions : Tensor, optional
Optional conditional inputs to the network, such as conditioning variables
Conditional inputs to the network, such as conditioning variables
or encoder outputs. Shape must be broadcastable with `xz`. Default is None.
training : bool, optional
Whether the model is in training mode. Affects behavior of dropout, batch norm,
Expand All @@ -238,12 +270,10 @@
log_snr_t = ops.broadcast_to(log_snr_t, ops.shape(xz)[:-1] + (1,))
alpha_t, sigma_t = self.noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t)

if conditions is None:
xtc = tensor_utils.concatenate_valid([xz, self._transform_log_snr(log_snr_t)], axis=-1)
else:
xtc = tensor_utils.concatenate_valid([xz, self._transform_log_snr(log_snr_t), conditions], axis=-1)

pred = self.output_projector(self.subnet(xtc, training=training), training=training)
subnet_out = self._subnet_input(
xz, self._transform_log_snr(log_snr_t), conditions=conditions, training=training
)
pred = self.output_projector(subnet_out, training=training)

x_pred = self.convert_prediction_to_x(pred=pred, z=xz, alpha_t=alpha_t, sigma_t=sigma_t, log_snr_t=log_snr_t)

Expand Down Expand Up @@ -461,11 +491,10 @@
diffused_x = alpha_t * x + sigma_t * eps_t

# calculate output of the network
if conditions is None:
xtc = tensor_utils.concatenate_valid([diffused_x, self._transform_log_snr(log_snr_t)], axis=-1)
else:
xtc = tensor_utils.concatenate_valid([diffused_x, self._transform_log_snr(log_snr_t), conditions], axis=-1)
pred = self.output_projector(self.subnet(xtc, training=training), training=training)
subnet_out = self._subnet_input(
diffused_x, self._transform_log_snr(log_snr_t), conditions=conditions, training=training
)
pred = self.output_projector(subnet_out, training=training)

x_pred = self.convert_prediction_to_x(
pred=pred, z=diffused_x, alpha_t=alpha_t, sigma_t=sigma_t, log_snr_t=log_snr_t
Expand Down
45 changes: 36 additions & 9 deletions bayesflow/networks/consistency_models/consistency_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np

from bayesflow.types import Tensor
from bayesflow.utils import find_network, layer_kwargs, weighted_mean
from bayesflow.utils import find_network, layer_kwargs, weighted_mean, tensor_utils, expand_right_as
from bayesflow.utils.serialization import deserialize, serializable, serialize

from ..inference_network import InferenceNetwork
Expand Down Expand Up @@ -77,6 +77,7 @@
subnet_kwargs = subnet_kwargs or {}
if subnet == "mlp":
subnet_kwargs = ConsistencyModel.MLP_DEFAULT_CONFIG | subnet_kwargs
self._concatenate_subnet_input = kwargs.get("concatenate_subnet_input", True)

self.subnet = find_network(subnet, **subnet_kwargs)
self.output_projector = keras.layers.Dense(
Expand Down Expand Up @@ -119,6 +120,7 @@
"eps": self.eps,
"s0": self.s0,
"s1": self.s1,
"concatenate_subnet_input": self._concatenate_subnet_input,
# we do not need to store subnet_kwargs
}

Expand Down Expand Up @@ -256,6 +258,35 @@
x = self.consistency_function(x_n, t, conditions=conditions, training=training)
return x

def _subnet_input(
self, x: Tensor, t: Tensor, conditions: Tensor = None, training: bool = False
) -> Tensor | tuple[Tensor, Tensor, Tensor]:
"""
Prepares the input for the subnet either by concatenating the latent variable `x`,
the time `t`, and optional conditions or by returning them separately.

Parameters
----------
x : Tensor
The input tensor for the diffusion model, typically of shape (..., D), but can vary.
t : Tensor
The time tensor, typically of shape (..., 1).
conditions : Tensor, optional
The optional conditioning tensor (e.g. parameters).
training : bool, optional
The training mode flag, which can be used to control behavior during training.

Returns
-------
Tensor
The concatenated input tensor for the subnet or a tuple of tensors if concatenation is disabled.
"""
if self._concatenate_subnet_input:
xtc = tensor_utils.concatenate_valid([x, t, conditions], axis=-1)
return self.subnet(xtc, training=training)
else:
return self.subnet(x, t, conditions, training=training)

Check warning on line 288 in bayesflow/networks/consistency_models/consistency_model.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/networks/consistency_models/consistency_model.py#L288

Added line #L288 was not covered by tests

def consistency_function(self, x: Tensor, t: Tensor, conditions: Tensor = None, training: bool = False) -> Tensor:
"""Compute consistency function.

Expand All @@ -271,12 +302,8 @@
Whether internal layers (e.g., dropout) should behave in train or inference mode.
"""

if conditions is not None:
xtc = ops.concatenate([x, t, conditions], axis=-1)
else:
xtc = ops.concatenate([x, t], axis=-1)

f = self.output_projector(self.subnet(xtc, training=training))
subnet_out = self._subnet_input(x, t, conditions, training=training)
f = self.output_projector(subnet_out)

# Compute skip and out parts (vectorized, since self.sigma2 is of shape (1, input_dim)
# Thus, we can do a cross product with the time vector which is (batch_size, 1) for
Expand Down Expand Up @@ -316,8 +343,8 @@

log_p = ops.log(p)
times = keras.random.categorical(ops.expand_dims(log_p, 0), ops.shape(x)[0], seed=self.seed_generator)[0]
t1 = ops.take(discretized_time, times)[..., None]
t2 = ops.take(discretized_time, times + 1)[..., None]
t1 = expand_right_as(ops.take(discretized_time, times), x)
t2 = expand_right_as(ops.take(discretized_time, times + 1), x)

# generate noise vector
noise = keras.random.normal(keras.ops.shape(x), dtype=keras.ops.dtype(x), seed=self.seed_generator)
Expand Down
44 changes: 37 additions & 7 deletions bayesflow/networks/flow_matching/flow_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
layer_kwargs,
optimal_transport,
weighted_mean,
tensor_utils,
)
from bayesflow.utils.serialization import serialize, deserialize, serializable
from ..inference_network import InferenceNetwork
Expand Down Expand Up @@ -107,6 +108,7 @@
subnet_kwargs = subnet_kwargs or {}
if subnet == "mlp":
subnet_kwargs = FlowMatching.MLP_DEFAULT_CONFIG | subnet_kwargs
self._concatenate_subnet_input = kwargs.get("concatenate_subnet_input", True)

self.subnet = find_network(subnet, **subnet_kwargs)
self.output_projector = keras.layers.Dense(units=None, bias_initializer="zeros", name="output_projector")
Expand Down Expand Up @@ -147,22 +149,50 @@
"loss_fn": self.loss_fn,
"integrate_kwargs": self.integrate_kwargs,
"optimal_transport_kwargs": self.optimal_transport_kwargs,
"concatenate_subnet_input": self._concatenate_subnet_input,
# we do not need to store subnet_kwargs
}

return base_config | serialize(config)

def _subnet_input(
self, x: Tensor, t: Tensor, conditions: Tensor = None, training: bool = False
) -> Tensor | tuple[Tensor, Tensor, Tensor]:
"""
Prepares the input for the subnet either by concatenating the latent variable `x`,
the time `t`, and optional conditions or by returning them separately.

Parameters
----------
x : Tensor
The input tensor for the diffusion model, typically of shape (..., D), but can vary.
t : Tensor
The time tensor, typically of shape (..., 1).
conditions : Tensor, optional
The optional conditioning tensor (e.g. parameters).
training : bool, optional
The training mode flag, which can be used to control behavior during training.

Returns
-------
Tensor
The concatenated input tensor for the subnet or a tuple of tensors if concatenation is disabled.
"""
if self._concatenate_subnet_input:
t = keras.ops.broadcast_to(t, keras.ops.shape(x)[:-1] + (1,))
xtc = tensor_utils.concatenate_valid([x, t, conditions], axis=-1)
return self.subnet(xtc, training=training)
else:
if training is False:
t = keras.ops.broadcast_to(t, keras.ops.shape(x)[:-1] + (1,))
return self.subnet(x, t, conditions, training=training)

Check warning on line 188 in bayesflow/networks/flow_matching/flow_matching.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/networks/flow_matching/flow_matching.py#L186-L188

Added lines #L186 - L188 were not covered by tests

def velocity(self, xz: Tensor, time: float | Tensor, conditions: Tensor = None, training: bool = False) -> Tensor:
time = keras.ops.convert_to_tensor(time, dtype=keras.ops.dtype(xz))
time = expand_right_as(time, xz)
time = keras.ops.broadcast_to(time, keras.ops.shape(xz)[:-1] + (1,))

if conditions is None:
xtc = keras.ops.concatenate([xz, time], axis=-1)
else:
xtc = keras.ops.concatenate([xz, time, conditions], axis=-1)

return self.output_projector(self.subnet(xtc, training=training), training=training)
subnet_out = self._subnet_input(xz, time, conditions, training=training)
return self.output_projector(subnet_out, training=training)

def _velocity_trace(
self, xz: Tensor, time: Tensor, conditions: Tensor = None, max_steps: int = None, training: bool = False
Expand Down