Skip to content

Commit

Permalink
Add Python 3.12 test and rescale method
Browse files Browse the repository at this point in the history
  • Loading branch information
gsarti committed May 28, 2024
1 parent 96add78 commit 01bd08b
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 21 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jobs:
if: github.actor != 'dependabot[bot]' && github.actor != 'dependabot-preview[bot]'
strategy:
matrix:
python-version: ["3.9", "3.10", "3.11"]
python-version: ["3.9", "3.10", "3.11", "3.12"]

steps:
- uses: actions/checkout@v3
Expand Down
9 changes: 5 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@

</div>

Inseq is a Pytorch-based hackable toolkit to democratize the access to common post-hoc **in**terpretability analyses of **seq**uence generation models.
Inseq is a Pytorch-based hackable toolkit to democratize access to common post-hoc **in**terpretability analyses of **seq**uence generation models.

## Installation

Inseq is available on PyPI and can be installed with `pip` for Python >= 3.9, <= 3.11:
Inseq is available on PyPI and can be installed with `pip` for Python >= 3.9, <= 3.12:

```bash
# Install latest stable version
Expand Down Expand Up @@ -308,13 +308,13 @@ If you use Inseq in your research we suggest to include a mention to the specifi
Inseq has been used in various research projects. A list of known publications that use Inseq to conduct interpretability analyses of generative models is shown below.

> [!TIP]
> Last update: April 2024. Please open a pull request to add your publication to the list.
> Last update: May 2024. Please open a pull request to add your publication to the list.
<details>
<summary><b>2023</b></summary>
<ol>
<li> <a href="https://aclanthology.org/2023.acl-demo.40/">Inseq: An Interpretability Toolkit for Sequence Generation Models</a> (Sarti et al., 2023) </li>
<li> <a href="https://arxiv.org/abs/2302.14220">Are Character-level Translations Worth the Wait? Comparing ByT5 and mT5 for Machine Translation</a> (Edman et al., 2023) </li>
<li> <a href="https://doi.org/10.1162/tacl_a_00651">Are Character-level Translations Worth the Wait? Comparing ByT5 and mT5 for Machine Translation</a> (Edman et al., 2023) </li>
<li> <a href="https://aclanthology.org/2023.nlp4convai-1.1/">Response Generation in Longitudinal Dialogues: Which Knowledge Representation Helps?</a> (Mousavi et al., 2023) </li>
<li> <a href="https://openreview.net/forum?id=XTHfNGI3zT">Quantifying the Plausibility of Context Reliance in Neural Machine Translation</a> (Sarti et al., 2023)</li>
<li> <a href="https://aclanthology.org/2023.emnlp-main.243/">A Tale of Pronouns: Interpretability Informs Gender Bias Mitigation for Fairer Instruction-Tuned Machine Translation</a> (Attanasio et al., 2023)</li>
Expand All @@ -330,6 +330,7 @@ Inseq has been used in various research projects. A list of known publications t
<li><a href="https://arxiv.org/abs/2401.12576">LLMCheckup: Conversational Examination of Large Language Models via Interpretability Tools</a> (Wang et al., 2024)</li>
<li><a href="https://arxiv.org/abs/2402.00794">ReAGent: A Model-agnostic Feature Attribution Method for Generative Language Models</a> (Zhao et al., 2024)</li>
<li><a href="https://arxiv.org/abs/2404.02421">Revisiting subword tokenization: A case study on affixal negation in large language models</a> (Truong et al., 2024)</li>
<li><a href="https://hal.science/hal-04581586">Exploring NMT Explainability for Translators Using NMT Visualising Tools</a> (Gonzalez-Saez et al., 2024)</li>
</ol>

</details>
16 changes: 13 additions & 3 deletions inseq/data/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
validate_indices,
)
from ..utils import normalize as normalize_fn
from ..utils import rescale as rescale_fn
from ..utils.typing import IndexSpan, OneOrMoreIndices, TokenWithId
from .aggregation_functions import AggregationFunction
from .data_utils import TensorWrapper
Expand Down Expand Up @@ -307,9 +308,14 @@ def _process_attribution_scores(
attr: "FeatureAttributionSequenceOutput",
aggregate_fn: AggregationFunction,
select_idx: Optional[OneOrMoreIndices] = None,
normalize: bool = True,
normalize: Optional[bool] = None,
rescale: Optional[bool] = None,
**kwargs,
):
if normalize and rescale:
raise ValueError("Only one of normalize and rescale can be set to True.")
if normalize is None:
normalize = rescale is None or not rescale
fn_kwargs = extract_signature_args(kwargs, aggregate_fn)
# If select_idx is a single int, no aggregation is performed
do_aggregate = not isinstance(select_idx, int)
Expand All @@ -331,6 +337,8 @@ def _process_attribution_scores(
scores = cls._aggregate_scores(scores, aggregate_fn, dim=-1, **fn_kwargs)
if normalize:
scores = normalize_fn(scores)
if rescale:
scores = rescale_fn(scores)
return scores

@classmethod
Expand Down Expand Up @@ -369,11 +377,12 @@ def aggregate_source_attributions(
aggregate_fn: AggregationFunction,
select_idx: Optional[OneOrMoreIndices] = None,
normalize: bool = True,
rescale: bool = False,
**kwargs,
):
if attr.source_attributions is None:
return attr.source_attributions
scores = cls._process_attribution_scores(attr, aggregate_fn, select_idx, normalize, **kwargs)
scores = cls._process_attribution_scores(attr, aggregate_fn, select_idx, normalize, rescale, **kwargs)
return scores[0] if attr.target_attributions is not None else scores

@classmethod
Expand All @@ -383,11 +392,12 @@ def aggregate_target_attributions(
aggregate_fn: AggregationFunction,
select_idx: Optional[OneOrMoreIndices] = None,
normalize: bool = True,
rescale: bool = False,
**kwargs,
):
if attr.target_attributions is None:
return attr.target_attributions
scores = cls._process_attribution_scores(attr, aggregate_fn, select_idx, normalize, **kwargs)
scores = cls._process_attribution_scores(attr, aggregate_fn, select_idx, normalize, rescale, **kwargs)
return scores[1] if attr.source_attributions is not None else scores

@staticmethod
Expand Down
2 changes: 2 additions & 0 deletions inseq/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
pad_with_nan,
recursive_get_submodule,
remap_from_filtered,
rescale,
top_p_logits_mask,
validate_indices,
)
Expand All @@ -86,6 +87,7 @@
"remap_from_filtered",
"drop_padding",
"normalize",
"rescale",
"aggregate_contiguous",
"get_front_padding",
"get_sequences_from_batched_steps",
Expand Down
48 changes: 35 additions & 13 deletions inseq/utils/torch_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import logging
from collections.abc import Sequence
from functools import wraps
from inspect import signature
from typing import TYPE_CHECKING, Callable, Literal, Optional, Union

import torch
Expand Down Expand Up @@ -38,23 +40,43 @@ def remap_from_filtered(
return new_source.scatter(0, index, filtered)


def postprocess_attribution_scores(func: Callable) -> Callable:
@wraps(func)
def postprocess_scores_wrapper(
attributions: Union[torch.Tensor, tuple[torch.Tensor, ...]], dim: int = 0, *args, **kwargs
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
multi_input = False
if isinstance(attributions, tuple):
orig_sizes = [a.shape[dim] for a in attributions]
attributions = torch.cat(attributions, dim=dim)
multi_input = True
nan_mask = attributions.isnan()
attributions[nan_mask] = 0.0
if "dim" in signature(func).parameters:
kwargs["dim"] = dim
attributions = func(attributions, *args, **kwargs)
attributions[nan_mask] = float("nan")
if multi_input:
return tuple(attributions.split(orig_sizes, dim=dim))
return attributions

return postprocess_scores_wrapper


@postprocess_attribution_scores
def normalize(
attributions: Union[torch.Tensor, tuple[torch.Tensor, ...]],
norm_dim: int = 0,
dim: int = 0,
norm_ord: int = 1,
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
multi_input = False
if isinstance(attributions, tuple):
orig_sizes = [a.shape[norm_dim] for a in attributions]
attributions = torch.cat(attributions, dim=norm_dim)
multi_input = True
nan_mask = attributions.isnan()
attributions[nan_mask] = 0.0
attributions = F.normalize(attributions, p=norm_ord, dim=norm_dim)
attributions[nan_mask] = float("nan")
if multi_input:
return tuple(attributions.split(orig_sizes, dim=norm_dim))
return attributions
return F.normalize(attributions, p=norm_ord, dim=dim)


@postprocess_attribution_scores
def rescale(
attributions: Union[torch.Tensor, tuple[torch.Tensor, ...]],
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
return attributions / attributions.abs().max()


def top_p_logits_mask(logits: torch.Tensor, top_p: float, min_tokens_to_keep: int) -> torch.Tensor:
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ classifiers = [
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Scientific/Engineering :: Visualization",
"Typing :: Typed"
Expand Down

0 comments on commit 01bd08b

Please sign in to comment.