Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion devito/ir/iet/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1057,10 +1057,11 @@ class Dereference(ExprStmt, Node):

is_Dereference = True

def __init__(self, pointee, pointer, flat=None):
def __init__(self, pointee, pointer, flat=None, offset=None):
self.pointee = pointee
self.pointer = pointer
self.flat = flat
self.offset = offset

def __repr__(self):
return "<Dereference(%s,%s)>" % (self.pointee, self.pointer)
Expand Down Expand Up @@ -1088,6 +1089,9 @@ def expr_symbols(self):
else:
assert False, f"Unexpected pointer type {type(self.pointer)}"

if self.offset is not None:
ret.append(self.offset)

return tuple(filter_ordered(ret))

@property
Expand Down
14 changes: 10 additions & 4 deletions devito/ir/iet/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,12 @@ def visit_PointerCast(self, o):

def visit_Dereference(self, o):
a0, a1 = o.functions

if o.offset:
ptr = f'({a1.name} + {o.offset})'
else:
ptr = a1.name

if a0.is_AbstractFunction:
cstr = self.ccode(a0.indexed._C_typedata)

Expand All @@ -517,17 +523,17 @@ def visit_Dereference(self, o):

if o.flat is None:
shape = ''.join(f"[{self.ccode(i)}]" for i in a0.symbolic_shape[1:])
rvalue = f'({cstr} (*){shape}) {a1.name}{cdim}'
rvalue = f'({cstr} (*){shape}) {ptr}{cdim}'
lvalue = c.Value(cstr, f'(*{self._restrict_keyword} {a0.name}){shape}')
else:
rvalue = f'({cstr} *) {a1.name}{cdim}'
rvalue = f'({cstr} *) {ptr}{cdim}'
lvalue = c.Value(cstr, f'*{self._restrict_keyword} {a0.name}')

else:
if a1.is_Symbol:
rvalue = f'*{a1.name}'
rvalue = f'*{ptr}'
else:
rvalue = f'{a1.name}->{a0._C_name}'
rvalue = f'{ptr}->{a0._C_name}'
lvalue = self._gen_value(a0, 0)

return c.Initializer(lvalue, rvalue)
Expand Down
19 changes: 16 additions & 3 deletions tests/test_iet.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,16 @@
switchconfig)
from devito.ir.iet import (
Call, Callable, Conditional, Definition, DeviceCall, DummyExpr, Iteration, List,
KernelLaunch, Lambda, ElementalFunction, CGen, FindSymbols, filter_iterations,
make_efunc, retrieve_iteration_tree, Transformer
KernelLaunch, Dereference, Lambda, ElementalFunction, CGen, FindSymbols,
filter_iterations, make_efunc, retrieve_iteration_tree, Transformer
)
from devito.ir import SymbolRegistry
from devito.passes.iet.engine import Graph
from devito.passes.iet.languages.C import CDataManager
from devito.symbolics import (Byref, FieldFromComposite, InlineIf, Macro, Class,
String, FLOAT)
from devito.tools import CustomDtype, as_tuple, dtype_to_ctype
from devito.types import CustomDimension, Array, LocalObject, Symbol
from devito.types import CustomDimension, Array, LocalObject, Symbol, Pointer


@pytest.fixture
Expand Down Expand Up @@ -496,3 +496,16 @@ def test_list_inline():

lst = List(body=[expr0, expr1], inline=True)
assert str(lst) == """a = 1; b = 2;"""


def test_dereference_base_plus_off():
ptr = Pointer(name='p', dtype=np.float32)
off = Symbol(name='offs', dtype=np.int32)

dim0 = CustomDimension(name='d0', symbolic_size=2)
dim1 = CustomDimension(name='d1', symbolic_size=3)
x = Array(name='x', dimensions=(dim0, dim1), dtype=np.float32)

deref = Dereference(x, ptr, offset=off)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any case where you could have a static offset (i.e. off = 3)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I cannot see one today, but as long as it's a sympy.Number it'd create no problems


assert str(deref) == "float (*restrict x)[3] = (float (*)[3]) (p + offs);"
Loading