Skip to content

Commit a19b75f

Browse files
committed
compiler: simplify complex reductions
1 parent 69048f9 commit a19b75f

File tree

6 files changed

+67
-75
lines changed

6 files changed

+67
-75
lines changed

devito/ir/cgen/printer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,8 @@ def _print_ListInitializer(self, expr):
349349
return f"{{{', '.join(self._print(i) for i in expr.params)}}}"
350350

351351
def _print_IndexedPointer(self, expr):
352-
return f"{expr.base}{''.join(f'[{self._print(i)}]' for i in expr.index)}"
352+
base = self._print(expr.base)
353+
return f"{base}{''.join(f'[{self._print(i)}]' for i in expr.index)}"
353354

354355
def _print_IntDiv(self, expr):
355356
lhs = self._print(expr.lhs)

devito/passes/iet/languages/C.py

Lines changed: 4 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import numpy as np
22
from sympy.printing.c import C99CodePrinter
33

4-
from devito.exceptions import InvalidOperator
5-
from devito.ir import Call, BasePrinter, List
4+
from devito.ir import Call, BasePrinter
65
from devito.passes.iet.definitions import DataManager
76
from devito.passes.iet.orchestration import Orchestrator
87
from devito.passes.iet.langbase import LangBB
8+
from devito.passes.iet.languages.utils import _atomic_add_split
99
from devito.symbolics import c_complex, c_double_complex
1010
from devito.symbolics.extended_sympy import UnaryOp
1111
from devito.tools import dtype_to_cstr
@@ -28,38 +28,14 @@ def atomic_add(i, pragmas, split=False):
2828
if not split:
2929
return i._rebuild(pragmas=pragmas)
3030
# Complex reduction, split using a temp pointer
31-
# Transforns lhs += rhs into
31+
# Transforms lhs += rhs into
3232
# {
3333
# pragmas
3434
# __real__ lhs += __real__ rhs;
3535
# pragmas
3636
# __imag__ lhs += __imag__ rhs;
3737
# }
38-
lhs, rhs = i.expr.lhs, i.expr.rhs
39-
if (np.issubdtype(lhs.dtype, np.complexfloating)
40-
and np.issubdtype(rhs.dtype, np.complexfloating)):
41-
# Complex i, complex j
42-
# Atomic add real and imaginary parts separately
43-
lhsr, rhsr = RealExt(lhs), RealExt(rhs)
44-
lhsi, rhsi = ImagExt(lhs), ImagExt(rhs)
45-
real = i._rebuild(expr=i.expr._rebuild(lhs=lhsr, rhs=rhsr),
46-
pragmas=pragmas)
47-
imag = i._rebuild(expr=i.expr._rebuild(lhs=lhsi, rhs=rhsi),
48-
pragmas=pragmas)
49-
return List(body=[real, imag])
50-
51-
elif (np.issubdtype(lhs.dtype, np.complexfloating)
52-
and not np.issubdtype(rhs.dtype, np.complexfloating)):
53-
# Complex i, real j
54-
# Atomic add j to real part of i
55-
lhsr, rhsr = RealExt(lhs), rhs
56-
real = i._rebuild(expr=i.expr._rebuild(lhs=lhsr, rhs=rhsr),
57-
pragmas=pragmas)
58-
return real
59-
else:
60-
# Real i, complex j
61-
raise InvalidOperator("Atomic add not implemented for real "
62-
"Functions with complex increments")
38+
return _atomic_add_split(i, pragmas, RealExt, ImagExt)
6339

6440

6541
class CBB(LangBB):

devito/passes/iet/languages/CXX.py

Lines changed: 16 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,13 @@
1-
from ctypes import POINTER
2-
31
import numpy as np
42
from sympy.printing.cxx import CXX11CodePrinter
53

6-
from devito import Real, Imag
7-
from devito.exceptions import InvalidOperator
8-
from devito.ir import Call, UsingNamespace, BasePrinter, DummyExpr, List
4+
from devito.ir import Call, UsingNamespace, BasePrinter
95
from devito.passes.iet.definitions import DataManager
106
from devito.passes.iet.orchestration import Orchestrator
117
from devito.passes.iet.langbase import LangBB
8+
from devito.passes.iet.languages.utils import _atomic_add_split
129
from devito.symbolics import c_complex, c_double_complex, IndexedPointer, cast, Byref
13-
from devito.tools import dtype_to_cstr, dtype_to_ctype
14-
from devito.types import Pointer
10+
from devito.tools import dtype_to_cstr
1511

1612
__all__ = ['CXXBB', 'CXXDataManager', 'CXXOrchestrator']
1713

@@ -70,49 +66,30 @@ def std_arith(prefix=None):
7066
"""
7167

7268

69+
def split_pointer(i, idx):
70+
dtype = i.dtype(0).real.__class__
71+
ptr = cast(dtype, stars='*')(Byref(i), reinterpret=True)
72+
return IndexedPointer(ptr, idx)
73+
74+
75+
cxx_imag = lambda i: split_pointer(i, 1)
76+
cxx_real = lambda i: split_pointer(i, 0)
77+
78+
7379
def atomic_add(i, pragmas, split=False):
7480
# Base case, real reduction
7581
if not split:
7682
return i._rebuild(pragmas=pragmas)
7783
# Complex reduction, split using a temp pointer
7884
# Transforns lhs += rhs into
7985
# {
80-
# float * lhs = reinterpret_cast<float*>(&lhs);
8186
# pragmas
82-
# lhs[0] += std::real(rhs);
87+
# reinterpret_cast<float*>(&lhs)[0] += std::real(rhs);
8388
# pragmas
84-
# lhs[1] += std::imag(rhs);
89+
# reinterpret_cast<float*>(&lhs)[1] += std::imag(rhs);
8590
# }
8691
# Make a temp pointer
87-
lhs, rhs = i.expr.lhs, i.expr.rhs
88-
rdtype = lhs.dtype(0).real.__class__
89-
plhs = Pointer(name=f'p{lhs.name}', dtype=POINTER(dtype_to_ctype(rdtype)))
90-
peq = DummyExpr(plhs, cast(rdtype, stars='*')(Byref(lhs), reinterpret=True))
91-
92-
if (np.issubdtype(lhs.dtype, np.complexfloating)
93-
and np.issubdtype(rhs.dtype, np.complexfloating)):
94-
# Complex i, complex j
95-
# Atomic add real and imaginary parts separately
96-
lhsr, rhsr = IndexedPointer(plhs, 0), Real(rhs)
97-
lhsi, rhsi = IndexedPointer(plhs, 1), Imag(rhs)
98-
real = i._rebuild(expr=i.expr._rebuild(lhs=lhsr, rhs=rhsr),
99-
pragmas=pragmas)
100-
imag = i._rebuild(expr=i.expr._rebuild(lhs=lhsi, rhs=rhsi),
101-
pragmas=pragmas)
102-
return List(body=[peq, real, imag])
103-
104-
elif (np.issubdtype(lhs.dtype, np.complexfloating)
105-
and not np.issubdtype(rhs.dtype, np.complexfloating)):
106-
# Complex i, real j
107-
# Atomic add j to real part of i
108-
lhsr, rhsr = IndexedPointer(plhs, 0), rhs
109-
real = i._rebuild(expr=i.expr._rebuild(lhs=lhsr, rhs=rhsr),
110-
pragmas=pragmas)
111-
return List(body=[peq, real])
112-
else:
113-
# Real i, complex j
114-
raise InvalidOperator("Atomic add not implemented for real "
115-
"Functions with complex increments")
92+
return _atomic_add_split(i, pragmas, cxx_real, cxx_imag)
11693

11794

11895
class CXXBB(LangBB):
Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,42 @@
1-
__all__ = ['joins']
1+
import numpy as np
2+
3+
from devito import Real, Imag
4+
from devito.exceptions import InvalidOperator
5+
from devito.ir import List
6+
7+
__all__ = ['joins', '_atomic_add_split']
28

39

410
def joins(*symbols):
511
return ",".join(sorted([i.name for i in symbols]))
12+
13+
14+
def _atomic_add_split(i, pragmas, real, imag):
15+
# Complex reduction, split between real and imaginary parts.
16+
# real is a function i -> real(i)
17+
# imag is a function i -> imag(i)
18+
lhs, rhs = i.expr.lhs, i.expr.rhs
19+
if (np.issubdtype(lhs.dtype, np.complexfloating)
20+
and np.issubdtype(rhs.dtype, np.complexfloating)):
21+
# Complex i, complex j
22+
# Atomic add real and imaginary parts separately
23+
lhsr, rhsr = real(lhs), Real(rhs)
24+
lhsi, rhsi = imag(lhs), Imag(rhs)
25+
real = i._rebuild(expr=i.expr._rebuild(lhs=lhsr, rhs=rhsr),
26+
pragmas=pragmas)
27+
imag = i._rebuild(expr=i.expr._rebuild(lhs=lhsi, rhs=rhsi),
28+
pragmas=pragmas)
29+
return List(body=[real, imag])
30+
31+
elif (np.issubdtype(lhs.dtype, np.complexfloating)
32+
and not np.issubdtype(rhs.dtype, np.complexfloating)):
33+
# Complex i, real j
34+
# Atomic add j to real part of i
35+
lhsr, rhsr = real(lhs), rhs
36+
real = i._rebuild(expr=i.expr._rebuild(lhs=lhsr, rhs=rhsr),
37+
pragmas=pragmas)
38+
return real
39+
else:
40+
# Real i, complex j
41+
raise InvalidOperator("Atomic add not implemented for real "
42+
"Functions with complex increments")

devito/passes/iet/linearization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def key1(f, d):
7272
if f.is_regular:
7373
# For paddable objects the following holds:
7474
# `same dim + same halo + same padding_dtype => same (auto-)padding`
75-
return (d, f._size_halo[d], f.__padding_dtype__)
75+
return (d, f._size_halo[d], f._size_padding[d])
7676
else:
7777
return False
7878

tests/test_dtypes.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -326,12 +326,13 @@ def test_complex_reduction(dtypeu: np.dtype[np.complexfloating]) -> None:
326326
if gnu and np.issubdtype(dtypeu, np.complexfloating):
327327
if 'CXX' in op._language:
328328
rd = dtype_to_cstr(dtypeu(0).real.__class__)
329-
assert f'{rd} * p{u.name} = reinterpret_cast<{rd}*>(&uL0' in str(op)
330-
assert f'p{u.name}[0] += std::real(r0)' in str(op)
331-
assert f'p{u.name}[1] += std::imag(r0)' in str(op)
329+
fu = f'reinterpret_cast<{rd}*>(&{ustr})'
330+
assert f'{fu}[0] += std::real(r0)' in str(op)
331+
assert f'{fu}[1] += std::imag(r0)' in str(op)
332332
else:
333-
assert f'__real__ {ustr} += __real__ r0' in str(op)
334-
assert f'__imag__ {ustr} += __imag__ r0' in str(op)
333+
ext = '' if dtypeu == np.complex128 else 'f'
334+
assert f'__real__ {ustr} += creal{ext}(r0)' in str(op)
335+
assert f'__imag__ {ustr} += cimag{ext}(r0)' in str(op)
335336
else:
336337
assert f'{ustr} += r0' in str(op)
337338

0 commit comments

Comments
 (0)