Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix MultiDimensional attribution aggregation #292

Merged
merged 3 commits into from
Nov 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,17 @@ out_sliced = out.aggregate("slices", target_spans=(13,73))
out_sliced = out[13:73]
```

- A new `StringSplitAggregator` (`"split"`) is added to allow for supporting more complex aggregation procedures beyond simple subword merging in`FeatureAttributionSequenceOutput` objects. More specifically, splitting supports regex expression to match split points even when these are (potentially overlapping) parts of existing tokens. The `split_mode` parameter can be set to `"single"` (default) to keep tokens containing matched split points separate while aggregating the rest, or `"start"` or `"end"` to concatenate them to the preceding/following aggregated token sequence. [#290](https://github.com/inseq-team/inseq/pull/290)

```python
# Split on newlines. Default split_mode = "single".
out.aggregate("split", split_pattern="\n").aggregate("sum").show(do_aggregation=False)

# Split on whitespace-separated words of length 5.
# Note: this works if clean_special_chars = True is used, otherwise the split_pattern should be adjusted to split on special characters like "Ġ" or "▁".
out.aggregate("split", split_pattern=r"\s(\w{5})(?=\s)", split_mode="end")
```

- The `__sub__` method in `FeatureAttributionSequenceOutput` is now used as a shortcut for `PairAggregator` ([#282](https://github.com/inseq-team/inseq/pull/282)).


Expand Down Expand Up @@ -90,6 +101,8 @@ out_female = attrib_model.attribute(

- Fix support for multi-EOS tokens (e.g. LLaMA 3.2, see [#287](https://github.com/inseq-team/inseq/issues/287)).

- Fix copying configuration parameters to aggregated `FeatureAttributionSequenceOutput` objects ([#292](https://github.com/inseq-team/inseq/pull/292)).

## 📝 Documentation and Tutorials

- Updated tutorial with `treescope` usage examples.
Expand Down
8 changes: 7 additions & 1 deletion inseq/data/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def _aggregate(
kwargs["aggregate_fn"] = kwargs["aggregate_fn"][cls.aggregator_family]
field_func = getattr(cls, f"aggregate_{field}")
aggregated_sequence_attribution_fields[field] = field_func(attr, **kwargs)
return attr.__class__(**aggregated_sequence_attribution_fields)
return attr.__class__(**aggregated_sequence_attribution_fields, **attr.config)

@classmethod
def _process_attribution_scores(
Expand Down Expand Up @@ -346,6 +346,12 @@ def _process_attribution_scores(
@classmethod
def post_aggregate_hook(cls, attr: "FeatureAttributionSequenceOutput", **kwargs):
super().post_aggregate_hook(attr, **kwargs)
if attr.source_attributions is not None:
attr._num_dimensions = attr.source_attributions.ndim
elif attr.target_attributions is not None:
attr._num_dimensions = attr.target_attributions.ndim
else:
attr._num_dimensions = 0
cls.is_compatible(attr)

@classmethod
Expand Down
13 changes: 10 additions & 3 deletions inseq/data/attribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ class FeatureAttributionSequenceOutput(TensorWrapper, AggregableMixin):
_aggregator: str | list[str] | None = None
_dict_aggregate_fn: dict[str, str] | None = None
_attribution_dim_names: dict[str, dict[int, str]] | None = None
_num_dimensions: int | None = None

def __post_init__(self):
if self._dict_aggregate_fn is None:
Expand All @@ -181,6 +182,8 @@ def __post_init__(self):
self._attribution_dim_names = default_dim_names
if self._aggregator is None:
self._aggregator = "scores"
if self._num_dimensions is None:
self._num_dimensions = 0
if self.attr_pos_end is None or self.attr_pos_end > len(self.target):
self.attr_pos_end = len(self.target)

Expand Down Expand Up @@ -309,6 +312,10 @@ def _recover_from_safetensors(self):
}
return self

@property
def config(self) -> dict[str, Any]:
return {k: v for k, v in self.__dict__.items() if k.startswith("_")}

@staticmethod
def get_remove_pad_fn(attr: "FeatureAttributionStepOutput", name: str) -> Callable:
if attr.source_attributions is None or name.startswith("decoder"):
Expand Down Expand Up @@ -696,10 +703,13 @@ class FeatureAttributionStepOutput(TensorWrapper):
source: OneOrMoreTokenWithIdSequences | None = None
prefix: OneOrMoreTokenWithIdSequences | None = None
target: OneOrMoreTokenWithIdSequences | None = None
_num_dimensions: int | None = None
_sequence_cls: type["FeatureAttributionSequenceOutput"] = FeatureAttributionSequenceOutput

def __post_init__(self):
self.to(torch.float32)
if self._num_dimensions is None:
self._num_dimensions = 0
if self.step_scores is None:
self.step_scores = {}
if self.sequence_scores is None:
Expand Down Expand Up @@ -1213,8 +1223,6 @@ class MultiDimensionalFeatureAttributionSequenceOutput(FeatureAttributionSequenc
attention head and per layer for every source-target token pair in the source attributions (i.e. 2 dimensions).
"""

_num_dimensions: int = 2

def __post_init__(self):
super().__post_init__()
self._aggregator = ["mean"] * self._num_dimensions
Expand All @@ -1233,7 +1241,6 @@ def __post_init__(self):
class MultiDimensionalFeatureAttributionStepOutput(FeatureAttributionStepOutput):
"""Raw output of a single step of multi-dimensional feature attribution."""

_num_dimensions: int = 2
_sequence_cls: type["FeatureAttributionSequenceOutput"] = MultiDimensionalFeatureAttributionSequenceOutput

def get_sequence_cls(self, **kwargs):
Expand Down
Loading