Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add missing PyLogicalPlan to_variant #1085

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions python/datafusion/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,12 @@
SqlTable = common_internal.SqlTable
SqlType = common_internal.SqlType
SqlView = common_internal.SqlView
TableType = common_internal.TableType
TableSource = common_internal.TableSource
Constraints = common_internal.Constraints

__all__ = [
"Constraints",
"DFSchema",
"DataType",
"DataTypeMap",
Expand All @@ -47,6 +51,8 @@
"SqlTable",
"SqlType",
"SqlView",
"TableSource",
"TableType",
]


Expand Down
50 changes: 50 additions & 0 deletions python/datafusion/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,29 @@
Case = expr_internal.Case
Cast = expr_internal.Cast
Column = expr_internal.Column
CopyTo = expr_internal.CopyTo
CreateCatalog = expr_internal.CreateCatalog
CreateCatalogSchema = expr_internal.CreateCatalogSchema
CreateExternalTable = expr_internal.CreateExternalTable
CreateFunction = expr_internal.CreateFunction
CreateFunctionBody = expr_internal.CreateFunctionBody
CreateIndex = expr_internal.CreateIndex
CreateMemoryTable = expr_internal.CreateMemoryTable
CreateView = expr_internal.CreateView
Deallocate = expr_internal.Deallocate
DescribeTable = expr_internal.DescribeTable
Distinct = expr_internal.Distinct
DmlStatement = expr_internal.DmlStatement
DropCatalogSchema = expr_internal.DropCatalogSchema
DropFunction = expr_internal.DropFunction
DropTable = expr_internal.DropTable
DropView = expr_internal.DropView
EmptyRelation = expr_internal.EmptyRelation
Execute = expr_internal.Execute
Exists = expr_internal.Exists
Explain = expr_internal.Explain
Extension = expr_internal.Extension
FileType = expr_internal.FileType
Filter = expr_internal.Filter
GroupingSet = expr_internal.GroupingSet
Join = expr_internal.Join
Expand All @@ -83,21 +98,31 @@
Literal = expr_internal.Literal
Negative = expr_internal.Negative
Not = expr_internal.Not
OperateFunctionArg = expr_internal.OperateFunctionArg
Partitioning = expr_internal.Partitioning
Placeholder = expr_internal.Placeholder
Prepare = expr_internal.Prepare
Projection = expr_internal.Projection
RecursiveQuery = expr_internal.RecursiveQuery
Repartition = expr_internal.Repartition
ScalarSubquery = expr_internal.ScalarSubquery
ScalarVariable = expr_internal.ScalarVariable
SetVariable = expr_internal.SetVariable
SimilarTo = expr_internal.SimilarTo
Sort = expr_internal.Sort
Subquery = expr_internal.Subquery
SubqueryAlias = expr_internal.SubqueryAlias
TableScan = expr_internal.TableScan
TransactionAccessMode = expr_internal.TransactionAccessMode
TransactionConclusion = expr_internal.TransactionConclusion
TransactionEnd = expr_internal.TransactionEnd
TransactionIsolationLevel = expr_internal.TransactionIsolationLevel
TransactionStart = expr_internal.TransactionStart
TryCast = expr_internal.TryCast
Union = expr_internal.Union
Unnest = expr_internal.Unnest
UnnestExpr = expr_internal.UnnestExpr
Values = expr_internal.Values
WindowExpr = expr_internal.WindowExpr

__all__ = [
Expand All @@ -111,15 +136,30 @@
"CaseBuilder",
"Cast",
"Column",
"CopyTo",
"CreateCatalog",
"CreateCatalogSchema",
"CreateExternalTable",
"CreateFunction",
"CreateFunctionBody",
"CreateIndex",
"CreateMemoryTable",
"CreateView",
"Deallocate",
"DescribeTable",
"Distinct",
"DmlStatement",
"DropCatalogSchema",
"DropFunction",
"DropTable",
"DropView",
"EmptyRelation",
"Execute",
"Exists",
"Explain",
"Expr",
"Extension",
"FileType",
"Filter",
"GroupingSet",
"ILike",
Expand All @@ -142,22 +182,32 @@
"Literal",
"Negative",
"Not",
"OperateFunctionArg",
"Partitioning",
"Placeholder",
"Prepare",
"Projection",
"RecursiveQuery",
"Repartition",
"ScalarSubquery",
"ScalarVariable",
"SetVariable",
"SimilarTo",
"Sort",
"SortExpr",
"Subquery",
"SubqueryAlias",
"TableScan",
"TransactionAccessMode",
"TransactionConclusion",
"TransactionEnd",
"TransactionIsolationLevel",
"TransactionStart",
"TryCast",
"Union",
"Unnest",
"UnnestExpr",
"Values",
"Window",
"WindowExpr",
"WindowFrame",
Expand Down
86 changes: 86 additions & 0 deletions python/tests/test_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,21 @@
AggregateFunction,
BinaryExpr,
Column,
CopyTo,
CreateIndex,
DescribeTable,
DmlStatement,
DropCatalogSchema,
Filter,
Limit,
Literal,
Projection,
RecursiveQuery,
Sort,
TableScan,
TransactionEnd,
TransactionStart,
Values,
)


Expand Down Expand Up @@ -247,3 +256,80 @@ def test_fill_null(df):
assert result.column(0) == pa.array([1, 2, 100])
assert result.column(1) == pa.array([4, 25, 6])
assert result.column(2) == pa.array([1234, 1234, 8])


def test_copy_to():
ctx = SessionContext()
ctx.sql("CREATE TABLE foo (a int, b int)").collect()
df = ctx.sql("COPY foo TO bar STORED AS CSV")
plan = df.logical_plan()
plan = plan.to_variant()
assert isinstance(plan, CopyTo)


def test_create_index():
ctx = SessionContext()
ctx.sql("CREATE TABLE foo (a int, b int)").collect()
plan = ctx.sql("create index idx on foo (a)").logical_plan()
plan = plan.to_variant()
assert isinstance(plan, CreateIndex)


def test_describe_table():
ctx = SessionContext()
ctx.sql("CREATE TABLE foo (a int, b int)").collect()
plan = ctx.sql("describe foo").logical_plan()
plan = plan.to_variant()
assert isinstance(plan, DescribeTable)


def test_dml_statement():
ctx = SessionContext()
ctx.sql("CREATE TABLE foo (a int, b int)").collect()
plan = ctx.sql("insert into foo values (1, 2)").logical_plan()
plan = plan.to_variant()
assert isinstance(plan, DmlStatement)


def drop_catalog_schema():
ctx = SessionContext()
plan = ctx.sql("drop schema cat").logical_plan()
plan = plan.to_variant()
assert isinstance(plan, DropCatalogSchema)


def test_recursive_query():
ctx = SessionContext()
plan = ctx.sql(
"""
WITH RECURSIVE cte AS (
SELECT 1 as n
UNION ALL
SELECT n + 1 FROM cte WHERE n < 5
)
SELECT * FROM cte;
"""
).logical_plan()
plan = plan.inputs()[0].inputs()[0].to_variant()
assert isinstance(plan, RecursiveQuery)


def test_values():
ctx = SessionContext()
plan = ctx.sql("values (1, 'foo'), (2, 'bar')").logical_plan()
plan = plan.to_variant()
assert isinstance(plan, Values)


def test_transaction_start():
ctx = SessionContext()
plan = ctx.sql("START TRANSACTION").logical_plan()
plan = plan.to_variant()
assert isinstance(plan, TransactionStart)


def test_transaction_end():
ctx = SessionContext()
plan = ctx.sql("COMMIT").logical_plan()
plan = plan.to_variant()
assert isinstance(plan, TransactionEnd)
3 changes: 3 additions & 0 deletions src/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,8 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<schema::SqlView>()?;
m.add_class::<schema::SqlStatistics>()?;
m.add_class::<function::SqlFunction>()?;
m.add_class::<schema::PyTableType>()?;
m.add_class::<schema::PyTableSource>()?;
m.add_class::<schema::PyConstraints>()?;
Ok(())
}
89 changes: 89 additions & 0 deletions src/common/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,22 @@
// specific language governing permissions and limitations
// under the License.

use std::fmt::{self, Display, Formatter};
use std::sync::Arc;
use std::{any::Any, borrow::Cow};

use arrow::datatypes::Schema;
use arrow::pyarrow::PyArrowType;
use datafusion::arrow::datatypes::SchemaRef;
use datafusion::common::Constraints;
use datafusion::datasource::TableType;
use datafusion::logical_expr::{Expr, TableProviderFilterPushDown, TableSource};
use pyo3::prelude::*;

use datafusion::logical_expr::utils::split_conjunction;

use crate::sql::logical::PyLogicalPlan;

use super::{data_type::DataTypeMap, function::SqlFunction};

#[pyclass(name = "SqlSchema", module = "datafusion.common", subclass)]
Expand Down Expand Up @@ -218,3 +226,84 @@ impl SqlStatistics {
self.row_count
}
}

#[pyclass(name = "Constraints", module = "datafusion.expr", subclass)]
#[derive(Clone)]
pub struct PyConstraints {
pub constraints: Constraints,
}

impl From<PyConstraints> for Constraints {
fn from(constraints: PyConstraints) -> Self {
constraints.constraints
}
}

impl From<Constraints> for PyConstraints {
fn from(constraints: Constraints) -> Self {
PyConstraints { constraints }
}
}

impl Display for PyConstraints {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
write!(f, "Constraints: {:?}", self.constraints)
}
}

#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
#[pyclass(eq, eq_int, name = "TableType", module = "datafusion.common")]
pub enum PyTableType {
Base,
View,
Temporary,
}

impl From<PyTableType> for datafusion::logical_expr::TableType {
fn from(table_type: PyTableType) -> Self {
match table_type {
PyTableType::Base => datafusion::logical_expr::TableType::Base,
PyTableType::View => datafusion::logical_expr::TableType::View,
PyTableType::Temporary => datafusion::logical_expr::TableType::Temporary,
}
}
}

impl From<TableType> for PyTableType {
fn from(table_type: TableType) -> Self {
match table_type {
datafusion::logical_expr::TableType::Base => PyTableType::Base,
datafusion::logical_expr::TableType::View => PyTableType::View,
datafusion::logical_expr::TableType::Temporary => PyTableType::Temporary,
}
}
}

#[pyclass(name = "TableSource", module = "datafusion.common", subclass)]
#[derive(Clone)]
pub struct PyTableSource {
pub table_source: Arc<dyn TableSource>,
}

#[pymethods]
impl PyTableSource {
pub fn schema(&self) -> PyArrowType<Schema> {
(*self.table_source.schema()).clone().into()
}

pub fn constraints(&self) -> Option<PyConstraints> {
self.table_source.constraints().map(|c| PyConstraints {
constraints: c.clone(),
})
}

pub fn table_type(&self) -> PyTableType {
self.table_source.table_type().into()
}

pub fn get_logical_plan(&self) -> Option<PyLogicalPlan> {
self.table_source
.get_logical_plan()
.map(|plan| PyLogicalPlan::new(plan.into_owned()))
}
}
Loading