Skip to content

Commit e3f4bd2

Browse files
authored
Merge pull request #126 from ShikharJ/MinMax
Wrapped Min and Max Functions
2 parents 3b69a22 + cb5be46 commit e3f4bd2

File tree

4 files changed

+121
-3
lines changed

4 files changed

+121
-3
lines changed

symengine/lib/symengine.pxd

+12
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,8 @@ cdef extern from "<symengine/symengine_rcp.h>" namespace "SymEngine":
143143
RCP[const FunctionSymbol] rcp_static_cast_FunctionSymbol "SymEngine::rcp_static_cast<const SymEngine::FunctionSymbol>"(RCP[const Basic] &b) nogil
144144
RCP[const FunctionWrapper] rcp_static_cast_FunctionWrapper "SymEngine::rcp_static_cast<const SymEngine::FunctionWrapper>"(RCP[const Basic] &b) nogil
145145
RCP[const Abs] rcp_static_cast_Abs "SymEngine::rcp_static_cast<const SymEngine::Abs>"(RCP[const Basic] &b) nogil
146+
RCP[const Max] rcp_static_cast_Max "SymEngine::rcp_static_cast<const SymEngine::Max>"(RCP[const Basic] &b) nogil
147+
RCP[const Min] rcp_static_cast_Min "SymEngine::rcp_static_cast<const SymEngine::Min>"(RCP[const Basic] &b) nogil
146148
RCP[const Gamma] rcp_static_cast_Gamma "SymEngine::rcp_static_cast<const SymEngine::Gamma>"(RCP[const Basic] &b) nogil
147149
RCP[const Derivative] rcp_static_cast_Derivative "SymEngine::rcp_static_cast<const SymEngine::Derivative>"(RCP[const Basic] &b) nogil
148150
RCP[const Subs] rcp_static_cast_Subs "SymEngine::rcp_static_cast<const SymEngine::Subs>"(RCP[const Basic] &b) nogil
@@ -251,6 +253,8 @@ cdef extern from "<symengine/basic.h>" namespace "SymEngine":
251253
bool is_a_ASech "SymEngine::is_a<SymEngine::ASech>"(const Basic &b) nogil
252254
bool is_a_FunctionSymbol "SymEngine::is_a<SymEngine::FunctionSymbol>"(const Basic &b) nogil
253255
bool is_a_Abs "SymEngine::is_a<SymEngine::Abs>"(const Basic &b) nogil
256+
bool is_a_Max "SymEngine::is_a<SymEngine::Max>"(const Basic &b) nogil
257+
bool is_a_Min "SymEngine::is_a<SymEngine::Min>"(const Basic &b) nogil
254258
bool is_a_Gamma "SymEngine::is_a<SymEngine::Gamma>"(const Basic &b) nogil
255259
bool is_a_Derivative "SymEngine::is_a<SymEngine::Derivative>"(const Basic &b) nogil
256260
bool is_a_Subs "SymEngine::is_a<SymEngine::Subs>"(const Basic &b) nogil
@@ -433,6 +437,8 @@ cdef extern from "<symengine/functions.h>" namespace "SymEngine":
433437
cdef RCP[const Basic] asech(RCP[const Basic] &arg) nogil except+
434438
cdef RCP[const Basic] function_symbol(string name, const vec_basic &arg) nogil except+
435439
cdef RCP[const Basic] abs(RCP[const Basic] &arg) nogil except+
440+
cdef RCP[const Basic] max(const vec_basic &arg) nogil except+
441+
cdef RCP[const Basic] min(const vec_basic &arg) nogil except+
436442
cdef RCP[const Basic] gamma(RCP[const Basic] &arg) nogil except+
437443
cdef RCP[const Basic] atan2(RCP[const Basic] &num, RCP[const Basic] &den) nogil except+
438444

@@ -539,6 +545,12 @@ cdef extern from "<symengine/functions.h>" namespace "SymEngine":
539545
cdef cppclass Abs(Function):
540546
RCP[const Basic] get_arg() nogil
541547

548+
cdef cppclass Max(Function):
549+
pass
550+
551+
cdef cppclass Min(Function):
552+
pass
553+
542554
cdef cppclass Gamma(Function):
543555
pass
544556

symengine/lib/symengine_wrapper.pyx

+74
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,10 @@ cdef c2py(RCP[const symengine.Basic] o):
4545
r = FunctionSymbol.__new__(FunctionSymbol)
4646
elif (symengine.is_a_Abs(deref(o))):
4747
r = Abs.__new__(Abs)
48+
elif (symengine.is_a_Max(deref(o))):
49+
r = Max.__new__(Max)
50+
elif (symengine.is_a_Min(deref(o))):
51+
r = Min.__new__(Min)
4852
elif (symengine.is_a_Gamma(deref(o))):
4953
r = Gamma.__new__(Gamma)
5054
elif (symengine.is_a_Derivative(deref(o))):
@@ -208,6 +212,10 @@ def sympy2symengine(a, raise_error=False):
208212
return log(a.args[0])
209213
elif isinstance(a, sympy.Abs):
210214
return abs(sympy2symengine(a.args[0], raise_error))
215+
elif isinstance(a, sympy.Max):
216+
return _max(*a.args)
217+
elif isinstance(a, sympy.Min):
218+
return _min(*a.args)
211219
elif isinstance(a, sympy.gamma):
212220
return gamma(a.args[0])
213221
elif isinstance(a, sympy.Derivative):
@@ -659,6 +667,20 @@ cdef class Basic(object):
659667
def has(self, *symbols):
660668
return any([has_symbol(self, symbol) for symbol in symbols])
661669

670+
def args_as_sage(Basic self):
671+
cdef symengine.vec_basic Y = deref(self.thisptr).get_args()
672+
s = []
673+
for i in range(Y.size()):
674+
s.append(c2py(<RCP[const symengine.Basic]>(Y[i]))._sage_())
675+
return s
676+
677+
def args_as_sympy(Basic self):
678+
cdef symengine.vec_basic Y = deref(self.thisptr).get_args()
679+
s = []
680+
for i in range(Y.size()):
681+
s.append(c2py(<RCP[const symengine.Basic]>(Y[i]))._sympy_())
682+
return s
683+
662684
def series(ex, x=None, x0=0, n=6, as_deg_coef_pair=False):
663685
# TODO: check for x0 an infinity, see sympy/core/expr.py
664686
# TODO: nonzero x0
@@ -1398,6 +1420,42 @@ cdef class Abs(Function):
13981420
return abs(arg)
13991421

14001422

1423+
class Max(Function):
1424+
1425+
def __new__(cls, *args):
1426+
if not args:
1427+
return super(Max, cls).__new__(cls)
1428+
return _max(*args)
1429+
1430+
def _sympy_(self):
1431+
import sympy
1432+
s = self.args_as_sympy()
1433+
return sympy.Max(*s)
1434+
1435+
def _sage_(self):
1436+
import sage.all as sage
1437+
s = self.args_as_sage()
1438+
return sage.max(*s)
1439+
1440+
1441+
class Min(Function):
1442+
1443+
def __new__(cls, *args):
1444+
if not args:
1445+
return super(Min, cls).__new__(cls)
1446+
return _min(*args)
1447+
1448+
def _sympy_(self):
1449+
import sympy
1450+
s = self.args_as_sympy()
1451+
return sympy.Min(*s)
1452+
1453+
def _sage_(self):
1454+
import sage.all as sage
1455+
s = self.args_as_sage()
1456+
return sage.min(*s)
1457+
1458+
14011459
cdef class Derivative(Basic):
14021460

14031461
@property
@@ -2342,6 +2400,22 @@ def log(x, y = None):
23422400
cdef Basic Y = _sympify(y)
23432401
return c2py(symengine.log(X.thisptr, Y.thisptr))
23442402

2403+
def _max(*args):
2404+
cdef symengine.vec_basic v
2405+
cdef Basic e_
2406+
for e in args:
2407+
e_ = sympify(e)
2408+
v.push_back(e_.thisptr)
2409+
return c2py(symengine.max(v))
2410+
2411+
def _min(*args):
2412+
cdef symengine.vec_basic v
2413+
cdef Basic e_
2414+
for e in args:
2415+
e_ = sympify(e)
2416+
v.push_back(e_.thisptr)
2417+
return c2py(symengine.min(v))
2418+
23452419
def gamma(x):
23462420
cdef Basic X = _sympify(x)
23472421
return c2py(symengine.gamma(X.thisptr))

symengine/sympy_compat.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
ImmutableDenseMatrix, DenseMatrix, Matrix, Derivative, exp,
77
nextprime, mod_inverse, primitive_root, Lambdify as lambdify,
88
symarray, diff, eye, diag, ones, zeros, expand, Subs,
9-
FunctionSymbol as AppliedUndef)
9+
FunctionSymbol as AppliedUndef, Max, Min)
1010
from types import ModuleType
1111
import sys
1212

symengine/tests/test_sympy_compat.py

+34-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from symengine.sympy_compat import (Integer, Rational, S, Basic, Add, Mul,
2-
Pow, symbols, Symbol, log, sin, sech, csch, zeros, atan2, Number, Float,
3-
symengine)
2+
Pow, symbols, Symbol, log, sin, cos, sech, csch, zeros, atan2, Number, Float,
3+
symengine, Min, Max)
44
from symengine.utilities import raises
55

66

@@ -86,6 +86,38 @@ def test_Pow():
8686
assert isinstance(i, Basic)
8787

8888

89+
def test_Max():
90+
x = Symbol("x")
91+
y = Symbol("y")
92+
z = Symbol("z")
93+
assert Max(Integer(6)/3, 1) == 2
94+
assert Max(-2, 2) == 2
95+
assert Max(2, 2) == 2
96+
assert Max(0.2, 0.3) == 0.3
97+
assert Max(x, x) == x
98+
assert Max(x, y) == Max(y, x)
99+
assert Max(x, y, z) == Max(z, y, x)
100+
assert Max(x, Max(y, z)) == Max(z, y, x)
101+
assert Max(1000, 100, -100, x, y, z) == Max(x, y, z, 1000)
102+
assert Max(cos(x), sin(x)) == Max(sin(x), cos(x))
103+
104+
105+
def test_Min():
106+
x = Symbol("x")
107+
y = Symbol("y")
108+
z = Symbol("z")
109+
assert Min(Integer(6)/3, 1) == 1
110+
assert Min(-2, 2) == -2
111+
assert Min(2, 2) == 2
112+
assert Min(0.2, 0.3) == 0.2
113+
assert Min(x, x) == x
114+
assert Min(x, y) == Min(y, x)
115+
assert Min(x, y, z) == Min(z, y, x)
116+
assert Min(x, Min(y, z)) == Min(z, y, x)
117+
assert Min(1000, 100, -100, x, y, z) == Min(x, y, z, -100)
118+
assert Min(cos(x), sin(x)) == Min(cos(x), sin(x))
119+
120+
89121
def test_sin():
90122
x = symbols("x")
91123
i = sin(0)

0 commit comments

Comments
 (0)