Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/peft/tuners/lora/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,8 +669,8 @@ class LoraConfig(PeftConfig):
"help": (
"Whether to tie weights or not after peft initialization. "
"This will ensure that the adapters added to the tied layers "
"are also tied. This is only applicable for layers passed via "
"`modules_to_save`."
"are also tied. This is applicable for layers passed via "
"`modules_to_save` and `trainable_token_indices`."
)
},
)
Expand Down
90 changes: 76 additions & 14 deletions src/peft/utils/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -1460,36 +1460,98 @@ def set_additional_trainable_modules(model, peft_config, model_config, adapter_n

modules_to_save = getattr(peft_config, "modules_to_save", None)
if modules_to_save is not None:
for target_layer in target_layers:
if target_layer in modules_to_save:
for target_layer_name in target_layers:
if target_layer_name in modules_to_save:
raise ValueError(
"The embedding layer is already marked to be trained fully, either specify "
f'`modules_to_save=[..., "{target_layer}", ...]` or '
f"`trainable_tokens={{'{target_layer}': x}}` but not both."
f'`modules_to_save=[..., "{target_layer_name}", ...]` or '
f"`trainable_tokens={{'{target_layer_name}': x}}` but not both."
)

for target_layer, token_indices in target_layers.items():
# Check weight tying configuration first to determine which layers to wrap
weights_tied = model_config.get("tie_word_embeddings", False)
ensure_weight_tying = getattr(peft_config, "ensure_weight_tying", False)

# When multiple target layers are specified, check if they correspond to tied weights
indices_mismatch = False
layers_to_skip = set()
tied_layer_keys = []

if len(target_layers) > 1 and weights_tied:
# Get module names that are tied with the embedding
tied_module_names = set(_get_module_names_tied_with_embedding(model))

# Also get the input embedding layer name as it's the source of tied weights
embedding_module = model.get_input_embeddings()
# Get just the last part of the name for matching with target_layers
embedding_name = next(n.split(".")[-1] for n, m in model.named_modules() if m is embedding_module)

# Find which target layers are in the tied weights (including the embedding source)
for target_layer_name in target_layers:
# Check if this is the embedding layer
if target_layer_name == embedding_name:
tied_layer_keys.append(target_layer_name)
continue
# Check if this target layer matches any tied module (considering nested structures)
for tied_module in tied_module_names:
if tied_module.endswith(target_layer_name) or target_layer_name in tied_module.split("."):
tied_layer_keys.append(target_layer_name)
break

# If we found multiple tied layers in our targets, check their indices
if len(tied_layer_keys) >= 2:
# Check if all tied layers have the same indices
first_indices = target_layers[tied_layer_keys[0]]
indices_mismatch = not all(target_layers[key] == first_indices for key in tied_layer_keys[1:])

# Raise error immediately if ensure_weight_tying=True and indices mismatch
if indices_mismatch and ensure_weight_tying:
tied_layers_info = ", ".join([f"{key}: {target_layers[key]}" for key in tied_layer_keys])
raise ValueError(
f"Cannot ensure weight tying when different token indices are specified for tied layers. "
f"Conflicting layers: {tied_layers_info}. "
f"Please use the same indices for all tied layers or set ensure_weight_tying=False."
)

# If indices match, skip tied modules (except embedding) as they'll be handled by weight tying logic
if not indices_mismatch:
layers_to_skip = set(tied_layer_keys) & tied_module_names

# Wrap target layers (skip those that will be handled by weight tying logic)
for target_layer_name, token_indices in target_layers.items():
if target_layer_name in layers_to_skip:
continue

_set_trainable(
model,
adapter_name,
inference_mode=peft_config.inference_mode,
module_names=[target_layer],
module_names=[target_layer_name],
strict_module_check=True,
wrapper_cls=TrainableTokensWrapper,
token_indices=token_indices,
activate_adapter=activate_adapter,
)

tied_weights_module_names = _get_module_names_tied_with_embedding(model)
# Warn if user expects weight tying but model doesn't have tied weights
if not weights_tied and ensure_weight_tying:
warnings.warn(
"ensure_weight_tying=True but the model does not have tied weights "
"(tie_word_embeddings=False). Weight tying will not be applied for trainable_token_indices."
)

# There might be the possibility that we have output weights that are tied to the input weights.
# In that case we will tie any module that wants tied weights to the token adapter to make sure that
# any modification is reflected in the tied layers as well.
if (
tied_weights_module_names
and model_config.get("tie_word_embeddings", False)
# Apply weight tying when appropriate
should_apply_tying = (
weights_tied
and isinstance(model.get_input_embeddings(), TrainableTokensWrapper)
):
and (ensure_weight_tying or not indices_mismatch)
)

if should_apply_tying:
# There might be the possibility that we have output weights that are tied to the input weights.
# In that case we will tie any module that wants tied weights to the token adapter to make sure that
# any modification is reflected in the tied layers as well.
tied_weights_module_names = _get_module_names_tied_with_embedding(model)
token_adapter = model.get_input_embeddings().token_adapter
_set_trainable(
model,
Expand Down
172 changes: 172 additions & 0 deletions tests/test_trainable_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -1016,3 +1016,175 @@ def test_scaled_embedding_with_lora(self):
orig_embedding.embed_scale.fill_(0)
embedding_output = peft_embedding(x)
assert (embedding_output == 0.0).all()

# Tests for ensure_weight_tying parameter with trainable_token_indices
# See #2864 for details on the expected behavior
@pytest.mark.parametrize(
"trainable_token_indices",
[
[1, 2, 3], # list format
{"embed_tokens": [1, 2, 3]}, # dict format - single layer
{"lm_head": [1, 2], "embed_tokens": [1, 2]}, # dict format - same indices
{"lm_head": [1, 2], "embed_tokens": [3, 4]}, # dict format - different indices
],
)
def test_ensure_weight_tying_warns_when_model_not_tied(
self, model_weight_untied, recwarn, trainable_token_indices
):
"""Should warn when ensure_weight_tying=True but model doesn't have tied weights"""
peft_config = LoraConfig(
target_modules="all-linear",
trainable_token_indices=trainable_token_indices,
ensure_weight_tying=True,
)
peft_model = get_peft_model(model_weight_untied, peft_config)

warnings_list = [w.message.args[0] for w in recwarn]
expected = "ensure_weight_tying=True but the model does not have tied weights"
assert any(expected in msg for msg in warnings_list)

# Verify adapters are not tied (model doesn't have tied weights)
if isinstance(trainable_token_indices, dict) and len(trainable_token_indices) > 1:
embed_adapter = peft_model.model.model.embed_tokens.token_adapter
lm_head_adapter = peft_model.model.lm_head.token_adapter
assert embed_adapter is not None
assert lm_head_adapter is not None
assert embed_adapter.trainable_tokens_delta is not lm_head_adapter.trainable_tokens_delta

def test_weight_tying_bc_different_indices_treated_separately(self, model_weight_tied):
"""Backwards compatibility: different indices should be treated separately when ensure_weight_tying=False"""
peft_config = LoraConfig(
target_modules="all-linear",
trainable_token_indices={"lm_head": [1, 2], "embed_tokens": [3, 4]},
ensure_weight_tying=False, # BC behavior
)
peft_model = get_peft_model(model_weight_tied, peft_config)

# Check that both layers have token adapters but they're NOT tied
embed_adapter = peft_model.model.model.decoder.embed_tokens.token_adapter
lm_head_adapter = peft_model.model.lm_head.token_adapter

assert embed_adapter is not None
assert lm_head_adapter is not None
# They should NOT share the same delta parameters (treated as separate)
assert embed_adapter.trainable_tokens_delta is not lm_head_adapter.trainable_tokens_delta
# They should have different token indices
assert embed_adapter.token_indices["default"] == [3, 4]
assert lm_head_adapter.token_indices["default"] == [1, 2]

def test_ensure_weight_tying_errors_with_different_indices(self, model_weight_tied):
"""Should raise error when ensure_weight_tying=True with different indices for embedding and lm_head"""
peft_config = LoraConfig(
target_modules="all-linear",
trainable_token_indices={"lm_head": [1, 2], "embed_tokens": [3, 4]},
ensure_weight_tying=True,
)

msg = "Cannot ensure weight tying when different token indices are specified"
with pytest.raises(ValueError, match=msg):
peft_model = get_peft_model(model_weight_tied, peft_config)

def test_ensure_weight_tying_applied_with_same_indices(self, model_weight_tied):
"""Should apply weight tying when ensure_weight_tying=True with same indices"""
peft_config = LoraConfig(
target_modules="all-linear",
trainable_token_indices={"lm_head": [1, 2], "embed_tokens": [1, 2]},
ensure_weight_tying=True,
)
peft_model = get_peft_model(model_weight_tied, peft_config)

# Check that weight tying is properly applied
embed_adapter = peft_model.model.model.decoder.embed_tokens.token_adapter
lm_head_adapter = peft_model.model.lm_head.token_adapter

# They should share the same delta parameters (weight tying)
assert embed_adapter.trainable_tokens_delta is lm_head_adapter.trainable_tokens_delta
# They should have the same token indices
assert embed_adapter.token_indices["default"] == [1, 2]
assert lm_head_adapter.token_indices["default"] == [1, 2]

def test_weight_tying_bc_same_indices_applied(self, model_weight_tied):
"""When indices are the same, weight tying should be applied even when ensure_weight_tying=False"""
peft_config = LoraConfig(
target_modules="all-linear",
trainable_token_indices={"lm_head": [1, 2], "embed_tokens": [1, 2]},
ensure_weight_tying=False, # BC: still applies tying when indices are the same
)
peft_model = get_peft_model(model_weight_tied, peft_config)

# Even with ensure_weight_tying=False, BC behavior should still tie when indices are same
embed_adapter = peft_model.model.model.decoder.embed_tokens.token_adapter
lm_head_adapter = peft_model.model.lm_head.token_adapter

# They should share the same delta parameters (BC behavior)
assert embed_adapter.trainable_tokens_delta is lm_head_adapter.trainable_tokens_delta

def test_ensure_weight_tying_with_single_layer(self, model_weight_tied):
"""ensure_weight_tying should work with single layer (list format)"""
peft_config = LoraConfig(
target_modules="all-linear",
trainable_token_indices=[1, 2, 3],
ensure_weight_tying=True,
)
peft_model = get_peft_model(model_weight_tied, peft_config)

# Should apply weight tying to tied layers automatically
embed_adapter = peft_model.model.model.decoder.embed_tokens.token_adapter
lm_head_adapter = peft_model.model.lm_head.token_adapter

# They should share the same delta parameters
assert embed_adapter.trainable_tokens_delta is lm_head_adapter.trainable_tokens_delta

def test_untied_model_list_format_no_ensure(self, model_weight_untied):
"""Untied model with list format, ensure_weight_tying=False - trainable tokens on embeddings only"""
peft_config = LoraConfig(
target_modules="all-linear",
trainable_token_indices=[1, 2, 3],
ensure_weight_tying=False,
)
peft_model = get_peft_model(model_weight_untied, peft_config)

# Only embed_tokens should have token adapter
assert hasattr(peft_model.model.model.embed_tokens, "token_adapter")
assert not hasattr(peft_model.model.lm_head, "token_adapter")

def test_tied_model_list_format_no_ensure(self, model_weight_tied):
"""Tied model with list format, ensure_weight_tying=False - tied trainable tokens"""
peft_config = LoraConfig(
target_modules="all-linear",
trainable_token_indices=[1, 2, 3],
ensure_weight_tying=False,
)
peft_model = get_peft_model(model_weight_tied, peft_config)

# Both should have token adapters and be tied
embed_adapter = peft_model.model.model.decoder.embed_tokens.token_adapter
lm_head_adapter = peft_model.model.lm_head.token_adapter

# They should share the same delta parameters (BC behavior)
assert embed_adapter.trainable_tokens_delta is lm_head_adapter.trainable_tokens_delta

@pytest.mark.parametrize(
"trainable_token_indices",
[
{"lm_head": [1, 2], "embed_tokens": [1, 2]}, # same indices
{"lm_head": [1, 2], "embed_tokens": [3, 4]}, # different indices
],
)
def test_untied_model_dict_no_ensure(self, model_weight_untied, trainable_token_indices):
"""Untied model with dict format, ensure_weight_tying=False - treat as separate"""
peft_config = LoraConfig(
target_modules="all-linear",
trainable_token_indices=trainable_token_indices,
ensure_weight_tying=False,
)
peft_model = get_peft_model(model_weight_untied, peft_config)

# Both should have token adapters but NOT tied (since model doesn't have tied weights)
embed_adapter = peft_model.model.model.embed_tokens.token_adapter
lm_head_adapter = peft_model.model.lm_head.token_adapter

assert embed_adapter is not None
assert lm_head_adapter is not None
# They should NOT share delta parameters (model doesn't have tied weights)
assert embed_adapter.trainable_tokens_delta is not lm_head_adapter.trainable_tokens_delta
Loading