Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 8ad1e83

Browse files
committedNov 25, 2024
Extensions framework
1 parent 31f8833 commit 8ad1e83

File tree

8 files changed

+1236
-220
lines changed

8 files changed

+1236
-220
lines changed
 

‎Cargo.lock

+1,057-193
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

‎Cargo.toml

+3-1
Original file line numberDiff line numberDiff line change
@@ -25,19 +25,21 @@ version = "0.1.0"
2525
edition = "2021"
2626
readme = "README.md"
2727
license = "Apache-2.0"
28-
rust-version = "1.62"
28+
rust-version = "1.70"
2929
build = "build.rs"
3030

3131
[dependencies]
3232
datafusion = { version = "42.0.0", features = ["pyarrow", "avro"] }
3333
datafusion-proto = "42.0.0"
34+
datafusion-python = "42.0.0"
3435
futures = "0.3"
3536
glob = "0.3.1"
3637
log = "0.4"
3738
prost = "0.13"
3839
pyo3 = { version = "0.22", features = ["extension-module", "abi3", "abi3-py38"] }
3940
tokio = { version = "1.40", features = ["macros", "rt", "rt-multi-thread", "sync"] }
4041
uuid = "1.11.0"
42+
async-trait = "0.1.83"
4143

4244
[build-dependencies]
4345
prost-types = "0.13"

‎datafusion_ray/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
ExecutionGraph,
2626
QueryStage,
2727
execute_partition,
28+
extended_session_context,
2829
)
2930
from .context import DatafusionRayContext
3031

‎src/context.rs

+32-23
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
use crate::ext::Extensions;
1819
use crate::planner::{make_execution_graph, PyExecutionGraph};
19-
use crate::shuffle::ShuffleCodec;
2020
use datafusion::arrow::pyarrow::ToPyArrow;
2121
use datafusion::arrow::record_batch::RecordBatch;
2222
use datafusion::error::{DataFusionError, Result};
@@ -26,6 +26,7 @@ use datafusion::physical_plan::{displayable, ExecutionPlan};
2626
use datafusion::prelude::*;
2727
use datafusion_proto::physical_plan::AsExecutionPlan;
2828
use datafusion_proto::protobuf;
29+
use datafusion_python::physical_plan::PyExecutionPlan;
2930
use futures::StreamExt;
3031
use prost::Message;
3132
use pyo3::exceptions::PyRuntimeError;
@@ -45,22 +46,30 @@ pub struct PyContext {
4546

4647
pub(crate) fn execution_plan_from_pyany(
4748
py_plan: &Bound<PyAny>,
49+
py: Python,
4850
) -> PyResult<Arc<dyn ExecutionPlan>> {
49-
let py_proto = py_plan.call_method0("to_proto")?;
50-
let plan_bytes: &[u8] = py_proto.extract()?;
51-
let plan_node = protobuf::PhysicalPlanNode::try_decode(plan_bytes).map_err(|e| {
52-
PyRuntimeError::new_err(format!(
53-
"Unable to decode physical plan protobuf message: {}",
54-
e
55-
))
56-
})?;
57-
58-
let codec = ShuffleCodec {};
59-
let runtime = RuntimeEnv::default();
60-
let registry = SessionContext::new();
61-
plan_node
62-
.try_into_physical_plan(&registry, &runtime, &codec)
63-
.map_err(|e| e.into())
51+
if let Ok(py_plan) = py_plan.to_object(py).downcast_bound::<PyExecutionPlan>(py) {
52+
// For session contexts created with datafusion_ray.extended_session_context(), the inner
53+
// execution plan can be used as such (and the enabled extensions are all available).
54+
Ok(py_plan.borrow().plan.clone())
55+
} else {
56+
// The session context originates from outside our library, so we'll grab the protobuf plan
57+
// by calling the python method with no extension codecs.
58+
let py_proto = py_plan.call_method0("to_proto")?;
59+
let plan_bytes: &[u8] = py_proto.extract()?;
60+
let plan_node = protobuf::PhysicalPlanNode::try_decode(plan_bytes).map_err(|e| {
61+
PyRuntimeError::new_err(format!(
62+
"Unable to decode physical plan protobuf message: {}",
63+
e
64+
))
65+
})?;
66+
67+
let runtime = RuntimeEnv::default();
68+
let registry = SessionContext::new();
69+
plan_node
70+
.try_into_physical_plan(&registry, &runtime, Extensions::codec())
71+
.map_err(|e| e.into())
72+
}
6473
}
6574

6675
#[pymethods]
@@ -87,14 +96,14 @@ impl PyContext {
8796
}
8897

8998
/// Plan a distributed SELECT query for executing against the Ray workers
90-
pub fn plan(&self, plan: &Bound<PyAny>) -> PyResult<PyExecutionGraph> {
99+
pub fn plan(&self, plan: &Bound<PyAny>, py: Python) -> PyResult<PyExecutionGraph> {
91100
// println!("Planning {}", sql);
92101
// let df = wait_for_future(py, self.ctx.sql(sql))?;
93102
// let py_df = self.run_sql(sql, py)?;
94103
// let py_plan = py_df.call_method0(py, "execution_plan")?;
95104
// let py_plan = py_plan.bind(py);
96105

97-
let plan = execution_plan_from_pyany(plan)?;
106+
let plan = execution_plan_from_pyany(plan, py)?;
98107
let graph = make_execution_graph(plan.clone())?;
99108

100109
// debug logging
@@ -140,9 +149,10 @@ pub fn serialize_execution_plan(
140149
plan: Arc<dyn ExecutionPlan>,
141150
py: Python,
142151
) -> PyResult<Bound<'_, PyBytes>> {
143-
let codec = ShuffleCodec {};
144-
let proto =
145-
datafusion_proto::protobuf::PhysicalPlanNode::try_from_physical_plan(plan.clone(), &codec)?;
152+
let proto = datafusion_proto::protobuf::PhysicalPlanNode::try_from_physical_plan(
153+
plan.clone(),
154+
Extensions::codec(),
155+
)?;
146156

147157
let bytes = proto.encode_to_vec();
148158
Ok(PyBytes::new_bound(py, &bytes))
@@ -159,9 +169,8 @@ pub fn deserialize_execution_plan(proto_msg: &Bound<PyBytes>) -> PyResult<Arc<dy
159169
})?;
160170

161171
let ctx = SessionContext::new();
162-
let codec = ShuffleCodec {};
163172
let plan = proto_plan
164-
.try_into_physical_plan(&ctx, &ctx.runtime_env(), &codec)
173+
.try_into_physical_plan(&ctx, &ctx.runtime_env(), Extensions::codec())
165174
.map_err(DataFusionError::from)?;
166175

167176
Ok(plan)

‎src/ext.rs

+121
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
use async_trait::async_trait;
2+
use datafusion::common::DataFusionError;
3+
use datafusion::common::Result;
4+
use datafusion::execution::FunctionRegistry;
5+
use datafusion::physical_plan::ExecutionPlan;
6+
use datafusion::prelude::SessionContext;
7+
use datafusion_proto::physical_plan::PhysicalExtensionCodec;
8+
use datafusion_python::context::PySessionContext;
9+
use datafusion_python::utils::wait_for_future;
10+
use pyo3::{pyfunction, PyResult, Python};
11+
use std::collections::HashMap;
12+
use std::fmt::Debug;
13+
use std::sync::{Arc, OnceLock};
14+
15+
mod built_in;
16+
17+
/// Creates a datafusion session context preconfigured with the enabled extensions
18+
/// that will register additional table providers, catalogs etc.
19+
/// If no extensions are required, the plain `datafusion.SessionContext()` will work just fine.
20+
/// # Arguments
21+
/// * `settings` - dictionary containing extension-specific key/value config options
22+
#[pyfunction]
23+
pub fn extended_session_context(
24+
settings: HashMap<String, String>,
25+
py: Python,
26+
) -> PyResult<PySessionContext> {
27+
let future_context = Extensions::session_context(&settings);
28+
let ctx = wait_for_future(py, future_context)?;
29+
Ok(ctx.into())
30+
}
31+
32+
/// Allows third party table/catalog providers, object stores, etc.
33+
/// to be registered with the DataFusion context.
34+
#[async_trait]
35+
trait Extension: Debug + Send + Sync + 'static {
36+
/// SessionContext initialization, using the provided key/value settings if needed.
37+
/// Declared async to allow implementers to perform network or other I/O operations.
38+
async fn init(&self, ctx: &SessionContext, settings: &HashMap<String, String>) -> Result<()> {
39+
let _ = ctx;
40+
let _ = settings;
41+
Ok(())
42+
}
43+
44+
/// Codecs for the custom physical plan nodes created by this extension, if any.
45+
fn codecs(&self) -> Vec<Box<dyn PhysicalExtensionCodec>> {
46+
vec![]
47+
}
48+
}
49+
50+
/// A composite extension registry for enabled extensions.
51+
#[derive(Debug)]
52+
pub(crate) struct Extensions(Box<[Box<dyn Extension>]>);
53+
54+
#[async_trait]
55+
impl Extension for Extensions {
56+
async fn init(&self, ctx: &SessionContext, settings: &HashMap<String, String>) -> Result<()> {
57+
for ext in &self.0 {
58+
ext.init(ctx, settings).await?;
59+
}
60+
Ok(())
61+
}
62+
63+
fn codecs(&self) -> Vec<Box<dyn PhysicalExtensionCodec>> {
64+
self.0.iter().flat_map(|ext| ext.codecs()).collect()
65+
}
66+
}
67+
68+
impl Extensions {
69+
fn new() -> Self {
70+
Self(Box::new([
71+
Box::new(built_in::DefaultExtension::default()),
72+
]))
73+
}
74+
75+
fn singleton() -> &'static Self {
76+
static EXTENSIONS: OnceLock<Extensions> = OnceLock::new();
77+
EXTENSIONS.get_or_init(Self::new)
78+
}
79+
80+
pub(crate) async fn session_context(
81+
settings: &HashMap<String, String>,
82+
) -> Result<SessionContext> {
83+
let ctx = SessionContext::new();
84+
Self::singleton().init(&ctx, settings).await?;
85+
Ok(ctx)
86+
}
87+
88+
pub(crate) fn codec() -> &'static CompositeCodec {
89+
static COMPOSITE_CODEC: OnceLock<CompositeCodec> = OnceLock::new();
90+
COMPOSITE_CODEC.get_or_init(|| CompositeCodec(Extensions::singleton().codecs().into()))
91+
}
92+
}
93+
94+
/// For both encoding and decoding, tries all the registered extension codecs and returns the first successful result.
95+
#[derive(Debug)]
96+
pub(crate) struct CompositeCodec(Box<[Box<dyn PhysicalExtensionCodec>]>);
97+
98+
impl PhysicalExtensionCodec for CompositeCodec {
99+
fn try_decode(
100+
&self,
101+
buf: &[u8],
102+
inputs: &[Arc<dyn ExecutionPlan>],
103+
registry: &dyn FunctionRegistry,
104+
) -> Result<Arc<dyn ExecutionPlan>> {
105+
self.0
106+
.iter()
107+
.filter_map(|codec| codec.try_decode(buf, inputs, registry).ok())
108+
.next()
109+
.ok_or_else(|| DataFusionError::Execution("No compatible codec found".into()))
110+
}
111+
112+
fn try_encode(&self, node: Arc<dyn ExecutionPlan>, buf: &mut Vec<u8>) -> Result<()> {
113+
self.0
114+
.iter()
115+
.filter_map(|codec| codec.try_encode(node.clone(), buf).ok())
116+
.next()
117+
.ok_or_else(|| {
118+
DataFusionError::Execution(format!("No compatible codec found for {}", node.name()))
119+
})
120+
}
121+
}

‎src/ext/built_in.rs

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
use crate::ext::Extension;
2+
use crate::shuffle::ShuffleCodec;
3+
use async_trait::async_trait;
4+
use datafusion_proto::physical_plan::PhysicalExtensionCodec;
5+
6+
#[derive(Debug, Default)]
7+
pub(super) struct DefaultExtension {}
8+
9+
#[async_trait]
10+
impl Extension for DefaultExtension {
11+
fn codecs(&self) -> Vec<Box<dyn PhysicalExtensionCodec>> {
12+
vec![Box::new(ShuffleCodec {})]
13+
}
14+
}

‎src/lib.rs

+4
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ mod proto;
2323
use crate::context::execute_partition;
2424
pub use proto::generated::protobuf;
2525

26+
mod ext;
27+
use crate::ext::extended_session_context;
28+
2629
pub mod context;
2730
pub mod planner;
2831
pub mod query_stage;
@@ -36,5 +39,6 @@ fn _datafusion_ray_internal(m: &Bound<'_, PyModule>) -> PyResult<()> {
3639
m.add_class::<planner::PyExecutionGraph>()?;
3740
m.add_class::<query_stage::PyQueryStage>()?;
3841
m.add_function(wrap_pyfunction!(execute_partition, m)?)?;
42+
m.add_function(wrap_pyfunction!(extended_session_context, m)?)?;
3943
Ok(())
4044
}

‎src/query_stage.rs

+4-3
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
// under the License.
1717

1818
use crate::context::serialize_execution_plan;
19-
use crate::shuffle::{ShuffleCodec, ShuffleReaderExec};
19+
use crate::ext::Extensions;
20+
use crate::shuffle::ShuffleReaderExec;
2021
use datafusion::error::Result;
2122
use datafusion::physical_plan::{ExecutionPlan, Partitioning};
2223
use datafusion::prelude::SessionContext;
@@ -41,8 +42,8 @@ impl PyQueryStage {
4142
#[new]
4243
pub fn new(id: usize, bytes: Vec<u8>) -> Result<Self> {
4344
let ctx = SessionContext::new();
44-
let codec = ShuffleCodec {};
45-
let plan = physical_plan_from_bytes_with_extension_codec(&bytes, &ctx, &codec)?;
45+
let plan =
46+
physical_plan_from_bytes_with_extension_codec(&bytes, &ctx, Extensions::codec())?;
4647
Ok(PyQueryStage {
4748
stage: Arc::new(QueryStage { id, plan }),
4849
})

0 commit comments

Comments
 (0)
Please sign in to comment.