Skip to content

Commit 9200e53

Browse files
committed
add tests
1 parent 15a6613 commit 9200e53

File tree

1 file changed

+67
-0
lines changed

1 file changed

+67
-0
lines changed

test/test_arraycontext.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,12 @@
3737

3838
from arraycontext import (
3939
ArrayContextFactory,
40+
ArrayOrScalar,
4041
BcastUntilActxArray,
4142
EagerJAXArrayContext,
4243
NumpyArrayContext,
4344
PyOpenCLArrayContext,
45+
PytatoJAXArrayContext,
4446
PytatoPyOpenCLArrayContext,
4547
dataclass_array_container,
4648
pytest_generate_tests_for_array_contexts,
@@ -656,6 +658,71 @@ def test_array_context_einsum_array_tripleprod(actx_factory: ArrayContextFactory
656658
# }}}
657659

658660

661+
def test_array_context_csr_matmul(actx_factory: ArrayContextFactory):
662+
actx = actx_factory()
663+
664+
if isinstance(actx, (EagerJAXArrayContext, PytatoJAXArrayContext)):
665+
pytest.skip(f"not implemented for '{type(actx).__name__}'")
666+
667+
n = 100
668+
669+
x = actx.from_numpy(np.arange(n, dtype=np.float64))
670+
ary_of_x = obj_array.new_1d([x] * 3)
671+
dc_of_x = MyContainer(
672+
name="container",
673+
mass=x,
674+
momentum=obj_array.new_1d([x] * 3),
675+
enthalpy=x)
676+
677+
elem_values = actx.zeros((n//2,), dtype=np.float64) + 1.
678+
elem_col_indices = actx.from_numpy(2*np.arange(n//2, dtype=np.int32))
679+
row_starts = actx.from_numpy(np.arange(n//2 + 1, dtype=np.int32))
680+
681+
mat = actx.make_csr_matrix(
682+
shape=(n//2, n),
683+
elem_values=elem_values,
684+
elem_col_indices=elem_col_indices,
685+
row_starts=row_starts)
686+
687+
expected_mat_x = actx.from_numpy(2 * np.arange(n//2, dtype=np.float64))
688+
689+
def _check_allclose(
690+
arg1: ArrayOrScalar, arg2: ArrayOrScalar, atol: float = 1.0e-14):
691+
from arraycontext import NotAnArrayContainerError
692+
try:
693+
arg1_iterable = serialize_container(arg1)
694+
arg2_iterable = serialize_container(arg2)
695+
except NotAnArrayContainerError:
696+
assert np.linalg.norm(actx.to_numpy(arg1 - arg2)) < atol
697+
else:
698+
arg1_subarrays = [
699+
subarray for _, subarray in arg1_iterable]
700+
arg2_subarrays = [
701+
subarray for _, subarray in arg2_iterable]
702+
for subarray1, subarray2 in zip(arg1_subarrays, arg2_subarrays,
703+
strict=True):
704+
_check_allclose(subarray1, subarray2)
705+
706+
# single array
707+
res = mat @ x
708+
expected_res = expected_mat_x
709+
_check_allclose(res, expected_res)
710+
711+
# array of arrays
712+
res = mat @ ary_of_x
713+
expected_res = obj_array.new_1d([expected_mat_x] * 3)
714+
_check_allclose(res, expected_res)
715+
716+
# container of arrays
717+
res = mat @ dc_of_x
718+
expected_res = MyContainer(
719+
name="container",
720+
mass=expected_mat_x,
721+
momentum=obj_array.new_1d([expected_mat_x] * 3),
722+
enthalpy=expected_mat_x)
723+
_check_allclose(res, expected_res)
724+
725+
659726
# {{{ array container classes for test
660727

661728

0 commit comments

Comments
 (0)