|
2 | 2 | import pytest |
3 | 3 | import sympy |
4 | 4 |
|
| 5 | +try: |
| 6 | + from ..conftest import skipif |
| 7 | +except ImportError: |
| 8 | + from conftest import skipif |
| 9 | + |
5 | 10 | from devito import ( |
6 | 11 | Constant, Eq, Function, Grid, Operator, exp, log, sin, configuration |
7 | 12 | ) |
| 13 | +from devito.arch.compiler import GNUCompiler, CustomCompiler |
| 14 | +from devito.exceptions import InvalidOperator |
8 | 15 | from devito.ir.cgen.printer import BasePrinter |
9 | 16 | from devito.passes.iet.langbase import LangBB |
10 | 17 | from devito.passes.iet.languages.C import CBB, CPrinter |
11 | 18 | from devito.passes.iet.languages.openacc import AccBB, AccPrinter |
12 | 19 | from devito.passes.iet.languages.openmp import OmpBB |
13 | 20 | from devito.symbolics.extended_dtypes import ctypes_vector_mapper |
| 21 | +from devito.tools import dtype_to_cstr |
14 | 22 | from devito.types.basic import Basic, Scalar, Symbol |
15 | 23 | from devito.types.dense import TimeFunction |
| 24 | +from devito.types.sparse import SparseTimeFunction |
16 | 25 |
|
17 | 26 | # Mappers for language-specific types and headers |
18 | 27 | _languages: dict[str, type[LangBB]] = { |
@@ -274,3 +283,56 @@ def test_complex_space_deriv(dtype: np.dtype[np.complexfloating]) -> None: |
274 | 283 | dfdy = h.data.T[1:-1, 1:-1] |
275 | 284 | assert np.allclose(dfdx, np.ones((5, 5), dtype=dtype)) |
276 | 285 | assert np.allclose(dfdy, np.ones((5, 5), dtype=dtype)) |
| 286 | + |
| 287 | + |
| 288 | +@skipif(['noomp', 'device']) |
| 289 | +@pytest.mark.parametrize('dtypeu', [np.float32, np.complex64, np.complex128]) |
| 290 | +def test_complex_reduction(dtypeu: np.dtype[np.complexfloating]) -> None: |
| 291 | + """ |
| 292 | + Tests reductions over complex-valued functions. |
| 293 | + """ |
| 294 | + grid = Grid((11, 11)) |
| 295 | + |
| 296 | + u = TimeFunction(name="u", grid=grid, space_order=2, time_order=1, dtype=dtypeu) |
| 297 | + for dtypes in [dtypeu, dtypeu(0).real.__class__]: |
| 298 | + u.data.fill(0) |
| 299 | + s = SparseTimeFunction(name="s", grid=grid, npoint=1, nt=10, dtype=dtypes) |
| 300 | + if np.issubdtype(dtypes, np.complexfloating): |
| 301 | + s.data[:] = 1 + 2j |
| 302 | + expected = 8. + 16.j |
| 303 | + else: |
| 304 | + s.data[:] = 1 |
| 305 | + expected = 8. |
| 306 | + s.coordinates.data[:] = [.5, .5] |
| 307 | + |
| 308 | + # s complex and u real should error |
| 309 | + if np.issubdtype(dtypeu, np.floating) and \ |
| 310 | + np.issubdtype(dtypes, np.complexfloating): |
| 311 | + with pytest.raises(InvalidOperator): |
| 312 | + op = Operator([Eq(u.forward, u)] + s.inject(u.forward, expr=s)) |
| 313 | + continue |
| 314 | + else: |
| 315 | + op = Operator([Eq(u.forward, u)] + s.inject(u.forward, expr=s)) |
| 316 | + op() |
| 317 | + |
| 318 | + if op._options['linearize']: |
| 319 | + ustr = 'uL0(t1, rsx + posx + 2, rsy + posy + 2)' |
| 320 | + else: |
| 321 | + ustr = 'u[t1][rsx + posx + 2][rsy + posy + 2]' |
| 322 | + |
| 323 | + compiler = configuration['compiler'] |
| 324 | + gnu = isinstance(compiler, GNUCompiler) or \ |
| 325 | + (isinstance(compiler, CustomCompiler) and compiler._base is GNUCompiler) |
| 326 | + if gnu and np.issubdtype(dtypeu, np.complexfloating): |
| 327 | + if 'CXX' in op._language: |
| 328 | + rd = dtype_to_cstr(dtypeu(0).real.__class__) |
| 329 | + assert f'{rd} * p{u.name} = reinterpret_cast<{rd}*>(&uL0' in str(op) |
| 330 | + assert f'p{u.name}[0] += std::real(r0)' in str(op) |
| 331 | + assert f'p{u.name}[1] += std::imag(r0)' in str(op) |
| 332 | + else: |
| 333 | + assert f'__real__ {ustr} += __real__ r0' in str(op) |
| 334 | + assert f'__imag__ {ustr} += __imag__ r0' in str(op) |
| 335 | + else: |
| 336 | + assert f'{ustr} += r0' in str(op) |
| 337 | + |
| 338 | + assert np.isclose(u.data[0, 5, 5], expected) |
0 commit comments