Skip to content

Commit dac523c

Browse files
authored
Improve stage mapping integration (#14)
1 parent 5f34881 commit dac523c

File tree

11 files changed

+1298
-214
lines changed

11 files changed

+1298
-214
lines changed

python/README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,3 +147,12 @@ python3 examples/server.py
147147

148148
### Acknowledgement
149149
Databend Python UDF Server API is inspired by [RisingWave Python API](https://pypi.org/project/risingwave/).
150+
151+
### Code Formatting
152+
153+
Use Ruff to keep the Python sources consistent:
154+
155+
```bash
156+
python -m pip install ruff # once
157+
python -m ruff format python/databend_udf python/tests
158+
```

python/README_CLIENT.md

Lines changed: 0 additions & 121 deletions
This file was deleted.

python/databend_udf/client.py

Lines changed: 84 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
"""
2-
Simple client library for testing Databend UDF servers.
3-
"""
1+
"""Simple client library for testing Databend UDF servers."""
2+
3+
import json
4+
from typing import Any, Dict, Iterable, List, Sequence, Tuple
45

56
import pyarrow as pa
67
import pyarrow.flight as fl
7-
from typing import List, Any
88

99

1010
class UDFClient:
@@ -135,7 +135,69 @@ def get_function_info(self, function_name: str) -> fl.FlightInfo:
135135
descriptor = fl.FlightDescriptor.for_path(function_name)
136136
return self.client.get_flight_info(descriptor)
137137

138-
def call_function(self, function_name: str, *args) -> List[Any]:
138+
@staticmethod
139+
def format_stage_mapping(stage_locations: Iterable[Dict[str, Any]]) -> str:
140+
"""Serialize stage mapping entries to the Databend header payload."""
141+
142+
serialized_entries: List[Dict[str, Any]] = []
143+
for entry in stage_locations:
144+
if not isinstance(entry, dict):
145+
raise ValueError("stage_locations entries must be dictionaries")
146+
if "param_name" not in entry:
147+
raise ValueError("stage_locations entry requires 'param_name'")
148+
serialized_entries.append(entry)
149+
150+
return json.dumps(serialized_entries)
151+
152+
@staticmethod
153+
def _build_flight_headers(
154+
headers: Dict[str, Any] = None,
155+
stage_locations: Iterable[Dict[str, Any]] = None,
156+
) -> Sequence[Tuple[str, str]]:
157+
"""Construct Flight headers for a UDF call.
158+
159+
``stage_locations`` becomes a single header named ``databend-stage-mapping``
160+
whose value is a JSON array. This mirrors what Databend Query sends to
161+
external UDF servers. Example HTTP-style representation::
162+
163+
databend-stage-mapping: [
164+
{
165+
"param_name": "stage_loc",
166+
"relative_path": "input/2024/",
167+
"stage_info": { ... StageInfo JSON ... }
168+
}
169+
]
170+
171+
Multiple stage parameters simply append more objects to the array.
172+
Additional custom headers can be supplied through ``headers``.
173+
"""
174+
headers = headers or {}
175+
flight_headers: List[Tuple[bytes, bytes]] = []
176+
177+
for key, value in headers.items():
178+
if isinstance(value, (list, tuple)):
179+
for item in value:
180+
flight_headers.append(
181+
(str(key).encode("utf-8"), str(item).encode("utf-8"))
182+
)
183+
else:
184+
flight_headers.append(
185+
(str(key).encode("utf-8"), str(value).encode("utf-8"))
186+
)
187+
188+
if stage_locations:
189+
payload = UDFClient.format_stage_mapping(stage_locations)
190+
flight_headers.append((b"databend-stage-mapping", payload.encode("utf-8")))
191+
192+
return flight_headers
193+
194+
def call_function(
195+
self,
196+
function_name: str,
197+
*args,
198+
headers: Dict[str, Any] = None,
199+
stage_locations: Iterable[Dict[str, Any]] = None,
200+
) -> List[Any]:
139201
"""
140202
Call a UDF function with given arguments.
141203
@@ -150,7 +212,11 @@ def call_function(self, function_name: str, *args) -> List[Any]:
150212

151213
# Call function
152214
descriptor = fl.FlightDescriptor.for_path(function_name)
153-
writer, reader = self.client.do_exchange(descriptor=descriptor)
215+
flight_headers = self._build_flight_headers(headers, stage_locations)
216+
options = (
217+
fl.FlightCallOptions(headers=flight_headers) if flight_headers else None
218+
)
219+
writer, reader = self.client.do_exchange(descriptor=descriptor, options=options)
154220

155221
with writer:
156222
writer.begin(input_schema)
@@ -166,7 +232,13 @@ def call_function(self, function_name: str, *args) -> List[Any]:
166232

167233
return results
168234

169-
def call_function_batch(self, function_name: str, **kwargs) -> List[Any]:
235+
def call_function_batch(
236+
self,
237+
function_name: str,
238+
headers: Dict[str, Any] = None,
239+
stage_locations: Iterable[Dict[str, Any]] = None,
240+
**kwargs,
241+
) -> List[Any]:
170242
"""
171243
Call a UDF function with batch data.
172244
@@ -181,7 +253,11 @@ def call_function_batch(self, function_name: str, **kwargs) -> List[Any]:
181253

182254
# Call function
183255
descriptor = fl.FlightDescriptor.for_path(function_name)
184-
writer, reader = self.client.do_exchange(descriptor=descriptor)
256+
flight_headers = self._build_flight_headers(headers, stage_locations)
257+
options = (
258+
fl.FlightCallOptions(headers=flight_headers) if flight_headers else None
259+
)
260+
writer, reader = self.client.do_exchange(descriptor=descriptor, options=options)
185261

186262
with writer:
187263
writer.begin(input_schema)

0 commit comments

Comments
 (0)