@@ -30,9 +30,8 @@ use crate::serde::protobuf::{
30
30
ExecuteQueryParams , GetJobStatusParams , GetJobStatusResult , KeyValuePair ,
31
31
PartitionLocation ,
32
32
} ;
33
- use crate :: utils:: WrappedStream ;
34
33
35
- use datafusion:: arrow:: datatypes:: { Schema , SchemaRef } ;
34
+ use datafusion:: arrow:: datatypes:: SchemaRef ;
36
35
use datafusion:: error:: { DataFusionError , Result } ;
37
36
use datafusion:: logical_plan:: LogicalPlan ;
38
37
use datafusion:: physical_plan:: expressions:: PhysicalSortExpr ;
@@ -43,12 +42,14 @@ use datafusion::physical_plan::{
43
42
use crate :: serde:: protobuf:: execute_query_params:: OptionalSessionId ;
44
43
use crate :: serde:: { AsLogicalPlan , DefaultLogicalExtensionCodec , LogicalExtensionCodec } ;
45
44
use async_trait:: async_trait;
45
+ use datafusion:: arrow:: error:: { ArrowError , Result as ArrowResult } ;
46
+ use datafusion:: arrow:: record_batch:: RecordBatch ;
46
47
use datafusion:: execution:: context:: TaskContext ;
47
- use futures :: future ;
48
- use futures:: StreamExt ;
48
+ use datafusion :: physical_plan :: stream :: RecordBatchStreamAdapter ;
49
+ use futures:: { Stream , StreamExt , TryFutureExt , TryStreamExt } ;
49
50
use log:: { error, info} ;
50
51
51
- /// This operator sends a logial plan to a Ballista scheduler for execution and
52
+ /// This operator sends a logical plan to a Ballista scheduler for execution and
52
53
/// polls the scheduler until the query is complete and then fetches the resulting
53
54
/// batches directly from the executors that hold the results from the final
54
55
/// query stage.
@@ -168,15 +169,6 @@ impl<T: 'static + AsLogicalPlan> ExecutionPlan for DistributedQueryExec<T> {
168
169
) -> Result < SendableRecordBatchStream > {
169
170
assert_eq ! ( 0 , partition) ;
170
171
171
- info ! ( "Connecting to Ballista scheduler at {}" , self . scheduler_url) ;
172
- // TODO reuse the scheduler to avoid connecting to the Ballista scheduler again and again
173
-
174
- let mut scheduler = SchedulerGrpcClient :: connect ( self . scheduler_url . clone ( ) )
175
- . await
176
- . map_err ( |e| DataFusionError :: Execution ( format ! ( "{:?}" , e) ) ) ?;
177
-
178
- let schema: Schema = self . plan . schema ( ) . as_ref ( ) . clone ( ) . into ( ) ;
179
-
180
172
let mut buf: Vec < u8 > = vec ! [ ] ;
181
173
let plan_message =
182
174
T :: try_from_logical_plan ( & self . plan , self . extension_codec . as_ref ( ) ) . map_err (
@@ -191,88 +183,30 @@ impl<T: 'static + AsLogicalPlan> ExecutionPlan for DistributedQueryExec<T> {
191
183
DataFusionError :: Execution ( format ! ( "failed to encode logical plan: {:?}" , e) )
192
184
} ) ?;
193
185
194
- let query_result = scheduler
195
- . execute_query ( ExecuteQueryParams {
196
- query : Some ( Query :: LogicalPlan ( buf) ) ,
197
- settings : self
198
- . config
199
- . settings ( )
200
- . iter ( )
201
- . map ( |( k, v) | KeyValuePair {
202
- key : k. to_owned ( ) ,
203
- value : v. to_owned ( ) ,
204
- } )
205
- . collect :: < Vec < _ > > ( ) ,
206
- optional_session_id : Some ( OptionalSessionId :: SessionId (
207
- self . session_id . clone ( ) ,
208
- ) ) ,
209
- } )
210
- . await
211
- . map_err ( |e| DataFusionError :: Execution ( format ! ( "{:?}" , e) ) ) ?
212
- . into_inner ( ) ;
213
-
214
- let response_session_id = query_result. session_id ;
215
- assert_eq ! (
216
- self . session_id. clone( ) ,
217
- response_session_id,
218
- "Session id inconsistent between Client and Server side in DistributedQueryExec."
219
- ) ;
186
+ let query = ExecuteQueryParams {
187
+ query : Some ( Query :: LogicalPlan ( buf) ) ,
188
+ settings : self
189
+ . config
190
+ . settings ( )
191
+ . iter ( )
192
+ . map ( |( k, v) | KeyValuePair {
193
+ key : k. to_owned ( ) ,
194
+ value : v. to_owned ( ) ,
195
+ } )
196
+ . collect :: < Vec < _ > > ( ) ,
197
+ optional_session_id : Some ( OptionalSessionId :: SessionId (
198
+ self . session_id . clone ( ) ,
199
+ ) ) ,
200
+ } ;
220
201
221
- let job_id = query_result. job_id ;
222
- let mut prev_status: Option < job_status:: Status > = None ;
202
+ let stream = futures:: stream:: once (
203
+ execute_query ( self . scheduler_url . clone ( ) , self . session_id . clone ( ) , query)
204
+ . map_err ( |e| ArrowError :: ExternalError ( Box :: new ( e) ) ) ,
205
+ )
206
+ . try_flatten ( ) ;
223
207
224
- loop {
225
- let GetJobStatusResult { status } = scheduler
226
- . get_job_status ( GetJobStatusParams {
227
- job_id : job_id. clone ( ) ,
228
- } )
229
- . await
230
- . map_err ( |e| DataFusionError :: Execution ( format ! ( "{:?}" , e) ) ) ?
231
- . into_inner ( ) ;
232
- let status = status. and_then ( |s| s. status ) . ok_or_else ( || {
233
- DataFusionError :: Internal ( "Received empty status message" . to_owned ( ) )
234
- } ) ?;
235
- let wait_future = tokio:: time:: sleep ( Duration :: from_millis ( 100 ) ) ;
236
- let has_status_change = prev_status. map ( |x| x != status) . unwrap_or ( true ) ;
237
- match status {
238
- job_status:: Status :: Queued ( _) => {
239
- if has_status_change {
240
- info ! ( "Job {} still queued..." , job_id) ;
241
- }
242
- wait_future. await ;
243
- prev_status = Some ( status) ;
244
- }
245
- job_status:: Status :: Running ( _) => {
246
- if has_status_change {
247
- info ! ( "Job {} is running..." , job_id) ;
248
- }
249
- wait_future. await ;
250
- prev_status = Some ( status) ;
251
- }
252
- job_status:: Status :: Failed ( err) => {
253
- let msg = format ! ( "Job {} failed: {}" , job_id, err. error) ;
254
- error ! ( "{}" , msg) ;
255
- break Err ( DataFusionError :: Execution ( msg) ) ;
256
- }
257
- job_status:: Status :: Completed ( completed) => {
258
- let result = future:: join_all (
259
- completed
260
- . partition_location
261
- . into_iter ( )
262
- . map ( fetch_partition) ,
263
- )
264
- . await
265
- . into_iter ( )
266
- . collect :: < Result < Vec < _ > > > ( ) ?;
267
-
268
- let result = WrappedStream :: new (
269
- Box :: pin ( futures:: stream:: iter ( result) . flatten ( ) ) ,
270
- Arc :: new ( schema) ,
271
- ) ;
272
- break Ok ( Box :: pin ( result) ) ;
273
- }
274
- } ;
275
- }
208
+ let schema = self . schema ( ) ;
209
+ Ok ( Box :: pin ( RecordBatchStreamAdapter :: new ( schema, stream) ) )
276
210
}
277
211
278
212
fn fmt_as (
@@ -299,6 +233,79 @@ impl<T: 'static + AsLogicalPlan> ExecutionPlan for DistributedQueryExec<T> {
299
233
}
300
234
}
301
235
236
+ async fn execute_query (
237
+ scheduler_url : String ,
238
+ session_id : String ,
239
+ query : ExecuteQueryParams ,
240
+ ) -> Result < impl Stream < Item = ArrowResult < RecordBatch > > + Send > {
241
+ info ! ( "Connecting to Ballista scheduler at {}" , scheduler_url) ;
242
+ // TODO reuse the scheduler to avoid connecting to the Ballista scheduler again and again
243
+
244
+ let mut scheduler = SchedulerGrpcClient :: connect ( scheduler_url. clone ( ) )
245
+ . await
246
+ . map_err ( |e| DataFusionError :: Execution ( format ! ( "{:?}" , e) ) ) ?;
247
+
248
+ let query_result = scheduler
249
+ . execute_query ( query)
250
+ . await
251
+ . map_err ( |e| DataFusionError :: Execution ( format ! ( "{:?}" , e) ) ) ?
252
+ . into_inner ( ) ;
253
+
254
+ assert_eq ! (
255
+ session_id, query_result. session_id,
256
+ "Session id inconsistent between Client and Server side in DistributedQueryExec."
257
+ ) ;
258
+
259
+ let job_id = query_result. job_id ;
260
+ let mut prev_status: Option < job_status:: Status > = None ;
261
+
262
+ loop {
263
+ let GetJobStatusResult { status } = scheduler
264
+ . get_job_status ( GetJobStatusParams {
265
+ job_id : job_id. clone ( ) ,
266
+ } )
267
+ . await
268
+ . map_err ( |e| DataFusionError :: Execution ( format ! ( "{:?}" , e) ) ) ?
269
+ . into_inner ( ) ;
270
+ let status = status. and_then ( |s| s. status ) . ok_or_else ( || {
271
+ DataFusionError :: Internal ( "Received empty status message" . to_owned ( ) )
272
+ } ) ?;
273
+ let wait_future = tokio:: time:: sleep ( Duration :: from_millis ( 100 ) ) ;
274
+ let has_status_change = prev_status. map ( |x| x != status) . unwrap_or ( true ) ;
275
+ match status {
276
+ job_status:: Status :: Queued ( _) => {
277
+ if has_status_change {
278
+ info ! ( "Job {} still queued..." , job_id) ;
279
+ }
280
+ wait_future. await ;
281
+ prev_status = Some ( status) ;
282
+ }
283
+ job_status:: Status :: Running ( _) => {
284
+ if has_status_change {
285
+ info ! ( "Job {} is running..." , job_id) ;
286
+ }
287
+ wait_future. await ;
288
+ prev_status = Some ( status) ;
289
+ }
290
+ job_status:: Status :: Failed ( err) => {
291
+ let msg = format ! ( "Job {} failed: {}" , job_id, err. error) ;
292
+ error ! ( "{}" , msg) ;
293
+ break Err ( DataFusionError :: Execution ( msg) ) ;
294
+ }
295
+ job_status:: Status :: Completed ( completed) => {
296
+ let streams = completed. partition_location . into_iter ( ) . map ( |p| {
297
+ let f = fetch_partition ( p)
298
+ . map_err ( |e| ArrowError :: ExternalError ( Box :: new ( e) ) ) ;
299
+
300
+ futures:: stream:: once ( f) . try_flatten ( )
301
+ } ) ;
302
+
303
+ break Ok ( futures:: stream:: iter ( streams) . flatten ( ) ) ;
304
+ }
305
+ } ;
306
+ }
307
+ }
308
+
302
309
async fn fetch_partition (
303
310
location : PartitionLocation ,
304
311
) -> Result < SendableRecordBatchStream > {
0 commit comments