Skip to content

Commit a8e4797

Browse files
authored
[lora] adapt new LoRA config injection method (#11999)
* use state dict when setting up LoRA. * up * up * up * comment * up * up
1 parent 50e18ee commit a8e4797

File tree

6 files changed

+35
-109
lines changed

6 files changed

+35
-109
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@
116116
"librosa",
117117
"numpy",
118118
"parameterized",
119-
"peft>=0.15.0",
119+
"peft>=0.17.0",
120120
"protobuf>=3.20.3,<4",
121121
"pytest",
122122
"pytest-timeout",

src/diffusers/dependency_versions_table.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
"librosa": "librosa",
2424
"numpy": "numpy",
2525
"parameterized": "parameterized",
26-
"peft": "peft>=0.15.0",
26+
"peft": "peft>=0.17.0",
2727
"protobuf": "protobuf>=3.20.3,<4",
2828
"pytest": "pytest",
2929
"pytest-timeout": "pytest-timeout",

src/diffusers/loaders/peft.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,9 @@ def map_state_dict_for_hotswap(sd):
320320
# it to None
321321
incompatible_keys = None
322322
else:
323-
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs)
323+
inject_adapter_in_model(
324+
lora_config, self, adapter_name=adapter_name, state_dict=state_dict, **peft_kwargs
325+
)
324326
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
325327

326328
if self._prepare_lora_hotswap_kwargs is not None:

src/diffusers/utils/peft_utils.py

Lines changed: 0 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -197,20 +197,6 @@ def get_peft_kwargs(
197197
"lora_bias": lora_bias,
198198
}
199199

200-
# Example: try load FusionX LoRA into Wan VACE
201-
exclude_modules = _derive_exclude_modules(model_state_dict, peft_state_dict, adapter_name)
202-
if exclude_modules:
203-
if not is_peft_version(">=", "0.14.0"):
204-
msg = """
205-
It seems like there are certain modules that need to be excluded when initializing `LoraConfig`. Your current `peft`
206-
version doesn't support passing an `exclude_modules` to `LoraConfig`. Please update it by running `pip install -U
207-
peft`. For most cases, this can be completely ignored. But if it seems unexpected, please file an issue -
208-
https://github.com/huggingface/diffusers/issues/new
209-
"""
210-
logger.debug(msg)
211-
else:
212-
lora_config_kwargs.update({"exclude_modules": exclude_modules})
213-
214200
return lora_config_kwargs
215201

216202

@@ -388,27 +374,3 @@ def _maybe_warn_for_unhandled_keys(incompatible_keys, adapter_name):
388374

389375
if warn_msg:
390376
logger.warning(warn_msg)
391-
392-
393-
def _derive_exclude_modules(model_state_dict, peft_state_dict, adapter_name=None):
394-
"""
395-
Derives the modules to exclude while initializing `LoraConfig` through `exclude_modules`. It works by comparing the
396-
`model_state_dict` and `peft_state_dict` and adds a module from `model_state_dict` to the exclusion set if it
397-
doesn't exist in `peft_state_dict`.
398-
"""
399-
if model_state_dict is None:
400-
return
401-
all_modules = set()
402-
string_to_replace = f"{adapter_name}." if adapter_name else ""
403-
404-
for name in model_state_dict.keys():
405-
if string_to_replace:
406-
name = name.replace(string_to_replace, "")
407-
if "." in name:
408-
module_name = name.rsplit(".", 1)[0]
409-
all_modules.add(module_name)
410-
411-
target_modules_set = {name.split(".lora")[0] for name in peft_state_dict.keys()}
412-
exclude_modules = list(all_modules - target_modules_set)
413-
414-
return exclude_modules

tests/lora/utils.py

Lines changed: 0 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
import copy
1615
import inspect
1716
import os
1817
import re
@@ -292,20 +291,6 @@ def _get_modules_to_save(self, pipe, has_denoiser=False):
292291

293292
return modules_to_save
294293

295-
def _get_exclude_modules(self, pipe):
296-
from diffusers.utils.peft_utils import _derive_exclude_modules
297-
298-
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
299-
denoiser = "unet" if self.unet_kwargs is not None else "transformer"
300-
modules_to_save = {k: v for k, v in modules_to_save.items() if k == denoiser}
301-
denoiser_lora_state_dict = self._get_lora_state_dicts(modules_to_save)[f"{denoiser}_lora_layers"]
302-
pipe.unload_lora_weights()
303-
denoiser_state_dict = pipe.unet.state_dict() if self.unet_kwargs is not None else pipe.transformer.state_dict()
304-
exclude_modules = _derive_exclude_modules(
305-
denoiser_state_dict, denoiser_lora_state_dict, adapter_name="default"
306-
)
307-
return exclude_modules
308-
309294
def add_adapters_to_pipeline(self, pipe, text_lora_config=None, denoiser_lora_config=None, adapter_name="default"):
310295
if text_lora_config is not None:
311296
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
@@ -2342,58 +2327,6 @@ def test_lora_unload_add_adapter(self):
23422327
)
23432328
_ = pipe(**inputs, generator=torch.manual_seed(0))[0]
23442329

2345-
@require_peft_version_greater("0.13.2")
2346-
def test_lora_exclude_modules(self):
2347-
"""
2348-
Test to check if `exclude_modules` works or not. It works in the following way:
2349-
we first create a pipeline and insert LoRA config into it. We then derive a `set`
2350-
of modules to exclude by investigating its denoiser state dict and denoiser LoRA
2351-
state dict.
2352-
2353-
We then create a new LoRA config to include the `exclude_modules` and perform tests.
2354-
"""
2355-
scheduler_cls = self.scheduler_classes[0]
2356-
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
2357-
pipe = self.pipeline_class(**components).to(torch_device)
2358-
_, _, inputs = self.get_dummy_inputs(with_generator=False)
2359-
2360-
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
2361-
self.assertTrue(output_no_lora.shape == self.output_shape)
2362-
2363-
# only supported for `denoiser` now
2364-
pipe_cp = copy.deepcopy(pipe)
2365-
pipe_cp, _ = self.add_adapters_to_pipeline(
2366-
pipe_cp, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
2367-
)
2368-
denoiser_exclude_modules = self._get_exclude_modules(pipe_cp)
2369-
pipe_cp.to("cpu")
2370-
del pipe_cp
2371-
2372-
denoiser_lora_config.exclude_modules = denoiser_exclude_modules
2373-
pipe, _ = self.add_adapters_to_pipeline(
2374-
pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
2375-
)
2376-
output_lora_exclude_modules = pipe(**inputs, generator=torch.manual_seed(0))[0]
2377-
2378-
with tempfile.TemporaryDirectory() as tmpdir:
2379-
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
2380-
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
2381-
lora_metadatas = self._get_lora_adapter_metadata(modules_to_save)
2382-
self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts, **lora_metadatas)
2383-
pipe.unload_lora_weights()
2384-
pipe.load_lora_weights(tmpdir)
2385-
2386-
output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
2387-
2388-
self.assertTrue(
2389-
not np.allclose(output_no_lora, output_lora_exclude_modules, atol=1e-3, rtol=1e-3),
2390-
"LoRA should change outputs.",
2391-
)
2392-
self.assertTrue(
2393-
np.allclose(output_lora_exclude_modules, output_lora_pretrained, atol=1e-3, rtol=1e-3),
2394-
"Lora outputs should match.",
2395-
)
2396-
23972330
def test_inference_load_delete_load_adapters(self):
23982331
"Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works."
23992332
for scheduler_cls in self.scheduler_classes:

tests/models/transformers/test_models_transformer_flux.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from diffusers import FluxTransformer2DModel
2121
from diffusers.models.attention_processor import FluxIPAdapterJointAttnProcessor2_0
2222
from diffusers.models.embeddings import ImageProjection
23-
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
23+
from diffusers.utils.testing_utils import enable_full_determinism, is_peft_available, torch_device
2424

2525
from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, TorchCompileTesterMixin
2626

@@ -172,6 +172,35 @@ def test_gradient_checkpointing_is_applied(self):
172172
expected_set = {"FluxTransformer2DModel"}
173173
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
174174

175+
# The test exists for cases like
176+
# https://github.com/huggingface/diffusers/issues/11874
177+
@unittest.skipIf(not is_peft_available(), "Only with PEFT")
178+
def test_lora_exclude_modules(self):
179+
from peft import LoraConfig, get_peft_model_state_dict, inject_adapter_in_model, set_peft_model_state_dict
180+
181+
lora_rank = 4
182+
target_module = "single_transformer_blocks.0.proj_out"
183+
adapter_name = "foo"
184+
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
185+
model = self.model_class(**init_dict).to(torch_device)
186+
187+
state_dict = model.state_dict()
188+
target_mod_shape = state_dict[f"{target_module}.weight"].shape
189+
lora_state_dict = {
190+
f"{target_module}.lora_A.weight": torch.ones(lora_rank, target_mod_shape[1]) * 22,
191+
f"{target_module}.lora_B.weight": torch.ones(target_mod_shape[0], lora_rank) * 33,
192+
}
193+
# Passing exclude_modules should no longer be necessary (or even passing target_modules, for that matter).
194+
config = LoraConfig(
195+
r=lora_rank, target_modules=["single_transformer_blocks.0.proj_out"], exclude_modules=["proj_out"]
196+
)
197+
inject_adapter_in_model(config, model, adapter_name=adapter_name, state_dict=lora_state_dict)
198+
set_peft_model_state_dict(model, lora_state_dict, adapter_name)
199+
retrieved_lora_state_dict = get_peft_model_state_dict(model, adapter_name=adapter_name)
200+
assert len(retrieved_lora_state_dict) == len(lora_state_dict)
201+
assert (retrieved_lora_state_dict["single_transformer_blocks.0.proj_out.lora_A.weight"] == 22).all()
202+
assert (retrieved_lora_state_dict["single_transformer_blocks.0.proj_out.lora_B.weight"] == 33).all()
203+
175204

176205
class FluxTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
177206
model_class = FluxTransformer2DModel

0 commit comments

Comments
 (0)