|
22 | 22 | import torch
|
23 | 23 | import torch.utils.data
|
24 | 24 | import transformers
|
25 |
| -from accelerate.utils import broadcast_object_list, gather, gather_object, set_seed |
| 25 | +from accelerate.utils import broadcast_object_list, gather, gather_object, is_peft_model, set_seed |
26 | 26 | from accelerate.utils.other import is_compiled_module
|
27 | 27 | from datasets import Dataset, IterableDataset
|
28 | 28 | from packaging import version
|
@@ -491,21 +491,20 @@ def _move_model_to_vllm(self):
|
491 | 491 | self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
|
492 | 492 | ) as unwrapped_model:
|
493 | 493 | if is_compiled_module(unwrapped_model):
|
494 |
| - state_dict = unwrapped_model._orig_mod.state_dict() |
495 |
| - elif isinstance(unwrapped_model, PeftModel): |
| 494 | + unwrapped_model = unwrapped_model._orig_mod |
| 495 | + if is_peft_model(unwrapped_model): |
496 | 496 | unwrapped_model.merge_adapter()
|
497 | 497 | state_dict = unwrapped_model.state_dict()
|
498 |
| - unwrapped_model.unmerge_adapter() |
499 | 498 | state_dict = {
|
500 |
| - k.removeprefix("base_model.model.").replace(".base_layer", ""): v |
| 499 | + k.removeprefix("base_model.model.") |
| 500 | + .removeprefix("base_model.model.") |
| 501 | + .replace(".default", "") |
| 502 | + .replace(".base_layer", "") |
| 503 | + .replace(".modules_to_save", ""): v |
501 | 504 | for k, v in state_dict.items()
|
502 |
| - if self.model.prefix not in k |
503 |
| - } |
504 |
| - state_dict = { |
505 |
| - k.replace("modules_to_save.default.", ""): v |
506 |
| - for k, v in state_dict.items() |
507 |
| - if "original_module" not in k |
| 505 | + if unwrapped_model.prefix not in k and "original_module" not in k |
508 | 506 | }
|
| 507 | + unwrapped_model.unmerge_adapter() |
509 | 508 | else:
|
510 | 509 | state_dict = unwrapped_model.state_dict()
|
511 | 510 | if self.accelerator.is_main_process:
|
|
0 commit comments