Skip to content

Commit b66fbf3

Browse files
committed
refactor arithmetic power
1 parent 2aa00f0 commit b66fbf3

File tree

4 files changed

+199
-40
lines changed

4 files changed

+199
-40
lines changed

mathics/builtin/arithfns/basic.py

+62-21
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
88
"""
99

10+
1011
from mathics.builtin.arithmetic import create_infix
1112
from mathics.builtin.base import (
1213
BinaryOperator,
@@ -15,6 +16,10 @@
1516
PrefixOperator,
1617
SympyFunction,
1718
)
19+
20+
import sympy
21+
22+
1823
from mathics.core.atoms import (
1924
Complex,
2025
Integer,
@@ -45,7 +50,6 @@
4550
Symbol,
4651
SymbolDivide,
4752
SymbolHoldForm,
48-
SymbolNull,
4953
SymbolPower,
5054
SymbolTimes,
5155
)
@@ -56,10 +60,17 @@
5660
SymbolInfix,
5761
SymbolLeft,
5862
SymbolMinus,
63+
SymbolOverflow,
5964
SymbolPattern,
60-
SymbolSequence,
6165
)
62-
from mathics.eval.arithmetic import eval_Plus, eval_Times
66+
from mathics.eval.arithmetic import (
67+
associate_powers,
68+
eval_Exponential,
69+
eval_Plus,
70+
eval_Power_inexact,
71+
eval_Power_number,
72+
eval_Times,
73+
)
6374
from mathics.eval.nevaluator import eval_N
6475
from mathics.eval.numerify import numerify
6576

@@ -535,15 +546,15 @@ class Power(BinaryOperator, MPMathFunction):
535546
# Remember to up sympy doc link when this is corrected
536547
sympy_name = "Pow"
537548

549+
def eval_exp(self, x, evaluation):
550+
"Power[E, x]"
551+
return eval_Exponential(x)
552+
538553
def eval_check(self, x, y, evaluation):
539554
"Power[x_, y_]"
540-
541-
# Power uses MPMathFunction but does some error checking first
542-
if isinstance(x, Number) and x.is_zero:
543-
if isinstance(y, Number):
544-
y_err = y
545-
else:
546-
y_err = eval_N(y, evaluation)
555+
# if x is zero
556+
if x.is_zero:
557+
y_err = y if isinstance(y, Number) else eval_N(y, evaluation)
547558
if isinstance(y_err, Number):
548559
py_y = y_err.round_to_float(permit_complex=True).real
549560
if py_y > 0:
@@ -557,17 +568,47 @@ def eval_check(self, x, y, evaluation):
557568
evaluation.message(
558569
"Power", "infy", Expression(SymbolPower, x, y_err)
559570
)
560-
return SymbolComplexInfinity
561-
if isinstance(x, Complex) and x.real.is_zero:
562-
yhalf = Expression(SymbolTimes, y, RationalOneHalf)
563-
factor = self.eval(Expression(SymbolSequence, x.imag, y), evaluation)
564-
return Expression(
565-
SymbolTimes, factor, Expression(SymbolPower, IntegerM1, yhalf)
566-
)
567-
568-
result = self.eval(Expression(SymbolSequence, x, y), evaluation)
569-
if result is None or result != SymbolNull:
570-
return result
571+
return SymbolComplexInfinity
572+
573+
# If x and y are inexact numbers, use the numerical function
574+
575+
if x.is_inexact() and y.is_inexact():
576+
try:
577+
return eval_Power_inexact(x, y)
578+
except OverflowError:
579+
evaluation.message("General", "ovfl")
580+
return Expression(SymbolOverflow)
581+
582+
# Tries to associate powers a^b^c-> a^(b*c)
583+
assoc = associate_powers(x, y)
584+
if not assoc.has_form("Power", 2):
585+
return assoc
586+
587+
assoc = numerify(assoc, evaluation)
588+
x, y = assoc.elements
589+
# If x and y are numbers
590+
if isinstance(x, Number) and isinstance(y, Number):
591+
try:
592+
return eval_Power_number(x, y)
593+
except OverflowError:
594+
evaluation.message("General", "ovfl")
595+
return Expression(SymbolOverflow)
596+
597+
# if x or y are inexact, leave the expression
598+
# as it is:
599+
if x.is_inexact() or y.is_inexact():
600+
return assoc
601+
602+
# Finally, try to convert to sympy
603+
base_sp, exp_sp = x.to_sympy(), y.to_sympy()
604+
if base_sp is None or exp_sp is None:
605+
# If base or exp can not be converted to sympy,
606+
# returns the result of applying the associative
607+
# rule.
608+
return assoc
609+
610+
result = from_sympy(sympy.Pow(base_sp, exp_sp))
611+
return result.evaluate_elements(evaluation)
571612

572613

573614
class Sqrt(SympyFunction):

mathics/eval/arithmetic.py

+102-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
# -*- coding: utf-8 -*-
22

33
"""
4-
arithmetic-related evaluation functions.
4+
helper functions for arithmetic evaluation, which do not
5+
depends on the evaluation context. Conversions to Sympy are
6+
used just as a last resource.
57
68
Many of these do do depend on the evaluation context. Conversions to Sympy are
79
used just as a last resource.
@@ -320,6 +322,28 @@ def eval_complex_sign(n: BaseElement) -> Optional[BaseElement]:
320322
return sign or eval_complex_sign(expr)
321323

322324

325+
def eval_Sign_number(n: Number) -> Number:
326+
"""
327+
Evals the absolute value of a number.
328+
"""
329+
if n.is_zero:
330+
return Integer0
331+
if isinstance(n, (Integer, Rational, Real)):
332+
return Integer1 if n.value > 0 else IntegerM1
333+
if isinstance(n, Complex):
334+
abs_sq = eval_add_numbers(
335+
*(eval_multiply_numbers(x, x) for x in (n.real, n.imag))
336+
)
337+
criteria = eval_add_numbers(abs_sq, IntegerM1)
338+
if test_zero_arithmetic_expr(criteria):
339+
return n
340+
if n.is_inexact():
341+
return eval_multiply_numbers(n, eval_Power_number(abs_sq, RealM0p5))
342+
if test_zero_arithmetic_expr(criteria, numeric=True):
343+
return n
344+
return eval_multiply_numbers(n, eval_Power_number(abs_sq, RationalMOneHalf))
345+
346+
323347
def eval_mpmath_function(
324348
mpmath_function: Callable, *args: Number, prec: Optional[int] = None
325349
) -> Optional[Number]:
@@ -347,6 +371,31 @@ def eval_mpmath_function(
347371
return call_mpmath(mpmath_function, tuple(mpmath_args), prec)
348372

349373

374+
def eval_Exponential(exp: BaseElement) -> BaseElement:
375+
"""
376+
Eval E^exp
377+
"""
378+
# If both base and exponent are exact quantities,
379+
# use sympy.
380+
381+
if not exp.is_inexact():
382+
exp_sp = exp.to_sympy()
383+
if exp_sp is None:
384+
return None
385+
return from_sympy(sympy.Exp(exp_sp))
386+
387+
prec = exp.get_precision()
388+
if prec is not None:
389+
if exp.is_machine_precision():
390+
number = mpmath.exp(exp.to_mpmath())
391+
result = from_mpmath(number)
392+
return result
393+
else:
394+
with mpmath.workprec(prec):
395+
number = mpmath.exp(exp.to_mpmath())
396+
return from_mpmath(number, prec)
397+
398+
350399
def eval_Plus(*items: BaseElement) -> BaseElement:
351400
"evaluate Plus for general elements"
352401
numbers, items_tuple = segregate_numbers_from_sorted_list(*items)
@@ -645,8 +694,58 @@ def eval_Times(*items: BaseElement) -> BaseElement:
645694
)
646695

647696

697+
def associate_powers(expr: BaseElement, power: BaseElement = Integer1) -> BaseElement:
698+
"""
699+
base^a^b^c^...^power -> base^(a*b*c*...power)
700+
provided one of the following cases
701+
* `a`, `b`, ... `power` are all integer numbers
702+
* `a`, `b`,... are Rational/Real number with absolute value <=1,
703+
and the other powers are not integer numbers.
704+
* `a` is not a Rational/Real number, and b, c, ... power are all
705+
integer numbers.
706+
"""
707+
powers = []
708+
base = expr
709+
if power is not Integer1:
710+
powers.append(power)
711+
712+
while base.has_form("Power", 2):
713+
previous_base, outer_power = base, power
714+
base, power = base.elements
715+
if len(powers) == 0:
716+
if power is not Integer1:
717+
powers.append(power)
718+
continue
719+
if power is IntegerM1:
720+
powers.append(power)
721+
continue
722+
if isinstance(power, (Rational, Real)):
723+
if abs(power.value) < 1:
724+
powers.append(power)
725+
continue
726+
# power is not rational/real and outer_power is integer,
727+
elif isinstance(outer_power, Integer):
728+
if power is not Integer1:
729+
powers.append(power)
730+
if isinstance(power, Integer):
731+
continue
732+
else:
733+
break
734+
# in any other case, use the previous base and
735+
# exit the loop
736+
base = previous_base
737+
break
738+
739+
if len(powers) == 0:
740+
return base
741+
elif len(powers) == 1:
742+
return Expression(SymbolPower, base, powers[0])
743+
result = Expression(SymbolPower, base, Expression(SymbolTimes, *powers))
744+
return result
745+
746+
648747
def eval_add_numbers(
649-
*numbers: Number,
748+
*numbers: List[Number],
650749
) -> BaseElement:
651750
"""
652751
Add the elements in ``numbers``.
@@ -693,7 +792,7 @@ def eval_inverse_number(n: Number) -> Number:
693792
return eval_Power_number(n, IntegerM1)
694793

695794

696-
def eval_multiply_numbers(*numbers: Number) -> Number:
795+
def eval_multiply_numbers(*numbers: Number) -> BaseElement:
697796
"""
698797
Multiply the elements in ``numbers``.
699798
"""

test/builtin/arithmetic/test_basic.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ def test_directed_infinity_precedence(str_expr, str_expected, msg):
197197
("I^(2/3)", "(-1) ^ (1 / 3)", None),
198198
# In WMA, the next test would return ``-(-I)^(2/3)``
199199
# which is less compact and elegant...
200-
# ("(-I)^(2/3)", "(-1) ^ (-1 / 3)", None),
200+
("(-I)^(2/3)", "(-1) ^ (-1 / 3)", None),
201201
("(2+3I)^3", "-46 + 9 I", None),
202202
("(1.+3. I)^.6", "1.46069 + 1.35921 I", None),
203203
("3^(1+2 I)", "3 ^ (1 + 2 I)", None),
@@ -208,15 +208,15 @@ def test_directed_infinity_precedence(str_expr, str_expected, msg):
208208
# sympy, which produces the result
209209
("(3/Pi)^(-I)", "(3 / Pi) ^ (-I)", None),
210210
# Association rules
211-
# ('(a^"w")^2', 'a^(2 "w")', "Integer power of a power with string exponent"),
211+
('(a^"w")^2', 'a^(2 "w")', "Integer power of a power with string exponent"),
212212
('(a^2)^"w"', '(a ^ 2) ^ "w"', None),
213213
('(a^2)^"w"', '(a ^ 2) ^ "w"', None),
214214
("(a^2)^(1/2)", "Sqrt[a ^ 2]", None),
215215
("(a^(1/2))^2", "a", None),
216216
("(a^(1/2))^2", "a", None),
217217
("(a^(3/2))^3.", "(a ^ (3 / 2)) ^ 3.", None),
218-
# ("(a^(1/2))^3.", "a ^ 1.5", "Power associativity rational, real"),
219-
# ("(a^(.3))^3.", "a ^ 0.9", "Power associativity for real powers"),
218+
("(a^(1/2))^3.", "a ^ 1.5", "Power associativity rational, real"),
219+
("(a^(.3))^3.", "a ^ 0.9", "Power associativity for real powers"),
220220
("(a^(1.3))^3.", "(a ^ 1.3) ^ 3.", None),
221221
# Exponentials involving expressions
222222
("(a^(p-2 q))^3", "a ^ (3 p - 6 q)", None),

test/format/test_format.py

+31-12
Original file line numberDiff line numberDiff line change
@@ -456,34 +456,53 @@
456456
"Sqrt[1/(1+1/(1+1/a))]": {
457457
"msg": "SqrtBox",
458458
"text": {
459-
"System`StandardForm": "Sqrt[1 / (1+1 / (1+1 / a))]",
460-
"System`TraditionalForm": "Sqrt[1 / (1+1 / (1+1 / a))]",
461-
"System`InputForm": "Sqrt[1 / (1 + 1 / (1 + 1 / a))]",
462-
"System`OutputForm": "Sqrt[1 / (1 + 1 / (1 + 1 / a))]",
459+
"System`StandardForm": "1 / Sqrt[1+1 / (1+1 / a)]",
460+
"System`TraditionalForm": "1 / Sqrt[1+1 / (1+1 / a)]",
461+
"System`InputForm": "1 / Sqrt[1 + 1 / (1 + 1 / a)]",
462+
"System`OutputForm": "1 / Sqrt[1 + 1 / (1 + 1 / a)]",
463463
},
464464
"mathml": {
465465
"System`StandardForm": (
466-
"<msqrt> <mfrac><mn>1</mn> <mrow><mn>1</mn> <mo>+</mo> <mfrac><mn>1</mn> <mrow><mn>1</mn> <mo>+</mo> <mfrac><mn>1</mn> <mi>a</mi></mfrac></mrow></mfrac></mrow></mfrac> </msqrt>",
466+
(
467+
r"<mfrac><mn>1</mn> <msqrt> <mrow><mn>1</mn> <mo>+</mo> <mfrac><mn>1</mn> "
468+
r"<mrow><mn>1</mn> <mo>+</mo> <mfrac><mn>1</mn> <mi>a</mi></mfrac></mrow></mfrac></mrow> "
469+
r"</msqrt></mfrac>"
470+
),
467471
"Fragile!",
468472
),
469473
"System`TraditionalForm": (
470-
"<msqrt> <mfrac><mn>1</mn> <mrow><mn>1</mn> <mo>+</mo> <mfrac><mn>1</mn> <mrow><mn>1</mn> <mo>+</mo> <mfrac><mn>1</mn> <mi>a</mi></mfrac></mrow></mfrac></mrow></mfrac> </msqrt>",
474+
(
475+
r"<mfrac><mn>1</mn> <msqrt> <mrow><mn>1</mn> <mo>+</mo> <mfrac><mn>1</mn> "
476+
r"<mrow><mn>1</mn> <mo>+</mo> <mfrac><mn>1</mn> <mi>a</mi></mfrac></mrow></mfrac></mrow> "
477+
r"</msqrt></mfrac>"
478+
),
471479
"Fragile!",
472480
),
473481
"System`InputForm": (
474-
"<mrow><mi>Sqrt</mi> <mo>[</mo> <mrow><mtext>1</mtext> <mtext>&nbsp;/&nbsp;</mtext> <mrow><mo>(</mo> <mrow><mtext>1</mtext> <mtext>&nbsp;+&nbsp;</mtext> <mrow><mtext>1</mtext> <mtext>&nbsp;/&nbsp;</mtext> <mrow><mo>(</mo> <mrow><mtext>1</mtext> <mtext>&nbsp;+&nbsp;</mtext> <mrow><mtext>1</mtext> <mtext>&nbsp;/&nbsp;</mtext> <mi>a</mi></mrow></mrow> <mo>)</mo></mrow></mrow></mrow> <mo>)</mo></mrow></mrow> <mo>]</mo></mrow>",
482+
(
483+
r"<mrow><mtext>1</mtext> <mtext>&nbsp;/&nbsp;</mtext> <mrow><mi>Sqrt</mi> <mo>[</mo> "
484+
r"<mrow><mtext>1</mtext> <mtext>&nbsp;+&nbsp;</mtext> <mrow><mtext>1</mtext> <mtext>&nbsp;/&nbsp;</mtext> "
485+
r"<mrow><mo>(</mo> <mrow><mtext>1</mtext> <mtext>&nbsp;+&nbsp;</mtext> <mrow><mtext>1</mtext> <mtext>"
486+
r"&nbsp;/&nbsp;</mtext> <mi>a</mi></mrow></mrow> <mo>)</mo></mrow></mrow></mrow> <mo>]</mo></mrow></mrow>"
487+
),
475488
"Fragile!",
476489
),
477490
"System`OutputForm": (
478-
"<mrow><mi>Sqrt</mi> <mo>[</mo> <mrow><mn>1</mn> <mtext>&nbsp;/&nbsp;</mtext> <mrow><mo>(</mo> <mrow><mn>1</mn> <mtext>&nbsp;+&nbsp;</mtext> <mrow><mn>1</mn> <mtext>&nbsp;/&nbsp;</mtext> <mrow><mo>(</mo> <mrow><mn>1</mn> <mtext>&nbsp;+&nbsp;</mtext> <mrow><mn>1</mn> <mtext>&nbsp;/&nbsp;</mtext> <mi>a</mi></mrow></mrow> <mo>)</mo></mrow></mrow></mrow> <mo>)</mo></mrow></mrow> <mo>]</mo></mrow>",
491+
(
492+
r"<mrow><mn>1</mn> <mtext>&nbsp;/&nbsp;</mtext> <mrow><mi>Sqrt</mi> <mo>["
493+
r"</mo> <mrow><mn>1</mn> <mtext>&nbsp;+&nbsp;</mtext> <mrow><mn>1</mn> "
494+
r"<mtext>&nbsp;/&nbsp;</mtext> <mrow><mo>(</mo> <mrow><mn>1</mn> <mtext>"
495+
r"&nbsp;+&nbsp;</mtext> <mrow><mn>1</mn> <mtext>&nbsp;/&nbsp;</mtext> "
496+
r"<mi>a</mi></mrow></mrow> <mo>)</mo></mrow></mrow></mrow> <mo>]</mo></mrow></mrow>"
497+
),
479498
"Fragile!",
480499
),
481500
},
482501
"latex": {
483-
"System`StandardForm": "\\sqrt{\\frac{1}{1+\\frac{1}{1+\\frac{1}{a}}}}",
484-
"System`TraditionalForm": "\\sqrt{\\frac{1}{1+\\frac{1}{1+\\frac{1}{a}}}}",
485-
"System`InputForm": "\\text{Sqrt}\\left[1\\text{ / }\\left(1\\text{ + }1\\text{ / }\\left(1\\text{ + }1\\text{ / }a\\right)\\right)\\right]",
486-
"System`OutputForm": "\\text{Sqrt}\\left[1\\text{ / }\\left(1\\text{ + }1\\text{ / }\\left(1\\text{ + }1\\text{ / }a\\right)\\right)\\right]",
502+
"System`StandardForm": "\\frac{1}{\\sqrt{1+\\frac{1}{1+\\frac{1}{a}}}}",
503+
"System`TraditionalForm": "\\frac{1}{\\sqrt{1+\\frac{1}{1+\\frac{1}{a}}}}",
504+
"System`InputForm": r"1\text{ / }\text{Sqrt}\left[1\text{ + }1\text{ / }\left(1\text{ + }1\text{ / }a\right)\right]",
505+
"System`OutputForm": r"1\text{ / }\text{Sqrt}\left[1\text{ + }1\text{ / }\left(1\text{ + }1\text{ / }a\right)\right]",
487506
},
488507
},
489508
# Grids, arrays and matrices

0 commit comments

Comments
 (0)