Skip to content

Commit

Permalink
Add possibility to skip special tokens during attribution (#275)
Browse files Browse the repository at this point in the history
* Add possibility to skip special tokens during attribution

* Bump pillow

* Bump idna
  • Loading branch information
gsarti authored May 15, 2024
1 parent 04dde30 commit 96add78
Show file tree
Hide file tree
Showing 10 changed files with 72 additions and 14 deletions.
16 changes: 14 additions & 2 deletions inseq/attr/feat/feature_attribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ def prepare_and_attribute(
attribute_target: bool = False,
step_scores: list[str] = [],
include_eos_baseline: bool = False,
skip_special_tokens: bool = False,
attributed_fn: Union[str, Callable[..., SingleScorePerStepTensor], None] = None,
attribution_args: dict[str, Any] = {},
attributed_fn_args: dict[str, Any] = {},
Expand Down Expand Up @@ -206,6 +207,8 @@ def prepare_and_attribute(
step scores can be added by using the :meth:`~inseq.register_step_function` function.
include_eos_baseline (:obj:`bool`, `optional`): Whether to include the EOS token in the baseline for
attribution. By default the EOS token is not used for attribution. Defaults to False.
skip_special_tokens (:obj:`bool`, `optional`): Whether to skip special tokens when encoding the input.
Defaults to False.
attributed_fn (:obj:`str` or :obj:`Callable[..., SingleScorePerStepTensor]`, `optional`): The identifier or
function of model outputs representing what should be attributed (e.g. output probits of model best
prediction after softmax). If it is a string, it must be a valid function.
Expand All @@ -224,12 +227,14 @@ def prepare_and_attribute(
inputs = (sources, targets)
if not self.attribution_model.is_encoder_decoder:
inputs = targets
encoded_sources = self.attribution_model.encode(sources, return_baseline=True)
encoded_sources = self.attribution_model.encode(
sources, return_baseline=True, add_special_tokens=not skip_special_tokens
)
# We do this here to support separate attr_pos_start for different sentences when batching
if attr_pos_start is None or attr_pos_start < encoded_sources.input_ids.shape[1]:
attr_pos_start = encoded_sources.input_ids.shape[1]
batch = self.attribution_model.formatter.prepare_inputs_for_attribution(
self.attribution_model, inputs, include_eos_baseline
self.attribution_model, inputs, include_eos_baseline, skip_special_tokens
)
# If prepare_and_attribute was called from AttributionModel.attribute,
# attributed_fn is already a Callable. Keep here to allow for usage independently
Expand All @@ -245,6 +250,7 @@ def prepare_and_attribute(
output_step_attributions=output_step_attributions,
attribute_target=attribute_target,
step_scores=step_scores,
skip_special_tokens=skip_special_tokens,
attribution_args=attribution_args,
attributed_fn_args=attributed_fn_args,
step_scores_args=step_scores_args,
Expand Down Expand Up @@ -310,6 +316,7 @@ def format_contrastive_targets(
step_scores_args: dict[str, Any],
attr_pos_start: int,
attr_pos_end: int,
skip_special_tokens: bool = False,
) -> tuple[Optional[DecoderOnlyBatch], Optional[list[list[tuple[int, int]]]], dict[str, Any], dict[str, Any]]:
contrast_batch, contrast_targets_alignments = None, None
contrast_targets = attributed_fn_args.get("contrast_targets", None)
Expand All @@ -327,6 +334,7 @@ def format_contrastive_targets(
attribution_model=self.attribution_model,
inputs=contrast_targets,
as_targets=as_targets,
skip_special_tokens=skip_special_tokens,
)
contrast_batch = DecoderOnlyBatch.from_batch(contrast_batch)
clean_tgt_tokens = self.attribution_model.clean_tokens(target_tokens, as_targets=as_targets)
Expand Down Expand Up @@ -358,6 +366,7 @@ def attribute(
output_step_attributions: bool = False,
attribute_target: bool = False,
step_scores: list[str] = [],
skip_special_tokens: bool = False,
attribution_args: dict[str, Any] = {},
attributed_fn_args: dict[str, Any] = {},
step_scores_args: dict[str, Any] = {},
Expand Down Expand Up @@ -385,6 +394,8 @@ def attribute(
step_scores (:obj:`list` of `str`): List of identifiers for step scores that need to be computed during
attribution. The available step scores are defined in :obj:`inseq.attr.feat.STEP_SCORES_MAP` and new
step scores can be added by using the :meth:`~inseq.register_step_function` function.
skip_special_tokens (:obj:`bool`, `optional`): Whether to skip special tokens when encoding the input.
Defaults to False.
attribution_args (:obj:`dict`, `optional`): Additional arguments to pass to the attribution method.
attributed_fn_args (:obj:`dict`, `optional`): Additional arguments to pass to the attributed function.
step_scores_args (:obj:`dict`, `optional`): Additional arguments to pass to the step scores function.
Expand Down Expand Up @@ -419,6 +430,7 @@ def attribute(
step_scores_args,
attr_pos_start,
attr_pos_end,
skip_special_tokens,
)
target_tokens_with_ids = self.attribution_model.get_token_with_ids(
batch,
Expand Down
14 changes: 14 additions & 0 deletions inseq/attr/step_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def contrast_logits_fn(
contrast_targets: Optional[FeatureAttributionInput] = None,
contrast_targets_alignments: Optional[list[list[tuple[int, int]]]] = None,
contrast_force_inputs: bool = False,
skip_special_tokens: bool = False,
):
"""Returns the logit of a generation target given contrastive context or target prediction alternative.
If only ``contrast_targets`` are specified, the logit of the contrastive prediction is computed given same
Expand All @@ -144,6 +145,7 @@ def contrast_logits_fn(
contrast_targets=contrast_targets,
contrast_targets_alignments=contrast_targets_alignments,
contrast_force_inputs=contrast_force_inputs,
skip_special_tokens=skip_special_tokens,
)
return logit_fn(c_args)

Expand All @@ -156,6 +158,7 @@ def contrast_prob_fn(
contrast_targets_alignments: Optional[list[list[tuple[int, int]]]] = None,
logprob: bool = False,
contrast_force_inputs: bool = False,
skip_special_tokens: bool = False,
):
"""Returns the probability of a generation target given contrastive context or target prediction alternative.
If only ``contrast_targets`` are specified, the probability of the contrastive prediction is computed given same
Expand All @@ -168,6 +171,7 @@ def contrast_prob_fn(
contrast_targets=contrast_targets,
contrast_targets_alignments=contrast_targets_alignments,
contrast_force_inputs=contrast_force_inputs,
skip_special_tokens=skip_special_tokens,
)
return probability_fn(c_args, logprob=logprob)

Expand All @@ -179,6 +183,7 @@ def pcxmi_fn(
contrast_targets: Optional[FeatureAttributionInput] = None,
contrast_targets_alignments: Optional[list[list[tuple[int, int]]]] = None,
contrast_force_inputs: bool = False,
skip_special_tokens: bool = False,
) -> SingleScorePerStepTensor:
"""Compute the pointwise conditional cross-mutual information (P-CXMI) of target ids given original and contrastive
input options. The P-CXMI is defined as the negative log-ratio between the conditional probability of the target
Expand All @@ -192,6 +197,7 @@ def pcxmi_fn(
contrast_targets=contrast_targets,
contrast_targets_alignments=contrast_targets_alignments,
contrast_force_inputs=contrast_force_inputs,
skip_special_tokens=skip_special_tokens,
).to(original_probs.device)
return -torch.log2(torch.div(original_probs, contrast_probs))

Expand All @@ -206,6 +212,7 @@ def kl_divergence_fn(
top_p: float = 1.0,
min_tokens_to_keep: int = 1,
contrast_force_inputs: bool = False,
skip_special_tokens: bool = False,
) -> SingleScorePerStepTensor:
"""Compute the pointwise Kullback-Leibler divergence of target ids given original and contrastive input options.
The KL divergence is the expectation of the log difference between the probabilities of regular (P) and contrastive
Expand Down Expand Up @@ -233,6 +240,7 @@ def kl_divergence_fn(
contrast_targets_alignments=contrast_targets_alignments,
return_contrastive_target_ids=False,
return_contrastive_batch=True,
skip_special_tokens=skip_special_tokens,
)
c_forward_output = args.attribution_model.get_forward_output(
contrast_inputs.batch, use_embeddings=args.attribution_model.is_encoder_decoder
Expand Down Expand Up @@ -263,6 +271,7 @@ def contrast_prob_diff_fn(
contrast_targets_alignments: Optional[list[list[tuple[int, int]]]] = None,
logprob: bool = False,
contrast_force_inputs: bool = False,
skip_special_tokens: bool = False,
):
"""Returns the difference between next step probability for a candidate generation target vs. a contrastive
alternative. Can be used as attribution target to answer the question: "Which features were salient in the
Expand All @@ -279,6 +288,7 @@ def contrast_prob_diff_fn(
contrast_targets_alignments=contrast_targets_alignments,
logprob=logprob,
contrast_force_inputs=contrast_force_inputs,
skip_special_tokens=skip_special_tokens,
).to(model_probs.device)
return model_probs - contrast_probs

Expand All @@ -290,6 +300,7 @@ def contrast_logits_diff_fn(
contrast_targets: Optional[FeatureAttributionInput] = None,
contrast_targets_alignments: Optional[list[list[tuple[int, int]]]] = None,
contrast_force_inputs: bool = False,
skip_special_tokens: bool = False,
):
"""Equivalent to ``contrast_prob_diff_fn`` but for logits. The original target function used in
`Yin and Neubig (2022) <https://aclanthology.org/2022.emnlp-main.14>`__
Expand All @@ -301,6 +312,7 @@ def contrast_logits_diff_fn(
contrast_targets=contrast_targets,
contrast_targets_alignments=contrast_targets_alignments,
contrast_force_inputs=contrast_force_inputs,
skip_special_tokens=skip_special_tokens,
).to(model_logits.device)
return model_logits - contrast_logits

Expand All @@ -312,6 +324,7 @@ def in_context_pvi_fn(
contrast_targets: Optional[FeatureAttributionInput] = None,
contrast_targets_alignments: Optional[list[list[tuple[int, int]]]] = None,
contrast_force_inputs: bool = False,
skip_special_tokens: bool = False,
):
"""Returns the in-context pointwise V-usable information as defined by `Lu et al. (2023)
<https://arxiv.org/abs/2310.12300>`__. In-context PVI is a variant of P-CXMI that captures the amount of usable
Expand All @@ -330,6 +343,7 @@ def in_context_pvi_fn(
contrast_targets_alignments=contrast_targets_alignments,
logprob=True,
contrast_force_inputs=contrast_force_inputs,
skip_special_tokens=skip_special_tokens,
).to(orig_logprob.device)
return -orig_logprob + contrast_logprob

Expand Down
10 changes: 8 additions & 2 deletions inseq/data/attribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def get_batch_from_inputs(
inputs: FeatureAttributionInput,
include_eos_baseline: bool = False,
as_targets: bool = False,
skip_special_tokens: bool = False,
) -> Batch:
if isinstance(inputs, Batch):
batch = inputs
Expand All @@ -57,6 +58,7 @@ def get_batch_from_inputs(
as_targets=as_targets,
return_baseline=True,
include_eos_baseline=include_eos_baseline,
add_special_tokens=not skip_special_tokens,
)
elif isinstance(inputs, BatchEncoding):
encodings = inputs
Expand All @@ -66,8 +68,12 @@ def get_batch_from_inputs(
"Inputs must be either a string, a list of strings, a BatchEncoding or a Batch."
)
embeddings = BatchEmbedding(
input_embeds=attribution_model.embed(encodings.input_ids, as_targets=as_targets),
baseline_embeds=attribution_model.embed(encodings.baseline_ids, as_targets=as_targets),
input_embeds=attribution_model.embed(
encodings.input_ids, as_targets=as_targets, add_special_tokens=not skip_special_tokens
),
baseline_embeds=attribution_model.embed(
encodings.baseline_ids, as_targets=as_targets, add_special_tokens=not skip_special_tokens
),
)
batch = Batch(encodings, embeddings)
return batch
Expand Down
23 changes: 19 additions & 4 deletions inseq/models/attribution_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def prepare_inputs_for_attribution(
attribution_model: "AttributionModel",
inputs: FeatureAttributionInput,
include_eos_baseline: bool = False,
skip_special_tokens: bool = False,
) -> Union[DecoderOnlyBatch, EncoderDecoderBatch]:
raise NotImplementedError()

Expand Down Expand Up @@ -316,6 +317,7 @@ def attribute(
device: Optional[str] = None,
batch_size: Optional[int] = None,
generate_from_target_prefix: bool = False,
skip_special_tokens: bool = False,
generation_args: dict[str, Any] = {},
**kwargs,
) -> FeatureAttributionOutput:
Expand Down Expand Up @@ -365,6 +367,8 @@ def attribute(
target prefixes for the generation process. If False, the ``generated_texts`` will be used as full
targets. This option is only available for encoder-decoder models, since the same behavior can be
achieved by modifying the input texts for decoder-only models. Default: False.
skip_special_tokens (:obj:`bool`, `optional`): Whether to skip special tokens when attributing the input
texts. Default: False.
**kwargs: Additional keyword arguments. These can include keyword arguments for the attribution method, for
the generation process or for the attributed function. Generation arguments can be provided explicitly
as a dictionary named ``generation_args``.
Expand All @@ -389,6 +393,8 @@ def attribute(
self.device = device
attribution_method = self.get_attribution_method(method, override_default_attribution)
attributed_fn = self.get_attributed_fn(attributed_fn)
if skip_special_tokens:
kwargs["skip_special_tokens"] = True
attribution_args, attributed_fn_args, step_scores_args = extract_args(
attribution_method,
attributed_fn,
Expand Down Expand Up @@ -418,9 +424,16 @@ def attribute(
logger.info(f"Splitting input texts into {n_batches} batches of size {batch_size}.")
# If constrained decoding is not enabled, output texts are generated from input texts.
if not has_generated_texts or generate_from_target_prefix:
encoded_input = self.encode(input_texts, return_baseline=True, include_eos_baseline=include_eos_baseline)
encoded_input = self.encode(
input_texts,
return_baseline=True,
include_eos_baseline=include_eos_baseline,
add_special_tokens=not skip_special_tokens,
)
if generate_from_target_prefix:
decoder_input = self.encode(generated_texts, as_targets=True)
decoder_input = self.encode(
generated_texts, as_targets=True, add_special_tokens=not skip_special_tokens
)
generation_args["decoder_input_ids"] = decoder_input.input_ids
generated_texts = self.generate(
encoded_input, return_generation_output=False, batch_size=batch_size, **generation_args
Expand Down Expand Up @@ -467,6 +480,7 @@ def attribute(
attribute_target=attribute_target,
step_scores=step_scores,
include_eos_baseline=include_eos_baseline,
skip_special_tokens=skip_special_tokens,
attributed_fn=attributed_fn,
attribution_args=attribution_args,
attributed_fn_args=attributed_fn_args,
Expand All @@ -484,11 +498,11 @@ def attribute(
self.device = original_device
return attribution_output

def embed(self, inputs: Union[TextInput, IdsTensor], as_targets: bool = False):
def embed(self, inputs: Union[TextInput, IdsTensor], as_targets: bool = False, add_special_tokens: bool = True):
if isinstance(inputs, str) or (
isinstance(inputs, list) and len(inputs) > 0 and all(isinstance(x, str) for x in inputs)
):
batch = self.encode(inputs, as_targets)
batch = self.encode(inputs, as_targets, add_special_tokens=add_special_tokens)
inputs = batch.input_ids
return self.embed_ids(inputs, as_targets=as_targets)

Expand Down Expand Up @@ -531,6 +545,7 @@ def encode(
as_targets: bool = False,
return_baseline: bool = False,
include_eos_baseline: bool = False,
add_special_tokens: bool = True,
) -> BatchEncoding:
pass

Expand Down
2 changes: 2 additions & 0 deletions inseq/models/decoder_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,14 @@ def prepare_inputs_for_attribution(
attribution_model: "DecoderOnlyAttributionModel",
inputs: FeatureAttributionInput,
include_eos_baseline: bool = False,
skip_special_tokens: bool = False,
) -> DecoderOnlyBatch:
batch = get_batch_from_inputs(
attribution_model,
inputs=inputs,
include_eos_baseline=include_eos_baseline,
as_targets=False,
skip_special_tokens=skip_special_tokens,
)
return DecoderOnlyBatch.from_batch(batch)

Expand Down
3 changes: 3 additions & 0 deletions inseq/models/encoder_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def prepare_inputs_for_attribution(
attribution_model: "EncoderDecoderAttributionModel",
inputs: tuple[FeatureAttributionInput, FeatureAttributionInput],
include_eos_baseline: bool = False,
skip_special_tokens: bool = False,
) -> EncoderDecoderBatch:
r"""Prepares sources and target to produce an :class:`~inseq.data.EncoderDecoderBatch`.
There are two stages of preparation:
Expand Down Expand Up @@ -67,12 +68,14 @@ def prepare_inputs_for_attribution(
inputs=sources,
include_eos_baseline=include_eos_baseline,
as_targets=False,
skip_special_tokens=skip_special_tokens,
)
target_batch = get_batch_from_inputs(
attribution_model,
inputs=targets,
include_eos_baseline=include_eos_baseline,
as_targets=True,
skip_special_tokens=skip_special_tokens,
)
return EncoderDecoderBatch(source_batch, target_batch)

Expand Down
2 changes: 1 addition & 1 deletion inseq/models/huggingface_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def generate(
if isinstance(inputs, str) or (
isinstance(inputs, list) and len(inputs) > 0 and all(isinstance(x, str) for x in inputs)
):
inputs = self.encode(inputs)
inputs = self.encode(inputs, add_special_tokens=not skip_special_tokens)
inputs = inputs.to(self.device)
generation_out = self.model.generate(
inputs=inputs.input_ids,
Expand Down
Loading

0 comments on commit 96add78

Please sign in to comment.