Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extensions API for third-party table/catalog providers #43

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2,035 changes: 1,830 additions & 205 deletions Cargo.lock

Large diffs are not rendered by default.

8 changes: 7 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,22 @@ version = "0.1.0"
edition = "2021"
readme = "README.md"
license = "Apache-2.0"
rust-version = "1.62"
rust-version = "1.70"
build = "build.rs"

[dependencies]
datafusion = { version = "42.0.0", features = ["pyarrow", "avro"] }
datafusion-proto = "42.0.0"
datafusion-python = "42.0.0"
futures = "0.3"
glob = "0.3.1"
log = "0.4"
prost = "0.13"
pyo3 = { version = "0.22", features = ["extension-module", "abi3", "abi3-py38"] }
tokio = { version = "1.40", features = ["macros", "rt", "rt-multi-thread", "sync"] }
uuid = "1.11.0"
async-trait = "0.1.83"
datafusion-table-providers = { version = "0.2.3", default-features = false, optional = true }

[build-dependencies]
prost-types = "0.13"
Expand All @@ -59,3 +62,6 @@ name = "datafusion_ray._datafusion_ray_internal"
[profile.release]
codegen-units = 1
lto = true

[features]
flight-sql-tables = ["datafusion-table-providers/flight"]
1 change: 1 addition & 0 deletions datafusion_ray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
ExecutionGraph,
QueryStage,
execute_partition,
extended_session_context,
)
from .context import DatafusionRayContext

Expand Down
27 changes: 27 additions & 0 deletions examples/flight.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import ray

import datafusion_ray
from datafusion_ray import DatafusionRayContext

## Prerequisites:
## $ brew install roapi
## $ roapi --table taxi=https://d37ci6vzurychx.cloudfront.net/trip-data/yellow_tripdata_2024-01.parquet
## $ maturin develop --features flight-sql-tables

ray.init()

ctx = datafusion_ray.extended_session_context({})

ray_ctx = DatafusionRayContext(ctx)

ray_ctx.sql("""
CREATE EXTERNAL TABLE trip_data
STORED AS FLIGHT_SQL
LOCATION 'http://localhost:32010'
OPTIONS (
'flight.sql.query' 'SELECT * FROM taxi LIMIT 25'
)
""")

df = ray_ctx.sql("SELECT tpep_pickup_datetime FROM trip_data LIMIT 10")
print(df.to_pandas())
55 changes: 32 additions & 23 deletions src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
// specific language governing permissions and limitations
// under the License.

use crate::ext::Extensions;
use crate::planner::{make_execution_graph, PyExecutionGraph};
use crate::shuffle::ShuffleCodec;
use datafusion::arrow::pyarrow::ToPyArrow;
use datafusion::arrow::record_batch::RecordBatch;
use datafusion::error::{DataFusionError, Result};
Expand All @@ -26,6 +26,7 @@ use datafusion::physical_plan::{displayable, ExecutionPlan};
use datafusion::prelude::*;
use datafusion_proto::physical_plan::AsExecutionPlan;
use datafusion_proto::protobuf;
use datafusion_python::physical_plan::PyExecutionPlan;
use futures::StreamExt;
use prost::Message;
use pyo3::exceptions::PyRuntimeError;
Expand All @@ -45,22 +46,30 @@ pub struct PyContext {

pub(crate) fn execution_plan_from_pyany(
py_plan: &Bound<PyAny>,
py: Python,
) -> PyResult<Arc<dyn ExecutionPlan>> {
let py_proto = py_plan.call_method0("to_proto")?;
let plan_bytes: &[u8] = py_proto.extract()?;
let plan_node = protobuf::PhysicalPlanNode::try_decode(plan_bytes).map_err(|e| {
PyRuntimeError::new_err(format!(
"Unable to decode physical plan protobuf message: {}",
e
))
})?;

let codec = ShuffleCodec {};
let runtime = RuntimeEnv::default();
let registry = SessionContext::new();
plan_node
.try_into_physical_plan(&registry, &runtime, &codec)
.map_err(|e| e.into())
if let Ok(py_plan) = py_plan.to_object(py).downcast_bound::<PyExecutionPlan>(py) {
// For session contexts created with datafusion_ray.extended_session_context(), the inner
// execution plan can be used as such (and the enabled extensions are all available).
Ok(py_plan.borrow().plan.clone())
} else {
// The session context originates from outside our library, so we'll grab the protobuf plan
// by calling the python method with no extension codecs.
let py_proto = py_plan.call_method0("to_proto")?;
let plan_bytes: &[u8] = py_proto.extract()?;
let plan_node = protobuf::PhysicalPlanNode::try_decode(plan_bytes).map_err(|e| {
PyRuntimeError::new_err(format!(
"Unable to decode physical plan protobuf message: {}",
e
))
})?;

let runtime = RuntimeEnv::default();
let registry = SessionContext::new();
plan_node
.try_into_physical_plan(&registry, &runtime, Extensions::codec())
.map_err(|e| e.into())
}
Comment on lines +51 to +72
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One of my goals was to remove the datafusion-python dependency from datafusion-ray so that you wouldn't have hard requirements about having the exact same version between the two. It can be worse in that you also have to have the same compiler for both. Now for datafusion-ray we may be able to get away with it for official releases since we control the build pipeline for both. This does place a restriction on end users in that they have to make sure they keep these versions synced on their machine. In my opinion it would be better to lean on things like the FFI interface that is coming in datafusion-python 43.0.0. I know that right now doesn't solve the problem of having all extensions, though.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, @timsaucer !
Makes sense, and I'm entirely onboard with this goal. Which is why I tried to preserve this guarantee by keeping the existing code within the else branch here, where no assumption is made about the datafusion-python version/compiler.

The "embedded" datafusion-python dependency is only supposed to be an opt-in alternative for users who decide to call the new function for creating an ABI compatible context preconfigured with the enabled extensions (that's the only way the above downcast can succeed, if I'm not mistaken).
So any existing or future code that doesn't switch to using the new extended_session_context() function will continue to work without any compatibility restrictions.
However, if I somehow failed to preserve this guarantee or if I missed something that introduces any potential risks, please let me know so I'll revisit the approach.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't want to hold up progress, but this feels like a step in the wrong direction. But I also don't have enough time right now to give a better solution. I think I'd want to do something like use the FFI_ExecutionPlan in df43 to share these across packages.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, I'll look into that. Thanks!

}

#[pymethods]
Expand All @@ -87,14 +96,14 @@ impl PyContext {
}

/// Plan a distributed SELECT query for executing against the Ray workers
pub fn plan(&self, plan: &Bound<PyAny>) -> PyResult<PyExecutionGraph> {
pub fn plan(&self, plan: &Bound<PyAny>, py: Python) -> PyResult<PyExecutionGraph> {
// println!("Planning {}", sql);
// let df = wait_for_future(py, self.ctx.sql(sql))?;
// let py_df = self.run_sql(sql, py)?;
// let py_plan = py_df.call_method0(py, "execution_plan")?;
// let py_plan = py_plan.bind(py);

let plan = execution_plan_from_pyany(plan)?;
let plan = execution_plan_from_pyany(plan, py)?;
let graph = make_execution_graph(plan.clone())?;

// debug logging
Expand Down Expand Up @@ -140,9 +149,10 @@ pub fn serialize_execution_plan(
plan: Arc<dyn ExecutionPlan>,
py: Python,
) -> PyResult<Bound<'_, PyBytes>> {
let codec = ShuffleCodec {};
let proto =
datafusion_proto::protobuf::PhysicalPlanNode::try_from_physical_plan(plan.clone(), &codec)?;
let proto = datafusion_proto::protobuf::PhysicalPlanNode::try_from_physical_plan(
plan.clone(),
Extensions::codec(),
)?;

let bytes = proto.encode_to_vec();
Ok(PyBytes::new_bound(py, &bytes))
Expand All @@ -159,9 +169,8 @@ pub fn deserialize_execution_plan(proto_msg: &Bound<PyBytes>) -> PyResult<Arc<dy
})?;

let ctx = SessionContext::new();
let codec = ShuffleCodec {};
let plan = proto_plan
.try_into_physical_plan(&ctx, &ctx.runtime_env(), &codec)
.try_into_physical_plan(&ctx, &ctx.runtime_env(), Extensions::codec())
.map_err(DataFusionError::from)?;

Ok(plan)
Expand Down
126 changes: 126 additions & 0 deletions src/ext.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
use async_trait::async_trait;
use datafusion::common::DataFusionError;
use datafusion::common::Result;
use datafusion::execution::FunctionRegistry;
use datafusion::physical_plan::ExecutionPlan;
use datafusion::prelude::SessionContext;
use datafusion_proto::physical_plan::PhysicalExtensionCodec;
use datafusion_python::context::PySessionContext;
use datafusion_python::utils::wait_for_future;
use pyo3::{pyfunction, PyResult, Python};
use std::collections::HashMap;
use std::fmt::Debug;
use std::sync::{Arc, OnceLock};

mod built_in;

#[cfg(feature = "flight-sql-tables")]
mod flight;

/// Creates a datafusion session context preconfigured with the enabled extensions
/// that will register additional table providers, catalogs etc.
/// If no extensions are required, the plain `datafusion.SessionContext()` will work just fine.
/// # Arguments
/// * `settings` - dictionary containing extension-specific key/value config options
#[pyfunction]
pub fn extended_session_context(
settings: HashMap<String, String>,
py: Python,
) -> PyResult<PySessionContext> {
let future_context = Extensions::session_context(&settings);
let ctx = wait_for_future(py, future_context)?;
Ok(ctx.into())
}

/// Allows third party table/catalog providers, object stores, etc.
/// to be registered with the DataFusion context.
#[async_trait]
trait Extension: Debug + Send + Sync + 'static {
/// SessionContext initialization, using the provided key/value settings if needed.
/// Declared async to allow implementers to perform network or other I/O operations.
async fn init(&self, ctx: &SessionContext, settings: &HashMap<String, String>) -> Result<()> {
let _ = ctx;
let _ = settings;
Ok(())
}

/// Codecs for the custom physical plan nodes created by this extension, if any.
fn codecs(&self) -> Vec<Box<dyn PhysicalExtensionCodec>> {
vec![]
}
}

/// A composite extension registry for enabled extensions.
#[derive(Debug)]
pub(crate) struct Extensions(Box<[Box<dyn Extension>]>);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not Vec<Box<dyn Extension>>? What is the advantage of Box<[T]>?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit lighter (hardly relevant in this scenario) and it's immutable by design (arguably also irrelevant, though I personally prefer it when I can choose). I think I got this habit after watching this video. :)
But I don't mind switching to Vec if you think that would improve readability.


#[async_trait]
impl Extension for Extensions {
async fn init(&self, ctx: &SessionContext, settings: &HashMap<String, String>) -> Result<()> {
for ext in &self.0 {
ext.init(ctx, settings).await?;
}
Ok(())
}

fn codecs(&self) -> Vec<Box<dyn PhysicalExtensionCodec>> {
self.0.iter().flat_map(|ext| ext.codecs()).collect()
}
}

impl Extensions {
fn new() -> Self {
Self(Box::new([
Box::new(built_in::DefaultExtension::default()),
#[cfg(feature = "flight-sql-tables")]
Box::new(flight::FlightSqlTables::default()),
]))
}

fn singleton() -> &'static Self {
static EXTENSIONS: OnceLock<Extensions> = OnceLock::new();
EXTENSIONS.get_or_init(Self::new)
}

pub(crate) async fn session_context(
settings: &HashMap<String, String>,
) -> Result<SessionContext> {
let ctx = SessionContext::new();
Self::singleton().init(&ctx, settings).await?;
Ok(ctx)
}

pub(crate) fn codec() -> &'static CompositeCodec {
static COMPOSITE_CODEC: OnceLock<CompositeCodec> = OnceLock::new();
COMPOSITE_CODEC.get_or_init(|| CompositeCodec(Extensions::singleton().codecs().into()))
}
}

/// For both encoding and decoding, tries all the registered extension codecs and returns the first successful result.
#[derive(Debug)]
pub(crate) struct CompositeCodec(Box<[Box<dyn PhysicalExtensionCodec>]>);

impl PhysicalExtensionCodec for CompositeCodec {
fn try_decode(
&self,
buf: &[u8],
inputs: &[Arc<dyn ExecutionPlan>],
registry: &dyn FunctionRegistry,
) -> Result<Arc<dyn ExecutionPlan>> {
self.0
.iter()
.filter_map(|codec| codec.try_decode(buf, inputs, registry).ok())
.next()
.ok_or_else(|| DataFusionError::Execution("No compatible codec found".into()))
}

fn try_encode(&self, node: Arc<dyn ExecutionPlan>, buf: &mut Vec<u8>) -> Result<()> {
self.0
.iter()
.filter_map(|codec| codec.try_encode(node.clone(), buf).ok())
.next()
.ok_or_else(|| {
DataFusionError::Execution(format!("No compatible codec found for {}", node.name()))
})
}
}
14 changes: 14 additions & 0 deletions src/ext/built_in.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
use crate::ext::Extension;
use crate::shuffle::ShuffleCodec;
use async_trait::async_trait;
use datafusion_proto::physical_plan::PhysicalExtensionCodec;

#[derive(Debug, Default)]
pub(super) struct DefaultExtension {}

#[async_trait]
impl Extension for DefaultExtension {
fn codecs(&self) -> Vec<Box<dyn PhysicalExtensionCodec>> {
vec![Box::new(ShuffleCodec {})]
}
}
33 changes: 33 additions & 0 deletions src/ext/flight.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
use crate::ext::Extension;
use async_trait::async_trait;
use datafusion::prelude::SessionContext;
use datafusion_proto::physical_plan::PhysicalExtensionCodec;
use datafusion_table_providers::flight::codec::FlightPhysicalCodec;
use datafusion_table_providers::flight::sql::FlightSqlDriver;
use datafusion_table_providers::flight::FlightTableFactory;
use std::collections::HashMap;
use std::sync::Arc;

#[derive(Debug, Default)]
pub(super) struct FlightSqlTables {}

#[async_trait]
impl Extension for FlightSqlTables {
async fn init(
&self,
ctx: &SessionContext,
_settings: &HashMap<String, String>,
) -> datafusion::common::Result<()> {
ctx.state_ref().write().table_factories_mut().insert(
"FLIGHT_SQL".into(),
Arc::new(FlightTableFactory::new(
Arc::new(FlightSqlDriver::default()),
)),
);
Ok(())
}

fn codecs(&self) -> Vec<Box<dyn PhysicalExtensionCodec>> {
vec![Box::new(FlightPhysicalCodec::default())]
}
}
4 changes: 4 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ mod proto;
use crate::context::execute_partition;
pub use proto::generated::protobuf;

mod ext;
use crate::ext::extended_session_context;

pub mod context;
pub mod planner;
pub mod query_stage;
Expand All @@ -36,5 +39,6 @@ fn _datafusion_ray_internal(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<planner::PyExecutionGraph>()?;
m.add_class::<query_stage::PyQueryStage>()?;
m.add_function(wrap_pyfunction!(execute_partition, m)?)?;
m.add_function(wrap_pyfunction!(extended_session_context, m)?)?;
Ok(())
}
7 changes: 4 additions & 3 deletions src/query_stage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
// under the License.

use crate::context::serialize_execution_plan;
use crate::shuffle::{ShuffleCodec, ShuffleReaderExec};
use crate::ext::Extensions;
use crate::shuffle::ShuffleReaderExec;
use datafusion::error::Result;
use datafusion::physical_plan::{ExecutionPlan, Partitioning};
use datafusion::prelude::SessionContext;
Expand All @@ -41,8 +42,8 @@ impl PyQueryStage {
#[new]
pub fn new(id: usize, bytes: Vec<u8>) -> Result<Self> {
let ctx = SessionContext::new();
let codec = ShuffleCodec {};
let plan = physical_plan_from_bytes_with_extension_codec(&bytes, &ctx, &codec)?;
let plan =
physical_plan_from_bytes_with_extension_codec(&bytes, &ctx, Extensions::codec())?;
Ok(PyQueryStage {
stage: Arc::new(QueryStage { id, plan }),
})
Expand Down