Skip to content

Commit 621caa2

Browse files
committed
ENH: add caching logic
1 parent 0f06f77 commit 621caa2

File tree

4 files changed

+60
-11
lines changed

4 files changed

+60
-11
lines changed

Diff for: data_prototype/containers.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,16 @@ def __init__(self, xfuncs, yfuncs=None, xyfuncs=None):
2121
self._xyfuncs = xyfuncs or {}
2222

2323
def query(self, data_bounds, size, xscale=None, yscale=None):
24+
hash_key = hash((data_bounds, size, xscale, yscale))
25+
if hash_key in self._cache:
26+
return self._cache[hash_key], hash_key
2427
xmin, xmax, ymin, ymax = data_bounds
2528
xpix, ypix = size
2629
x_data = np.linspace(xmin, xmax, int(xpix) * 2)
2730
y_data = np.linspace(ymin, ymax, int(ypix) * 2)
28-
return dict(
31+
ret = self._cache[hash_key] = dict(
2932
**{k: f(x_data) for k, f in self._xfuncs.items()},
3033
**{k: f(y_data) for k, f in self._yfuncs.items()},
3134
**{k: f(x_data, y_data) for k, f in self._xyfuncs.items()}
3235
)
36+
return ret, hash_key

Diff for: data_prototype/wrappers.py

+52-10
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
from typing import List, Dict, Any
2+
3+
import numpy as np
4+
5+
from cachetools import LFUCache
6+
17
from matplotlib.lines import Line2D as _Line2D
28
from matplotlib.image import AxesImage as _AxesImage
39

@@ -24,18 +30,54 @@ def __setattr__(self, key, value):
2430
else:
2531
super().__setattr__(key, value)
2632

27-
def _query_and_transform(self, renderer):
33+
def _query_and_transform(self, renderer, *, xunits: List[str], yunits: List[str]) -> Dict[str, Any]:
34+
"""
35+
Helper to centralize the data querying and python-side transforms
36+
37+
Parameters
38+
----------
39+
renderer : RendererBase
40+
xunits, yunits : List[str]
41+
The list of keys that need to be run through the x and y unit machinery.
42+
"""
43+
# extract what we need to about the axes to query the data
2844
ax = self._wrapped_instance.axes
45+
# TODO do we want to trust the implicit renderer on the Axes?
2946
ax_bbox = ax.get_window_extent(renderer)
30-
return {
31-
# doing this here is nice because we can write it once, but we really want to
32-
# push this computation down a layer
33-
k: self.nus.get(k, lambda x: x)(v)
34-
for k, v in self.data.query([*ax.get_xlim(), *ax.get_ylim()], ax_bbox.size).items()
35-
}
47+
48+
# actually query the underlying data. This returns both the (raw) data
49+
# and key to use for caching.
50+
data, cache_key = self.data.query(
51+
# TODO do this need to be (de) unitized
52+
(*ax.get_xlim(), *ax.get_ylim()),
53+
tuple(np.round(ax_bbox.size).astype(int)),
54+
# TODO sort out how to spell the x/y scale
55+
# TODO is scale enoguh? What do we have to do about non-trivial projection?
56+
xscale=None,
57+
yscale=None,
58+
)
59+
# see if we can short-circuit
60+
try:
61+
return self._cache[cache_key]
62+
except KeyError:
63+
...
64+
# TODO decide if units go pre-nu or post-nu?
65+
for x_like in xunits:
66+
data[x_like] = ax.xaxis.convert_units(data[x_like])
67+
for y_like in yunits:
68+
data[y_like] = ax.xaxis.convert_units(data[y_like])
69+
70+
# doing the nu work here is nice because we can write it once, but we
71+
# really want to push this computation down a layer
72+
# TODO sort out how this interaporates with the transform stack
73+
data = {k: self.nus.get(k, lambda x: x)(v) for k, v in data.items()}
74+
self._cache[cache_key] = data
75+
return data
3676

3777
def __init__(self, data, nus):
3878
self.data = data
79+
self._cache = LFUCache(64)
80+
# TODO make sure mutating this will invalidate the cache!
3981
self.nus = nus or {}
4082

4183

@@ -48,7 +90,7 @@ def __init__(self, data, nus=None, /, **kwargs):
4890
self._wrapped_instance = self._wrapped_class([], [], **kwargs)
4991

5092
def draw(self, renderer):
51-
data = self._query_and_transform(renderer)
93+
data = self._query_and_transform(renderer, xunits=["x"], yunits=["y"])
5294
self._wrapped_instance.set_data(data["x"], data["y"])
5395
return self._wrapped_instance.draw(renderer)
5496

@@ -62,7 +104,7 @@ def __init__(self, data, nus=None, /, **kwargs):
62104
self._wrapped_instance = self._wrapped_class(None, **kwargs)
63105

64106
def draw(self, renderer):
65-
data = self._query_and_transform(renderer)
107+
data = self._query_and_transform(renderer, xunits=["xextent"], yunits=["yextent"])
66108
self._wrapped_instance.set_array(data["image"])
67-
self._wrapped_instance.set_extent(data["extent"])
109+
self._wrapped_instance.set_extent([*data["xextent"], *data["yextent"]])
68110
return self._wrapped_instance.draw(renderer)

Diff for: requirements-dev.txt

+2
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,5 @@ numpydoc
1616
sphinx-copybutton
1717
mpl-sphinx-theme
1818
sphinx-gallery
19+
# typing
20+
types-cachetools

Diff for: requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# List required packages in this file, one per line.
2+
cachetools
23
matplotlib
34
pandas
45
xarray

0 commit comments

Comments
 (0)