15
15
// specific language governing permissions and limitations
16
16
// under the License.
17
17
18
+ use datafusion_expr:: utils:: exprlist_to_fields;
19
+ use datafusion_expr:: LogicalPlan ;
18
20
use pyo3:: { basic:: CompareOp , prelude:: * } ;
19
21
use std:: convert:: { From , Into } ;
22
+ use std:: sync:: Arc ;
20
23
21
- use datafusion:: arrow:: datatypes:: DataType ;
24
+ use datafusion:: arrow:: datatypes:: { DataType , Field } ;
22
25
use datafusion:: arrow:: pyarrow:: PyArrowType ;
23
26
use datafusion:: scalar:: ScalarValue ;
24
27
use datafusion_expr:: {
@@ -29,11 +32,12 @@ use datafusion_expr::{
29
32
} ;
30
33
31
34
use crate :: common:: data_type:: { DataTypeMap , RexType } ;
32
- use crate :: errors:: { py_datafusion_err, py_runtime_err, py_type_err} ;
35
+ use crate :: errors:: { py_datafusion_err, py_runtime_err, py_type_err, DataFusionError } ;
33
36
use crate :: expr:: aggregate_expr:: PyAggregateFunction ;
34
37
use crate :: expr:: binary_expr:: PyBinaryExpr ;
35
38
use crate :: expr:: column:: PyColumn ;
36
39
use crate :: expr:: literal:: PyLiteral ;
40
+ use crate :: sql:: logical:: PyLogicalPlan ;
37
41
38
42
use self :: alias:: PyAlias ;
39
43
use self :: bool_expr:: {
@@ -553,9 +557,40 @@ impl PyExpr {
553
557
}
554
558
} )
555
559
}
560
+
561
+ pub fn column_name ( & self , plan : PyLogicalPlan ) -> PyResult < String > {
562
+ self . _column_name ( & plan. plan ( ) ) . map_err ( py_runtime_err)
563
+ }
556
564
}
557
565
558
566
impl PyExpr {
567
+ pub fn _column_name ( & self , plan : & LogicalPlan ) -> Result < String , DataFusionError > {
568
+ let field = Self :: expr_to_field ( & self . expr , plan) ?;
569
+ Ok ( field. name ( ) . to_owned ( ) )
570
+ }
571
+
572
+ /// Create a [Field] representing an [Expr], given an input [LogicalPlan] to resolve against
573
+ pub fn expr_to_field (
574
+ expr : & Expr ,
575
+ input_plan : & LogicalPlan ,
576
+ ) -> Result < Arc < Field > , DataFusionError > {
577
+ match expr {
578
+ Expr :: Sort ( Sort { expr, .. } ) => {
579
+ // DataFusion does not support create_name for sort expressions (since they never
580
+ // appear in projections) so we just delegate to the contained expression instead
581
+ Self :: expr_to_field ( expr, input_plan)
582
+ }
583
+ Expr :: Wildcard { .. } => {
584
+ // Since * could be any of the valid column names just return the first one
585
+ Ok ( Arc :: new ( input_plan. schema ( ) . field ( 0 ) . clone ( ) ) )
586
+ }
587
+ _ => {
588
+ let fields =
589
+ exprlist_to_fields ( & [ expr. clone ( ) ] , input_plan) . map_err ( PyErr :: from) ?;
590
+ Ok ( fields[ 0 ] . 1 . clone ( ) )
591
+ }
592
+ }
593
+ }
559
594
fn _types ( expr : & Expr ) -> PyResult < DataTypeMap > {
560
595
match expr {
561
596
Expr :: BinaryExpr ( BinaryExpr {
0 commit comments