Skip to content

Commit 9a6805e

Browse files
authored
Add missing exports for wrapper modules (#782)
* Add imports in base file to match those in internal * Correct class capitalization and exports for substrait * Add exports for common to match internal * Add exports for Expr to match internal * Add __all__ to functions * Add exports for object store to match internal * Add pytest to ensure all pyo3 exposed objects are also exposed in our wrappers so we don't miss any functions or classes * Add license
1 parent 951d6b9 commit 9a6805e

File tree

8 files changed

+423
-14
lines changed

8 files changed

+423
-14
lines changed

python/datafusion/__init__.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,12 @@
3333
SQLOptions,
3434
)
3535

36+
from .catalog import Catalog, Database, Table
37+
3638
# The following imports are okay to remain as opaque to the user.
37-
from ._internal import Config
39+
from ._internal import Config, LogicalPlan, ExecutionPlan, runtime
40+
41+
from .record_batch import RecordBatchStream, RecordBatch
3842

3943
from .udf import ScalarUDF, AggregateUDF, Accumulator
4044

@@ -49,6 +53,8 @@
4953
WindowFrame,
5054
)
5155

56+
from . import functions, object_store, substrait
57+
5258
__version__ = importlib_metadata.version(__name__)
5359

5460
__all__ = [
@@ -65,6 +71,20 @@
6571
"column",
6672
"literal",
6773
"DFSchema",
74+
"runtime",
75+
"Catalog",
76+
"Database",
77+
"Table",
78+
"AggregateUDF",
79+
"LogicalPlan",
80+
"ExecutionPlan",
81+
"RecordBatch",
82+
"RecordBatchStream",
83+
"common",
84+
"expr",
85+
"functions",
86+
"object_store",
87+
"substrait",
6888
]
6989

7090

python/datafusion/common.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,34 @@
1616
# under the License.
1717
"""Common data types used throughout the DataFusion project."""
1818

19-
from ._internal import common
19+
from ._internal import common as common_internal
2020

21+
# TODO these should all have proper wrapper classes
2122

22-
def __getattr__(name):
23-
return getattr(common, name)
23+
DFSchema = common_internal.DFSchema
24+
DataType = common_internal.DataType
25+
DataTypeMap = common_internal.DataTypeMap
26+
NullTreatment = common_internal.NullTreatment
27+
PythonType = common_internal.PythonType
28+
RexType = common_internal.RexType
29+
SqlFunction = common_internal.SqlFunction
30+
SqlSchema = common_internal.SqlSchema
31+
SqlStatistics = common_internal.SqlStatistics
32+
SqlTable = common_internal.SqlTable
33+
SqlType = common_internal.SqlType
34+
SqlView = common_internal.SqlView
35+
36+
__all__ = [
37+
"DFSchema",
38+
"DataType",
39+
"DataTypeMap",
40+
"RexType",
41+
"PythonType",
42+
"SqlType",
43+
"NullTreatment",
44+
"SqlTable",
45+
"SqlSchema",
46+
"SqlView",
47+
"SqlStatistics",
48+
"SqlFunction",
49+
]

python/datafusion/expr.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
CrossJoin = expr_internal.CrossJoin
4848
Distinct = expr_internal.Distinct
4949
DropTable = expr_internal.DropTable
50+
EmptyRelation = expr_internal.EmptyRelation
5051
Exists = expr_internal.Exists
5152
Explain = expr_internal.Explain
5253
Extension = expr_internal.Extension
@@ -58,6 +59,7 @@
5859
InSubquery = expr_internal.InSubquery
5960
IsFalse = expr_internal.IsFalse
6061
IsNotTrue = expr_internal.IsNotTrue
62+
IsNull = expr_internal.IsNull
6163
IsTrue = expr_internal.IsTrue
6264
IsUnknown = expr_internal.IsUnknown
6365
IsNotFalse = expr_internal.IsNotFalse
@@ -83,6 +85,70 @@
8385
TableScan = expr_internal.TableScan
8486
TryCast = expr_internal.TryCast
8587
Union = expr_internal.Union
88+
Unnest = expr_internal.Unnest
89+
Window = expr_internal.Window
90+
91+
__all__ = [
92+
"Expr",
93+
"Column",
94+
"Literal",
95+
"BinaryExpr",
96+
"Literal",
97+
"AggregateFunction",
98+
"Not",
99+
"IsNotNull",
100+
"IsNull",
101+
"IsTrue",
102+
"IsFalse",
103+
"IsUnknown",
104+
"IsNotTrue",
105+
"IsNotFalse",
106+
"IsNotUnknown",
107+
"Negative",
108+
"Like",
109+
"ILike",
110+
"SimilarTo",
111+
"ScalarVariable",
112+
"Alias",
113+
"InList",
114+
"Exists",
115+
"Subquery",
116+
"InSubquery",
117+
"ScalarSubquery",
118+
"Placeholder",
119+
"GroupingSet",
120+
"Case",
121+
"CaseBuilder",
122+
"Cast",
123+
"TryCast",
124+
"Between",
125+
"Explain",
126+
"Limit",
127+
"Aggregate",
128+
"Sort",
129+
"Analyze",
130+
"EmptyRelation",
131+
"Join",
132+
"JoinType",
133+
"JoinConstraint",
134+
"CrossJoin",
135+
"Union",
136+
"Unnest",
137+
"Extension",
138+
"Filter",
139+
"Projection",
140+
"TableScan",
141+
"CreateMemoryTable",
142+
"CreateView",
143+
"Distinct",
144+
"SubqueryAlias",
145+
"DropTable",
146+
"Partitioning",
147+
"Repartition",
148+
"Window",
149+
"WindowFrame",
150+
"WindowFrameBound",
151+
]
86152

87153

88154
class Expr:
@@ -246,6 +312,14 @@ def __lt__(self, rhs: Any) -> Expr:
246312
rhs = Expr.literal(rhs)
247313
return Expr(self.expr.__lt__(rhs.expr))
248314

315+
__radd__ = __add__
316+
__rand__ = __and__
317+
__rmod__ = __mod__
318+
__rmul__ = __mul__
319+
__ror__ = __or__
320+
__rsub__ = __sub__
321+
__rtruediv__ = __truediv__
322+
249323
@staticmethod
250324
def literal(value: Any) -> Expr:
251325
"""Creates a new expression representing a scalar value.

0 commit comments

Comments
 (0)