Skip to content

Commit c8bacde

Browse files
Merge answer_choices and metadata. (#548)
* Combine answer_choices and answer_choices_key in Python code. * Remove span squad and sequence accuracy from app choices. * Test merge on community dataset. * Fix path parsing for community datasets. * Forgot to put the slash back in. * Just skip it. * Actually skipping. * Test just the a datasets. * Update all the rest. * some missing traces of `answer_choices_key` * Fix handling of answer_choices when not set. Co-authored-by: VictorSanh <[email protected]>
1 parent a3f95dd commit c8bacde

File tree

276 files changed

+1069
-4721
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

276 files changed

+1069
-4721
lines changed

promptsource/app.py

+8-24
Original file line numberDiff line numberDiff line change
@@ -348,8 +348,6 @@ def get_infos(d_name):
348348
st.markdown("##### Metrics")
349349
st.text(", ".join(template.metadata.metrics) if template.metadata.metrics else None)
350350
st.markdown("##### Answer Choices")
351-
st.text(", ".join(template.answer_choices) if template.answer_choices is not None else None)
352-
st.markdown("##### Answer Choices Key")
353351
if template.get_answer_choices_expr() is not None:
354352
show_jinja(template.get_answer_choices_expr())
355353
else:
@@ -491,11 +489,9 @@ def get_infos(d_name):
491489
metrics_choices = [
492490
"BLEU",
493491
"ROUGE",
494-
"Span Squad",
495492
"Squad",
496493
"Trivia QA",
497494
"Accuracy",
498-
"Sequence Accuracy",
499495
"Pearson Correlation",
500496
"Spearman Correlation",
501497
"MultiRC",
@@ -518,21 +514,13 @@ def get_infos(d_name):
518514
)
519515

520516
# Answer choices
521-
state.answer_choices = st.text_input(
522-
"Answer Choices",
523-
value=" ||| ".join(template.answer_choices) if template.answer_choices is not None else "",
524-
help="A ||| separated list of possible outputs (or leave blank). "
525-
+ "Value is available in Jinja in a list called answer_choices.",
526-
)
527-
528-
# Answer choices key
529517
if template.get_answer_choices_expr() is not None:
530-
answer_choices_key = template.get_answer_choices_expr()
518+
answer_choices = template.get_answer_choices_expr()
531519
else:
532-
answer_choices_key = ""
533-
state.answer_choices_key = st.text_input(
534-
"Answer Choices Key",
535-
value=answer_choices_key,
520+
answer_choices = ""
521+
state.answer_choices = st.text_input(
522+
"Answer Choices",
523+
value=answer_choices,
536524
help="A Jinja expression for computing answer choices. "
537525
"Separate choices with a triple bar (|||).",
538526
)
@@ -553,14 +541,11 @@ def get_infos(d_name):
553541
elif updated_template_name == "":
554542
st.error("Need to provide a template name.")
555543
else:
556-
# Parses state.answer_choices and state.answer_choices_key
557-
updated_answer_choices = [x.strip() for x in state.answer_choices.split("|||")]
558-
if len(updated_answer_choices) == 0 or len(updated_answer_choices) == 1:
544+
# Parses state.answer_choices
545+
if state.answer_choices == "":
559546
updated_answer_choices = None
560-
if state.answer_choices_key == "":
561-
updated_answer_choices_key = None
562547
else:
563-
updated_answer_choices_key = state.answer_choices_key
548+
updated_answer_choices = state.answer_choices
564549

565550
dataset_templates.update_template(
566551
state.template_name,
@@ -569,7 +554,6 @@ def get_infos(d_name):
569554
state.reference,
570555
state.metadata,
571556
updated_answer_choices,
572-
updated_answer_choices_key,
573557
)
574558
# Update the state as well
575559
state.template_name = updated_template_name

promptsource/seqio_tasks/tasks.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def add_task(dataset_name, subset_name, template_name, task_name=None, split_map
118118
)
119119

120120
# Add rank classification eval task
121-
if template.answer_choices or template.answer_choices_key:
121+
if template.answer_choices:
122122
rank_classification_preprocessor = functools.partial(
123123
t5.data.preprocessors.rank_classification,
124124
inputs_fn=lambda ex: tf.fill((len(ex["answer_choices"]),), ex["inputs"]),
@@ -225,7 +225,7 @@ def add_task(dataset_name, subset_name, template_name, task_name=None, split_map
225225
if (dataset_name, subset_name) in d4_eval:
226226
if template.metadata.original_task:
227227
d4_eval_mixture.append(task_name)
228-
# TODO use template.metadata.answer_choices or answer_choice_keys here for rank eval
228+
# TODO use template.metadata.answer_choices here for rank eval
229229
if (dataset_name, subset_name) in bias_fairness_eval:
230230
bias_fairness_eval_mixture.append(task_name)
231231

@@ -245,7 +245,7 @@ def add_task(dataset_name, subset_name, template_name, task_name=None, split_map
245245
template = dataset[template_name]
246246
if template.metadata.original_task:
247247
d4_eval_mixture.append(task_name) # TODO or add to ANLI special mixture
248-
# TODO use template.metadata.answer_choices or answer_choice_keys here for rank eval
248+
# TODO use template.metadata.answer_choices here for rank eval
249249

250250

251251
TASK_BLACKLIST = [

promptsource/templates.py

+17-39
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ class Template(yaml.YAMLObject):
5858

5959
yaml_tag = "!Template"
6060

61-
def __init__(self, name, jinja, reference, metadata=None, answer_choices=None, answer_choices_key=None):
61+
def __init__(self, name, jinja, reference, metadata=None, answer_choices=None):
6262
"""
6363
Creates a prompt template.
6464
@@ -73,23 +73,19 @@ def __init__(self, name, jinja, reference, metadata=None, answer_choices=None, a
7373
:param jinja: template expressed in Jinja
7474
:param reference: string describing author or paper reference for template
7575
:param metadata: a Metadata object with template annotations
76-
:param answer_choices: list of strings that enumerates the possible completions
77-
for templates that should be evaluated as ranked
78-
completions. If None, then the template is open-ended.
79-
This list is accessible from within Jinja as the
80-
variable `answer_choices`.
81-
TODO: Merge answer_choices and answer_choices_key
82-
:param answer_choices_key: Jinja expression for answer choices, or None if
83-
no answer choices
84-
76+
:param answer_choices: Jinja expression for answer choices. Should produce
77+
a ||| delimited string of choices that enumerates
78+
the possible completions for templates that should
79+
be evaluated as ranked completions. If None, then
80+
the template is open-ended. This list is accessible
81+
from within Jinja as the variable `answer_choices`.
8582
"""
8683
self.id = str(uuid.uuid4())
8784
self.name = name
8885
self.jinja = jinja
8986
self.reference = reference
9087
self.metadata = metadata if metadata is not None else Template.Metadata()
9188
self.answer_choices = answer_choices
92-
self.answer_choices_key = answer_choices_key
9389

9490
def get_id(self):
9591
"""
@@ -115,36 +111,24 @@ def get_reference(self):
115111
"""
116112
return self.reference
117113

118-
def get_answer_choices(self):
119-
"""
120-
Returns a list of strings enumerating the possible completions for
121-
this template, or None if the template is open ended.
122-
123-
:return: List[String]
124-
"""
125-
# TODO: Replace answer_choices with answer_choices_key values
126-
return self.answer_choices
127-
128114
def get_answer_choices_expr(self):
129115
"""
130116
Returns a Jinja expression for computing the answer choices from an example.
131117
132118
:return: String, or None if no answer choices
133119
"""
134-
# TODO: Change to return answer_choices
135-
return self.answer_choices_key
120+
return self.answer_choices
136121

137122
def get_answer_choices_list(self, example):
138123
"""
139124
Returns a list of answer choices for a given example
140125
141126
:return: list of strings, or None if get_answer_choices_expr is None
142127
"""
143-
# TODO: remove when merging answer_choices and answer_choices_key
144-
if self.get_answer_choices_expr() is None:
145-
return self.get_answer_choices()
128+
jinja = self.get_answer_choices_expr()
129+
if jinja is None:
130+
return None
146131

147-
jinja = self.answer_choices_key
148132
rtemplate = env.from_string(jinja)
149133
protected_example = self._escape_pipe(example)
150134
rendered_choices = rtemplate.render(**protected_example)
@@ -156,11 +140,10 @@ def get_fixed_answer_choices_list(self):
156140
157141
:return: list of strings, or None if no static list exists
158142
"""
159-
# TODO: remove when merging answer_choices and answer_choices_key
160-
if self.get_answer_choices_expr() is None:
161-
return self.get_answer_choices()
143+
jinja = self.get_answer_choices_expr()
144+
if jinja is None:
145+
return None
162146

163-
jinja = self.answer_choices_key
164147
parse = env.parse(jinja)
165148
variables = meta.find_undeclared_variables(parse)
166149
if len(variables) == 0:
@@ -450,8 +433,7 @@ def update_template(
450433
jinja: str,
451434
reference: str,
452435
metadata: Template.Metadata,
453-
answer_choices: List[str],
454-
answer_choices_key: str,
436+
answer_choices: str,
455437
) -> None:
456438
"""
457439
Updates a pre-existing template and writes changes
@@ -461,16 +443,14 @@ def update_template(
461443
:param jinja: new jinja entry
462444
:param reference: new reference entry
463445
:param metadata: a Metadata object with template annotations
464-
:param answer_choices: new answer_choices list
465-
:param answer_choices_key: new answer_choices_key string
446+
:param answer_choices: new answer_choices string
466447
"""
467448
template_id = self.name_to_id_mapping[current_template_name]
468449
self.templates[template_id].name = new_template_name
469450
self.templates[template_id].jinja = jinja
470451
self.templates[template_id].reference = reference
471452
self.templates[template_id].metadata = metadata
472453
self.templates[template_id].answer_choices = answer_choices
473-
self.templates[template_id].answer_choices_key = answer_choices_key
474454

475455
self.write_to_file()
476456

@@ -512,7 +492,6 @@ def get_templates_data_frame():
512492
"choices_in_prompt": [],
513493
"metrics": [],
514494
"answer_choices": [],
515-
"answer_choices_key": [],
516495
"jinja": [],
517496
}
518497

@@ -530,8 +509,7 @@ def get_templates_data_frame():
530509
data["original_task"].append(template.metadata.original_task)
531510
data["choices_in_prompt"].append(template.metadata.choices_in_prompt)
532511
data["metrics"].append(template.metadata.metrics)
533-
data["answer_choices"].append(template.get_answer_choices())
534-
data["answer_choices_key"].append(template.get_answer_choices_expr())
512+
data["answer_choices"].append(template.get_answer_choices_expr())
535513
data["jinja"].append(template.jinja)
536514

537515
return pd.DataFrame(data)

promptsource/templates/Zaid/coqa_expanded/templates.yaml

-7
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ dataset: Zaid/coqa_expanded
22
templates:
33
12ad4331-d063-4b56-b0f6-76f59c690717: !Template
44
answer_choices: null
5-
answer_choices_key: null
65
id: 12ad4331-d063-4b56-b0f6-76f59c690717
76
jinja: "Below is a passage, followed by a series of questions and answers about\
87
\ the passage. Answer the last question based on the information contained in\
@@ -18,7 +17,6 @@ templates:
1817
reference: 'Metric: variant of SQuAD (Section 6.1 of the paper)'
1918
2f9fb20d-f4c9-4371-9cd4-db47607cb7a3: !Template
2019
answer_choices: null
21-
answer_choices_key: null
2220
id: 2f9fb20d-f4c9-4371-9cd4-db47607cb7a3
2321
jinja: "What is the answer to the last question in the dialogue below? If there\
2422
\ is no answer in the passage, say \"unknown\".\n\nPassage: {{story}}\n\nQ:\
@@ -33,7 +31,6 @@ templates:
3331
reference: 'Metric: variant of SQuAD (Section 6.1 of the paper)'
3432
9aff8967-d41c-4d79-8ef4-fc3650773735: !Template
3533
answer_choices: null
36-
answer_choices_key: null
3734
id: 9aff8967-d41c-4d79-8ef4-fc3650773735
3835
jinja: "Complete the dialogue based on the information contained in the passage.\
3936
\ If there is no answer in the passage, say \"unknown\".\n\nPassage: {{story}}\n\
@@ -48,7 +45,6 @@ templates:
4845
reference: 'Metric: variant of SQuAD (Section 6.1 of the paper)'
4946
9bc32f2e-eee6-4006-bce3-74a79403d33e: !Template
5047
answer_choices: null
51-
answer_choices_key: null
5248
id: 9bc32f2e-eee6-4006-bce3-74a79403d33e
5349
jinja: "Answer the last question based on the information contained in the passage.\
5450
\ If there is no answer in the passage, say \"unknown\".\n\nPassage: {{story}}\n\
@@ -63,7 +59,6 @@ templates:
6359
reference: 'Metric: variant of SQuAD (Section 6.1 of the paper)'
6460
bacb6534-e607-4afc-a412-ccfcd9fe38e2: !Template
6561
answer_choices: null
66-
answer_choices_key: null
6762
id: bacb6534-e607-4afc-a412-ccfcd9fe38e2
6863
jinja: 'In the passage below, extract the part which answers the last question.
6964
If there is no answer in the passage, say "unknown".
@@ -94,7 +89,6 @@ templates:
9489
reference: ''
9590
be39974f-aa86-4076-b444-bd3c2732b17b: !Template
9691
answer_choices: null
97-
answer_choices_key: null
9892
id: be39974f-aa86-4076-b444-bd3c2732b17b
9993
jinja: "Help me complete the dialogue about this passage. If there is no answer\
10094
\ in the passage, say \"unknown\".\n\nPassage: {{story}}\n\nQ: {{question}}\
@@ -109,7 +103,6 @@ templates:
109103
reference: 'Metric: variant of SQuAD (Section 6.1 of the paper)'
110104
d95440ce-d538-40f8-ae09-664e05852ca8: !Template
111105
answer_choices: null
112-
answer_choices_key: null
113106
id: d95440ce-d538-40f8-ae09-664e05852ca8
114107
jinja: "{{story}}\n\nQ: {{question}} \nA: ||| {% if answer[\"answer_start\"] !=\
115108
\ -1 %}\n{{answer[\"input_text\"]}}\n{% else %}\nunknown\n{% endif %}"

promptsource/templates/Zaid/quac_expanded/templates.yaml

-6
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ dataset: Zaid/quac_expanded
22
templates:
33
01d8c949-89a7-4a44-9a39-6cf2ac3e0a7b: !Template
44
answer_choices: null
5-
answer_choices_key: null
65
id: 01d8c949-89a7-4a44-9a39-6cf2ac3e0a7b
76
jinja: "What is the answer to the last question in the dialogue below? If there\
87
\ is no answer in the passage, say \"unknown\".\n\nPassage: {{context}}\n\n\
@@ -16,7 +15,6 @@ templates:
1615
reference: 'Metric: F1'
1716
1484c6e6-bf42-47ca-9ea7-c3c552a24de1: !Template
1817
answer_choices: null
19-
answer_choices_key: null
2018
id: 1484c6e6-bf42-47ca-9ea7-c3c552a24de1
2119
jinja: "{{context}}\n\nQ: {{question}} \nA: ||| {{answer[\"texts\"][0]}}"
2220
metadata: !TemplateMetadata
@@ -28,7 +26,6 @@ templates:
2826
reference: 'Brown et al. NeurIPS 2020. Metric: F1'
2927
2bca0532-01a3-4a64-a228-a57ae0965719: !Template
3028
answer_choices: null
31-
answer_choices_key: null
3229
id: 2bca0532-01a3-4a64-a228-a57ae0965719
3330
jinja: "Below is a passage, followed by a series of questions and answers about\
3431
\ the passage. Answer the last question based on the information contained in\
@@ -43,7 +40,6 @@ templates:
4340
reference: 'Metric: F1'
4441
4abd0379-dbc0-4f71-901b-dd0af3581157: !Template
4542
answer_choices: null
46-
answer_choices_key: null
4743
id: 4abd0379-dbc0-4f71-901b-dd0af3581157
4844
jinja: "Answer the last question based on the information contained in the passage.\
4945
\ If there is no answer in the passage, say \"unknown\".\n\nPassage: {{context}}\n\
@@ -57,7 +53,6 @@ templates:
5753
reference: 'Metric: F1'
5854
8ebbd098-b40c-4e69-8cbb-0ffecf0fe2a6: !Template
5955
answer_choices: null
60-
answer_choices_key: null
6156
id: 8ebbd098-b40c-4e69-8cbb-0ffecf0fe2a6
6257
jinja: "Complete the dialogue based on the information contained in the passage.\
6358
\ If there is no answer in the passage, say \"unknown\".\n\nPassage: {{context}}\n\
@@ -71,7 +66,6 @@ templates:
7166
reference: 'Metric: F1'
7267
e624695b-5d26-47cc-bdb4-ac2bee4ddaea: !Template
7368
answer_choices: null
74-
answer_choices_key: null
7569
id: e624695b-5d26-47cc-bdb4-ac2bee4ddaea
7670
jinja: "Help me complete the dialogue about this passage. If there is no answer\
7771
\ in the passage, say \"unknown\".\n\nPassage: {{context}}\n\nQ: {{question}}\

promptsource/templates/acronym_identification/templates.yaml

-6
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ dataset: acronym_identification
22
templates:
33
64f438f2-9968-459f-82d2-24bad632b358: !Template
44
answer_choices: null
5-
answer_choices_key: null
65
id: 64f438f2-9968-459f-82d2-24bad632b358
76
jinja: "{% set random_abbr = '' %}\n{% set _dummy = none %}\n{% set abbr_exp_dict\
87
\ = namespace(value = {}) %}\n{% set abbr_string=namespace(value='') %}\n{%\
@@ -48,7 +47,6 @@ templates:
4847
reference: Given the tokens, find the expansion of an abbreviation in the tokens.
4948
81babc83-18cd-4eed-a343-8ede56b21df5: !Template
5049
answer_choices: null
51-
answer_choices_key: null
5250
id: 81babc83-18cd-4eed-a343-8ede56b21df5
5351
jinja: "Given the BIO encoding as follows: \"{{\"B-short\"}}\" and \"{{\"I-short\"\
5452
}}\" represent the beginning and intermediate tokens for abbreviations.\"{{\"\
@@ -66,7 +64,6 @@ templates:
6664
reference: Given the comma separated tokens, generate BIO encoding for abbreviations.
6765
8832e5f7-7c45-46da-b85f-71fcb444f264: !Template
6866
answer_choices: null
69-
answer_choices_key: null
7067
id: 8832e5f7-7c45-46da-b85f-71fcb444f264
7168
jinja: 'List all the expansions of the acronyms present in the following comma-separated
7269
tokens. Return {{"No expansions found"}} if the expansions can''t be found.
@@ -122,7 +119,6 @@ templates:
122119
reference: Given the tokens, list the expansion tokens.
123120
cae58242-cde9-472d-ae9e-56fc7e79c0d1: !Template
124121
answer_choices: null
125-
answer_choices_key: null
126122
id: cae58242-cde9-472d-ae9e-56fc7e79c0d1
127123
jinja: "List all the acryonyms in the following comma-separated tokens: \n\n{{tokens|join(',\
128124
\ ')}}\n|||\n{% set abbr_string=namespace(value='') %}\n{% set answer_list=namespace(value=[])\
@@ -142,7 +138,6 @@ templates:
142138
reference: Given the tokens, list the abbreviations.
143139
e4e42433-0e37-4aa5-bbce-7f336ecac6a3: !Template
144140
answer_choices: null
145-
answer_choices_key: null
146141
id: e4e42433-0e37-4aa5-bbce-7f336ecac6a3
147142
jinja: "{% set _dummy = none %}\n{% set abbr_exp_dict = namespace(value = {})\
148143
\ %}\n{% set abbr_string=namespace(value='') %}\n{% set exp_string=namespace(value='')%}\n\
@@ -185,7 +180,6 @@ templates:
185180
reference: Given the tokens, find the abbreviation mapping.
186181
eed32ee4-ebc3-499f-ba61-e91461f56ccb: !Template
187182
answer_choices: null
188-
answer_choices_key: null
189183
id: eed32ee4-ebc3-499f-ba61-e91461f56ccb
190184
jinja: "{% set random_exp = '' %}{% set _dummy = none %}{% set exp_abbr_dict =\
191185
\ namespace(value = {}) %}{% set abbr_string=namespace(value='') %}{% set exp_string=namespace(value='')%}{%\

0 commit comments

Comments
 (0)