Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions crates/core/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ use datafusion_proto::physical_plan::PhysicalExtensionCodec;
use datafusion_python_util::{
create_logical_extension_capsule, create_physical_extension_capsule,
ffi_logical_codec_from_pycapsule, get_global_ctx, get_tokio_runtime,
physical_codec_from_pycapsule, spawn_future, wait_for_future,
physical_codec_from_pycapsule, physical_optimizer_rule_from_pycapsule, spawn_future,
wait_for_future,
};
use object_store::ObjectStore;
use pyo3::IntoPyObjectExt;
Expand Down Expand Up @@ -375,11 +376,12 @@ pub struct PySessionContext {

#[pymethods]
impl PySessionContext {
#[pyo3(signature = (config=None, runtime=None))]
#[pyo3(signature = (config=None, runtime=None, physical_optimizer_rules=None))]
#[new]
pub fn new(
config: Option<PySessionConfig>,
runtime: Option<PyRuntimeEnvBuilder>,
physical_optimizer_rules: Option<Vec<Bound<'_, PyAny>>>,
) -> PyDataFusionResult<Self> {
let config = if let Some(c) = config {
c.config
Expand All @@ -392,11 +394,15 @@ impl PySessionContext {
RuntimeEnvBuilder::default()
};
let runtime = Arc::new(runtime_env_builder.build()?);
let session_state = SessionStateBuilder::new()
let mut state_builder = SessionStateBuilder::new()
.with_config(config)
.with_runtime_env(runtime)
.with_default_features()
.build();
.with_default_features();
for rule in physical_optimizer_rules.unwrap_or_default() {
let rule = physical_optimizer_rule_from_pycapsule(&rule)?;
state_builder = state_builder.with_physical_optimizer_rule(rule);
}
let session_state = state_builder.build();
let ctx = Arc::new(SessionContext::new_with_state(session_state));
Ok(PySessionContext {
ctx,
Expand Down
9 changes: 9 additions & 0 deletions crates/util/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ use datafusion::datasource::TableProvider;
use datafusion::execution::TaskContext;
use datafusion::execution::context::SessionContext;
use datafusion::logical_expr::Volatility;
use datafusion::physical_optimizer::PhysicalOptimizerRule;
use datafusion_ffi::execution::FFI_TaskContextProvider;
use datafusion_ffi::physical_optimizer::FFI_PhysicalOptimizerRule;
use datafusion_ffi::proto::logical_extension_codec::FFI_LogicalExtensionCodec;
use datafusion_ffi::proto::physical_extension_codec::FFI_PhysicalExtensionCodec;
use datafusion_ffi::table_provider::FFI_TableProvider;
Expand Down Expand Up @@ -332,6 +334,13 @@ from_pycapsule!(
dyn PhysicalExtensionCodec
);

from_pycapsule!(
physical_optimizer_rule_from_pycapsule,
"datafusion_physical_optimizer_rule",
FFI_PhysicalOptimizerRule,
dyn PhysicalOptimizerRule + Send + Sync
);

try_from_pycapsule!(
task_context_from_pycapsule,
"datafusion_task_context_provider",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

import pyarrow as pa
from datafusion import SessionContext
from datafusion_ffi_example import MyPhysicalOptimizerRule


def test_ffi_physical_optimizer_rule_runs_during_planning():
"""A rule supplied via physical_optimizer_rules is invoked while the
physical plan is built, and the query still returns correct results."""
rule = MyPhysicalOptimizerRule()
ctx = SessionContext(physical_optimizer_rules=[rule])
batch = pa.RecordBatch.from_arrays(
[pa.array([1, 2, 3])],
names=["a"],
)
ctx.register_record_batches("t", [[batch]])

before = rule.optimize_calls()
result = ctx.sql("SELECT a FROM t").collect()
after = rule.optimize_calls()

assert after > before, (
f"Expected user FFI physical optimizer rule to fire, "
f"before={before} after={after}"
)
assert result[0].column(0).to_pylist() == [1, 2, 3]


def test_ffi_physical_optimizer_rule_export():
"""The rule object exposes the FFI capsule entry point."""
rule = MyPhysicalOptimizerRule()
capsule = rule.__datafusion_physical_optimizer_rule__()
assert capsule is not None
3 changes: 3 additions & 0 deletions examples/datafusion-ffi-example/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use crate::catalog_provider::{FixedSchemaProvider, MyCatalogProvider, MyCatalogP
use crate::config::MyConfig;
use crate::logical_extension_codec::MyLogicalExtensionCodec;
use crate::physical_extension_codec::MyPhysicalExtensionCodec;
use crate::physical_optimizer::MyPhysicalOptimizerRule;
use crate::scalar_udf::IsNullUDF;
use crate::table_function::MyTableFunction;
use crate::table_provider::MyTableProvider;
Expand All @@ -33,6 +34,7 @@ pub(crate) mod catalog_provider;
pub(crate) mod config;
pub(crate) mod logical_extension_codec;
pub(crate) mod physical_extension_codec;
pub(crate) mod physical_optimizer;
pub(crate) mod scalar_udf;
pub(crate) mod table_function;
pub(crate) mod table_provider;
Expand All @@ -55,5 +57,6 @@ fn datafusion_ffi_example(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<MyConfig>()?;
m.add_class::<MyLogicalExtensionCodec>()?;
m.add_class::<MyPhysicalExtensionCodec>()?;
m.add_class::<MyPhysicalOptimizerRule>()?;
Ok(())
}
98 changes: 98 additions & 0 deletions examples/datafusion-ffi-example/src/physical_optimizer.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};

use datafusion::common::Result;
use datafusion::common::config::ConfigOptions;
use datafusion::physical_optimizer::PhysicalOptimizerRule;
use datafusion::physical_plan::ExecutionPlan;
use datafusion_ffi::physical_optimizer::FFI_PhysicalOptimizerRule;
use datafusion_python_util::get_tokio_runtime;
use pyo3::prelude::*;
use pyo3::types::PyCapsule;

/// A physical optimizer rule that leaves every plan unchanged but bumps a
/// shared counter each time it runs. Tests use the counter to prove that a
/// session built with this rule actually routed physical planning through a
/// user-supplied [`PhysicalOptimizerRule`] over FFI.
#[derive(Debug)]
struct CountingPhysicalOptimizerRule {
optimize_calls: Arc<AtomicUsize>,
}

impl PhysicalOptimizerRule for CountingPhysicalOptimizerRule {
fn optimize(
&self,
plan: Arc<dyn ExecutionPlan>,
_config: &ConfigOptions,
) -> Result<Arc<dyn ExecutionPlan>> {
self.optimize_calls.fetch_add(1, Ordering::SeqCst);
Ok(plan)
}

fn name(&self) -> &str {
"counting_physical_optimizer_rule"
}

fn schema_check(&self) -> bool {
// The plan is returned unchanged, so the schema is preserved.
true
}
}

/// Python-visible handle that produces an [`FFI_PhysicalOptimizerRule`] and
/// exposes the shared call counter.
#[pyclass(
from_py_object,
name = "MyPhysicalOptimizerRule",
module = "datafusion_ffi_example",
subclass
)]
#[derive(Debug, Default, Clone)]
pub(crate) struct MyPhysicalOptimizerRule {
optimize_calls: Arc<AtomicUsize>,
}

#[pymethods]
impl MyPhysicalOptimizerRule {
#[new]
fn new() -> Self {
Self::default()
}

fn optimize_calls(&self) -> usize {
self.optimize_calls.load(Ordering::SeqCst)
}

fn __datafusion_physical_optimizer_rule__<'py>(
&self,
py: Python<'py>,
) -> PyResult<Bound<'py, PyCapsule>> {
let rule: Arc<dyn PhysicalOptimizerRule + Send + Sync> =
Arc::new(CountingPhysicalOptimizerRule {
optimize_calls: Arc::clone(&self.optimize_calls),
});

let runtime = get_tokio_runtime().handle().clone();
let ffi = FFI_PhysicalOptimizerRule::new(rule, Some(runtime));

let name = cr"datafusion_physical_optimizer_rule".into();
PyCapsule::new(py, ffi, Some(name))
}
}
28 changes: 27 additions & 1 deletion python/datafusion/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,16 @@ class TableProviderExportable(Protocol):
def __datafusion_table_provider__(self, session: Any) -> object: ... # noqa: D105


class PhysicalOptimizerRuleExportable(Protocol):
"""Type hint for object that has __datafusion_physical_optimizer_rule__ PyCapsule.

The method returns a PyCapsule wrapping an ``FFI_PhysicalOptimizerRule``,
typically produced by a separate compiled extension.
"""

def __datafusion_physical_optimizer_rule__(self) -> object: ... # noqa: D105


class SessionConfig:
"""Session configuration options."""

Expand Down Expand Up @@ -524,6 +534,7 @@ def __init__(
self,
config: SessionConfig | None = None,
runtime: RuntimeEnvBuilder | None = None,
physical_optimizer_rules: list[PhysicalOptimizerRuleExportable] | None = None,
) -> None:
"""Main interface for executing queries with DataFusion.

Expand All @@ -534,6 +545,11 @@ def __init__(
Args:
config: Session configuration options.
runtime: Runtime configuration options.
physical_optimizer_rules: User-defined physical optimizer rules to
append to the default set, each a
:class:`PhysicalOptimizerRuleExportable`. There is no upstream
API to add physical rules to a live context, so these can only
be supplied at construction time.

Example usage:

Expand All @@ -544,11 +560,21 @@ def __init__(

ctx = SessionContext()
df = ctx.read_csv("data.csv")

To register a physical optimizer rule supplied by a compiled
extension, pass it via ``physical_optimizer_rules``::

from datafusion import SessionContext
from my_extension import MyPhysicalOptimizerRule

ctx = SessionContext(
physical_optimizer_rules=[MyPhysicalOptimizerRule()]
)
"""
config = config.config_internal if config is not None else None
runtime = runtime.config_internal if runtime is not None else None

self.ctx = SessionContextInternal(config, runtime)
self.ctx = SessionContextInternal(config, runtime, physical_optimizer_rules)

def __repr__(self) -> str:
"""Print a string representation of the Session Context."""
Expand Down
Loading