Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion src/peft/mixed_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,13 @@ def disable_adapter(self):
finally:
self.base_model.enable_adapter_layers()

def add_adapter(self, adapter_name: str, peft_config: PeftConfig, low_cpu_mem_usage: bool = False) -> None:
def add_adapter(
self,
adapter_name: str,
peft_config: PeftConfig,
low_cpu_mem_usage: bool = False,
autocast_adapter_dtype: bool = True,
) -> None:
"""
Add an adapter to the model based on the passed configuration.

Expand All @@ -222,6 +228,11 @@ def add_adapter(self, adapter_name: str, peft_config: PeftConfig, low_cpu_mem_us

> [!TIP] > Don't use `low_cpu_mem_usage=True` when creating a new PEFT adapter for training (training
is untested > and discouraged for PeftMixedModel in general).
autocast_adapter_dtype (`bool`, *optional*, defaults to `True`):
Whether to autocast the adapter dtype. Defaults to `True`. Right now, this will only cast adapter
weights using float16 and bfloat16 to float32, as this is typically required for stable training, and
only affect select PEFT tuners. If set to `False`, the dtypes will stay the same as those of the
corresponding layer.
"""
_check_config_compatible(peft_config)

Expand All @@ -233,6 +244,8 @@ def add_adapter(self, adapter_name: str, peft_config: PeftConfig, low_cpu_mem_us
del self.peft_config[adapter_name]
raise

self.base_model._cast_adapter_dtype(adapter_name=adapter_name, autocast_adapter_dtype=autocast_adapter_dtype)

self.set_modules_to_save(peft_config, adapter_name)

def set_modules_to_save(self, peft_config: PeftConfig, adapter_name: str) -> None:
Expand Down
105 changes: 81 additions & 24 deletions src/peft/peft_model.py

Large diffs are not rendered by default.

5 changes: 5 additions & 0 deletions src/peft/tuners/osf/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,12 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
active_adapter = self.active_adapters[0] if self.active_adapters else None
if active_adapter and active_adapter in self.osf_svd_params:
weight = self._reconstruct_weight(active_adapter)
orig_dtype = x.dtype # assume that the intended dtype is that of the input
x = self._cast_input_dtype(x, weight.dtype)
if bias is not None:
bias = bias.to(weight.dtype)
result = F.linear(x, weight, bias)
result = result.to(orig_dtype)
else:
result = self.base_layer(x, *args, **kwargs)

Expand Down
36 changes: 0 additions & 36 deletions src/peft/tuners/osf/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,42 +118,6 @@ def _mark_only_adapters_as_trainable(self, model: nn.Module) -> None:
if "osf_svd_params" not in n:
p.requires_grad = False

def _cast_adapter_dtype(self, adapter_name: str, autocast_adapter_dtype: bool = True) -> None:
"""
Ensure all OSF adapter components have consistent dtype with the base model.

Instead of forcing float32, we match the base model's actual dtype for consistency.
"""
if not autocast_adapter_dtype:
return

for module in self.model.modules():
if not hasattr(module, "osf_svd_params"):
continue

# Get target dtype from base layer weight
base_layer = getattr(module, "base_layer", None)
if base_layer is None or not hasattr(base_layer, "weight"):
continue

target_dtype = base_layer.weight.dtype

# Cast trainable low-rank parameters to match base model dtype
if adapter_name in module.osf_svd_params:
svd_params = module.osf_svd_params[adapter_name]
for param_name, param in svd_params.items():
if param.dtype != target_dtype:
param.data = param.data.to(target_dtype)

# Cast frozen high-rank buffers to match base model dtype
for buffer_dict_name in OSFLayer.other_param_names:
if hasattr(module, buffer_dict_name):
buffer_dict = getattr(module, buffer_dict_name)
if adapter_name in buffer_dict:
buffer = buffer_dict[adapter_name]
if buffer.dtype != target_dtype:
buffer_dict[adapter_name] = buffer.to(target_dtype)

# Use BaseTuner's merge and merge_and_unload implementations.
# Explicitly disallow unmerging at the model level for OSF.
def unmerge_adapter(self, *args, **kwargs):
Expand Down
57 changes: 38 additions & 19 deletions src/peft/tuners/tuners_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1507,24 +1507,44 @@ def set_requires_grad(self, adapter_names: str | Sequence[str], requires_grad: b
if key in adapter_names_set:
layer.requires_grad_(requires_grad)

def _get_base_layer_device_and_dtype(self, base_layer):
"""
Helper function to determine the device and dtype of the base layer. If not possible to determine, return None.
"""
device, dtype = None, None

# check weight and qweight (for GPTQ)
for weight_name in ("weight", "qweight"):
weight = getattr(base_layer, weight_name, None)
if weight is not None:
device = weight.device
dtype = weight.dtype
break

if hasattr(base_layer, "compute_dtype"): # bnb Linear4bitLt
dtype = base_layer.compute_dtype

return device, dtype

def _move_adapter_to_device_of_base_layer(self, adapter_name: str, device: Optional[torch.device] = None) -> None:
"""
Move the adapter of the given name to the device of the base layer.
Move the adapter of the given name to the device, and possibly dtype, of the base layer.
"""
if device is None:
base_layer = self.get_base_layer()
if isinstance(base_layer, nn.MultiheadAttention):
base_layer = base_layer.out_proj
# check weight and qweight (for GPTQ)
for weight_name in ("weight", "qweight"):
weight = getattr(base_layer, weight_name, None)
if weight is not None:
device = weight.device
dtype = weight.dtype
break
else:
# no break encountered: could not determine the device
return
base_layer = self.get_base_layer()
if isinstance(base_layer, nn.MultiheadAttention):
base_layer = base_layer.out_proj
base_layer_device, base_layer_dtype = self._get_base_layer_device_and_dtype(base_layer)

target_device = device if device is not None else base_layer_device
if target_device is None:
# could not determine device
return

target_dtype = None
if base_layer_dtype is not None:
# don't cast to int dtype
if base_layer_dtype.is_floating_point or base_layer_dtype.is_complex:
target_dtype = base_layer_dtype

meta = torch.device("meta")

Expand All @@ -1540,11 +1560,10 @@ def _move_adapter_to_device_of_base_layer(self, adapter_name: str, device: Optio
if any(p.device == meta for p in adapter_layer.parameters()):
continue

# TODO: weight is not necessarily defined here, leading to a NameError, fix that
if weight.dtype.is_floating_point or weight.dtype.is_complex:
adapter_layer[adapter_name] = adapter_layer[adapter_name].to(device, dtype=dtype)
if target_dtype is not None:
adapter_layer[adapter_name] = adapter_layer[adapter_name].to(target_device, dtype=target_dtype)
else:
adapter_layer[adapter_name] = adapter_layer[adapter_name].to(device)
adapter_layer[adapter_name] = adapter_layer[adapter_name].to(target_device)

@overload
def _cast_input_dtype(self, x: None, dtype: torch.dtype) -> None: ...
Expand Down
2 changes: 2 additions & 0 deletions src/peft/utils/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,8 @@ def _forward_wrapped_passthrough(self, x, *args, **kwargs):
return self.original_module(x, *args, **kwargs)

def _hasattr_wrapped(self, name, modules):
if not self.active_adapters:
return False
return self.active_adapters[0] in modules["modules_to_save"]

def _getattr_wrapped(self, name, modules):
Expand Down
65 changes: 65 additions & 0 deletions tests/test_custom_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1990,6 +1990,71 @@ def test_forward_bfloat16_no_autocast(self, test_name, model_id, config_cls, con
model = model.merge_and_unload()
model(**X)

@pytest.mark.parametrize("test_name, model_id, config_cls, config_kwargs", TEST_CASES)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
@pytest.mark.parametrize("autocast_adapter_dtype", [True, False])
@pytest.mark.parametrize("low_cpu_mem_usage", [False, True])
def test_adapter_dtype_autocast(
self,
test_name,
model_id,
config_cls,
config_kwargs,
dtype,
autocast_adapter_dtype,
low_cpu_mem_usage,
tmp_path,
):
"""checks that the dtype of the PEFT adapter corresponds to the expected dtype.

Checks:
- get_peft_model
- add_adapter
- PeftModel.from_pretrained
- load_adapter
- with and without autocasting adapter dtype
- with and without low_cpu_mem_usage (which only makes sense for loading adapters)
"""
if autocast_adapter_dtype and (config_cls == LNTuningConfig):
# LN Tuning basically copies the base weight and makes it trainable, hence it makes sense to keep the dtype
# of the base model weight.
pytest.skip("LNTuning and OSF are exempted from casting the adapter weights to float32")

if autocast_adapter_dtype:
expected_dtype = torch.float32
else:
expected_dtype = dtype

model = self.transformers_class.from_pretrained(model_id, dtype=dtype).to(self.torch_device)
config = config_cls(
base_model_name_or_path=model_id,
**config_kwargs,
)
model = get_peft_model(model, config, autocast_adapter_dtype=autocast_adapter_dtype)
if config_kwargs.get("target_parameters", None) is None:
# target_parameters does not allow multiple adapters on the same parameter
model.add_adapter("other", config, autocast_adapter_dtype=autocast_adapter_dtype)
peft_params = [p for n, p in model.named_parameters() if model.prefix in n]
assert all(p.dtype == expected_dtype for p in peft_params)

model.save_pretrained(tmp_path)
del model

model = self.transformers_class.from_pretrained(model_id, dtype=dtype).to(self.torch_device)
model = PeftModel.from_pretrained(
model, tmp_path, autocast_adapter_dtype=autocast_adapter_dtype, low_cpu_mem_usage=low_cpu_mem_usage
)
if config_kwargs.get("target_parameters", None) is None:
# target_parameters does not allow multiple adapters on the same parameter
model.load_adapter(
tmp_path / "other",
adapter_name="other",
autocast_adapter_dtype=autocast_adapter_dtype,
low_cpu_mem_usage=low_cpu_mem_usage,
)
peft_params = [p for n, p in model.named_parameters() if model.prefix in n]
assert all(p.dtype == expected_dtype for p in peft_params)

@pytest.mark.parametrize("test_name, model_id, config_cls, config_kwargs", TEST_CASES)
def test_only_params_are_updated(self, test_name, model_id, config_cls, config_kwargs):
# An explicit test that when using an adapter on a custom model, only the adapter parameters are updated during
Expand Down
Loading
Loading