Skip to content

Commit 7a4896c

Browse files
authored
Merge pull request #359 from GSam/master
symengine_wrapper.pyx: Calculate the zero-th derivative
2 parents 58d8678 + cb409e1 commit 7a4896c

File tree

2 files changed

+51
-12
lines changed

2 files changed

+51
-12
lines changed

symengine/lib/symengine_wrapper.pyx

+34-12
Original file line numberDiff line numberDiff line change
@@ -4063,21 +4063,43 @@ def module_cleanup():
40634063
import atexit
40644064
atexit.register(module_cleanup)
40654065

4066-
def diff(ex, *args):
4067-
ex = sympify(ex)
4068-
prev = 0
4066+
def diff(expr, *args):
4067+
cdef Basic ex = sympify(expr)
4068+
cdef Basic prev
40694069
cdef Basic b
40704070
cdef size_t i
4071-
for x in args:
4072-
b = sympify(x)
4073-
if isinstance(b, Integer):
4074-
i = int(b) - 1
4075-
for j in range(i):
4076-
ex = ex._diff(prev)
4071+
cdef size_t length = len(args)
4072+
4073+
if not length:
4074+
return ex
4075+
4076+
cdef size_t l = 0
4077+
cdef Basic cur_arg, next_arg
4078+
cur_arg = sympify(args[l])
4079+
4080+
while l < length:
4081+
if isinstance(cur_arg, Integer):
4082+
raise ValueError("Unexpected integer argument")
4083+
4084+
if l + 1 == length:
4085+
# No next argument, differentiate with no integer argument
4086+
return ex._diff(cur_arg)
4087+
4088+
next_arg = sympify(args[l + 1])
4089+
# Check if the next arg was derivative order
4090+
if isinstance(next_arg, Integer):
4091+
i = int(next_arg)
4092+
for _ in range(i):
4093+
ex = ex._diff(cur_arg)
4094+
l += 2
4095+
if l == length:
4096+
return ex
4097+
cur_arg = sympify(args[l])
40774098
else:
4078-
ex = ex._diff(b)
4079-
prev = b
4080-
return ex
4099+
ex = ex._diff(cur_arg)
4100+
l += 1
4101+
cur_arg = next_arg
4102+
40814103

40824104
def expand(x, deep=True):
40834105
return sympify(x).expand(deep)

symengine/tests/test_functions.py

+17
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
loggamma, beta, polygamma, digamma, trigamma, sign, floor, ceiling, conjugate,
66
nan, Float, UnevaluatedExpr
77
)
8+
from symengine.utilities import raises
89

910
import unittest
1011

@@ -62,6 +63,11 @@ def test_derivative():
6263
assert f.diff(y) == 0
6364
assert f.diff(x).args == (f, x)
6465
assert f.diff(x).diff(x).args == (f, x, x)
66+
assert f.diff(x, 0) == f
67+
assert f.diff(x, 0) == Derivative(function_symbol("f", x), x, 0)
68+
raises(ValueError, lambda: f.diff(0))
69+
raises(ValueError, lambda: f.diff(x, 0, 0))
70+
raises(ValueError, lambda: f.diff(x, y, 0, 0, x))
6571

6672
g = function_symbol("f", y)
6773
assert g.diff(x) == 0
@@ -84,6 +90,17 @@ def test_derivative():
8490
assert g == fxy.diff(y, 1, x, 2)
8591
assert g == fxy.diff(y, x, 2)
8692

93+
h = Derivative(Function("f")(x, y), x, 0, y, 1)
94+
assert h == fxy.diff(x, 0, y)
95+
assert h == fxy.diff(y, x, 0)
96+
97+
i = Derivative(Function("f")(x, y), x, 0, y, 1, x, 1)
98+
assert i == fxy.diff(x, 0, y, x, 1)
99+
assert i == fxy.diff(x, 0, y, x)
100+
assert i == fxy.diff(y, x)
101+
assert i == fxy.diff(y, 1, x, 1)
102+
assert i == fxy.diff(y, 1, x)
103+
87104

88105
def test_abs():
89106
x = Symbol("x")

0 commit comments

Comments
 (0)