Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.

Commit 5672a90

Browse files
authored
Merge pull request #495 from Project-MONAI/494-clipping-min-and-max-values-on-scheduler
Allowing sample clipping values besides clip_sample True/False
2 parents cb464ad + ebfc3fb commit 5672a90

File tree

4 files changed

+28
-5
lines changed

4 files changed

+28
-5
lines changed

generative/networks/nets/patchgan_discriminator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,8 @@ def __init__(
8787
pool = None
8888
else:
8989
pool = get_pool_layer(
90-
(pooling_method, {"kernel_size": kernel_size, "stride": 2, 'padding': self.padding}), spatial_dims=spatial_dims
90+
(pooling_method, {"kernel_size": kernel_size, "stride": 2, "padding": self.padding}),
91+
spatial_dims=spatial_dims,
9192
)
9293
self.num_channels = num_channels
9394
for i_ in range(self.num_d):

generative/networks/nets/vqvae.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import torch.nn as nn
1818
from monai.networks.blocks import Convolution
1919
from monai.networks.layers import Act
20-
from monai.utils import ensure_tuple_rep
20+
from monai.utils.misc import ensure_tuple_rep
2121

2222
from generative.networks.layers.vector_quantizer import EMAQuantizer, VectorQuantizer
2323

generative/networks/schedulers/ddim.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@ class DDIMScheduler(Scheduler):
7070
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
7171
stable diffusion.
7272
prediction_type: member of DDPMPredictionType
73+
clip_sample_min: if clip_sample is True, minimum value to clamp the prediction by.
74+
clip_sample_max: if clip_sample is False, maximum value to clamp the prediction by.
7375
schedule_args: arguments to pass to the schedule function
7476
7577
"""
@@ -82,13 +84,18 @@ def __init__(
8284
set_alpha_to_one: bool = True,
8385
steps_offset: int = 0,
8486
prediction_type: str = DDIMPredictionType.EPSILON,
87+
clip_sample_min: int = -1,
88+
clip_sample_max: int = 1,
8589
**schedule_args,
8690
) -> None:
8791
super().__init__(num_train_timesteps, schedule, **schedule_args)
8892

8993
if prediction_type not in DDIMPredictionType.__members__.values():
9094
raise ValueError("Argument `prediction_type` must be a member of DDIMPredictionType")
9195

96+
if clip_sample_min >= clip_sample_max:
97+
raise ValueError("clip_sample_min must be < clip_sample_max")
98+
9299
self.prediction_type = prediction_type
93100

94101
# At every step in ddim, we are looking into the previous alphas_cumprod
@@ -107,6 +114,7 @@ def __init__(
107114
self.timesteps = torch.from_numpy(np.arange(0, self.num_train_timesteps)[::-1].astype(np.int64))
108115

109116
self.clip_sample = clip_sample
117+
self.clip_sample_values = [clip_sample_min, clip_sample_max]
110118
self.steps_offset = steps_offset
111119

112120
# default the number of inference timesteps to the number of train steps
@@ -203,7 +211,9 @@ def step(
203211

204212
# 4. Clip "predicted x_0"
205213
if self.clip_sample:
206-
pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
214+
pred_original_sample = torch.clamp(
215+
pred_original_sample, self.clip_sample_values[0], self.clip_sample_values[1]
216+
)
207217

208218
# 5. compute variance: "sigma_t(η)" -> see formula (16)
209219
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
@@ -278,7 +288,9 @@ def reversed_step(
278288

279289
# 4. Clip "predicted x_0"
280290
if self.clip_sample:
281-
pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
291+
pred_original_sample = torch.clamp(
292+
pred_original_sample, self.clip_sample_values[0], self.clip_sample_values[1]
293+
)
282294

283295
# 5. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
284296
pred_sample_direction = (1 - alpha_prod_t_next) ** (0.5) * pred_epsilon

generative/networks/schedulers/ddpm.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@ class DDPMScheduler(Scheduler):
7676
variance_type: member of DDPMVarianceType
7777
clip_sample: option to clip predicted sample between -1 and 1 for numerical stability.
7878
prediction_type: member of DDPMPredictionType
79+
clip_sample_min: if clip_sample is True, minimum value to clamp the prediction by.
80+
clip_sample_max: if clip_sample is False, maximum value to clamp the prediction by.
7981
schedule_args: arguments to pass to the schedule function
8082
"""
8183

@@ -86,6 +88,8 @@ def __init__(
8688
variance_type: str = DDPMVarianceType.FIXED_SMALL,
8789
clip_sample: bool = True,
8890
prediction_type: str = DDPMPredictionType.EPSILON,
91+
clip_sample_min: int = -1,
92+
clip_sample_max: int = 1,
8993
**schedule_args,
9094
) -> None:
9195
super().__init__(num_train_timesteps, schedule, **schedule_args)
@@ -96,9 +100,13 @@ def __init__(
96100
if prediction_type not in DDPMPredictionType.__members__.values():
97101
raise ValueError("Argument `prediction_type` must be a member of `DDPMPredictionType`")
98102

103+
if clip_sample_min >= clip_sample_max:
104+
raise ValueError("clip_sample_min must be < clip_sample_max")
105+
99106
self.clip_sample = clip_sample
100107
self.variance_type = variance_type
101108
self.prediction_type = prediction_type
109+
self.clip_sample_values = [clip_sample_min, clip_sample_max]
102110

103111
def set_timesteps(self, num_inference_steps: int, device: str | torch.device | None = None) -> None:
104112
"""
@@ -218,7 +226,9 @@ def step(
218226

219227
# 3. Clip "predicted x_0"
220228
if self.clip_sample:
221-
pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
229+
pred_original_sample = torch.clamp(
230+
pred_original_sample, self.clip_sample_values[0], self.clip_sample_values[1]
231+
)
222232

223233
# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
224234
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf

0 commit comments

Comments
 (0)