Skip to content

Commit

Permalink
Fix attribute_context
Browse files Browse the repository at this point in the history
  • Loading branch information
gsarti committed Aug 22, 2024
1 parent a33fae2 commit 9c118bb
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 9 deletions.
7 changes: 5 additions & 2 deletions inseq/commands/attribute_context/attribute_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def attribute_context_with_model(args: AttributeContextArgs, model: HuggingfaceM
output_context=args.output_context_text,
output_context_tokens=output_context_tokens,
output_current=args.output_current_text,
output_current_tokens=output_current_tokens,
output_current_tokens=cti_tokens,
cti_scores=cti_scores,
info=args,
)
Expand Down Expand Up @@ -197,7 +197,10 @@ def attribute_context_with_model(args: AttributeContextArgs, model: HuggingfaceM
contrast_token = output_ctxless_tokens[tok_pos]
if args.attributed_fn == "kl_divergence" or output_ctx_tokens[tok_pos] == output_ctxless_tokens[tok_pos]:
cci_kwargs["contrast_force_inputs"] = True
bos_offset = int(model.is_encoder_decoder or output_ctx_tokens[0] == model.bos_token)
bos_offset = int(
model.is_encoder_decoder
or (output_ctx_tokens[0] == model.bos_token and model.bos_token not in args.special_tokens_to_keep)
)
pos_start = output_current_text_offset + cti_idx + bos_offset + int(has_lang_tag)
cci_attrib_out = model.attribute(
contextual_input,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -321,8 +321,8 @@ def visualize_attribute_context_treescope(
scores=cci.input_context_scores,
title=f'Input contextual cues for "{cleaned_curr_tok}"',
title_style="font-style: italic; color: #888888;",
min_val=output.min_cci,
max_val=output.max_cci,
min_val=cci.minimum,
max_val=cci.maximum,
rounding=10,
colormap=cmap_cci,
strip_chars=replace_chars,
Expand All @@ -337,8 +337,8 @@ def visualize_attribute_context_treescope(
scores=cci.output_context_scores,
title=f'Output contextual cue for "{cleaned_curr_tok}"',
title_style="font-style: italic; color: #888888;",
min_val=output.min_cci,
max_val=output.max_cci,
min_val=cci.minimum,
max_val=cci.maximum,
rounding=10,
colormap=cmap_cci,
strip_chars=replace_chars,
Expand Down
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ scipy==1.12.0
# via scikit-learn
sentencepiece==0.2.0
# via transformers
setuptools==69.1.0
setuptools==72.2.0
# via
# nodeenv
# safety
Expand Down
4 changes: 2 additions & 2 deletions tests/commands/test_attribute_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def test_in_ctx_deconly(deconly_model: GPT2LMHeadModel):
output_context=None,
output_context_tokens=None,
output_current="to the hospital. He said he was fine",
output_current_tokens=["to", "Ġthe", "Ġhospital", ".", "ĠHe", "Ġsaid", "Ġhe", "Ġwas", "Ġfine"],
output_current_tokens=["Ġto", "Ġthe", "Ġhospital", ".", "ĠHe", "Ġsaid", "Ġhe", "Ġwas", "Ġfine"],
cti_scores=[0.31, 0.25, 0.55, 0.16, 0.43, 0.19, 0.13, 0.07, 0.37],
cci_scores=[
CCIOutput(
Expand Down Expand Up @@ -227,7 +227,7 @@ def test_in_out_ctx_deconly(deconly_model: GPT2LMHeadModel):
output_context="something was wrong. He said",
output_context_tokens=["something", "Ġwas", "Ġwrong", ".", "ĠHe", "Ġsaid"],
output_current="he was fine.",
output_current_tokens=["he", "Ġwas", "Ġfine", "."],
output_current_tokens=["Ġhe", "Ġwas", "Ġfine", "."],
cti_scores=[1.2, 0.72, 1.5, 0.49],
cci_scores=[
CCIOutput(
Expand Down

0 comments on commit 9c118bb

Please sign in to comment.