Skip to content

Commit 050e6ab

Browse files
authored
Merge pull request #87 from bjodah/more-Lambdify-tests
Add tests for more than 255 arguments
2 parents 97c5a21 + beedc06 commit 050e6ab

File tree

3 files changed

+75
-16
lines changed

3 files changed

+75
-16
lines changed

symengine/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
sin, cos, tan, cot, csc, sec, asin, acos, atan, acot, acsc, asec,
55
sinh, cosh, tanh, coth, asinh, acosh, atanh, acoth, Lambdify,
66
LambdifyCSE, DictBasic, series, symarray, diff, zeros, eye, diag,
7-
ones, zeros, expand, has_symbol, UndefFunction)
7+
ones, zeros, add, expand, has_symbol, UndefFunction)
88
from .utilities import var, symbols
99

1010
if have_mpfr:

symengine/lib/symengine_wrapper.pyx

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2575,7 +2575,7 @@ cdef class Lambdify(object):
25752575
Lambdify instances are callbacks that numerically evaluate their symbolic
25762576
expressions from user provided input (real or complex) into (possibly user
25772577
provided) output buffers (real or complex). Multidimensional data are
2578-
processed in their most cache-friendly way ("ravelled").
2578+
processed in their most cache-friendly way (i.e. "ravelled").
25792579
25802580
Parameters
25812581
----------
@@ -2599,7 +2599,7 @@ cdef class Lambdify(object):
25992599
[ 9., 24.]
26002600
26012601
"""
2602-
cdef size_t inp_size, out_size
2602+
cdef size_t args_size, out_size
26032603
cdef tuple out_shape
26042604
cdef readonly bool real
26052605
cdef vector[symengine.LambdaRealDoubleVisitor] lambda_double
@@ -2615,7 +2615,7 @@ cdef class Lambdify(object):
26152615
int idx = 0
26162616
self.real = real
26172617
self.out_shape = get_shape(exprs)
2618-
self.inp_size = _size(args)
2618+
self.args_size = _size(args)
26192619
self.out_size = reduce(mul, self.out_shape)
26202620

26212621
if isinstance(args, DenseMatrix):
@@ -2659,7 +2659,7 @@ cdef class Lambdify(object):
26592659
cdef vector[ValueType] inp_
26602660
cdef size_t idx, ninp = inp.size, nout = out.size
26612661

2662-
if inp.size != self.inp_size:
2662+
if inp.size != self.args_size:
26632663
raise ValueError("Size of inp incompatible with number of args.")
26642664
if out.size != self.out_size:
26652665
raise ValueError("Size of out incompatible with number of exprs.")
@@ -2720,10 +2720,10 @@ cdef class Lambdify(object):
27202720
inp = tuple(inp)
27212721
inp_shape = (len(inp),)
27222722
inp_size = reduce(mul, inp_shape)
2723-
if inp_size % self.inp_size != 0:
2723+
if inp_size % self.args_size != 0:
27242724
raise ValueError("Broadcasting failed")
2725-
nbroadcast = inp_size // self.inp_size
2726-
if nbroadcast > 1 and self.inp_size == 1 and inp_shape[-1] != 1: # Implicit reshape
2725+
nbroadcast = inp_size // self.args_size
2726+
if nbroadcast > 1 and self.args_size == 1 and inp_shape[-1] != 1: # Implicit reshape
27272727
inp_shape = inp_shape + (1,)
27282728
new_out_shape = inp_shape[:-1] + self.out_shape
27292729
new_out_size = nbroadcast * self.out_size
@@ -2758,7 +2758,6 @@ cdef class Lambdify(object):
27582758
if out is None:
27592759
# allocate output container
27602760
if use_numpy:
2761-
nbroadcast = inp.size // self.inp_size
27622761
out = np.empty(new_out_size, dtype=np.float64 if
27632762
self.real else np.complex128)
27642763
else:
@@ -2771,17 +2770,20 @@ cdef class Lambdify(object):
27712770
reshape_out = len(new_out_shape) > 1
27722771
else:
27732772
if use_numpy:
2774-
out = np.asarray(out, dtype=np.float64 if
2775-
self.real else np.complex128) # copy if needed
2773+
try:
2774+
out_dtype = out.dtype
2775+
except AttributeError:
2776+
out = np.asarray(out)
2777+
out_dtype = out.dtype
2778+
if out_dtype != (np.float64 if self.real else np.complex128):
2779+
raise TypeError("Output array is of incorrect type")
27762780
if out.size < new_out_size:
27772781
raise ValueError("Incompatible size of output argument")
27782782
if not out.flags['C_CONTIGUOUS']:
27792783
raise ValueError("Output argument needs to be C-contiguous")
27802784
for idx, length in enumerate(out.shape[-len(self.out_shape)::-1]):
27812785
if length < self.out_shape[-idx]:
27822786
raise ValueError("Incompatible shape of output argument")
2783-
if out.dtype != np.float64:
2784-
raise ValueError("Output argument dtype not float64: %s" % out.dtype)
27852787
if not out.flags['WRITEABLE']:
27862788
raise ValueError("Output argument needs to be writeable")
27872789
if out.ndim > 1:
@@ -2798,12 +2800,12 @@ cdef class Lambdify(object):
27982800
if self.real:
27992801
real_inp_view = inp # slicing cython.view.array does not give a memview
28002802
real_out_view = out
2801-
self.unsafe_real(real_inp_view[idx*self.inp_size:(idx+1)*self.inp_size],
2803+
self.unsafe_real(real_inp_view[idx*self.args_size:(idx+1)*self.args_size],
28022804
real_out_view[idx*self.out_size:(idx+1)*self.out_size])
28032805
else:
28042806
complex_inp_view = inp
28052807
complex_out_view = out
2806-
self.unsafe_complex(complex_inp_view[idx*self.inp_size:(idx+1)*self.inp_size],
2808+
self.unsafe_complex(complex_inp_view[idx*self.args_size:(idx+1)*self.args_size],
28072809
complex_out_view[idx*self.out_size:(idx+1)*self.out_size])
28082810

28092811
if use_numpy and reshape_out:

symengine/tests/test_lambdify.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
# -*- coding: utf-8 -*-
22
from __future__ import (absolute_import, division, print_function)
33

4+
from symengine.utilities import raises
5+
46
import array
57
import cmath
68
import itertools
79
import math
8-
import operator
910
import sys
1011

1112
try:
@@ -112,6 +113,38 @@ def test_array_out():
112113
assert np.allclose(out1[:], [-1]*len(exprs))
113114

114115

116+
#@pytest.mark.skipif(not HAVE_NUMPY, reason='requires NumPy')
117+
def test_numpy_array_out_exceptions():
118+
if not HAVE_NUMPY: # nosetests work-around
119+
return
120+
import numpy as np
121+
args, exprs, inp, check = _get_array()
122+
lmb = se.Lambdify(args, exprs)
123+
124+
all_right = np.empty(len(exprs))
125+
lmb(inp, all_right)
126+
127+
too_short = np.empty(len(exprs) - 1)
128+
raises(ValueError, lambda: (lmb(inp, too_short)))
129+
130+
wrong_dtype = np.empty(len(exprs), dtype=int)
131+
raises(TypeError, lambda: (lmb(inp, wrong_dtype)))
132+
133+
read_only = np.empty(len(exprs))
134+
read_only.flags['WRITEABLE'] = False
135+
raises(ValueError, lambda: (lmb(inp, read_only)))
136+
137+
all_right_broadcast = np.empty((2, len(exprs)))
138+
inp_bcast = [[1, 2, 3], [4, 5, 6]]
139+
lmb(np.array(inp_bcast), all_right_broadcast)
140+
141+
f_contig_broadcast = np.empty((2, len(exprs)), order='F')
142+
raises(ValueError, lambda: (lmb(inp_bcast, f_contig_broadcast)))
143+
144+
improper_bcast = np.empty((3, len(exprs)))
145+
raises(ValueError, lambda: (lmb(inp_bcast, improper_bcast)))
146+
147+
115148
def test_array_out_no_numpy():
116149
if sys.version_info[0] < 3:
117150
return # requires Py3
@@ -459,3 +492,27 @@ def test_complex_2():
459492
lmb = se.Lambdify([x], [3 + x - 1j], real=False)
460493
assert abs(lmb([11+13j])[0] -
461494
(14 + 12j)) < 1e-15
495+
496+
497+
def test_more_than_255_args():
498+
# SymPy's lambdify can handle at most 255 arguments
499+
# this is a proof of concept that this limitation does
500+
# not affect SymEngine's Lambdify class
501+
if not HAVE_NUMPY: # nosetests work-around
502+
return
503+
import numpy as np
504+
n = 257
505+
x = se.symarray('x', n)
506+
p, q, r = 17, 42, 13
507+
terms = [i*s for i, s in enumerate(x, p)]
508+
exprs = [se.add(*terms), r + x[0], -99]
509+
callback = se.Lambdify(x, exprs)
510+
input_arr = np.arange(q, q + n*n).reshape((n, n))
511+
out = callback(input_arr)
512+
ref = np.empty((n, 3))
513+
coeffs = np.arange(p, p + n)
514+
for i in range(n):
515+
ref[i, 0] = coeffs.dot(np.arange(q + n*i, q + n*(i+1)))
516+
ref[i, 1] = q + n*i + r
517+
ref[:, 2] = -99
518+
assert np.allclose(out, ref)

0 commit comments

Comments
 (0)