Skip to content

Commit bea2af2

Browse files
committed
Merge branch 'master' into gpu
2 parents e937806 + a33fedf commit bea2af2

26 files changed

+2120
-1733
lines changed

pyop2/caching.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,15 @@
3333

3434
"""Provides common base classes for cached objects."""
3535

36+
import hashlib
37+
import os
38+
from pathlib import Path
39+
import pickle
3640

41+
import cachetools
42+
43+
from pyop2.configuration import configuration
44+
from pyop2.mpi import hash_comm
3745
from pyop2.utils import cached_property
3846

3947

@@ -230,3 +238,108 @@ def _cache_key(cls, *args, **kwargs):
230238
def cache_key(self):
231239
"""Cache key."""
232240
return self._key
241+
242+
243+
cached = cachetools.cached
244+
"""Cache decorator for functions. See the cachetools documentation for more
245+
information.
246+
247+
.. note::
248+
If you intend to use this decorator to cache things that are collective
249+
across a communicator then you must include the communicator as part of
250+
the cache key. Since communicators are themselves not hashable you should
251+
use :func:`pyop2.mpi.hash_comm`.
252+
253+
You should also make sure to use unbounded caches as otherwise some ranks
254+
may evict results leading to deadlocks.
255+
"""
256+
257+
258+
def disk_cached(cache, cachedir=None, key=cachetools.keys.hashkey, collective=False):
259+
"""Decorator for wrapping a function in a cache that stores values in memory and to disk.
260+
261+
:arg cache: The in-memory cache, usually a :class:`dict`.
262+
:arg cachedir: The location of the cache directory. Defaults to ``PYOP2_CACHE_DIR``.
263+
:arg key: Callable returning the cache key for the function inputs. If ``collective``
264+
is ``True`` then this function must return a 2-tuple where the first entry is the
265+
communicator to be collective over and the second is the key. This is required to ensure
266+
that deadlocks do not occur when using different subcommunicators.
267+
:arg collective: If ``True`` then cache lookup is done collectively over a communicator.
268+
"""
269+
if cachedir is None:
270+
cachedir = configuration["cache_dir"]
271+
272+
def decorator(func):
273+
def wrapper(*args, **kwargs):
274+
if collective:
275+
comm, disk_key = key(*args, **kwargs)
276+
disk_key = _as_hexdigest(disk_key)
277+
k = hash_comm(comm), disk_key
278+
else:
279+
k = _as_hexdigest(key(*args, **kwargs))
280+
281+
# first try the in-memory cache
282+
try:
283+
return cache[k]
284+
except KeyError:
285+
pass
286+
287+
# then try to retrieve from disk
288+
if collective:
289+
if comm.rank == 0:
290+
v = _disk_cache_get(cachedir, disk_key)
291+
comm.bcast(v, root=0)
292+
else:
293+
v = comm.bcast(None, root=0)
294+
else:
295+
v = _disk_cache_get(cachedir, k)
296+
if v is not None:
297+
return cache.setdefault(k, v)
298+
299+
# if all else fails call func and populate the caches
300+
v = func(*args, **kwargs)
301+
if collective:
302+
if comm.rank == 0:
303+
_disk_cache_set(cachedir, disk_key, v)
304+
else:
305+
_disk_cache_set(cachedir, k, v)
306+
return cache.setdefault(k, v)
307+
return wrapper
308+
return decorator
309+
310+
311+
def _as_hexdigest(key):
312+
return hashlib.md5(str(key).encode()).hexdigest()
313+
314+
315+
def _disk_cache_get(cachedir, key):
316+
"""Retrieve a value from the disk cache.
317+
318+
:arg cachedir: The cache directory.
319+
:arg key: The cache key (must be a string).
320+
:returns: The cached object if found, else ``None``.
321+
"""
322+
filepath = Path(cachedir, key[:2], key[2:])
323+
try:
324+
with open(filepath, "rb") as f:
325+
return pickle.load(f)
326+
except FileNotFoundError:
327+
return None
328+
329+
330+
def _disk_cache_set(cachedir, key, value):
331+
"""Store a new value in the disk cache.
332+
333+
:arg cachedir: The cache directory.
334+
:arg key: The cache key (must be a string).
335+
:arg value: The new item to store in the cache.
336+
"""
337+
k1, k2 = key[:2], key[2:]
338+
basedir = Path(cachedir, k1)
339+
basedir.mkdir(parents=True, exist_ok=True)
340+
341+
tempfile = basedir.joinpath(f"{k2}_p{os.getpid()}.tmp")
342+
filepath = basedir.joinpath(k2)
343+
with open(tempfile, "wb") as f:
344+
pickle.dump(value, f)
345+
tempfile.rename(filepath)

pyop2/codegen/builder.py

Lines changed: 92 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
import numpy
77
from loopy.types import OpaqueType
8+
from pyop2.global_kernel import (GlobalKernelArg, DatKernelArg, MixedDatKernelArg,
9+
MatKernelArg, MixedMatKernelArg, PermutedMapKernelArg)
810
from pyop2.codegen.representation import (Accumulate, Argument, Comparison,
911
DummyInstruction, Extent, FixedIndex,
1012
FunctionCall, Index, Indexed,
@@ -16,7 +18,7 @@
1618
When, Zero)
1719
from pyop2.datatypes import IntType
1820
from pyop2.op2 import (ALL, INC, MAX, MIN, ON_BOTTOM, ON_INTERIOR_FACETS,
19-
ON_TOP, READ, RW, WRITE, Subset, PermutedMap)
21+
ON_TOP, READ, RW, WRITE)
2022
from pyop2.utils import cached_property
2123

2224

@@ -32,18 +34,22 @@ class Map(object):
3234
"variable", "unroll", "layer_bounds",
3335
"prefetch", "_pmap_count")
3436

35-
def __init__(self, map_, interior_horizontal, layer_bounds,
36-
offset=None, unroll=False):
37-
self.variable = map_.iterset._extruded and not map_.iterset.constant_layers
37+
def __init__(self, interior_horizontal, layer_bounds,
38+
arity, dtype,
39+
offset=None, unroll=False,
40+
extruded=False, constant_layers=False):
41+
self.variable = extruded and not constant_layers
3842
self.unroll = unroll
3943
self.layer_bounds = layer_bounds
4044
self.interior_horizontal = interior_horizontal
4145
self.prefetch = {}
42-
offset = map_.offset
43-
shape = (None, ) + map_.shape[1:]
44-
values = Argument(shape, dtype=map_.dtype, pfx="map")
46+
47+
shape = (None, arity)
48+
values = Argument(shape, dtype=dtype, pfx="map")
4549
if offset is not None:
46-
if len(set(map_.offset)) == 1:
50+
assert type(offset) == tuple
51+
offset = numpy.array(offset, dtype=numpy.int32)
52+
if len(set(offset)) == 1:
4753
offset = Literal(offset[0], casting=True)
4854
else:
4955
offset = NamedLiteral(offset, parent=values, suffix="offset")
@@ -616,15 +622,18 @@ def emit_unpack_instruction(self, *,
616622

617623
class WrapperBuilder(object):
618624

619-
def __init__(self, *, kernel, iterset, iteration_region=None, single_cell=False,
625+
def __init__(self, *, kernel, subset, extruded, constant_layers, iteration_region=None, single_cell=False,
620626
pass_layer_to_kernel=False, forward_arg_types=()):
621627
self.kernel = kernel
628+
self.local_knl_args = iter(kernel.arguments)
622629
self.arguments = []
623630
self.argument_accesses = []
624631
self.packed_args = []
625632
self.indices = []
626633
self.maps = OrderedDict()
627-
self.iterset = iterset
634+
self.subset = subset
635+
self.extruded = extruded
636+
self.constant_layers = constant_layers
628637
if iteration_region is None:
629638
self.iteration_region = ALL
630639
else:
@@ -637,18 +646,6 @@ def __init__(self, *, kernel, iterset, iteration_region=None, single_cell=False,
637646
def requires_zeroed_output_arguments(self):
638647
return self.kernel.requires_zeroed_output_arguments
639648

640-
@property
641-
def subset(self):
642-
return isinstance(self.iterset, Subset)
643-
644-
@property
645-
def extruded(self):
646-
return self.iterset._extruded
647-
648-
@property
649-
def constant_layers(self):
650-
return self.extruded and self.iterset.constant_layers
651-
652649
@cached_property
653650
def loop_extents(self):
654651
return (Argument((), IntType, name="start"),
@@ -753,94 +750,98 @@ def loop_indices(self):
753750
return (self.loop_index, None, self._loop_index)
754751

755752
def add_argument(self, arg):
753+
local_arg = next(self.local_knl_args)
754+
access = local_arg.access
755+
dtype = local_arg.dtype
756756
interior_horizontal = self.iteration_region == ON_INTERIOR_FACETS
757-
if arg._is_dat:
758-
if arg._is_mixed:
759-
packs = []
760-
for a in arg:
761-
shape = a.data.shape[1:]
762-
if shape == ():
763-
shape = (1,)
764-
shape = (None, *shape)
765-
argument = Argument(shape, a.data.dtype, pfx="mdat")
766-
packs.append(a.data.pack(argument, arg.access, self.map_(a.map, unroll=a.unroll_map),
767-
interior_horizontal=interior_horizontal,
768-
init_with_zero=self.requires_zeroed_output_arguments))
769-
self.arguments.append(argument)
770-
pack = MixedDatPack(packs, arg.access, arg.dtype, interior_horizontal=interior_horizontal)
771-
self.packed_args.append(pack)
772-
self.argument_accesses.append(arg.access)
757+
758+
if isinstance(arg, GlobalKernelArg):
759+
argument = Argument(arg.dim, dtype, pfx="glob")
760+
761+
pack = GlobalPack(argument, access,
762+
init_with_zero=self.requires_zeroed_output_arguments)
763+
self.arguments.append(argument)
764+
elif isinstance(arg, DatKernelArg):
765+
if arg.dim == ():
766+
shape = (None, 1)
767+
else:
768+
shape = (None, *arg.dim)
769+
argument = Argument(shape, dtype, pfx="dat")
770+
771+
if arg.is_indirect:
772+
map_ = self._add_map(arg.map_)
773773
else:
774-
if arg._is_dat_view:
775-
view_index = arg.data.index
776-
data = arg.data._parent
774+
map_ = None
775+
pack = arg.pack(argument, access, map_=map_,
776+
interior_horizontal=interior_horizontal,
777+
view_index=arg.index,
778+
init_with_zero=self.requires_zeroed_output_arguments)
779+
self.arguments.append(argument)
780+
elif isinstance(arg, MixedDatKernelArg):
781+
packs = []
782+
for a in arg:
783+
if a.dim == ():
784+
shape = (None, 1)
785+
else:
786+
shape = (None, *a.dim)
787+
argument = Argument(shape, dtype, pfx="mdat")
788+
789+
if a.is_indirect:
790+
map_ = self._add_map(a.map_)
777791
else:
778-
view_index = None
779-
data = arg.data
780-
shape = data.shape[1:]
781-
if shape == ():
782-
shape = (1,)
783-
shape = (None, *shape)
784-
argument = Argument(shape,
785-
arg.data.dtype,
786-
pfx="dat")
787-
pack = arg.data.pack(argument, arg.access, self.map_(arg.map, unroll=arg.unroll_map),
788-
interior_horizontal=interior_horizontal,
789-
view_index=view_index,
790-
init_with_zero=self.requires_zeroed_output_arguments)
792+
map_ = None
793+
794+
packs.append(arg.pack(argument, access, map_,
795+
interior_horizontal=interior_horizontal,
796+
init_with_zero=self.requires_zeroed_output_arguments))
791797
self.arguments.append(argument)
792-
self.packed_args.append(pack)
793-
self.argument_accesses.append(arg.access)
794-
elif arg._is_global:
795-
argument = Argument(arg.data.dim,
796-
arg.data.dtype,
797-
pfx="glob")
798-
pack = GlobalPack(argument, arg.access,
799-
init_with_zero=self.requires_zeroed_output_arguments)
798+
pack = MixedDatPack(packs, access, dtype,
799+
interior_horizontal=interior_horizontal)
800+
elif isinstance(arg, MatKernelArg):
801+
argument = Argument((), PetscMat(), pfx="mat")
802+
maps = tuple(self._add_map(m, arg.unroll)
803+
for m in arg.maps)
804+
pack = arg.pack(argument, access, maps,
805+
arg.dims, dtype,
806+
interior_horizontal=interior_horizontal)
800807
self.arguments.append(argument)
801-
self.packed_args.append(pack)
802-
self.argument_accesses.append(arg.access)
803-
elif arg._is_mat:
804-
if arg._is_mixed:
805-
packs = []
806-
for a in arg:
807-
argument = Argument((), PetscMat(), pfx="mat")
808-
map_ = tuple(self.map_(m, unroll=arg.unroll_map) for m in a.map)
809-
packs.append(arg.data.pack(argument, a.access, map_,
810-
a.data.dims, a.data.dtype,
811-
interior_horizontal=interior_horizontal))
812-
self.arguments.append(argument)
813-
pack = MixedMatPack(packs, arg.access, arg.dtype,
814-
arg.data.sparsity.shape)
815-
self.packed_args.append(pack)
816-
self.argument_accesses.append(arg.access)
817-
else:
808+
elif isinstance(arg, MixedMatKernelArg):
809+
packs = []
810+
for a in arg:
818811
argument = Argument((), PetscMat(), pfx="mat")
819-
map_ = tuple(self.map_(m, unroll=arg.unroll_map) for m in arg.map)
820-
pack = arg.data.pack(argument, arg.access, map_,
821-
arg.data.dims, arg.data.dtype,
822-
interior_horizontal=interior_horizontal)
812+
maps = tuple(self._add_map(m, a.unroll)
813+
for m in a.maps)
814+
815+
packs.append(arg.pack(argument, access, maps,
816+
a.dims, dtype,
817+
interior_horizontal=interior_horizontal))
823818
self.arguments.append(argument)
824-
self.packed_args.append(pack)
825-
self.argument_accesses.append(arg.access)
819+
pack = MixedMatPack(packs, access, dtype,
820+
arg.shape)
826821
else:
827822
raise ValueError("Unhandled argument type")
828823

829-
def map_(self, map_, unroll=False):
824+
self.packed_args.append(pack)
825+
self.argument_accesses.append(access)
826+
827+
def _add_map(self, map_, unroll=False):
830828
if map_ is None:
831829
return None
832830
interior_horizontal = self.iteration_region == ON_INTERIOR_FACETS
833831
key = map_
834832
try:
835833
return self.maps[key]
836834
except KeyError:
837-
if isinstance(map_, PermutedMap):
838-
imap = self.map_(map_.map_, unroll=unroll)
839-
map_ = PMap(imap, map_.permutation)
835+
if isinstance(map_, PermutedMapKernelArg):
836+
imap = self._add_map(map_.base_map, unroll)
837+
map_ = PMap(imap, numpy.asarray(map_.permutation, dtype=IntType))
840838
else:
841-
map_ = Map(map_, interior_horizontal,
839+
map_ = Map(interior_horizontal,
842840
(self.bottom_layer, self.top_layer),
843-
unroll=unroll)
841+
arity=map_.arity, offset=map_.offset, dtype=IntType,
842+
unroll=unroll,
843+
extruded=self.extruded,
844+
constant_layers=self.constant_layers)
844845
self.maps[key] = map_
845846
return map_
846847

0 commit comments

Comments
 (0)