From 904c893153a061c656239d07026d13e1b009cedc Mon Sep 17 00:00:00 2001 From: Gabriele Sarti Date: Fri, 9 Aug 2024 16:32:37 +0200 Subject: [PATCH] `treescope` support and new visualizations (#283) * Add treescope requirement * Drop Python 3.9 support * Fix tornado version for safety * show_granular first version working * Add __treescope_repr__ to FeatureAttributionSequenceOutput * Add slicing for show_granular, started show_tokens * Finished show_tokens * Fix vmin for step scores * Fix lint * Add docs * Update changelog * Fix viz for attribute_context, improved cmaps * Fix safety --- .github/workflows/build.yml | 2 +- .readthedocs.yaml | 2 +- CHANGELOG.md | 8 +- Makefile | 2 +- README.md | 6 +- docs/source/main_classes/main_functions.rst | 4 + examples/inseq_tutorial.ipynb | 456 +++++++++------- inseq/__init__.py | 4 + inseq/attr/attribution_decorators.py | 14 +- inseq/attr/feat/attribution_utils.py | 25 +- inseq/attr/feat/feature_attribution.py | 23 +- inseq/attr/feat/internals_attribution.py | 8 +- .../ops/discretized_integrated_gradients.py | 13 +- inseq/attr/feat/ops/lime.py | 14 +- inseq/attr/feat/ops/monotonic_path_builder.py | 18 +- inseq/attr/feat/ops/reagent.py | 7 +- .../importance_score_evaluator.py | 7 +- .../feat/ops/reagent_core/rationalizer.py | 5 +- .../stopping_condition_evaluator.py | 5 +- .../feat/ops/reagent_core/token_sampler.py | 8 +- .../ops/sequential_integrated_gradients.py | 25 +- inseq/attr/feat/ops/value_zeroing.py | 21 +- inseq/attr/step_functions.py | 50 +- inseq/commands/attribute/attribute.py | 7 +- inseq/commands/attribute/attribute_args.py | 17 +- .../attribute_context_args.py | 28 +- .../attribute_context_helpers.py | 46 +- .../attribute_context_viz_helpers.py | 20 +- .../attribute_dataset/attribute_dataset.py | 4 +- .../attribute_dataset_args.py | 19 +- inseq/commands/base.py | 9 +- inseq/data/__init__.py | 4 +- inseq/data/aggregation_functions.py | 3 +- inseq/data/aggregator.py | 68 +-- inseq/data/attribution.py | 505 +++++++++++++++--- inseq/data/batch.py | 31 +- inseq/data/data_utils.py | 10 +- inseq/data/viz.py | 383 ++++++++++++- inseq/models/__init__.py | 4 +- inseq/models/attribution_model.py | 101 ++-- inseq/models/decoder_only.py | 19 +- inseq/models/encoder_decoder.py | 27 +- inseq/models/huggingface_model.py | 48 +- inseq/models/model_config.py | 3 +- inseq/models/model_decorators.py | 3 +- inseq/utils/alignment_utils.py | 31 +- inseq/utils/argparse.py | 25 +- inseq/utils/cache.py | 3 +- inseq/utils/contrast_utils.py | 19 +- inseq/utils/hooks.py | 4 +- inseq/utils/misc.py | 47 +- inseq/utils/serialization.py | 29 +- inseq/utils/torch_utils.py | 36 +- inseq/utils/typing.py | 27 +- inseq/utils/viz_utils.py | 97 +++- pyproject.toml | 8 +- requirements-dev.txt | 7 +- requirements.txt | 5 +- tests/attr/feat/test_attribution_utils.py | 6 +- tests/attr/feat/test_feature_attribution.py | 8 +- tests/attr/feat/test_step_functions.py | 4 +- tests/models/test_huggingface_model.py | 6 +- 62 files changed, 1674 insertions(+), 774 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 73a000bd..46d1813b 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -12,7 +12,7 @@ jobs: if: github.actor != 'dependabot[bot]' && github.actor != 'dependabot-preview[bot]' strategy: matrix: - python-version: ["3.9", "3.10", "3.11", "3.12"] + python-version: ["3.10", "3.11", "3.12"] steps: - uses: actions/checkout@v3 diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 0481a6c0..f19ccb34 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -14,7 +14,7 @@ sphinx: build: os: ubuntu-20.04 tools: - python: "3.9" + python: "3.10" python: install: diff --git a/CHANGELOG.md b/CHANGELOG.md index e1e91eba..15ea1f9b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,10 @@ ## 🚀 Features +- Added [treescope](https://github.com/google-deepmind/treescope) for interactive model and tensor visualization. ([#283](https://github.com/inseq-team/inseq/pull/283)) + +- New `treescope`-powered methods `FeatureAttributionOutput.show_granular` and `FeatureAttributionSequenceOutput.show_tokens` for interactive visualization of multidimensional attribution tensors and token highlights. ([#283](https://github.com/inseq-team/inseq/pull/283)) + - Added new models `DbrxForCausalLM`, `OlmoForCausalLM`, `Phi3ForCausalLM`, `Qwen2MoeForCausalLM`, `Gemma2ForCausalLM` to model config. - Add `rescale_attributions` to Inseq CLI commands for `rescale=True` ([#280](https://github.com/inseq-team/inseq/pull/280)). @@ -84,8 +88,8 @@ out_female = attrib_model.attribute( ## 📝 Documentation and Tutorials -*No changes* +- Updated tutorial with `treescope` usage examples. ## 💥 Breaking Changes -*No changes* +- Dropped support for Python 3.9. Please use Python >= 3.10. ([#283](https://github.com/inseq-team/inseq/pull/283)) diff --git a/Makefile b/Makefile index 96febe14..711aa0e0 100644 --- a/Makefile +++ b/Makefile @@ -82,7 +82,7 @@ fix-style: .PHONY: check-safety check-safety: - $(PYTHON) -m safety check --full-report -i 70612 -i 71670 + $(PYTHON) -m safety check --full-report -i 70612 -i 71670 -i 72089 .PHONY: lint lint: fix-style check-safety diff --git a/README.md b/README.md index 7757aa7d..4c2d73a0 100644 --- a/README.md +++ b/README.md @@ -13,12 +13,12 @@ [![Downloads](https://static.pepy.tech/badge/inseq)](https://pepy.tech/project/inseq) [![License](https://img.shields.io/github/license/inseq-team/inseq)](https://github.com/inseq-team/inseq/blob/main/LICENSE) [![Demo Paper](https://img.shields.io/badge/ACL%20Anthology%20-%20?logo=data%3Aimage%2Fx-icon%3Bbase64%2CAAABAAEAIBIAAAEAIABwCQAAFgAAACgAAAAgAAAAJAAAAAEAIAAAAAAAAAkAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAACEa7k0mH%2B%2F5JBzt%2FyQc7f8kHO3%2FJBzt%2FyQc7f8kHO3%2FJBzt%2FyQc7f8kHO3%2FJB3t%2FyQd7f8kHe3%2FJBzt%2FyQc7f8kHO3%2FJBzt%2FyQd7f8kHO3%2FJB3t%2FyQd7f8kHO3%2FJB3t%2FyQd7f8kHe3%2FJB3t%2FyMc79EkGP8VAAAAAAAAAAAAAAAAIRruTSYf7%2FkkHO3%2FJBzt%2FyQd7f8kHe3%2FJB3t%2FyQd7f8kHO3%2FJB3t%2FyQc7f8kHe3%2FJBzt%2FyQc7f8kHO3%2FJB3t%2FyQc7f8kHO3%2FJBzt%2FyQc7f8kHO3%2FJB3t%2FyQc7f8kHO3%2FJB3t%2FyQd7f8kHe3%2FIxzv0SQY%2FxUAAAAAAAAAAAAAAAAhIe5NJh%2Fv%2BSQd7f8kHe3%2FJB3t%2FyQd7f8kHO3%2FJBzt%2FyQc7f8kHe3%2FJBzt%2FyQd7f8kHe3%2FJBzt%2FyQc7f8kHe3%2FJB3t%2FyQd7f8kHe3%2FJB3t%2FyQc7f8kHe3%2FJB3t%2FyQd7f8kHe3%2FJBzt%2FyQd7f8jHO%2FRJBj%2FFQAAAAAAAAAAAAAAACEa7k0mH%2B%2F5JB3t%2FyQc7f8kHO3%2FJBzt%2FyQc7f8kHe3%2FJBzt%2FyQd7f8kHe3%2FJBzt%2FyQc7f8kHO3%2FJBzt%2FyQc7f8kHe3%2FJBzt%2FyQd7f8kHe3%2FJB3t%2FyQc7f8kHe3%2FJB3t%2FyQc7f8kHO3%2FJBzt%2FyMc79EkGP8VAAAAAAAAAAAAAAAAIRruTSYf7%2FkkHO3%2FJBzt%2FyQc7f8kHe3%2FJB3t%2FyQd7f8kHO3%2FJB3t%2FyQd7f8kHO3%2FJB3t%2FyQc7f8kHe3%2FJBzt%2FyQc7f8kHe3%2FJB3t%2FyQd7f8kHe3%2FJBzt%2FyQd7f8kHO3%2FJB3t%2FyQc7f8kHO3%2FIxzv0SQY%2FxUAAAAAAAAAAAAAAAAhGu5NJh%2Fv%2BSQd7f8kHO3%2FJBzt%2FyQd7f8jIOzYIxvtgiQc8X8kHPF%2FJBzxfyQc8X8kHPF%2FJBzxfyQc8X8kHPF%2FIx%2FuiiMf7OgkHe3%2FJBzt%2FyQc7f8kHO3%2FJhzs9CUg7JYkHPF%2FJBzxfyQc8X8iHfBoMzP%2FCgAAAAAAAAAAAAAAACEa7k0mHu%2F5JBzt%2FyQc7f8kHO3%2FJBzt%2FyQb7LEAAP8FAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAkGP8VIxzv0SQc7f8kHe3%2FJB3t%2FyQc7f8jHOzpIhzuLQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAIRruTSYf7%2FkkHe3%2FJB3t%2FyQd7f8kHO3%2FJBvssQAA%2FwUAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAACQY%2FxUjHO%2FRJB3t%2FyQc7f8kHO3%2FJBzt%2FyMb7OkiHO4tAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAhGu5NJh%2Fv%2BSQd7f8kHe3%2FJBzt%2FyQc7f8kHuyxAAD%2FBQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAJBj%2FFSMc79EkHe3%2FJBzt%2FyQc7f8kHO3%2FIxvs6SIc7i0AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAACEa7k0mH%2B%2F5JBzt%2FyQc7f8kHO3%2FJBzt%2FyQb7LEAAP8FAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAkGP8VIx3v0SQd7f8kHe3%2FJBzt%2FyQc7f8jHOzpIhzuLQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAIRruTSYf7%2FkkHO3%2FJBzt%2FyQc7f8kHe3%2FJBzuxSgi81MhGu5NIRruTSEa7k0hGu5NISHuTSEh7k0hGu5NIRruTSIa72EjHe3aJBzt%2FyQd7f8kHO3%2FJBzt%2FyMc7OkiHO4tAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAhGu5NJh7v%2BSQc7f8kHe3%2FJBzt%2FyQc7f8kHO3%2FJh%2Fv%2BSYf7%2FkmHu%2F5Jh%2Fv%2BSYf7%2FkmH%2B%2F5Jh%2Fv%2BSYf7%2FkmH%2B%2F5Jh7v%2BSQc7f8kHO3%2FJBzt%2FyQc7f8kHO3%2FIxzs6SIc7i0AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAACEa7k0mHu%2F5JBzt%2FyQd7f8kHO3%2FJBzt%2FyQc7f8kHe3%2FJB3t%2FyQd7f8kHO3%2FJBzt%2FyQc7f8kHe3%2FJB3t%2FyQc7f8kHO3%2FJBzt%2FyQc7f8kHe3%2FJBzt%2FyQd7f8jHOzpIhzuLQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAIRruTSYe7%2FkkHO3%2FJB3t%2FyQc7f8kHO3%2FJBzt%2FyQd7f8kHe3%2FJB3t%2FyQc7f8kHe3%2FJB3t%2FyQd7f8kHO3%2FJBzt%2FyQc7f8kHO3%2FJB3t%2FyQc7f8kHO3%2FJBzt%2FyMc7OkiHO4tAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAhGu5NJh%2Fv%2BSQc7f8kHO3%2FJB3t%2FyQc7f8kHO3%2FJBzt%2FyQc7f8kHO3%2FJBzt%2FyQc7f8kHO3%2FJBzt%2FyQc7f8kHO3%2FJBzt%2FyQc7f8kHO3%2FJBzt%2FyQc7f8kHO3%2FIxzs6SIc7i0AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAACEa7k0mH%2B%2F5JB3t%2FyQc7f8kHe3%2FJBzt%2FyQc7f8kHO3%2FJB3t%2FyQc7f8kHO3%2FJBzt%2FyQd7f8kHO3%2FJBzt%2FyQc7f8kHO3%2FJB3t%2FyQc7f8kHO3%2FJBzt%2FyQc7f8jHOzpIhzuLQAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAKB%2FtOSUc7askHuyxJBvssSQe7LEkHuyxJB7ssSQe7LEkHuyxJBvssSQb7LEkHuyxJB7ssSQe7LEkHuyxJB7ssSUc7LMjHe31JB3t%2FyQd7f8kHe3%2FJBzt%2FyMc7OkiHO4tAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAD%2FBQAA%2FwUAAP8FAAD%2FBQAA%2FwUAAP8FAAD%2FBQAA%2FwUAAP8FAAD%2FBQAA%2FwUAAP8FAAD%2FBQAA%2FwUAAP8FHBzsGyMd7qYjHO%2FRIxzv0SMd79EjHO%2FRIx7tux4Y%2BSoAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADAAAABwAAAAcAAAAHAAAABwAAAAcAAAAHAP8A%2FwD%2FAP8A%2FwD%2FAP8A%2FwAAAP8AAAD%2FAAAA%2FwAAAP8AAAD%2FAAAA%2FwAAAP%2BAAAD8%3D&labelColor=white&color=red&link=https%3A%2F%2Faclanthology.org%2F2023.acl-demo.40%2F -)](http://arxiv.org/abs/2302.13942) +)](https://aclanthology.org/2023.acl-demo.40)
- [![Follow Inseq on Twitter]( https://img.shields.io/badge/Twitter-1DA1F2?style=for-the-badge&logo=twitter&logoColor=white)](https://twitter.com/InseqLib) + [![Follow Inseq on Twitter](https://img.shields.io/badge/Twitter-1DA1F2?style=for-the-badge&logo=twitter&logoColor=white)](https://twitter.com/InseqLib) [![Join the Inseq Discord server](https://img.shields.io/badge/Discord-7289DA?style=for-the-badge&logo=discord&logoColor=white)](https://discord.gg/V5VgwwFPbu) [![Read the Docs](https://img.shields.io/badge/-Docs-blue?style=for-the-badge&logo=Read-the-Docs&logoColor=white&link=https://inseq.org)](https://inseq.org) [![Tutorial](https://img.shields.io/badge/-Tutorial-orange?style=for-the-badge&logo=Jupyter&logoColor=white&link=https://github.com/inseq-team/inseq/blob/main/examples/inseq_tutorial.ipynb)](https://github.com/inseq-team/inseq/blob/main/examples/inseq_tutorial.ipynb) @@ -30,7 +30,7 @@ Inseq is a Pytorch-based hackable toolkit to democratize access to common post-h ## Installation -Inseq is available on PyPI and can be installed with `pip` for Python >= 3.9, <= 3.12: +Inseq is available on PyPI and can be installed with `pip` for Python >= 3.10, <= 3.12: ```bash # Install latest stable version diff --git a/docs/source/main_classes/main_functions.rst b/docs/source/main_classes/main_functions.rst index ba8f2003..43e80c3f 100644 --- a/docs/source/main_classes/main_functions.rst +++ b/docs/source/main_classes/main_functions.rst @@ -38,4 +38,8 @@ functionalities required for its usage. .. autofunction:: show_attributions +.. autofunction:: show_granular_attributions + +.. autofunction:: show_token_attributions + .. autofunction:: merge_attributions diff --git a/examples/inseq_tutorial.ipynb b/examples/inseq_tutorial.ipynb index cf38b86b..0062d132 100644 --- a/examples/inseq_tutorial.ipynb +++ b/examples/inseq_tutorial.ipynb @@ -91,20 +91,14 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 22, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "/home/gsarti/.cache/pypoetry/virtualenvs/inseq-PzwjmCYf-py3.9/lib/python3.9/site-packages/transformers/models/marian/tokenization_marian.py:194: UserWarning: Recommended: pip install sacremoses.\n", - " warnings.warn(\"Recommended: pip install sacremoses.\")\n", - "/home/gsarti/.cache/pypoetry/virtualenvs/inseq-PzwjmCYf-py3.9/lib/python3.9/site-packages/transformers/generation/utils.py:1255: UserWarning: You have modified the pretrained model configuration to control generation. This is a deprecated strategy to control generation and will be removed soon, in a future version. Please use a generation configuration file (see https://huggingface.co/docs/transformers/main_classes/text_generation)\n", - " warnings.warn(\n", - "/home/gsarti/.cache/pypoetry/virtualenvs/inseq-PzwjmCYf-py3.9/lib/python3.9/site-packages/transformers/generation/utils.py:1349: UserWarning: Using `max_length`'s default (512) to control the generation length. This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we recommend using `max_new_tokens` to control the maximum length of the generation.\n", - " warnings.warn(\n", - "Attributing with input_x_gradient...: 100%|██████████| 20/20 [00:00<00:00, 24.44it/s]\n" + "Attributing with input_x_gradient...: 100%|██████████| 20/20 [00:02<00:00, 6.53it/s]\n" ] }, { @@ -112,21 +106,20 @@ "text/html": [ "
0th instance:
\n", "\n", - "
\n", - "
\n", - "
\n", + "
\n", + "
\n", + "
\n", " \n", - "
\n", + "
\n", "
\n", " Source Saliency Heatmap\n", "
\n", - " x: Generated tokens, y: Attributed tokens\n", + " → : Generated tokens, ↓ : Attributed tokens\n", "
\n", " \n", "\n", - " \n", - "
▁Ciao▁a▁tutti,▁spero▁che▁vi▁stiate▁godendo▁il▁tutorial!</s>
▁H0.1150.0960.0610.0760.0510.0860.0460.0630.0630.0650.0210.0280.0340.0310.0460.0110.0080.0710.065
ello0.1510.1370.0880.1120.0810.1010.050.0770.0820.0730.0240.0330.0420.0390.0550.0130.0080.0850.077
▁everyone0.1110.0930.0660.1470.150.0650.0540.0580.0640.0430.0180.0220.0250.0290.0390.010.0090.0590.057
,0.0620.0470.0310.050.0440.050.050.0420.0420.030.0120.0150.0170.020.0270.0130.0060.040.037
▁hope0.0820.060.0340.0560.0480.0550.2270.0790.0570.050.0360.0440.0290.0260.030.0120.0080.0460.043
▁you0.0360.0290.0140.0310.0230.0330.0420.0330.0330.0310.0190.0180.0120.0170.0190.0070.0050.030.027
'0.0450.0360.0140.0360.0240.0480.0350.0390.0360.0460.0240.0220.0160.0220.0270.010.0060.0390.034
re0.040.030.0110.0350.0210.0430.030.0340.0350.0470.0210.0270.0170.0190.0230.0070.0080.0360.032
▁enjoying0.0910.0660.0310.0660.0520.0930.0960.1010.1590.170.1810.3050.0890.0490.0520.0260.0210.0880.078
▁the0.0320.0230.0110.0220.0150.0290.0250.0280.0420.0380.0370.0620.0230.0430.0380.0420.0290.0270.025
▁tutor0.0930.0710.0450.0660.0430.0890.0590.070.0940.0870.0670.1220.0510.2890.2550.3320.1540.0850.073
ial0.0370.0320.0210.030.0210.0380.0320.0380.0540.0460.0390.0680.040.1460.0850.1740.2360.0430.035
!0.0540.0370.0410.0360.030.0540.0310.0380.040.0350.0130.0160.0420.0250.0290.0150.0220.0710.048
</s>0.0510.0330.0240.0350.0270.050.0270.0370.0370.0380.0180.0290.0240.0450.0410.0480.0430.0510.047
probability0.470.9110.910.870.9140.740.5120.4890.6660.7130.6550.7420.9450.7980.8760.9410.9560.8740.894
\n", + "0123456789101112131415161718▁Ciao▁a▁tutti,▁spero▁che▁vi▁stiate▁godendo▁il▁tutorial!</s>0▁H0.1150.0960.0610.0760.0510.0860.0460.0630.0630.0650.0210.0280.0350.0310.0460.0110.0080.0710.0651ello0.1510.1370.0880.1120.0810.1010.050.0770.0820.0730.0240.0330.0420.0390.0550.0130.0080.0850.0772▁everyone0.1110.0930.0660.1470.150.0650.0540.0580.0640.0430.0180.0220.0250.0290.0390.010.0090.0590.0573,0.0620.0470.0310.050.0440.050.050.0420.0420.030.0120.0150.0170.020.0270.0130.0060.040.0374▁hope0.0820.060.0340.0560.0480.0550.2270.0790.0570.050.0360.0440.0290.0260.030.0120.0080.0460.0435▁you0.0360.0290.0140.0310.0230.0330.0420.0330.0330.0310.0190.0180.0120.0170.0190.0070.0050.030.0276'0.0450.0360.0140.0360.0240.0480.0350.0390.0360.0460.0240.0220.0170.0220.0270.010.0060.0390.0347re0.040.030.0110.0350.0210.0430.030.0340.0350.0470.0210.0270.0170.0190.0230.0070.0080.0360.0328▁enjoying0.0910.0660.0310.0660.0520.0930.0960.1010.160.170.1810.3050.0890.0490.0520.0260.0210.0880.0789▁the0.0320.0230.0110.0220.0150.0290.0250.0280.0420.0380.0370.0620.0230.0430.0380.0420.0290.0270.02510▁tutor0.0930.0710.0450.0660.0430.0890.0590.070.0940.0870.0670.1220.0510.2890.2550.3320.1540.0850.07311ial0.0370.0320.0210.030.0210.0380.0320.0380.0540.0460.0390.0680.040.1460.0850.1740.2360.0430.03512!0.0540.0370.0410.0360.030.0540.0310.0380.040.0350.0130.0160.0420.0250.0290.0150.0220.0710.04813</s>0.0510.0330.0240.0350.0270.050.0270.0370.0370.0380.0180.0290.0240.0450.0410.0480.0430.0510.047probability0.470.9110.910.870.9140.740.5120.4890.6660.7130.6550.7430.9450.7980.8760.9410.9560.8740.894\n", "
\n", "\n", "
\n", @@ -135,21 +128,20 @@ "\n", "
0th instance:
\n", "\n", - "
\n", - "
\n", - "
\n", + "
\n", + "
\n", + "
\n", " \n", - "
\n", + "
\n", "
\n", " Target Saliency Heatmap\n", "
\n", - " x: Generated tokens, y: Attributed tokens\n", + " → : Generated tokens, ↓ : Attributed tokens\n", "
\n", " \n", "\n", - " \n", - "
▁Ciao▁a▁tutti,▁spero▁che▁vi▁stiate▁godendo▁il▁tutorial!</s>
▁Ci0.2090.2380.0650.0770.0350.0430.0210.0220.0170.0120.0120.010.0120.0140.0060.0050.0140.02
a0.2680.0610.0860.0350.0380.0210.020.0190.0090.0140.0080.0120.0150.0060.0030.0140.019
o0.0780.1130.0340.0410.0180.0150.0140.0120.010.0080.0090.0130.0030.0030.0120.016
▁a0.0940.030.0230.0190.0140.0110.0070.0080.0060.0090.0110.0040.0020.010.01
▁tutti0.0330.0260.0210.0240.0110.0110.0090.0070.0080.010.0050.0030.0090.011
,0.0270.0270.0080.0160.0140.010.0110.0090.0120.0040.0030.0110.012
▁spero0.1340.0440.0580.0550.0210.0310.0170.0210.0090.0060.0260.023
▁che0.0130.0210.0470.0120.0220.010.0110.0050.0070.010.012
▁vi0.0330.1170.030.0450.0180.0180.0050.0090.0130.013
▁stia0.1880.0320.0770.0230.0230.0060.0060.020.023
te0.0320.0490.0140.0160.0050.0050.0090.013
▁gode0.2650.0420.0410.0240.0160.0220.03
ndo0.0190.0110.010.0090.0070.011
▁il0.020.0320.0280.0090.01
▁tu0.1540.1450.0130.016
tori0.1860.0130.013
al0.0140.017
!0.053
</s>
probability0.470.9110.910.870.9140.740.5120.4890.6660.7130.6550.7420.9450.7980.8760.9410.9560.8740.894
\n", + "0123456789101112131415161718▁Ciao▁a▁tutti,▁spero▁che▁vi▁stiate▁godendo▁il▁tutorial!</s>0<pad>0.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.00.01▁Ci0.2090.2380.0650.0770.0350.0430.0210.0220.0170.0120.0120.010.0120.0140.0060.0050.0140.022a0.2680.0610.0860.0350.0380.0210.020.0190.0090.0140.0080.0120.0150.0060.0030.0140.0193o0.0780.1130.0340.0410.0180.0150.0140.0120.010.0080.0090.0130.0030.0030.0120.0164▁a0.0940.030.0230.0190.0140.0110.0070.0080.0060.0090.0110.0040.0020.010.015▁tutti0.0330.0260.0210.0240.0110.0110.0090.0070.0080.010.0050.0030.0090.0116,0.0270.0270.0080.0160.0140.010.0110.0090.0120.0040.0030.0110.0127▁spero0.1340.0440.0580.0550.0210.0310.0170.0210.0090.0060.0260.0238▁che0.0130.0210.0470.0120.0220.010.0110.0050.0070.010.0129▁vi0.0330.1170.030.0450.0180.0180.0050.0090.0130.01310▁stia0.1880.0320.0770.0230.0230.0060.0060.020.02311te0.0320.0490.0140.0160.0050.0050.0090.01312▁gode0.2640.0420.0410.0240.0160.0220.0313ndo0.0190.0110.010.0090.0070.01114▁il0.020.0320.0280.0090.0115▁tu0.1540.1450.0130.01616tori0.1860.0130.01317al0.0140.01718!0.05319</s>probability0.470.9110.910.870.9140.740.5120.4890.6660.7130.6550.7430.9450.7980.8760.9410.9560.8740.894\n", "
\n", "\n", "
\n", @@ -201,7 +193,7 @@ "\n", "### The FeatureAttributionOutput class\n", "\n", - "The output returned by the `model.attribute` function contains attribution tensors, plus rich information regarding the attribution settings:" + "The output returned by the `model.attribute` function contains attribution tensors, plus rich information regarding the attribution settings. From v0.7 onwards, in notebook environments `inseq` uses [`treescope`](https://github.com/google-deepmind/treescope) as default visualizer to provide interactive visualizations of objects and tensors:" ] }, { @@ -209,54 +201,89 @@ "execution_count": 2, "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/gsarti/Documents/projects/inseq/.venv/lib/python3.12/site-packages/torch/_tensor.py:1419: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/c10/core/TensorImpl.h:1924.)\n", + " return super().rename(names)\n" + ] + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, { "data": { + "text/html": [ + "" + ], "text/plain": [ "FeatureAttributionOutput({\n", - " sequence_attributions: list with 1 elements of type GranularFeatureAttributionSequenceOutput:[\n", + " sequence_attributions: list with 1 elements of type GranularFeatureAttributionSequenceOutput: [\n", " GranularFeatureAttributionSequenceOutput({\n", - " source: list with 14 elements of type TokenWithId:[\n", + " source: list with 14 elements of type TokenWithId: [\n", " '▁H', 'ello', '▁everyone', ',', '▁hope', '▁you', ''', 're', '▁enjoying', '▁the', '▁tutor', 'ial', '!', ''\n", " ],\n", - " target: list with 19 elements of type TokenWithId:[\n", - " '▁Ci', 'a', 'o', '▁a', '▁tutti', ',', '▁spero', '▁che', '▁vi', '▁stia', 'te', '▁gode', 'ndo', '▁il', '▁tu', 'tori', 'al', '!', ''\n", + " target: list with 20 elements of type TokenWithId: [\n", + " '', '▁Ci', 'a', 'o', '▁a', '▁tutti', ',', '▁spero', '▁che', '▁vi', '▁stia', 'te', '▁gode', 'ndo', '▁il', '▁tu', 'tori', 'al', '!', ''\n", " ],\n", " source_attributions: torch.float32 tensor of shape [14, 19, 512] on cpu,\n", - " target_attributions: torch.float32 tensor of shape [19, 19, 512] on cpu,\n", + " target_attributions: torch.float32 tensor of shape [20, 19, 512] on cpu,\n", " step_scores: {\n", " probability: torch.float32 tensor of shape [19] on cpu: [\n", " 0.47, 0.91, 0.91, 0.87, 0.91, 0.74, 0.51, 0.49, 0.67, 0.71, 0.65, 0.74, 0.95, 0.80, 0.88, 0.94, 0.96, 0.87, 0.89\n", " ],\n", " },\n", " sequence_scores: {},\n", - " attr_pos_start: 0,\n", - " attr_pos_end: 19,\n", + " attr_pos_start: 1,\n", + " attr_pos_end: 20,\n", " })\n", " ],\n", " step_attributions: None,\n", " info: {\n", - " attribution_method: input_x_gradient,\n", + " attribution_method: \"input_x_gradient\",\n", " attr_pos_start: 1,\n", " attr_pos_end: 20,\n", " output_step_attributions: False,\n", " attribute_target: True,\n", - " step_scores: list with 1 elements of type str:[\n", + " step_scores: list with 1 elements of type str: [\n", " 'probability'\n", " ],\n", - " exec_time: 0.777145,\n", - " model_name: Helsinki-NLP/opus-mt-en-it,\n", - " model_class: MarianMTModel,\n", + " exec_time: 3.076626,\n", + " model_name: \"Helsinki-NLP/opus-mt-en-it\",\n", + " model_class: \"MarianMTModel\",\n", " tokenizer_name: None,\n", - " tokenizer_class: MarianTokenizer,\n", + " tokenizer_class: \"MarianTokenizer\",\n", " include_eos_baseline: False,\n", - " attributed_fn: probability_fn,\n", + " attributed_fn: \"probability_fn\",\n", " attribution_args: {},\n", " attributed_fn_args: {},\n", " step_scores_args: {},\n", - " input_texts: list with 1 elements of type str:[\n", + " input_texts: list with 1 elements of type str: [\n", " 'Hello everyone, hope you're enjoying the tutorial!'\n", " ],\n", - " generated_texts: list with 1 elements of type str:[\n", + " generated_texts: list with 1 elements of type str: [\n", " 'Ciao a tutti, spero che vi stiate godendo il tutorial!'\n", " ],\n", " generation_args: {},\n", @@ -293,10 +320,11 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ + "# Use compress=True and scores_precision=\"float16\" or \"float8\" to reduce the size of saved outputs\n", "out.save(\"marian_en_it_attribution.json\")\n", "\n", "# Reload the saved output\n", @@ -308,6 +336,65 @@ " f.write(html)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "From v0.7 onwards, two new methods of `FeatureAttributionOutput` and `FeatureAttributionSequenceOutput` classes are available:\n", + "\n", + "- [`show_granular`](https://inseq.org/en/latest/main_classes/data_classes.html#inseq.data.attribution.FeatureAttributionSequenceOutput.show_granular) can be used to visualize multidimensional attribution matrices without aggregation.\n", + "\n", + "- [`show_tokens`](https://inseq.org/en/latest/main_classes/data_classes.html#inseq.data.attribution.FeatureAttributionSequenceOutput.show_tokens) can be used to visualize aggregated attribution scores (the same as `show` method) as highlights over input tokens for every attributed tokens, facilitating the visualization for long sequences." + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "Source Saliency HeatmapTarget Saliency Heatmap" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Use the Embedding Dimension slider to visualize specific dimensions of the gradient attribution tensor\n", + "# Equivalent to out.show(do_aggregation=False)\n", + "out.show_granular()" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "Generated text with probability highlights:\n", + "▁Ci a o ▁a ▁tutti , ▁spero ▁che ▁vi ▁stia te ▁gode ndo ▁il ▁tu tori al ! " + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Adding highlights over generated tokens using probability scores\n", + "out.show_tokens(step_score_highlight=\"probability\")" + ] + }, { "attachments": {}, "cell_type": "markdown", @@ -324,23 +411,38 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "Attributing with attention...: 100%|██████████| 20/20 [00:00<00:00, 98.32it/s] \n" + "Attributing with attention...: 2it [00:00, 19.00it/s] \n" ] }, { "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "" + ], "text/plain": [ "torch.Size([14, 19, 6, 8])" ] }, - "execution_count": 3, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } @@ -363,7 +465,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -371,21 +473,20 @@ "text/html": [ "
0th instance:
\n", "\n", - "
\n", - "
\n", - "
\n", + "
\n", + "
\n", + "
\n", " \n", - "
\n", + "
\n", "
\n", " Source Saliency Heatmap\n", "
\n", - " x: Generated tokens, y: Attributed tokens\n", + " → : Generated tokens, ↓ : Attributed tokens\n", "
\n", " \n", "\n", - " \n", - "
▁Ciao▁a▁tutti,▁spero▁che▁vi▁stiate▁godendo▁il▁tutorial!</s>
▁H0.1060.1440.0630.0560.0140.0120.0140.0060.0050.0020.0030.0030.0020.0030.0030.0030.0010.0030.006
ello0.0860.160.0910.0820.0180.0160.0130.0060.0060.0030.0030.0030.0020.0030.0030.0030.0020.0030.005
▁everyone0.0650.0520.0470.2010.1880.1550.0290.0170.0120.0050.0050.0040.0030.0040.0050.0070.0030.0040.004
,0.0910.0230.0130.0270.0210.0920.1530.0210.0150.0050.0040.0030.0040.0040.0060.0020.0010.0020.015
▁hope0.0570.0260.0150.020.0170.060.2090.1440.0860.0190.0130.010.0080.0080.0050.0050.0030.0050.004
▁you0.0130.0080.0030.0040.0080.010.0170.0790.0950.1290.0220.0360.0070.0080.0080.0040.0020.0040.003
'0.0060.0030.0010.0020.0040.0040.0090.0230.0250.0210.0430.0220.0040.0040.0030.0010.0010.0020.002
re0.0080.0040.0020.0040.0050.0060.0110.0320.0350.0320.0630.0340.0060.0070.0040.0020.0010.0030.003
▁enjoying0.0170.0110.0050.0060.0180.0140.0250.0810.110.2230.1420.1890.1350.0980.0210.010.0060.0110.007
▁the0.0310.0040.0040.0050.0070.0080.0110.0250.0350.0310.0320.0230.0230.090.1380.0090.0060.0230.013
▁tutor0.020.0070.0060.0080.0070.0090.0060.0120.0130.0110.0140.0220.0180.0370.2070.2180.0870.0510.013
ial0.0150.0040.0030.0040.0030.0050.0030.0050.0050.0060.0060.010.0090.0140.0540.0750.1250.070.01
!0.0990.0210.0170.0230.0140.020.0280.0220.0190.0150.0120.0130.0240.0250.0340.0140.0110.140.182
</s>0.3860.3520.4420.2590.3810.2480.1350.2290.1790.2020.2780.2390.4290.2960.1820.3930.4270.2820.388
\n", + "0123456789101112131415161718▁Ciao▁a▁tutti,▁spero▁che▁vi▁stiate▁godendo▁il▁tutorial!</s>0▁H0.0530.0840.0430.0410.0110.010.010.0040.0040.0020.0020.0030.0020.0020.0020.0020.0010.0020.0051ello0.0430.0920.0630.0610.0140.0120.010.0040.0050.0020.0020.0030.0020.0020.0020.0020.0020.0020.0042▁everyone0.0320.0310.0320.1310.130.1170.0220.0120.0090.0040.0040.0040.0020.0040.0040.0060.0020.0030.0033,0.0460.0140.0090.0190.0160.070.1160.0150.0130.0040.0030.0030.0030.0030.0050.0010.0010.0020.0114▁hope0.0290.0160.0110.0150.0130.0460.1530.0960.0720.0140.0110.0080.0060.0070.0040.0030.0020.0040.0035▁you0.0060.0050.0020.0030.0060.0080.0140.0570.0710.090.0180.0320.0050.0070.0060.0030.0020.0040.0026'0.0030.0020.0010.0020.0030.0030.0070.0170.0190.0150.0330.0180.0030.0040.0030.0010.0010.0020.0027re0.0040.0020.0010.0030.0040.0040.0090.0230.0270.0230.0480.0280.0040.0060.0040.0010.0010.0030.0028▁enjoying0.0080.0070.0040.0040.0140.0110.0190.060.0830.1560.1120.1480.0920.0850.0160.0070.0040.0090.0059▁the0.0160.0030.0030.0040.0060.0060.0080.0190.0270.0210.0250.0180.0170.0710.1040.0060.0050.0190.0110▁tutor0.010.0040.0050.0060.0050.0070.0040.0090.010.0080.0120.0180.0130.0290.150.1440.0620.0450.0111ial0.0080.0030.0020.0030.0020.0030.0020.0030.0040.0040.0050.0080.0060.0120.0390.0530.090.0620.00812!0.050.0130.0130.0170.0110.0150.0220.0160.0150.0110.0090.0110.0180.020.0260.010.0090.110.13713</s>0.1930.2250.3110.1920.2660.1870.1030.1640.1420.1460.2180.1970.3260.2480.1350.2610.3190.2320.299\n", "
\n", "\n", "
\n", @@ -394,21 +495,20 @@ "\n", "
0th instance:
\n", "\n", - "
\n", - "
\n", - "
\n", + "
\n", + "
\n", + "
\n", " \n", - "
\n", + "
\n", "
\n", " Target Saliency Heatmap\n", "
\n", - " x: Generated tokens, y: Attributed tokens\n", + " → : Generated tokens, ↓ : Attributed tokens\n", "
\n", " \n", "\n", - " \n", - "
▁Ciao▁a▁tutti,▁spero▁che▁vi▁stiate▁godendo▁il▁tutorial!</s>
▁Ci0.180.1380.0690.030.0170.0150.0050.0080.0030.0050.0040.0020.0030.0030.0020.0010.0030.005
a0.150.1090.0420.0350.0210.0050.0060.0040.0040.0030.0020.0020.0020.0020.0020.0020.004
o0.1220.1320.0570.0440.010.0090.0040.0040.0030.0020.0030.0030.0020.0020.0040.004
▁a0.0920.0790.0270.0070.0070.0030.0030.0020.0020.0020.0020.0020.0010.0020.002
▁tutti0.1540.0940.0260.0250.0050.0050.0050.0030.0040.0040.0020.0020.0050.004
,0.1360.0640.030.0120.0090.0080.0040.0090.0060.0030.0030.0120.007
▁spero0.1820.1710.0510.0210.0180.0050.0130.0080.0030.0030.0140.007
▁che0.1060.1030.0450.0280.0090.0190.0070.0030.0030.0110.006
▁vi0.110.1380.0930.0160.0290.010.0030.0050.0070.005
▁stia0.1290.1190.0370.0360.010.0030.0030.0120.006
te0.1030.1020.0450.0120.0040.0050.0120.006
▁gode0.1440.1580.0680.010.0070.0130.006
ndo0.0770.0580.010.010.0180.007
▁il0.1350.0580.030.0410.014
▁tu0.1480.1380.0570.011
tori0.1090.0860.011
al0.0960.042
!0.198
</s>
\n", + "0123456789101112131415161718▁Ciao▁a▁tutti,▁spero▁che▁vi▁stiate▁godendo▁il▁tutorial!</s>0<pad>0.50.380.2780.2680.2840.2370.2360.2780.2090.2780.2090.1770.240.1580.2450.3160.2460.1630.2271▁Ci0.120.1170.0580.0220.0130.0110.0040.0070.0020.0030.0030.0010.0020.0020.0020.0010.0030.0042a0.1050.0890.030.0260.0160.0040.0040.0020.0030.0030.0010.0020.0020.0010.0010.0020.0033o0.0860.0950.0430.0320.0070.0070.0030.0030.0020.0010.0030.0020.0010.0010.0030.0034▁a0.0690.0610.0210.0050.0060.0020.0020.0020.0010.0020.0010.0010.0010.0010.0025▁tutti0.120.0760.0190.020.0040.0040.0040.0020.0040.0030.0020.0020.0040.0036,0.1070.0470.0240.0080.0070.0060.0030.0070.0040.0020.0020.0090.0057▁spero0.1360.140.0360.0160.0150.0040.0110.0060.0020.0020.0110.0058▁che0.0830.0780.0350.0230.0070.0150.0050.0020.0020.0090.0059▁vi0.0860.110.0750.0120.0240.0080.0020.0040.0060.00410▁stia0.1080.1070.0280.0310.0070.0020.0020.010.00411te0.0820.0830.0390.0090.0030.0030.010.00412▁gode0.1160.1430.0530.0070.0050.0120.00413ndo0.0610.0460.0070.0070.0150.00514▁il0.1060.0430.0230.0340.0115▁tu0.1080.1170.0510.00916tori0.080.0780.00817al0.0780.03518!0.15819</s>\n", "
\n", "\n", "
\n", @@ -441,15 +541,15 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Aggregators: ['spans', 'pair', 'subwords', 'scores']\n", - "Aggregation functions: ['vnorm', 'absmax', 'prod', 'sum', 'max', 'min', 'mean']\n" + "Aggregators: ['scores', 'slices', 'pair', 'spans', 'subwords']\n", + "Aggregation functions: ['absmax', 'prod', 'min', 'vnorm', 'mean', 'sum', 'max']\n" ] } ], @@ -460,7 +560,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, "metadata": {}, "outputs": [ { @@ -476,21 +576,20 @@ "text/html": [ "
0th instance:
\n", "\n", - "
\n", - "
\n", - "
\n", + "
\n", + "
\n", + "
\n", " \n", - "
\n", + "
\n", "
\n", " Source Saliency Heatmap\n", "
\n", - " x: Generated tokens, y: Attributed tokens\n", + " → : Generated tokens, ↓ : Attributed tokens\n", "
\n", " \n", "\n", - " \n", - "
▁Ciao▁a▁tutti,▁spero▁che▁vi▁stiate▁godendo▁il▁tutorial!</s>
▁H0.130.0130.0030.0270.0020.0070.0050.0030.0010.0030.0070.0050.00.0010.0010.0020.0010.0010.009
ello0.0930.0240.0040.0240.0020.0060.0030.0020.0010.0020.0040.0040.00.00.0010.0030.0010.0010.009
▁everyone0.0760.0740.0240.1540.390.090.0190.1190.0760.0040.0190.0090.0030.0110.0120.0170.0110.0050.007
,0.0210.0790.0650.0250.0090.0230.0130.0110.0020.0010.0010.0010.030.0010.00.0010.0020.0010.001
▁hope0.260.0140.010.0870.0060.3740.5020.0260.0170.0090.0120.0150.0010.0010.0010.0030.0030.0030.005
▁you0.0090.0030.0010.0020.0050.010.0320.0470.0710.0120.010.0060.0130.0120.0060.0020.0040.0080.003
'0.0070.0050.0010.0070.00.0140.0230.0420.0360.0480.0210.0330.0010.0060.0060.0010.0020.0050.001
re0.010.0050.0010.010.00.0190.030.0570.0550.070.0320.0550.0010.0090.010.0020.0030.0070.002
▁enjoying0.0070.0070.0040.0040.0010.0130.0180.2140.2470.3180.450.4480.0020.0160.0060.0120.0050.0060.005
▁the0.0050.0030.0050.0030.0010.0050.0040.0370.040.010.020.0440.0210.120.0460.0110.0090.0070.003
▁tutor0.0120.0020.0060.0110.0010.0120.0060.0180.020.0090.0140.0330.0030.0350.6410.0680.0160.0160.013
ial0.0010.0020.0040.0020.00.0030.0010.0090.0060.0020.0020.0040.0030.010.0590.0820.150.0170.004
!0.0520.0530.150.0410.010.0280.0190.0390.0170.0040.0130.0140.2270.0240.0090.0130.020.0410.014
</s>0.3150.7160.7210.6020.5730.3960.3250.3770.4110.5090.3950.3290.6950.7540.20.7830.7720.8820.925
\n", + "0123456789101112131415161718▁Ciao▁a▁tutti,▁spero▁che▁vi▁stiate▁godendo▁il▁tutorial!</s>0▁H0.130.0130.0030.0270.0020.0070.0050.0030.0010.0030.0070.0050.00.0010.0010.0020.0010.0010.0091ello0.0930.0240.0040.0240.0020.0060.0030.0020.0010.0020.0040.0040.00.00.0010.0030.0010.0010.0092▁everyone0.0760.0740.0240.1540.390.090.0190.1190.0760.0040.0190.0090.0030.0110.0120.0170.0110.0050.0073,0.0210.0790.0650.0250.0090.0230.0130.0110.0020.0010.0010.0010.030.0010.00.0010.0020.0010.0014▁hope0.260.0140.010.0870.0060.3740.5020.0260.0170.0090.0120.0150.0010.0010.0010.0030.0030.0030.0055▁you0.0090.0030.0010.0020.0050.010.0320.0470.0710.0120.010.0060.0130.0120.0060.0020.0040.0080.0036'0.0070.0050.0010.0070.00.0140.0230.0420.0360.0480.0210.0330.0010.0060.0060.0010.0020.0050.0017re0.010.0050.0010.010.00.0190.030.0570.0550.070.0320.0550.0010.0090.010.0020.0030.0070.0028▁enjoying0.0070.0070.0040.0040.0010.0130.0180.2140.2470.3180.450.4480.0020.0160.0060.0120.0050.0060.0059▁the0.0050.0030.0050.0030.0010.0050.0040.0370.040.010.020.0440.0210.120.0460.0110.0090.0070.00310▁tutor0.0120.0020.0060.0110.0010.0120.0060.0180.020.0090.0140.0330.0030.0350.6410.0680.0160.0160.01311ial0.0010.0020.0040.0020.00.0030.0010.0090.0060.0020.0020.0040.0030.010.0590.0820.150.0170.00412!0.0520.0530.150.0410.010.0280.0190.0390.0170.0040.0130.0140.2270.0240.0090.0130.020.0410.01413</s>0.3150.7160.7210.6020.5730.3960.3250.3770.4110.5090.3950.3290.6950.7540.20.7830.7720.8820.925\n", "
\n", "\n", "
\n", @@ -499,21 +598,20 @@ "\n", "
0th instance:
\n", "\n", - "
\n", - "
\n", - "
\n", + "
\n", + "
\n", + "
\n", " \n", - "
\n", + "
\n", "
\n", " Target Saliency Heatmap\n", "
\n", - " x: Generated tokens, y: Attributed tokens\n", + " → : Generated tokens, ↓ : Attributed tokens\n", "
\n", " \n", "\n", - " \n", - "
▁Ciao▁a▁tutti,▁spero▁che▁vi▁stiate▁godendo▁il▁tutorial!</s>
▁Ci0.3520.0020.0060.030.0060.010.0010.0010.0030.0010.0050.00.00.0010.0010.0010.00.004
a0.410.0080.0110.0040.0080.00.00.0010.00.00.00.00.0010.0010.0020.00.004
o0.4630.0260.0230.0150.0110.0060.0010.0030.0010.0010.00.0030.00.00.0010.004
▁a0.240.0150.0060.0020.0020.0010.0010.00.0010.0010.0010.0010.00.0010.001
▁tutti0.4980.0940.1420.1320.0050.0170.0040.0020.0050.0110.0040.0020.0070.009
,0.4560.0390.0380.0110.0130.0090.0080.0110.0160.0040.0050.0270.011
▁spero0.5840.1120.0960.0370.0250.0070.0380.0220.0010.0010.0050.023
▁che0.5230.2330.0550.0440.0130.0490.010.0010.0010.010.011
▁vi0.1020.1980.0920.0150.0150.0210.0040.0010.0050.007
▁stia0.5220.1150.1040.0090.0060.0010.00.0030.005
te0.3580.2150.0130.0170.0030.0020.0040.01
▁gode0.3360.0020.0020.0010.00.00.002
ndo0.470.0490.0040.0030.0180.009
▁il0.2980.0250.0080.0180.009
▁tu0.4480.0060.0010.003
tori0.2690.0010.005
al0.4490.068
!0.416
</s>
\n", + "0123456789101112131415161718▁Ciao▁a▁tutti,▁spero▁che▁vi▁stiate▁godendo▁il▁tutorial!</s>0<pad>1.00.6480.5880.5220.6930.4530.4110.220.1870.5470.1520.3460.2970.3870.540.4990.6980.450.3991▁Ci0.3520.0020.0060.030.0060.010.0010.0010.0030.0010.0050.00.00.0010.0010.0010.00.0042a0.410.0080.0110.0040.0080.00.00.0010.00.00.00.00.0010.0010.0020.00.0043o0.4630.0260.0230.0150.0110.0060.0010.0030.0010.0010.00.0030.00.00.0010.0044▁a0.240.0150.0060.0020.0020.0010.0010.00.0010.0010.0010.0010.00.0010.0015▁tutti0.4980.0940.1420.1320.0050.0170.0040.0020.0050.0110.0040.0020.0070.0096,0.4560.0390.0380.0110.0130.0090.0080.0110.0160.0040.0050.0270.0117▁spero0.5840.1120.0960.0370.0250.0070.0380.0220.0010.0010.0050.0238▁che0.5230.2330.0550.0440.0130.0490.010.0010.0010.010.0119▁vi0.1020.1980.0920.0150.0150.0210.0040.0010.0050.00710▁stia0.5220.1150.1040.0090.0060.0010.00.0030.00511te0.3580.2150.0130.0170.0030.0020.0040.0112▁gode0.3360.0020.0020.0010.00.00.00213ndo0.470.0490.0040.0030.0180.00914▁il0.2980.0250.0080.0180.00915▁tu0.4480.0060.0010.00316tori0.2690.0010.00517al0.4490.06818!0.41619</s>\n", "
\n", "\n", "
\n", @@ -561,7 +659,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -569,21 +667,20 @@ "text/html": [ "
0th instance:
\n", "\n", - "
\n", - "
\n", - "
\n", + "
\n", + "
\n", + "
\n", " \n", - "
\n", + "
\n", "
\n", " Source Saliency Heatmap\n", "
\n", - " x: Generated tokens, y: Attributed tokens\n", + " → : Generated tokens, ↓ : Attributed tokens\n", "
\n", " \n", "\n", - " \n", - "
▁Ciao▁a▁tutti,▁spero▁che▁vi▁stiate▁godendo▁il▁tutorial!</s>
▁Hello0.130.0270.0070.0050.0030.0010.0070.0050.0010.0030.009
▁everyone,0.0790.1540.390.0190.1190.0760.0190.030.0110.0170.007
▁hope0.260.0870.3740.5020.0260.0170.0120.0150.0010.0030.005
▁you're0.010.010.0190.0320.0570.0710.070.0550.0120.010.003
▁enjoying0.0070.0040.0130.0180.2140.2470.450.4480.0160.0120.005
▁the0.0050.0030.0050.0040.0370.040.020.0440.120.0460.003
▁tutorial!0.150.0410.0280.0190.0390.020.0140.2270.0350.6410.014
</s>0.7210.6020.5730.3250.3770.4110.5090.6950.7540.8820.925
\n", + "012345678910▁Ciao▁a▁tutti,▁spero▁che▁vi▁stiate▁godendo▁il▁tutorial!</s>0▁Hello0.130.0270.0070.0050.0030.0010.0070.0050.0010.0030.0091▁everyone,0.0790.1540.390.0190.1190.0760.0190.030.0110.0170.0072▁hope0.260.0870.3740.5020.0260.0170.0120.0150.0010.0030.0053▁you're0.010.010.0190.0320.0570.0710.070.0550.0120.010.0034▁enjoying0.0070.0040.0130.0180.2140.2470.450.4480.0160.0120.0055▁the0.0050.0030.0050.0040.0370.040.020.0440.120.0460.0036▁tutorial!0.150.0410.0280.0190.0390.020.0140.2270.0350.6410.0147</s>0.7210.6020.5730.3250.3770.4110.5090.6950.7540.8820.925\n", "
\n", "\n", "
\n", @@ -592,21 +689,20 @@ "\n", "
0th instance:
\n", "\n", - "
\n", - "
\n", - "
\n", + "
\n", + "
\n", + "
\n", " \n", - "
\n", + "
\n", "
\n", " Target Saliency Heatmap\n", "
\n", - " x: Generated tokens, y: Attributed tokens\n", + " → : Generated tokens, ↓ : Attributed tokens\n", "
\n", " \n", "\n", - " \n", - "
▁Ciao▁a▁tutti,▁spero▁che▁vi▁stiate▁godendo▁il▁tutorial!</s>
▁Ciao0.4630.030.0150.0110.0060.0030.0050.00.0030.004
▁a0.240.0060.0020.0020.0010.0010.0010.0010.001
▁tutti,0.4560.1420.1320.0170.0090.0110.0270.011
▁spero0.5840.1120.0960.0250.0380.0220.023
▁che0.5230.2330.0440.0490.010.011
▁vi0.1980.0920.0150.0210.007
▁stiate0.3580.0130.0170.01
▁godendo0.470.0490.009
▁il0.2980.009
▁tutorial!0.416
</s>
\n", + "012345678910▁Ciao▁a▁tutti,▁spero▁che▁vi▁stiate▁godendo▁il▁tutorial!</s>0<pad>1.00.5220.6930.4110.220.1870.5470.3460.3870.6980.3991▁Ciao0.410.4630.030.0150.0110.0060.0030.0050.00.0030.0042▁a0.240.0060.0020.0020.0010.0010.0010.0010.0013▁tutti,0.4980.4560.1420.1320.0170.0090.0110.0270.0114▁spero0.5840.1120.0960.0250.0380.0220.0235▁che0.5230.2330.0440.0490.010.0116▁vi0.1980.0920.0150.0210.0077▁stiate0.5220.3580.0130.0170.018▁godendo0.3360.470.0490.0099▁il0.2980.00910▁tutorial!0.4490.41611</s>\n", "
\n", "\n", "
\n", @@ -638,7 +734,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -759,14 +855,16 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n" + "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n", + "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n", + "Batched constrained decoding is currently not supported for decoder-only models. Using batch size of 1.\n" ] }, { @@ -780,8 +878,8 @@ "name": "stderr", "output_type": "stream", "text": [ - "Attributing with integrated_gradients...: 100%|██████████| 9/9 [00:00<00:00, 9.99it/s]\n", - "Attributing with integrated_gradients...: 100%|██████████| 9/9 [00:00<00:00, 10.60it/s]\n" + "Attributing with integrated_gradients...: 100%|██████████| 9/9 [00:04<00:00, 1.07s/it]\n", + "Attributing with integrated_gradients...: 100%|██████████| 9/9 [00:04<00:00, 1.06s/it]\n" ] }, { @@ -789,21 +887,20 @@ "text/html": [ "
0th instance:
\n", "\n", - "
\n", - "
\n", - "
\n", + "
\n", + "
\n", + "
\n", " \n", - "
\n", + "
\n", "
\n", " Target Saliency Heatmap\n", "
\n", - " x: Generated tokens, y: Attributed tokens\n", + " → : Generated tokens, ↓ : Attributed tokens\n", "
\n", " \n", "\n", - " \n", - "
ĠheĠwasĠsick.
The0.2010.2240.1470.157
Ġmanager0.3850.1490.1870.144
Ġwent0.1240.1450.1130.131
Ġhome0.1650.120.2620.143
Ġbecause0.1250.1680.1230.075
Ġhe0.1940.0710.044
Ġwas0.0970.062
Ġsick0.244
.
probability0.430.2510.0670.282
\n", + "0123ĠheĠwasĠsick.0The0.2010.2240.1470.1571Ġmanager0.3850.1490.1870.1442Ġwent0.1240.1450.1130.1313Ġhome0.1650.120.2620.1434Ġbecause0.1250.1680.1230.0755Ġhe0.1940.0710.0446Ġwas0.0970.0627Ġsick0.2448.probability0.430.2510.0670.282\n", "
\n", "\n", "
\n", @@ -823,21 +920,20 @@ "text/html": [ "
0th instance:
\n", "\n", - "
\n", - "
\n", - "
\n", + "
\n", + "
\n", + "
\n", " \n", - "
\n", + "
\n", "
\n", " Target Saliency Heatmap\n", "
\n", - " x: Generated tokens, y: Attributed tokens\n", + " → : Generated tokens, ↓ : Attributed tokens\n", "
\n", " \n", "\n", - " \n", - "
ĠsheĠwasĠsick.
The0.270.2410.1520.158
Ġmanager0.4720.1480.180.143
Ġwent0.0770.1350.1220.153
Ġhome0.1020.0950.2760.16
Ġbecause0.0790.1540.1090.076
Ġshe0.2270.0720.068
Ġwas0.0890.057
Ġsick0.186
.
probability0.080.2810.0710.29
\n", + "0123ĠsheĠwasĠsick.0The0.270.2410.1520.1581Ġmanager0.4720.1480.180.1432Ġwent0.0770.1350.1220.1533Ġhome0.1020.0950.2760.164Ġbecause0.0790.1540.1090.0765Ġshe0.2270.0720.0686Ġwas0.0890.0577Ġsick0.1868.probability0.080.2810.0710.29\n", "
\n", "\n", "
\n", @@ -876,7 +972,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 12, "metadata": {}, "outputs": [ { @@ -884,21 +980,20 @@ "text/html": [ "
0th instance:
\n", "\n", - "
\n", - "
\n", - "
\n", + "
\n", + "
\n", + "
\n", " \n", - "
\n", + "
\n", "
\n", " Target Saliency Heatmap\n", "
\n", - " x: Generated tokens, y: Attributed tokens\n", + " → : Generated tokens, ↓ : Attributed tokens\n", "
\n", " \n", "\n", - " \n", - "
Ġhe → ĠsheĠwasĠsick.
The0.2480.1440.1310.09
Ġmanager0.4620.2130.1650.175
Ġwent0.0950.0530.1270.095
Ġhome0.1140.0770.2290.094
Ġbecause0.0810.1730.0940.077
Ġhe → Ġshe0.340.1860.175
Ġwas0.0680.057
Ġsick0.237
.
probability-0.350.030.0040.008
\n", + "0123Ġhe → ĠsheĠwasĠsick.0The0.2480.1440.1310.091Ġmanager0.4620.2130.1650.1752Ġwent0.0950.0530.1270.0953Ġhome0.1140.0770.2290.0944Ġbecause0.0810.1730.0940.0775Ġhe → Ġshe0.340.1860.1756Ġwas0.0680.0577Ġsick0.2378.probability-0.350.030.0040.008\n", "
\n", "\n", "
\n", @@ -941,7 +1036,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 13, "metadata": {}, "outputs": [ { @@ -964,7 +1059,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 14, "metadata": {}, "outputs": [ { @@ -1007,24 +1102,21 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "/home/gsarti/.cache/pypoetry/virtualenvs/inseq-PzwjmCYf-py3.9/lib/python3.9/site-packages/transformers/models/marian/tokenization_marian.py:197: UserWarning: Recommended: pip install sacremoses.\n", + "/Users/gsarti/Documents/projects/inseq/.venv/lib/python3.12/site-packages/transformers/models/marian/tokenization_marian.py:197: UserWarning: Recommended: pip install sacremoses.\n", " warnings.warn(\"Recommended: pip install sacremoses.\")\n", - "Provided alignments do not cover all 6 tokens from the original sequence.\n", - "Filling missing position with right-aligned 1:1 position alignments.\n", - "Generated alignments: [(1, 1), (2, 2), (3, 3), (4, 4), (5, 5), (6, 6)]\n", - "Attributing with saliency...: 14%|█▍ | 1/7 [00:000th instance:
\n", "\n", - "
\n", - "
\n", - "
\n", + "
\n", + "
\n", + "
\n", " \n", - "
\n", + "
\n", "
\n", " Source Saliency Heatmap\n", "
\n", - " x: Generated tokens, y: Attributed tokens\n", + " → : Generated tokens, ↓ : Attributed tokens\n", "
\n", " \n", "\n", - " \n", - "
▁Ho▁salutato▁la → ▁il▁manager</s>
▁I0.00.00.00.050.00.0
▁said0.00.00.00.0630.00.0
▁hi0.00.00.00.0740.00.0
▁to0.00.00.00.0340.00.0
▁the0.00.00.00.0860.00.0
▁manager0.00.00.00.4290.00.0
</s>0.00.00.00.0920.00.0
contrast_prob_diff0.00.00.00.7070.013-0.043
\n", + "012345▁Ho▁salutato▁la → ▁il▁manager</s>0▁I0.00.00.00.0480.00.01▁said0.00.00.00.0610.00.02▁hi0.00.00.00.0720.00.03▁to0.00.00.00.0330.00.04▁the0.00.00.00.0840.00.05▁manager0.00.00.00.4140.00.06</s>0.00.00.00.0890.00.0contrast_prob_diff0.00.00.00.7070.013-0.043\n", "
\n", "\n", "
\n", @@ -1055,21 +1146,20 @@ "\n", "
0th instance:
\n", "\n", - "
\n", - "
\n", - "
\n", + "
\n", + "
\n", + "
\n", " \n", - "
\n", + "
\n", "
\n", " Target Saliency Heatmap\n", "
\n", - " x: Generated tokens, y: Attributed tokens\n", + " → : Generated tokens, ↓ : Attributed tokens\n", "
\n", " \n", "\n", - " \n", - "
▁Ho▁salutato▁la → ▁il▁manager</s>
▁Ho0.00.00.0250.00.0
▁saluta0.00.1030.00.0
to0.0430.00.0
▁la → ▁il0.00.0
▁manager0.0
</s>
contrast_prob_diff0.00.00.00.7070.013-0.043
\n", + "012345▁Ho▁salutato▁la → ▁il▁manager</s>0<pad>0.00.00.00.0340.00.01▁Ho0.00.00.0250.00.02▁saluta0.00.10.00.03to0.0410.00.04▁la → ▁il0.00.05▁manager0.06</s>contrast_prob_diff0.00.00.00.7070.013-0.043\n", "
\n", "\n", "
\n", @@ -1125,7 +1215,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -1254,22 +1344,16 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "/home/gsarti/.cache/pypoetry/virtualenvs/inseq-PzwjmCYf-py3.9/lib/python3.9/site-packages/transformers/models/marian/tokenization_marian.py:197: UserWarning: Recommended: pip install sacremoses.\n", - " warnings.warn(\"Recommended: pip install sacremoses.\")\n", - "Provided alignments do not cover all 7 tokens from the original sequence.\n", - "Filling missing position with right-aligned 1:1 position alignments.\n", - "\n", - "Generated alignments: [(1, 3), (2, 4), (3, 5), (4, 6), (5, 7), (6, 8), (7, 9)]\n", - "Attributing with saliency...: 12%|█▎ | 1/8 [00:000th instance:
\n", "\n", - "
\n", - "
\n", - "
\n", + "
\n", + "
\n", + "
\n", " \n", - "
\n", + "
\n", "
\n", " Source Saliency Heatmap\n", "
\n", - " x: Generated tokens, y: Attributed tokens\n", + " → : Generated tokens, ↓ : Attributed tokens\n", "
\n", " \n", "\n", - " \n", - "
▁militaires → ▁Les▁de → ▁soldats▁paix → ▁de▁des → ▁la▁Nations → ▁paix▁Unies → ▁ONU</s>
▁UN0.090.070.0690.2580.180.2580.092
▁peacekeepers0.580.6280.6110.4860.5170.4470.589
</s>0.330.3020.3190.2560.3030.2950.319
contrast_prob_diff0.042-0.6070.8650.4670.056-0.6910.015
\n", + "0123456▁militaires → ▁Les▁de → ▁soldats▁paix → ▁de▁des → ▁la▁Nations → ▁paix▁Unies → ▁ONU</s>0▁UN0.090.070.0690.2580.1790.2580.0921▁peacekeepers0.580.6280.6110.4860.5170.4470.5882</s>0.330.3020.3190.2560.3030.2950.319contrast_prob_diff0.042-0.6070.8650.4670.056-0.6910.015\n", "
\n", "\n", "
\n", @@ -1339,7 +1422,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 18, "metadata": {}, "outputs": [ { @@ -1359,14 +1442,14 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "Attributing with saliency...: 100%|██████████| 8/8 [00:00<00:00, 17.28it/s]\n" + "Attributing with saliency...: 100%|██████████| 8/8 [00:01<00:00, 5.55it/s]\n" ] }, { @@ -1374,21 +1457,20 @@ "text/html": [ "
0th instance:
\n", "\n", - "
\n", - "
\n", - "
\n", + "
\n", + "
\n", + "
\n", " \n", - "
\n", + "
\n", "
\n", " Source Saliency Heatmap\n", "
\n", - " x: Generated tokens, y: Attributed tokens\n", + " → : Generated tokens, ↓ : Attributed tokens\n", "
\n", " \n", "\n", - " \n", - "
▁Le → ▁Les▁forces → ▁soldats▁de▁de → ▁la▁paix▁Nations → ▁ONU</s>
▁UN0.090.0570.0820.0810.0750.2570.092
▁peacekeepers0.5680.6320.6160.6150.6180.4540.589
</s>0.3420.3110.3020.3040.3070.2890.319
contrast_prob_diff0.0310.1380.1420.1560.954-0.920.015
\n", + "0123456▁Le → ▁Les▁forces → ▁soldats▁de▁de → ▁la▁paix▁Nations → ▁ONU</s>0▁UN0.090.0570.0820.0810.0750.2570.0921▁peacekeepers0.5680.6320.6160.6150.6180.4540.5882</s>0.3420.3110.3020.3040.3070.2890.319contrast_prob_diff0.0310.1380.1420.1560.954-0.920.015\n", "
\n", "\n", "
\n", @@ -1429,20 +1511,29 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 20, "metadata": {}, "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "a25e05015d9c4b8a9b438fd6e1b6b970", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "tokenizer_config.json: 0%| | 0.00/397 [00:000th instance:
\n", "\n", - "
\n", - "
\n", - "
\n", + "
\n", + "
\n", + "
\n", " \n", - "
\n", + "
\n", "
\n", " Source Saliency Heatmap\n", "
\n", - " x: Generated tokens, y: Attributed tokens\n", + " → : Generated tokens, ↓ : Attributed tokens\n", "
\n", " \n", "\n", - " \n", - "
▁Le → ▁Les▁forces → ▁soldats▁de▁de → ▁la▁paix▁Nations → ▁ONU</s>
▁UN0.090.0570.0820.0810.0750.2570.092
▁peacekeepers0.5680.6320.6160.6150.6180.4540.589
</s>0.3420.3110.3020.3040.3070.2890.319
contrast_prob_diff0.0310.1380.1420.1560.954-0.920.015
\n", + "0123456▁Le → ▁Les▁forces → ▁soldats▁de▁de → ▁la▁paix▁Nations → ▁ONU</s>0▁UN0.090.0570.0820.0810.0750.2570.0921▁peacekeepers0.5680.6320.6160.6150.6180.4540.5882</s>0.3420.3110.3020.3040.3070.2890.319contrast_prob_diff0.0310.1380.1420.1560.954-0.920.015\n", "
\n", "\n", "
\n", @@ -2059,7 +2149,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.13" + "version": "3.12.4" }, "orig_nbformat": 4 }, diff --git a/inseq/__init__.py b/inseq/__init__.py index e916561a..2eedbb14 100644 --- a/inseq/__init__.py +++ b/inseq/__init__.py @@ -7,6 +7,8 @@ list_aggregators, merge_attributions, show_attributions, + show_granular_attributions, + show_token_attributions, ) from .models import AttributionModel, list_supported_frameworks, load_model, register_model_config from .utils.id_utils import explain @@ -28,6 +30,8 @@ def get_version() -> str: "load_model", "explain", "show_attributions", + "show_granular_attributions", + "show_token_attributions", "list_feature_attribution_methods", "list_aggregators", "list_aggregation_functions", diff --git a/inseq/attr/attribution_decorators.py b/inseq/attr/attribution_decorators.py index 60c58d12..18fea9fc 100644 --- a/inseq/attr/attribution_decorators.py +++ b/inseq/attr/attribution_decorators.py @@ -14,9 +14,9 @@ """Decorators for attribution methods.""" import logging -from collections.abc import Sequence +from collections.abc import Callable, Sequence from functools import wraps -from typing import Any, Callable, Optional +from typing import Any from ..data.data_utils import TensorWrapper @@ -55,14 +55,14 @@ def batched(f: Callable[..., Any]) -> Callable[..., Any]: """Decorator that enables batching of the args.""" @wraps(f) - def batched_wrapper(self, *args, batch_size: Optional[int] = None, **kwargs): - def get_batched(bs: Optional[int], seq: Sequence[Any]) -> list[list[Any]]: + def batched_wrapper(self, *args, batch_size: int | None = None, **kwargs): + def get_batched(bs: int | None, seq: Sequence[Any]) -> list[list[Any]]: if isinstance(seq, str): seq = [seq] if isinstance(seq, list): return [seq[i : i + bs] for i in range(0, len(seq), bs)] # noqa if isinstance(seq, tuple): - return list(zip(*[get_batched(bs, s) for s in seq])) + return list(zip(*[get_batched(bs, s) for s in seq], strict=False)) elif isinstance(seq, TensorWrapper): return [seq.slice_batch(slice(i, i + bs)) for i in range(0, len(seq), bs)] # noqa else: @@ -75,7 +75,9 @@ def get_batched(bs: Optional[int], seq: Sequence[Any]) -> list[list[Any]]: len_batches = len(batched_args[0]) assert all(len(batch) == len_batches for batch in batched_args) output = [] - zipped_batched_args = zip(*batched_args) if len(batched_args) > 1 else [(x,) for x in batched_args[0]] + zipped_batched_args = ( + zip(*batched_args, strict=False) if len(batched_args) > 1 else [(x,) for x in batched_args[0]] + ) for i, batch in enumerate(zipped_batched_args): logger.debug(f"Batching enabled: processing batch {i + 1} of {len_batches}...") out = f(self, *batch, **kwargs) diff --git a/inseq/attr/feat/attribution_utils.py b/inseq/attr/feat/attribution_utils.py index a9679845..bf1ac94a 100644 --- a/inseq/attr/feat/attribution_utils.py +++ b/inseq/attr/feat/attribution_utils.py @@ -1,6 +1,7 @@ import logging import math -from typing import TYPE_CHECKING, Any, Callable, Optional, Union +from collections.abc import Callable +from typing import TYPE_CHECKING, Any from ...utils import extract_signature_args, get_aligned_idx from ...utils.typing import ( @@ -24,8 +25,8 @@ def tok2string( attribution_model: "AttributionModel", token_lists: OneOrMoreTokenSequences, - start: Optional[int] = None, - end: Optional[int] = None, + start: int | None = None, + end: int | None = None, as_targets: bool = True, ) -> TextInput: """Enables bounded tokenization of a list of lists of tokens with start and end positions.""" @@ -42,14 +43,14 @@ def rescale_attributions_to_tokens( ) -> OneOrMoreAttributionSequences: return [ attr[: len(tokens)] if not all(math.isnan(x) for x in attr) else [] - for attr, tokens in zip(attributions, tokens) + for attr, tokens in zip(attributions, tokens, strict=False) ] def check_attribute_positions( max_length: int, - attr_pos_start: Optional[int] = None, - attr_pos_end: Optional[int] = None, + attr_pos_start: int | None = None, + attr_pos_end: int | None = None, ) -> tuple[int, int]: r"""Checks whether the combination of start/end positions for attribution is valid. @@ -88,8 +89,8 @@ def check_attribute_positions( def join_token_ids( tokens: OneOrMoreTokenSequences, ids: OneOrMoreIdSequences, - contrast_tokens: Optional[OneOrMoreTokenSequences] = None, - contrast_targets_alignments: Optional[list[list[tuple[int, int]]]] = None, + contrast_tokens: OneOrMoreTokenSequences | None = None, + contrast_targets_alignments: list[list[tuple[int, int]]] | None = None, ) -> list[TokenWithId]: """Joins tokens and ids into a list of TokenWithId objects.""" if contrast_tokens is None: @@ -99,10 +100,10 @@ def join_token_ids( contrast_targets_alignments = [[(idx, idx) for idx, _ in enumerate(seq)] for seq in tokens] sequences = [] for target_tokens_seq, contrast_target_tokens_seq, input_ids_seq, alignments_seq in zip( - tokens, contrast_tokens, ids, contrast_targets_alignments + tokens, contrast_tokens, ids, contrast_targets_alignments, strict=False ): curr_seq = [] - for pos_idx, (token, token_idx) in enumerate(zip(target_tokens_seq, input_ids_seq)): + for pos_idx, (token, token_idx) in enumerate(zip(target_tokens_seq, input_ids_seq, strict=False)): contrast_pos_idx = get_aligned_idx(pos_idx, alignments_seq) if contrast_pos_idx != -1 and token != contrast_target_tokens_seq[contrast_pos_idx]: curr_seq.append(TokenWithId(f"{contrast_target_tokens_seq[contrast_pos_idx]} → {token}", -1)) @@ -142,10 +143,10 @@ def extract_args( def get_source_target_attributions( - attr: Union[StepAttributionTensor, tuple[StepAttributionTensor, StepAttributionTensor]], + attr: StepAttributionTensor | tuple[StepAttributionTensor, StepAttributionTensor], is_encoder_decoder: bool, has_sequence_scores: bool = False, -) -> tuple[Optional[StepAttributionTensor], Optional[StepAttributionTensor]]: +) -> tuple[StepAttributionTensor | None, StepAttributionTensor | None]: if isinstance(attr, tuple): if is_encoder_decoder: if has_sequence_scores: diff --git a/inseq/attr/feat/feature_attribution.py b/inseq/attr/feat/feature_attribution.py index 8dc4cfc8..5272398f 100644 --- a/inseq/attr/feat/feature_attribution.py +++ b/inseq/attr/feat/feature_attribution.py @@ -17,8 +17,9 @@ * 🟡: Allow custom arguments for model loading in the :class:`FeatureAttribution` :meth:`load` method. """ import logging +from collections.abc import Callable from datetime import datetime -from typing import TYPE_CHECKING, Any, Callable, Optional, Union +from typing import TYPE_CHECKING, Any, Optional import torch from jaxtyping import Int @@ -123,7 +124,7 @@ def load( cls, method_name: str, attribution_model: Optional["AttributionModel"] = None, - model_name_or_path: Optional[ModelIdentifier] = None, + model_name_or_path: ModelIdentifier | None = None, **kwargs, ) -> "FeatureAttribution": r"""Load the selected method and hook it to an existing or available @@ -168,8 +169,8 @@ def prepare_and_attribute( self, sources: FeatureAttributionInput, targets: FeatureAttributionInput, - attr_pos_start: Optional[int] = None, - attr_pos_end: Optional[int] = None, + attr_pos_start: int | None = None, + attr_pos_end: int | None = None, show_progress: bool = True, pretty_progress: bool = True, output_step_attributions: bool = False, @@ -177,7 +178,7 @@ def prepare_and_attribute( step_scores: list[str] = [], include_eos_baseline: bool = False, skip_special_tokens: bool = False, - attributed_fn: Union[str, Callable[..., SingleScorePerStepTensor], None] = None, + attributed_fn: str | Callable[..., SingleScorePerStepTensor] | None = None, attribution_args: dict[str, Any] = {}, attributed_fn_args: dict[str, Any] = {}, step_scores_args: dict[str, Any] = {}, @@ -317,7 +318,7 @@ def format_contrastive_targets( attr_pos_start: int, attr_pos_end: int, skip_special_tokens: bool = False, - ) -> tuple[Optional[DecoderOnlyBatch], Optional[list[list[tuple[int, int]]]], dict[str, Any], dict[str, Any]]: + ) -> tuple[DecoderOnlyBatch | None, list[list[tuple[int, int]]] | None, dict[str, Any], dict[str, Any]]: contrast_batch, contrast_targets_alignments = None, None contrast_targets = attributed_fn_args.get("contrast_targets", None) if contrast_targets is None: @@ -357,10 +358,10 @@ def format_contrastive_targets( def attribute( self, - batch: Union[DecoderOnlyBatch, EncoderDecoderBatch], + batch: DecoderOnlyBatch | EncoderDecoderBatch, attributed_fn: Callable[..., SingleScorePerStepTensor], - attr_pos_start: Optional[int] = None, - attr_pos_end: Optional[int] = None, + attr_pos_start: int | None = None, + attr_pos_end: int | None = None, show_progress: bool = True, pretty_progress: bool = True, output_step_attributions: bool = False, @@ -545,10 +546,10 @@ def attribute( def filtered_attribute_step( self, - batch: Union[DecoderOnlyBatch, EncoderDecoderBatch], + batch: DecoderOnlyBatch | EncoderDecoderBatch, target_ids: Int[torch.Tensor, "batch_size 1"], attributed_fn: Callable[..., SingleScorePerStepTensor], - target_attention_mask: Optional[Int[torch.Tensor, "batch_size 1"]] = None, + target_attention_mask: Int[torch.Tensor, "batch_size 1"] | None = None, attribute_target: bool = False, step_scores: list[str] = [], attribution_args: dict[str, Any] = {}, diff --git a/inseq/attr/feat/internals_attribution.py b/inseq/attr/feat/internals_attribution.py index f3f0479d..3fa1667b 100644 --- a/inseq/attr/feat/internals_attribution.py +++ b/inseq/attr/feat/internals_attribution.py @@ -14,7 +14,7 @@ """Attention-based feature attribution methods.""" import logging -from typing import Any, Optional +from typing import Any from captum._utils.typing import TensorOrTupleOfTensorsGeneric @@ -46,9 +46,9 @@ def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, additional_forward_args: TensorOrTupleOfTensorsGeneric, - encoder_self_attentions: Optional[MultiLayerMultiUnitScoreTensor] = None, - decoder_self_attentions: Optional[MultiLayerMultiUnitScoreTensor] = None, - cross_attentions: Optional[MultiLayerMultiUnitScoreTensor] = None, + encoder_self_attentions: MultiLayerMultiUnitScoreTensor | None = None, + decoder_self_attentions: MultiLayerMultiUnitScoreTensor | None = None, + cross_attentions: MultiLayerMultiUnitScoreTensor | None = None, ) -> MultiDimensionalFeatureAttributionStepOutput: """Extracts the attention weights from the model. diff --git a/inseq/attr/feat/ops/discretized_integrated_gradients.py b/inseq/attr/feat/ops/discretized_integrated_gradients.py index c3ad4908..30f98fb2 100644 --- a/inseq/attr/feat/ops/discretized_integrated_gradients.py +++ b/inseq/attr/feat/ops/discretized_integrated_gradients.py @@ -16,8 +16,9 @@ # WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE # OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +from collections.abc import Callable from pathlib import Path -from typing import Any, Callable, Union +from typing import Any import torch from captum._utils.common import ( @@ -94,9 +95,9 @@ def attribute( # type: ignore additional_forward_args: Any = None, n_steps: int = 50, method: str = "greedy", - internal_batch_size: Union[None, int] = None, + internal_batch_size: None | int = None, return_convergence_delta: bool = False, - ) -> Union[TensorOrTupleOfTensorsGeneric, tuple[TensorOrTupleOfTensorsGeneric, Tensor]]: + ) -> TensorOrTupleOfTensorsGeneric | tuple[TensorOrTupleOfTensorsGeneric, Tensor]: n_examples = inputs[0].shape[0] # Keeps track whether original input is a tuple or not before # converting it into a tuple. @@ -112,7 +113,7 @@ def attribute( # type: ignore n_steps=n_steps, scale_strategy=method, ) - for input_tensor, baseline_tensor in zip(inputs, baselines) + for input_tensor, baseline_tensor in zip(inputs, baselines, strict=False) ) if internal_batch_size is not None: attributions = _batch_attribution( @@ -181,7 +182,7 @@ def _attribute( # total_grads has the same dimensionality as the original inputs total_grads = tuple( _reshape_and_sum(scaled_grad, n_steps, grad.shape[0] // n_steps, grad.shape[1:]) - for (scaled_grad, grad) in zip(scaled_grads, grads) + for (scaled_grad, grad) in zip(scaled_grads, grads, strict=False) ) # computes attribution for each tensor in input_tuple # attributions has the same dimensionality as the original inputs @@ -191,5 +192,5 @@ def _attribute( inputs, baselines = self.get_inputs_baselines(scaled_features_tpl, n_steps) return tuple( total_grad * (input - baseline) - for (total_grad, input, baseline) in zip(total_grads, inputs, baselines) + for (total_grad, input, baseline) in zip(total_grads, inputs, baselines, strict=False) ) diff --git a/inseq/attr/feat/ops/lime.py b/inseq/attr/feat/ops/lime.py index 19cc869c..1a6326cb 100644 --- a/inseq/attr/feat/ops/lime.py +++ b/inseq/attr/feat/ops/lime.py @@ -1,8 +1,9 @@ import inspect import logging import math +from collections.abc import Callable from functools import partial -from typing import Any, Callable, Optional, cast +from typing import Any, cast import torch from captum._utils.common import _expand_additional_forward_args, _expand_target @@ -25,8 +26,8 @@ def __init__( similarity_func: Callable = None, perturb_func: Callable = None, perturb_interpretable_space: bool = False, - from_interp_rep_transform: Optional[Callable] = None, - to_interp_rep_transform: Optional[Callable] = None, + from_interp_rep_transform: Callable | None = None, + to_interp_rep_transform: Callable | None = None, mask_prob: float = 0.3, ) -> None: if interpretable_model is None: @@ -271,7 +272,12 @@ def detach_to_list(t): # Merge the binary mask with the special_token_ids mask mask = ( - torch.tensor([m + s if s == 0 else s for m, s in zip(mask_multinomial_binary, mask_special_token_ids)]) + torch.tensor( + [ + m + s if s == 0 else s + for m, s in zip(mask_multinomial_binary, mask_special_token_ids, strict=False) + ] + ) .to(self.attribution_model.device) .unsqueeze(-1) # 1D -> 2D ) diff --git a/inseq/attr/feat/ops/monotonic_path_builder.py b/inseq/attr/feat/ops/monotonic_path_builder.py index d45fec4b..f5e192d1 100644 --- a/inseq/attr/feat/ops/monotonic_path_builder.py +++ b/inseq/attr/feat/ops/monotonic_path_builder.py @@ -23,7 +23,7 @@ from enum import Enum from itertools import islice from pathlib import Path -from typing import Any, Optional, Union +from typing import Any import torch from jaxtyping import Float, Int @@ -85,7 +85,7 @@ def __init__( @staticmethod @cache_results def compute_embeddings_knn( - vocabulary_embeddings: Optional[VocabularyEmbeddingsTensor], + vocabulary_embeddings: VocabularyEmbeddingsTensor | None, n_neighbors: int = 50, mode: str = "distance", n_jobs: int = -1, @@ -111,7 +111,7 @@ def load( save_cache: bool = True, overwrite_cache: bool = False, cache_dir: Path = INSEQ_ARTIFACTS_CACHE / "path_knn", - vocabulary_embeddings: Optional[VocabularyEmbeddingsTensor] = None, + vocabulary_embeddings: VocabularyEmbeddingsTensor | None = None, special_tokens: list[int] = [], embedding_scaling: int = 1, ) -> "MonotonicPathBuilder": @@ -140,8 +140,8 @@ def scale_inputs( self, input_ids: Int[torch.Tensor, "batch_size seq_len"], baseline_ids: Int[torch.Tensor, "batch_size seq_len"], - n_steps: Optional[int] = None, - scale_strategy: Optional[str] = None, + n_steps: int | None = None, + scale_strategy: str | None = None, ) -> MultiStepEmbeddingsTensor: """Generate paths required by DIG.""" if n_steps is None: @@ -186,8 +186,8 @@ def find_path( self, word_idx: int, baseline_idx: int, - n_steps: Optional[int] = 30, - strategy: Optional[str] = "greedy", + n_steps: int | None = 30, + strategy: str | None = "greedy", ) -> list[int]: """Find a monotonic path from a word to a baseline.""" # if word_idx is a special token copy it and return @@ -260,7 +260,7 @@ def get_word_distance( baseline_idx: int, original_idx: int, n_steps: int, - ) -> Union[float, int]: + ) -> float | int: """Get the distance between the anchor word and the baseline word.""" if strategy == PathBuildingStrategies.GREEDY.value: # calculate the distance of the monotonized vec from the interpolated point @@ -299,7 +299,7 @@ def make_monotonic_vec( anchor: torch.Tensor, baseline: torch.Tensor, input: torch.Tensor, - n_steps: Optional[int] = 30, + n_steps: int | None = 30, ) -> torch.Tensor: """Create a new monotonic vector w.r.t. input and baseline from an existing anchor.""" non_monotonic_dims = ~cls.get_monotonic_dims(anchor, baseline, input) diff --git a/inseq/attr/feat/ops/reagent.py b/inseq/attr/feat/ops/reagent.py index 080a2131..baea0c4e 100644 --- a/inseq/attr/feat/ops/reagent.py +++ b/inseq/attr/feat/ops/reagent.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, Union +from typing import TYPE_CHECKING, Any import torch from captum._utils.typing import TargetType, TensorOrTupleOfTensorsGeneric @@ -106,10 +106,7 @@ def attribute( # type: ignore inputs: TensorOrTupleOfTensorsGeneric, _target: TargetType = None, additional_forward_args: Any = None, - ) -> Union[ - TensorOrTupleOfTensorsGeneric, - tuple[TensorOrTupleOfTensorsGeneric, Tensor], - ]: + ) -> TensorOrTupleOfTensorsGeneric | tuple[TensorOrTupleOfTensorsGeneric, Tensor]: """Implement attribute""" # encoder-decoder if self.forward_func.is_encoder_decoder: diff --git a/inseq/attr/feat/ops/reagent_core/importance_score_evaluator.py b/inseq/attr/feat/ops/reagent_core/importance_score_evaluator.py index adc79b7b..de1c6b18 100644 --- a/inseq/attr/feat/ops/reagent_core/importance_score_evaluator.py +++ b/inseq/attr/feat/ops/reagent_core/importance_score_evaluator.py @@ -2,7 +2,6 @@ import logging from abc import ABC, abstractmethod -from typing import Optional import torch from jaxtyping import Float @@ -33,7 +32,7 @@ def __call__( self, input_ids: IdsTensor, target_id: TargetIdsTensor, - decoder_input_ids: Optional[IdsTensor] = None, + decoder_input_ids: IdsTensor | None = None, attribute_target: bool = False, ) -> MultipleScoresPerStepTensor: """Evaluate importance score of input sequence @@ -84,7 +83,7 @@ def update_importance_score( input_ids: IdsTensor, target_id: TargetIdsTensor, prob_original_target: Float[torch.Tensor, "batch_size 1"], - decoder_input_ids: Optional[IdsTensor] = None, + decoder_input_ids: IdsTensor | None = None, attribute_target: bool = False, ) -> MultipleScoresPerStepTensor: """Update importance score by one step @@ -138,7 +137,7 @@ def __call__( self, input_ids: IdsTensor, target_id: TargetIdsTensor, - decoder_input_ids: Optional[IdsTensor] = None, + decoder_input_ids: IdsTensor | None = None, attribute_target: bool = False, ) -> MultipleScoresPerStepTensor: """Evaluate importance score of input sequence diff --git a/inseq/attr/feat/ops/reagent_core/rationalizer.py b/inseq/attr/feat/ops/reagent_core/rationalizer.py index ab7c2be8..878f26c5 100644 --- a/inseq/attr/feat/ops/reagent_core/rationalizer.py +++ b/inseq/attr/feat/ops/reagent_core/rationalizer.py @@ -1,6 +1,5 @@ import math from abc import ABC, abstractmethod -from typing import Optional import torch from jaxtyping import Int64 @@ -21,7 +20,7 @@ def __call__( self, input_ids: IdsTensor, target_id: TargetIdsTensor, - decoder_input_ids: Optional[IdsTensor] = None, + decoder_input_ids: IdsTensor | None = None, attribute_target: bool = False, ) -> Int64[torch.Tensor, "batch_size other_dims"]: """Compute rational of a sequence on a target @@ -78,7 +77,7 @@ def __call__( self, input_ids: IdsTensor, target_id: TargetIdsTensor, - decoder_input_ids: Optional[IdsTensor] = None, + decoder_input_ids: IdsTensor | None = None, attribute_target: bool = False, ) -> Int64[torch.Tensor, "batch_size other_dims"]: """Compute rational of a sequence on a target diff --git a/inseq/attr/feat/ops/reagent_core/stopping_condition_evaluator.py b/inseq/attr/feat/ops/reagent_core/stopping_condition_evaluator.py index fd3bb67d..d45fc022 100644 --- a/inseq/attr/feat/ops/reagent_core/stopping_condition_evaluator.py +++ b/inseq/attr/feat/ops/reagent_core/stopping_condition_evaluator.py @@ -1,6 +1,5 @@ import logging from abc import ABC, abstractmethod -from typing import Optional import torch from transformers import AutoModelForCausalLM @@ -19,7 +18,7 @@ def __call__( input_ids: IdsTensor, target_id: TargetIdsTensor, importance_score: MultipleScoresPerStepTensor, - decoder_input_ids: Optional[IdsTensor] = None, + decoder_input_ids: IdsTensor | None = None, attribute_target: bool = False, ) -> TargetIdsTensor: """Evaluate stop condition according to the specified strategy. @@ -75,7 +74,7 @@ def __call__( input_ids: IdsTensor, target_id: TargetIdsTensor, importance_score: MultipleScoresPerStepTensor, - decoder_input_ids: Optional[IdsTensor] = None, + decoder_input_ids: IdsTensor | None = None, attribute_target: bool = False, ) -> TargetIdsTensor: """Evaluate stop condition diff --git a/inseq/attr/feat/ops/reagent_core/token_sampler.py b/inseq/attr/feat/ops/reagent_core/token_sampler.py index 7ca41bf2..db39abb1 100644 --- a/inseq/attr/feat/ops/reagent_core/token_sampler.py +++ b/inseq/attr/feat/ops/reagent_core/token_sampler.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from collections import defaultdict from pathlib import Path -from typing import Any, Optional, Union +from typing import Any import torch from transformers import AutoTokenizer, PreTrainedTokenizerBase @@ -35,13 +35,13 @@ class POSTagTokenSampler(TokenSampler): def __init__( self, - tokenizer: Union[str, PreTrainedTokenizerBase], + tokenizer: str | PreTrainedTokenizerBase, identifier: str = "pos_tag_sampler", save_cache: bool = True, overwrite_cache: bool = False, cache_dir: Path = INSEQ_ARTIFACTS_CACHE / "pos_tag_sampler_cache", - device: Optional[str] = None, - tokenizer_kwargs: Optional[dict[str, Any]] = {}, + device: str | None = None, + tokenizer_kwargs: dict[str, Any] | None = {}, ) -> None: if isinstance(tokenizer, PreTrainedTokenizerBase): self.tokenizer = tokenizer diff --git a/inseq/attr/feat/ops/sequential_integrated_gradients.py b/inseq/attr/feat/ops/sequential_integrated_gradients.py index e63a27f2..9cf043e9 100644 --- a/inseq/attr/feat/ops/sequential_integrated_gradients.py +++ b/inseq/attr/feat/ops/sequential_integrated_gradients.py @@ -17,7 +17,8 @@ # OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. import typing -from typing import Any, Callable, Union +from collections.abc import Callable +from typing import Any import torch from captum._utils.common import ( @@ -121,7 +122,7 @@ def attribute( additional_forward_args: Any = None, n_steps: int = 50, method: str = "gausslegendre", - internal_batch_size: Union[None, int] = None, + internal_batch_size: None | int = None, return_convergence_delta: Literal[False] = False, ) -> TensorOrTupleOfTensorsGeneric: ... @@ -135,7 +136,7 @@ def attribute( additional_forward_args: Any = None, n_steps: int = 50, method: str = "gausslegendre", - internal_batch_size: Union[None, int] = None, + internal_batch_size: None | int = None, *, return_convergence_delta: Literal[True], ) -> tuple[TensorOrTupleOfTensorsGeneric, Tensor]: @@ -149,12 +150,9 @@ def attribute( # type: ignore additional_forward_args: Any = None, n_steps: int = 50, method: str = "gausslegendre", - internal_batch_size: Union[None, int] = None, + internal_batch_size: None | int = None, return_convergence_delta: bool = False, - ) -> Union[ - TensorOrTupleOfTensorsGeneric, - tuple[TensorOrTupleOfTensorsGeneric, Tensor], - ]: + ) -> TensorOrTupleOfTensorsGeneric | tuple[TensorOrTupleOfTensorsGeneric, Tensor]: r""" This method attributes the output of the model with given target index (in case it is provided, otherwise it assumes that output is a @@ -368,13 +366,13 @@ def attribute( # type: ignore def _attribute( self, inputs: tuple[Tensor, ...], - baselines: tuple[Union[Tensor, int, float], ...], + baselines: tuple[Tensor | int | float, ...], target: TargetType = None, additional_forward_args: Any = None, n_steps: int = 50, method: str = "gausslegendre", idx: int = None, - step_sizes_and_alphas: Union[None, tuple[list[float], list[float]]] = None, + step_sizes_and_alphas: None | tuple[list[float], list[float]] = None, ) -> tuple[Tensor, ...]: if step_sizes_and_alphas is None: # retrieve step size and scaling factor for specified @@ -412,7 +410,7 @@ def _attribute( ], dim=1, ) - for input, baseline in zip(inputs, baselines_) + for input, baseline in zip(inputs, baselines_, strict=False) ) additional_forward_args = _format_additional_forward_args(additional_forward_args) @@ -447,7 +445,7 @@ def _attribute( # total_grads has the same dimensionality as inputs total_grads = tuple( _reshape_and_sum(scaled_grad, n_steps, grad.shape[0] // n_steps, grad.shape[1:]) - for (scaled_grad, grad) in zip(scaled_grads, grads) + for (scaled_grad, grad) in zip(scaled_grads, grads, strict=False) ) # computes attribution for each tensor in input tuple @@ -456,7 +454,8 @@ def _attribute( attributions = total_grads else: attributions = tuple( - total_grad * (input - baseline) for total_grad, input, baseline in zip(total_grads, inputs, baselines) + total_grad * (input - baseline) + for total_grad, input, baseline in zip(total_grads, inputs, baselines, strict=False) ) return attributions diff --git a/inseq/attr/feat/ops/value_zeroing.py b/inseq/attr/feat/ops/value_zeroing.py index ed95eb12..20f0764a 100644 --- a/inseq/attr/feat/ops/value_zeroing.py +++ b/inseq/attr/feat/ops/value_zeroing.py @@ -13,9 +13,10 @@ # limitations under the License. import logging +from collections.abc import Callable from enum import Enum from types import FrameType -from typing import TYPE_CHECKING, Callable, Optional +from typing import TYPE_CHECKING import torch from captum._utils.typing import TensorOrTupleOfTensorsGeneric @@ -101,8 +102,8 @@ def get_value_zeroing_hook(varname: str = "value") -> Callable[..., None]: def value_zeroing_forward_mid_hook( frame: FrameType, - zeroed_token_index: Optional[int] = None, - zeroed_units_indices: Optional[OneOrMoreIndices] = None, + zeroed_token_index: int | None = None, + zeroed_units_indices: OneOrMoreIndices | None = None, batch_size: int = 1, ) -> None: if varname not in frame.f_locals: @@ -166,10 +167,10 @@ def compute_modules_post_zeroing_similarity( additional_forward_args: TensorOrTupleOfTensorsGeneric, hidden_states: MultiLayerEmbeddingsTensor, attention_module_name: str, - attributed_seq_len: Optional[int] = None, + attributed_seq_len: int | None = None, similarity_metric: str = ValueZeroingSimilarityMetric.COSINE.value, mode: str = ValueZeroingModule.DECODER.value, - zeroed_units_indices: Optional[OneOrMoreIndicesDict] = None, + zeroed_units_indices: OneOrMoreIndicesDict | None = None, min_score_threshold: float = 1e-5, use_causal_mask: bool = False, ) -> MultiLayerScoreTensor: @@ -306,11 +307,11 @@ def attribute( inputs: TensorOrTupleOfTensorsGeneric, additional_forward_args: TensorOrTupleOfTensorsGeneric, similarity_metric: str = ValueZeroingSimilarityMetric.COSINE.value, - encoder_zeroed_units_indices: Optional[OneOrMoreIndicesDict] = None, - decoder_zeroed_units_indices: Optional[OneOrMoreIndicesDict] = None, - cross_zeroed_units_indices: Optional[OneOrMoreIndicesDict] = None, - encoder_hidden_states: Optional[MultiLayerEmbeddingsTensor] = None, - decoder_hidden_states: Optional[MultiLayerEmbeddingsTensor] = None, + encoder_zeroed_units_indices: OneOrMoreIndicesDict | None = None, + decoder_zeroed_units_indices: OneOrMoreIndicesDict | None = None, + cross_zeroed_units_indices: OneOrMoreIndicesDict | None = None, + encoder_hidden_states: MultiLayerEmbeddingsTensor | None = None, + decoder_hidden_states: MultiLayerEmbeddingsTensor | None = None, output_decoder_self_scores: bool = True, output_encoder_self_scores: bool = True, ) -> TensorOrTupleOfTensorsGeneric: diff --git a/inseq/attr/step_functions.py b/inseq/attr/step_functions.py index df82fef7..5b556b01 100644 --- a/inseq/attr/step_functions.py +++ b/inseq/attr/step_functions.py @@ -1,7 +1,7 @@ import logging from dataclasses import dataclass from inspect import signature -from typing import TYPE_CHECKING, Any, Optional, Protocol, Union +from typing import TYPE_CHECKING, Any, Protocol import torch import torch.nn.functional as F @@ -70,7 +70,7 @@ class StepFunctionDecoderOnlyArgs(StepFunctionBaseArgs): pass -StepFunctionArgs = Union[StepFunctionEncoderDecoderArgs, StepFunctionDecoderOnlyArgs] +StepFunctionArgs = StepFunctionEncoderDecoderArgs | StepFunctionDecoderOnlyArgs class StepFunction(Protocol): @@ -128,9 +128,9 @@ def perplexity_fn(args: StepFunctionArgs) -> SingleScorePerStepTensor: @contrast_fn_docstring() def contrast_logits_fn( args: StepFunctionArgs, - contrast_sources: Optional[FeatureAttributionInput] = None, - contrast_targets: Optional[FeatureAttributionInput] = None, - contrast_targets_alignments: Optional[list[list[tuple[int, int]]]] = None, + contrast_sources: FeatureAttributionInput | None = None, + contrast_targets: FeatureAttributionInput | None = None, + contrast_targets_alignments: list[list[tuple[int, int]]] | None = None, contrast_force_inputs: bool = False, skip_special_tokens: bool = False, ): @@ -153,9 +153,9 @@ def contrast_logits_fn( @contrast_fn_docstring() def contrast_prob_fn( args: StepFunctionArgs, - contrast_sources: Optional[FeatureAttributionInput] = None, - contrast_targets: Optional[FeatureAttributionInput] = None, - contrast_targets_alignments: Optional[list[list[tuple[int, int]]]] = None, + contrast_sources: FeatureAttributionInput | None = None, + contrast_targets: FeatureAttributionInput | None = None, + contrast_targets_alignments: list[list[tuple[int, int]]] | None = None, logprob: bool = False, contrast_force_inputs: bool = False, skip_special_tokens: bool = False, @@ -179,9 +179,9 @@ def contrast_prob_fn( @contrast_fn_docstring() def pcxmi_fn( args: StepFunctionArgs, - contrast_sources: Optional[FeatureAttributionInput] = None, - contrast_targets: Optional[FeatureAttributionInput] = None, - contrast_targets_alignments: Optional[list[list[tuple[int, int]]]] = None, + contrast_sources: FeatureAttributionInput | None = None, + contrast_targets: FeatureAttributionInput | None = None, + contrast_targets_alignments: list[list[tuple[int, int]]] | None = None, contrast_force_inputs: bool = False, skip_special_tokens: bool = False, ) -> SingleScorePerStepTensor: @@ -205,9 +205,9 @@ def pcxmi_fn( @contrast_fn_docstring() def kl_divergence_fn( args: StepFunctionArgs, - contrast_sources: Optional[FeatureAttributionInput] = None, - contrast_targets: Optional[FeatureAttributionInput] = None, - contrast_targets_alignments: Optional[list[list[tuple[int, int]]]] = None, + contrast_sources: FeatureAttributionInput | None = None, + contrast_targets: FeatureAttributionInput | None = None, + contrast_targets_alignments: list[list[tuple[int, int]]] | None = None, top_k: int = 0, top_p: float = 1.0, min_tokens_to_keep: int = 1, @@ -266,9 +266,9 @@ def kl_divergence_fn( @contrast_fn_docstring() def contrast_prob_diff_fn( args: StepFunctionArgs, - contrast_sources: Optional[FeatureAttributionInput] = None, - contrast_targets: Optional[FeatureAttributionInput] = None, - contrast_targets_alignments: Optional[list[list[tuple[int, int]]]] = None, + contrast_sources: FeatureAttributionInput | None = None, + contrast_targets: FeatureAttributionInput | None = None, + contrast_targets_alignments: list[list[tuple[int, int]]] | None = None, logprob: bool = False, contrast_force_inputs: bool = False, skip_special_tokens: bool = False, @@ -296,9 +296,9 @@ def contrast_prob_diff_fn( @contrast_fn_docstring() def contrast_logits_diff_fn( args: StepFunctionArgs, - contrast_sources: Optional[FeatureAttributionInput] = None, - contrast_targets: Optional[FeatureAttributionInput] = None, - contrast_targets_alignments: Optional[list[list[tuple[int, int]]]] = None, + contrast_sources: FeatureAttributionInput | None = None, + contrast_targets: FeatureAttributionInput | None = None, + contrast_targets_alignments: list[list[tuple[int, int]]] | None = None, contrast_force_inputs: bool = False, skip_special_tokens: bool = False, ): @@ -320,9 +320,9 @@ def contrast_logits_diff_fn( @contrast_fn_docstring() def in_context_pvi_fn( args: StepFunctionArgs, - contrast_sources: Optional[FeatureAttributionInput] = None, - contrast_targets: Optional[FeatureAttributionInput] = None, - contrast_targets_alignments: Optional[list[list[tuple[int, int]]]] = None, + contrast_sources: FeatureAttributionInput | None = None, + contrast_targets: FeatureAttributionInput | None = None, + contrast_targets_alignments: list[list[tuple[int, int]]] | None = None, contrast_force_inputs: bool = False, skip_special_tokens: bool = False, ): @@ -440,7 +440,7 @@ def get_step_scores( def get_step_scores_args( - score_identifiers: list[str], kwargs: dict[str, Any], default_args: Optional[dict[str, Any]] = None + score_identifiers: list[str], kwargs: dict[str, Any], default_args: dict[str, Any] | None = None ) -> dict[str, Any]: step_scores_args = {} for step_fn_id in score_identifiers: @@ -468,7 +468,7 @@ def list_step_functions() -> list[str]: def register_step_function( fn: StepFunction, identifier: str, - aggregate_map: Optional[dict[str, str]] = None, + aggregate_map: dict[str, str] | None = None, overwrite: bool = False, ) -> None: """Registers a function to be used to compute step scores and store them in the diff --git a/inseq/commands/attribute/attribute.py b/inseq/commands/attribute/attribute.py index 95c4e5d7..341db914 100644 --- a/inseq/commands/attribute/attribute.py +++ b/inseq/commands/attribute/attribute.py @@ -1,5 +1,4 @@ import logging -from typing import Optional from ... import FeatureAttributionOutput, load_model from ..base import BaseCLICommand @@ -8,13 +7,13 @@ def aggregate_attribution_scores( out: FeatureAttributionOutput, - selectors: Optional[list[int]] = None, - aggregators: Optional[list[str]] = None, + selectors: list[int] | None = None, + aggregators: list[str] | None = None, normalize_attributions: bool = False, rescale_attributions: bool = False, ) -> FeatureAttributionOutput: if selectors is not None and aggregators is not None: - for select_idx, aggregator_fn in zip(selectors, aggregators): + for select_idx, aggregator_fn in zip(selectors, aggregators, strict=False): out = out.aggregate( aggregator=aggregator_fn, normalize=normalize_attributions, diff --git a/inseq/commands/attribute/attribute_args.py b/inseq/commands/attribute/attribute_args.py index dfca76dd..d778cbc3 100644 --- a/inseq/commands/attribute/attribute_args.py +++ b/inseq/commands/attribute/attribute_args.py @@ -1,5 +1,4 @@ from dataclasses import dataclass -from typing import Optional from ... import ( list_aggregation_functions, @@ -17,7 +16,7 @@ class AttributeBaseArgs: model_name_or_path: str = cli_arg( default=None, aliases=["-m"], help="The name or path of the model on which attribution is performed." ) - attribution_method: Optional[str] = cli_arg( + attribution_method: str | None = cli_arg( default="saliency", aliases=["-a"], help="The attribution method used to perform feature attribution.", @@ -28,7 +27,7 @@ class AttributeBaseArgs: aliases=["--dev"], help="The device used for inference with Pytorch. Multi-GPU is not supported.", ) - attributed_fn: Optional[str] = cli_arg( + attributed_fn: str | None = cli_arg( default=None, aliases=["-fn"], choices=list_step_functions(), @@ -38,7 +37,7 @@ class AttributeBaseArgs: " specified using the ``attribution_kwargs`` argument." ), ) - attribution_selectors: Optional[list[int]] = cli_arg( + attribution_selectors: list[int] | None = cli_arg( default=None, help=( "The indices of the attribution scores to be used for the attribution aggregation. If specified, the" @@ -125,19 +124,19 @@ class AttributeExtendedArgs(AttributeBaseArgs): aliases=["--hide"], help="If specified, the attribution visualization are not shown in the output.", ) - save_path: Optional[str] = cli_arg( + save_path: str | None = cli_arg( default=None, aliases=["-o"], help="Path where the attribution output should be saved in JSON format.", ) - viz_path: Optional[str] = cli_arg( + viz_path: str | None = cli_arg( default=None, help="Path where the attribution visualization should be saved in HTML format.", ) - start_pos: Optional[int] = cli_arg( + start_pos: int | None = cli_arg( default=None, aliases=["-s"], help="Start position for the attribution. Default: first token" ) - end_pos: Optional[int] = cli_arg( + end_pos: int | None = cli_arg( default=None, aliases=["-e"], help="End position for the attribution. Default: last token" ) verbose: bool = cli_arg( @@ -152,7 +151,7 @@ class AttributeExtendedArgs(AttributeBaseArgs): @dataclass class AttributeWithInputsArgs(AttributeExtendedArgs): input_texts: list[str] = cli_arg(default=None, aliases=["-i"], help="One or more input texts used for generation.") - generated_texts: Optional[list[str]] = cli_arg( + generated_texts: list[str] | None = cli_arg( default=None, aliases=["-g"], help="If specified, constrains the decoding procedure to the specified outputs." ) diff --git a/inseq/commands/attribute_context/attribute_context_args.py b/inseq/commands/attribute_context/attribute_context_args.py index 295d5cc2..6616d83e 100644 --- a/inseq/commands/attribute_context/attribute_context_args.py +++ b/inseq/commands/attribute_context/attribute_context_args.py @@ -1,7 +1,7 @@ import logging from dataclasses import dataclass from enum import Enum -from typing import Any, Optional +from typing import Any from ... import list_step_functions from ...attr.step_functions import is_contrastive_step_function @@ -30,7 +30,7 @@ class AttributeContextInputArgs: "``input_template``." ), ) - input_context_text: Optional[str] = cli_arg( + input_context_text: str | None = cli_arg( default=None, help=( "Additional input context influencing the generation of ``output_current_text``. If the model is a" @@ -39,7 +39,7 @@ class AttributeContextInputArgs: " It will be formatted as {context} in the ``input_template``." ), ) - input_template: Optional[str] = cli_arg( + input_template: str | None = cli_arg( default=None, help=( "The template used to format model inputs. The template must contain at least the" @@ -49,7 +49,7 @@ class AttributeContextInputArgs: " Defaults to '{context} {current}' if ``input_context_text`` is provided, '{current}' otherwise." ), ) - output_context_text: Optional[str] = cli_arg( + output_context_text: str | None = cli_arg( default=None, help=( "An output contexts for which context sensitivity should be detected. For encoder-decoder models, this" @@ -60,7 +60,7 @@ class AttributeContextInputArgs: " along with the output current text, and user validation might be required to separate the two." ), ) - output_current_text: Optional[str] = cli_arg( + output_current_text: str | None = cli_arg( default=None, help=( "The output text generated by the model when all available contexts are provided. Tokens in " @@ -70,7 +70,7 @@ class AttributeContextInputArgs: " and ``output_template``. It will be formatted as {current} in the ``output_template``." ), ) - output_template: Optional[str] = cli_arg( + output_template: str | None = cli_arg( default=None, help=( "The template used to format model outputs. The template must contain at least the" @@ -80,7 +80,7 @@ class AttributeContextInputArgs: " Defaults to '{context} {current}' if ``output_context_text`` is provided, '{current}' otherwise." ), ) - contextless_input_current_text: Optional[str] = cli_arg( + contextless_input_current_text: str | None = cli_arg( default=None, help=( "The input current text or template to use in the contrastive comparison with contextual input. By default" @@ -91,7 +91,7 @@ class AttributeContextInputArgs: " be used as-is for the contrastive comparison, enabling contrastive comparison with different inputs." ), ) - contextless_output_current_text: Optional[str] = cli_arg( + contextless_output_current_text: str | None = cli_arg( default=None, help=( "The output current text or template to use in the contrastive comparison with contextual output. By default" @@ -162,7 +162,7 @@ class AttributeContextMethodArgs(AttributeBaseArgs): " ``context_sensitivity_metric`` score for tokens to be considered context-sensitive." ), ) - context_sensitivity_topk: Optional[int] = cli_arg( + context_sensitivity_topk: int | None = cli_arg( default=None, help=( "If set, after selecting the salient context-sensitive tokens with ``context_sensitivity_std_threshold`` " @@ -179,7 +179,7 @@ class AttributeContextMethodArgs(AttributeBaseArgs): "used in the visualization of context reliance." ), ) - attribution_topk: Optional[int] = cli_arg( + attribution_topk: int | None = cli_arg( default=None, help=( "If set, after selecting the most salient tokens with ``attribution_std_threshold`` " @@ -198,7 +198,7 @@ class AttributeContextOutputArgs: "identification (CTI) and contextual cues imputation (CCI) are shown during the process." ), ) - save_path: Optional[str] = cli_arg( + save_path: str | None = cli_arg( default=None, aliases=["-o"], help="If present, the output of the two-step process will be saved in JSON format at the specified path.", @@ -207,7 +207,7 @@ class AttributeContextOutputArgs: default=True, help="If specified, additional information about the attribution process is added to the saved output.", ) - viz_path: Optional[str] = cli_arg( + viz_path: str | None = cli_arg( default=None, help="If specified, the visualization produced from the output is saved in HTML format at the specified path.", ) @@ -225,11 +225,11 @@ def __repr__(self): @classmethod def _to_dict(cls, val: Any) -> dict[str, Any]: - if val is None or isinstance(val, (str, int, float, bool)): + if val is None or isinstance(val, str | int | float | bool): return val elif isinstance(val, dict): return {k: cls._to_dict(v) for k, v in val.items()} - elif isinstance(val, (list, tuple)): + elif isinstance(val, list | tuple): return [cls._to_dict(v) for v in val] else: return str(val) diff --git a/inseq/commands/attribute_context/attribute_context_helpers.py b/inseq/commands/attribute_context/attribute_context_helpers.py index cb8793e2..eb505e41 100644 --- a/inseq/commands/attribute_context/attribute_context_helpers.py +++ b/inseq/commands/attribute_context/attribute_context_helpers.py @@ -1,7 +1,7 @@ import logging import re from dataclasses import dataclass, field, fields -from typing import Any, Optional +from typing import Any from rich import print as rprint from rich.prompt import Confirm, Prompt @@ -25,8 +25,8 @@ class CCIOutput: cti_score: float contextual_output: str contextless_output: str - input_context_scores: Optional[list[float]] = None - output_context_scores: Optional[list[float]] = None + input_context_scores: list[float] | None = None + output_context_scores: list[float] | None = None def __repr__(self): return f"{self.__class__.__name__}({pretty_dict(self.__dict__)})" @@ -39,15 +39,15 @@ def to_dict(self) -> dict[str, Any]: class AttributeContextOutput: """Output of the overall context attribution process.""" - input_context: Optional[str] = None - input_context_tokens: Optional[list[str]] = None - output_context: Optional[str] = None - output_context_tokens: Optional[list[str]] = None - output_current: Optional[str] = None - output_current_tokens: Optional[list[str]] = None - cti_scores: Optional[list[float]] = None + input_context: str | None = None + input_context_tokens: list[str] | None = None + output_context: str | None = None + output_context_tokens: list[str] | None = None + output_current: str | None = None + output_current_tokens: list[str] | None = None + cti_scores: list[float] | None = None cci_scores: list[CCIOutput] = field(default_factory=list) - info: Optional[AttributeContextArgs] = None + info: AttributeContextArgs | None = None def __repr__(self): return f"{self.__class__.__name__}({pretty_dict(self.__dict__)})" @@ -80,7 +80,7 @@ def concat_with_sep(s1: str, s2: str, sep: str) -> bool: return s1 + s2 -def format_template(template: str, current: str, context: Optional[str] = None) -> str: +def format_template(template: str, current: str, context: str | None = None) -> str: kwargs = {"current": current} if context is not None: kwargs["context"] = context @@ -149,7 +149,7 @@ def generate_model_output( return output_gen -def prompt_user_for_context(output: str, context_candidate: Optional[str] = None) -> str: +def prompt_user_for_context(output: str, context_candidate: str | None = None) -> str: """Prompt the user to provide the correct context for the provided output.""" while True: if context_candidate: @@ -194,16 +194,16 @@ def get_output_context_from_aligned_inputs(input_context: str, output_text: str) def prepare_outputs( model: HuggingfaceModel, - input_context_text: Optional[str], + input_context_text: str | None, input_full_text: str, - output_context_text: Optional[str], - output_current_text: Optional[str], + output_context_text: str | None, + output_current_text: str | None, output_template: str, handle_output_context_strategy: str, generation_kwargs: dict[str, Any] = {}, special_tokens_to_keep: list[str] = [], decoder_input_output_separator: str = " ", -) -> tuple[Optional[str], str]: +) -> tuple[str | None, str]: """Handle model outputs and prepare them for attribution. This procedure is valid both for encoder-decoder and decoder-only models. @@ -318,11 +318,11 @@ def get_scores_threshold(scores: list[float], std_weight: float) -> float: def filter_rank_tokens( tokens: list[str], scores: list[float], - std_threshold: Optional[float] = None, - topk: Optional[int] = None, + std_threshold: float | None = None, + topk: int | None = None, ) -> tuple[list[tuple[int, float, str]], float]: indices = list(range(0, len(scores))) - token_score_tuples = sorted(zip(indices, scores, tokens), key=lambda x: abs(x[1]), reverse=True) + token_score_tuples = sorted(zip(indices, scores, tokens, strict=False), key=lambda x: abs(x[1]), reverse=True) threshold = get_scores_threshold(scores, std_threshold) token_score_tuples = [(i, s, t) for i, s, t in token_score_tuples if abs(s) >= threshold] if topk: @@ -336,7 +336,7 @@ def get_contextless_output( output_current_tokens: list[str], cti_idx: int, cti_ranked_tokens: tuple[int, float, str], - contextless_output_next_tokens: Optional[list[str]], + contextless_output_next_tokens: list[str] | None, prompt_user_for_contextless_output_next_tokens: bool, cci_step_idx: int, decoder_input_output_separator: str = " ", @@ -420,7 +420,7 @@ def get_source_target_cci_scores( model_has_lang_tag: bool, decoder_input_output_separator: str, special_tokens_to_keep: list[str] = [], -) -> tuple[Optional[list[float]], Optional[list[float]]]: +) -> tuple[list[float] | None, list[float] | None]: """Extract attribution scores for the input and output contexts.""" input_scores, output_scores = None, None if has_input_context: @@ -456,7 +456,7 @@ def prompt_user_for_contextless_output_next_tokens( cti_idx: int, model: HuggingfaceModel, special_tokens_to_keep: list[str] = [], -) -> Optional[str]: +) -> str | None: """Prompt the user to provide the next tokens of the contextless output. Args: diff --git a/inseq/commands/attribute_context/attribute_context_viz_helpers.py b/inseq/commands/attribute_context/attribute_context_viz_helpers.py index 8a2dff76..d8e74bd6 100644 --- a/inseq/commands/attribute_context/attribute_context_viz_helpers.py +++ b/inseq/commands/attribute_context/attribute_context_viz_helpers.py @@ -1,10 +1,11 @@ from copy import deepcopy -from typing import Literal, Optional, Union +from typing import Literal from rich.console import Console from ... import load_model from ...models import HuggingfaceModel +from ...utils.viz_utils import treescope_ignore from .attribute_context_args import AttributeContextArgs from .attribute_context_helpers import ( AttributeContextOutput, @@ -15,7 +16,7 @@ def get_formatted_procedure_details(args: AttributeContextArgs) -> str: - def format_comment(std: Optional[float] = None, topk: Optional[int] = None) -> str: + def format_comment(std: float | None = None, topk: int | None = None) -> str: comment = [] if std: comment.append(f"std λ={std:.2f}") @@ -55,7 +56,7 @@ def format_context_comment( special_tokens_to_keep: list[str], context: str, context_scores: list[float], - other_context_scores: Optional[list[float]] = None, + other_context_scores: list[float] | None = None, is_target: bool = False, context_type: Literal["Input", "Output"] = "Input", ) -> str: @@ -117,16 +118,16 @@ def format_context_comment( return out_string +@treescope_ignore def visualize_attribute_context( output: AttributeContextOutput, - model: Union[HuggingfaceModel, str, None] = None, - cti_threshold: Optional[float] = None, + model: HuggingfaceModel | str | None = None, + cti_threshold: float | None = None, return_html: bool = False, -) -> Optional[str]: +) -> str | None: if output.info is None: raise ValueError("Cannot visualize attribution results without args. Set add_output_info = True.") console = Console(record=True) - viz = get_formatted_procedure_details(output.info) if model is None: model = output.info.model_name_or_path if isinstance(model, str): @@ -140,15 +141,16 @@ def visualize_attribute_context( raise TypeError(f"Unsupported model type {type(model)} for visualization.") if cti_threshold is None and len(output.cti_scores) > 1: cti_threshold = get_scores_threshold(output.cti_scores, output.info.context_sensitivity_std_threshold) + viz = get_formatted_procedure_details(output.info) viz += "\n\n" + get_formatted_attribute_context_results(model, output.info, output, cti_threshold) with console.capture() as _: console.print(viz, soft_wrap=False) + if output.info.show_viz: + console.print(viz, soft_wrap=False) html = console.export_html() if output.info.viz_path: with open(output.info.viz_path, "w", encoding="utf-8") as f: f.write(html) - if output.info.show_viz: - console.print(viz, soft_wrap=False) if return_html: return html return None diff --git a/inseq/commands/attribute_dataset/attribute_dataset.py b/inseq/commands/attribute_dataset/attribute_dataset.py index c81f65a7..eead8ec5 100644 --- a/inseq/commands/attribute_dataset/attribute_dataset.py +++ b/inseq/commands/attribute_dataset/attribute_dataset.py @@ -1,5 +1,3 @@ -from typing import Optional - from ...utils import is_datasets_available from ..attribute import AttributeExtendedArgs from ..attribute.attribute import attribute @@ -10,7 +8,7 @@ from datasets import load_dataset -def load_fields_from_dataset(dataset_args: LoadDatasetArgs) -> tuple[list[str], Optional[list[str]]]: +def load_fields_from_dataset(dataset_args: LoadDatasetArgs) -> tuple[list[str], list[str] | None]: if not is_datasets_available(): raise ImportError("The datasets library needs to be installed to use the attribute-dataset client.") dataset = load_dataset( diff --git a/inseq/commands/attribute_dataset/attribute_dataset_args.py b/inseq/commands/attribute_dataset/attribute_dataset_args.py index 3910d997..e3e480b5 100644 --- a/inseq/commands/attribute_dataset/attribute_dataset_args.py +++ b/inseq/commands/attribute_dataset/attribute_dataset_args.py @@ -1,5 +1,4 @@ from dataclasses import dataclass -from typing import Optional from ...utils import cli_arg from ..commands_utils import command_args_docstring @@ -12,29 +11,29 @@ class LoadDatasetArgs: aliases=["-d", "--dataset"], help="The type of dataset to be loaded for attribution.", ) - input_text_field: Optional[str] = cli_arg( + input_text_field: str | None = cli_arg( aliases=["-in", "--input"], help="Name of the field containing the input texts used for attribution." ) - generated_text_field: Optional[str] = cli_arg( + generated_text_field: str | None = cli_arg( default=None, aliases=["-gen", "--generated"], help="Name of the field containing the generated texts used for constrained decoding.", ) - dataset_config: Optional[str] = cli_arg( + dataset_config: str | None = cli_arg( default=None, aliases=["--config"], help="The name of the Huggingface dataset configuration." ) - dataset_dir: Optional[str] = cli_arg( + dataset_dir: str | None = cli_arg( default=None, aliases=["--dir"], help="Path to the directory containing the data files." ) - dataset_files: Optional[list[str]] = cli_arg(default=None, aliases=["--files"], help="Path to the dataset files.") - dataset_split: Optional[str] = cli_arg(default="train", aliases=["--split"], help="Dataset split.") - dataset_revision: Optional[str] = cli_arg( + dataset_files: list[str] | None = cli_arg(default=None, aliases=["--files"], help="Path to the dataset files.") + dataset_split: str | None = cli_arg(default="train", aliases=["--split"], help="Dataset split.") + dataset_revision: str | None = cli_arg( default=None, aliases=["--revision"], help="The Huggingface dataset revision." ) - dataset_auth_token: Optional[str] = cli_arg( + dataset_auth_token: str | None = cli_arg( default=None, aliases=["--auth"], help="The auth token for the Huggingface dataset." ) - dataset_kwargs: Optional[dict] = cli_arg( + dataset_kwargs: dict | None = cli_arg( default_factory=dict, help="Additional keyword arguments passed to the dataset constructor in JSON format.", ) diff --git a/inseq/commands/base.py b/inseq/commands/base.py index 8f412d39..418d3f73 100644 --- a/inseq/commands/base.py +++ b/inseq/commands/base.py @@ -1,13 +1,13 @@ import dataclasses -from abc import ABC, abstractstaticmethod +from abc import ABC, abstractmethod from argparse import Namespace from collections.abc import Iterable -from typing import Any, NewType, Union +from typing import Any, NewType from ..utils import InseqArgumentParser DataClassType = NewType("DataClassType", Any) -OneOrMoreDataClasses = Union[DataClassType, Iterable[DataClassType]] +OneOrMoreDataClasses = DataClassType | Iterable[DataClassType] class BaseCLICommand(ABC): @@ -44,6 +44,7 @@ def build(cls, args: Namespace): dataclasses_args = dataclasses_args[0] return cls, dataclasses_args - @abstractstaticmethod + @staticmethod + @abstractmethod def run(args: OneOrMoreDataClasses): raise NotImplementedError() diff --git a/inseq/data/__init__.py b/inseq/data/__init__.py index 7beefa23..99e20f55 100644 --- a/inseq/data/__init__.py +++ b/inseq/data/__init__.py @@ -31,7 +31,7 @@ EncoderDecoderBatch, slice_batch_from_position, ) -from .viz import show_attributions +from .viz import show_attributions, show_granular_attributions, show_token_attributions __all__ = [ "Aggregator", @@ -58,6 +58,8 @@ "OneOrMoreTokenSequences", "TextInput", "show_attributions", + "show_granular_attributions", + "show_token_attributions", "list_aggregation_functions", "MultiDimensionalFeatureAttributionStepOutput", "get_batch_from_inputs", diff --git a/inseq/data/aggregation_functions.py b/inseq/data/aggregation_functions.py index ac35dd84..904a05e7 100644 --- a/inseq/data/aggregation_functions.py +++ b/inseq/data/aggregation_functions.py @@ -14,7 +14,6 @@ import logging from abc import abstractmethod -from typing import Union import torch from torch.linalg import vector_norm @@ -37,7 +36,7 @@ def __init__(self): @abstractmethod def __call__( self, - scores: Union[torch.Tensor, tuple[torch.Tensor, ...]], + scores: torch.Tensor | tuple[torch.Tensor, ...], dim: int, **kwargs, ) -> ScoreTensor: diff --git a/inseq/data/aggregator.py b/inseq/data/aggregator.py index 32265b09..95b10309 100644 --- a/inseq/data/aggregator.py +++ b/inseq/data/aggregator.py @@ -1,7 +1,7 @@ import logging from abc import ABC, abstractmethod -from collections.abc import Sequence -from typing import TYPE_CHECKING, Callable, Optional, TypeVar, Union +from collections.abc import Callable, Sequence +from typing import TYPE_CHECKING, TypeVar import torch @@ -45,7 +45,7 @@ def _get_fn(name: str) -> Callable: ) return AggregationFunction.available_classes()[name]() - def __init__(self, default: Union[str, Callable], **kwargs): + def __init__(self, default: str | Callable, **kwargs): super().__init__(**kwargs) self.default = self._get_fn(default) if isinstance(default, str) else default @@ -127,8 +127,8 @@ def end_aggregation_hook(cls, tensors: TensorWrapper, **kwargs): def _get_aggregators_from_id( aggregator: str, - aggregate_fn: Optional[str] = None, -) -> tuple[type[Aggregator], Optional[AggregationFunction]]: + aggregate_fn: str | None = None, +) -> tuple[type[Aggregator], AggregationFunction | None]: if aggregator in available_classes(Aggregator): aggregator = Aggregator.available_classes()[aggregator] elif aggregator in available_classes(AggregationFunction): @@ -158,8 +158,8 @@ def _get_aggregators_from_id( class AggregatorPipeline: def __init__( self, - aggregators: list[Union[str, type[Aggregator]]], - aggregate_fn: Optional[list[Union[str, Callable]]] = None, + aggregators: list[str | type[Aggregator]], + aggregate_fn: list[str | Callable] | None = None, ): self.aggregators: list[type[Aggregator]] = [] self.aggregate_fn: list[Callable] = [] @@ -186,7 +186,7 @@ def aggregate( if do_pre_aggregation_checks: for aggregator in self.aggregators: aggregator.start_aggregation_hook(tensors, **kwargs) - for aggregator, aggregate_fn in zip(self.aggregators, self.aggregate_fn): + for aggregator, aggregate_fn in zip(self.aggregators, self.aggregate_fn, strict=False): curr_aggregation_kwargs = kwargs.copy() if aggregate_fn is not None: curr_aggregation_kwargs["aggregate_fn"] = aggregate_fn @@ -199,7 +199,7 @@ def aggregate( return tensors -AggregatorInput = Union[AggregatorPipeline, type[Aggregator], str, Sequence[Union[str, type[Aggregator]]], None] +AggregatorInput = AggregatorPipeline | type[Aggregator] | str | Sequence[str | type[Aggregator]] | None def list_aggregators() -> list[str]: @@ -208,12 +208,12 @@ def list_aggregators() -> list[str]: class AggregableMixin(ABC): - _aggregator: Union[AggregatorPipeline, type[Aggregator]] + _aggregator: AggregatorPipeline | type[Aggregator] def aggregate( self: AggregableMixinClass, aggregator: AggregatorInput = None, - aggregate_fn: Union[str, Sequence[str], None] = None, + aggregate_fn: str | Sequence[str] | None = None, do_pre_aggregation_checks: bool = True, do_post_aggregation_checks: bool = True, **kwargs, @@ -230,7 +230,7 @@ def aggregate( if aggregator is None: aggregator = self._aggregator if isinstance(aggregator, str): - if isinstance(aggregate_fn, (list, tuple)): + if isinstance(aggregate_fn, list | tuple): raise ValueError( "If a single aggregator is used, aggregate_fn should also be a string identifier for the " "corresponding aggregation function if defined." @@ -238,11 +238,11 @@ def aggregate( aggregator, aggregate_fn = _get_aggregators_from_id(aggregator, aggregate_fn) if aggregate_fn is not None: kwargs["aggregate_fn"] = aggregate_fn - elif isinstance(aggregator, (list, tuple)): - if all(isinstance(a, (str, type)) for a in aggregator): + elif isinstance(aggregator, list | tuple): + if all(isinstance(a, str | type) for a in aggregator): aggregator = AggregatorPipeline(aggregator, aggregate_fn) elif all(isinstance(agg, tuple) for agg in aggregator): - if all(isinstance(idx, (str, type)) for agg in aggregator for idx in agg): + if all(isinstance(idx, str | type) for agg in aggregator for idx in agg): aggregator = AggregatorPipeline([a[0] for a in aggregator], [a[1] for a in aggregator]) else: raise ValueError( @@ -280,7 +280,7 @@ class SequenceAttributionAggregator(Aggregator): @classmethod def _aggregate( - cls, attr: "FeatureAttributionSequenceOutput", aggregate_fn: Union[str, Callable, None] = None, **kwargs + cls, attr: "FeatureAttributionSequenceOutput", aggregate_fn: str | Callable | None = None, **kwargs ) -> "FeatureAttributionSequenceOutput": if aggregate_fn is None and isinstance(attr._dict_aggregate_fn, dict): aggregate_fn = DictWithDefault(default=cls.default_fn, **attr._dict_aggregate_fn) @@ -307,9 +307,9 @@ def _process_attribution_scores( cls, attr: "FeatureAttributionSequenceOutput", aggregate_fn: AggregationFunction, - select_idx: Optional[OneOrMoreIndices] = None, - normalize: Optional[bool] = None, - rescale: Optional[bool] = None, + select_idx: OneOrMoreIndices | None = None, + normalize: bool | None = None, + rescale: bool | None = None, **kwargs, ): if normalize and rescale: @@ -375,7 +375,7 @@ def aggregate_source_attributions( cls, attr: "FeatureAttributionSequenceOutput", aggregate_fn: AggregationFunction, - select_idx: Optional[OneOrMoreIndices] = None, + select_idx: OneOrMoreIndices | None = None, normalize: bool = True, rescale: bool = False, **kwargs, @@ -390,7 +390,7 @@ def aggregate_target_attributions( cls, attr: "FeatureAttributionSequenceOutput", aggregate_fn: AggregationFunction, - select_idx: Optional[OneOrMoreIndices] = None, + select_idx: OneOrMoreIndices | None = None, normalize: bool = True, rescale: bool = False, **kwargs, @@ -409,7 +409,7 @@ def aggregate_sequence_scores( cls, attr: "FeatureAttributionSequenceOutput", aggregate_fn: AggregationFunction, - select_idx: Optional[OneOrMoreIndices] = None, + select_idx: OneOrMoreIndices | None = None, **kwargs, ): if aggregate_fn.takes_sequence_scores: @@ -450,7 +450,7 @@ def is_compatible(attr: "FeatureAttributionSequenceOutput"): def _filter_scores( scores: torch.Tensor, dim: int = -1, - indices: Optional[OneOrMoreIndices] = None, + indices: OneOrMoreIndices | None = None, ) -> torch.Tensor: indexed = scores.index_select(dim, validate_indices(scores, dim, indices).to(scores.device)) if isinstance(indices, int): @@ -459,11 +459,11 @@ def _filter_scores( @staticmethod def _aggregate_scores( - scores: Union[torch.Tensor, tuple[torch.Tensor, ...]], + scores: torch.Tensor | tuple[torch.Tensor, ...], aggregate_fn: AggregationFunction, dim: int = -1, **kwargs, - ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: + ) -> torch.Tensor | tuple[torch.Tensor, ...]: if isinstance(scores, tuple) and aggregate_fn.takes_single_tensor: return tuple(aggregate_fn(score, dim=dim, **kwargs) for score in scores) return aggregate_fn(scores, dim=dim, **kwargs) @@ -492,8 +492,8 @@ class ContiguousSpanAggregator(SequenceAttributionAggregator): def start_aggregation_hook( cls, attr: "FeatureAttributionSequenceOutput", - source_spans: Optional[IndexSpan] = None, - target_spans: Optional[IndexSpan] = None, + source_spans: IndexSpan | None = None, + target_spans: IndexSpan | None = None, **kwargs, ): super().start_aggregation_hook(attr, **kwargs) @@ -508,8 +508,8 @@ def end_aggregation_hook(cls, attr: "FeatureAttributionSequenceOutput", **kwargs def aggregate( cls, attr: "FeatureAttributionSequenceOutput", - source_spans: Optional[IndexSpan] = None, - target_spans: Optional[IndexSpan] = None, + source_spans: IndexSpan | None = None, + target_spans: IndexSpan | None = None, **kwargs, ): """Spans can be: @@ -530,7 +530,7 @@ def format_spans(spans) -> list[tuple[int, int]]: return [spans] if isinstance(spans[0], int) else spans @classmethod - def validate_spans(cls, span_sequence: list[TokenWithId], spans: Optional[IndexSpan] = None): + def validate_spans(cls, span_sequence: list[TokenWithId], spans: IndexSpan | None = None): if not spans: return allmatch = lambda l, type: all(isinstance(x, type) for x in l) @@ -685,7 +685,7 @@ def aggregate( attr: "FeatureAttributionSequenceOutput", aggregate_source: bool = True, aggregate_target: bool = True, - special_chars: Union[str, tuple[str, ...]] = "▁", + special_chars: str | tuple[str, ...] = "▁", is_suffix_symbol: bool = False, **kwargs, ): @@ -698,7 +698,7 @@ def aggregate( return super().aggregate(attr, source_spans=source_spans, target_spans=target_spans, **kwargs) @staticmethod - def get_spans(tokens: list[TokenWithId], special_chars: Union[str, tuple[str, ...]], is_suffix_symbol: bool): + def get_spans(tokens: list[TokenWithId], special_chars: str | tuple[str, ...], is_suffix_symbol: bool): spans = [] last_prefix_idx = 0 has_special_chars = any(sym in token.token for token in tokens for sym in special_chars) @@ -828,8 +828,8 @@ class SliceAggregator(ContiguousSpanAggregator): def aggregate( cls, attr: "FeatureAttributionSequenceOutput", - source_spans: Optional[IndexSpan] = None, - target_spans: Optional[IndexSpan] = None, + source_spans: IndexSpan | None = None, + target_spans: IndexSpan | None = None, **kwargs, ): """Spans can be: diff --git a/inseq/data/attribution.py b/inseq/data/attribution.py index e42c0300..fbf2b686 100644 --- a/inseq/data/attribution.py +++ b/inseq/data/attribution.py @@ -1,12 +1,14 @@ import base64 import logging +from collections.abc import Callable from copy import deepcopy from dataclasses import dataclass, field from os import PathLike from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, Optional, Union +from typing import TYPE_CHECKING, Any import torch +import treescope as ts from ..utils import ( convert_from_safetensor, @@ -36,16 +38,23 @@ from .aggregator import AggregableMixin, Aggregator, AggregatorPipeline from .batch import Batch, BatchEmbedding, BatchEncoding, DecoderOnlyBatch, EncoderDecoderBatch from .data_utils import TensorWrapper +from .viz import get_saliency_heatmap_treescope, get_tokens_heatmap_treescope if TYPE_CHECKING: from ..models import AttributionModel -FeatureAttributionInput = Union[TextInput, BatchEncoding, Batch] +FeatureAttributionInput = TextInput | BatchEncoding | Batch logger = logging.getLogger(__name__) +DEFAULT_ATTRIBUTION_DIM_NAMES = { + "source_attributions": {0: "Input Tokens", 1: "Generated Tokens"}, + "target_attributions": {0: "Input Tokens", 1: "Generated Tokens"}, +} + + def get_batch_from_inputs( attribution_model: "AttributionModel", inputs: FeatureAttributionInput, @@ -56,7 +65,7 @@ def get_batch_from_inputs( if isinstance(inputs, Batch): batch = inputs else: - if isinstance(inputs, (str, list)): + if isinstance(inputs, str | list): encodings: BatchEncoding = attribution_model.encode( inputs, as_targets=as_targets, @@ -149,14 +158,15 @@ class FeatureAttributionSequenceOutput(TensorWrapper, AggregableMixin): source: list[TokenWithId] target: list[TokenWithId] - source_attributions: Optional[SequenceAttributionTensor] = None - target_attributions: Optional[SequenceAttributionTensor] = None - step_scores: Optional[dict[str, SingleScoresPerSequenceTensor]] = None - sequence_scores: Optional[dict[str, MultipleScoresPerSequenceTensor]] = None + source_attributions: SequenceAttributionTensor | None = None + target_attributions: SequenceAttributionTensor | None = None + step_scores: dict[str, SingleScoresPerSequenceTensor] | None = None + sequence_scores: dict[str, MultipleScoresPerSequenceTensor] | None = None attr_pos_start: int = 0 - attr_pos_end: Optional[int] = None - _aggregator: Union[str, list[str], None] = None - _dict_aggregate_fn: Optional[dict[str, str]] = None + attr_pos_end: int | None = None + _aggregator: str | list[str] | None = None + _dict_aggregate_fn: dict[str, str] | None = None + _attribution_dim_names: dict[str, dict[int, str]] | None = None def __post_init__(self): if self._dict_aggregate_fn is None: @@ -164,12 +174,17 @@ def __post_init__(self): default_aggregate_fn = DEFAULT_ATTRIBUTION_AGGREGATE_DICT default_aggregate_fn.update(self._dict_aggregate_fn) self._dict_aggregate_fn = default_aggregate_fn + if self._attribution_dim_names is None: + self._attribution_dim_names = {} + default_dim_names = DEFAULT_ATTRIBUTION_DIM_NAMES + default_dim_names.update(self._attribution_dim_names) + self._attribution_dim_names = default_dim_names if self._aggregator is None: self._aggregator = "scores" if self.attr_pos_end is None or self.attr_pos_end > len(self.target): self.attr_pos_end = len(self.target) - def __getitem__(self, s: Union[slice, int]) -> "FeatureAttributionSequenceOutput": + def __getitem__(self, s: slice | int) -> "FeatureAttributionSequenceOutput": source_spans = None if self.source_attributions is None else (s.start, s.stop) target_spans = None if self.source_attributions is not None else (s.start, s.stop) return self.aggregate("slices", source_spans=source_spans, target_spans=target_spans) @@ -179,6 +194,72 @@ def __sub__(self, other: "FeatureAttributionSequenceOutput") -> "FeatureAttribut raise ValueError(f"Cannot compare {type(other)} with {type(self)}") return self.aggregate("pair", paired_attr=other, do_post_aggregation_checks=False) + def __treescope_repr__( + self, + path: str, + subtree_renderer: Callable[[Any, str | None], ts.rendering_parts.Rendering], + ) -> ts.rendering_parts.Rendering: + def granular_attribution_visualizer( + value: Any, + path: tuple[Any, ...] | None, + ): + if isinstance(value, torch.Tensor): + tname = path.split(".")[-1] + column_labels = [t.token for t in self.target[self.attr_pos_start : self.attr_pos_end]] + if tname == "source_attributions": + row_labels = [t.token for t in self.source] + elif tname == "target_attributions": + row_labels = [t.token for t in self.target] + elif tname.startswith("sequence_scores"): + tname = tname[17:].split("_")[0] + if tname.startswith("encoder"): + row_labels = [t.token for t in self.source] + column_labels = [t.token for t in self.source] + elif tname.startswith("decoder"): + row_labels = [t.token for t in self.target] + column_labels = [t.token for t in self.target] + adapter = ts.type_registries.lookup_ndarray_adapter(value) + if value.ndim >= 2: + return ts.IPythonVisualization( + ts.figures.inline( + adapter.get_array_summary(value, fast=False), + get_saliency_heatmap_treescope( + scores=value.numpy(), + column_labels=column_labels, + row_labels=row_labels, + dim_names=self._attribution_dim_names.get(tname, None), + ), + ), + replace=True, + ) + else: + return ts.IPythonVisualization( + ts.figures.inline( + adapter.get_array_summary(value, fast=False) + "\n\n", + ts.figures.figure_from_treescope_rendering_part( + ts.rendering_parts.indented_children( + [ + get_tokens_heatmap_treescope( + tokens=column_labels, + scores=value.numpy(), + max_val=value.max().item(), + ) + ] + ) + ), + ), + replace=True, + ) + + with ts.active_autovisualizer.set_scoped(granular_attribution_visualizer): + return ts.repr_lib.render_object_constructor( + object_type=type(self), + attributes=self.__dict__, + path=path, + subtree_renderer=subtree_renderer, + roundtrippable=True, + ) + def _convert_to_safetensors(self, scores_precision: ScorePrecision = "float32"): """ Converts tensor attributes within the class to the specified precision. @@ -249,8 +330,8 @@ def from_step_attributions( cls, attributions: list["FeatureAttributionStepOutput"], tokenized_target_sentences: list[list[TokenWithId]], - pad_token: Optional[Any] = None, - attr_pos_end: Optional[int] = None, + pad_token: Any | None = None, + attr_pos_end: int | None = None, ) -> list["FeatureAttributionSequenceOutput"]: """Converts a list of :class:`~inseq.data.attribution.FeatureAttributionStepOutput` objects containing multiple examples outputs per step into a list of :class:`~inseq.data.attribution.FeatureAttributionSequenceOutput` with @@ -267,7 +348,6 @@ def from_step_attributions( num_sequences = len(attr.prefix) if not all(len(attr.prefix) == num_sequences for attr in attributions): raise ValueError("All the attributions must include the same number of sequences.") - seq_attributions: list[FeatureAttributionSequenceOutput] = [] sources = [] targets = [] pos_start = [] @@ -288,22 +368,13 @@ def from_step_attributions( # If the model is decoder-only, the source is the input prefix curr_pos_start = min(len(tokenized_target_sentences[seq_idx]), attr_pos_end) - len(targets[seq_idx]) pos_start.append(curr_pos_start) - source = tokenized_target_sentences[seq_idx][:curr_pos_start] if not sources else sources[seq_idx] - curr_seq_attribution: FeatureAttributionSequenceOutput = attr.get_sequence_cls( - source=source, - target=tokenized_target_sentences[seq_idx], - attr_pos_start=pos_start[seq_idx], - attr_pos_end=attr_pos_end, - ) - seq_attributions.append(curr_seq_attribution) if attr.source_attributions is not None: source_attributions = get_sequences_from_batched_steps([att.source_attributions for att in attributions]) for seq_id in range(num_sequences): # Remove padding from tensor - filtered_source_attribution = source_attributions[seq_id][ + source_attributions[seq_id] = source_attributions[seq_id][ : len(sources[seq_id]), : len(targets[seq_id]), ... ] - seq_attributions[seq_id].source_attributions = filtered_source_attribution if attr.target_attributions is not None: target_attributions = get_sequences_from_batched_steps( [att.target_attributions for att in attributions], padding_dims=[1] @@ -316,7 +387,6 @@ def from_step_attributions( ] if target_attributions[seq_id].shape[0] != len(tokenized_target_sentences[seq_id]): target_attributions[seq_id] = pad_with_nan(target_attributions[seq_id], dim=0, pad_size=1) - seq_attributions[seq_id].target_attributions = target_attributions[seq_id] if attr.step_scores is not None: step_scores = [{} for _ in range(num_sequences)] for step_score_name in attr.step_scores.keys(): @@ -325,8 +395,6 @@ def from_step_attributions( ) for seq_id in range(num_sequences): step_scores[seq_id][step_score_name] = out_step_scores[seq_id][: len(targets[seq_id])] - for seq_id in range(num_sequences): - seq_attributions[seq_id].step_scores = step_scores[seq_id] if attr.sequence_scores is not None: seq_scores = [{} for _ in range(num_sequences)] for seq_score_name in attr.sequence_scores.keys(): @@ -335,7 +403,7 @@ def from_step_attributions( # that are not source-to-target (default for encoder-decoder) or target-to-target # (default for decoder only). remove_pad_fn = cls.get_remove_pad_fn(attr, seq_score_name) - if seq_score_name.startswith("encoder"): + if seq_score_name.startswith("encoder") or seq_score_name.startswith("decoder"): out_seq_scores = [attr.sequence_scores[seq_score_name][i, ...] for i in range(num_sequences)] else: out_seq_scores = get_sequences_from_batched_steps( @@ -343,20 +411,37 @@ def from_step_attributions( ) for seq_id in range(num_sequences): seq_scores[seq_id][seq_score_name] = remove_pad_fn(out_seq_scores, sources, targets, seq_id) - for seq_id in range(num_sequences): - seq_attributions[seq_id].sequence_scores = seq_scores[seq_id] + seq_attributions: list[FeatureAttributionSequenceOutput] = [] + for seq_idx in range(num_sequences): + curr_seq_attribution: FeatureAttributionSequenceOutput = attr.get_sequence_cls( + source=deepcopy( + tokenized_target_sentences[seq_idx][: pos_start[seq_idx]] if not sources else sources[seq_idx] + ), + target=deepcopy(tokenized_target_sentences[seq_idx]), + source_attributions=source_attributions[seq_idx] if attr.source_attributions is not None else None, + target_attributions=target_attributions[seq_idx] if attr.target_attributions is not None else None, + step_scores=step_scores[seq_idx] if attr.step_scores is not None else None, + sequence_scores=seq_scores[seq_idx] if attr.sequence_scores is not None else None, + attr_pos_start=pos_start[seq_idx], + attr_pos_end=attr_pos_end, + ) + seq_attributions.append(curr_seq_attribution) return seq_attributions def show( self, - min_val: Optional[int] = None, - max_val: Optional[int] = None, + min_val: int | None = None, + max_val: int | None = None, + max_show_size: int | None = None, + show_dim: int | str | None = None, + slice_dims: dict[int | str, tuple[int, int]] | None = None, display: bool = True, - return_html: Optional[bool] = False, - aggregator: Union[AggregatorPipeline, type[Aggregator]] = None, + return_html: bool | None = False, + return_figure: bool = False, + aggregator: AggregatorPipeline | type[Aggregator] = None, do_aggregation: bool = True, **kwargs, - ) -> Optional[str]: + ) -> str | None: """Visualize the attributions. Args: @@ -366,11 +451,24 @@ def show( max_val (:obj:`int`, *optional*, defaults to None): Maximum value in the color range of the visualization. If None, the maximum value of the attributions across all visualized examples is used. + max_show_size (:obj:`int`, *optional*, defaults to None): + For granular visualization, this parameter specifies the maximum dimension size for additional dimensions + to be visualized. Default: 20. + show_dim (:obj:`int` or :obj:`str`, *optional*, defaults to None): + For granular visualization, this parameter specifies the dimension that should be visualized along with + the source and target tokens. Can be either the dimension index or the dimension name. Works only if + the dimension size is less than or equal to `max_show_size`. + slice_dims (:obj:`dict[int or str, tuple[int, int]]`, *optional*, defaults to None): + For granular visualization, this parameter specifies the dimensions that should be sliced and visualized + along with the source and target tokens. The dictionary should contain the dimension index or name as the + key and the slice range as the value. display (:obj:`bool`, *optional*, defaults to True): Whether to display the visualization. Can be set to False if the visualization is produced and stored for later use. return_html (:obj:`bool`, *optional*, defaults to False): Whether to return the HTML code of the visualization. + return_figure (:obj:`bool`, *optional*, defaults to False): + For granular visualization, whether to return the Treescope figure object for further manipulation. aggregator (:obj:`AggregatorPipeline`, *optional*, defaults to None): Aggregates attributions before visualizing them. If not specified, the default aggregator for the class is used. @@ -381,7 +479,7 @@ def show( Returns: :obj:`str`: The HTML code of the visualization if :obj:`return_html` is set to True, otherwise None. """ - from inseq import show_attributions + from inseq import show_attributions, show_granular_attributions # If no aggregator is specified, the default aggregator for the class is used aggregated = self.aggregate(aggregator, **kwargs) if do_aggregation else self @@ -390,8 +488,132 @@ def show( ): tokens = "".join(tid.token for tid in self.target) logger.warning(f"Found empty attributions, skipping attribution matching generation: {tokens}") + if ( + (aggregated.source_attributions is not None and aggregated.source_attributions.ndim == 2) + or (aggregated.target_attributions is not None and aggregated.target_attributions.ndim == 2) + or (aggregated.source_attributions is None and aggregated.target_attributions is None) + ): + return show_attributions( + attributions=aggregated, min_val=min_val, max_val=max_val, display=display, return_html=return_html + ) else: - return show_attributions(aggregated, min_val, max_val, display, return_html) + return show_granular_attributions( + attributions=aggregated, + max_show_size=max_show_size, + min_val=min_val, + max_val=max_val, + show_dim=show_dim, + display=display, + return_html=return_html, + return_figure=return_figure, + slice_dims=slice_dims, + ) + + def show_granular( + self, + min_val: int | None = None, + max_val: int | None = None, + max_show_size: int | None = None, + show_dim: int | str | None = None, + slice_dims: dict[int | str, tuple[int, int]] | None = None, + display: bool = True, + return_html: bool | None = False, + return_figure: bool = False, + ) -> str | None: + """Visualizes granular attribution heatmaps in HTML format. + + Args: + min_val (:obj:`int`, *optional*, defaults to None): + Lower attribution score threshold for color map. + max_val (:obj:`int`, *optional*, defaults to None): + Upper attribution score threshold for color map. + max_show_size (:obj:`int`, *optional*, defaults to None): + Maximum dimension size for additional dimensions to be visualized. Default: 20. + show_dim (:obj:`int` or :obj:`str`, *optional*, defaults to None): + Dimension to be visualized along with the source and target tokens. Can be either the dimension index or + the dimension name. Works only if the dimension size is less than or equal to `max_show_size`. + slice_dims (:obj:`dict[int or str, tuple[int, int]]`, *optional*, defaults to None): + Dimensions to be sliced and visualized along with the source and target tokens. The dictionary should + contain the dimension index or name as the key and the slice range as the value. + display (:obj:`bool`, *optional*, defaults to True): + Whether to show the output of the visualization function. + return_html (:obj:`bool`, *optional*, defaults to False): + If true, returns the HTML corresponding to the notebook visualization of the attributions in + string format, for saving purposes. + return_figure (:obj:`bool`, *optional*, defaults to False): + If true, returns the Treescope figure object for further manipulation. + + Returns: + `str`: Returns the HTML output if `return_html=True` + """ + from inseq import show_granular_attributions + + return show_granular_attributions( + attributions=self, + max_show_size=max_show_size, + min_val=min_val, + max_val=max_val, + show_dim=show_dim, + slice_dims=slice_dims, + display=display, + return_html=return_html, + return_figure=return_figure, + ) + + def show_tokens( + self, + min_val: int | None = None, + max_val: int | None = None, + display: bool = True, + return_html: bool | None = False, + return_figure: bool = False, + replace_char: dict[str, str] | None = None, + wrap_after: int | str | list[str] | tuple[str] | None = None, + step_score_highlight: str | None = None, + aggregator: AggregatorPipeline | type[Aggregator] = None, + do_aggregation: bool = True, + **kwargs, + ) -> str | None: + """Visualizes token-level attributions in HTML format. + + Args: + attributions (:class:`~inseq.data.attribution.FeatureAttributionSequenceOutput`): + Sequence attributions to be visualized. + min_val (:obj:`int`, *optional*, defaults to None): + Lower attribution score threshold for color map. + max_val (:obj:`int`, *optional*, defaults to None): + Upper attribution score threshold for color map. + display (:obj:`bool`, *optional*, defaults to True): + Whether to show the output of the visualization function. + return_html (:obj:`bool`, *optional*, defaults to False): + If true, returns the HTML corresponding to the notebook visualization of the attributions in string format, + for saving purposes. + return_figure (:obj:`bool`, *optional*, defaults to False): + If true, returns the Treescope figure object for further manipulation. + replace_char (:obj:`dict[str, str]`, *optional*, defaults to None): + Dictionary mapping strings to be replaced to replacement options, used for cleaning special characters. + Default: {}. + wrap_after (:obj:`int` or :obj:`str` or :obj:`list[str]` :obj:`tuple[str]]`, *optional*, defaults to None): + Token indices or tokens after which to wrap lines. E.g. 10 = wrap after every 10 tokens, "hi" = wrap after + word hi occurs, ["." "!", "?"] or ".!?" = wrap after every sentence-ending punctuation. + step_score_highlight (`str`, *optional*, defaults to None): + Name of the step score to use to highlight generated tokens in the visualization. If None, no highlights are + shown. Default: None. + """ + from inseq import show_token_attributions + + aggregated = self.aggregate(aggregator, **kwargs) if do_aggregation else self + return show_token_attributions( + attributions=aggregated, + min_val=min_val, + max_val=max_val, + display=display, + return_html=return_html, + return_figure=return_figure, + replace_char=replace_char, + wrap_after=wrap_after, + step_score_highlight=step_score_highlight, + ) @property def minimum(self) -> float: @@ -432,7 +654,7 @@ def weight_attributions(self, step_fn_id: str): def get_scores_dicts( self, - aggregator: Union[AggregatorPipeline, type[Aggregator]] = None, + aggregator: AggregatorPipeline | type[Aggregator] = None, do_aggregation: bool = True, **kwargs, ) -> dict[str, dict[str, dict[str, float]]]: @@ -467,13 +689,13 @@ def get_scores_dicts( class FeatureAttributionStepOutput(TensorWrapper): """Output of a single step of feature attribution, plus extra information related to what was attributed.""" - source_attributions: Optional[StepAttributionTensor] = None - step_scores: Optional[dict[str, SingleScorePerStepTensor]] = None - target_attributions: Optional[StepAttributionTensor] = None - sequence_scores: Optional[dict[str, MultipleScoresPerStepTensor]] = None - source: Optional[OneOrMoreTokenWithIdSequences] = None - prefix: Optional[OneOrMoreTokenWithIdSequences] = None - target: Optional[OneOrMoreTokenWithIdSequences] = None + source_attributions: StepAttributionTensor | None = None + step_scores: dict[str, SingleScorePerStepTensor] | None = None + target_attributions: StepAttributionTensor | None = None + sequence_scores: dict[str, MultipleScoresPerStepTensor] | None = None + source: OneOrMoreTokenWithIdSequences | None = None + prefix: OneOrMoreTokenWithIdSequences | None = None + target: OneOrMoreTokenWithIdSequences | None = None _sequence_cls: type["FeatureAttributionSequenceOutput"] = FeatureAttributionSequenceOutput def __post_init__(self): @@ -489,7 +711,7 @@ def get_sequence_cls(self, **kwargs): def remap_from_filtered( self, target_attention_mask: TargetIdsTensor, - batch: Union[DecoderOnlyBatch, EncoderDecoderBatch], + batch: DecoderOnlyBatch | EncoderDecoderBatch, is_final_step_method: bool = False, ) -> None: """Remaps the attributions to the original shape of the input sequence.""" @@ -571,7 +793,7 @@ class FeatureAttributionOutput: ] sequence_attributions: list[FeatureAttributionSequenceOutput] - step_attributions: Optional[list[FeatureAttributionStepOutput]] = None + step_attributions: list[FeatureAttributionStepOutput] | None = None info: dict[str, Any] = field(default_factory=dict) def __str__(self): @@ -581,11 +803,11 @@ def __repr__(self): return self.__str__() def __eq__(self, other): - for self_seq, other_seq in zip(self.sequence_attributions, other.sequence_attributions): + for self_seq, other_seq in zip(self.sequence_attributions, other.sequence_attributions, strict=False): if self_seq != other_seq: return False if self.step_attributions is not None and other.step_attributions is not None: - for self_step, other_step in zip(self.step_attributions, other.step_attributions): + for self_step, other_step in zip(self.step_attributions, other.step_attributions, strict=False): if self_step != other_step: return False if self.info != other.info: @@ -665,7 +887,7 @@ def save( ] save_outs.append(self_out) paths.append(path) - for attr_out, path_out in zip(save_outs, paths): + for attr_out, path_out in zip(save_outs, paths, strict=False): with open(path_out, f"w{'b' if compress else ''}") as f: json_advanced_dump( attr_out, @@ -703,7 +925,7 @@ def load( def aggregate( self, - aggregator: Union[AggregatorPipeline, type[Aggregator]] = None, + aggregator: AggregatorPipeline | type[Aggregator] = None, **kwargs, ) -> "FeatureAttributionOutput": """Aggregate the sequence attributions using one or more aggregators. @@ -723,21 +945,29 @@ def aggregate( def show( self, - min_val: Optional[int] = None, - max_val: Optional[int] = None, + min_val: int | None = None, + max_val: int | None = None, + max_show_size: int | None = None, + show_dim: int | str | None = None, + slice_dims: dict[int | str, tuple[int, int]] | None = None, display: bool = True, - return_html: Optional[bool] = False, - aggregator: Union[AggregatorPipeline, type[Aggregator]] = None, + return_html: bool | None = False, + return_figure: bool = False, + aggregator: AggregatorPipeline | type[Aggregator] = None, do_aggregation: bool = True, **kwargs, - ) -> Optional[str]: + ) -> str | list | None: """Visualize the sequence attributions. Args: min_val (int, optional): Minimum value for color scale. max_val (int, optional): Maximum value for color scale. + max_show_size (int, optional): Maximum size of the dimension to show. + show_dim (int or str, optional): Dimension to show. + slice_dims (dict[int or str, tuple[int, int]], optional): Dimensions to slice. display (bool, optional): If True, display the attribution visualization. return_html (bool, optional): If True, return the attribution visualization as HTML. + return_figure (bool, optional): If True, return the Treescope figure object for further manipulation. aggregator (:obj:`AggregatorPipeline` or :obj:`Type[Aggregator]`, optional): Aggregator or pipeline to use. If not provided, the default aggregator for every sequence attribution is used. @@ -746,23 +976,163 @@ def show( attributions are already aggregated. Returns: - str: Attribution visualization as HTML if `return_html=True`, None otherwise. + str: Attribution visualization as HTML if `return_html=True` + list: List of Treescope figure objects if `return_figure=True` + None if `return_html=False` and `return_figure=False` + """ out_str = "" + out_figs = [] for attr in self.sequence_attributions: + curr_out = attr.show( + min_val=min_val, + max_val=max_val, + max_show_size=max_show_size, + show_dim=show_dim, + slice_dims=slice_dims, + display=display, + return_html=return_html, + return_figure=return_figure, + aggregator=aggregator, + do_aggregation=do_aggregation, + **kwargs, + ) if return_html: - out_str += attr.show(min_val, max_val, display, return_html, aggregator, do_aggregation, **kwargs) - else: - attr.show(min_val, max_val, display, return_html, aggregator, do_aggregation, **kwargs) + out_str += curr_out + if return_figure: + out_figs.append(curr_out) + if return_html: + return out_str + if return_figure: + return out_figs + + def show_granular( + self, + min_val: int | None = None, + max_val: int | None = None, + max_show_size: int | None = None, + show_dim: int | str | None = None, + slice_dims: dict[int | str, tuple[int, int]] | None = None, + display: bool = True, + return_html: bool = False, + return_figure: bool = False, + ) -> str | None: + """Visualizes granular attribution heatmaps in HTML format. + + Args: + min_val (:obj:`int`, *optional*, defaults to None): + Lower attribution score threshold for color map. + max_val (:obj:`int`, *optional*, defaults to None): + Upper attribution score threshold for color map. + max_show_size (:obj:`int`, *optional*, defaults to None): + Maximum dimension size for additional dimensions to be visualized. Default: 20. + show_dim (:obj:`int` or :obj:`str`, *optional*, defaults to None): + Dimension to be visualized along with the source and target tokens. Can be either the dimension index or + the dimension name. Works only if the dimension size is less than or equal to `max_show_size`. + slice_dims (:obj:`dict[int or str, tuple[int, int]]`, *optional*, defaults to None): + Dimensions to be sliced and visualized along with the source and target tokens. The dictionary should + contain the dimension index or name as the key and the slice range as the value. + display (:obj:`bool`, *optional*, defaults to True): + Whether to show the output of the visualization function. + return_html (:obj:`bool`, *optional*, defaults to False): + If true, returns the HTML corresponding to the notebook visualization of the attributions in + string format, for saving purposes. + return_figure (:obj:`bool`, *optional*, defaults to False): + If true, returns the Treescope figure object for further manipulation. + + Returns: + `str`: Returns the HTML output if `return_html=True` + """ + out_str = "" + out_figs = [] + for attr in self.sequence_attributions: + curr_out = attr.show_granular( + min_val=min_val, + max_val=max_val, + max_show_size=max_show_size, + show_dim=show_dim, + slice_dims=slice_dims, + display=display, + return_html=return_html, + ) + if return_html: + out_str += curr_out + if return_figure: + out_figs.append(curr_out) + if return_html: + return out_str + if return_figure: + return out_figs + + def show_tokens( + self, + min_val: int | None = None, + max_val: int | None = None, + display: bool = True, + return_html: bool = False, + return_figure: bool = False, + replace_char: dict[str, str] | None = None, + wrap_after: int | str | list[str] | tuple[str] | None = None, + step_score_highlight: str | None = None, + aggregator: AggregatorPipeline | type[Aggregator] = None, + do_aggregation: bool = True, + **kwargs, + ) -> str | None: + """Visualizes token-level attributions in HTML format. + + Args: + min_val (:obj:`int`, *optional*, defaults to None): + Lower attribution score threshold for color map. + max_val (:obj:`int`, *optional*, defaults to None): + Upper attribution score threshold for color map. + display (:obj:`bool`, *optional*, defaults to True): + Whether to show the output of the visualization function. + return_html (:obj:`bool`, *optional*, defaults to False): + If true, returns the HTML corresponding to the notebook visualization of the attributions in string format, + for saving purposes. + return_figure (:obj:`bool`, *optional*, defaults to False): + If true, returns the Treescope figure object for further manipulation. + replace_char (:obj:`dict[str, str]`, *optional*, defaults to None): + Dictionary mapping strings to be replaced to replacement options, used for cleaning special characters. + Default: {}. + wrap_after (:obj:`int` or :obj:`str` or :obj:`list[str]` :obj:`tuple[str]]`, *optional*, defaults to None): + Token indices or tokens after which to wrap lines. E.g. 10 = wrap after every 10 tokens, "hi" = wrap after + word hi occurs, ["." "!", "?"] or ".!?" = wrap after every sentence-ending punctuation. + step_score_highlight (`str`, *optional*, defaults to None): + Name of the step score to use to highlight generated tokens in the visualization. If None, no highlights are + shown. Default: None. + """ + out_str = "" + out_figs = [] + for attr in self.sequence_attributions: + curr_out = attr.show_tokens( + min_val=min_val, + max_val=max_val, + display=display, + return_html=return_html, + return_figure=return_figure, + replace_char=replace_char, + wrap_after=wrap_after, + step_score_highlight=step_score_highlight, + aggregator=aggregator, + do_aggregation=do_aggregation, + **kwargs, + ) + if return_html: + out_str += curr_out + if return_figure: + out_figs.append(curr_out) if return_html: return out_str + if return_figure: + return out_figs def weight_attributions(self, step_score_id: str): for i, attr in enumerate(self.sequence_attributions): self.sequence_attributions[i] = attr.weight_attributions(step_score_id) def get_scores_dicts( - self, aggregator: Union[AggregatorPipeline, type[Aggregator]] = None, do_aggregation: bool = True, **kwargs + self, aggregator: AggregatorPipeline | type[Aggregator] = None, do_aggregation: bool = True, **kwargs ) -> list[dict[str, dict[str, dict[str, float]]]]: """Get all computed scores (attributions and step scores) for all sequences as a list of dictionaries. @@ -801,6 +1171,10 @@ def __post_init__(self): self._dict_aggregate_fn["target_attributions"]["scores"] = "vnorm" if "deltas" not in self._dict_aggregate_fn["step_scores"]["spans"]: self._dict_aggregate_fn["step_scores"]["spans"]["deltas"] = "absmax" + self._attribution_dim_names = { + "source_attributions": {0: "Input Tokens", 1: "Generated Tokens", 2: "Embedding Dimension"}, + "target_attributions": {0: "Input Tokens", 1: "Generated Tokens", 2: "Embedding Dimension"}, + } @dataclass(eq=False, repr=False) @@ -844,6 +1218,15 @@ class MultiDimensionalFeatureAttributionSequenceOutput(FeatureAttributionSequenc def __post_init__(self): super().__post_init__() self._aggregator = ["mean"] * self._num_dimensions + self._attribution_dim_names = { + "source_attributions": {0: "Input Tokens", 1: "Generated Tokens", 2: "Model Layer"}, + "target_attributions": {0: "Input Tokens", 1: "Generated Tokens", 2: "Model Layer"}, + "encoder": {0: "Input Tokens", 1: "Input Tokens", 2: "Model Layer"}, + "decoder": {0: "Generated Tokens", 1: "Generated Tokens", 2: "Model Layer"}, + } + if self._num_dimensions == 2: + for key in self._attribution_dim_names.keys(): + self._attribution_dim_names[key][3] = "Attention Head" @dataclass(eq=False, repr=False) diff --git a/inseq/data/batch.py b/inseq/data/batch.py index 65c49a9e..4f9413dc 100644 --- a/inseq/data/batch.py +++ b/inseq/data/batch.py @@ -1,5 +1,4 @@ from dataclasses import dataclass -from typing import Optional, Union from ..utils import get_aligned_idx from ..utils.typing import EmbeddingsTensor, ExpandedTargetIdsTensor, IdsTensor, OneOrMoreTokenSequences @@ -22,8 +21,8 @@ class BatchEncoding(TensorWrapper): input_ids: IdsTensor attention_mask: IdsTensor - input_tokens: Optional[OneOrMoreTokenSequences] = None - baseline_ids: Optional[IdsTensor] = None + input_tokens: OneOrMoreTokenSequences | None = None + baseline_ids: IdsTensor | None = None def __len__(self) -> int: return len(self.input_tokens) @@ -40,10 +39,10 @@ class BatchEmbedding(TensorWrapper): ``[batch_size, longest_seq_length, embedding_size]`` for each sentence in the batch. """ - input_embeds: Optional[EmbeddingsTensor] = None - baseline_embeds: Optional[EmbeddingsTensor] = None + input_embeds: EmbeddingsTensor | None = None + baseline_embeds: EmbeddingsTensor | None = None - def __len__(self) -> Optional[int]: + def __len__(self) -> int | None: if self.input_embeds is not None: return self.input_embeds.shape[0] return None @@ -79,15 +78,15 @@ def attention_mask(self) -> IdsTensor: return self.encoding.attention_mask @property - def baseline_ids(self) -> Optional[IdsTensor]: + def baseline_ids(self) -> IdsTensor | None: return self.encoding.baseline_ids @property - def input_embeds(self) -> Optional[EmbeddingsTensor]: + def input_embeds(self) -> EmbeddingsTensor | None: return self.embedding.input_embeds @property - def baseline_embeds(self) -> Optional[EmbeddingsTensor]: + def baseline_embeds(self) -> EmbeddingsTensor | None: return self.embedding.baseline_embeds @input_ids.setter @@ -103,15 +102,15 @@ def attention_mask(self, value: IdsTensor): self.encoding.attention_mask = value @baseline_ids.setter - def baseline_ids(self, value: Optional[IdsTensor]): + def baseline_ids(self, value: IdsTensor | None): self.encoding.baseline_ids = value @input_embeds.setter - def input_embeds(self, value: Optional[EmbeddingsTensor]): + def input_embeds(self, value: EmbeddingsTensor | None): self.embedding.input_embeds = value @baseline_embeds.setter - def baseline_embeds(self, value: Optional[EmbeddingsTensor]): + def baseline_embeds(self, value: EmbeddingsTensor | None): self.embedding.baseline_embeds = value @@ -128,7 +127,7 @@ class EncoderDecoderBatch(TensorWrapper): sources: Batch targets: Batch - def __getitem__(self, subscript: Union[slice, int]) -> "EncoderDecoderBatch": + def __getitem__(self, subscript: slice | int) -> "EncoderDecoderBatch": return EncoderDecoderBatch(sources=self.sources, targets=self.targets[subscript]) @property @@ -169,7 +168,7 @@ def target_mask(self) -> IdsTensor: def get_step_target( self, step: int, with_attention: bool = False - ) -> Union[ExpandedTargetIdsTensor, tuple[ExpandedTargetIdsTensor, ExpandedTargetIdsTensor]]: + ) -> ExpandedTargetIdsTensor | tuple[ExpandedTargetIdsTensor, ExpandedTargetIdsTensor]: tgt = self.targets.input_ids[:, step] if with_attention: return tgt, self.targets.attention_mask[:, step] @@ -218,7 +217,7 @@ def target_mask(self) -> IdsTensor: def get_step_target( self, step: int, with_attention: bool = False - ) -> Union[ExpandedTargetIdsTensor, tuple[ExpandedTargetIdsTensor, ExpandedTargetIdsTensor]]: + ) -> ExpandedTargetIdsTensor | tuple[ExpandedTargetIdsTensor, ExpandedTargetIdsTensor]: tgt = self.input_ids[:, step] if with_attention: return tgt, self.attention_mask[:, step] @@ -233,7 +232,7 @@ def from_batch(self, batch: Batch) -> "DecoderOnlyBatch": def slice_batch_from_position( - batch: DecoderOnlyBatch, curr_idx: int, alignments: Optional[list[tuple[int, int]]] = None + batch: DecoderOnlyBatch, curr_idx: int, alignments: list[tuple[int, int]] | None = None ) -> tuple[DecoderOnlyBatch, IdsTensor]: if len(alignments) > 0 and isinstance(alignments[0], list): alignments = alignments[0] diff --git a/inseq/data/data_utils.py b/inseq/data/data_utils.py index d0f90203..39099986 100644 --- a/inseq/data/data_utils.py +++ b/inseq/data/data_utils.py @@ -38,7 +38,7 @@ def _slice_batch(attr, subscript): return attr[subscript] if attr.ndim >= 2: return attr[subscript, ...] - elif isinstance(attr, (TensorWrapper, list)): + elif isinstance(attr, TensorWrapper | list): return attr[subscript] elif isinstance(attr, dict): return {key: TensorWrapper._slice_batch(val, subscript) for key, val in attr.items()} @@ -69,7 +69,7 @@ def _select_active(attr, mask): @staticmethod def _to(attr, device: str): - if isinstance(attr, (torch.Tensor, TensorWrapper)): + if isinstance(attr, torch.Tensor | TensorWrapper): return attr.to(device) elif isinstance(attr, dict): return {key: TensorWrapper._to(val, device) for key, val in attr.items()} @@ -78,7 +78,7 @@ def _to(attr, device: str): @staticmethod def _detach(attr): - if isinstance(attr, (torch.Tensor, TensorWrapper)): + if isinstance(attr, torch.Tensor | TensorWrapper): return attr.detach() elif isinstance(attr, dict): return {key: TensorWrapper._detach(val) for key, val in attr.items()} @@ -87,7 +87,7 @@ def _detach(attr): @staticmethod def _numpy(attr): - if isinstance(attr, (torch.Tensor, TensorWrapper)): + if isinstance(attr, torch.Tensor | TensorWrapper): np_array = attr.numpy() if isinstance(np_array, np.ndarray): return np.ascontiguousarray(np_array, dtype=np_array.dtype) @@ -167,7 +167,7 @@ def clone(self: TensorClass) -> TensorClass: out_params = {} for field in fields(self.__class__): attr = getattr(self, field.name) - if isinstance(attr, (torch.Tensor, TensorWrapper)): + if isinstance(attr, torch.Tensor | TensorWrapper): out_params[field.name] = attr.clone() elif attr is not None: out_params[field.name] = deepcopy(attr) diff --git a/inseq/data/viz.py b/inseq/data/viz.py index 28d7004a..fc8707d3 100644 --- a/inseq/data/viz.py +++ b/inseq/data/viz.py @@ -18,9 +18,12 @@ import random import string -from typing import Literal, Optional, Union +from typing import TYPE_CHECKING, Literal import numpy as np +import treescope as ts +import treescope.figures as fg +import treescope.rendering_parts as rp from matplotlib.colors import Colormap from rich import box from rich.color import Color @@ -36,26 +39,38 @@ from tqdm.std import tqdm from ..utils import isnotebook +from ..utils.misc import clean_tokens from ..utils.typing import TextSequences from ..utils.viz_utils import ( final_plot_html, get_colors, get_instance_html, + maybe_add_linebreak, red_transparent_blue_colormap, saliency_heatmap_html, saliency_heatmap_table_header, sanitize_html, + test_dim, + treescope_cmap, ) -from .attribution import FeatureAttributionSequenceOutput + +if TYPE_CHECKING: + from .attribution import FeatureAttributionSequenceOutput + +if isnotebook(): + cmap = treescope_cmap() + ts.basic_interactive_setup(autovisualize_arrays=True) + ts.default_diverging_colormap.set_globally(cmap) + ts.default_sequential_colormap.set_globally(cmap) def show_attributions( - attributions: FeatureAttributionSequenceOutput, - min_val: Optional[int] = None, - max_val: Optional[int] = None, + attributions: "FeatureAttributionSequenceOutput", + min_val: int | None = None, + max_val: int | None = None, display: bool = True, - return_html: Optional[bool] = False, -) -> Optional[str]: + return_html: bool | None = False, +) -> str | None: """Core function allowing for visualization of feature attribution maps in console/HTML format. Args: @@ -74,6 +89,8 @@ def show_attributions( Returns: `Optional[str]`: Returns the HTML output if `return_html=True` """ + from inseq.data.attribution import FeatureAttributionSequenceOutput + if isinstance(attributions, FeatureAttributionSequenceOutput): attributions = [attributions] html_out = "" @@ -128,14 +145,243 @@ def show_attributions( return html_out +def show_granular_attributions( + attributions: "FeatureAttributionSequenceOutput", + max_show_size: int = 20, + min_val: int | None = None, + max_val: int | None = None, + show_dim: int | str | None = None, + slice_dims: dict[int | str, tuple[int, int]] | None = None, + display: bool = True, + return_html: bool | None = False, + return_figure: bool = False, +) -> str | None: + """Visualizes granular attribution heatmaps in HTML format. + + Args: + attributions (:class:`~inseq.data.attribution.FeatureAttributionSequenceOutput`): + Sequence attributions to be visualized. Does not require pre-aggregation. + min_val (:obj:`int`, *optional*, defaults to None): + Lower attribution score threshold for color map. + max_val (:obj:`int`, *optional*, defaults to None): + Upper attribution score threshold for color map. + max_show_size (:obj:`int`, *optional*, defaults to None): + Maximum dimension size for additional dimensions to be visualized. Default: 20. + show_dim (:obj:`int` or :obj:`str`, *optional*, defaults to None): + Dimension to be visualized along with the source and target tokens. Can be either the dimension index or + the dimension name. Works only if the dimension size is less than or equal to `max_show_size`. + slice_dims (:obj:`dict[int or str, tuple[int, int]]`, *optional*, defaults to None): + Dimensions to be sliced and visualized along with the source and target tokens. The dictionary should + contain the dimension index or name as the key and the slice range as the value. + display (:obj:`bool`, *optional*, defaults to True): + Whether to show the output of the visualization function. + return_html (:obj:`bool`, *optional*, defaults to False): + If true, returns the HTML corresponding to the notebook visualization of the attributions in + string format, for saving purposes. + return_figure (:obj:`bool`, *optional*, defaults to False): + If true, returns the Treescope figure object for further manipulation. + + Returns: + `str`: Returns the HTML output if `return_html=True` + """ + from inseq.data.attribution import FeatureAttributionSequenceOutput + + if isinstance(attributions, FeatureAttributionSequenceOutput): + attributions: list["FeatureAttributionSequenceOutput"] = [attributions] + if not isnotebook() and display: + raise ValueError( + "Granular attribution heatmaps visualization is only supported in Jupyter notebooks. " + "Please set `display=False` and `return_html=True` to avoid this error." + ) + if return_html and return_figure: + raise ValueError("Only one of `return_html` and `return_figure` can be set to True.") + items_to_render = [] + for attribution in attributions: + if attribution.source_attributions is not None: + items_to_render += [ + fg.bolded("Source Saliency Heatmap"), + get_saliency_heatmap_treescope( + attribution.source_attributions.numpy(), + [t.token for t in attribution.target[attribution.attr_pos_start : attribution.attr_pos_end]], + [t.token for t in attribution.source], + attribution._attribution_dim_names["source_attributions"], + max_show_size=max_show_size, + max_val=max_val, + min_val=min_val, + show_dim=show_dim, + slice_dims=slice_dims, + ), + ] + if attribution.target_attributions is not None: + items_to_render += [ + fg.bolded("Target Saliency Heatmap"), + get_saliency_heatmap_treescope( + attribution.target_attributions.numpy(), + [t.token for t in attribution.target[attribution.attr_pos_start : attribution.attr_pos_end]], + [t.token for t in attribution.target], + attribution._attribution_dim_names["target_attributions"], + max_show_size=max_show_size, + max_val=max_val, + min_val=min_val, + show_dim=show_dim, + slice_dims=slice_dims, + ), + ] + items_to_render.append("") + fig = fg.inline(*items_to_render) + if return_figure: + return fig + if display: + ts.show(fig) + if return_html: + return ts.render_to_html(fig) + + +def show_token_attributions( + attributions: "FeatureAttributionSequenceOutput", + min_val: int | None = None, + max_val: int | None = None, + display: bool = True, + return_html: bool | None = False, + return_figure: bool = False, + replace_char: dict[str, str] | None = None, + wrap_after: int | str | list[str] | tuple[str] | None = None, + step_score_highlight: str | None = None, +): + """Visualizes token-level attributions in HTML format. + + Args: + attributions (:class:`~inseq.data.attribution.FeatureAttributionSequenceOutput`): + Sequence attributions to be visualized. + min_val (:obj:`Optional[int]`, *optional*, defaults to None): + Lower attribution score threshold for color map. + max_val (`Optional[int]`, *optional*, defaults to None): + Upper attribution score threshold for color map. + display (`bool`, *optional*, defaults to True): + Whether to show the output of the visualization function. + return_html (`Optional[bool]`, *optional*, defaults to False): + If true, returns the HTML corresponding to the notebook visualization of the attributions in string format, + for saving purposes. + return_figure (`Optional[bool]`, *optional*, defaults to False): + If true, returns the Treescope figure object for further manipulation. + replace_char (`Optional[dict[str, str]]`, *optional*, defaults to None): + Dictionary mapping strings to be replaced to replacement options, used for cleaning special characters. + Default: {}. + wrap_after (`Optional[int | str | list[str] | tuple[str]]`, *optional*, defaults to None): + Token indices or tokens after which to wrap lines. E.g. 10 = wrap after every 10 tokens, "hi" = wrap after + word hi occurs, ["." "!", "?"] or ".!?" = wrap after every sentence-ending punctuation. + step_score_highlight (`Optional[str]`, *optional*, defaults to None): + Name of the step score to use to highlight generated tokens in the visualization. If None, no highlights are + shown. Default: None. + """ + from inseq.data.attribution import FeatureAttributionSequenceOutput + + if isinstance(attributions, FeatureAttributionSequenceOutput): + attributions: list["FeatureAttributionSequenceOutput"] = [attributions] + if not isnotebook() and display: + raise ValueError( + "Token attribution visualization is only supported in Jupyter notebooks. " + "Please set `display=False` and `return_html=True` to avoid this error." + ) + if return_html and return_figure: + raise ValueError("Only one of `return_html` and `return_figure` can be set to True.") + if replace_char is None: + replace_char = {} + if max_val is None: + max_val = max(attribution.maximum for attribution in attributions) + if step_score_highlight is not None and ( + attributions[0].step_scores is None or step_score_highlight not in attributions[0].step_scores + ): + raise ValueError( + f'The requested step score "{step_score_highlight}" is not available for highlights in the provided ' + "attribution object. Please set `step_score_highlight=None` or recompute `model.attribute` by passing " + f'`step_scores=["{step_score_highlight}"].' + ) + generated_token_parts = [] + for attr in attributions: + cleaned_generated_tokens = clean_tokens(t.token for t in attr.target[attr.attr_pos_start : attr.attr_pos_end]) + cleaned_input_tokens = clean_tokens(t.token for t in attr.source) + cleaned_target_tokens = clean_tokens(t.token for t in attr.target) + step_scores = None + title = "Generated text:\n\n" + if step_score_highlight is not None: + step_scores = attr.step_scores[step_score_highlight] + scores_vmax = step_scores.max().item() + # Use different cmap to differentiate from attribution scores + scores_cmap = ( + treescope_cmap("greens") if all(x >= 0 for x in step_scores) else treescope_cmap("brown_to_green") + ) + title = f"Generated text with {step_score_highlight} highlights:\n\n" + generated_token_parts.append(rp.custom_style(rp.text(title), css_style="font-weight: bold;")) + for gen_idx, curr_gen_tok in enumerate(cleaned_generated_tokens): + attributed_token_parts = [rp.text("\n")] + if attr.source_attributions is not None: + attributed_token_parts.append( + get_tokens_heatmap_treescope( + tokens=cleaned_input_tokens, + scores=attr.source_attributions[:, gen_idx].numpy(), + title=f'Source attributions for "{curr_gen_tok}"', + title_style="font-style: italic; color: #888888;", + min_val=min_val, + max_val=max_val, + wrap_after=wrap_after, + ) + ) + attributed_token_parts.append(rp.text("\n\n")) + if attr.target_attributions is not None: + attributed_token_parts.append( + get_tokens_heatmap_treescope( + tokens=cleaned_target_tokens[: attr.attr_pos_start + gen_idx], + scores=attr.target_attributions[:, gen_idx].numpy(), + title=f'Target attributions for "{curr_gen_tok}"', + title_style="font-style: italic; color: #888888;", + min_val=min_val, + max_val=max_val, + wrap_after=wrap_after, + ) + ) + attributed_token_parts.append(rp.text("\n\n")) + if step_scores is not None: + gen_tok_label = fg.treescope_part_from_display_object( + fg.text_on_color( + curr_gen_tok, + value=round(step_scores[gen_idx].item(), 4), + vmax=scores_vmax, + colormap=scores_cmap, + ) + ) + else: + gen_tok_label = rp.text(curr_gen_tok) + generated_token_parts.append( + rp.build_full_line_with_annotations( + rp.build_custom_foldable_tree_node( + label=gen_tok_label, + contents=rp.fold_condition( + collapsed=rp.text(" "), + expanded=rp.indented_children([rp.siblings(*attributed_token_parts)]), + ), + ) + ) + ) + fig = fg.figure_from_treescope_rendering_part( + rp.custom_style(rp.siblings(*generated_token_parts), css_style="white-space: pre-wrap") + ) + if return_figure: + return fig + if display: + ts.show(fig) + if return_html: + return ts.render_to_html(fig) + + def get_attribution_colors( - attributions: list[FeatureAttributionSequenceOutput], - min_val: Optional[int] = None, - max_val: Optional[int] = None, - cmap: Union[str, Colormap, None] = None, + attributions: list["FeatureAttributionSequenceOutput"], + min_val: int | None = None, + max_val: int | None = None, + cmap: str | Colormap | None = None, return_alpha: bool = True, return_strings: bool = True, -) -> list[list[list[Union[str, tuple[float, float, float]]]]]: +) -> list[list[list[str | tuple[float, float, float]]]]: """A list (one element = one sentence) of lists (one element = attributions for one token) of lists (one element = one attribution) of colors. Colors are either strings or RGB(A) tuples. """ @@ -161,7 +407,7 @@ def get_attribution_colors( def get_heatmap_type( - attribution: FeatureAttributionSequenceOutput, + attribution: "FeatureAttributionSequenceOutput", colors, heatmap_type: Literal["Source", "Target"] = "Source", use_html: bool = False, @@ -197,13 +443,13 @@ def get_heatmap_type( def get_saliency_heatmap_html( - scores: Union[np.ndarray, None], + scores: np.ndarray | None, column_labels: list[str], row_labels: list[str], input_colors: list[list[str]], - step_scores: Optional[dict[str, np.ndarray]] = None, + step_scores: dict[str, np.ndarray] | None = None, label: str = "", - step_scores_threshold: Union[float, dict[str, float]] = 0.5, + step_scores_threshold: float | dict[str, float] = 0.5, ): # unique ID added to HTML elements and function to avoid collision of differnent instances uuid = "".join(random.choices(string.ascii_lowercase, k=20)) @@ -250,13 +496,13 @@ def get_saliency_heatmap_html( def get_saliency_heatmap_rich( - scores: Union[np.ndarray, None], + scores: np.ndarray | None, column_labels: list[str], row_labels: list[str], input_colors: list[list[str]], - step_scores: Optional[dict[str, np.ndarray]] = None, + step_scores: dict[str, np.ndarray] | None = None, label: str = "", - step_scores_threshold: Union[float, dict[str, float]] = 0.5, + step_scores_threshold: float | dict[str, float] = 0.5, ): columns = [ Column(header="", justify="right", overflow="fold"), @@ -297,6 +543,87 @@ def get_saliency_heatmap_rich( return table +def get_saliency_heatmap_treescope( + scores: np.ndarray | None, + column_labels: list[str], + row_labels: list[str], + dim_names: dict[int, str] | None = None, + max_show_size: int | None = None, + max_val: float | None = None, + min_val: float | None = None, + show_dim: int | str | None = None, + slice_dims: dict[int | str, tuple[int, int]] | None = None, +): + if max_show_size is None: + max_show_size = 20 + if dim_names is None: + dim_names = {} + item_labels_dict = {0: row_labels, 1: column_labels} + rev_dim_names = {v: k for k, v in dim_names.items()} + col_dims = [1] + slider_dims = [] + if slice_dims is not None: + slices = [slice(None)] * scores.ndim + for dim_name, slice_idxs in slice_dims.items(): + dim_idx = test_dim(dim_name, dim_names, rev_dim_names, scores) + slices[dim_idx] = slice(slice_idxs[0], slice_idxs[1]) + scores = scores[tuple(slices)] + if show_dim is not None: + show_dim_idx = test_dim(show_dim, dim_names, rev_dim_names, scores) + if scores.shape[show_dim_idx] > max_show_size: + raise ValueError( + f"Dimension {show_dim_idx} has size {scores.shape[show_dim_idx]} which is greater than the maximum " + f"show size {max_show_size}. Please choose a different dimension or slice the tensor before " + "visualizing it using SliceAggregator." + ) + col_dims.append(show_dim_idx) + for dim_idx, dim_name in dim_names.items(): + if dim_idx > 1: + if scores.shape[dim_idx] <= max_show_size and len(col_dims) < 2: + col_dims.append(dim_idx) + else: + slider_dims.append(dim_idx) + item_labels_dict[dim_idx] = [f"{dim_name} #{i}" for i in range(scores.shape[dim_idx])] + return ts.render_array( + scores, + rows=[0], + columns=col_dims, + sliders=slider_dims, + axis_labels={k: f"{v}: {scores.shape[k]}" for k, v in dim_names.items()}, + axis_item_labels=item_labels_dict, + vmax=max_val, + vmin=min_val, + ) + + +def get_tokens_heatmap_treescope( + tokens: list[str], + scores: np.ndarray, + title: str | None = None, + title_style: str | None = None, + min_val: float | None = None, + max_val: float | None = None, + wrap_after: int | str | list[str] | tuple[str] | None = None, +): + parts = [] + if title is not None: + parts.append( + rp.custom_style( + rp.text(title + ":\n"), + css_style=title_style, + ) + ) + for idx, tok in enumerate(tokens): + if not np.isnan(scores[idx]): + parts.append( + fg.treescope_part_from_display_object( + fg.text_on_color(tok, value=round(scores[idx], 4), vmin=min_val, vmax=max_val) + ) + ) + parts += maybe_add_linebreak(tok, idx, wrap_after) + return rp.siblings(*parts) + + # Progress bar utilities @@ -308,7 +635,7 @@ def get_progress_bar( pretty: bool, attr_pos_start: int, attr_pos_end: int, -) -> Union[tqdm, tuple[Progress, Live], None]: +) -> tqdm | tuple[Progress, Live] | None: if not show: return None elif show and not pretty: @@ -324,7 +651,7 @@ def get_progress_bar( TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), TimeRemainingColumn(), ) - for idx, (tgt, tgt_len) in enumerate(zip(sequences.targets, target_lengths)): + for idx, (tgt, tgt_len) in enumerate(zip(sequences.targets, target_lengths, strict=False)): clean_tgt = escape(tgt.replace("\n", "\\n")) job_progress.add_task(f"{idx}. {clean_tgt}", total=tgt_len) progress_table = Table.grid() @@ -356,11 +683,11 @@ def get_progress_bar( def update_progress_bar( - pbar: Union[tqdm, tuple[Progress, Live], None], - skipped_prefixes: Optional[list[str]] = None, - attributed_sentences: Optional[list[str]] = None, - unattributed_suffixes: Optional[list[str]] = None, - skipped_suffixes: Optional[list[str]] = None, + pbar: tqdm | tuple[Progress, Live] | None, + skipped_prefixes: list[str] | None = None, + attributed_sentences: list[str] | None = None, + unattributed_suffixes: list[str] | None = None, + skipped_suffixes: list[str] | None = None, whitespace_indexes: list[list[int]] = None, show: bool = False, pretty: bool = False, @@ -376,7 +703,7 @@ def update_progress_bar( pbar[0].advance(job.id) formatted_desc = f"{job.id}. " past_length = 0 - for split, color in zip(split_targets, ["grey58", "green", "orange1", "grey58"]): + for split, color in zip(split_targets, ["grey58", "green", "orange1", "grey58"], strict=False): if split[job.id]: formatted_desc += f"[{color}]" + escape(split[job.id].replace("\n", "\\n")) + "[/]" past_length += len(split[job.id]) @@ -386,7 +713,7 @@ def update_progress_bar( pbar[0].update(job.id, description=formatted_desc, refresh=True) -def close_progress_bar(pbar: Union[tqdm, tuple[Progress, Live], None], show: bool, pretty: bool) -> None: +def close_progress_bar(pbar: tqdm | tuple[Progress, Live] | None, show: bool, pretty: bool) -> None: if not show: return elif show and not pretty: diff --git a/inseq/models/__init__.py b/inseq/models/__init__.py index 5959d5f9..83377607 100644 --- a/inseq/models/__init__.py +++ b/inseq/models/__init__.py @@ -19,8 +19,8 @@ def load_model( - model: Union[ModelIdentifier, ModelClass], - attribution_method: Optional[str] = None, + model: ModelIdentifier | ModelClass, + attribution_method: str | None = None, framework: str = "hf_transformers", **kwargs, ) -> AttributionModel: diff --git a/inseq/models/attribution_model.py b/inseq/models/attribution_model.py index 08e645de..ee84793c 100644 --- a/inseq/models/attribution_model.py +++ b/inseq/models/attribution_model.py @@ -1,7 +1,8 @@ import logging from abc import ABC, abstractmethod +from collections.abc import Callable from functools import wraps -from typing import Any, Callable, Optional, Protocol, TypeVar, Union +from typing import Any, Protocol, TypeVar import torch @@ -52,11 +53,11 @@ class ForwardMethod(Protocol): def __call__( self, - batch: Union[DecoderOnlyBatch, EncoderDecoderBatch], + batch: DecoderOnlyBatch | EncoderDecoderBatch, target_ids: ExpandedTargetIdsTensor, attributed_fn: Callable[..., SingleScorePerStepTensor], use_embeddings: bool, - attributed_fn_argnames: Optional[list[str]], + attributed_fn_argnames: list[str] | None, *args, ) -> CustomForwardOutput: ... @@ -70,13 +71,13 @@ def prepare_inputs_for_attribution( inputs: FeatureAttributionInput, include_eos_baseline: bool = False, skip_special_tokens: bool = False, - ) -> Union[DecoderOnlyBatch, EncoderDecoderBatch]: + ) -> DecoderOnlyBatch | EncoderDecoderBatch: raise NotImplementedError() @staticmethod @abstractmethod def format_attribution_args( - batch: Union[DecoderOnlyBatch, EncoderDecoderBatch], + batch: DecoderOnlyBatch | EncoderDecoderBatch, target_ids: TargetIdsTensor, attributed_fn: Callable[..., SingleScorePerStepTensor], attribute_target: bool = False, @@ -84,7 +85,7 @@ def format_attribution_args( attribute_batch_ids: bool = False, forward_batch_embeds: bool = True, use_baselines: bool = False, - ) -> tuple[dict[str, Any], tuple[Union[IdsTensor, EmbeddingsTensor, None], ...]]: + ) -> tuple[dict[str, Any], tuple[IdsTensor | EmbeddingsTensor | None, ...]]: raise NotImplementedError() @staticmethod @@ -92,11 +93,11 @@ def format_attribution_args( def enrich_step_output( attribution_model: "AttributionModel", step_output: FeatureAttributionStepOutput, - batch: Union[DecoderOnlyBatch, EncoderDecoderBatch], + batch: DecoderOnlyBatch | EncoderDecoderBatch, target_tokens: OneOrMoreTokenSequences, target_ids: TargetIdsTensor, - contrast_batch: Optional[DecoderOnlyBatch] = None, - contrast_targets_alignments: Optional[list[list[tuple[int, int]]]] = None, + contrast_batch: DecoderOnlyBatch | None = None, + contrast_targets_alignments: list[list[tuple[int, int]]] | None = None, ) -> FeatureAttributionStepOutput: r"""Enriches the attribution output with token information, producing the finished :class:`~inseq.data.FeatureAttributionStepOutput` object. @@ -116,7 +117,7 @@ def enrich_step_output( @staticmethod @abstractmethod - def convert_args_to_batch(args: StepFunctionArgs = None, **kwargs) -> Union[DecoderOnlyBatch, EncoderDecoderBatch]: + def convert_args_to_batch(args: StepFunctionArgs = None, **kwargs) -> DecoderOnlyBatch | EncoderDecoderBatch: raise NotImplementedError() @staticmethod @@ -141,7 +142,7 @@ def format_step_function_args( @staticmethod @abstractmethod def get_text_sequences( - attribution_model: "AttributionModel", batch: Union[DecoderOnlyBatch, EncoderDecoderBatch] + attribution_model: "AttributionModel", batch: DecoderOnlyBatch | EncoderDecoderBatch ) -> TextSequences: raise NotImplementedError() @@ -152,15 +153,15 @@ def get_step_function_reserved_args() -> list[str]: @staticmethod def format_contrast_targets_alignments( - contrast_targets_alignments: Union[list[tuple[int, int]], list[list[tuple[int, int]]], str], + contrast_targets_alignments: list[tuple[int, int]] | list[list[tuple[int, int]]] | str, target_sequences: list[str], target_tokens: list[list[str]], contrast_sequences: list[str], contrast_tokens: list[list[str]], special_tokens: list[str] = [], start_pos: int = 0, - end_pos: Optional[int] = None, - ) -> tuple[DecoderOnlyBatch, Optional[list[list[tuple[int, int]]]]]: + end_pos: int | None = None, + ) -> tuple[DecoderOnlyBatch, list[list[tuple[int, int]]] | None]: # Ensure that the contrast_targets_alignments are in the correct format (list of lists of idxs pairs) if contrast_targets_alignments: if isinstance(contrast_targets_alignments, list) and len(contrast_targets_alignments) > 0: @@ -174,7 +175,7 @@ def format_contrast_targets_alignments( adjusted_alignments = [] aligns = contrast_targets_alignments for seq_idx, (tgt_seq, tgt_tok, c_seq, c_tok) in enumerate( - zip(target_sequences, target_tokens, contrast_sequences, contrast_tokens) + zip(target_sequences, target_tokens, contrast_sequences, contrast_tokens, strict=False) ): if isinstance(contrast_targets_alignments, list): aligns = contrast_targets_alignments[seq_idx] @@ -217,18 +218,18 @@ def __init__(self, **kwargs) -> None: self.model = None self.model_name: str = None self.is_encoder_decoder: bool = True - self.pad_token: Optional[str] = None - self.embed_scale: Optional[float] = None - self._device: Optional[str] = None - self.device_map: Optional[dict[str, Union[str, int, torch.device]]] = None - self.attribution_method: Optional[FeatureAttribution] = None + self.pad_token: str | None = None + self.embed_scale: float | None = None + self._device: str | None = None + self.device_map: dict[str, str | int | torch.device] | None = None + self.attribution_method: FeatureAttribution | None = None self.is_hooked: bool = False self._default_attributed_fn_id: str = "probability" - self.config: Optional[ModelConfig] = None - self.is_distributed: Optional[bool] = None + self.config: ModelConfig | None = None + self.is_distributed: bool | None = None @property - def device(self) -> Optional[str]: + def device(self) -> str | None: return self._device @device.setter @@ -238,7 +239,7 @@ def device(self, new_device: str) -> None: if self.model: self.model.to(self._device) - def setup(self, device: Optional[str] = None, attribution_method: Optional[str] = None, **kwargs) -> None: + def setup(self, device: str | None = None, attribution_method: str | None = None, **kwargs) -> None: """Move the model to device and in eval mode.""" self.device = device if device is not None else get_default_device() if self.model: @@ -258,7 +259,7 @@ def set_attributed_fn(self, fn: str): self._default_attributed_fn_id = fn @property - def info(self) -> dict[Optional[str], Optional[str]]: + def info(self) -> dict[str | None, str | None]: return { "model_name": self.model_name, "model_class": self.model.__class__.__name__ if self.model is not None else None, @@ -266,8 +267,8 @@ def info(self) -> dict[Optional[str], Optional[str]]: def get_attribution_method( self, - method: Optional[str] = None, - override_default_attribution: Optional[bool] = False, + method: str | None = None, + override_default_attribution: bool | None = False, **kwargs, ) -> FeatureAttribution: # No method present -> missing method error @@ -287,7 +288,7 @@ def get_attribution_method( return self.attribution_method def get_attributed_fn( - self, attributed_fn: Union[str, Callable[..., SingleScorePerStepTensor], None] = None + self, attributed_fn: str | Callable[..., SingleScorePerStepTensor] | None = None ) -> Callable[..., SingleScorePerStepTensor]: if attributed_fn is None: attributed_fn = self.default_attributed_fn_id @@ -302,20 +303,20 @@ def get_attributed_fn( def attribute( self, input_texts: TextInput, - generated_texts: Optional[TextInput] = None, - method: Optional[str] = None, - override_default_attribution: Optional[bool] = False, - attr_pos_start: Optional[int] = None, - attr_pos_end: Optional[int] = None, + generated_texts: TextInput | None = None, + method: str | None = None, + override_default_attribution: bool | None = False, + attr_pos_start: int | None = None, + attr_pos_end: int | None = None, show_progress: bool = True, pretty_progress: bool = True, output_step_attributions: bool = False, attribute_target: bool = False, step_scores: list[str] = [], include_eos_baseline: bool = False, - attributed_fn: Union[str, Callable[..., SingleScorePerStepTensor], None] = None, - device: Optional[str] = None, - batch_size: Optional[int] = None, + attributed_fn: str | Callable[..., SingleScorePerStepTensor] | None = None, + device: str | None = None, + batch_size: int | None = None, generate_from_target_prefix: bool = False, skip_special_tokens: bool = False, generation_args: dict[str, Any] = {}, @@ -500,7 +501,7 @@ def attribute( self.device = original_device return attribution_output - def embed(self, inputs: Union[TextInput, IdsTensor], as_targets: bool = False, add_special_tokens: bool = True): + def embed(self, inputs: TextInput | IdsTensor, as_targets: bool = False, add_special_tokens: bool = True): if isinstance(inputs, str) or ( isinstance(inputs, list) and len(inputs) > 0 and all(isinstance(x, str) for x in inputs) ): @@ -510,9 +511,9 @@ def embed(self, inputs: Union[TextInput, IdsTensor], as_targets: bool = False, a def get_token_with_ids( self, - batch: Union[EncoderDecoderBatch, DecoderOnlyBatch], - contrast_target_tokens: Optional[OneOrMoreTokenSequences] = None, - contrast_targets_alignments: Optional[list[list[tuple[int, int]]]] = None, + batch: EncoderDecoderBatch | DecoderOnlyBatch, + contrast_target_tokens: OneOrMoreTokenSequences | None = None, + contrast_targets_alignments: list[list[tuple[int, int]]] | None = None, ) -> list[list[TokenWithId]]: if contrast_target_tokens is not None: return join_token_ids( @@ -529,10 +530,10 @@ def get_token_with_ids( @abstractmethod def generate( self, - encodings: Union[TextInput, BatchEncoding], - return_generation_output: Optional[bool] = False, + encodings: TextInput | BatchEncoding, + return_generation_output: bool | None = False, **kwargs, - ) -> Union[list[str], tuple[list[str], Any]]: + ) -> list[str] | tuple[list[str], Any]: pass @staticmethod @@ -561,14 +562,14 @@ def embed_ids(self, ids: IdsTensor, as_targets: bool = False) -> EmbeddingsTenso @abstractmethod def convert_ids_to_tokens( - self, ids: torch.Tensor, skip_special_tokens: Optional[bool] = True + self, ids: torch.Tensor, skip_special_tokens: bool | None = True ) -> OneOrMoreTokenSequences: pass @abstractmethod def convert_tokens_to_ids( self, - tokens: Union[list[str], list[list[str]]], + tokens: list[str] | list[list[str]], ) -> OneOrMoreIdSequences: pass @@ -576,7 +577,7 @@ def convert_tokens_to_ids( def convert_tokens_to_string( self, tokens: OneOrMoreTokenSequences, - skip_special_tokens: Optional[bool] = True, + skip_special_tokens: bool | None = True, as_targets: bool = False, ) -> TextInput: pass @@ -668,11 +669,11 @@ def get_hidden_states_dict(output: ModelOutput) -> dict[str, torch.Tensor]: def _forward( self, - batch: Union[DecoderOnlyBatch, EncoderDecoderBatch], + batch: DecoderOnlyBatch | EncoderDecoderBatch, target_ids: ExpandedTargetIdsTensor, attributed_fn: Callable[..., SingleScorePerStepTensor], use_embeddings: bool = True, - attributed_fn_argnames: Optional[list[str]] = None, + attributed_fn_argnames: list[str] | None = None, *args, **kwargs, ) -> LogitsTensor: @@ -683,12 +684,12 @@ def _forward( step_fn_args = self.formatter.format_step_function_args( attribution_model=self, forward_output=output, target_ids=target_ids, is_attributed_fn=True, batch=batch ) - step_fn_extra_args = {k: v for k, v in zip(attributed_fn_argnames, args) if v is not None} + step_fn_extra_args = {k: v for k, v in zip(attributed_fn_argnames, args, strict=False) if v is not None} return attributed_fn(step_fn_args, **step_fn_extra_args) def _forward_with_output( self, - batch: Union[DecoderOnlyBatch, EncoderDecoderBatch], + batch: DecoderOnlyBatch | EncoderDecoderBatch, use_embeddings: bool = True, *args, **kwargs, diff --git a/inseq/models/decoder_only.py b/inseq/models/decoder_only.py index 48aefce5..c22641fd 100644 --- a/inseq/models/decoder_only.py +++ b/inseq/models/decoder_only.py @@ -1,6 +1,7 @@ import logging +from collections.abc import Callable from functools import wraps -from typing import Any, Callable, Optional, TypeVar, Union +from typing import Any, TypeVar import torch @@ -60,7 +61,7 @@ def format_attribution_args( attribute_batch_ids: bool = False, forward_batch_embeds: bool = True, use_baselines: bool = False, - ) -> tuple[dict[str, Any], tuple[Union[IdsTensor, EmbeddingsTensor, None], ...]]: + ) -> tuple[dict[str, Any], tuple[IdsTensor | EmbeddingsTensor | None, ...]]: if attribute_batch_ids: inputs = (batch.input_ids,) baselines = (batch.baseline_ids,) @@ -97,8 +98,8 @@ def enrich_step_output( batch: DecoderOnlyBatch, target_tokens: OneOrMoreTokenSequences, target_ids: TargetIdsTensor, - contrast_batch: Optional[DecoderOnlyBatch] = None, - contrast_targets_alignments: Optional[list[list[tuple[int, int]]]] = None, + contrast_batch: DecoderOnlyBatch | None = None, + contrast_targets_alignments: list[list[tuple[int, int]]] | None = None, ) -> FeatureAttributionStepOutput: r"""Enriches the attribution output with token information, producing the finished :class:`~inseq.data.FeatureAttributionStepOutput` object. @@ -153,9 +154,9 @@ def format_step_function_args( @staticmethod def convert_args_to_batch( args: StepFunctionDecoderOnlyArgs = None, - decoder_input_ids: Optional[IdsTensor] = None, - decoder_attention_mask: Optional[IdsTensor] = None, - decoder_input_embeds: Optional[EmbeddingsTensor] = None, + decoder_input_ids: IdsTensor | None = None, + decoder_attention_mask: IdsTensor | None = None, + decoder_input_embeds: EmbeddingsTensor | None = None, **kwargs, ) -> DecoderOnlyBatch: if args is not None: @@ -175,9 +176,9 @@ def formatted_forward_input_wrapper( input_ids: IdsTensor, target_ids: ExpandedTargetIdsTensor, attributed_fn: Callable[..., SingleScorePerStepTensor], - attention_mask: Optional[IdsTensor] = None, + attention_mask: IdsTensor | None = None, use_embeddings: bool = True, - attributed_fn_argnames: Optional[list[str]] = None, + attributed_fn_argnames: list[str] | None = None, *args, **kwargs, ) -> CustomForwardOutput: diff --git a/inseq/models/encoder_decoder.py b/inseq/models/encoder_decoder.py index 81022be3..8cf04502 100644 --- a/inseq/models/encoder_decoder.py +++ b/inseq/models/encoder_decoder.py @@ -1,6 +1,7 @@ import logging +from collections.abc import Callable from functools import wraps -from typing import Any, Callable, Optional, TypeVar, Union +from typing import Any, TypeVar from ..attr.feat import join_token_ids from ..attr.step_functions import StepFunctionEncoderDecoderArgs @@ -89,7 +90,7 @@ def format_attribution_args( attribute_batch_ids: bool = False, forward_batch_embeds: bool = True, use_baselines: bool = False, - ) -> tuple[dict[str, Any], tuple[Union[IdsTensor, EmbeddingsTensor, None], ...]]: + ) -> tuple[dict[str, Any], tuple[IdsTensor | EmbeddingsTensor | None, ...]]: if attribute_batch_ids: inputs = (batch.sources.input_ids,) baselines = (batch.sources.baseline_ids,) @@ -139,8 +140,8 @@ def enrich_step_output( batch: EncoderDecoderBatch, target_tokens: OneOrMoreTokenSequences, target_ids: TargetIdsTensor, - contrast_batch: Optional[DecoderOnlyBatch] = None, - contrast_targets_alignments: Optional[list[list[tuple[int, int]]]] = None, + contrast_batch: DecoderOnlyBatch | None = None, + contrast_targets_alignments: list[list[tuple[int, int]]] | None = None, ) -> FeatureAttributionStepOutput: r"""Enriches the attribution output with token information, producing the finished :class:`~inseq.data.FeatureAttributionStepOutput` object. @@ -198,12 +199,12 @@ def format_step_function_args( @staticmethod def convert_args_to_batch( args: StepFunctionEncoderDecoderArgs = None, - encoder_input_ids: Optional[IdsTensor] = None, - decoder_input_ids: Optional[IdsTensor] = None, - encoder_attention_mask: Optional[IdsTensor] = None, - decoder_attention_mask: Optional[IdsTensor] = None, - encoder_input_embeds: Optional[EmbeddingsTensor] = None, - decoder_input_embeds: Optional[EmbeddingsTensor] = None, + encoder_input_ids: IdsTensor | None = None, + decoder_input_ids: IdsTensor | None = None, + encoder_attention_mask: IdsTensor | None = None, + decoder_attention_mask: IdsTensor | None = None, + encoder_input_embeds: EmbeddingsTensor | None = None, + decoder_input_embeds: EmbeddingsTensor | None = None, **kwargs, ) -> EncoderDecoderBatch: if args is not None: @@ -232,10 +233,10 @@ def formatted_forward_input_wrapper( decoder_input_ids: IdsTensor, target_ids: ExpandedTargetIdsTensor, attributed_fn: Callable[..., SingleScorePerStepTensor], - encoder_attention_mask: Optional[IdsTensor] = None, - decoder_attention_mask: Optional[IdsTensor] = None, + encoder_attention_mask: IdsTensor | None = None, + decoder_attention_mask: IdsTensor | None = None, use_embeddings: bool = True, - attributed_fn_argnames: Optional[list[str]] = None, + attributed_fn_argnames: list[str] | None = None, *args, **kwargs, ) -> CustomForwardOutput: diff --git a/inseq/models/huggingface_model.py b/inseq/models/huggingface_model.py index fe8a7bca..86937920 100644 --- a/inseq/models/huggingface_model.py +++ b/inseq/models/huggingface_model.py @@ -1,7 +1,7 @@ """HuggingFace Seq2seq model.""" import logging from abc import abstractmethod -from typing import Any, NoReturn, Optional, Union +from typing import Any, NoReturn import torch from torch import long @@ -65,12 +65,12 @@ class HuggingfaceModel(AttributionModel): def __init__( self, - model: Union[str, PreTrainedModel], - attribution_method: Optional[str] = None, - tokenizer: Union[str, PreTrainedTokenizerBase, None] = None, - device: Optional[str] = None, - model_kwargs: Optional[dict[str, Any]] = {}, - tokenizer_kwargs: Optional[dict[str, Any]] = {}, + model: str | PreTrainedModel, + attribution_method: str | None = None, + tokenizer: str | PreTrainedTokenizerBase | None = None, + device: str | None = None, + model_kwargs: dict[str, Any] | None = {}, + tokenizer_kwargs: dict[str, Any] | None = {}, **kwargs, ) -> None: """AttributionModel subclass for Huggingface-compatible models. @@ -142,12 +142,12 @@ def __init__( @staticmethod def load( - model: Union[str, PreTrainedModel], - attribution_method: Optional[str] = None, - tokenizer: Union[str, PreTrainedTokenizerBase, None] = None, + model: str | PreTrainedModel, + attribution_method: str | None = None, + tokenizer: str | PreTrainedTokenizerBase | None = None, device: str = None, - model_kwargs: Optional[dict[str, Any]] = {}, - tokenizer_kwargs: Optional[dict[str, Any]] = {}, + model_kwargs: dict[str, Any] | None = {}, + tokenizer_kwargs: dict[str, Any] | None = {}, **kwargs, ) -> "HuggingfaceModel": """Loads a HuggingFace model and tokenizer and wraps them in the appropriate AttributionModel.""" @@ -204,12 +204,12 @@ def info(self) -> dict[str, str]: @batched def generate( self, - inputs: Union[TextInput, BatchEncoding], + inputs: TextInput | BatchEncoding, return_generation_output: bool = False, skip_special_tokens: bool = True, output_generated_only: bool = False, **kwargs, - ) -> Union[list[str], tuple[list[str], ModelOutput]]: + ) -> list[str] | tuple[list[str], ModelOutput]: """Wrapper of model.generate to handle tokenization and decoding. Args: @@ -244,7 +244,7 @@ def generate( return texts @staticmethod - def output2logits(forward_output: Union[Seq2SeqLMOutput, CausalLMOutput]) -> LogitsTensor: + def output2logits(forward_output: Seq2SeqLMOutput | CausalLMOutput) -> LogitsTensor: # Full logits for last position of every sentence: # (batch_size, tgt_seq_len, vocab_size) => (batch_size, vocab_size) return forward_output.logits[:, -1, :].squeeze(1) @@ -302,7 +302,7 @@ def encode( def decode( self, - ids: Union[list[int], list[list[int]], IdsTensor], + ids: list[int] | list[list[int]] | IdsTensor, skip_special_tokens: bool = True, ) -> list[str]: return self.tokenizer.batch_decode( @@ -331,7 +331,7 @@ def _convert_ids_to_tokens(self, ids: IdsTensor, skip_special_tokens: bool = Tru return tokens def convert_ids_to_tokens( - self, ids: IdsTensor, skip_special_tokens: Optional[bool] = True + self, ids: IdsTensor, skip_special_tokens: bool | None = True ) -> OneOrMoreTokenSequences: if ids.ndim < 2: return self._convert_ids_to_tokens(ids, skip_special_tokens) @@ -350,7 +350,7 @@ def convert_tokens_to_string( ) -> TextInput: if isinstance(tokens, list) and len(tokens) == 0: return "" - elif isinstance(tokens[0], (bytes, str)): + elif isinstance(tokens[0], bytes | str): tmp_decode_state = self.tokenizer._decode_use_source_tokenizer self.tokenizer._decode_use_source_tokenizer = not as_targets out_strings = self.tokenizer.convert_tokens_to_string( @@ -396,7 +396,7 @@ def clean_tokens( """ if isinstance(tokens, list) and len(tokens) == 0: return [] - elif isinstance(tokens[0], (bytes, str)): + elif isinstance(tokens[0], bytes | str): clean_tokens = [] for tok in tokens: clean_tok = self.convert_tokens_to_string( @@ -489,12 +489,12 @@ class HuggingfaceDecoderOnlyModel(HuggingfaceModel, DecoderOnlyAttributionModel) def __init__( self, - model: Union[str, PreTrainedModel], - attribution_method: Optional[str] = None, - tokenizer: Union[str, PreTrainedTokenizerBase, None] = None, + model: str | PreTrainedModel, + attribution_method: str | None = None, + tokenizer: str | PreTrainedTokenizerBase | None = None, device: str = None, - model_kwargs: Optional[dict[str, Any]] = {}, - tokenizer_kwargs: Optional[dict[str, Any]] = {}, + model_kwargs: dict[str, Any] | None = {}, + tokenizer_kwargs: dict[str, Any] | None = {}, **kwargs, ) -> NoReturn: super().__init__(model, attribution_method, tokenizer, device, model_kwargs, tokenizer_kwargs, **kwargs) diff --git a/inseq/models/model_config.py b/inseq/models/model_config.py index 52d9d47b..8616a0a0 100644 --- a/inseq/models/model_config.py +++ b/inseq/models/model_config.py @@ -1,7 +1,6 @@ import logging from dataclasses import dataclass from pathlib import Path -from typing import Optional import yaml @@ -29,7 +28,7 @@ class ModelConfig: self_attention_module: str value_vector: str - cross_attention_module: Optional[str] = None + cross_attention_module: str | None = None MODEL_CONFIGS = { diff --git a/inseq/models/model_decorators.py b/inseq/models/model_decorators.py index 8141f432..ea3de8c3 100644 --- a/inseq/models/model_decorators.py +++ b/inseq/models/model_decorators.py @@ -1,5 +1,6 @@ +from collections.abc import Callable from functools import wraps -from typing import Any, Callable +from typing import Any def unhooked(f: Callable[..., Any]) -> Callable[..., Any]: diff --git a/inseq/utils/alignment_utils.py b/inseq/utils/alignment_utils.py index c465f1e3..daede31c 100644 --- a/inseq/utils/alignment_utils.py +++ b/inseq/utils/alignment_utils.py @@ -4,7 +4,6 @@ from enum import Enum from functools import lru_cache from itertools import chain -from typing import Optional, Union import torch from transformers import AutoModel, AutoTokenizer, PreTrainedModel, PreTrainedTokenizerBase @@ -91,8 +90,8 @@ def _get_aligner_subword_aligns( def compute_word_aligns( - src: Union[str, list[str]], - tgt: Union[str, list[str]], + src: str | list[str], + tgt: str | list[str], split_pattern: str = r"\s+|\b", align_layer: int = 8, score_threshold: float = 1e-3, @@ -207,10 +206,10 @@ def add_alignment_extra_positions( def auto_align_sequences( - a_sequence: Optional[str] = None, - a_tokens: Optional[list[str]] = None, - b_sequence: Optional[str] = None, - b_tokens: Optional[list[str]] = None, + a_sequence: str | None = None, + a_tokens: list[str] | None = None, + b_sequence: str | None = None, + b_tokens: list[str] | None = None, filter_special_tokens: list[str] = [], split_pattern: str = r"\s+|\b", ) -> AlignedSequences: @@ -228,8 +227,8 @@ def auto_align_sequences( a_to_b_word_align = compute_word_aligns(a_words, b_words) # 2. Align word-level alignments to token-level alignments from the generative model tokenizer. # Requires cleaning up the model tokens from special tokens (special characters already removed) - clean_a_tokens, removed_a_token_idxs = clean_tokens(a_tokens, filter_special_tokens) - clean_b_tokens, removed_b_token_idxs = clean_tokens(b_tokens, filter_special_tokens) + clean_a_tokens, removed_a_token_idxs = clean_tokens(a_tokens, filter_special_tokens, return_removed_idxs=True) + clean_b_tokens, removed_b_token_idxs = clean_tokens(b_tokens, filter_special_tokens, return_removed_idxs=True) if len(removed_a_token_idxs) != len(removed_b_token_idxs): logger.debug( "The number of special tokens in the target and contrast sequences do not match. " @@ -246,7 +245,7 @@ def auto_align_sequences( rm_b_idx = removed_b_token_idxs[removed_b_tokens.index(rm_a)] aligned_special_tokens.append((rm_a_idx, rm_b_idx)) else: - aligned_special_tokens = list(zip(removed_a_token_idxs, removed_b_token_idxs)) + aligned_special_tokens = list(zip(removed_a_token_idxs, removed_b_token_idxs, strict=False)) a_word_to_token_align = align_tokenizations(a_words, clean_a_tokens) b_word_to_token_align = align_tokenizations(b_words, clean_b_tokens) # 3. Propagate word-level alignments to token-level alignments. @@ -274,15 +273,15 @@ def auto_align_sequences( def get_adjusted_alignments( - alignments: Union[list[tuple[int, int]], str], - target_sequence: Optional[str] = None, - target_tokens: Optional[list[str]] = None, - contrast_sequence: Optional[str] = None, - contrast_tokens: Optional[list[str]] = None, + alignments: list[tuple[int, int]] | str, + target_sequence: str | None = None, + target_tokens: list[str] | None = None, + contrast_sequence: str | None = None, + contrast_tokens: list[str] | None = None, fill_missing: bool = False, special_tokens: list[str] = [], start_pos: int = 0, - end_pos: Optional[int] = None, + end_pos: int | None = None, ) -> list[tuple[int, int]]: is_auto_aligned = False if fill_missing and not target_tokens: diff --git a/inseq/utils/argparse.py b/inseq/utils/argparse.py index f8e27d98..3ba53f9b 100644 --- a/inseq/utils/argparse.py +++ b/inseq/utils/argparse.py @@ -17,12 +17,12 @@ import sys import types from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser, ArgumentTypeError -from collections.abc import Iterable +from collections.abc import Callable, Iterable from copy import copy from enum import Enum from inspect import isclass from pathlib import Path -from typing import Any, Callable, Literal, NewType, Optional, Union, get_type_hints +from typing import Any, Literal, NewType, Union, get_type_hints import yaml @@ -61,7 +61,7 @@ def make_choice_type_function(choices: list) -> Callable[[str], Any]: def cli_arg( *, - aliases: Union[str, list[str]] = None, + aliases: str | list[str] = None, help: str = None, default: Any = dataclasses.MISSING, default_factory: Callable[[], Any] = dataclasses.MISSING, @@ -121,7 +121,7 @@ class InseqArgumentParser(ArgumentParser): dataclass_types: Iterable[DataClassType] - def __init__(self, dataclass_types: Optional[Union[DataClassType, Iterable[DataClassType]]] = None, **kwargs): + def __init__(self, dataclass_types: DataClassType | Iterable[DataClassType] | None = None, **kwargs): """ Args: dataclass_types (`Union[DataClassType, Iterable[DataClassType]]`, *optional*): @@ -192,7 +192,7 @@ def _parse_dataclass_field(parser: ArgumentParser, field: dataclasses.Field): kwargs["default"] = field.default else: kwargs["required"] = True - elif field.type is bool or field.type == Optional[bool]: + elif field.type is bool or field.type == bool | None: # Copy the currect kwargs to use to instantiate a `no_*` complement argument below. # We do not initialize it here because the `no_*` alternative must be instantiated after the real argument bool_kwargs = copy(kwargs) @@ -232,7 +232,7 @@ def _parse_dataclass_field(parser: ArgumentParser, field: dataclasses.Field): # Order is important for arguments with the same destination! # We use a copy of earlier kwargs because the original kwargs have changed a lot before reaching down # here and we do not need those changes/additional keys. - if field.default is True and (field.type is bool or field.type == Optional[bool]): + if field.default is True and (field.type is bool or field.type == bool | None): bool_kwargs["default"] = False parser.add_argument(f"--no_{field.name}", action="store_false", dest=field.name, **bool_kwargs) @@ -250,19 +250,6 @@ def _add_dataclass_arguments(self, dtype: DataClassType): "removing line of `from __future__ import annotations` which opts in Postponed " "Evaluation of Annotations (PEP 563)" ) from ex - except TypeError as ex: - # Remove this block when we drop Python 3.9 support - if sys.version_info[:2] < (3, 10) and "unsupported operand type(s) for |" in str(ex): - python_version = ".".join(map(str, sys.version_info[:3])) - raise RuntimeError( - f"Type resolution failed for {dtype} on Python {python_version}. Try removing " - "line of `from __future__ import annotations` which opts in union types as " - "`X | Y` (PEP 604) via Postponed Evaluation of Annotations (PEP 563). To " - "support Python versions that lower than 3.10, you need to use " - "`typing.Union[X, Y]` instead of `X | Y` and `typing.Optional[X]` instead of " - "`X | None`." - ) from ex - raise for field in dataclasses.fields(dtype): if not field.init: continue diff --git a/inseq/utils/cache.py b/inseq/utils/cache.py index cab7c39a..638e906d 100644 --- a/inseq/utils/cache.py +++ b/inseq/utils/cache.py @@ -1,9 +1,10 @@ import logging import os import pickle +from collections.abc import Callable from functools import wraps from pathlib import Path -from typing import Any, Callable +from typing import Any logger = logging.getLogger(__name__) diff --git a/inseq/utils/contrast_utils.py b/inseq/utils/contrast_utils.py index 576fef00..94a6bb87 100644 --- a/inseq/utils/contrast_utils.py +++ b/inseq/utils/contrast_utils.py @@ -1,7 +1,8 @@ import logging import warnings +from collections.abc import Callable from dataclasses import dataclass -from typing import TYPE_CHECKING, Callable, Optional, Union +from typing import TYPE_CHECKING import torch @@ -53,15 +54,15 @@ def docstring_decorator(fn: "StepFunction") -> "StepFunction": @dataclass class ContrastInputs: - batch: Union[EncoderDecoderBatch, DecoderOnlyBatch, None] = None - target_ids: Optional[TargetIdsTensor] = None + batch: EncoderDecoderBatch | DecoderOnlyBatch | None = None + target_ids: TargetIdsTensor | None = None def _get_contrast_inputs( args: "StepFunctionArgs", - contrast_sources: Optional[FeatureAttributionInput] = None, - contrast_targets: Optional[FeatureAttributionInput] = None, - contrast_targets_alignments: Optional[list[list[tuple[int, int]]]] = None, + contrast_sources: FeatureAttributionInput | None = None, + contrast_targets: FeatureAttributionInput | None = None, + contrast_targets_alignments: list[list[tuple[int, int]]] | None = None, return_contrastive_target_ids: bool = False, return_contrastive_batch: bool = False, skip_special_tokens: bool = False, @@ -128,9 +129,9 @@ def _get_contrast_inputs( def _setup_contrast_args( args: "StepFunctionArgs", - contrast_sources: Optional[FeatureAttributionInput] = None, - contrast_targets: Optional[FeatureAttributionInput] = None, - contrast_targets_alignments: Optional[list[list[tuple[int, int]]]] = None, + contrast_sources: FeatureAttributionInput | None = None, + contrast_targets: FeatureAttributionInput | None = None, + contrast_targets_alignments: list[list[tuple[int, int]]] | None = None, contrast_force_inputs: bool = False, skip_special_tokens: bool = False, ): diff --git a/inseq/utils/hooks.py b/inseq/utils/hooks.py index 98fd07e5..ec245015 100644 --- a/inseq/utils/hooks.py +++ b/inseq/utils/hooks.py @@ -1,8 +1,8 @@ import re +from collections.abc import Callable from inspect import getsourcelines from sys import gettrace, settrace from types import FrameType -from typing import Callable, Optional from torch import nn @@ -13,7 +13,7 @@ def get_last_variable_assignment_position( module: nn.Module, varname: str, fname: str = "forward", -) -> Optional[int]: +) -> int | None: """Extract the code line number of the last variable assignment for a variable of interest in the specified method of a `nn.Module` object. diff --git a/inseq/utils/misc.py b/inseq/utils/misc.py index 2f0c178e..5843f074 100644 --- a/inseq/utils/misc.py +++ b/inseq/utils/misc.py @@ -5,14 +5,14 @@ import math from base64 import standard_b64decode, standard_b64encode from collections import OrderedDict -from collections.abc import Sequence +from collections.abc import Callable, Sequence from contextlib import contextmanager from functools import wraps from importlib import import_module from inspect import signature from numbers import Number from os import PathLike, fsync -from typing import Any, Callable, Optional, Union +from typing import Any from numpy import asarray, frombuffer from torch import Tensor @@ -44,7 +44,7 @@ def _pretty_list_contents(l: Sequence[Any]) -> str: ) -def _pretty_list(l: Optional[Sequence[Any]], lpad: int = 8) -> str: +def _pretty_list(l: Sequence[Any] | None, lpad: int = 8) -> str: if all(isinstance(x, list) for x in l): line_sep = f" ],\n{' ' * lpad}[ " contents = " " * lpad + "[ " + line_sep.join([_pretty_list_contents(subl) for subl in l]) + " ]" @@ -57,7 +57,7 @@ def _pretty_list(l: Optional[Sequence[Any]], lpad: int = 8) -> str: return "[\n" + contents + f"\n{' ' * (lpad - 4)}]" -def pretty_list(l: Optional[Sequence[Any]], lpad: int = 8) -> str: +def pretty_list(l: Sequence[Any] | None, lpad: int = 8) -> str: if l is None: return "None" if len(l) == 0: @@ -74,7 +74,7 @@ def pretty_list(l: Optional[Sequence[Any]], lpad: int = 8) -> str: return f"{out_txt}: {_pretty_list(l, lpad)}" -def pretty_tensor(t: Optional[Tensor] = None, lpad: int = 8) -> str: +def pretty_tensor(t: Tensor | None = None, lpad: int = 8) -> str: if t is None: return "None" if t.ndim > 2 or any(x > 20 for x in t.shape): @@ -91,7 +91,7 @@ def pretty_dict(d: dict[str, Any], lpad: int = 4) -> str: out_txt = "{\n" for k, v in d.items(): out_txt += f"{' ' * lpad}{k}: " - if isinstance(v, (list, tuple)): + if isinstance(v, list | tuple): out_txt += pretty_list(v, lpad + 4) elif isinstance(v, Tensor): out_txt += pretty_tensor(v, lpad + 4) @@ -112,9 +112,9 @@ def pretty_dict(d: dict[str, Any], lpad: int = 4) -> str: def extract_signature_args( full_args: dict[str, Any], func: Callable[[Any], Any], - exclude_args: Optional[Sequence[str]] = None, + exclude_args: Sequence[str] | None = None, return_remaining: bool = False, -) -> Union[dict[str, Any], tuple[dict[str, Any], dict[str, Any]]]: +) -> dict[str, Any] | tuple[dict[str, Any], dict[str, Any]]: extracted_args = { k: v for k, v in full_args.items() @@ -201,7 +201,7 @@ def isnotebook(): def format_input_texts( texts: TextInput, - ref_texts: Optional[TextInput] = None, + ref_texts: TextInput | None = None, skip_special_tokens: bool = False, special_tokens: list[str] = [], ) -> tuple[list[str], list[str]]: @@ -246,7 +246,7 @@ def aggregate_token_pair(tokens: list[TokenWithId], other_tokens: list[TokenWith if not other_tokens: return tokens out_tokens = [] - for tok, other in zip(tokens, other_tokens): + for tok, other in zip(tokens, other_tokens, strict=False): if tok.token == other.token: out_tokens.append(TokenWithId(tok.token, tok.id)) else: @@ -310,13 +310,13 @@ def save_to_file(f: Callable[[Any], Any]) -> Callable[[Any], Any]: @wraps(f) def save_to_file_wrapper( obj: Any, - fp: Union[str, bytes, PathLike] = None, + fp: str | bytes | PathLike = None, *args, - compression: Union[int, bool] = None, + compression: int | bool = None, force_flush: bool = False, return_output: bool = True, **kwargs, - ) -> Optional[Any]: + ) -> Any | None: if "compression" in signature(f).parameters: kwargs["compression"] = compression txt = f(obj, *args, **kwargs) @@ -431,16 +431,27 @@ def get_cls_from_instance_type(mod, name, cls_lookup_map): return curr_class -def clean_tokens(tokens: list[str], remove_tokens: list[str]) -> tuple[list[str], list[int]]: +def clean_tokens( + tokens: list[str], + remove_tokens: list[str] = [], + return_removed_idxs: bool = False, + replace_chars: dict[str, str] | None = None, +) -> list[str] | tuple[list[str], list[int]]: """Removes tokens from a list of tokens and returns the cleaned list and the removed token indexes.""" clean_tokens = [] removed_token_idxs = [] for idx, tok in enumerate(tokens): - if tok not in remove_tokens: - clean_tokens += [tok.strip()] - else: + new_tok = tok + if new_tok in remove_tokens: removed_token_idxs += [idx] - return clean_tokens, removed_token_idxs + else: + if replace_chars is not None: + for k, v in replace_chars.items(): + new_tok = new_tok.replace(k, v) + clean_tokens += [new_tok.strip()] + if return_removed_idxs: + return clean_tokens, removed_token_idxs + return clean_tokens def get_left_padding(text: str): diff --git a/inseq/utils/serialization.py b/inseq/utils/serialization.py index f7966191..e220986d 100644 --- a/inseq/utils/serialization.py +++ b/inseq/utils/serialization.py @@ -32,9 +32,10 @@ import base64 import json from collections import OrderedDict +from collections.abc import Callable from json import JSONEncoder from os import PathLike -from typing import Any, Callable, Optional, TypeVar, Union +from typing import Any, TypeVar from numpy import generic, ndarray @@ -58,7 +59,7 @@ def class_instance_encode(obj: EncodableObject, use_primitives: bool = True, **k """Encodes a class instance to json. Note that it can only be recovered if the environment allows the class to be imported in the same way. """ - if isinstance(obj, (list, dict)): + if isinstance(obj, list | dict): return obj if isinstance(obj, bytes): return base64.b64encode(obj).decode("UTF8") @@ -94,7 +95,7 @@ def class_instance_encode(obj: EncodableObject, use_primitives: bool = True, **k def ndarray_encode( obj: EncodableObject, use_primitives: bool = True, - ndarray_compact: Optional[bool] = None, + ndarray_compact: bool | None = None, compression: bool = False, **kwargs, ) -> dict[str, Any]: @@ -136,9 +137,9 @@ def ndarray_encode( class AttributionSerializer(JSONEncoder): def __init__( self, - encoders: Optional[list[Callable]] = None, + encoders: list[Callable] | None = None, use_primitives: bool = True, - ndarray_compact: Optional[bool] = None, + ndarray_compact: bool | None = None, compression: bool = False, **json_kwargs, ): @@ -178,7 +179,7 @@ def json_advanced_dumps( encoders: list[Callable] = ENCODE_HOOKS, use_primitives: bool = True, allow_nan: bool = True, - ndarray_compact: Optional[bool] = None, + ndarray_compact: bool | None = None, compression: bool = False, **jsonkwargs, ) -> str: @@ -225,7 +226,7 @@ def json_advanced_dump( encoders: list[Callable] = ENCODE_HOOKS, use_primitives: bool = False, allow_nan: bool = True, - ndarray_compact: Optional[bool] = None, + ndarray_compact: bool | None = None, compression: bool = False, **jsonkwargs, ) -> str: @@ -302,7 +303,7 @@ def ndarray_hook(dct: Any, **kwargs) -> DecodableObject: return scalar_to_numpy(data_json, nptype) -def class_instance_hook(dct: Any, cls_lookup_map: Optional[dict[str, type]] = None, **kwargs) -> DecodableObject: +def class_instance_hook(dct: Any, cls_lookup_map: dict[str, type] | None = None, **kwargs) -> DecodableObject: """This hook tries to convert json encoded by class_instance_encoder back to it's original instance. It only works if the environment is the same, e.g. the class is similarly importable and hasn't changed. @@ -344,8 +345,8 @@ class AttributionDeserializer: def __init__( self, ordered: bool = False, - hooks: Optional[list[Callable]] = None, - cls_lookup_map: Optional[dict[str, type]] = None, + hooks: list[Callable] | None = None, + cls_lookup_map: dict[str, type] | None = None, ): self.map_type = OrderedDict if ordered else dict self.hooks = hooks if hooks else [] @@ -367,7 +368,7 @@ def json_advanced_loads( ordered: bool = False, decompression: bool = False, hooks: list[Callable] = DECODE_HOOKS, - cls_lookup_map: Optional[dict[str, type]] = None, + cls_lookup_map: dict[str, type] | None = None, **jsonkwargs, ) -> Any: """Load a complex object containing classes and arrays object from a string. @@ -403,11 +404,11 @@ def json_advanced_loads( def json_advanced_load( - fp: Union[str, bytes, PathLike], + fp: str | bytes | PathLike, ordered: bool = True, decompression: bool = False, hooks: list[Callable] = DECODE_HOOKS, - cls_lookup_map: Optional[dict[str, type]] = None, + cls_lookup_map: dict[str, type] | None = None, **jsonkwargs, ) -> Any: """Load a complex object containing classes and arrays from a JSON file. @@ -431,7 +432,7 @@ def json_advanced_load( The loaded object. """ try: - if isinstance(fp, (PathLike, bytes, str)): + if isinstance(fp, PathLike | bytes | str): with open(fp, "rb" if decompression else "r") as fh: string = fh.read() else: diff --git a/inseq/utils/torch_utils.py b/inseq/utils/torch_utils.py index d3c13b90..ba0e431f 100644 --- a/inseq/utils/torch_utils.py +++ b/inseq/utils/torch_utils.py @@ -1,8 +1,8 @@ import logging -from collections.abc import Sequence +from collections.abc import Callable, Sequence from functools import wraps from inspect import signature -from typing import TYPE_CHECKING, Callable, Literal, Optional, Union +from typing import TYPE_CHECKING, Literal import safetensors import torch @@ -76,8 +76,8 @@ def convert_from_safetensor(safetensor: bytes) -> torch.Tensor: def postprocess_attribution_scores(func: Callable) -> Callable: @wraps(func) def postprocess_scores_wrapper( - attributions: Union[torch.Tensor, tuple[torch.Tensor, ...]], dim: int = 0, *args, **kwargs - ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: + attributions: torch.Tensor | tuple[torch.Tensor, ...], dim: int = 0, *args, **kwargs + ) -> torch.Tensor | tuple[torch.Tensor, ...]: multi_input = False if isinstance(attributions, tuple): orig_sizes = [a.shape[dim] for a in attributions] @@ -98,17 +98,17 @@ def postprocess_scores_wrapper( @postprocess_attribution_scores def normalize( - attributions: Union[torch.Tensor, tuple[torch.Tensor, ...]], + attributions: torch.Tensor | tuple[torch.Tensor, ...], dim: int = 0, norm_ord: int = 1, -) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: +) -> torch.Tensor | tuple[torch.Tensor, ...]: return F.normalize(attributions, p=norm_ord, dim=dim) @postprocess_attribution_scores def rescale( - attributions: Union[torch.Tensor, tuple[torch.Tensor, ...]], -) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: + attributions: torch.Tensor | tuple[torch.Tensor, ...], +) -> torch.Tensor | tuple[torch.Tensor, ...]: return attributions / attributions.abs().max() @@ -140,9 +140,9 @@ def top_k_logits_mask(logits: torch.Tensor, top_k: int, min_tokens_to_keep: int) def get_logits_from_filter_strategy( - filter_strategy: Union[Literal["original"], Literal["contrast"], Literal["merged"]], + filter_strategy: Literal["original"] | Literal["contrast"] | Literal["merged"], original_logits: torch.Tensor, - contrast_logits: Optional[torch.Tensor] = None, + contrast_logits: torch.Tensor | None = None, ) -> torch.Tensor: if filter_strategy == "original": return original_logits @@ -154,12 +154,12 @@ def get_logits_from_filter_strategy( def filter_logits( original_logits: torch.Tensor, - contrast_logits: Optional[torch.Tensor] = None, + contrast_logits: torch.Tensor | None = None, top_p: float = 1.0, top_k: int = 0, min_tokens_to_keep: int = 1, - filter_strategy: Union[Literal["original"], Literal["contrast"], Literal["merged"], None] = None, -) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + filter_strategy: Literal["original"] | Literal["contrast"] | Literal["merged"] | None = None, +) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: """Applies top-k and top-p filtering to logits, and optionally to an additional set of contrastive logits.""" if top_k > original_logits.size(-1) or top_k < 0: raise ValueError(f"`top_k` has to be a positive integer < {original_logits.size(-1)}, but is {top_k}") @@ -201,7 +201,7 @@ def euclidean_distance(vec_a: torch.Tensor, vec_b: torch.Tensor) -> torch.Tensor def aggregate_contiguous( t: torch.Tensor, spans: Sequence[tuple[int, int]], - aggregate_fn: Optional[Callable] = None, + aggregate_fn: Callable | None = None, aggregate_dim: int = 0, ): """Given a tensor, aggregate contiguous spans of the tensor along a given dimension using the provided @@ -269,7 +269,7 @@ def get_sequences_from_batched_steps( padding_dims = set(padding_dims) max_dims = tuple(max([bstep.shape[dim] for bstep in bsteps]) for dim in padding_dims) for bstep_idx, bstep in enumerate(bsteps): - for curr_dim, max_dim in zip(padding_dims, max_dims): + for curr_dim, max_dim in zip(padding_dims, max_dims, strict=False): bstep_dim = bstep.shape[curr_dim] if bstep_dim < max_dim: # Pad the end of curr_dim with nans @@ -331,7 +331,7 @@ def find_block_stack(module): def validate_indices( scores: torch.Tensor, dim: int = -1, - indices: Optional[OneOrMoreIndices] = None, + indices: OneOrMoreIndices | None = None, ) -> OneOrMoreIndices: """Validates a set of indices for a given dimension of a tensor of scores. Supports single indices, spans and lists of indices, including negative indices to specify positions relative to the end of the tensor. @@ -350,7 +350,7 @@ def validate_indices( if dim >= scores.ndim: raise IndexError(f"Dimension {dim} is greater than tensor dimension {scores.ndim}") n_units = scores.shape[dim] - if not isinstance(indices, (int, tuple, list)) and indices is not None: + if not isinstance(indices, int | tuple | list) and indices is not None: raise TypeError( "Indices must be an integer, a (start, end) tuple of indices representing a span, a list of individual" " indices or a single index." @@ -405,7 +405,7 @@ def pad_with_nan(t: torch.Tensor, dim: int, pad_size: int, front: bool = False) return torch.cat([t, nan_tensor], dim=dim) -def recursive_get_submodule(parent: nn.Module, target: str) -> Optional[nn.Module]: +def recursive_get_submodule(parent: nn.Module, target: str) -> nn.Module | None: if target == "": return parent mod = None diff --git a/inseq/utils/typing.py b/inseq/utils/typing.py index 4673c079..8c48f9bf 100644 --- a/inseq/utils/typing.py +++ b/inseq/utils/typing.py @@ -1,13 +1,13 @@ -from collections.abc import Sequence +from collections.abc import Callable, Sequence from dataclasses import dataclass -from typing import TYPE_CHECKING, Callable, Literal, Optional, Union +from typing import TYPE_CHECKING, Literal import torch from captum.attr._utils.attribution import Attribution from jaxtyping import Float, Float32, Int64 from transformers import PreTrainedModel -TextInput = Union[str, Sequence[str]] +TextInput = str | Sequence[str] if TYPE_CHECKING: from inseq.models import AttributionModel @@ -21,7 +21,7 @@ class TokenWithId: def __str__(self): return self.token - def __eq__(self, other: Union[str, int, "TokenWithId"]): + def __eq__(self, other: "str | int | TokenWithId"): if isinstance(other, str): return self.token == other elif isinstance(other, int): @@ -63,7 +63,7 @@ def get_name(cls: type["InseqAttribution"]) -> str: @dataclass class TextSequences: targets: TextInput - sources: Optional[TextInput] = None + sources: TextInput | None = None OneOrMoreIdSequences = Sequence[Sequence[int]] @@ -73,8 +73,8 @@ class TextSequences: ScorePrecision = Literal["float32", "float16", "float8"] -IndexSpan = Union[tuple[int, int], Sequence[tuple[int, int]]] -OneOrMoreIndices = Union[int, list[int], tuple[int, int]] +IndexSpan = tuple[int, int] | Sequence[tuple[int, int]] +OneOrMoreIndices = int | list[int] | tuple[int, int] OneOrMoreIndicesDict = dict[int, OneOrMoreIndices] IdsTensor = Int64[torch.Tensor, "batch_size seq_len"] @@ -107,7 +107,7 @@ class TextSequences: # or produced by methods that work at token-level (e.g. attention) TokenStepAttributionTensor = MultipleScoresPerStepTensor -StepAttributionTensor = Union[GranularStepAttributionTensor, TokenStepAttributionTensor] +StepAttributionTensor = GranularStepAttributionTensor | TokenStepAttributionTensor # One attribution score per embedding value for every attributed token in attributed_seq # for all generated tokens in generated_seq. Produced by aggregating GranularStepAttributionTensor @@ -119,7 +119,7 @@ class TextSequences: # or by aggregating TokenStepAttributionTensor across multiple steps and separating batches. TokenSequenceAttributionTensor = MultipleScoresPerSequenceTensor -SequenceAttributionTensor = Union[GranularSequenceAttributionTensor, TokenSequenceAttributionTensor] +SequenceAttributionTensor = GranularSequenceAttributionTensor | TokenSequenceAttributionTensor # For Huggingface it's a string identifier e.g. "t5-base", "Helsinki-NLP/opus-mt-en-it" # For Fairseq it's a tuple of strings containing repo and model name @@ -127,9 +127,6 @@ class TextSequences: ModelIdentifier = str # Union[str, Tuple[str, str]] ModelClass = PreTrainedModel -AttributionForwardInputs = Union[IdsTensor, EmbeddingsTensor] -AttributionForwardInputsPair = Union[ - tuple[IdsTensor, IdsTensor], - tuple[EmbeddingsTensor, EmbeddingsTensor], -] -OneOrTwoAttributionForwardInputs = Union[AttributionForwardInputs, AttributionForwardInputsPair] +AttributionForwardInputs = IdsTensor | EmbeddingsTensor +AttributionForwardInputsPair = tuple[IdsTensor, IdsTensor] | tuple[EmbeddingsTensor, EmbeddingsTensor] +OneOrTwoAttributionForwardInputs = AttributionForwardInputs | AttributionForwardInputsPair diff --git a/inseq/utils/viz_utils.py b/inseq/utils/viz_utils.py index 69e4a94f..7238b9a6 100644 --- a/inseq/utils/viz_utils.py +++ b/inseq/utils/viz_utils.py @@ -16,34 +16,71 @@ # WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE # OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -from typing import Union + +from collections.abc import Callable +from functools import wraps +from typing import Any, Literal import matplotlib.pyplot as plt import numpy as np +import treescope as ts from matplotlib.colors import Colormap, LinearSegmentedColormap from numpy.typing import NDArray -from .misc import ordinal_str +from .misc import isnotebook, ordinal_str from .typing import TokenWithId +red = (178, 24, 43) +beige = (247, 252, 253) +blue = (33, 102, 172) +green = (0, 109, 44) +brown = (140, 81, 10) + def get_instance_html(i: int): return "
" + ordinal_str(i) + " instance:
" +def interpolate_color(color1, color2, t): + return tuple(int(c1 + (c2 - c1) * t) for c1, c2 in zip(color1, color2, strict=False)) + + +def generate_colormap(start_color, end_color, num_colors): + return [interpolate_color(start_color, end_color, t) for t in np.linspace(0, 1, num_colors)] + + def red_transparent_blue_colormap(): colors = [] for l in np.linspace(1, 0, 100): - colors.append((30.0 / 255, 136.0 / 255, 229.0 / 255, l)) + colors.append((*(float(c) / 255 for c in blue), l)) for l in np.linspace(0, 1, 100): - colors.append((255.0 / 255, 13.0 / 255, 87.0 / 255, l)) + colors.append((*(float(c) / 255 for c in red), l)) return LinearSegmentedColormap.from_list("red_transparent_blue", colors) +def treescope_cmap(colors: Literal["blue_to_red", "brown_to_green", "greens", "blues"] = "blue_to_red", n: int = 200): + match colors: + case "blue_to_red": + first_half = generate_colormap(blue, beige, n // 2) + second_half = generate_colormap(beige, red, n - len(first_half)) + cmap = first_half + second_half + case "brown_to_green": + first_half = generate_colormap(brown, beige, n // 2) + second_half = generate_colormap(beige, green, n - len(first_half)) + cmap = first_half + second_half + case "greens": + cmap = generate_colormap(beige, green, n) + case "blues": + cmap = generate_colormap(beige, blue, n) + case _: + raise ValueError(f"Invalid color scheme {colors}: valid options are 'blue_to_red', 'greens', 'blues'") + return cmap + + def get_color( score: float, - min_value: Union[float, int], - max_value: Union[float, int], + min_value: float | int, + max_value: float | int, cmap: Colormap, return_alpha: bool = True, return_string: bool = True, @@ -62,7 +99,7 @@ def get_color( return color -def sanitize_html(txt: Union[str, TokenWithId]) -> str: +def sanitize_html(txt: str | TokenWithId) -> str: if isinstance(txt, TokenWithId): txt = txt.token return txt.replace("<", "<").replace(">", ">") @@ -70,9 +107,9 @@ def sanitize_html(txt: Union[str, TokenWithId]) -> str: def get_colors( scores: NDArray, - min_value: Union[float, int], - max_value: Union[float, int], - cmap: Union[str, Colormap, None] = None, + min_value: float | int, + max_value: float | int, + cmap: str | Colormap | None = None, return_alpha: bool = True, return_strings: bool = True, ): @@ -90,6 +127,46 @@ def get_colors( return input_colors +def test_dim(dim: int | str, dim_names: dict[int, str], rev_dim_names: dict[str, int], scores: np.ndarray) -> int: + if isinstance(dim, str): + if dim not in rev_dim_names: + raise ValueError(f"Invalid dimension name {dim}: valid names are {list(rev_dim_names.keys())}") + dim_idx = rev_dim_names[dim] + else: + dim_idx = dim + if dim_idx <= 1 or dim_idx > scores.ndim or dim_idx not in dim_names: + raise ValueError(f"Invalid dimension {dim_idx}: valid indices are {list(range(2, scores.ndim))}") + return dim_idx + + +def maybe_add_linebreak(tok: str, i: int, wrap_after: int | str | list[str] | tuple[str]) -> list[str]: + if isinstance(wrap_after, str) and tok == wrap_after: + return [ts.rendering_parts.text("\n")] + elif isinstance(wrap_after, list | tuple) and tok in wrap_after: + return [ts.rendering_parts.text("\n")] + elif isinstance(wrap_after, int) and i % wrap_after == 0: + return [ts.rendering_parts.text("\n")] + else: + return [] + + +def treescope_ignore(f: Callable[..., Any]) -> Callable[..., Any]: + @wraps(f) + def treescope_unhooked_wrapper(self, *args, **kwargs): + if isnotebook(): + # Unhook the treescope visualization to allow `rich.jupyter.JupyterRenderable` to render correctly + import IPython + + del IPython.get_ipython().display_formatter.formatters["text/html"].type_printers[object] + out = f(self, *args, **kwargs) + if isnotebook(): + # Re-hook the treescope visualization + ts.register_as_default() + return out + + return treescope_unhooked_wrapper + + # Full plot final_plot_html = """ diff --git a/pyproject.toml b/pyproject.toml index fcfd8af0..363c3c07 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ name = "inseq" version = "0.7.0.dev0" description = "Interpretability for Sequence Generation Models 🔍" readme = "README.md" -requires-python = ">=3.9" +requires-python = ">=3.10" license = {file = "LICENSE"} keywords = ["generative AI", "transformers", "natural language processing", "XAI", "explainable ai", "interpretability", "feature attribution", "machine translation"] authors = [ @@ -31,7 +31,6 @@ classifiers = [ "Operating System :: OS Independent", "Programming Language :: Python", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", @@ -50,6 +49,7 @@ dependencies = [ "torch>=2.0", "matplotlib>=3.5.3", "tqdm>=4.64.0", + "treescope>=0.1.0", "nvidia-cublas-cu11>=11.10.3.66; sys_platform=='Linux'", "nvidia-cuda-cupti-cu11>=11.7.101; sys_platform=='Linux'", "nvidia-cuda-nvrtc-cu11>=11.7.99; sys_platform=='Linux'", @@ -110,7 +110,7 @@ changelog = "https://github.com/inseq-team/inseq/blob/main/CHANGELOG.md" [tool.mypy] # https://mypy.readthedocs.io/en/latest/config_file.html#using-a-pyproject-toml-file -python_version = "3.9" +python_version = "3.10" strict = true @@ -164,7 +164,7 @@ packages = ["inseq"] [tool.ruff] -target-version = "py39" +target-version = "py310" exclude = [ ".git", ".vscode", diff --git a/requirements-dev.txt b/requirements-dev.txt index 4374e00b..40a3eb2c 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -24,7 +24,7 @@ beautifulsoup4==4.12.3 # via furo captum==0.7.0 # via inseq (pyproject.toml) -certifi==2024.6.2 +certifi==2024.7.4 # via requests cffi==1.16.0 # via cryptography @@ -190,6 +190,7 @@ numpy==1.26.4 # scikit-learn # scipy # transformers + # treescope packaging==23.2 # via # datasets @@ -379,7 +380,7 @@ torch==2.3.1 # via # inseq (pyproject.toml) # captum -tornado==6.4 +tornado==6.4.1 # via # ipykernel # jupyter-client @@ -402,6 +403,8 @@ traitlets==5.14.1 # matplotlib-inline transformers==4.38.1 # via inseq (pyproject.toml) +treescope==0.1.0 + # via inseq (pyproject.toml) typeguard==2.13.3 # via # inseq (pyproject.toml) diff --git a/requirements.txt b/requirements.txt index 27e3e89c..34fa69bf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,7 @@ # uv pip compile pyproject.toml -o requirements.txt captum==0.7.0 # via inseq (pyproject.toml) -certifi==2024.6.2 +certifi==2024.7.4 # via requests charset-normalizer==3.3.2 # via requests @@ -55,6 +55,7 @@ numpy==1.26.4 # jaxtyping # matplotlib # transformers + # treescope packaging==23.2 # via # huggingface-hub @@ -106,6 +107,8 @@ tqdm==4.66.4 # transformers transformers==4.38.1 # via inseq (pyproject.toml) +treescope==0.1.0 + # via inseq (pyproject.toml) typeguard==2.13.3 # via # inseq (pyproject.toml) diff --git a/tests/attr/feat/test_attribution_utils.py b/tests/attr/feat/test_attribution_utils.py index 956ae637..f13c0419 100644 --- a/tests/attr/feat/test_attribution_utils.py +++ b/tests/attr/feat/test_attribution_utils.py @@ -29,7 +29,7 @@ def test_get_step_prediction_probabilities(m2m100_model, encoder_decoder_batches ] # fmt: on for i, (batch, next_batch) in enumerate( - zip(encoder_decoder_batches["batches"][1:], encoder_decoder_batches["batches"][2:]) + zip(encoder_decoder_batches["batches"][1:], encoder_decoder_batches["batches"][2:], strict=False) ): output = m2m100_model.get_forward_output( batch.to(m2m100_model.device), use_embeddings=m2m100_model.attribution_method.forward_batch_embeds @@ -43,7 +43,9 @@ def test_get_step_prediction_probabilities(m2m100_model, encoder_decoder_batches def test_crossentropy_nlll_equivalence(m2m100_model, encoder_decoder_batches): - for batch, next_batch in zip(encoder_decoder_batches["batches"][1:], encoder_decoder_batches["batches"][2:]): + for batch, next_batch in zip( + encoder_decoder_batches["batches"][1:], encoder_decoder_batches["batches"][2:], strict=False + ): batch.to(m2m100_model.device) next_batch.to(m2m100_model.device) output = m2m100_model.model( diff --git a/tests/attr/feat/test_feature_attribution.py b/tests/attr/feat/test_feature_attribution.py index 80856176..78473c9b 100644 --- a/tests/attr/feat/test_feature_attribution.py +++ b/tests/attr/feat/test_feature_attribution.py @@ -1,4 +1,4 @@ -from typing import Any, Optional +from typing import Any import torch from captum._utils.typing import TensorOrTupleOfTensorsGeneric @@ -147,9 +147,9 @@ def attribute( self, inputs: TensorOrTupleOfTensorsGeneric, additional_forward_args: TensorOrTupleOfTensorsGeneric, - encoder_self_attentions: Optional[MultiLayerMultiUnitScoreTensor] = None, - decoder_self_attentions: Optional[MultiLayerMultiUnitScoreTensor] = None, - cross_attentions: Optional[MultiLayerMultiUnitScoreTensor] = None, + encoder_self_attentions: MultiLayerMultiUnitScoreTensor | None = None, + decoder_self_attentions: MultiLayerMultiUnitScoreTensor | None = None, + cross_attentions: MultiLayerMultiUnitScoreTensor | None = None, ) -> MultiDimensionalFeatureAttributionStepOutput: # We adopt the format [batch_size, sequence_length, num_layers, num_heads] # for consistency with other multi-unit methods (e.g. gradient attribution) diff --git a/tests/attr/feat/test_step_functions.py b/tests/attr/feat/test_step_functions.py index e9fe78be..916b8ea1 100644 --- a/tests/attr/feat/test_step_functions.py +++ b/tests/attr/feat/test_step_functions.py @@ -32,7 +32,7 @@ def test_contrast_prob_consistency_decoder(saliency_gpt2: DecoderOnlyAttribution step_scores=["probability"], ) regular_prob = out_regular.sequence_attributions[0].step_scores["probability"] - assert all(c == r for c, r in zip(contrast_prob, regular_prob)) + assert all(c == r for c, r in zip(contrast_prob, regular_prob, strict=False)) def test_contrast_prob_consistency_enc_dec(saliency_mt_model: EncoderDecoderAttributionModel): @@ -52,7 +52,7 @@ def test_contrast_prob_consistency_enc_dec(saliency_mt_model: EncoderDecoderAttr step_scores=["probability"], ) regular_prob = out_regular.sequence_attributions[0].step_scores["probability"] - assert all(c == r for c, r in zip(contrast_prob, regular_prob[-len(contrast_prob) :])) + assert all(c == r for c, r in zip(contrast_prob, regular_prob[-len(contrast_prob) :], strict=False)) def attr_prob_diff_fn( diff --git a/tests/models/test_huggingface_model.py b/tests/models/test_huggingface_model.py index 72da4a2f..2d2373a3 100644 --- a/tests/models/test_huggingface_model.py +++ b/tests/models/test_huggingface_model.py @@ -101,11 +101,11 @@ def test_cuda_attribution_consistency_seq2seq(texts, reference_texts, attribute_ ) assert isinstance(out[device], FeatureAttributionOutput) assert isinstance(out[device].sequence_attributions[0], FeatureAttributionSequenceOutput) - for out_cpu, out_gpu in zip(out["cpu"].sequence_attributions, out["cuda"].sequence_attributions): - assert all(tok_cpu == tok_gpu for tok_cpu, tok_gpu in zip(out_cpu.target, out_gpu.target)) + for out_cpu, out_gpu in zip(out["cpu"].sequence_attributions, out["cuda"].sequence_attributions, strict=False): + assert all(tok_cpu == tok_gpu for tok_cpu, tok_gpu in zip(out_cpu.target, out_gpu.target, strict=False)) attr_score_matches = [ torch.allclose(cpu_attr, gpu_attr, atol=1e-3) - for cpu_attr, gpu_attr in zip(out_cpu.source_attributions, out_gpu.source_attributions) + for cpu_attr, gpu_attr in zip(out_cpu.source_attributions, out_gpu.source_attributions, strict=False) ] assert all(attr_score_matches)