Skip to content

Commit 9cb7ea0

Browse files
authored
Merge pull request #292 from isuruf/ctypes
as_ctypes function for Lambdify
2 parents 45d649e + cb37158 commit 9cb7ea0

File tree

3 files changed

+61
-27
lines changed

3 files changed

+61
-27
lines changed

symengine/lib/symengine_wrapper.pxd

+3-6
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,9 @@ cdef class _Lambdify(object):
3939

4040
cdef _init(self, symengine.vec_basic& args_, symengine.vec_basic& outs_, cppbool cse)
4141
cdef _load(self, const string &s)
42-
cdef void unsafe_real_ptr(self, double *inp, double *out) nogil
4342
cpdef unsafe_real(self,
4443
double[::1] inp, double[::1] out,
4544
int inp_offset=*, int out_offset=*)
46-
cdef void unsafe_complex_ptr(self, double complex *inp, double complex *out) nogil
4745
cpdef unsafe_complex(self, double complex[::1] inp, double complex[::1] out,
4846
int inp_offset=*, int out_offset=*)
4947
cpdef eval_real(self, inp, out)
@@ -53,17 +51,16 @@ cdef class LambdaDouble(_Lambdify):
5351
cdef vector[symengine.LambdaRealDoubleVisitor] lambda_double
5452
cdef vector[symengine.LambdaComplexDoubleVisitor] lambda_double_complex
5553
cdef _init(self, symengine.vec_basic& args_, symengine.vec_basic& outs_, cppbool cse)
56-
cdef void unsafe_real_ptr(self, double *inp, double *out) nogil
5754
cpdef unsafe_real(self, double[::1] inp, double[::1] out, int inp_offset=*, int out_offset=*)
58-
cdef void unsafe_complex_ptr(self, double complex *inp, double complex *out) nogil
5955
cpdef unsafe_complex(self, double complex[::1] inp, double complex[::1] out, int inp_offset=*, int out_offset=*)
6056
cpdef as_scipy_low_level_callable(self)
57+
cpdef as_ctypes(self)
6158

6259
IF HAVE_SYMENGINE_LLVM:
6360
cdef class LLVMDouble(_Lambdify):
6461
cdef vector[symengine.LLVMDoubleVisitor] lambda_double
6562
cdef _init(self, symengine.vec_basic& args_, symengine.vec_basic& outs_, cppbool cse)
6663
cdef _load(self, const string &s)
67-
cdef void unsafe_real_ptr(self, double *inp, double *out) nogil
6864
cpdef unsafe_real(self, double[::1] inp, double[::1] out, int inp_offset=*, int out_offset=*)
69-
cpdef as_scipy_low_level_callable(self)
65+
cpdef as_scipy_low_level_callable(self)
66+
cpdef as_ctypes(self)

symengine/lib/symengine_wrapper.pyx

+45-21
Original file line numberDiff line numberDiff line change
@@ -4507,19 +4507,11 @@ cdef class _Lambdify(object):
45074507
cdef _load(self, const string &s):
45084508
raise ValueError("Not supported")
45094509

4510-
cdef void unsafe_real_ptr(self, double *inp, double *out) nogil:
4511-
with gil:
4512-
raise ValueError("Not supported")
4513-
45144510
cpdef unsafe_real(self,
45154511
double[::1] inp, double[::1] out,
45164512
int inp_offset=0, int out_offset=0):
45174513
raise ValueError("Not supported")
45184514

4519-
cdef void unsafe_complex_ptr(self, double complex *inp, double complex *out) nogil:
4520-
with gil:
4521-
raise ValueError("Not supported")
4522-
45234515
cpdef unsafe_complex(self, double complex[::1] inp, double complex[::1] out,
45244516
int inp_offset=0, int out_offset=0):
45254517
raise ValueError("Not supported")
@@ -4670,6 +4662,9 @@ cdef double _scipy_callback_lambda_real(int n, double *x, void *user_data) nogil
46704662
deref(lamb).call(&result, x)
46714663
return result
46724664

4665+
cdef void _ctypes_callback_lambda_real(double *output, const double *input, void *user_data) nogil:
4666+
cdef symengine.LambdaRealDoubleVisitor* lamb = <symengine.LambdaRealDoubleVisitor *>user_data
4667+
deref(lamb).call(output, input)
46734668

46744669
IF HAVE_SYMENGINE_LLVM:
46754670
cdef double _scipy_callback_llvm_real(int n, double *x, void *user_data) nogil:
@@ -4678,6 +4673,10 @@ IF HAVE_SYMENGINE_LLVM:
46784673
deref(lamb).call(&result, x)
46794674
return result
46804675

4676+
cdef void _ctypes_callback_llvm_real(double *output, const double *input, void *user_data) nogil:
4677+
cdef symengine.LLVMDoubleVisitor* lamb = <symengine.LLVMDoubleVisitor *>user_data
4678+
deref(lamb).call(output, input)
4679+
46814680

46824681
def create_low_level_callable(lambdify, *args):
46834682
from scipy import LowLevelCallable
@@ -4698,17 +4697,11 @@ cdef class LambdaDouble(_Lambdify):
46984697
self.lambda_double_complex.resize(1)
46994698
self.lambda_double_complex[0].init(args_, outs_, cse)
47004699

4701-
cdef void unsafe_real_ptr(self, double *inp, double *out) nogil:
4702-
self.lambda_double[0].call(out, inp)
4703-
47044700
cpdef unsafe_real(self, double[::1] inp, double[::1] out, int inp_offset=0, int out_offset=0):
4705-
self.unsafe_real_ptr(&inp[inp_offset], &out[out_offset])
4706-
4707-
cdef void unsafe_complex_ptr(self, double complex *inp, double complex *out) nogil:
4708-
self.lambda_double_complex[0].call(out, inp)
4701+
self.lambda_double[0].call(&out[out_offset], &inp[inp_offset])
47094702

47104703
cpdef unsafe_complex(self, double complex[::1] inp, double complex[::1] out, int inp_offset=0, int out_offset=0):
4711-
self.unsafe_complex_ptr(&inp[inp_offset], &out[out_offset])
4704+
self.lambda_double_complex[0].call(&out[out_offset], &inp[inp_offset])
47124705

47134706
cpdef as_scipy_low_level_callable(self):
47144707
from ctypes import c_double, c_void_p, c_int, cast, POINTER, CFUNCTYPE
@@ -4721,6 +4714,23 @@ cdef class LambdaDouble(_Lambdify):
47214714
addr2 = cast(<size_t>&self.lambda_double[0], c_void_p)
47224715
return create_low_level_callable(self, addr1, addr2)
47234716

4717+
cpdef as_ctypes(self):
4718+
"""
4719+
Returns a tuple with first element being a ctypes function with signature
4720+
4721+
void func(double * output, const double *input, void *user_data)
4722+
4723+
and second element being a ctypes void pointer. This void pointer needs to be
4724+
passed as input to the function as the third argument `user_data`.
4725+
"""
4726+
from ctypes import c_double, c_void_p, c_int, cast, POINTER, CFUNCTYPE
4727+
if not self.real:
4728+
raise RuntimeError("Lambda function has to be real")
4729+
addr1 = cast(<size_t>&_ctypes_callback_lambda_real,
4730+
CFUNCTYPE(c_void_p, POINTER(c_double), POINTER(c_double), c_void_p))
4731+
addr2 = cast(<size_t>&self.lambda_double[0], c_void_p)
4732+
return addr1, addr2
4733+
47244734

47254735
IF HAVE_SYMENGINE_LLVM:
47264736
cdef class LLVMDouble(_Lambdify):
@@ -4740,23 +4750,37 @@ IF HAVE_SYMENGINE_LLVM:
47404750
return llvm_loading_func, (self.args_size, self.tot_out_size, self.out_shapes, self.real, \
47414751
self.n_exprs, self.order, self.accum_out_sizes, self.numpy_dtype, s)
47424752

4743-
cdef void unsafe_real_ptr(self, double *inp, double *out) nogil:
4744-
self.lambda_double[0].call(out, inp)
4745-
47464753
cpdef unsafe_real(self, double[::1] inp, double[::1] out, int inp_offset=0, int out_offset=0):
4747-
self.unsafe_real_ptr(&inp[inp_offset], &out[out_offset])
4754+
self.lambda_double[0].call(&out[out_offset], &inp[inp_offset])
47484755

47494756
cpdef as_scipy_low_level_callable(self):
47504757
from ctypes import c_double, c_void_p, c_int, cast, POINTER, CFUNCTYPE
47514758
if not self.real:
47524759
raise RuntimeError("Lambda function has to be real")
47534760
if self.tot_out_size > 1:
47544761
raise RuntimeError("SciPy LowLevelCallable supports only functions with 1 output")
4755-
addr1 = cast(<size_t>&_scipy_callback_lambda_real,
4762+
addr1 = cast(<size_t>&_scipy_callback_llvm_real,
47564763
CFUNCTYPE(c_double, c_int, POINTER(c_double), c_void_p))
47574764
addr2 = cast(<size_t>&self.lambda_double[0], c_void_p)
47584765
return create_low_level_callable(self, addr1, addr2)
47594766

4767+
cpdef as_ctypes(self):
4768+
"""
4769+
Returns a tuple with first element being a ctypes function with signature
4770+
4771+
void func(double * output, const double *input, void *user_data)
4772+
4773+
and second element being a ctypes void pointer. This void pointer needs to be
4774+
passed as input to the function as the third argument `user_data`.
4775+
"""
4776+
from ctypes import c_double, c_void_p, c_int, cast, POINTER, CFUNCTYPE
4777+
if not self.real:
4778+
raise RuntimeError("Lambda function has to be real")
4779+
addr1 = cast(<size_t>&_ctypes_callback_llvm_real,
4780+
CFUNCTYPE(c_void_p, POINTER(c_double), POINTER(c_double), c_void_p))
4781+
addr2 = cast(<size_t>&self.lambda_double[0], c_void_p)
4782+
return addr1, addr2
4783+
47604784
def llvm_loading_func(*args):
47614785
return LLVMDouble(args, _load=True)
47624786

symengine/tests/test_lambdify.py

+13
Original file line numberDiff line numberDiff line change
@@ -806,3 +806,16 @@ def test_scipy():
806806
lmb = se.Lambdify(args, [se.exp(-x*t)/t**5], as_scipy=True)
807807
res = integrate.nquad(lmb, [[1, np.inf], [0, np.inf]])
808808
assert abs(res[0] - 0.2) < 1e-7
809+
810+
811+
@unittest.skipUnless(have_numpy, "Numpy not installed")
812+
def test_as_ctypes():
813+
import numpy as np
814+
import ctypes
815+
x, y, z = se.symbols('x, y, z')
816+
l = se.Lambdify([x, y, z], [x+y+z, x*y*z+1])
817+
addr1, addr2 = l.as_ctypes()
818+
inp = np.array([1,2,3], dtype=np.double)
819+
out = np.array([0, 0], dtype=np.double)
820+
addr1(out.ctypes.data_as(ctypes.POINTER(ctypes.c_double)), inp.ctypes.data_as(ctypes.POINTER(ctypes.c_double)), addr2)
821+
assert np.all(out == [6, 7])

0 commit comments

Comments
 (0)