Skip to content

Commit 45550f5

Browse files
committed
misc: Thread-safe symbol cache
1 parent 23c8c2d commit 45550f5

File tree

2 files changed

+38
-22
lines changed

2 files changed

+38
-22
lines changed

devito/types/basic.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from devito.tools import (Pickable, as_tuple, dtype_to_ctype,
1717
frozendict, memoized_meth, sympy_mutex, CustomDtype)
1818
from devito.types.args import ArgProvider
19-
from devito.types.caching import Cached, Uncached
19+
from devito.types.caching import CacheManager, Cached, Uncached
2020
from devito.types.lazy import Evaluable
2121
from devito.types.utils import DimensionTuple
2222

@@ -559,24 +559,31 @@ def _cache_key(cls, *args, **kwargs):
559559
def __new__(cls, *args, **kwargs):
560560
assumptions, kwargs = cls._filter_assumptions(**kwargs)
561561
key = cls._cache_key(*args, **{**assumptions, **kwargs})
562-
obj = cls._cache_get(key)
563562

563+
# Initial cache lookup (not locked)
564+
obj = cls._cache_get(key)
564565
if obj is not None:
565566
return obj
566567

567-
# Not in cache. Create a new Symbol via sympy.Symbol
568-
args = list(args)
569-
name = kwargs.pop('name', None) or args.pop(0)
570-
newobj = cls.__xnew__(cls, name, **assumptions)
568+
# Lock against the symbol cache and double-check the cache
569+
with CacheManager.lock():
570+
obj = cls._cache_get(key)
571+
if obj is not None:
572+
return obj
571573

572-
# Initialization
573-
newobj._dtype = cls.__dtype_setup__(**kwargs)
574-
newobj.__init_finalize__(name, *args, **kwargs)
574+
# Not in cache. Create a new Symbol via sympy.Symbol
575+
args = list(args)
576+
name = kwargs.pop('name', None) or args.pop(0)
577+
newobj = cls.__xnew__(cls, name, **assumptions)
575578

576-
# Store new instance in symbol cache
577-
Cached.__init__(newobj, key)
579+
# Initialization
580+
newobj._dtype = cls.__dtype_setup__(**kwargs)
581+
newobj.__init_finalize__(name, *args, **kwargs)
578582

579-
return newobj
583+
# Store new instance in symbol cache
584+
Cached.__init__(newobj, key)
585+
586+
return newobj
580587

581588
__hash__ = Cached.__hash__
582589

devito/types/caching.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import gc
2+
from threading import RLock
23
import weakref
34

45
import sympy
@@ -10,6 +11,7 @@
1011
__all__ = ['Cached', 'Uncached', '_SymbolCache', 'CacheManager']
1112

1213
_SymbolCache = {}
14+
_cache_lock = RLock()
1315
"""The symbol cache."""
1416

1517

@@ -76,8 +78,10 @@ def _cache_get(cls, key):
7678
obj = obj_cached()
7779
if obj is None:
7880
# Cleanup _SymbolCache (though practically unnecessary)
79-
# does not fail if it's already gone
80-
_SymbolCache.pop(key, None)
81+
with _cache_lock:
82+
# Ensure another thread hasn't replaced the ref we're evicting
83+
if _SymbolCache.get(key) is obj_cached:
84+
_SymbolCache.pop(key, None)
8185
return None
8286
else:
8387
return obj
@@ -196,16 +200,21 @@ def clear(cls, force=True):
196200
# We won't call gc.collect() this time
197201
cls.ncalls_w_force_false += 1
198202

199-
for key in cache_copied:
200-
obj = _SymbolCache.get(key)
201-
if obj is None:
202-
# deleted by another thread since we took the copy
203-
continue
204-
if obj() is None:
205-
# (key could be removed in another thread since get() above)
206-
_SymbolCache.pop(key, None)
203+
for key, obj_cached in cache_copied.items():
204+
if obj_cached() is None:
205+
with _cache_lock:
206+
# Check if our snapshot of the cached object is still live
207+
if _SymbolCache.get(key) is obj_cached:
208+
_SymbolCache.pop(key, None)
207209

208210
# Maybe trigger garbage collection
209211
if force:
210212
del cache_copied
211213
gc.collect()
214+
215+
@staticmethod
216+
def lock():
217+
"""
218+
Returns the global symbol cache lock for atomic construction.
219+
"""
220+
return _cache_lock

0 commit comments

Comments
 (0)