FIX: Correct adapter dtype with bnb weights #2893
Draft
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Resolves #2889
Description
The reported bug is this: When the base model is quantized with 4bit bitsandbytes, the adapter weights would be cast to float32, even if
autocast_adapter_dtype=Falsewas passed. This is because the dtype of the base layer was not correctly determined in that case. This PR now correctly determines the dtype.While working on this, I noticed that the
peft_model.add_adaptermethod was lacking the option to disable autocasting. This was added now and the tests cover it as well. I also updated some of the corresponding docstrings.Tangential changes
An unrelated issue I noticed is when I was debugging: At one point, OSF calls
if not hasattr(module, "osf_svd_params"). This would error when the module was aModulesToSaveWrapperbecauseModulesToSaveWrapper._hasattr_wrappedwas not taking into account the case that there is no active adapter. This is now fixed too.Moreover, OSF implemented its own
_cast_adapter_dtype. This would basically bypass upcasting to float32 of the OSF weights if the base model is loaded in lower precision. However, unless the user explicitly passesautocast_adapter_dtype=False, the default in PEFT is to upcast the adapter weights to float32. With the changes to this PR, upcasting is now done. To make this work with theforwardpass, thexis cast to the dtype of theweight. We assume that the output dtype should be the same as the original dtype ofx.TODOs
There is still an issue left with 8bit bnb weights. They don't have a compute dtype, so at a layer level, it is not possible to determine what the dtype of the PEFT adapter should be (of course, it cannot be int8). Therefore, the corresponding tests for 8bit bnb are x-failing for now. One possible solution could be to pass down the dtype of the base model (if any) and use that as a fallback. This could be implemented in a later PR.