|
37 | 37 |
|
38 | 38 | from arraycontext import ( |
39 | 39 | ArrayContextFactory, |
| 40 | + ArrayOrScalar, |
40 | 41 | BcastUntilActxArray, |
41 | 42 | EagerJAXArrayContext, |
42 | 43 | NumpyArrayContext, |
43 | 44 | PyOpenCLArrayContext, |
| 45 | + PytatoJAXArrayContext, |
44 | 46 | PytatoPyOpenCLArrayContext, |
45 | 47 | dataclass_array_container, |
46 | 48 | pytest_generate_tests_for_array_contexts, |
@@ -656,6 +658,71 @@ def test_array_context_einsum_array_tripleprod(actx_factory: ArrayContextFactory |
656 | 658 | # }}} |
657 | 659 |
|
658 | 660 |
|
| 661 | +def test_array_context_csr_matmul(actx_factory: ArrayContextFactory): |
| 662 | + actx = actx_factory() |
| 663 | + |
| 664 | + if isinstance(actx, (EagerJAXArrayContext, PytatoJAXArrayContext)): |
| 665 | + pytest.skip(f"not implemented for '{type(actx).__name__}'") |
| 666 | + |
| 667 | + n = 100 |
| 668 | + |
| 669 | + x = actx.from_numpy(np.arange(n, dtype=np.float64)) |
| 670 | + ary_of_x = obj_array.new_1d([x] * 3) |
| 671 | + dc_of_x = MyContainer( |
| 672 | + name="container", |
| 673 | + mass=x, |
| 674 | + momentum=obj_array.new_1d([x] * 3), |
| 675 | + enthalpy=x) |
| 676 | + |
| 677 | + elem_values = actx.zeros((n//2,), dtype=np.float64) + 1. |
| 678 | + elem_col_indices = actx.from_numpy(2*np.arange(n//2, dtype=np.int32)) |
| 679 | + row_starts = actx.from_numpy(np.arange(n//2 + 1, dtype=np.int32)) |
| 680 | + |
| 681 | + mat = actx.make_csr_matrix( |
| 682 | + shape=(n//2, n), |
| 683 | + elem_values=elem_values, |
| 684 | + elem_col_indices=elem_col_indices, |
| 685 | + row_starts=row_starts) |
| 686 | + |
| 687 | + expected_mat_x = actx.from_numpy(2 * np.arange(n//2, dtype=np.float64)) |
| 688 | + |
| 689 | + def _check_allclose( |
| 690 | + arg1: ArrayOrScalar, arg2: ArrayOrScalar, atol: float = 1.0e-14): |
| 691 | + from arraycontext import NotAnArrayContainerError |
| 692 | + try: |
| 693 | + arg1_iterable = serialize_container(arg1) |
| 694 | + arg2_iterable = serialize_container(arg2) |
| 695 | + except NotAnArrayContainerError: |
| 696 | + assert np.linalg.norm(actx.to_numpy(arg1 - arg2)) < atol |
| 697 | + else: |
| 698 | + arg1_subarrays = [ |
| 699 | + subarray for _, subarray in arg1_iterable] |
| 700 | + arg2_subarrays = [ |
| 701 | + subarray for _, subarray in arg2_iterable] |
| 702 | + for subarray1, subarray2 in zip(arg1_subarrays, arg2_subarrays, |
| 703 | + strict=True): |
| 704 | + _check_allclose(subarray1, subarray2) |
| 705 | + |
| 706 | + # single array |
| 707 | + res = mat @ x |
| 708 | + expected_res = expected_mat_x |
| 709 | + _check_allclose(res, expected_res) |
| 710 | + |
| 711 | + # array of arrays |
| 712 | + res = mat @ ary_of_x |
| 713 | + expected_res = obj_array.new_1d([expected_mat_x] * 3) |
| 714 | + _check_allclose(res, expected_res) |
| 715 | + |
| 716 | + # container of arrays |
| 717 | + res = mat @ dc_of_x |
| 718 | + expected_res = MyContainer( |
| 719 | + name="container", |
| 720 | + mass=expected_mat_x, |
| 721 | + momentum=obj_array.new_1d([expected_mat_x] * 3), |
| 722 | + enthalpy=expected_mat_x) |
| 723 | + _check_allclose(res, expected_res) |
| 724 | + |
| 725 | + |
659 | 726 | # {{{ array container classes for test |
660 | 727 |
|
661 | 728 |
|
|
0 commit comments