Skip to content

Commit 31f8833

Browse files
authored
feat: Reinstate disk-based shuffle (#47)
* old old shuffle reader/writer * old old shuffle reader/writer * remove ray shuffle * revert more changes * save progress * update expected plans * remove unused code * fix regression
1 parent a86218c commit 31f8833

36 files changed

+872
-864
lines changed

Cargo.lock

+2
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

+2
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,12 @@ build = "build.rs"
3232
datafusion = { version = "42.0.0", features = ["pyarrow", "avro"] }
3333
datafusion-proto = "42.0.0"
3434
futures = "0.3"
35+
glob = "0.3.1"
3536
log = "0.4"
3637
prost = "0.13"
3738
pyo3 = { version = "0.22", features = ["extension-module", "abi3", "abi3-py38"] }
3839
tokio = { version = "1.40", features = ["macros", "rt", "rt-multi-thread", "sync"] }
40+
uuid = "1.11.0"
3941

4042
[build-dependencies]
4143
prost-types = "0.13"

datafusion_ray/context.py

+46-53
Original file line numberDiff line numberDiff line change
@@ -29,23 +29,31 @@
2929
from datafusion import SessionContext
3030

3131

32-
def schedule_execution(
33-
graph: ExecutionGraph,
34-
stage_id: int,
35-
is_final_stage: bool,
36-
) -> list[ray.ObjectRef]:
37-
stage = graph.get_query_stage(stage_id)
32+
@ray.remote(num_cpus=0)
33+
def execute_query_stage(
34+
query_stages: list[QueryStage],
35+
stage_id: int
36+
) -> tuple[int, list[ray.ObjectRef]]:
37+
"""
38+
Execute a query stage on the workers.
39+
40+
Returns the stage ID, and a list of futures for the output partitions of the query stage.
41+
"""
42+
stage = QueryStage(stage_id, query_stages[stage_id])
43+
3844
# execute child stages first
39-
# A list of (stage ID, list of futures) for each child stage
40-
# Each list is a 2-D array of (input partitions, output partitions).
41-
child_outputs = []
45+
child_futures = []
4246
for child_id in stage.get_child_stage_ids():
43-
child_outputs.append((child_id, schedule_execution(graph, child_id, False)))
44-
# child_outputs.append((child_id, schedule_execution(graph, child_id)))
47+
child_futures.append(
48+
execute_query_stage.remote(query_stages, child_id)
49+
)
4550

51+
# if the query stage has a single output partition then we need to execute for the output
52+
# partition, otherwise we need to execute in parallel for each input partition
4653
concurrency = stage.get_input_partition_count()
4754
output_partitions_count = stage.get_output_partition_count()
48-
if is_final_stage:
55+
if output_partitions_count == 1:
56+
# reduce stage
4957
print("Forcing reduce stage concurrency from {} to 1".format(concurrency))
5058
concurrency = 1
5159

@@ -55,50 +63,33 @@ def schedule_execution(
5563
)
5664
)
5765

58-
def _get_worker_inputs(
59-
part: int,
60-
) -> tuple[list[tuple[int, int, int]], list[ray.ObjectRef]]:
61-
ids = []
62-
futures = []
63-
for child_stage_id, child_futures in child_outputs:
64-
for i, lst in enumerate(child_futures):
65-
if isinstance(lst, list):
66-
for j, f in enumerate(lst):
67-
if concurrency == 1 or j == part:
68-
# If concurrency is 1, pass in all shuffle partitions. Otherwise,
69-
# only pass in the partitions that match the current worker partition.
70-
ids.append((child_stage_id, i, j))
71-
futures.append(f)
72-
elif concurrency == 1 or part == 0:
73-
ids.append((child_stage_id, i, 0))
74-
futures.append(lst)
75-
return ids, futures
66+
# A list of (stage ID, list of futures) for each child stage
67+
# Each list is a 2-D array of (input partitions, output partitions).
68+
child_outputs = ray.get(child_futures)
69+
70+
# if we are using disk-based shuffle, wait until the child stages to finish
71+
# writing the shuffle files to disk first.
72+
ray.get([f for _, lst in child_outputs for f in lst])
7673

7774
# schedule the actual execution workers
7875
plan_bytes = stage.get_execution_plan_bytes()
7976
futures = []
8077
opt = {}
81-
# TODO not sure why we had this but my Ray cluster could not find suitable resource
82-
# until I commented this out
83-
# opt["resources"] = {"worker": 1e-3}
84-
opt["num_returns"] = output_partitions_count
8578
for part in range(concurrency):
86-
ids, inputs = _get_worker_inputs(part)
8779
futures.append(
8880
execute_query_partition.options(**opt).remote(
89-
stage_id, plan_bytes, part, ids, *inputs
81+
stage_id, plan_bytes, part
9082
)
9183
)
92-
return futures
84+
85+
return stage_id, futures
9386

9487

9588
@ray.remote
9689
def execute_query_partition(
9790
stage_id: int,
9891
plan_bytes: bytes,
99-
part: int,
100-
input_partition_ids: list[tuple[int, int, int]],
101-
*input_partitions: list[pa.RecordBatch],
92+
part: int
10293
) -> Iterable[pa.RecordBatch]:
10394
start_time = time.time()
10495
# plan = datafusion_ray.deserialize_execution_plan(plan_bytes)
@@ -109,13 +100,10 @@ def execute_query_partition(
109100
# input_partition_ids,
110101
# )
111102
# )
112-
partitions = [
113-
(s, j, p) for (s, _, j), p in zip(input_partition_ids, input_partitions)
114-
]
115103
# This is delegating to DataFusion for execution, but this would be a good place
116104
# to plug in other execution engines by translating the plan into another engine's plan
117105
# (perhaps via Substrait, once DataFusion supports converting a physical plan to Substrait)
118-
ret = datafusion_ray.execute_partition(plan_bytes, part, partitions)
106+
ret = datafusion_ray.execute_partition(plan_bytes, part)
119107
duration = time.time() - start_time
120108
event = {
121109
"cat": f"{stage_id}-{part}",
@@ -153,19 +141,24 @@ def sql(self, sql: str) -> pa.RecordBatch:
153141
return []
154142

155143
df = self.df_ctx.sql(sql)
156-
execution_plan = df.execution_plan()
144+
return self.plan(df.execution_plan())
145+
146+
def plan(self, execution_plan: Any) -> pa.RecordBatch:
157147

158148
graph = self.ctx.plan(execution_plan)
159149
final_stage_id = graph.get_final_query_stage().id()
160-
partitions = schedule_execution(graph, final_stage_id, True)
150+
# serialize the query stages and store in Ray object store
151+
query_stages = [
152+
graph.get_query_stage(i).get_execution_plan_bytes()
153+
for i in range(final_stage_id + 1)
154+
]
155+
# schedule execution
156+
future = execute_query_stage.remote(
157+
query_stages,
158+
final_stage_id
159+
)
160+
_, partitions = ray.get(future)
161161
# assert len(partitions) == 1, len(partitions)
162162
result_set = ray.get(partitions[0])
163163
return result_set
164164

165-
def plan(self, physical_plan: Any) -> pa.RecordBatch:
166-
graph = self.ctx.plan(physical_plan)
167-
final_stage_id = graph.get_final_query_stage().id()
168-
partitions = schedule_execution(graph, final_stage_id, True)
169-
# assert len(partitions) == 1, len(partitions)
170-
result_set = ray.get(partitions[0])
171-
return result_set

src/context.rs

+5-64
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,7 @@
1616
// under the License.
1717

1818
use crate::planner::{make_execution_graph, PyExecutionGraph};
19-
use crate::shuffle::{RayShuffleReaderExec, ShuffleCodec};
20-
use datafusion::arrow::pyarrow::FromPyArrow;
19+
use crate::shuffle::ShuffleCodec;
2120
use datafusion::arrow::pyarrow::ToPyArrow;
2221
use datafusion::arrow::record_batch::RecordBatch;
2322
use datafusion::error::{DataFusionError, Result};
@@ -31,7 +30,7 @@ use futures::StreamExt;
3130
use prost::Message;
3231
use pyo3::exceptions::PyRuntimeError;
3332
use pyo3::prelude::*;
34-
use pyo3::types::{PyBytes, PyList, PyLong, PyTuple};
33+
use pyo3::types::{PyBytes, PyTuple};
3534
use std::collections::HashMap;
3635
use std::sync::Arc;
3736
use tokio::runtime::Runtime;
@@ -117,22 +116,20 @@ impl PyContext {
117116
&self,
118117
plan: &Bound<'_, PyBytes>,
119118
part: usize,
120-
inputs: PyObject,
121119
py: Python,
122120
) -> PyResult<PyResultSet> {
123-
execute_partition(plan, part, inputs, py)
121+
execute_partition(plan, part, py)
124122
}
125123
}
126124

127125
#[pyfunction]
128126
pub fn execute_partition(
129127
plan_bytes: &Bound<'_, PyBytes>,
130128
part: usize,
131-
inputs: PyObject,
132129
py: Python,
133130
) -> PyResult<PyResultSet> {
134131
let plan = deserialize_execution_plan(plan_bytes)?;
135-
_execute_partition(plan, part, inputs)
132+
_execute_partition(plan, part)
136133
.unwrap()
137134
.into_iter()
138135
.map(|batch| batch.to_pyarrow(py))
@@ -170,59 +167,10 @@ pub fn deserialize_execution_plan(proto_msg: &Bound<PyBytes>) -> PyResult<Arc<dy
170167
Ok(plan)
171168
}
172169

173-
/// Iterate down an ExecutionPlan and set the input objects for RayShuffleReaderExec.
174-
fn _set_inputs_for_ray_shuffle_reader(
175-
plan: Arc<dyn ExecutionPlan>,
176-
input_partitions: &Bound<'_, PyList>,
177-
) -> Result<()> {
178-
if let Some(reader_exec) = plan.as_any().downcast_ref::<RayShuffleReaderExec>() {
179-
let exec_stage_id = reader_exec.stage_id;
180-
// iterate over inputs, wrap in PyBytes and set as input objects
181-
for item in input_partitions.iter() {
182-
let pytuple = item
183-
.downcast::<PyTuple>()
184-
.map_err(|e| DataFusionError::Execution(format!("{}", e)))?;
185-
let stage_id = pytuple
186-
.get_item(0)
187-
.map_err(|e| DataFusionError::Execution(format!("{}", e)))?
188-
.downcast::<PyLong>()
189-
.map_err(|e| DataFusionError::Execution(format!("{}", e)))?
190-
.extract::<usize>()
191-
.map_err(|e| DataFusionError::Execution(format!("{}", e)))?;
192-
if stage_id != exec_stage_id {
193-
continue;
194-
}
195-
let part = pytuple
196-
.get_item(1)
197-
.map_err(|e| DataFusionError::Execution(format!("{}", e)))?
198-
.downcast::<PyLong>()
199-
.map_err(|e| DataFusionError::Execution(format!("{}", e)))?
200-
.extract::<usize>()
201-
.map_err(|e| DataFusionError::Execution(format!("{}", e)))?;
202-
let batch = RecordBatch::from_pyarrow_bound(
203-
&pytuple
204-
.get_item(2)
205-
.map_err(|e| DataFusionError::Execution(format!("{}", e)))?,
206-
)
207-
.map_err(|e| DataFusionError::Execution(format!("{}", e)))?;
208-
reader_exec.add_input_partition(part, batch)?;
209-
}
210-
} else {
211-
for child in plan.children() {
212-
_set_inputs_for_ray_shuffle_reader(child.to_owned(), input_partitions)?;
213-
}
214-
}
215-
Ok(())
216-
}
217-
218170
/// Execute a partition of a query plan. This will typically be executing a shuffle write and
219171
/// write the results to disk, except for the final query stage, which will return the data.
220172
/// inputs is a list of tuples of (stage_id, partition_id, bytes) for each input partition.
221-
fn _execute_partition(
222-
plan: Arc<dyn ExecutionPlan>,
223-
part: usize,
224-
inputs: PyObject,
225-
) -> Result<Vec<RecordBatch>> {
173+
fn _execute_partition(plan: Arc<dyn ExecutionPlan>, part: usize) -> Result<Vec<RecordBatch>> {
226174
let ctx = Arc::new(TaskContext::new(
227175
Some("task_id".to_string()),
228176
"session_id".to_string(),
@@ -233,13 +181,6 @@ fn _execute_partition(
233181
Arc::new(RuntimeEnv::default()),
234182
));
235183

236-
Python::with_gil(|py| {
237-
let input_partitions = inputs
238-
.downcast_bound::<PyList>(py)
239-
.map_err(|e| DataFusionError::Execution(format!("{}", e)))?;
240-
_set_inputs_for_ray_shuffle_reader(plan.clone(), input_partitions)
241-
})?;
242-
243184
// create a Tokio runtime to run the async code
244185
let rt = Runtime::new().unwrap();
245186

src/planner.rs

+17-3
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
use crate::query_stage::PyQueryStage;
1919
use crate::query_stage::QueryStage;
20-
use crate::shuffle::{RayShuffleReaderExec, RayShuffleWriterExec};
20+
use crate::shuffle::{ShuffleReaderExec, ShuffleWriterExec};
2121
use datafusion::error::Result;
2222
use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec;
2323
use datafusion::physical_plan::repartition::RepartitionExec;
@@ -29,6 +29,7 @@ use pyo3::prelude::*;
2929
use std::collections::HashMap;
3030
use std::sync::atomic::{AtomicUsize, Ordering};
3131
use std::sync::Arc;
32+
use uuid::Uuid;
3233

3334
#[pyclass(name = "ExecutionGraph", module = "datafusion_ray", subclass)]
3435
pub struct PyExecutionGraph {
@@ -200,11 +201,15 @@ fn create_shuffle_exchange(
200201
// introduce shuffle to produce one output partition
201202
let stage_id = graph.next_id();
202203

204+
// create temp dir for stage shuffle files
205+
let temp_dir = create_temp_dir(stage_id)?;
206+
203207
let shuffle_writer_input = plan.clone();
204-
let shuffle_writer: Arc<dyn ExecutionPlan> = Arc::new(RayShuffleWriterExec::new(
208+
let shuffle_writer: Arc<dyn ExecutionPlan> = Arc::new(ShuffleWriterExec::new(
205209
stage_id,
206210
shuffle_writer_input,
207211
partitioning_scheme.clone(),
212+
&temp_dir,
208213
));
209214

210215
debug!(
@@ -214,13 +219,22 @@ fn create_shuffle_exchange(
214219

215220
let stage_id = graph.add_query_stage(stage_id, shuffle_writer);
216221
// replace the plan with a shuffle reader
217-
Ok(Arc::new(RayShuffleReaderExec::new(
222+
Ok(Arc::new(ShuffleReaderExec::new(
218223
stage_id,
219224
plan.schema(),
220225
partitioning_scheme,
226+
&temp_dir,
221227
)))
222228
}
223229

230+
fn create_temp_dir(stage_id: usize) -> Result<String> {
231+
let uuid = Uuid::new_v4();
232+
let temp_dir = format!("/tmp/ray-sql-{uuid}-stage-{stage_id}");
233+
debug!("Creating temp shuffle dir: {temp_dir}");
234+
std::fs::create_dir(&temp_dir)?;
235+
Ok(temp_dir)
236+
}
237+
224238
#[cfg(test)]
225239
mod test {
226240
use super::*;

src/proto/datafusion_ray.proto

+8-4
Original file line numberDiff line numberDiff line change
@@ -11,25 +11,29 @@ import "datafusion.proto";
1111

1212
message RaySqlExecNode {
1313
oneof PlanType {
14-
RayShuffleReaderExecNode ray_shuffle_reader = 3;
15-
RayShuffleWriterExecNode ray_shuffle_writer = 4;
14+
ShuffleReaderExecNode shuffle_reader = 1;
15+
ShuffleWriterExecNode shuffle_writer = 2;
1616
}
1717
}
1818

19-
message RayShuffleReaderExecNode {
19+
message ShuffleReaderExecNode {
2020
// stage to read from
2121
uint32 stage_id = 1;
2222
// schema of the shuffle stage
2323
datafusion_common.Schema schema = 2;
2424
// this must match the output partitioning of the writer we are reading from
2525
datafusion.PhysicalHashRepartition partitioning = 3;
26+
// directory for shuffle files
27+
string shuffle_dir = 4;
2628
}
2729

28-
message RayShuffleWriterExecNode {
30+
message ShuffleWriterExecNode {
2931
// stage that is writing the shuffle files
3032
uint32 stage_id = 1;
3133
// plan to execute
3234
datafusion.PhysicalPlanNode plan = 2;
3335
// output partitioning schema
3436
datafusion.PhysicalHashRepartition partitioning = 3;
37+
// directory for shuffle files
38+
string shuffle_dir = 4;
3539
}

0 commit comments

Comments
 (0)