Reproduction
In nash_md_trainer.py Line 171, function _generate_completions:
def _generate_completions(self, model, prompts):
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
model_output = unwrapped_model.generate(
input_ids=prompts["input_ids"],
attention_mask=prompts["attention_mask"],
generation_config=self.generation_config,
)
ref_model = model if self.ref_model is None else self.ref_model
with torch.no_grad(), unwrap_model_for_generation(ref_model, self.accelerator) as unwrapped_ref_model:
mixture_model = GeometricMixtureWrapper(
model=unwrapped_model,
ref_model=unwrapped_ref_model,
generation_config=self.generation_config,
mixture_coef=self.mixture_coef,
device=self.accelerator.device,
)
mixture_output = mixture_model.generate(
input_ids=prompts["input_ids"],
attention_mask=prompts["attention_mask"],
generation_config=self.generation_config,
)
return model_output, mixture_output
When we run the trainer with model being a PEFT model, and no self.ref_model provided (which means disabling the adapter gives the reference model), the unwrapped_ref_model in this function will be the model itself. A correct way is to pass is_peft_model=<model is peft model> in the call unwrap_model_for_generation(ref_model, self.accelerator).
System Info
- Platform: Linux-5.14.0-503.22.1.el9_5.x86_64-x86_64-with-glibc2.34
- Python version: 3.10.13
- PyTorch version: 2.2.1
- CUDA device(s): NVIDIA RTX A6000, NVIDIA RTX A6000, NVIDIA RTX A6000, NVIDIA RTX A6000, NVIDIA RTX A6000, NVIDIA RTX A6000, NVIDIA RTX A6000, NVIDIA RTX A6000
- Transformers version: 4.48.0
- Accelerate version: 1.2.1
- Accelerate config:
- compute_environment: LOCAL_MACHINE
- distributed_type: MULTI_GPU
- mixed_precision: bf16
- use_cpu: False
- debug: False
- num_processes: 8
- machine_rank: 0
- num_machines: 1
- gpu_ids: all
- rdzv_backend: static
- same_network: True
- main_training_function: main
- enable_cpu_affinity: False
- downcast_bf16: no
- tpu_use_cluster: False
- tpu_use_sudo: False
- tpu_env: []
- Datasets version: 3.2.0
- HF Hub version: 0.27.1
- TRL version: 0.13.0
- bitsandbytes version: not installed
- DeepSpeed version: 0.16.2
- Diffusers version: not installed
- Liger-Kernel version: not installed
- LLM-Blender version: 0.0.2
- OpenAI version: not installed
- PEFT version: 0.14.0
Checklist
Reproduction
In
nash_md_trainer.pyLine 171, function_generate_completions:When we run the trainer with
modelbeing a PEFT model, and noself.ref_modelprovided (which means disabling the adapter gives the reference model), theunwrapped_ref_modelin this function will be the model itself. A correct way is to passis_peft_model=<model is peft model>in the callunwrap_model_for_generation(ref_model, self.accelerator).System Info
Checklist