Skip to content

Commit 2023515

Browse files
Changed preview sympy code generation for responses in latex format with plus_minus
1 parent dcb4dbb commit 2023515

File tree

3 files changed

+48
-12
lines changed

3 files changed

+48
-12
lines changed

app/expression_utilities.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,11 @@ def create_expression_set(exprs, params):
7878
expr = expr.replace(params["minus_plus"], "minus_plus")
7979

8080
if ("plus_minus" in expr) or ("minus_plus" in expr):
81-
expr_set.add(expr.replace("plus_minus", "+").replace("minus_plus", "-"))
82-
expr_set.add(expr.replace("plus_minus", "-").replace("minus_plus", "+"))
81+
for pm_mp_ops in [("+","-"),("-","+")]:
82+
expr_string = expr.replace("plus_minus", pm_mp_ops[0]).replace("minus_plus", pm_mp_ops[1]).strip()
83+
while expr_string[0] == "+":
84+
expr_string = expr_string[1:]
85+
expr_set.add(expr_string.strip())
8386
else:
8487
expr_set.add(expr)
8588

app/preview_utilities.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
extract_latex,
99
SymbolDict,
1010
find_matching_parenthesis,
11+
create_expression_set,
1112
)
1213

1314
class Params(TypedDict):
@@ -26,7 +27,7 @@ class Result(TypedDict):
2627
preview: Preview
2728

2829

29-
def parse_latex(response: str, symbols: SymbolDict, simplify : bool) -> str:
30+
def parse_latex(response: str, symbols: SymbolDict, simplify : bool, parameters=None) -> str:
3031
"""Parse a LaTeX string to a sympy string while preserving custom symbols.
3132
3233
Args:
@@ -40,6 +41,9 @@ def parse_latex(response: str, symbols: SymbolDict, simplify : bool) -> str:
4041
Returns:
4142
str: The expression in sympy syntax.
4243
"""
44+
if parameters is None:
45+
parameters = dict()
46+
4347
substitutions = {}
4448

4549
pm_placeholder = None
@@ -56,8 +60,10 @@ def parse_latex(response: str, symbols: SymbolDict, simplify : bool) -> str:
5660

5761
if pm_placeholder is not None:
5862
response = response.replace(r"\pm ", pm_placeholder)
63+
substitutions[pm_placeholder] = sympy.Symbol(pm_placeholder, commutative=False)
5964
if mp_placeholder is not None:
6065
response = response.replace(r"\mp ", mp_placeholder)
66+
substitutions[mp_placeholder] = sympy.Symbol(mp_placeholder, commutative=False)
6167

6268
for sympy_symbol_str in symbols:
6369
symbol_str = symbols[sympy_symbol_str]["latex"]
@@ -73,8 +79,6 @@ def parse_latex(response: str, symbols: SymbolDict, simplify : bool) -> str:
7379
)
7480
substitutions[latex_symbol] = sympy.Symbol(sympy_symbol_str)
7581

76-
substitutions.update({r"\pm ": pm_placeholder, r"\mp ": mp_placeholder})
77-
7882
try:
7983
expression = latex2sympy(response, substitutions)
8084
if isinstance(expression, list):
@@ -84,13 +88,20 @@ def parse_latex(response: str, symbols: SymbolDict, simplify : bool) -> str:
8488
except Exception as e:
8589
raise ValueError(str(e))
8690

87-
result_str = str(expression.xreplace(substitutions))
88-
for ph in [(pm_placeholder, "plus_minus"), (mp_placeholder, "minus_plus")]:
89-
if ph[0] is not None:
90-
result_str = result_str.replace("*"+ph[0]+"*", " "+ph[1]+" ")
91-
result_str = result_str.replace(ph[0]+"*", " "+ph[1]+" ")
92-
result_str = result_str.replace("*"+ph[0], " "+ph[1]+" ")
93-
result_str = result_str.replace(ph[0], " "+ph[1]+" ")
91+
if (pm_placeholder is not None) or (mp_placeholder is not None):
92+
result_str_set = set()
93+
result_str = str(expression)
94+
for ph in [(pm_placeholder, "plus_minus"), (mp_placeholder, "minus_plus")]:
95+
if ph[0] is not None:
96+
result_str = result_str.replace("*"+ph[0]+"*", " "+ph[1]+" ")
97+
result_str = result_str.replace(ph[0]+"*", " "+ph[1]+" ")
98+
result_str = result_str.replace("*"+ph[0], " "+ph[1]+" ")
99+
result_str = result_str.replace(ph[0], " "+ph[1]+" ")
100+
for expr in create_expression_set(result_str, parameters):
101+
result_str_set.add(expr)
102+
result_str = '{'+', '.join(result_str_set)+'}'
103+
else:
104+
result_str = str(expression.xreplace(substitutions))
94105

95106
return result_str
96107

app/symbolic_comparison_preview_tests.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,28 @@ def test_sympy_with_equality_symbol(self):
7777
preview = result["preview"]
7878
assert preview.get("latex") == "\\frac{x^{2} + x + x}{x} = 1"
7979

80+
def test_latex_with_plus_minus(self):
81+
response = r"\pm \frac{3}{\sqrt{5}} i"
82+
params = Params(
83+
is_latex=True,
84+
simplify=False,
85+
complexNumbers=True,
86+
symbols={
87+
"I": {
88+
"latex": "$i$",
89+
"aliases": ["i"],
90+
},
91+
"plus_minus": {
92+
"latex": "$\\pm$",
93+
"aliases": ["pm", "+-"],
94+
},
95+
}
96+
)
97+
result = preview_function(response, params)
98+
preview = result["preview"]
99+
assert preview.get("sympy") in ['{3*i/sqrt(5), - 3*i/sqrt(5)}', '{- 3*i/sqrt(5), 3*i/sqrt(5)}']
100+
assert preview.get("latex") == r'\pm \frac{3}{\sqrt{5}} i'
101+
80102
def test_latex_conversion_preserves_default_symbols(self):
81103
response = "\\mu + x + 1"
82104
params = Params(is_latex=True, simplify=False)

0 commit comments

Comments
 (0)