Skip to content

Commit 42cf370

Browse files
committed
Refactor and use cdef
1 parent d527738 commit 42cf370

File tree

2 files changed

+27
-38
lines changed

2 files changed

+27
-38
lines changed

symengine/lib/symengine_wrapper.pyx

+23-38
Original file line numberDiff line numberDiff line change
@@ -4054,58 +4054,43 @@ def module_cleanup():
40544054
import atexit
40554055
atexit.register(module_cleanup)
40564056

4057-
def diff(ex, *args):
4058-
ex = sympify(ex)
4059-
prev = None
4057+
def diff(expr, *args):
4058+
cdef Basic ex = sympify(expr)
4059+
cdef Basic prev
40604060
cdef Basic b
40614061
cdef size_t i
4062-
length = len(args)
4062+
cdef size_t length = len(args)
40634063

40644064
if not length:
40654065
return ex
40664066

4067-
l = 0
4068-
x = args[l]
4069-
b = sympify(x)
4070-
l += 1
4067+
cdef size_t l = 0
4068+
cdef Basic cur_arg, next_arg
4069+
cur_arg = sympify(args[l])
40714070

4072-
while l <= length:
4073-
# Assume symbol 'x' or 'y' currently in b
4074-
# Pointer to next arg l is either derivative order or a separate symbol
4071+
while l < length:
4072+
if isinstance(cur_arg, Integer):
4073+
raise ValueError("Unexpected integer argument")
40754074

4076-
prev = b
4077-
4078-
if l == length:
4075+
if l + 1 == length:
40794076
# No next argument, differentiate with no integer argument
4080-
if isinstance(b, Integer):
4081-
raise ValueError("Unexpected integer argument")
4082-
ex = ex._diff(b)
4083-
break
4077+
return ex._diff(cur_arg)
40844078

4085-
x = args[l]
4086-
b = sympify(x)
4079+
next_arg = sympify(args[l + 1])
40874080
# Check if the next arg was derivative order
4088-
if isinstance(b, Integer):
4089-
i = int(b)
4090-
for j in range(i):
4091-
ex = ex._diff(prev)
4092-
4093-
# Move forward to point at next symbol
4094-
l += 1
4081+
if isinstance(next_arg, Integer):
4082+
i = int(next_arg)
4083+
for _ in range(i):
4084+
ex = ex._diff(cur_arg)
4085+
l += 2
40954086
if l == length:
4096-
break
4097-
4098-
x = args[l]
4099-
b = sympify(x)
4100-
if isinstance(b, Integer):
4101-
raise ValueError("Unexpected double integer argument")
4087+
return ex
4088+
cur_arg = sympify(args[l])
41024089
else:
4103-
# Separate symbol and no derivative order, differentiate now
4104-
ex = ex._diff(prev)
4105-
4106-
l += 1
4090+
ex = ex._diff(cur_arg)
4091+
l += 1
4092+
cur_arg = next_arg
41074093

4108-
return ex
41094094

41104095
def expand(x, deep=True):
41114096
return sympify(x).expand(deep)

symengine/tests/test_functions.py

+4
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
loggamma, beta, polygamma, digamma, trigamma, sign, floor, ceiling, conjugate,
66
nan, Float, UnevaluatedExpr
77
)
8+
from symengine.utilities import raises
89

910
import unittest
1011

@@ -64,6 +65,9 @@ def test_derivative():
6465
assert f.diff(x).diff(x).args == (f, x, x)
6566
assert f.diff(x, 0) == f
6667
assert f.diff(x, 0) == Derivative(function_symbol("f", x), x, 0)
68+
raises(ValueError, lambda: f.diff(0))
69+
raises(ValueError, lambda: f.diff(x, 0, 0))
70+
raises(ValueError, lambda: f.diff(x, y, 0, 0, x))
6771

6872
g = function_symbol("f", y)
6973
assert g.diff(x) == 0

0 commit comments

Comments
 (0)