diff --git a/inseq/commands/attribute_context/attribute_context.py b/inseq/commands/attribute_context/attribute_context.py index 8c84ee2..49ec694 100644 --- a/inseq/commands/attribute_context/attribute_context.py +++ b/inseq/commands/attribute_context/attribute_context.py @@ -148,7 +148,7 @@ def attribute_context_with_model(args: AttributeContextArgs, model: HuggingfaceM output_context=args.output_context_text, output_context_tokens=output_context_tokens, output_current=args.output_current_text, - output_current_tokens=output_current_tokens, + output_current_tokens=cti_tokens, cti_scores=cti_scores, info=args, ) @@ -197,7 +197,10 @@ def attribute_context_with_model(args: AttributeContextArgs, model: HuggingfaceM contrast_token = output_ctxless_tokens[tok_pos] if args.attributed_fn == "kl_divergence" or output_ctx_tokens[tok_pos] == output_ctxless_tokens[tok_pos]: cci_kwargs["contrast_force_inputs"] = True - bos_offset = int(model.is_encoder_decoder or output_ctx_tokens[0] == model.bos_token) + bos_offset = int( + model.is_encoder_decoder + or (output_ctx_tokens[0] == model.bos_token and model.bos_token not in args.special_tokens_to_keep) + ) pos_start = output_current_text_offset + cti_idx + bos_offset + int(has_lang_tag) cci_attrib_out = model.attribute( contextual_input, diff --git a/inseq/commands/attribute_context/attribute_context_viz_helpers.py b/inseq/commands/attribute_context/attribute_context_viz_helpers.py index babed31..62329ef 100644 --- a/inseq/commands/attribute_context/attribute_context_viz_helpers.py +++ b/inseq/commands/attribute_context/attribute_context_viz_helpers.py @@ -321,8 +321,8 @@ def visualize_attribute_context_treescope( scores=cci.input_context_scores, title=f'Input contextual cues for "{cleaned_curr_tok}"', title_style="font-style: italic; color: #888888;", - min_val=output.min_cci, - max_val=output.max_cci, + min_val=cci.minimum, + max_val=cci.maximum, rounding=10, colormap=cmap_cci, strip_chars=replace_chars, @@ -337,8 +337,8 @@ def visualize_attribute_context_treescope( scores=cci.output_context_scores, title=f'Output contextual cue for "{cleaned_curr_tok}"', title_style="font-style: italic; color: #888888;", - min_val=output.min_cci, - max_val=output.max_cci, + min_val=cci.minimum, + max_val=cci.maximum, rounding=10, colormap=cmap_cci, strip_chars=replace_chars, diff --git a/requirements-dev.txt b/requirements-dev.txt index 40a3eb2..e826e32 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -318,7 +318,7 @@ scipy==1.12.0 # via scikit-learn sentencepiece==0.2.0 # via transformers -setuptools==69.1.0 +setuptools==72.2.0 # via # nodeenv # safety diff --git a/tests/commands/test_attribute_context.py b/tests/commands/test_attribute_context.py index a0aa05c..555f0f3 100644 --- a/tests/commands/test_attribute_context.py +++ b/tests/commands/test_attribute_context.py @@ -90,7 +90,7 @@ def test_in_ctx_deconly(deconly_model: GPT2LMHeadModel): output_context=None, output_context_tokens=None, output_current="to the hospital. He said he was fine", - output_current_tokens=["to", "Ġthe", "Ġhospital", ".", "ĠHe", "Ġsaid", "Ġhe", "Ġwas", "Ġfine"], + output_current_tokens=["Ġto", "Ġthe", "Ġhospital", ".", "ĠHe", "Ġsaid", "Ġhe", "Ġwas", "Ġfine"], cti_scores=[0.31, 0.25, 0.55, 0.16, 0.43, 0.19, 0.13, 0.07, 0.37], cci_scores=[ CCIOutput( @@ -227,7 +227,7 @@ def test_in_out_ctx_deconly(deconly_model: GPT2LMHeadModel): output_context="something was wrong. He said", output_context_tokens=["something", "Ġwas", "Ġwrong", ".", "ĠHe", "Ġsaid"], output_current="he was fine.", - output_current_tokens=["he", "Ġwas", "Ġfine", "."], + output_current_tokens=["Ġhe", "Ġwas", "Ġfine", "."], cti_scores=[1.2, 0.72, 1.5, 0.49], cci_scores=[ CCIOutput(