Skip to content

Commit a91d4b3

Browse files
Prevent upcasting norm layers in prepare_model_for_kbit_training (#4457)
Co-authored-by: Quentin Gallouédec <[email protected]>
1 parent 121318e commit a91d4b3

File tree

1 file changed

+1
-9
lines changed

1 file changed

+1
-9
lines changed

trl/models/utils.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -485,18 +485,10 @@ def prepare_model_for_kbit_training(model, use_gradient_checkpointing=True, grad
485485
if gradient_checkpointing_kwargs is None:
486486
gradient_checkpointing_kwargs = {}
487487

488-
n_upcasted = 0
489-
for name, param in model.named_parameters():
488+
for _, param in model.named_parameters():
490489
# freeze all parameters
491490
param.requires_grad = False
492491

493-
# upcast LayerNorm / Norm to float32 for numerical stability
494-
if (param.dtype in [torch.float16, torch.bfloat16]) and (
495-
"norm" in name.lower() or "layernorm" in name.lower()
496-
):
497-
param.data = param.data.to(torch.float32)
498-
n_upcasted += 1
499-
500492
# Enable gradient checkpointing if needed
501493
if (loaded_in_kbit or is_quantized) and use_gradient_checkpointing:
502494
if hasattr(model, "enable_input_require_grads"):

0 commit comments

Comments
 (0)