Skip to content

Commit 3c2f0e9

Browse files
committed
PR #200 finetune VectorCube.apply_dimension
ref: #197, Open-EO/openeo-geopyspark-driver#437
1 parent 30a422e commit 3c2f0e9

File tree

3 files changed

+160
-33
lines changed

3 files changed

+160
-33
lines changed

openeo_driver/datacube.py

Lines changed: 37 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,14 @@
1818

1919
from openeo.metadata import CollectionMetadata
2020
from openeo.util import ensure_dir, str_truncate
21+
import openeo.udf
2122
from openeo_driver.datastructs import SarBackscatterArgs, ResolutionMergeArgs, StacAsset
2223
from openeo_driver.errors import FeatureUnsupportedException, InternalException
2324
from openeo_driver.util.geometry import GeometryBufferer, validate_geojson_coordinates
2425
from openeo_driver.util.ioformats import IOFORMATS
26+
from openeo_driver.util.pgparsing import SingleRunUDFProcessGraph
2527
from openeo_driver.util.utm import area_in_square_meters
2628
from openeo_driver.utils import EvalEnv
27-
from openeogeotrellis.backend import SingleNodeUDFProcessGraphVisitor
2829

2930
log = logging.getLogger(__name__)
3031

@@ -248,38 +249,6 @@ def with_cube(self, cube: xarray.DataArray, flatten_prefix: str = FLATTEN_PREFIX
248249
geometries=self._geometries, cube=cube, flatten_prefix=flatten_prefix
249250
)
250251

251-
def apply_dimension(
252-
self,
253-
process: dict,
254-
*,
255-
dimension: str,
256-
target_dimension: Optional[str] = None,
257-
context: Optional[dict] = None,
258-
env: EvalEnv,
259-
) -> "DriverVectorCube":
260-
if dimension == "bands" and target_dimension == None and len(process) == 1 and next(iter(process.values())).get('process_id') == 'run_udf':
261-
visitor = SingleNodeUDFProcessGraphVisitor().accept_process_graph(process)
262-
udf = visitor.udf_args.get('udf', None)
263-
264-
from openeo.udf import FeatureCollection, UdfData
265-
collection = FeatureCollection(id='VectorCollection', data=self._as_geopandas_df())
266-
data = UdfData(
267-
proj={"EPSG": self._geometries.crs.to_epsg()}, feature_collection_list=[collection], user_context=context
268-
)
269-
270-
log.info(f"[run_udf] Running UDF {str_truncate(udf, width=256)!r} on {data!r}")
271-
result_data = env.backend_implementation.processing.run_udf(udf, data)
272-
log.info(f"[run_udf] UDF resulted in {result_data!r}")
273-
274-
if isinstance(result_data, UdfData):
275-
if(result_data.get_feature_collection_list() is not None and len(result_data.get_feature_collection_list()) == 1):
276-
return DriverVectorCube(geometries=result_data.get_feature_collection_list()[0].data)
277-
278-
raise ValueError(f"Could not handle UDF result: {result_data}")
279-
280-
else:
281-
raise FeatureUnsupportedException()
282-
283252
@classmethod
284253
def from_fiona(
285254
cls,
@@ -537,6 +506,41 @@ def buffer_points(self, distance: float = 10) -> "DriverVectorCube":
537506
]
538507
)
539508

509+
def apply_dimension(
510+
self,
511+
process: dict,
512+
*,
513+
dimension: str,
514+
target_dimension: Optional[str] = None,
515+
context: Optional[dict] = None,
516+
env: EvalEnv,
517+
) -> "DriverVectorCube":
518+
single_run_udf = SingleRunUDFProcessGraph.parse_or_none(process)
519+
520+
if single_run_udf:
521+
# Process with single "run_udf" node
522+
if self._cube is None and dimension == self.DIM_GEOMETRIES and target_dimension is Non:
523+
# TODO: this is non-standard special case: vector cube with only geometries, but no "cube" data
524+
feature_collection = openeo.udf.FeatureCollection(id="_", data=self._as_geopandas_df())
525+
udf_data = openeo.udf.UdfData(
526+
proj={"EPSG": self._geometries.crs.to_epsg()},
527+
feature_collection_list=[feature_collection],
528+
user_context=context,
529+
)
530+
log.info(f"[run_udf] Running UDF {str_truncate(single_run_udf.udf, width=256)!r} on {udf_data!r}")
531+
result_data = env.backend_implementation.processing.run_udf(udf=single_run_udf.udf, data=udf_data)
532+
log.info(f"[run_udf] UDF resulted in {result_data!r}")
533+
534+
if isinstance(result_data, openeo.udf.UdfData):
535+
result_features = result_data.get_feature_collection_list()
536+
if result_features and len(result_features) == 1:
537+
return DriverVectorCube(geometries=result_features[0].data)
538+
raise ValueError(f"Could not handle UDF result: {result_data}")
539+
540+
else:
541+
raise FeatureUnsupportedException()
542+
543+
540544

541545
class DriverMlModel:
542546
"""Base class for driver-side 'ml-model' data structures"""

openeo_driver/util/pgparsing.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import dataclasses
2+
from typing import Optional
3+
4+
5+
class NotASingleRunUDFProcessGraph(ValueError):
6+
pass
7+
8+
9+
@dataclasses.dataclass(frozen=True)
10+
class SingleRunUDFProcessGraph:
11+
"""
12+
Container (and parser) for a callback process graph containing only a single `run_udf` node.
13+
"""
14+
15+
data: dict
16+
udf: str
17+
runtime: str
18+
version: Optional[str] = None
19+
context: Optional[dict] = None
20+
21+
@classmethod
22+
def parse(cls, process_graph: dict) -> "SingleRunUDFProcessGraph":
23+
try:
24+
(node,) = process_graph.values()
25+
assert node["process_id"] == "run_udf"
26+
assert node["result"] is True
27+
arguments = node["arguments"]
28+
assert {"data", "udf", "runtime"}.issubset(arguments.keys())
29+
30+
return cls(
31+
data=arguments["data"],
32+
udf=arguments["udf"],
33+
runtime=arguments["runtime"],
34+
version=arguments.get("version"),
35+
context=arguments.get("context") or {},
36+
)
37+
except Exception as e:
38+
raise NotASingleRunUDFProcessGraph(str(e)) from e
39+
40+
@classmethod
41+
def parse_or_none(cls, process_graph: dict) -> Optional["SingleNodeRunUDFProcessGraph"]:
42+
try:
43+
return cls.parse(process_graph=process_graph)
44+
except NotASingleRunUDFProcessGraph:
45+
return None

tests/util/test_pgparsing.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
import pytest
2+
3+
from openeo_driver.util.pgparsing import SingleRunUDFProcessGraph, NotASingleRunUDFProcessGraph
4+
5+
6+
class TestSingleRunUDFProcessGraph:
7+
def test_parse_basic(self):
8+
pg = {
9+
"runudf1": {
10+
"process_id": "run_udf",
11+
"arguments": {
12+
"data": {"from_parameter": "data"},
13+
"udf": "print('Hello world')",
14+
"runtime": "Python",
15+
},
16+
"result": True,
17+
}
18+
}
19+
run_udf = SingleRunUDFProcessGraph.parse(pg)
20+
assert run_udf.data == {"from_parameter": "data"}
21+
assert run_udf.udf == "print('Hello world')"
22+
assert run_udf.runtime == "Python"
23+
assert run_udf.version is None
24+
assert run_udf.context == {}
25+
26+
@pytest.mark.parametrize(
27+
"pg",
28+
[
29+
{
30+
"runudf1": {
31+
"process_id": "run_udffffffffffffffff",
32+
"arguments": {"data": {"from_parameter": "data"}, "udf": "x = 4", "runtime": "Python"},
33+
"result": True,
34+
}
35+
},
36+
{
37+
"runudf1": {
38+
"process_id": "run_udf",
39+
"arguments": {"udf": "x = 4", "runtime": "Python"},
40+
"result": True,
41+
}
42+
},
43+
{
44+
"runudf1": {
45+
"process_id": "run_udf",
46+
"arguments": {"data": {"from_parameter": "data"}, "runtime": "Python"},
47+
"result": True,
48+
}
49+
},
50+
{
51+
"runudf1": {
52+
"process_id": "run_udf",
53+
"arguments": {"data": {"from_parameter": "data"}, "udf": "x = 4"},
54+
"result": True,
55+
}
56+
},
57+
{
58+
"runudf1": {
59+
"process_id": "run_udf",
60+
"arguments": {"data": {"from_parameter": "data"}, "udf": "x = 4", "runtime": "Python"},
61+
}
62+
},
63+
{
64+
"runudf1": {
65+
"process_id": "run_udf",
66+
"arguments": {"data": {"from_parameter": "data"}, "udf": "x = 4", "runtime": "Python"},
67+
"result": True,
68+
},
69+
"runudf2": {
70+
"process_id": "run_udf",
71+
"arguments": {"data": {"from_parameter": "data"}, "udf": "x = 4", "runtime": "Python"},
72+
},
73+
},
74+
],
75+
)
76+
def test_parse_invalid(self, pg):
77+
with pytest.raises(NotASingleRunUDFProcessGraph):
78+
_ = SingleRunUDFProcessGraph.parse(pg)

0 commit comments

Comments
 (0)