36
36
37
37
def set_seed (seed : int ):
38
38
"""
39
- Args:
40
39
Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`.
40
+
41
+ Args:
41
42
seed (`int`): The seed to set.
42
43
"""
43
44
random .seed (seed )
@@ -194,6 +195,13 @@ def unet_lora_state_dict(unet: UNet2DConditionModel) -> Dict[str, torch.Tensor]:
194
195
195
196
196
197
def cast_training_params (model : Union [torch .nn .Module , List [torch .nn .Module ]], dtype = torch .float32 ):
198
+ """
199
+ Casts the training parameters of the model to the specified data type.
200
+
201
+ Args:
202
+ model: The PyTorch model whose parameters will be cast.
203
+ dtype: The data type to which the model parameters will be cast.
204
+ """
197
205
if not isinstance (model , list ):
198
206
model = [model ]
199
207
for m in model :
@@ -225,7 +233,8 @@ def _set_state_dict_into_text_encoder(
225
233
def compute_density_for_timestep_sampling (
226
234
weighting_scheme : str , batch_size : int , logit_mean : float = None , logit_std : float = None , mode_scale : float = None
227
235
):
228
- """Compute the density for sampling the timesteps when doing SD3 training.
236
+ """
237
+ Compute the density for sampling the timesteps when doing SD3 training.
229
238
230
239
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
231
240
@@ -244,7 +253,8 @@ def compute_density_for_timestep_sampling(
244
253
245
254
246
255
def compute_loss_weighting_for_sd3 (weighting_scheme : str , sigmas = None ):
247
- """Computes loss weighting scheme for SD3 training.
256
+ """
257
+ Computes loss weighting scheme for SD3 training.
248
258
249
259
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
250
260
@@ -261,7 +271,9 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
261
271
262
272
263
273
def free_memory ():
264
- """Runs garbage collection. Then clears the cache of the available accelerator."""
274
+ """
275
+ Runs garbage collection. Then clears the cache of the available accelerator.
276
+ """
265
277
gc .collect ()
266
278
267
279
if torch .cuda .is_available ():
@@ -494,7 +506,8 @@ def pin_memory(self) -> None:
494
506
self .shadow_params = [p .pin_memory () for p in self .shadow_params ]
495
507
496
508
def to (self , device = None , dtype = None , non_blocking = False ) -> None :
497
- r"""Move internal buffers of the ExponentialMovingAverage to `device`.
509
+ r"""
510
+ Move internal buffers of the ExponentialMovingAverage to `device`.
498
511
499
512
Args:
500
513
device: like `device` argument to `torch.Tensor.to`
@@ -528,23 +541,25 @@ def state_dict(self) -> dict:
528
541
529
542
def store (self , parameters : Iterable [torch .nn .Parameter ]) -> None :
530
543
r"""
544
+ Saves the current parameters for restoring later.
545
+
531
546
Args:
532
- Save the current parameters for restoring later.
533
- parameters: Iterable of `torch.nn.Parameter`; the parameters to be
534
- temporarily stored.
547
+ parameters: Iterable of `torch.nn.Parameter`. The parameters to be temporarily stored.
535
548
"""
536
549
self .temp_stored_params = [param .detach ().cpu ().clone () for param in parameters ]
537
550
538
551
def restore (self , parameters : Iterable [torch .nn .Parameter ]) -> None :
539
552
r"""
540
- Args:
541
- Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters without:
542
- affecting the original optimization process. Store the parameters before the `copy_to()` method. After
553
+ Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters
554
+ without: affecting the original optimization process. Store the parameters before the `copy_to()` method. After
543
555
validation (or model saving), use this to restore the former parameters.
556
+
557
+ Args:
544
558
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
545
559
updated with the stored parameters. If `None`, the parameters with which this
546
560
`ExponentialMovingAverage` was initialized will be used.
547
561
"""
562
+
548
563
if self .temp_stored_params is None :
549
564
raise RuntimeError ("This ExponentialMovingAverage has no `store()`ed weights " "to `restore()`" )
550
565
if self .foreach :
@@ -560,9 +575,10 @@ def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None:
560
575
561
576
def load_state_dict (self , state_dict : dict ) -> None :
562
577
r"""
563
- Args:
564
578
Loads the ExponentialMovingAverage state. This method is used by accelerate during checkpointing to save the
565
579
ema state dict.
580
+
581
+ Args:
566
582
state_dict (dict): EMA state. Should be an object returned
567
583
from a call to :meth:`state_dict`.
568
584
"""
0 commit comments