|
1 | | -from ctypes import POINTER |
2 | | - |
3 | 1 | import numpy as np |
4 | 2 | from sympy.printing.cxx import CXX11CodePrinter |
5 | 3 |
|
6 | | -from devito import Real, Imag |
7 | | -from devito.exceptions import InvalidOperator |
8 | | -from devito.ir import Call, UsingNamespace, BasePrinter, DummyExpr, List |
| 4 | +from devito.ir import Call, UsingNamespace, BasePrinter |
9 | 5 | from devito.passes.iet.definitions import DataManager |
10 | 6 | from devito.passes.iet.orchestration import Orchestrator |
11 | 7 | from devito.passes.iet.langbase import LangBB |
| 8 | +from devito.passes.iet.languages.utils import _atomic_add_split |
12 | 9 | from devito.symbolics import c_complex, c_double_complex, IndexedPointer, cast, Byref |
13 | | -from devito.tools import dtype_to_cstr, dtype_to_ctype |
14 | | -from devito.types import Pointer |
| 10 | +from devito.tools import dtype_to_cstr |
15 | 11 |
|
16 | 12 | __all__ = ['CXXBB', 'CXXDataManager', 'CXXOrchestrator'] |
17 | 13 |
|
@@ -70,49 +66,30 @@ def std_arith(prefix=None): |
70 | 66 | """ |
71 | 67 |
|
72 | 68 |
|
| 69 | +def split_pointer(i, idx): |
| 70 | + dtype = i.dtype(0).real.__class__ |
| 71 | + ptr = cast(dtype, stars='*')(Byref(i), reinterpret=True) |
| 72 | + return IndexedPointer(ptr, idx) |
| 73 | + |
| 74 | + |
| 75 | +cxx_imag = lambda i: split_pointer(i, 1) |
| 76 | +cxx_real = lambda i: split_pointer(i, 0) |
| 77 | + |
| 78 | + |
73 | 79 | def atomic_add(i, pragmas, split=False): |
74 | 80 | # Base case, real reduction |
75 | 81 | if not split: |
76 | 82 | return i._rebuild(pragmas=pragmas) |
77 | 83 | # Complex reduction, split using a temp pointer |
78 | 84 | # Transforns lhs += rhs into |
79 | 85 | # { |
80 | | - # float * lhs = reinterpret_cast<float*>(&lhs); |
81 | 86 | # pragmas |
82 | | - # lhs[0] += std::real(rhs); |
| 87 | + # reinterpret_cast<float*>(&lhs)[0] += std::real(rhs); |
83 | 88 | # pragmas |
84 | | - # lhs[1] += std::imag(rhs); |
| 89 | + # reinterpret_cast<float*>(&lhs)[1] += std::imag(rhs); |
85 | 90 | # } |
86 | 91 | # Make a temp pointer |
87 | | - lhs, rhs = i.expr.lhs, i.expr.rhs |
88 | | - rdtype = lhs.dtype(0).real.__class__ |
89 | | - plhs = Pointer(name=f'p{lhs.name}', dtype=POINTER(dtype_to_ctype(rdtype))) |
90 | | - peq = DummyExpr(plhs, cast(rdtype, stars='*')(Byref(lhs), reinterpret=True)) |
91 | | - |
92 | | - if (np.issubdtype(lhs.dtype, np.complexfloating) |
93 | | - and np.issubdtype(rhs.dtype, np.complexfloating)): |
94 | | - # Complex i, complex j |
95 | | - # Atomic add real and imaginary parts separately |
96 | | - lhsr, rhsr = IndexedPointer(plhs, 0), Real(rhs) |
97 | | - lhsi, rhsi = IndexedPointer(plhs, 1), Imag(rhs) |
98 | | - real = i._rebuild(expr=i.expr._rebuild(lhs=lhsr, rhs=rhsr), |
99 | | - pragmas=pragmas) |
100 | | - imag = i._rebuild(expr=i.expr._rebuild(lhs=lhsi, rhs=rhsi), |
101 | | - pragmas=pragmas) |
102 | | - return List(body=[peq, real, imag]) |
103 | | - |
104 | | - elif (np.issubdtype(lhs.dtype, np.complexfloating) |
105 | | - and not np.issubdtype(rhs.dtype, np.complexfloating)): |
106 | | - # Complex i, real j |
107 | | - # Atomic add j to real part of i |
108 | | - lhsr, rhsr = IndexedPointer(plhs, 0), rhs |
109 | | - real = i._rebuild(expr=i.expr._rebuild(lhs=lhsr, rhs=rhsr), |
110 | | - pragmas=pragmas) |
111 | | - return List(body=[peq, real]) |
112 | | - else: |
113 | | - # Real i, complex j |
114 | | - raise InvalidOperator("Atomic add not implemented for real " |
115 | | - "Functions with complex increments") |
| 92 | + return _atomic_add_split(i, pragmas, cxx_real, cxx_imag) |
116 | 93 |
|
117 | 94 |
|
118 | 95 | class CXXBB(LangBB): |
|
0 commit comments