Skip to content

Commit 949b272

Browse files
Support pyarrow.Table with geoarrow.pyarrow extension types as geometry columns (#218)
Work-in-progress, and haven't yet tested it. The idea is that when encountering a `geoarrow.pyarrow` extension type, to convert this to its raw storage field with metadata (i.e. how the extension type looks like in IPC or when the python wrapper class is not registered). Short term, this is the easiest approach, because a lot of the existing code relies on checking eg the "ARROW:extension:name" in the field metadata. --------- Co-authored-by: Kyle Barron <[email protected]>
1 parent 655e55d commit 949b272

File tree

6 files changed

+184
-7
lines changed

6 files changed

+184
-7
lines changed

lonboard/_geoarrow/sanitize.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
"""Remove custom geoarrow.pyarrow types from input geoarrow data
2+
"""
3+
import json
4+
from typing import Tuple
5+
6+
import pyarrow as pa
7+
from pyproj import CRS
8+
9+
10+
def sanitize_table(table: pa.Table) -> pa.Table:
11+
"""
12+
Convert any registered geoarrow.pyarrow extension fields and arrays to plain
13+
metadata
14+
"""
15+
for field_idx in range(len(table.schema)):
16+
field = table.field(field_idx)
17+
column = table.column(field_idx)
18+
19+
if isinstance(field.type, pa.ExtensionType):
20+
assert all(isinstance(chunk, pa.ExtensionArray) for chunk in column.chunks)
21+
new_field, new_column = sanitize_column(field, column)
22+
table = table.set_column(field_idx, new_field, new_column)
23+
24+
return table
25+
26+
27+
def sanitize_column(
28+
field: pa.Field, column: pa.ChunkedArray
29+
) -> Tuple[pa.Field, pa.ChunkedArray]:
30+
"""
31+
Convert a registered geoarrow.pyarrow extension field and column to plain metadata
32+
"""
33+
import geoarrow.pyarrow as gap
34+
35+
extension_metadata = {}
36+
if field.type.crs:
37+
extension_metadata["crs"] = CRS.from_user_input(field.type.crs).to_json()
38+
39+
if field.type.edge_type == gap.EdgeType.SPHERICAL:
40+
extension_metadata["edges"] = "spherical"
41+
42+
metadata = {
43+
"ARROW:extension:name": field.type.extension_name,
44+
}
45+
if extension_metadata:
46+
metadata["ARROW:extension:metadata"] = json.dumps(extension_metadata)
47+
48+
new_field = pa.field(
49+
field.name, field.type.storage_type, nullable=field.nullable, metadata=metadata
50+
)
51+
52+
new_chunks = []
53+
for chunk in column.chunks:
54+
if hasattr(chunk, "storage"):
55+
new_chunks.append(chunk.storage)
56+
else:
57+
new_chunks.append(chunk.cast(new_field.type))
58+
59+
return new_field, pa.chunked_array(new_chunks)

lonboard/_layer.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from lonboard._geoarrow.ops import reproject_table
2424
from lonboard._geoarrow.ops.bbox import Bbox, total_bounds
2525
from lonboard._geoarrow.ops.centroid import WeightedCentroid, weighted_centroid
26+
from lonboard._geoarrow.sanitize import sanitize_table
2627
from lonboard._serialization import infer_rows_per_chunk
2728
from lonboard._utils import auto_downcast as _auto_downcast
2829
from lonboard._utils import get_geometry_column_index, remove_extension_kwargs
@@ -231,6 +232,13 @@ class BaseArrowLayer(BaseLayer):
231232
def __init__(
232233
self, *, table: pa.Table, _rows_per_chunk: Optional[int] = None, **kwargs
233234
):
235+
# Check for Arrow PyCapsule Interface
236+
# https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
237+
if not isinstance(table, pa.Table) and hasattr(table, "__arrow_c_stream__"):
238+
table = pa.table(table)
239+
240+
table = sanitize_table(table)
241+
234242
# Reproject table to WGS84 if needed
235243
# Note this must happen before calculating the default viewport
236244
table = reproject_table(table, to_crs=OGC_84)

lonboard/traits.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
ACCESSOR_SERIALIZATION,
2222
TABLE_SERIALIZATION,
2323
)
24+
from lonboard._utils import get_geometry_column_index
2425

2526

2627
# This is a custom subclass of traitlets.TraitType because its `error` method ignores
@@ -139,21 +140,19 @@ def __init__(
139140
)
140141

141142
def validate(self, obj: Self, value: Any):
142-
# Check for Arrow PyCapsule Interface
143-
# https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
144-
if not isinstance(value, pa.Table) and hasattr(value, "__arrow_c_stream__"):
145-
value = pa.table(value)
146-
147143
if not isinstance(value, pa.Table):
148144
self.error(obj, value)
149145

150146
allowed_geometry_types = self.metadata.get("allowed_geometry_types")
147+
# No restriction on the allowed geometry types in this table
151148
if not allowed_geometry_types:
152149
return value
153150

154-
geometry_extension_type = value.schema.field("geometry").metadata.get(
151+
geom_col_idx = get_geometry_column_index(value.schema)
152+
geometry_extension_type = value.schema.field(geom_col_idx).metadata.get(
155153
b"ARROW:extension:name"
156154
)
155+
157156
if (
158157
allowed_geometry_types
159158
and geometry_extension_type not in allowed_geometry_types

poetry.lock

Lines changed: 97 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ black = "^23.10.1"
3232
geoarrow-rust-core = "^0.1.0"
3333
geodatasets = "^2023.12.0"
3434
pyogrio = "^0.7.2"
35+
geoarrow-pyarrow = "^0.1.1"
3536

3637
[tool.poetry.group.docs.dependencies]
3738
mkdocs = "^1.4.3"

tests/test_layer.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import geopandas as gpd
22
import numpy as np
3+
import pyarrow as pa
34
import pytest
45
import shapely
56
from traitlets import TraitError
@@ -58,3 +59,16 @@ def test_layer_outside_4326_range():
5859

5960
with pytest.raises(ValueError, match="outside of WGS84 bounds"):
6061
_layer = ScatterplotLayer.from_geopandas(gdf)
62+
63+
64+
def test_layer_from_geoarrow_pyarrow():
65+
ga = pytest.importorskip("geoarrow.pyarrow")
66+
67+
points = gpd.GeoSeries(shapely.points([1, 2], [3, 4]))
68+
69+
# convert to geoarrow.pyarrow Table (currently requires to ensure interleaved
70+
# coordinates manually)
71+
points = ga.with_coord_type(ga.as_geoarrow(points), ga.CoordType.INTERLEAVED)
72+
table = pa.table({"geometry": points})
73+
74+
_layer = ScatterplotLayer(table=table)

0 commit comments

Comments
 (0)