@@ -922,66 +922,6 @@ def get_peft_config(model_args: ModelConfig) -> "Optional[PeftConfig]":
922
922
923
923
return peft_config
924
924
925
- def get_lora_merged_state_dict (
926
- model : torch .nn .Module ,
927
- ) -> dict :
928
- r"""
929
- Create and return a state_dict that has the LoRA deltas
930
- merged into the base model’s weights, without modifying `model` in place.
931
-
932
- Arguments:
933
- model (torch.nn.Module): A model that has LoRA/PEFT adapters attached.
934
-
935
- Returns:
936
- dict: A state_dict of the merged parameters.
937
- """
938
-
939
- if not is_peft_available ():
940
- raise ValueError (
941
- "You need to have PEFT library installed in your environment, make sure to install `peft`. "
942
- "Make sure to run `pip install -U peft`."
943
- )
944
-
945
- base_model_prefix = "base_model.model."
946
- state_dict = {}
947
- key_list = [key for key , _ in model .named_modules () if model .prefix not in key ]
948
- for key in key_list :
949
- try :
950
- _ , target , _ = _get_submodules (model , key )
951
- except AttributeError :
952
- continue
953
- with onload_layer (target ):
954
- weight_key = key .replace (base_model_prefix , "" ) + ".weight"
955
- bias_key = key .replace (base_model_prefix , "" ) + ".bias"
956
- if hasattr (target , "base_layer" ):
957
- target .merge (safe_merge = True , adapter_names = None )
958
- # get the state_dict of target.base_layer
959
- layer_state_dict = target .base_layer .state_dict ()
960
- state_dict [weight_key ] = layer_state_dict ["weight" ]
961
- elif isinstance (target , ModulesToSaveWrapper ):
962
- # save any additional trainable modules part of `modules_to_save`
963
- new_module = target .modules_to_save [target .active_adapter ]
964
- if hasattr (new_module , "base_layer" ):
965
- # check if the module is itself a tuner layer
966
- new_module .merge (safe_merge = True , adapter_names = None )
967
- layer_state_dict = new_module .state_dict ()
968
- state_dict [weight_key ] = layer_state_dict ["weight" ]
969
- elif hasattr (target , "weight" ):
970
- if any (
971
- skip in key
972
- for skip in [
973
- ".original_module" ,
974
- ".modules_to_save" ,
975
- ".base_layer" ,
976
- ]
977
- ):
978
- continue
979
- layer_state_dict = target .state_dict ()
980
- state_dict [weight_key ] = layer_state_dict ["weight" ]
981
- if hasattr (target , "bias" ) and "bias" in layer_state_dict .keys ():
982
- state_dict [bias_key ] = layer_state_dict ["bias" ]
983
- return state_dict
984
-
985
925
986
926
def get_exp_cap (value , decimal = 4 ):
987
927
"""
0 commit comments