Skip to content

Commit 9a30d58

Browse files
committed
Add as_ctypes function to Lambdify
1 parent 45d649e commit 9a30d58

File tree

2 files changed

+45
-2
lines changed

2 files changed

+45
-2
lines changed

symengine/lib/symengine_wrapper.pxd

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ cdef class LambdaDouble(_Lambdify):
5858
cdef void unsafe_complex_ptr(self, double complex *inp, double complex *out) nogil
5959
cpdef unsafe_complex(self, double complex[::1] inp, double complex[::1] out, int inp_offset=*, int out_offset=*)
6060
cpdef as_scipy_low_level_callable(self)
61+
cpdef as_ctypes(self)
6162

6263
IF HAVE_SYMENGINE_LLVM:
6364
cdef class LLVMDouble(_Lambdify):
@@ -66,4 +67,5 @@ IF HAVE_SYMENGINE_LLVM:
6667
cdef _load(self, const string &s)
6768
cdef void unsafe_real_ptr(self, double *inp, double *out) nogil
6869
cpdef unsafe_real(self, double[::1] inp, double[::1] out, int inp_offset=*, int out_offset=*)
69-
cpdef as_scipy_low_level_callable(self)
70+
cpdef as_scipy_low_level_callable(self)
71+
cpdef as_ctypes(self)

symengine/lib/symengine_wrapper.pyx

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4670,6 +4670,9 @@ cdef double _scipy_callback_lambda_real(int n, double *x, void *user_data) nogil
46704670
deref(lamb).call(&result, x)
46714671
return result
46724672

4673+
cdef void _ctypes_callback_lambda_real(double *output, const double *input, void *user_data) nogil:
4674+
cdef symengine.LambdaRealDoubleVisitor* lamb = <symengine.LambdaRealDoubleVisitor *>user_data
4675+
deref(lamb).call(output, input)
46734676

46744677
IF HAVE_SYMENGINE_LLVM:
46754678
cdef double _scipy_callback_llvm_real(int n, double *x, void *user_data) nogil:
@@ -4678,6 +4681,10 @@ IF HAVE_SYMENGINE_LLVM:
46784681
deref(lamb).call(&result, x)
46794682
return result
46804683

4684+
cdef void _ctypes_callback_llvm_real(double *output, const double *input, void *user_data) nogil:
4685+
cdef symengine.LLVMDoubleVisitor* lamb = <symengine.LLVMDoubleVisitor *>user_data
4686+
deref(lamb).call(output, input)
4687+
46814688

46824689
def create_low_level_callable(lambdify, *args):
46834690
from scipy import LowLevelCallable
@@ -4721,6 +4728,23 @@ cdef class LambdaDouble(_Lambdify):
47214728
addr2 = cast(<size_t>&self.lambda_double[0], c_void_p)
47224729
return create_low_level_callable(self, addr1, addr2)
47234730

4731+
cpdef as_ctypes(self):
4732+
"""
4733+
Returns a tuple with first element being a ctypes function with signature
4734+
4735+
void func(double * output, const double *input, void *user_data)
4736+
4737+
and second element being a ctypes void pointer. This void pointer needs to be
4738+
passed as input to the function as the third argument `user_data`.
4739+
"""
4740+
from ctypes import c_double, c_void_p, c_int, cast, POINTER, CFUNCTYPE
4741+
if not self.real:
4742+
raise RuntimeError("Lambda function has to be real")
4743+
addr1 = cast(<size_t>&_ctypes_callback_lambda_real,
4744+
CFUNCTYPE(c_void_p, POINTER(c_double), POINTER(c_double), c_void_p))
4745+
addr2 = cast(<size_t>&self.lambda_double[0], c_void_p)
4746+
return addr1, addr2
4747+
47244748

47254749
IF HAVE_SYMENGINE_LLVM:
47264750
cdef class LLVMDouble(_Lambdify):
@@ -4752,11 +4776,28 @@ IF HAVE_SYMENGINE_LLVM:
47524776
raise RuntimeError("Lambda function has to be real")
47534777
if self.tot_out_size > 1:
47544778
raise RuntimeError("SciPy LowLevelCallable supports only functions with 1 output")
4755-
addr1 = cast(<size_t>&_scipy_callback_lambda_real,
4779+
addr1 = cast(<size_t>&_scipy_callback_llvm_real,
47564780
CFUNCTYPE(c_double, c_int, POINTER(c_double), c_void_p))
47574781
addr2 = cast(<size_t>&self.lambda_double[0], c_void_p)
47584782
return create_low_level_callable(self, addr1, addr2)
47594783

4784+
cpdef as_ctypes(self):
4785+
"""
4786+
Returns a tuple with first element being a ctypes function with signature
4787+
4788+
void func(double * output, const double *input, void *user_data)
4789+
4790+
and second element being a ctypes void pointer. This void pointer needs to be
4791+
passed as input to the function as the third argument `user_data`.
4792+
"""
4793+
from ctypes import c_double, c_void_p, c_int, cast, POINTER, CFUNCTYPE
4794+
if not self.real:
4795+
raise RuntimeError("Lambda function has to be real")
4796+
addr1 = cast(<size_t>&_ctypes_callback_llvm_real,
4797+
CFUNCTYPE(c_void_p, POINTER(c_double), POINTER(c_double), c_void_p))
4798+
addr2 = cast(<size_t>&self.lambda_double[0], c_void_p)
4799+
return addr1, addr2
4800+
47604801
def llvm_loading_func(*args):
47614802
return LLVMDouble(args, _load=True)
47624803

0 commit comments

Comments
 (0)