Skip to content

Commit fa92d42

Browse files
committed
compiler: fix buffering ispace
1 parent a19b75f commit fa92d42

File tree

5 files changed

+47
-8
lines changed

5 files changed

+47
-8
lines changed

devito/passes/clusters/buffering.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -403,8 +403,11 @@ def generate_buffers(clusters, key, sregistry, options, **kwargs):
403403
# Finally create the actual buffer
404404
cls = callback or Array
405405
name = sregistry.make_name(prefix='%sb' % f.name)
406+
# We specify the padding to match the input Function's one, so that
407+
# the array can be used in place of the Function with valid strides
406408
mapper[f] = cls(name=name, dimensions=dimensions, dtype=f.dtype,
407-
grid=f.grid, halo=f.halo, space='mapped', mapped=f, f=f)
409+
padding=f.padding, grid=f.grid, halo=f.halo,
410+
space='mapped', mapped=f, f=f)
408411

409412
return mapper
410413

@@ -439,8 +442,27 @@ def __init__(self, f, b, clusters):
439442
# The IterationSpace within which the buffer will be accessed
440443
# NOTE: The `key` is to avoid Clusters including `f` but not directly
441444
# using it in an expression, such as HaloTouch Clusters
442-
key = lambda c: any(i in c.ispace.dimensions for i in self.bdims)
443-
ispaces = {c.ispace for c in clusters if key(c)}
445+
def key(c):
446+
bufferdim = any(i in c.ispace.dimensions for i in self.bdims)
447+
timeonly = all(d.is_Time for d in c.ispace.dimensions)
448+
return bufferdim or timeonly
449+
450+
ispaces = set()
451+
for c in clusters:
452+
if not key(c):
453+
continue
454+
455+
# Skip wild clusters (e.g. HaloTouch Clusters)
456+
if c.is_wild:
457+
continue
458+
# Iterations space and buffering dims
459+
ispace = c.ispace
460+
edims = [d for d in self.bdims if d not in ispace.dimensions]
461+
if not edims:
462+
ispaces.add(ispace)
463+
else:
464+
# Add all missing buffering dimensions
465+
ispaces.add(ispace.insert(self.dim, edims).reorder())
444466

445467
if len(ispaces) > 1:
446468
# Best effort to make buffering work in the presence of multiple
@@ -449,6 +471,7 @@ def __init__(self, f, b, clusters):
449471
ispaces = {i.lift(self.bdims, v=stamp) for i in ispaces}
450472

451473
if len(ispaces) > 1:
474+
print(ispaces)
452475
raise CompilationError("Unsupported `buffering` over different "
453476
"IterationSpaces")
454477

devito/passes/iet/languages/C.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def atomic_add(i, pragmas, split=False):
2727
# Base case, real reduction
2828
if not split:
2929
return i._rebuild(pragmas=pragmas)
30+
3031
# Complex reduction, split using a temp pointer
3132
# Transforms lhs += rhs into
3233
# {

devito/passes/iet/languages/CXX.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,9 @@ def atomic_add(i, pragmas, split=False):
8080
# Base case, real reduction
8181
if not split:
8282
return i._rebuild(pragmas=pragmas)
83+
8384
# Complex reduction, split using a temp pointer
84-
# Transforns lhs += rhs into
85+
# Transforms lhs += rhs into
8586
# {
8687
# pragmas
8788
# reinterpret_cast<float*>(&lhs)[0] += std::real(rhs);

devito/passes/iet/linearization.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,21 @@ def key1(f, d):
7272
if f.is_regular:
7373
# For paddable objects the following holds:
7474
# `same dim + same halo + same padding_dtype => same (auto-)padding`
75-
return (d, f._size_halo[d], f._size_padding[d])
75+
if d is f.dimensions[-1]:
76+
# Only the last dimension is padded
77+
try:
78+
if f.padding == f.mapped.padding:
79+
# Padding set from the mapped Function
80+
# e.g. from buffering or fft temp array
81+
pad_key = f.mapped.__padding_dtype__
82+
else:
83+
pad_key = f.__padding_dtype__
84+
except AttributeError:
85+
pad_key = f.__padding_dtype__
86+
else:
87+
pad_key = None
88+
89+
return (d, f._size_halo[d], pad_key)
7690
else:
7791
return False
7892

devito/passes/iet/parpragma.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -242,9 +242,9 @@ def _make_reductions(self, partree):
242242
# Implement reduction
243243
mapper = {partree.root: partree.root._rebuild(reduction=reductions)}
244244
elif all(i is OpInc for _, _, i in reductions):
245-
test2 = not self._support_complex_reduction(self.compiler) and \
246-
any(np.iscomplexobj(i.dtype(0)) for i, _, _ in reductions)
247-
mapper = {i: self.langbb['atomic'](i, test2) for i in exprs}
245+
flag = (not self._support_complex_reduction(self.compiler) and
246+
any(np.iscomplexobj(i.dtype(0)) for i, _, _ in reductions))
247+
mapper = {i: self.langbb['atomic'](i, flag) for i in exprs}
248248
else:
249249
raise NotImplementedError
250250

0 commit comments

Comments
 (0)