Skip to content

Commit 374d86c

Browse files
committed
compiler: add local array with fixed size type
1 parent 9d2e0c4 commit 374d86c

File tree

4 files changed

+10
-7
lines changed

4 files changed

+10
-7
lines changed

devito/arch/archinfo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -577,7 +577,7 @@ def get_advisor_path():
577577
@memoized_func
578578
def get_hip_path():
579579
# *** First try: via commonly used environment variables
580-
for i in ['HIP_HOME']:
580+
for i in ['HIP_HOME', 'ROCM_HOME']:
581581
hip_home = os.environ.get(i)
582582
if hip_home:
583583
return hip_home

devito/ir/iet/visitors.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,10 @@ def _gen_rettype(self, obj):
353353
elif isinstance(obj, (FieldFromComposite, FieldFromPointer)):
354354
return self._gen_value(obj.function.base, 0).typename
355355
else:
356-
return None
356+
try:
357+
return obj._type_.__name__
358+
except AttributeError:
359+
return None
357360

358361
def _args_decl(self, args):
359362
"""Generate cgen declarations from an iterable of symbols and expressions."""

devito/mpi/routines.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1395,7 +1395,7 @@ def _arg_values(self, args=None, **kwargs):
13951395
class AllreduceCall(Call):
13961396

13971397
def __init__(self, arguments, **kwargs):
1398-
super().__init__('MPI_Allreduce', arguments)
1398+
super().__init__('MPI_Allreduce', arguments, **kwargs)
13991399

14001400

14011401
class ReductionBuilder(object):
@@ -1422,6 +1422,6 @@ def make(self, dr):
14221422
op = self.mapper[dr.op]
14231423

14241424
arguments = [inplace, Byref(f), Integer(1), mpitype, op, comm]
1425-
allreduce = AllreduceCall(arguments)
1425+
allreduce = AllreduceCall(arguments, writes=f)
14261426

14271427
return allreduce

devito/symbolics/extended_sympy.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -172,9 +172,9 @@ def __new__(cls, call, pointer, params=None, **kwargs):
172172
pointer = Symbol(pointer)
173173
if isinstance(call, str):
174174
call = Symbol(call)
175-
elif not isinstance(call, Basic):
176-
raise ValueError("`call` must be a `devito.Basic` or a type "
177-
"with compatible interface")
175+
elif not isinstance(call, (BasicWrapperMixin, Basic)):
176+
raise ValueError(f"`call` {call} must be a `devito.Basic` or a type "
177+
f"with compatible interface, not {type(call)}")
178178
_params = []
179179
for p in as_tuple(params):
180180
if isinstance(p, str):

0 commit comments

Comments
 (0)