Skip to content

Commit

Permalink
Fix LLaMA 3.2, add clean_special_chars (#289)
Browse files Browse the repository at this point in the history
  • Loading branch information
gsarti authored Oct 15, 2024
1 parent 9c118bb commit d7c5269
Show file tree
Hide file tree
Showing 7 changed files with 135 additions and 46 deletions.
20 changes: 12 additions & 8 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,19 @@

## 🚀 Features

- Added [treescope](https://github.com/google-deepmind/treescope) for interactive model and tensor visualization. ([#283](https://github.com/inseq-team/inseq/pull/283))
- Added [treescope](https://github.com/google-deepmind/treescope) for interactive model and tensor visualization ([#283](https://github.com/inseq-team/inseq/pull/283)).

- New `treescope`-powered methods `FeatureAttributionOutput.show_granular` and `FeatureAttributionSequenceOutput.show_tokens` for interactive visualization of multidimensional attribution tensors and token highlights. ([#283](https://github.com/inseq-team/inseq/pull/283))
- New `treescope`-powered methods `FeatureAttributionOutput.show_granular` and `FeatureAttributionSequenceOutput.show_tokens` for interactive visualization of multidimensional attribution tensors and token highlights ([#283](https://github.com/inseq-team/inseq/pull/283)).

- Added new models `DbrxForCausalLM`, `OlmoForCausalLM`, `Phi3ForCausalLM`, `Qwen2MoeForCausalLM`, `Gemma2ForCausalLM` to model config.
- Added new models `DbrxForCausalLM`, `OlmoForCausalLM`, `Phi3ForCausalLM`, `Qwen2MoeForCausalLM`, `Gemma2ForCausalLM`, `OlmoeForCausalLM`, `GraniteForCausalLM`, `GraniteMoeForCausalLM` to model config.

- Add `rescale_attributions` to Inseq CLI commands for `rescale=True` ([#280](https://github.com/inseq-team/inseq/pull/280)).

- Rows and columns in the visualization now have indices alongside tokens to facilitate index-based slicing, aggregation and alignment [#282](https://github.com/inseq-team/inseq/pull/282)
- Rows and columns in the visualization now have indices alongside tokens to facilitate index-based slicing, aggregation and alignment ([#282](https://github.com/inseq-team/inseq/pull/282)).

- Added a `scores_precision` to `FeatureAttributionOutput.save` to enable efficient saving in `float16` and `float8` formats. This is useful for saving large attribution outputs in a more memory-efficient way. [#273](https://github.com/inseq-team/inseq/pull/273)
- New parameter `clean_special_chars` in `model.attribute` to automatically clean special characters from output tokens, such as `` and `Ġ` ([#289](https://github.com/inseq-team/inseq/pull/289)).

- Added a `scores_precision` to `FeatureAttributionOutput.save` to enable efficient saving in `float16` and `float8` formats. This is useful for saving large attribution outputs in a more memory-efficient way ([#273](https://github.com/inseq-team/inseq/pull/273)).

```python
import inseq
Expand Down Expand Up @@ -53,7 +55,7 @@ out_sliced = out.aggregate("slices", target_spans=(13,73))
out_sliced = out[13:73]
```

- The `__sub__` method in `FeatureAttributionSequenceOutput` is now used as a shortcut for `PairAggregator` [#282](https://github.com/inseq-team/inseq/pull/282)
- The `__sub__` method in `FeatureAttributionSequenceOutput` is now used as a shortcut for `PairAggregator` ([#282](https://github.com/inseq-team/inseq/pull/282)).


```python
Expand Down Expand Up @@ -84,12 +86,14 @@ out_female = attrib_model.attribute(

- Fix multi-device support and duplicate BOS for chat template models ([#280](https://github.com/inseq-team/inseq/pull/280)).

- The directions of generated/attributed tokens were clarified in the visualization using arrows instead of x/y [#282](https://github.com/inseq-team/inseq/pull/282)
- The directions of generated/attributed tokens were clarified in the visualization using arrows instead of x/y ([#282](https://github.com/inseq-team/inseq/pull/282)).

- Fix support for multi-EOS tokens (e.g. LLaMA 3.2, see [#287](https://github.com/inseq-team/inseq/issues/287)).

## 📝 Documentation and Tutorials

- Updated tutorial with `treescope` usage examples.

## 💥 Breaking Changes

- Dropped support for Python 3.9. Please use Python >= 3.10. ([#283](https://github.com/inseq-team/inseq/pull/283))
- Dropped support for Python 3.9. Current support is Python >= 3.10, <= 3.12 ([#283](https://github.com/inseq-team/inseq/pull/283)).
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ install:
.PHONY: install-dev
install-dev:
make uv-activate && uv pip install -r requirements-dev.txt && pre-commit install && pre-commit autoupdate


.PHONY: install-ci
install-ci:
Expand All @@ -82,7 +82,7 @@ fix-style:

.PHONY: check-safety
check-safety:
$(PYTHON) -m safety check --full-report -i 70612 -i 71670 -i 72089
$(PYTHON) -m safety check --full-report -i 70612 -i 72089

.PHONY: lint
lint: fix-style check-safety
Expand Down
21 changes: 21 additions & 0 deletions inseq/attr/feat/feature_attribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ def prepare_and_attribute(
step_scores: list[str] = [],
include_eos_baseline: bool = False,
skip_special_tokens: bool = False,
clean_special_chars: bool = False,
attributed_fn: str | Callable[..., SingleScorePerStepTensor] | None = None,
attribution_args: dict[str, Any] = {},
attributed_fn_args: dict[str, Any] = {},
Expand Down Expand Up @@ -210,6 +211,8 @@ def prepare_and_attribute(
attribution. By default the EOS token is not used for attribution. Defaults to False.
skip_special_tokens (:obj:`bool`, `optional`): Whether to skip special tokens when encoding the input.
Defaults to False.
clean_special_chars (:obj:`bool`, `optional`): Whether to clean special characters from the input and the
generated tokens. Defaults to False.
attributed_fn (:obj:`str` or :obj:`Callable[..., SingleScorePerStepTensor]`, `optional`): The identifier or
function of model outputs representing what should be attributed (e.g. output probits of model best
prediction after softmax). If it is a string, it must be a valid function.
Expand Down Expand Up @@ -252,6 +255,7 @@ def prepare_and_attribute(
attribute_target=attribute_target,
step_scores=step_scores,
skip_special_tokens=skip_special_tokens,
clean_special_chars=clean_special_chars,
attribution_args=attribution_args,
attributed_fn_args=attributed_fn_args,
step_scores_args=step_scores_args,
Expand Down Expand Up @@ -368,6 +372,7 @@ def attribute(
attribute_target: bool = False,
step_scores: list[str] = [],
skip_special_tokens: bool = False,
clean_special_chars: bool = False,
attribution_args: dict[str, Any] = {},
attributed_fn_args: dict[str, Any] = {},
step_scores_args: dict[str, Any] = {},
Expand Down Expand Up @@ -397,6 +402,8 @@ def attribute(
step scores can be added by using the :meth:`~inseq.register_step_function` function.
skip_special_tokens (:obj:`bool`, `optional`): Whether to skip special tokens when encoding the input.
Defaults to False.
clean_special_chars (:obj:`bool`, `optional`): Whether to clean special characters from the input and the
generated tokens. Defaults to False.
attribution_args (:obj:`dict`, `optional`): Additional arguments to pass to the attribution method.
attributed_fn_args (:obj:`dict`, `optional`): Additional arguments to pass to the attributed function.
step_scores_args (:obj:`dict`, `optional`): Additional arguments to pass to the step scores function.
Expand Down Expand Up @@ -522,6 +529,20 @@ def attribute(
attr_pos_start=attr_pos_start,
attr_pos_end=iter_pos_end,
)
if clean_special_chars:
for out in attribution_outputs:
out.source = self.attribution_model.clean_tokens(out.source) if out.source is not None else None
out.prefix = (
self.attribution_model.clean_tokens(out.prefix, as_targets=True)
if out.prefix is not None
else None
)
out.target = (
self.attribution_model.clean_tokens(out.target, as_targets=True)
if out.target is not None
else None
)
target_tokens_with_ids = self.attribution_model.clean_tokens(target_tokens_with_ids, as_targets=True)
out = FeatureAttributionOutput(
sequence_attributions=FeatureAttributionSequenceOutput.from_step_attributions(
attributions=attribution_outputs,
Expand Down
4 changes: 4 additions & 0 deletions inseq/data/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ class BatchEncoding(TensorWrapper):
def __len__(self) -> int:
return len(self.input_tokens)

@property
def num_sequences(self) -> int:
return self.input_ids.shape[0]


@dataclass(eq=False, repr=False)
class BatchEmbedding(TensorWrapper):
Expand Down
80 changes: 54 additions & 26 deletions inseq/models/attribution_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,47 @@ def get_attributed_fn(
attributed_fn = STEP_SCORES_MAP[attributed_fn]
return attributed_fn

def validate_attribute_args(
self,
input_texts: TextInput,
generated_texts: TextInput,
has_generated_texts: bool,
attribution_method: FeatureAttribution,
batch_size: int,
attr_pos_start: int | None,
attr_pos_end: int | None,
) -> int:
logger.debug(f"reference_texts={generated_texts}")
if not self.is_encoder_decoder:
error_input_gen_mismatch = "Forced generations of decoder-only models must start with the input texts:\n\n"
mismatch_seqs = []
for idx in range(len(input_texts)):
if not generated_texts[idx].startswith(input_texts[idx]):
mismatch_seqs.append(f"{repr(input_texts[idx])}\n!=\n{repr(generated_texts[idx])}")
assert len(mismatch_seqs) == 0, error_input_gen_mismatch + "\n\n".join(mismatch_seqs)
if has_generated_texts and len(input_texts) > 1:
logger.warning(
"Batched constrained decoding is currently not supported for decoder-only models."
" Using batch size of 1."
)
batch_size = 1
if len(input_texts) > 1 and (attr_pos_start is not None or attr_pos_end is not None):
logger.warning(
"Custom attribution positions are currently not supported when batching generations for"
" decoder-only models. Using batch size of 1."
)
batch_size = 1
elif attribution_method.is_final_step_method and len(input_texts) > 1:
logger.warning(
"Batched attribution with encoder-decoder models currently not supported for final-step methods."
" Using batch size of 1."
)
batch_size = 1
if attribution_method.method_name == "lime":
logger.warning("Batched attribution currently not supported for LIME. Using batch size of 1.")
batch_size = 1
return batch_size

def attribute(
self,
input_texts: TextInput,
Expand All @@ -319,6 +360,7 @@ def attribute(
batch_size: int | None = None,
generate_from_target_prefix: bool = False,
skip_special_tokens: bool = False,
clean_special_chars: bool = False,
generation_args: dict[str, Any] = {},
**kwargs,
) -> FeatureAttributionOutput:
Expand Down Expand Up @@ -370,6 +412,8 @@ def attribute(
achieved by modifying the input texts for decoder-only models. Default: False.
skip_special_tokens (:obj:`bool`, `optional`): Whether to skip special tokens when attributing the input
texts. Default: False.
clean_special_chars (:obj:`bool`, `optional`): Whether to clean special characters from the input and
generated texts. Default: False.
**kwargs: Additional keyword arguments. These can include keyword arguments for the attribution method, for
the generation process or for the attributed function. Generation arguments can be provided explicitly
as a dictionary named ``generation_args``.
Expand Down Expand Up @@ -445,32 +489,15 @@ def attribute(
logger.warning(
f"Generation arguments {generation_args} are provided, but will be ignored (constrained decoding)."
)
logger.debug(f"reference_texts={generated_texts}")
if not self.is_encoder_decoder:
assert all(
generated_texts[idx].startswith(input_texts[idx]) for idx in range(len(input_texts))
), "Forced generations with decoder-only models must start with the input texts."
if has_generated_texts and len(input_texts) > 1:
logger.warning(
"Batched constrained decoding is currently not supported for decoder-only models."
" Using batch size of 1."
)
batch_size = 1
if len(input_texts) > 1 and (attr_pos_start is not None or attr_pos_end is not None):
logger.warning(
"Custom attribution positions are currently not supported when batching generations for"
" decoder-only models. Using batch size of 1."
)
batch_size = 1
elif attribution_method.is_final_step_method and len(input_texts) > 1:
logger.warning(
"Batched attribution with encoder-decoder models currently not supported for final-step methods."
" Using batch size of 1."
)
batch_size = 1
if attribution_method.method_name == "lime":
logger.warning("Batched attribution currently not supported for LIME. Using batch size of 1.")
batch_size = 1
batch_size = self.validate_attribute_args(
input_texts=input_texts,
generated_texts=generated_texts,
has_generated_texts=has_generated_texts,
attribution_method=attribution_method,
batch_size=batch_size,
attr_pos_start=attr_pos_start,
attr_pos_end=attr_pos_end,
)
attribution_outputs = attribution_method.prepare_and_attribute(
input_texts,
generated_texts,
Expand All @@ -484,6 +511,7 @@ def attribute(
step_scores=step_scores,
include_eos_baseline=include_eos_baseline,
skip_special_tokens=skip_special_tokens,
clean_special_chars=clean_special_chars,
attributed_fn=attributed_fn,
attribution_args=attribution_args,
attributed_fn_args=attributed_fn_args,
Expand Down
40 changes: 30 additions & 10 deletions inseq/models/huggingface_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@
MultiLayerMultiUnitScoreTensor,
OneOrMoreIdSequences,
OneOrMoreTokenSequences,
OneOrMoreTokenWithIdSequences,
TextInput,
TokenWithId,
VocabularyEmbeddingsTensor,
)
from .attribution_model import AttributionModel
Expand Down Expand Up @@ -110,6 +112,8 @@ def __init__(
else:
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer, **tokenizer_kwargs)
self.eos_token_id = getattr(self.model.config, "eos_token_id", None)
if isinstance(self.eos_token_id, list):
self.eos_token_id = self.eos_token_id[0]
pad_token_id = self.model.config.pad_token_id
if pad_token_id is None:
if self.tokenizer.pad_token_id is None:
Expand All @@ -118,6 +122,8 @@ def __init__(
else:
pad_token_id = self.tokenizer.pad_token_id
self.pad_token = self._convert_ids_to_tokens(pad_token_id, skip_special_tokens=False)
if isinstance(self.pad_token, list):
self.pad_token = self.pad_token[0]
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.pad_token
if self.model.config.pad_token_id is None:
Expand Down Expand Up @@ -206,7 +212,7 @@ def generate(
self,
inputs: TextInput | BatchEncoding,
return_generation_output: bool = False,
skip_special_tokens: bool = True,
skip_special_tokens: bool | None = None,
output_generated_only: bool = False,
**kwargs,
) -> list[str] | tuple[list[str], ModelOutput]:
Expand All @@ -229,7 +235,7 @@ def generate(
isinstance(inputs, list) and len(inputs) > 0 and all(isinstance(x, str) for x in inputs)
):
inputs = self.encode(inputs, add_special_tokens=not skip_special_tokens)
inputs = inputs.to(self.device)
inputs: BatchEncoding = inputs.to(self.device)
generation_out = self.model.generate(
inputs=inputs.input_ids,
return_dict_in_generate=True,
Expand All @@ -238,6 +244,10 @@ def generate(
sequences = generation_out.sequences
if output_generated_only and not self.is_encoder_decoder:
sequences = sequences[:, inputs.input_ids.shape[1] :]

# Left-padding in multi-sentence sequences is skipped by default.
if skip_special_tokens is None:
skip_special_tokens = inputs.num_sequences != 1 or self.is_encoder_decoder
texts = self.decode(ids=sequences, skip_special_tokens=skip_special_tokens)
if return_generation_output:
return texts, generation_out
Expand Down Expand Up @@ -278,6 +288,15 @@ def encode(
return_tensors="pt",
).to(self.device)
baseline_ids = None

# Fix: If two BOS tokens are present (e.g. when using chat templates), the second one is removed.
if (
batch["input_ids"].shape[0] == 1
and len(batch["input_ids"][0]) >= 2
and batch["input_ids"][0][0] == batch["input_ids"][0][1] == self.bos_token_id
):
batch["input_ids"] = batch["input_ids"][:, 1:]
batch["attention_mask"] = batch["attention_mask"][:, 1:]
if return_baseline:
if include_eos_baseline:
baseline_ids = torch.ones_like(batch["input_ids"]).long() * self.tokenizer.unk_token_id
Expand Down Expand Up @@ -377,7 +396,7 @@ def convert_string_to_tokens(

def clean_tokens(
self,
tokens: OneOrMoreTokenSequences,
tokens: OneOrMoreTokenSequences | OneOrMoreTokenWithIdSequences,
skip_special_tokens: bool = False,
as_targets: bool = False,
) -> OneOrMoreTokenSequences:
Expand All @@ -396,16 +415,17 @@ def clean_tokens(
"""
if isinstance(tokens, list) and len(tokens) == 0:
return []
elif isinstance(tokens[0], bytes | str):
elif isinstance(tokens[0], bytes | str | TokenWithId):
clean_tokens = []
for tok in tokens:
clean_tok = self.convert_tokens_to_string(
[tok], skip_special_tokens=skip_special_tokens, as_targets=as_targets
str_tok = tok.token if isinstance(tok, TokenWithId) else tok
clean_str_tok = self.convert_tokens_to_string(
[str_tok], skip_special_tokens=skip_special_tokens, as_targets=as_targets
)
if clean_tok:
clean_tokens.append(clean_tok)
elif tok:
clean_tokens.append(" ")
if not clean_str_tok and tok:
clean_str_tok = tok
clean_tok = TokenWithId(clean_str_tok, tok.id) if isinstance(tok, TokenWithId) else clean_str_tok
clean_tokens.append(clean_tok)
return clean_tokens
return [self.clean_tokens(token_seq, skip_special_tokens, as_targets) for token_seq in tokens]

Expand Down
12 changes: 12 additions & 0 deletions inseq/models/model_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ GPTNeoForCausalLM:
GPTNeoXForCausalLM:
self_attention_module: "attention"
value_vector: "value"
GraniteForCausalLM:
self_attention_module: "self_attn"
value_vector: "value_states"
GraniteMoeForCausalLM:
self_attention_module: "self_attn"
value_vector: "value_states"
LlamaForCausalLM:
self_attention_module: "self_attn"
value_vector: "value_states"
Expand All @@ -47,12 +53,18 @@ MistralForCausalLM:
MixtralForCausalLM:
self_attention_module: "self_attn"
value_vector: "value_states"
NemotronForCausalLM:
self_attention_module: "self_attn"
value_vector: "value_states"
MptForCausalLM:
self_attention_module: "attn"
value_vector: "value_states"
OlmoForCausalLM:
self_attention_module: "self_attn"
value_vector: "value_states"
OlmoeForCausalLM:
self_attention_module: "self_attn"
value_vector: "value_states"
OpenAIGPTLMHeadModel:
self_attention_module: "attn"
value_vector: "value"
Expand Down

0 comments on commit d7c5269

Please sign in to comment.