Skip to content

Commit 988298e

Browse files
committed
Implement DelayedConversionNode as a generalization on evaluate-time dependency injection
Used for units initially, presented as an alternative to #34 which provides more general tools
1 parent c588266 commit 988298e

File tree

4 files changed

+43
-35
lines changed

4 files changed

+43
-35
lines changed

Diff for: data_prototype/axes.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from .containers import ArrayContainer, DataUnion
1212
from .conversion_node import (
13-
MatplotlibUnitConversion,
13+
DelayedConversionNode,
1414
FunctionConversionNode,
1515
RenameConversionNode,
1616
)
@@ -96,8 +96,8 @@ def scatter(
9696
cont = DataUnion(defaults, inputs)
9797

9898
pipeline = []
99-
xconvert = MatplotlibUnitConversion.from_keys(("x",), axis=self.xaxis)
100-
yconvert = MatplotlibUnitConversion.from_keys(("y",), axis=self.yaxis)
99+
xconvert = DelayedConversionNode.from_keys(("x",), converter_key="xunits")
100+
yconvert = DelayedConversionNode.from_keys(("y",), converter_key="yunits")
101101
pipeline.extend([xconvert, yconvert])
102102
pipeline.append(lambda x: np.ma.ravel(x))
103103
pipeline.append(lambda y: np.ma.ravel(y))

Diff for: data_prototype/conversion_node.py

+22-11
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,22 @@
66
import inspect
77
from functools import cached_property
88

9-
from matplotlib.axis import Axis
10-
119
from typing import Any
1210

1311

14-
def evaluate_pipeline(nodes: Sequence[ConversionNode], input: dict[str, Any]):
12+
def evaluate_pipeline(
13+
nodes: Sequence[ConversionNode],
14+
input: dict[str, Any],
15+
delayed_converters: dict[str, Callable] | None = None,
16+
):
1517
for node in nodes:
1618
if isinstance(node, Callable):
1719
k = list(inspect.signature(node).parameters.keys())[0]
1820
node = FunctionConversionNode.from_funcs({k: node})
19-
20-
input = node.evaluate(input)
21+
if isinstance(node, DelayedConversionNode):
22+
input = node.evaluate(input, delayed_converters)
23+
else:
24+
input = node.evaluate(input)
2125
return input
2226

2327

@@ -122,17 +126,24 @@ def evaluate(self, input: dict[str, Any]) -> dict[str, Any]:
122126

123127

124128
@dataclass
125-
class MatplotlibUnitConversion(ConversionNode):
126-
axis: Axis
129+
class DelayedConversionNode(ConversionNode):
130+
converter_key: str
127131

128132
@classmethod
129-
def from_keys(cls, keys: Sequence[str], axis: Axis):
130-
return cls(tuple(keys), tuple(keys), trim_keys=False, axis=axis)
133+
def from_keys(cls, keys: Sequence[str], converter_key: str):
134+
return cls(
135+
tuple(keys), tuple(keys), trim_keys=False, converter_key=converter_key
136+
)
131137

132-
def evaluate(self, input: dict[str, Any]) -> dict[str, Any]:
138+
def evaluate(
139+
self, input: dict[str, Any], converters: dict[str, Callable] | None = None
140+
) -> dict[str, Any]:
133141
return super().evaluate(
134142
{
135143
**input,
136-
**{k: self.axis.convert_units(input[k]) for k in self.required_keys},
144+
**{
145+
k: converters[self.converter_key](input[k])
146+
for k in self.required_keys
147+
},
137148
}
138149
)

Diff for: data_prototype/wrappers.py

+13-18
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, Dict, Any, Protocol, Tuple, get_type_hints
1+
from typing import Dict, Any, Protocol, Tuple, get_type_hints
22
import inspect
33

44
import numpy as np
@@ -121,17 +121,13 @@ def draw(self, renderer):
121121
def _update_wrapped(self, data):
122122
raise NotImplementedError
123123

124-
def _query_and_transform(
125-
self, renderer, *, xunits: List[str], yunits: List[str]
126-
) -> Dict[str, Any]:
124+
def _query_and_transform(self, renderer) -> Dict[str, Any]:
127125
"""
128126
Helper to centralize the data querying and python-side transforms
129127
130128
Parameters
131129
----------
132130
renderer : RendererBase
133-
xunits, yunits : List[str]
134-
The list of keys that need to be run through the x and y unit machinery.
135131
"""
136132
# extract what we need to about the axes to query the data
137133
ax = self.axes
@@ -153,8 +149,11 @@ def _query_and_transform(
153149
return self._cache[cache_key]
154150
except KeyError:
155151
...
156-
# TODO units
157-
transformed_data = evaluate_pipeline(self._converters, data)
152+
delayed_conversion = {
153+
"xunits": ax.xaxis.convert_units,
154+
"yunits": ax.yaxis.convert_units,
155+
}
156+
transformed_data = evaluate_pipeline(self._converters, data, delayed_conversion)
158157

159158
self._cache[cache_key] = transformed_data
160159
return transformed_data
@@ -232,7 +231,7 @@ def __init__(self, data: DataContainer, converters=None, /, **kwargs):
232231
@_stale_wrapper
233232
def draw(self, renderer):
234233
self._update_wrapped(
235-
self._query_and_transform(renderer, xunits=["x"], yunits=["y"]),
234+
self._query_and_transform(renderer),
236235
)
237236
return self._wrapped_instance.draw(renderer)
238237

@@ -265,7 +264,7 @@ def __init__(self, data: DataContainer, converters=None, /, **kwargs):
265264
@_stale_wrapper
266265
def draw(self, renderer):
267266
self._update_wrapped(
268-
self._query_and_transform(renderer, xunits=[], yunits=[]),
267+
self._query_and_transform(renderer),
269268
)
270269
return self._wrapped_instance.draw(renderer)
271270

@@ -304,7 +303,7 @@ def __init__(
304303
@_stale_wrapper
305304
def draw(self, renderer):
306305
self._update_wrapped(
307-
self._query_and_transform(renderer, xunits=["xextent"], yunits=["yextent"]),
306+
self._query_and_transform(renderer),
308307
)
309308
return self._wrapped_instance.draw(renderer)
310309

@@ -325,7 +324,7 @@ def __init__(self, data: DataContainer, converters=None, /, **kwargs):
325324
@_stale_wrapper
326325
def draw(self, renderer):
327326
self._update_wrapped(
328-
self._query_and_transform(renderer, xunits=["edges"], yunits=["density"]),
327+
self._query_and_transform(renderer),
329328
)
330329
return self._wrapped_instance.draw(renderer)
331330

@@ -344,7 +343,7 @@ def __init__(self, data: DataContainer, converters=None, /, **kwargs):
344343
@_stale_wrapper
345344
def draw(self, renderer):
346345
self._update_wrapped(
347-
self._query_and_transform(renderer, xunits=[], yunits=[]),
346+
self._query_and_transform(renderer),
348347
)
349348
return self._wrapped_instance.draw(renderer)
350349

@@ -425,11 +424,7 @@ def __init__(self, data: DataContainer, converters=None, /, **kwargs):
425424
@_stale_wrapper
426425
def draw(self, renderer):
427426
self._update_wrapped(
428-
self._query_and_transform(
429-
renderer,
430-
xunits=["x", "xupper", "xlower"],
431-
yunits=["y", "yupper", "ylower"],
432-
),
427+
self._query_and_transform(renderer),
433428
)
434429
for k, v in self._wrapped_instances.items():
435430
v.draw(renderer)

Diff for: examples/units.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import matplotlib.markers as mmarkers
1313

1414
from data_prototype.containers import ArrayContainer
15-
from data_prototype.conversion_node import MatplotlibUnitConversion
15+
from data_prototype.conversion_node import DelayedConversionNode
1616

1717
from data_prototype.wrappers import PathCollectionWrapper
1818

@@ -35,8 +35,10 @@
3535
fig, ax = plt.subplots()
3636
ax.set_xlim(-0.5, 7)
3737
ax.set_ylim(0, 5)
38-
conv = MatplotlibUnitConversion.from_keys(("x",), axis=ax.xaxis)
39-
lw = PathCollectionWrapper(cont, [conv], offset_transform=ax.transData)
38+
xconv = DelayedConversionNode.from_keys(("x",), converter_key="xunits")
39+
yconv = DelayedConversionNode.from_keys(("y",), converter_key="yunits")
40+
lw = PathCollectionWrapper(cont, [xconv, yconv], offset_transform=ax.transData)
4041
ax.add_artist(lw)
4142
ax.xaxis.set_units(ureg.feet)
43+
ax.yaxis.set_units(ureg.m)
4244
plt.show()

0 commit comments

Comments
 (0)