1- """
2- Simple client library for testing Databend UDF servers.
3- """
1+ """Simple client library for testing Databend UDF servers."""
2+
3+ import json
4+ from typing import Any , Dict , Iterable , List , Sequence , Tuple
45
56import pyarrow as pa
67import pyarrow .flight as fl
7- from typing import List , Any
88
99
1010class UDFClient :
@@ -135,7 +135,69 @@ def get_function_info(self, function_name: str) -> fl.FlightInfo:
135135 descriptor = fl .FlightDescriptor .for_path (function_name )
136136 return self .client .get_flight_info (descriptor )
137137
138- def call_function (self , function_name : str , * args ) -> List [Any ]:
138+ @staticmethod
139+ def format_stage_mapping (stage_locations : Iterable [Dict [str , Any ]]) -> str :
140+ """Serialize stage mapping entries to the Databend header payload."""
141+
142+ serialized_entries : List [Dict [str , Any ]] = []
143+ for entry in stage_locations :
144+ if not isinstance (entry , dict ):
145+ raise ValueError ("stage_locations entries must be dictionaries" )
146+ if "param_name" not in entry :
147+ raise ValueError ("stage_locations entry requires 'param_name'" )
148+ serialized_entries .append (entry )
149+
150+ return json .dumps (serialized_entries )
151+
152+ @staticmethod
153+ def _build_flight_headers (
154+ headers : Dict [str , Any ] = None ,
155+ stage_locations : Iterable [Dict [str , Any ]] = None ,
156+ ) -> Sequence [Tuple [str , str ]]:
157+ """Construct Flight headers for a UDF call.
158+
159+ ``stage_locations`` becomes a single header named ``databend-stage-mapping``
160+ whose value is a JSON array. This mirrors what Databend Query sends to
161+ external UDF servers. Example HTTP-style representation::
162+
163+ databend-stage-mapping: [
164+ {
165+ "param_name": "stage_loc",
166+ "relative_path": "input/2024/",
167+ "stage_info": { ... StageInfo JSON ... }
168+ }
169+ ]
170+
171+ Multiple stage parameters simply append more objects to the array.
172+ Additional custom headers can be supplied through ``headers``.
173+ """
174+ headers = headers or {}
175+ flight_headers : List [Tuple [bytes , bytes ]] = []
176+
177+ for key , value in headers .items ():
178+ if isinstance (value , (list , tuple )):
179+ for item in value :
180+ flight_headers .append (
181+ (str (key ).encode ("utf-8" ), str (item ).encode ("utf-8" ))
182+ )
183+ else :
184+ flight_headers .append (
185+ (str (key ).encode ("utf-8" ), str (value ).encode ("utf-8" ))
186+ )
187+
188+ if stage_locations :
189+ payload = UDFClient .format_stage_mapping (stage_locations )
190+ flight_headers .append ((b"databend-stage-mapping" , payload .encode ("utf-8" )))
191+
192+ return flight_headers
193+
194+ def call_function (
195+ self ,
196+ function_name : str ,
197+ * args ,
198+ headers : Dict [str , Any ] = None ,
199+ stage_locations : Iterable [Dict [str , Any ]] = None ,
200+ ) -> List [Any ]:
139201 """
140202 Call a UDF function with given arguments.
141203
@@ -150,7 +212,11 @@ def call_function(self, function_name: str, *args) -> List[Any]:
150212
151213 # Call function
152214 descriptor = fl .FlightDescriptor .for_path (function_name )
153- writer , reader = self .client .do_exchange (descriptor = descriptor )
215+ flight_headers = self ._build_flight_headers (headers , stage_locations )
216+ options = (
217+ fl .FlightCallOptions (headers = flight_headers ) if flight_headers else None
218+ )
219+ writer , reader = self .client .do_exchange (descriptor = descriptor , options = options )
154220
155221 with writer :
156222 writer .begin (input_schema )
@@ -166,7 +232,13 @@ def call_function(self, function_name: str, *args) -> List[Any]:
166232
167233 return results
168234
169- def call_function_batch (self , function_name : str , ** kwargs ) -> List [Any ]:
235+ def call_function_batch (
236+ self ,
237+ function_name : str ,
238+ headers : Dict [str , Any ] = None ,
239+ stage_locations : Iterable [Dict [str , Any ]] = None ,
240+ ** kwargs ,
241+ ) -> List [Any ]:
170242 """
171243 Call a UDF function with batch data.
172244
@@ -181,7 +253,11 @@ def call_function_batch(self, function_name: str, **kwargs) -> List[Any]:
181253
182254 # Call function
183255 descriptor = fl .FlightDescriptor .for_path (function_name )
184- writer , reader = self .client .do_exchange (descriptor = descriptor )
256+ flight_headers = self ._build_flight_headers (headers , stage_locations )
257+ options = (
258+ fl .FlightCallOptions (headers = flight_headers ) if flight_headers else None
259+ )
260+ writer , reader = self .client .do_exchange (descriptor = descriptor , options = options )
185261
186262 with writer :
187263 writer .begin (input_schema )
0 commit comments