Skip to content

Commit f5544e4

Browse files
authored
Multi-dimensional GPU-based data filtering (#335)
Filter on up to 4 numeric attributes at the same time! This is working locally, but I need to figure out a clean way to connect it to the filter_value, since that's a list of lists. The best solution (to enable jslink) might be to create our own custom widget that can display 1-4 float sliders and stores its value as a list of lists?
1 parent f8e057e commit f5544e4

File tree

5 files changed

+416
-49
lines changed

5 files changed

+416
-49
lines changed

docs/api/experimental/traits.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
# lonboard.experimental.traits
22

33
::: lonboard.experimental.traits.PointAccessor
4+
::: lonboard.experimental.traits.GetFilterValueAccessor

lonboard/experimental/layer_extension.py

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import traitlets
22

33
from lonboard._base import BaseExtension
4-
from lonboard.experimental.traits import PointAccessor
4+
from lonboard.experimental.traits import GetFilterValueAccessor, PointAccessor
55
from lonboard.traits import FloatAccessor
66

77

@@ -160,12 +160,19 @@ class DataFilterExtension(BaseExtension):
160160
161161
## `filter_range`
162162
163-
The (min, max) bounds which defines whether an object should be rendered.
163+
The bounds which defines whether an object should be rendered. If an object's
164+
filtered value is within the bounds, the object will be rendered; otherwise it will
165+
be hidden. This prop can be updated on user input or animation with very little
166+
cost.
164167
165-
If an object's filtered value is within the bounds, the object will be rendered;
166-
otherwise it will be hidden.
168+
Format:
167169
168-
- Type: Tuple[float, float], optional
170+
If `filter_size` is 1, provide a single tuple of `(min, max)`.
171+
172+
If `filter_size` is 2 to 4, provide a list of tuples: `[(min0, max0), (min1,
173+
max1), ...]` for each filtered property, respectively.
174+
175+
- Type: either Tuple[float, float] or List[Tuple[float, float]], optional
169176
- Default: `(-1, 1)`
170177
171178
## `filter_soft_range`
@@ -199,8 +206,9 @@ class DataFilterExtension(BaseExtension):
199206
200207
Accessor to retrieve the value for each object that it will be filtered by.
201208
202-
- Type: [FloatAccessor][lonboard.traits.FloatAccessor]
203-
- If a number is provided, it is used as the value for all objects.
209+
- Type:
210+
[GetFilterValueAccessor][lonboard.experimental.traits.GetFilterValueAccessor]
211+
- If a scalar value is provided, it is used as the value for all objects.
204212
- If an array is provided, each value in the array will be used as the value
205213
for the object at the same row index.
206214
"""
@@ -209,22 +217,25 @@ class DataFilterExtension(BaseExtension):
209217

210218
_layer_traits = {
211219
"filter_enabled": traitlets.Bool(True).tag(sync=True),
212-
"filter_range": traitlets.Tuple(
213-
traitlets.Float(), traitlets.Float(), default_value=(-1, 1)
220+
"filter_range": traitlets.Union(
221+
[
222+
traitlets.List(traitlets.Float(), minlen=2, maxlen=2),
223+
traitlets.List(
224+
traitlets.List(traitlets.Float(), minlen=2, maxlen=2),
225+
minlen=2,
226+
maxlen=4,
227+
),
228+
]
214229
).tag(sync=True),
215230
"filter_soft_range": traitlets.Tuple(
216231
traitlets.Float(), traitlets.Float(), default_value=None, allow_none=True
217232
).tag(sync=True),
218233
"filter_transform_size": traitlets.Bool(True).tag(sync=True),
219234
"filter_transform_color": traitlets.Bool(True).tag(sync=True),
220-
"get_filter_value": FloatAccessor(None, allow_none=False),
235+
"get_filter_value": GetFilterValueAccessor(None, allow_none=False),
221236
}
222237

223-
# TODO: support filterSize > 1
224-
# In order to support filterSize > 1, we need to allow the get_filter_value accessor
225-
# to be either a single float or a fixed size list of up to 4 floats.
226-
227-
# filter_size = traitlets.Int(1).tag(sync=True)
238+
filter_size = traitlets.Int(1, min=1, max=4).tag(sync=True)
228239
"""The size of the filter (number of columns to filter by).
229240
230241
The data filter can show/hide data based on 1-4 numeric properties of each object.

lonboard/experimental/traits.py

Lines changed: 196 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -44,39 +44,37 @@ def validate(
4444
) -> Union[Tuple[int, ...], List[int], pa.ChunkedArray, pa.FixedSizeListArray]:
4545
if isinstance(value, np.ndarray):
4646
if value.ndim != 2:
47-
self.error(obj, value, info="Point array must have 2 dimensions.")
47+
self.error(obj, value, info="Point array to have 2 dimensions")
4848

4949
list_size = value.shape[1]
5050
if list_size not in (2, 3):
5151
self.error(
5252
obj,
5353
value,
54-
info="Point array must have 2 or 3 as its second dimension.",
54+
info="Point array to have 2 or 3 as its second dimension",
5555
)
5656

5757
return pa.FixedSizeListArray.from_arrays(value.flatten("C"), list_size)
5858

5959
if isinstance(value, (pa.ChunkedArray, pa.Array)):
6060
if not pa.types.is_fixed_size_list(value.type):
61-
self.error(
62-
obj, value, info="Point pyarrow array must be a FixedSizeList."
63-
)
61+
self.error(obj, value, info="Point pyarrow array to be a FixedSizeList")
6462

6563
if value.type.list_size not in (2, 3):
6664
self.error(
6765
obj,
6866
value,
6967
info=(
70-
"Color pyarrow array must have a FixedSizeList inner size of "
71-
"2 or 3."
68+
"Color pyarrow array to be a FixedSizeList with list size of "
69+
"2 or 3"
7270
),
7371
)
7472

7573
if not pa.types.is_floating(value.type.value_type):
7674
self.error(
7775
obj,
7876
value,
79-
info="Point pyarrow array must have a floating point child.",
77+
info="Point pyarrow array to have a floating point child",
8078
)
8179

8280
return value
@@ -89,8 +87,8 @@ def validate(
8987
obj,
9088
value,
9189
info=(
92-
"Color string must be a hex string interpretable by "
93-
"matplotlib.colors.to_rgba."
90+
"Color string to be a hex string interpretable by "
91+
"matplotlib.colors.to_rgba"
9492
),
9593
)
9694
return
@@ -99,3 +97,191 @@ def validate(
9997

10098
self.error(obj, value)
10199
assert False
100+
101+
102+
class GetFilterValueAccessor(FixedErrorTraitType):
103+
"""
104+
A trait to validate input for the `get_filter_value` accessor added by the
105+
[`DataFilterExtension`][lonboard.experimental.DataFilterExtension], which can have
106+
between 1 and 4 float values per row.
107+
108+
Various input is allowed:
109+
110+
- An `int` or `float`. This will be used as the value for all objects. The
111+
`filter_size` of the
112+
[`DataFilterExtension`][lonboard.experimental.DataFilterExtension] instance must
113+
be 1.
114+
- A one-dimensional numpy `ndarray` with a numeric data type. This will be casted to
115+
an array of data type [`np.float32`][numpy.float32]. Each value in the array will
116+
be used as the value for the object at the same row index. The `filter_size` of
117+
the [`DataFilterExtension`][lonboard.experimental.DataFilterExtension] instance
118+
must be 1.
119+
- A two-dimensional numpy `ndarray` with a numeric data type. This will be casted to
120+
an array of data type [`np.float32`][numpy.float32]. Each value in the array will
121+
be used as the value for the object at the same row index. The `filter_size` of
122+
the [`DataFilterExtension`][lonboard.experimental.DataFilterExtension] instance
123+
must match the size of the second dimension of the array.
124+
- A pandas `Series` with a numeric data type. This will be casted to an array of
125+
data type [`np.float32`][numpy.float32]. Each value in the array will be used as
126+
the value for the object at the same row index. The `filter_size` of
127+
the [`DataFilterExtension`][lonboard.experimental.DataFilterExtension] instance
128+
must be 1.
129+
- A pyarrow [`FloatArray`][pyarrow.FloatArray], [`DoubleArray`][pyarrow.DoubleArray]
130+
or [`ChunkedArray`][pyarrow.ChunkedArray] containing either a `FloatArray` or
131+
`DoubleArray`. Each value in the array will be used as the value for the object at
132+
the same row index. The `filter_size` of the
133+
[`DataFilterExtension`][lonboard.experimental.DataFilterExtension] instance must
134+
be 1.
135+
136+
Alternatively, you can pass any corresponding Arrow data structure from a library
137+
that implements the [Arrow PyCapsule
138+
Interface](https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html).
139+
- A pyarrow [`FixedSizeListArray`][pyarrow.FixedSizeListArray] or
140+
[`ChunkedArray`][pyarrow.ChunkedArray] containing `FixedSizeListArray`s. The child
141+
array of the fixed size list must be of floating point type. The `filter_size` of
142+
the [`DataFilterExtension`][lonboard.experimental.DataFilterExtension] instance
143+
must match the list size.
144+
145+
Alternatively, you can pass any corresponding Arrow data structure from a library
146+
that implements the [Arrow PyCapsule
147+
Interface](https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html).
148+
"""
149+
150+
default_value = float(0)
151+
info_text = (
152+
"a float value or numpy ndarray or pyarrow array representing an array"
153+
" of floats"
154+
)
155+
156+
def __init__(
157+
self: TraitType,
158+
*args,
159+
**kwargs: Any,
160+
) -> None:
161+
super().__init__(*args, **kwargs)
162+
self.tag(sync=True, **ACCESSOR_SERIALIZATION)
163+
164+
def validate(self, obj, value) -> Union[float, pa.ChunkedArray, pa.DoubleArray]:
165+
# Find the data filter extension in the attributes of the parent object so we
166+
# can validate against the filter size.
167+
data_filter_extension = [
168+
ext for ext in obj.extensions if ext._extension_type == "data-filter"
169+
]
170+
assert len(data_filter_extension) == 1
171+
filter_size = data_filter_extension[0].filter_size
172+
173+
if isinstance(value, (int, float)):
174+
if filter_size != 1:
175+
self.error(obj, value, info="filter_size==1 with scalar value")
176+
177+
return float(value)
178+
179+
if isinstance(value, (tuple, list)):
180+
if filter_size != len(value):
181+
self.error(
182+
obj,
183+
value,
184+
info=f"filter_size ({filter_size}) to match length of tuple/list",
185+
)
186+
187+
if any(not isinstance(v, (int, float)) for v in value):
188+
self.error(
189+
obj,
190+
value,
191+
info="all values in tuple or list to be numeric",
192+
)
193+
194+
return value
195+
196+
# pandas Series
197+
if (
198+
value.__class__.__module__.startswith("pandas")
199+
and value.__class__.__name__ == "Series"
200+
):
201+
# Assert that filter_size == 1 for a pandas series.
202+
# Pandas series can technically contain Python list objects inside them, but
203+
# for simplicity we disallow that.
204+
if filter_size != 1:
205+
self.error(obj, value, info="filter_size==1 with pandas Series")
206+
207+
# Cast pandas Series to numpy ndarray
208+
value = np.asarray(value)
209+
210+
if isinstance(value, np.ndarray):
211+
if not np.issubdtype(value.dtype, np.number):
212+
self.error(obj, value, info="numeric dtype")
213+
214+
# Cast to float32
215+
value = value.astype(np.float32)
216+
217+
if len(value.shape) == 1:
218+
if filter_size != 1:
219+
self.error(obj, value, info="filter_size==1 with 1-D numpy array")
220+
221+
return pa.array(value)
222+
223+
if len(value.shape) != 2:
224+
self.error(obj, value, info="1-D or 2-D numpy array")
225+
226+
if value.shape[1] != filter_size:
227+
self.error(
228+
obj,
229+
value,
230+
info=(
231+
f"filter_size ({filter_size}) to match 2nd dimension of "
232+
"numpy array"
233+
),
234+
)
235+
236+
return pa.FixedSizeListArray.from_arrays(value.flatten("C"), filter_size)
237+
238+
# Check for Arrow PyCapsule Interface
239+
# https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
240+
# TODO: with pyarrow v16 also import chunked array from stream
241+
if not isinstance(value, (pa.ChunkedArray, pa.Array)):
242+
if hasattr(value, "__arrow_c_array__"):
243+
value = pa.array(value)
244+
245+
if isinstance(value, (pa.ChunkedArray, pa.Array)):
246+
# Allowed inputs are either a FixedSizeListArray or numeric array.
247+
# If not a fixed size list array, check for floating and cast to float32
248+
if not pa.types.is_fixed_size_list(value.type):
249+
if filter_size != 1:
250+
self.error(
251+
obj,
252+
value,
253+
info="filter_size==1 with non-FixedSizeList type arrow array",
254+
)
255+
256+
if not pa.types.is_floating(value.type):
257+
self.error(
258+
obj,
259+
value,
260+
info="arrow array to be a floating point type",
261+
)
262+
263+
return value.cast(pa.float32())
264+
265+
# We have a FixedSizeListArray
266+
if filter_size != value.type.list_size:
267+
self.error(
268+
obj,
269+
value,
270+
info=(
271+
f"filter_size ({filter_size}) to match list size of "
272+
"FixedSizeList arrow array"
273+
),
274+
)
275+
276+
if not pa.types.is_floating(value.type.value_type):
277+
self.error(
278+
obj,
279+
value,
280+
info="arrow array to have floating point child type",
281+
)
282+
283+
# Cast values to float32
284+
return value.cast(pa.list_(pa.float32(), value.type.list_size))
285+
286+
self.error(obj, value)
287+
assert False

src/model/extension.ts

Lines changed: 4 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
11
import { LayerExtension } from "@deck.gl/core/typed";
22
import {
3-
BrushingExtensionProps,
4-
CollisionFilterExtensionProps,
5-
DataFilterExtensionProps,
63
BrushingExtension as _BrushingExtension,
74
CollisionFilterExtension as _CollisionFilterExtension,
85
DataFilterExtension as _DataFilterExtension,
@@ -26,11 +23,6 @@ export class BrushingExtension extends BaseExtensionModel {
2623

2724
extensionInstance: _BrushingExtension;
2825

29-
protected getBrushingTarget?: BrushingExtensionProps["getBrushingTarget"];
30-
protected brushingEnabled?: BrushingExtensionProps["brushingEnabled"];
31-
protected brushingTarget?: BrushingExtensionProps["brushingTarget"];
32-
protected brushingRadius?: BrushingExtensionProps["brushingRadius"];
33-
3426
constructor(
3527
model: WidgetModel,
3628
layerModel: BaseLayerModel,
@@ -65,11 +57,6 @@ export class CollisionFilterExtension extends BaseExtensionModel {
6557

6658
extensionInstance: _CollisionFilterExtension;
6759

68-
protected collisionEnabled?: CollisionFilterExtensionProps["collisionEnabled"];
69-
protected collisionGroup?: CollisionFilterExtensionProps["collisionGroup"];
70-
protected collisionTestProps?: CollisionFilterExtensionProps["collisionTestProps"];
71-
protected getCollisionPriority?: CollisionFilterExtensionProps["getCollisionPriority"];
72-
7360
constructor(
7461
model: WidgetModel,
7562
layerModel: BaseLayerModel,
@@ -107,22 +94,17 @@ export class DataFilterExtension extends BaseExtensionModel {
10794

10895
extensionInstance: _DataFilterExtension;
10996

110-
protected getFilterValue?: DataFilterExtensionProps["getFilterValue"];
111-
112-
protected filterEnabled?: DataFilterExtensionProps["filterEnabled"];
113-
protected filterRange?: DataFilterExtensionProps["filterRange"];
114-
protected filterSoftRange?: DataFilterExtensionProps["filterSoftRange"];
115-
protected filterTransformSize?: DataFilterExtensionProps["filterTransformSize"];
116-
protected filterTransformColor?: DataFilterExtensionProps["filterTransformColor"];
117-
11897
constructor(
11998
model: WidgetModel,
12099
layerModel: BaseLayerModel,
121100
updateStateCallback: () => void,
122101
) {
123102
super(model, updateStateCallback);
103+
124104
// TODO: set filterSize, fp64, countItems in constructor
125-
this.extensionInstance = new _DataFilterExtension();
105+
// TODO: should filter_size automatically update from python?
106+
const filterSize = this.model.get("filter_size");
107+
this.extensionInstance = new _DataFilterExtension({ filterSize });
126108

127109
// Properties added by the extension onto the layer
128110
layerModel.initRegularAttribute("filter_enabled", "filterEnabled");

0 commit comments

Comments
 (0)