Skip to content

Commit 4c73dc4

Browse files
authored
Merge pull request #148 from ShikharJ/_set
Improvements for ImmutableMatrix
2 parents d98b268 + f7da26c commit 4c73dc4

File tree

2 files changed

+41
-12
lines changed

2 files changed

+41
-12
lines changed

symengine/lib/symengine_wrapper.pyx

+24-12
Original file line numberDiff line numberDiff line change
@@ -1798,26 +1798,32 @@ cdef class DenseMatrixBase(MatrixBase):
17981798
cdef DenseMatrixBase o = _sympify(rhs)
17991799
if self.rows != o.rows:
18001800
raise ShapeError("`self` and `rhs` must have the same number of rows.")
1801-
cdef DenseMatrixBase result = zeros(self.rows, self.cols + o.cols)
1801+
cdef DenseMatrixBase result = self.__class__(self.rows, self.cols + o.cols)
1802+
cdef Basic e_
18021803
for i in range(self.rows):
18031804
for j in range(self.cols):
1804-
result[i, j] = self[i, j]
1805+
e_ = self._get(i, j)
1806+
deref(result.thisptr).set(i, j, e_.thisptr)
18051807
for i in range(o.rows):
18061808
for j in range(o.cols):
1807-
result[i, j + self.cols] = o[i, j]
1809+
e_ = _sympify(o._get(i, j))
1810+
deref(result.thisptr).set(i, j + self.cols, e_.thisptr)
18081811
return result
18091812

18101813
def col_join(self, bott):
18111814
cdef DenseMatrixBase o = _sympify(bott)
18121815
if self.cols != o.cols:
18131816
raise ShapeError("`self` and `rhs` must have the same number of columns.")
1814-
cdef DenseMatrixBase result = zeros(self.rows + o.rows, self.cols)
1817+
cdef DenseMatrixBase result = self.__class__(self.rows + o.rows, self.cols)
1818+
cdef Basic e_
18151819
for i in range(self.rows):
18161820
for j in range(self.cols):
1817-
result[i, j] = self[i, j]
1821+
e_ = self._get(i, j)
1822+
deref(result.thisptr).set(i, j, e_.thisptr)
18181823
for i in range(o.rows):
18191824
for j in range(o.cols):
1820-
result[i + self.rows, j] = o[i, j]
1825+
e_ = _sympify(o._get(i, j))
1826+
deref(result.thisptr).set(i + self.rows, j, e_.thisptr)
18211827
return result
18221828

18231829
@property
@@ -1937,16 +1943,16 @@ cdef class DenseMatrixBase(MatrixBase):
19371943
def T(self):
19381944
return self.transpose()
19391945

1940-
def _applyfunc(self, f):
1946+
def applyfunc(self, f):
1947+
cdef DenseMatrixBase out = self.__class__(self)
19411948
cdef int nr = self.nrows()
19421949
cdef int nc = self.ncols()
1950+
cdef Basic e_;
19431951
for i in range(nr):
19441952
for j in range(nc):
1945-
self._set(i, j, f(self._get(i, j)))
1946-
1947-
def applyfunc(self, f):
1948-
cdef DenseMatrixBase out = self.__class__(self)
1949-
out._applyfunc(f)
1953+
e_ = _sympify(f(self._get(i, j)))
1954+
if e_ is not None:
1955+
deref(out.thisptr).set(i, j, e_.thisptr)
19501956
return out
19511957

19521958
def msubs(self, *args):
@@ -2160,6 +2166,12 @@ cdef class MutableDenseMatrix(DenseMatrixBase):
21602166
for k in range(0, self.cols):
21612167
self[i, k], self[j, k] = self[j, k], self[i, k]
21622168

2169+
def _applyfunc(self, f):
2170+
cdef int nr = self.nrows()
2171+
cdef int nc = self.ncols()
2172+
for i in range(nr):
2173+
for j in range(nc):
2174+
self._set(i, j, f(self._get(i, j)))
21632175

21642176
Matrix = DenseMatrix = MutableDenseMatrix
21652177

symengine/tests/test_matrices.py

+17
Original file line numberDiff line numberDiff line change
@@ -395,3 +395,20 @@ def test_immutablematrix():
395395
X = ImmutableMatrix([[1, 2], [3, 4]])
396396
Y = ImmutableMatrix([[1], [0]])
397397
assert type(X.LUsolve(Y)) == ImmutableMatrix
398+
399+
x = Symbol("x")
400+
X = ImmutableMatrix([[1, 2], [3, 4]])
401+
Y = ImmutableMatrix([[1, 2], [x, 4]])
402+
assert Y.subs(x, 3) == X
403+
404+
X = ImmutableMatrix([[1, 2], [3, 4]])
405+
Y = ImmutableMatrix([[5], [6]])
406+
Z = X.row_join(Y)
407+
assert isinstance(Z, ImmutableMatrix)
408+
assert Z == ImmutableMatrix([[1, 2, 5], [3, 4, 6]])
409+
410+
X = ImmutableMatrix([[1, 2], [3, 4]])
411+
Y = ImmutableMatrix([[5, 6]])
412+
Z = X.col_join(Y)
413+
assert isinstance(Z, ImmutableMatrix)
414+
assert Z == ImmutableMatrix([[1, 2], [3, 4], [5, 6]])

0 commit comments

Comments
 (0)