|
60 | 60 |
|
61 | 61 | if is_peft_available():
|
62 | 62 | from peft import LoraConfig, PeftConfig
|
| 63 | + from peft.tuners.tuners_utils import onload_layer |
| 64 | + from peft.utils import ModulesToSaveWrapper, _get_submodules |
63 | 65 |
|
64 | 66 |
|
65 | 67 | class DataCollatorForCompletionOnlyLM(DataCollatorForLanguageModeling):
|
@@ -920,6 +922,66 @@ def get_peft_config(model_args: ModelConfig) -> "Optional[PeftConfig]":
|
920 | 922 |
|
921 | 923 | return peft_config
|
922 | 924 |
|
| 925 | +def get_lora_merged_state_dict( |
| 926 | + model: torch.nn.Module, |
| 927 | +) -> dict: |
| 928 | + r""" |
| 929 | + Create and return a state_dict that has the LoRA deltas |
| 930 | + merged into the base model’s weights, without modifying `model` in place. |
| 931 | +
|
| 932 | + Arguments: |
| 933 | + model (torch.nn.Module): A model that has LoRA/PEFT adapters attached. |
| 934 | +
|
| 935 | + Returns: |
| 936 | + dict: A state_dict of the merged parameters. |
| 937 | + """ |
| 938 | + |
| 939 | + if not is_peft_available(): |
| 940 | + raise ValueError( |
| 941 | + "You need to have PEFT library installed in your environment, make sure to install `peft`. " |
| 942 | + "Make sure to run `pip install -U peft`." |
| 943 | + ) |
| 944 | + |
| 945 | + base_model_prefix = "base_model.model." |
| 946 | + state_dict = {} |
| 947 | + key_list = [key for key, _ in model.named_modules() if model.prefix not in key] |
| 948 | + for key in key_list: |
| 949 | + try: |
| 950 | + _, target, _ = _get_submodules(model, key) |
| 951 | + except AttributeError: |
| 952 | + continue |
| 953 | + with onload_layer(target): |
| 954 | + weight_key = key.replace(base_model_prefix, "") + ".weight" |
| 955 | + bias_key = key.replace(base_model_prefix, "") + ".bias" |
| 956 | + if hasattr(target, "base_layer"): |
| 957 | + target.merge(safe_merge=True, adapter_names=None) |
| 958 | + # get the state_dict of target.base_layer |
| 959 | + layer_state_dict = target.base_layer.state_dict() |
| 960 | + state_dict[weight_key] = layer_state_dict["weight"] |
| 961 | + elif isinstance(target, ModulesToSaveWrapper): |
| 962 | + # save any additional trainable modules part of `modules_to_save` |
| 963 | + new_module = target.modules_to_save[target.active_adapter] |
| 964 | + if hasattr(new_module, "base_layer"): |
| 965 | + # check if the module is itself a tuner layer |
| 966 | + new_module.merge(safe_merge=True, adapter_names=None) |
| 967 | + layer_state_dict = new_module.state_dict() |
| 968 | + state_dict[weight_key] = layer_state_dict["weight"] |
| 969 | + elif hasattr(target, "weight"): |
| 970 | + if any( |
| 971 | + skip in key |
| 972 | + for skip in [ |
| 973 | + ".original_module", |
| 974 | + ".modules_to_save", |
| 975 | + ".base_layer", |
| 976 | + ] |
| 977 | + ): |
| 978 | + continue |
| 979 | + layer_state_dict = target.state_dict() |
| 980 | + state_dict[weight_key] = layer_state_dict["weight"] |
| 981 | + if hasattr(target, "bias") and "bias" in layer_state_dict.keys(): |
| 982 | + state_dict[bias_key] = layer_state_dict["bias"] |
| 983 | + return state_dict |
| 984 | + |
923 | 985 |
|
924 | 986 | def get_exp_cap(value, decimal=4):
|
925 | 987 | """
|
|
0 commit comments