Skip to content

Commit 7073f8f

Browse files
Merge pull request #206 from lambda-feedback/tr173-changed-preview-code-gen-for-latex-response-with-pm
Found a case where generating a set was still necessary.
2 parents 8617803 + 933c567 commit 7073f8f

File tree

3 files changed

+47
-30
lines changed

3 files changed

+47
-30
lines changed

app/expression_utilities.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,10 @@ def _print_log(self, expr, exp=None):
6767
# -------- String Manipulation Utilities
6868
def create_expression_set(exprs, params):
6969
if isinstance(exprs, str):
70-
exprs = [exprs]
70+
if exprs.startswith('{') and exprs.endswith('}'):
71+
exprs = [expr.strip() for expr in exprs[1:-1].split(',')]
72+
else:
73+
exprs = [exprs]
7174
expr_set = set()
7275
for expr in exprs:
7376
expr = substitute_input_symbols(expr, params)[0]

app/preview_utilities.py

Lines changed: 26 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -49,21 +49,24 @@ def parse_latex(response: str, symbols: SymbolDict, simplify : bool, parameters=
4949
pm_placeholder = None
5050
mp_placeholder = None
5151

52+
results = set()
53+
5254
if r"\pm " in response or r"\mp " in response:
55+
response_set = set()
5356
for char in 'abcdefghjkoqrtvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ':
5457
if char not in response and pm_placeholder is None:
5558
pm_placeholder = char
59+
substitutions[pm_placeholder] = sympy.Symbol(pm_placeholder, commutative=False)
5660
elif char not in response and mp_placeholder is None:
5761
mp_placeholder = char
62+
substitutions[mp_placeholder] = sympy.Symbol(mp_placeholder, commutative=False)
5863
if pm_placeholder is not None and mp_placeholder is not None:
5964
break
60-
61-
if pm_placeholder is not None:
62-
response = response.replace(r"\pm ", pm_placeholder)
63-
substitutions[pm_placeholder] = sympy.Symbol(pm_placeholder, commutative=False)
64-
if mp_placeholder is not None:
65-
response = response.replace(r"\mp ", mp_placeholder)
66-
substitutions[mp_placeholder] = sympy.Symbol(mp_placeholder, commutative=False)
65+
for expr in create_expression_set(response.replace(r"\pm ",'plus_minus').replace(r"\mp ",'minus_plus'), parameters):
66+
response_set.add(expr)
67+
response = response_set
68+
else:
69+
response_set = {response}
6770

6871
for sympy_symbol_str in symbols:
6972
symbol_str = symbols[sympy_symbol_str]["latex"]
@@ -79,28 +82,23 @@ def parse_latex(response: str, symbols: SymbolDict, simplify : bool, parameters=
7982
)
8083
substitutions[latex_symbol] = sympy.Symbol(sympy_symbol_str)
8184

82-
try:
83-
expression = latex2sympy(response, substitutions)
84-
if isinstance(expression, list):
85-
expression = expression.pop()
86-
if simplify is True:
87-
expression = expression.simplify()
88-
except Exception as e:
89-
raise ValueError(str(e))
90-
91-
if (pm_placeholder is not None) or (mp_placeholder is not None):
92-
result_str = str(expression)
93-
for ph in [(pm_placeholder, "plus_minus"), (mp_placeholder, "minus_plus")]:
94-
if ph[0] is not None:
95-
result_str = result_str.replace("*"+ph[0]+"*", " "+ph[1]+" ")
96-
result_str = result_str.replace(ph[0]+"*", " "+ph[1]+" ")
97-
result_str = result_str.replace("*"+ph[0], " "+ph[1]+" ")
98-
result_str = result_str.replace(ph[0], " "+ph[1]+" ")
85+
parsed_responses = set()
86+
for expression in response_set:
87+
try:
88+
expression = latex2sympy(expression, substitutions)
89+
if isinstance(expression, list):
90+
expression = expression.pop()
91+
if simplify is True:
92+
expression = expression.simplify()
93+
except Exception as e:
94+
raise ValueError(str(e))
95+
96+
parsed_responses.add(str(expression.xreplace(substitutions)))
97+
98+
if len(parsed_responses) < 2:
99+
return parsed_responses.pop()
99100
else:
100-
result_str = str(expression.xreplace(substitutions))
101-
102-
return result_str
103-
101+
return '{'+', '.join(parsed_responses)+'}'
104102

105103
def sanitise_latex(response):
106104
response = "".join(response.split())

app/symbolic_comparison_preview_tests.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,24 @@ def test_latex_with_plus_minus(self):
9696
)
9797
result = preview_function(response, params)
9898
preview = result["preview"]
99-
assert preview.get("sympy") == ' plus_minus 3*i/sqrt(5)'
99+
assert preview.get("sympy") in {'{3*(sqrt(5)/5)*I, -3*sqrt(5)/5*I}', '{-3*sqrt(5)/5*I, 3*(sqrt(5)/5)*I}'}
100100
assert preview.get("latex") == r'\pm \frac{3}{\sqrt{5}} i'
101+
response = r"4 \pm \sqrt{6}}"
102+
params = Params(
103+
is_latex=True,
104+
simplify=False,
105+
complexNumbers=True,
106+
symbols={
107+
"plus_minus": {
108+
"latex": "$\\pm$",
109+
"aliases": ["pm", "+-"],
110+
},
111+
}
112+
)
113+
result = preview_function(response, params)
114+
preview = result["preview"]
115+
assert preview.get("sympy") in {'{sqrt(6) + 4, 4 - sqrt(6)}', '{4 - sqrt(6), sqrt(6) + 4}'}
116+
assert preview.get("latex") == r'4 \pm \sqrt{6}}'
101117

102118
def test_latex_conversion_preserves_default_symbols(self):
103119
response = "\\mu + x + 1"

0 commit comments

Comments
 (0)