8
8
extract_latex ,
9
9
SymbolDict ,
10
10
find_matching_parenthesis ,
11
+ create_expression_set ,
11
12
)
12
13
13
14
class Params (TypedDict ):
@@ -26,7 +27,7 @@ class Result(TypedDict):
26
27
preview : Preview
27
28
28
29
29
- def parse_latex (response : str , symbols : SymbolDict , simplify : bool ) -> str :
30
+ def parse_latex (response : str , symbols : SymbolDict , simplify : bool , parameters = None ) -> str :
30
31
"""Parse a LaTeX string to a sympy string while preserving custom symbols.
31
32
32
33
Args:
@@ -40,6 +41,9 @@ def parse_latex(response: str, symbols: SymbolDict, simplify : bool) -> str:
40
41
Returns:
41
42
str: The expression in sympy syntax.
42
43
"""
44
+ if parameters is None :
45
+ parameters = dict ()
46
+
43
47
substitutions = {}
44
48
45
49
pm_placeholder = None
@@ -56,8 +60,10 @@ def parse_latex(response: str, symbols: SymbolDict, simplify : bool) -> str:
56
60
57
61
if pm_placeholder is not None :
58
62
response = response .replace (r"\pm " , pm_placeholder )
63
+ substitutions [pm_placeholder ] = sympy .Symbol (pm_placeholder , commutative = False )
59
64
if mp_placeholder is not None :
60
65
response = response .replace (r"\mp " , mp_placeholder )
66
+ substitutions [mp_placeholder ] = sympy .Symbol (mp_placeholder , commutative = False )
61
67
62
68
for sympy_symbol_str in symbols :
63
69
symbol_str = symbols [sympy_symbol_str ]["latex" ]
@@ -73,8 +79,6 @@ def parse_latex(response: str, symbols: SymbolDict, simplify : bool) -> str:
73
79
)
74
80
substitutions [latex_symbol ] = sympy .Symbol (sympy_symbol_str )
75
81
76
- substitutions .update ({r"\pm " : pm_placeholder , r"\mp " : mp_placeholder })
77
-
78
82
try :
79
83
expression = latex2sympy (response , substitutions )
80
84
if isinstance (expression , list ):
@@ -84,13 +88,20 @@ def parse_latex(response: str, symbols: SymbolDict, simplify : bool) -> str:
84
88
except Exception as e :
85
89
raise ValueError (str (e ))
86
90
87
- result_str = str (expression .xreplace (substitutions ))
88
- for ph in [(pm_placeholder , "plus_minus" ), (mp_placeholder , "minus_plus" )]:
89
- if ph [0 ] is not None :
90
- result_str = result_str .replace ("*" + ph [0 ]+ "*" , " " + ph [1 ]+ " " )
91
- result_str = result_str .replace (ph [0 ]+ "*" , " " + ph [1 ]+ " " )
92
- result_str = result_str .replace ("*" + ph [0 ], " " + ph [1 ]+ " " )
93
- result_str = result_str .replace (ph [0 ], " " + ph [1 ]+ " " )
91
+ if (pm_placeholder is not None ) or (mp_placeholder is not None ):
92
+ result_str_set = set ()
93
+ result_str = str (expression )
94
+ for ph in [(pm_placeholder , "plus_minus" ), (mp_placeholder , "minus_plus" )]:
95
+ if ph [0 ] is not None :
96
+ result_str = result_str .replace ("*" + ph [0 ]+ "*" , " " + ph [1 ]+ " " )
97
+ result_str = result_str .replace (ph [0 ]+ "*" , " " + ph [1 ]+ " " )
98
+ result_str = result_str .replace ("*" + ph [0 ], " " + ph [1 ]+ " " )
99
+ result_str = result_str .replace (ph [0 ], " " + ph [1 ]+ " " )
100
+ for expr in create_expression_set (result_str , parameters ):
101
+ result_str_set .add (expr )
102
+ result_str = '{' + ', ' .join (result_str_set )+ '}'
103
+ else :
104
+ result_str = str (expression .xreplace (substitutions ))
94
105
95
106
return result_str
96
107
0 commit comments