Skip to content

Commit 55f6b67

Browse files
Merge pull request #146 from lambda-feedback/tr126-proof-of-concept-error-identification
Tr126 proof of concept error identification
2 parents a5d602d + b1c9dde commit 55f6b67

File tree

3 files changed

+122
-15
lines changed

3 files changed

+122
-15
lines changed

app/evaluation.py

-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from .unit_system_conversions import set_of_SI_prefixes, set_of_SI_base_unit_dimensions
99
from .preview import preview_function
1010

11-
1211
def evaluation_function(response, answer, params, include_test_data=False) -> dict:
1312
"""
1413
Function that allows for various types of comparison of various kinds of expressions.

app/symbolic_comparison_evaluation.py

+116-11
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from sympy.parsing.sympy_parser import T as parser_transformations
2-
from sympy import Abs, Equality, latex, pi, Symbol
2+
from sympy import Abs, Equality, latex, pi, Symbol, Add, Pow
33
from sympy.printing.latex import LatexPrinter
4+
from copy import deepcopy
45

56
from .expression_utilities import (
67
substitute_input_symbols,
@@ -161,6 +162,44 @@ def evaluation_node_internal(unused_input):
161162
graph.attach(label+"_FALSE", END.label)
162163
return graph
163164

165+
def find_coords_for_node_type(expression, node_type):
166+
stack = [(expression, tuple() )]
167+
node_coords = []
168+
while len(stack) > 0:
169+
(expr, coord) = stack.pop()
170+
if isinstance(expr, node_type):
171+
node_coords.append(coord)
172+
for (k, arg) in enumerate(expr.args):
173+
stack.append((arg, coord+(k,)))
174+
return node_coords
175+
176+
def replace_node_variations(expression, type_of_node, replacement_function):
177+
variations = []
178+
list_of_coords = find_coords_for_node_type(expression, type_of_node)
179+
for coords in list_of_coords:
180+
nodes = [expression]
181+
for coord in coords:
182+
nodes.append(nodes[-1].args[coord])
183+
for k in range(0, len(nodes[-1].args)):
184+
variation = replacement_function(nodes[-1], k)
185+
for (node, coord) in reversed(list(zip(nodes, coords))):
186+
new_args = node.args[0:coord]+(variation,)+node.args[coord+1:]
187+
variation = type(node)(*new_args)
188+
variations.append(variation)
189+
return variations
190+
191+
def one_addition_to_subtraction(expression):
192+
def addition_to_subtraction(node, k):
193+
return node - 2*node.args[k]
194+
variations = replace_node_variations(expression, Add, addition_to_subtraction)
195+
return variations
196+
197+
def one_exponent_flip(expression):
198+
def exponent_flip(node, k):
199+
return node**(-1)
200+
variations = replace_node_variations(expression, Pow, exponent_flip)
201+
return variations
202+
164203
def criterion_where_node(criterion, parameters_dict, label=None):
165204
parsing_params = parameters_dict["parsing_params"]
166205
expression = criterion.children[0]
@@ -189,6 +228,17 @@ def expression_check(unused_input):
189228
else:
190229
return {label+"_FALSE"}
191230
return expression_check
231+
232+
graph = CriteriaGraph(label)
233+
END = CriteriaGraph.END
234+
graph.add_node(END)
235+
graph.add_evaluation_node(label, summary=label, details="Checks if "+str(expression)+" where "+str(subs)+".", evaluate=create_expression_check(expression))
236+
graph.attach(label, label+"_TRUE", summary=str(expression)+" where "+str(subs), details=str(expression)+" where "+str(subs)+"is true.")
237+
graph.attach(label+"_TRUE", END.label)
238+
graph.attach(label, label+"_FALSE", summary="not "+str(expression), details=str(expression)+" is not true with"+str(subs)+".")
239+
240+
reserved_expressions = list(parameters_dict["reserved_expressions"].items())
241+
response = parameters_dict["reserved_expressions"]["response"]
192242
expression_to_vary = None
193243
if expression.children[0].content_string().strip() == "response":
194244
expression_to_vary = expression.children[1]
@@ -197,16 +247,69 @@ def expression_check(unused_input):
197247
if "response" in expression_to_vary.content_string():
198248
expression_to_vary = None
199249
if expression_to_vary is not None:
200-
expression_to_vary = parse_expression(expression_to_vary.content_string(), parsing_params)
201-
expression_variations = []
202-
graph = CriteriaGraph(label)
203-
END = CriteriaGraph.END
204-
graph.add_node(END)
205-
graph.add_evaluation_node(label, summary=label, details="Checks if "+str(expression)+" where "+str(subs)+".", evaluate=create_expression_check(expression))
206-
graph.attach(label, label+"_TRUE", summary=str(expression)+" where "+str(subs), details=str(expression)+" where "+str(subs)+"is true.")
207-
graph.attach(label+"_TRUE", END.label)
208-
graph.attach(label, label+"_FALSE", summary="not "+str(expression), details=str(expression)+" is not true with"+str(subs)+".")
209-
graph.attach(label+"_FALSE", END.label)
250+
response_value = response.subs(local_subs)
251+
expression_to_vary = parse_expression(expression_to_vary.content_string(), parsing_params).subs(reserved_expressions)
252+
variation_groups = {
253+
"ONE_ADDITION_TO_SUBTRACTION": {
254+
"variations": one_addition_to_subtraction(expression_to_vary),
255+
"summary": lambda expression, variations: str(expression)+" is true if one addition is changed to a subtraction or vice versa.",
256+
"details": lambda expression, variations: "The following expressions are checked: "+", ".join([str(e) for e in variations]),
257+
},
258+
"ONE_EXPONENT_FLIP": {
259+
"variations": one_exponent_flip(expression_to_vary),
260+
"summary": lambda expression, variations: str(expression)+" is true if one exponent has its sign changed.",
261+
"details": lambda expression, variations: "The following expressions are checked: "+", ".join([str(e) for e in variations]),
262+
}
263+
}
264+
values_and_expressions = {expression_to_vary.subs(local_subs): set([expression_to_vary])}
265+
values_and_variations_group = {expression_to_vary.subs(local_subs): set(["UNDETECTABLE"])}
266+
for (group_label, info) in variation_groups.items():
267+
for variation in info["variations"]:
268+
value = variation.subs(local_subs)
269+
values_and_expressions.update({value: values_and_expressions.get(value, set()).union(set([variation]))})
270+
if value == expression_to_vary.subs(local_subs):
271+
values_and_variations_group["UNDETECTABLE"].add(variation)
272+
else:
273+
values_and_variations_group.update({value: values_and_variations_group.get(value, set()).union(set([group_label]))})
274+
if len(values_and_expressions) > 1:
275+
def identify_reason(unused_input):
276+
reasons = {label+"_"+group_label for group_label in values_and_variations_group.get(response_value, {"UNKNOWN"})}
277+
return reasons
278+
graph.attach(label+"_FALSE", label+"_IDENTIFY_REASON", summary="Identify reason.", details="Attempt to identify why the response is incorrect.", evaluate=identify_reason)
279+
graph.attach(label+"_IDENTIFY_REASON", label+"_UNKNOWN", summary="Unknown reason", details="No candidates for how the response was computed were found.")
280+
graph.attach(label+"_UNKNOWN", END.label)
281+
282+
def get_candidates(unused_input):
283+
candidates = set(["response candidates "+", ".join([str(e) for e in values_and_expressions[response_value]])])
284+
return candidates
285+
for (group_label, group_info) in variation_groups.items():
286+
graph.attach(
287+
label+"_IDENTIFY_REASON",
288+
label+"_"+group_label,
289+
summary=group_info["summary"](expression_to_vary, group_info["variations"]),
290+
details=group_info["details"](expression_to_vary, group_info["variations"])
291+
)
292+
graph.attach(
293+
label+"_"+group_label,
294+
label+"_GET_CANDIDATES_"+group_label,
295+
summary="Get candidate responses that satisfy "+str(expression),
296+
details="Get candidate responses that satisfy "+str(expression), evaluate=get_candidates
297+
)
298+
299+
for (value, expressions) in values_and_expressions.items():
300+
expressions_string = ", ".join([str(e) for e in expressions])
301+
for group_label in values_and_variations_group[value]:
302+
if group_label != "UNDETECTABLE":
303+
graph.attach(
304+
label+"_GET_CANDIDATES_"+group_label,
305+
"response candidates "+expressions_string,
306+
summary="Response candidates: "+expressions_string,
307+
details="Response candidates: "+expressions_string
308+
)
309+
graph.attach(
310+
"response candidates "+expressions_string,
311+
END.label
312+
)
210313
return graph
211314

212315
def create_criteria_list(criteria_string, criteria_parser, parsing_params):
@@ -411,6 +514,8 @@ def symbolic_comparison(response, answer, params, eval_response) -> dict:
411514
#is_correct = is_correct and check_criterion(criterion, parameters_dict)
412515
is_correct = is_correct and main_criteria in criteria_feedback
413516
result = main_criteria in criteria_feedback
517+
for item in criteria_feedback:
518+
eval_response.add_feedback((item, item))
414519
for (reference_tag, reference_strings) in reference_criteria_strings.items():
415520
if reference_tag in eval_response.get_tags():
416521
continue

app/symbolic_comparison_evaluation_tests.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -1101,9 +1101,12 @@ def test_no_reserved_keywords_in_old_format_input_symbol_alternatives(self):
11011101
'I': {'aliases': ['i'], 'latex': r'\(i\)'},
11021102
}
11031103
}),
1104-
("3", "x+1", "response=answer where x=2", True, [], {}),
1105-
("6", "x+y+1", "response=answer where x=2; y=3", True, [], {}),
1106-
("15", "x+y*z+1", "response=answer where x=2; y=3; z=4", True, [], {}),
1104+
("3", "x+1", "response=answer where x=2", True, ["response=answer where x=2_TRUE"], {}),
1105+
("1", "x+1", "response=answer where x=2", False, ["response=answer where x=2_ONE_ADDITION_TO_SUBTRACTION", "response candidates x - 1"], {}),
1106+
("5/3", "x/y+1", "response=answer where x=2; y=3", True, ["response=answer where x=2; y=3_TRUE"], {}),
1107+
("15", "x/y+1", "response=answer where x=2; y=3", False, ["response=answer where x=2; y=3_ONE_EXPONENT_FLIP"], {}), #NOTE: Sympy reporesents input as (x+y)/y so flipping the exponent gives (x+y)*y instead of x*y+1
1108+
("-1/3", "x/y+1", "response=answer where x=2; y=3", False, ["response=answer where x=2; y=3_ONE_ADDITION_TO_SUBTRACTION"], {}),
1109+
("13", "x+y*z-1", "response=answer where x=2; y=3; z=4", True, [], {}),
11071110
]
11081111
)
11091112
def test_criteria_based_comparison(self, response, answer, criteria, value, feedback_tags, additional_params):

0 commit comments

Comments
 (0)