Skip to content

Commit b7bb2cf

Browse files
authored
Fix Ballista executing during plan (#2428)
1 parent dc76ec1 commit b7bb2cf

File tree

5 files changed

+314
-318
lines changed

5 files changed

+314
-318
lines changed

ballista/rust/core/src/execution_plans/distributed_query.rs

+101-94
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,8 @@ use crate::serde::protobuf::{
3030
ExecuteQueryParams, GetJobStatusParams, GetJobStatusResult, KeyValuePair,
3131
PartitionLocation,
3232
};
33-
use crate::utils::WrappedStream;
3433

35-
use datafusion::arrow::datatypes::{Schema, SchemaRef};
34+
use datafusion::arrow::datatypes::SchemaRef;
3635
use datafusion::error::{DataFusionError, Result};
3736
use datafusion::logical_plan::LogicalPlan;
3837
use datafusion::physical_plan::expressions::PhysicalSortExpr;
@@ -43,12 +42,14 @@ use datafusion::physical_plan::{
4342
use crate::serde::protobuf::execute_query_params::OptionalSessionId;
4443
use crate::serde::{AsLogicalPlan, DefaultLogicalExtensionCodec, LogicalExtensionCodec};
4544
use async_trait::async_trait;
45+
use datafusion::arrow::error::{ArrowError, Result as ArrowResult};
46+
use datafusion::arrow::record_batch::RecordBatch;
4647
use datafusion::execution::context::TaskContext;
47-
use futures::future;
48-
use futures::StreamExt;
48+
use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
49+
use futures::{Stream, StreamExt, TryFutureExt, TryStreamExt};
4950
use log::{error, info};
5051

51-
/// This operator sends a logial plan to a Ballista scheduler for execution and
52+
/// This operator sends a logical plan to a Ballista scheduler for execution and
5253
/// polls the scheduler until the query is complete and then fetches the resulting
5354
/// batches directly from the executors that hold the results from the final
5455
/// query stage.
@@ -168,15 +169,6 @@ impl<T: 'static + AsLogicalPlan> ExecutionPlan for DistributedQueryExec<T> {
168169
) -> Result<SendableRecordBatchStream> {
169170
assert_eq!(0, partition);
170171

171-
info!("Connecting to Ballista scheduler at {}", self.scheduler_url);
172-
// TODO reuse the scheduler to avoid connecting to the Ballista scheduler again and again
173-
174-
let mut scheduler = SchedulerGrpcClient::connect(self.scheduler_url.clone())
175-
.await
176-
.map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?;
177-
178-
let schema: Schema = self.plan.schema().as_ref().clone().into();
179-
180172
let mut buf: Vec<u8> = vec![];
181173
let plan_message =
182174
T::try_from_logical_plan(&self.plan, self.extension_codec.as_ref()).map_err(
@@ -191,88 +183,30 @@ impl<T: 'static + AsLogicalPlan> ExecutionPlan for DistributedQueryExec<T> {
191183
DataFusionError::Execution(format!("failed to encode logical plan: {:?}", e))
192184
})?;
193185

194-
let query_result = scheduler
195-
.execute_query(ExecuteQueryParams {
196-
query: Some(Query::LogicalPlan(buf)),
197-
settings: self
198-
.config
199-
.settings()
200-
.iter()
201-
.map(|(k, v)| KeyValuePair {
202-
key: k.to_owned(),
203-
value: v.to_owned(),
204-
})
205-
.collect::<Vec<_>>(),
206-
optional_session_id: Some(OptionalSessionId::SessionId(
207-
self.session_id.clone(),
208-
)),
209-
})
210-
.await
211-
.map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?
212-
.into_inner();
213-
214-
let response_session_id = query_result.session_id;
215-
assert_eq!(
216-
self.session_id.clone(),
217-
response_session_id,
218-
"Session id inconsistent between Client and Server side in DistributedQueryExec."
219-
);
186+
let query = ExecuteQueryParams {
187+
query: Some(Query::LogicalPlan(buf)),
188+
settings: self
189+
.config
190+
.settings()
191+
.iter()
192+
.map(|(k, v)| KeyValuePair {
193+
key: k.to_owned(),
194+
value: v.to_owned(),
195+
})
196+
.collect::<Vec<_>>(),
197+
optional_session_id: Some(OptionalSessionId::SessionId(
198+
self.session_id.clone(),
199+
)),
200+
};
220201

221-
let job_id = query_result.job_id;
222-
let mut prev_status: Option<job_status::Status> = None;
202+
let stream = futures::stream::once(
203+
execute_query(self.scheduler_url.clone(), self.session_id.clone(), query)
204+
.map_err(|e| ArrowError::ExternalError(Box::new(e))),
205+
)
206+
.try_flatten();
223207

224-
loop {
225-
let GetJobStatusResult { status } = scheduler
226-
.get_job_status(GetJobStatusParams {
227-
job_id: job_id.clone(),
228-
})
229-
.await
230-
.map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?
231-
.into_inner();
232-
let status = status.and_then(|s| s.status).ok_or_else(|| {
233-
DataFusionError::Internal("Received empty status message".to_owned())
234-
})?;
235-
let wait_future = tokio::time::sleep(Duration::from_millis(100));
236-
let has_status_change = prev_status.map(|x| x != status).unwrap_or(true);
237-
match status {
238-
job_status::Status::Queued(_) => {
239-
if has_status_change {
240-
info!("Job {} still queued...", job_id);
241-
}
242-
wait_future.await;
243-
prev_status = Some(status);
244-
}
245-
job_status::Status::Running(_) => {
246-
if has_status_change {
247-
info!("Job {} is running...", job_id);
248-
}
249-
wait_future.await;
250-
prev_status = Some(status);
251-
}
252-
job_status::Status::Failed(err) => {
253-
let msg = format!("Job {} failed: {}", job_id, err.error);
254-
error!("{}", msg);
255-
break Err(DataFusionError::Execution(msg));
256-
}
257-
job_status::Status::Completed(completed) => {
258-
let result = future::join_all(
259-
completed
260-
.partition_location
261-
.into_iter()
262-
.map(fetch_partition),
263-
)
264-
.await
265-
.into_iter()
266-
.collect::<Result<Vec<_>>>()?;
267-
268-
let result = WrappedStream::new(
269-
Box::pin(futures::stream::iter(result).flatten()),
270-
Arc::new(schema),
271-
);
272-
break Ok(Box::pin(result));
273-
}
274-
};
275-
}
208+
let schema = self.schema();
209+
Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
276210
}
277211

278212
fn fmt_as(
@@ -299,6 +233,79 @@ impl<T: 'static + AsLogicalPlan> ExecutionPlan for DistributedQueryExec<T> {
299233
}
300234
}
301235

236+
async fn execute_query(
237+
scheduler_url: String,
238+
session_id: String,
239+
query: ExecuteQueryParams,
240+
) -> Result<impl Stream<Item = ArrowResult<RecordBatch>> + Send> {
241+
info!("Connecting to Ballista scheduler at {}", scheduler_url);
242+
// TODO reuse the scheduler to avoid connecting to the Ballista scheduler again and again
243+
244+
let mut scheduler = SchedulerGrpcClient::connect(scheduler_url.clone())
245+
.await
246+
.map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?;
247+
248+
let query_result = scheduler
249+
.execute_query(query)
250+
.await
251+
.map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?
252+
.into_inner();
253+
254+
assert_eq!(
255+
session_id, query_result.session_id,
256+
"Session id inconsistent between Client and Server side in DistributedQueryExec."
257+
);
258+
259+
let job_id = query_result.job_id;
260+
let mut prev_status: Option<job_status::Status> = None;
261+
262+
loop {
263+
let GetJobStatusResult { status } = scheduler
264+
.get_job_status(GetJobStatusParams {
265+
job_id: job_id.clone(),
266+
})
267+
.await
268+
.map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?
269+
.into_inner();
270+
let status = status.and_then(|s| s.status).ok_or_else(|| {
271+
DataFusionError::Internal("Received empty status message".to_owned())
272+
})?;
273+
let wait_future = tokio::time::sleep(Duration::from_millis(100));
274+
let has_status_change = prev_status.map(|x| x != status).unwrap_or(true);
275+
match status {
276+
job_status::Status::Queued(_) => {
277+
if has_status_change {
278+
info!("Job {} still queued...", job_id);
279+
}
280+
wait_future.await;
281+
prev_status = Some(status);
282+
}
283+
job_status::Status::Running(_) => {
284+
if has_status_change {
285+
info!("Job {} is running...", job_id);
286+
}
287+
wait_future.await;
288+
prev_status = Some(status);
289+
}
290+
job_status::Status::Failed(err) => {
291+
let msg = format!("Job {} failed: {}", job_id, err.error);
292+
error!("{}", msg);
293+
break Err(DataFusionError::Execution(msg));
294+
}
295+
job_status::Status::Completed(completed) => {
296+
let streams = completed.partition_location.into_iter().map(|p| {
297+
let f = fetch_partition(p)
298+
.map_err(|e| ArrowError::ExternalError(Box::new(e)));
299+
300+
futures::stream::once(f).try_flatten()
301+
});
302+
303+
break Ok(futures::stream::iter(streams).flatten());
304+
}
305+
};
306+
}
307+
}
308+
302309
async fn fetch_partition(
303310
location: PartitionLocation,
304311
) -> Result<SendableRecordBatchStream> {

ballista/rust/core/src/execution_plans/shuffle_reader.rs

+20-17
Original file line numberDiff line numberDiff line change
@@ -15,30 +15,28 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
use std::any::Any;
1819
use std::sync::Arc;
19-
use std::{any::Any, pin::Pin};
2020

2121
use crate::client::BallistaClient;
2222
use crate::serde::scheduler::{PartitionLocation, PartitionStats};
2323

24-
use crate::utils::WrappedStream;
2524
use async_trait::async_trait;
2625
use datafusion::arrow::datatypes::SchemaRef;
2726

27+
use datafusion::error::{DataFusionError, Result};
2828
use datafusion::physical_plan::expressions::PhysicalSortExpr;
2929
use datafusion::physical_plan::metrics::{
3030
ExecutionPlanMetricsSet, MetricBuilder, MetricsSet,
3131
};
3232
use datafusion::physical_plan::{
3333
DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, Statistics,
3434
};
35-
use datafusion::{
36-
error::{DataFusionError, Result},
37-
physical_plan::RecordBatchStream,
38-
};
39-
use futures::{future, StreamExt};
35+
use futures::{StreamExt, TryStreamExt};
4036

37+
use datafusion::arrow::error::ArrowError;
4138
use datafusion::execution::context::TaskContext;
39+
use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
4240
use log::info;
4341

4442
/// ShuffleReaderExec reads partitions that have already been materialized by a ShuffleWriterExec
@@ -112,18 +110,23 @@ impl ExecutionPlan for ShuffleReaderExec {
112110

113111
let fetch_time =
114112
MetricBuilder::new(&self.metrics).subset_time("fetch_time", partition);
115-
let timer = fetch_time.timer();
116113

117-
let partition_locations = &self.partition[partition];
118-
let result = future::join_all(partition_locations.iter().map(fetch_partition))
119-
.await
120-
.into_iter()
121-
.collect::<Result<Vec<_>>>()?;
122-
timer.done();
114+
let locations = self.partition[partition].clone();
115+
let stream = locations.into_iter().map(move |p| {
116+
let fetch_time = fetch_time.clone();
117+
futures::stream::once(async move {
118+
let timer = fetch_time.timer();
119+
let r = fetch_partition(&p).await;
120+
timer.done();
121+
122+
r.map_err(|e| ArrowError::ExternalError(Box::new(e)))
123+
})
124+
.try_flatten()
125+
});
123126

124-
let result = WrappedStream::new(
125-
Box::pin(futures::stream::iter(result).flatten()),
127+
let result = RecordBatchStreamAdapter::new(
126128
Arc::new(self.schema.as_ref().clone()),
129+
futures::stream::iter(stream).flatten(),
127130
);
128131
Ok(Box::pin(result))
129132
}
@@ -201,7 +204,7 @@ fn stats_for_partitions(
201204

202205
async fn fetch_partition(
203206
location: &PartitionLocation,
204-
) -> Result<Pin<Box<dyn RecordBatchStream + Send>>> {
207+
) -> Result<SendableRecordBatchStream> {
205208
let metadata = &location.executor_meta;
206209
let partition_id = &location.partition_id;
207210
let mut ballista_client =

0 commit comments

Comments
 (0)