Skip to content

Commit 73ce5df

Browse files
Fixed bug in input symbol substitution
- Input symbols substitutions did not take elementary function names into account - Added exception input symbols handling so that "I" is treated as the imaginary constant if "complexNumbers" is set to true, regardless of if there is an input symbol with code "I" or not.
1 parent de81f33 commit 73ce5df

File tree

2 files changed

+48
-7
lines changed

2 files changed

+48
-7
lines changed

app/expression_utilities.py

+22-6
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,17 @@ def substitute_input_symbols(exprs, params):
222222

223223
substitutions = [(expr, expr) for expr in params.get("reserved_keywords",[])]
224224

225+
if params.get("elementary_functions", False) is True:
226+
alias_substitutions = []
227+
for expr in exprs:
228+
for (name, alias_list) in elementary_functions_names+special_symbols_names:
229+
if name in expr:
230+
alias_substitutions += [(name, " "+name)]
231+
for alias in alias_list:
232+
if alias in expr:
233+
alias_substitutions += [(alias, " "+name)]
234+
substitutions += alias_substitutions
235+
225236
input_symbols = params.get("symbols",dict())
226237

227238
if "symbols" in params.keys():
@@ -497,12 +508,11 @@ def create_sympy_parsing_params(params, unsplittable_symbols=tuple(), symbol_ass
497508
parse_expression function.
498509
'''
499510

511+
unsplittable_symbols = list(unsplittable_symbols)
500512
if "symbols" in params.keys():
501-
to_keep = []
502513
for symbol in params["symbols"].keys():
503514
if len(symbol) > 1:
504-
to_keep.append(symbol)
505-
unsplittable_symbols += tuple(to_keep)
515+
unsplittable_symbols.append(symbol)
506516

507517
if params.get("specialFunctions", False) is True:
508518
from sympy import beta, gamma, zeta
@@ -512,6 +522,12 @@ def create_sympy_parsing_params(params, unsplittable_symbols=tuple(), symbol_ass
512522
zeta = Symbol("zeta")
513523
if params.get("complexNumbers", False) is True:
514524
from sympy import I
525+
# imaginary_constant_index = None
526+
# for (k, symbol) in enumerate(unsplittable_symbols):
527+
# if "I" == symbol[0]:
528+
# imaginary_constant_index = k
529+
# if imaginary_constant_index is not None:
530+
# unsplittable_symbols = unsplittable_symbols[0:imaginary_constant_index]+unsplittable_symbols[imaginary_constant_index+1:]
515531
else:
516532
I = Symbol("I")
517533
if params.get("elementary_functions", False) is True:
@@ -535,10 +551,10 @@ def create_sympy_parsing_params(params, unsplittable_symbols=tuple(), symbol_ass
535551
"E": E
536552
}
537553

538-
for symbol in unsplittable_symbols:
539-
symbol_dict.update({symbol: Symbol(symbol)})
554+
# for symbol in unsplittable_symbols:
555+
# symbol_dict.update({symbol: Symbol(symbol)})
540556

541-
symbol_dict.update(sympy_symbols(params.get("symbols", {})))
557+
symbol_dict.update(sympy_symbols(unsplittable_symbols))
542558

543559
strict_syntax = params.get("strict_syntax", True)
544560

app/symbolic_comparison_evaluation_tests.py

+26-1
Original file line numberDiff line numberDiff line change
@@ -1075,7 +1075,32 @@ def test_no_reserved_keywords_in_old_format_input_symbol_alternatives(self):
10751075
("5*exp(lambda*x)/(1+5*exp(lambda*x))", "c*exp(lambda*x)/(1+c*exp(lambda*x))", "diff(response,x)=lambda*response*(1-response)", True, [], {}),
10761076
("6*exp(lambda*x)/(1+7*exp(lambda*x))", "c*exp(lambda*x)/(1+c*exp(lambda*x))", "diff(response,x)=lambda*response*(1-response)", False, [], {}),
10771077
("c*exp(lambda*x)/(1+c*exp(lambda*x))", "c*exp(lambda*x)/(1+c*exp(lambda*x))", "diff(response,x)=lambda*response*(1-response)", True, [], {}),
1078-
("-A/r^2*cos(omega*t-k*r)+k*A/r*sin(omega*t-k*r)", "(-A/(r**2))*exp(I*(omega*t-k*r))*(1+I*k*r)", "re(response)=re(answer)", True, [], {"complexNumbers": True, "symbol_assumptions": "('k','real') ('r','real') ('omega','real') ('t','real') ('A','real')"}),
1078+
("-A/r^2*cos(omega*t-k*r)+k*A/r*sin(omega*t-k*r)", "(-A/(r**2))*exp(i*(omega*t-k*r))*(1+i*k*r)", "re(response)=re(answer)", True, [],
1079+
{
1080+
"complexNumbers": True,
1081+
"symbol_assumptions": "('k','real') ('r','real') ('omega','real') ('t','real') ('A','real')",
1082+
'symbols': {
1083+
'r': {'aliases': ['R'], 'latex': r'\(r\)'},
1084+
'A': {'aliases': ['a'], 'latex': r'\(A\)'},
1085+
'omega': {'aliases': ['OMEGA', 'Omega'], 'latex': r'\(\omega\)'},
1086+
'k': {'aliases': ['K'], 'latex': r'\(k\)'},
1087+
't': {'aliases': ['T'], 'latex': r'\(t\)'},
1088+
'I': {'aliases': ['i'], 'latex': r'\(i\)'},
1089+
}
1090+
}),
1091+
("-A/r^2*(cos(omega*t-kr)+I*sin(omega*t-kr))*(1+Ikr)", "(-A/(r**2))*exp(I*(omega*t-k*r))*(1+I*k*r)", "re(response)=re(answer)", True, [],
1092+
{
1093+
"complexNumbers": True,
1094+
"symbol_assumptions": "('k','real') ('r','real') ('omega','real') ('t','real') ('A','real')",
1095+
'symbols': {
1096+
'r': {'aliases': ['R'], 'latex': r'\(r\)'},
1097+
'A': {'aliases': ['a'], 'latex': r'\(A\)'},
1098+
'omega': {'aliases': ['OMEGA', 'Omega'], 'latex': r'\(\omega\)'},
1099+
'k': {'aliases': ['K'], 'latex': r'\(k\)'},
1100+
't': {'aliases': ['T'], 'latex': r'\(t\)'},
1101+
'I': {'aliases': ['i'], 'latex': r'\(i\)'},
1102+
}
1103+
}),
10791104
]
10801105
)
10811106
def test_criteria_based_comparison(self, response, answer, criteria, value, feedback_tags, additional_params):

0 commit comments

Comments
 (0)