Skip to content

Commit 9b4f90a

Browse files
authored
Fix sort node deserialization from proto (#12626)
1 parent 1b3608d commit 9b4f90a

File tree

3 files changed

+46
-9
lines changed

3 files changed

+46
-9
lines changed

datafusion/expr/src/logical_plan/builder.rs

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -570,10 +570,18 @@ impl LogicalPlanBuilder {
570570
)
571571
}
572572

573-
/// Apply a sort
574573
pub fn sort(
575574
self,
576575
sorts: impl IntoIterator<Item = impl Into<SortExpr>> + Clone,
576+
) -> Result<Self> {
577+
self.sort_with_limit(sorts, None)
578+
}
579+
580+
/// Apply a sort
581+
pub fn sort_with_limit(
582+
self,
583+
sorts: impl IntoIterator<Item = impl Into<SortExpr>> + Clone,
584+
fetch: Option<usize>,
577585
) -> Result<Self> {
578586
let sorts = rewrite_sort_cols_by_aggs(sorts, &self.plan)?;
579587

@@ -597,7 +605,7 @@ impl LogicalPlanBuilder {
597605
return Ok(Self::new(LogicalPlan::Sort(Sort {
598606
expr: normalize_sorts(sorts, &self.plan)?,
599607
input: self.plan,
600-
fetch: None,
608+
fetch,
601609
})));
602610
}
603611

@@ -613,7 +621,7 @@ impl LogicalPlanBuilder {
613621
let sort_plan = LogicalPlan::Sort(Sort {
614622
expr: normalize_sorts(sorts, &plan)?,
615623
input: Arc::new(plan),
616-
fetch: None,
624+
fetch,
617625
});
618626

619627
Projection::try_new(new_expr, Arc::new(sort_plan))
@@ -1202,7 +1210,7 @@ impl LogicalPlanBuilder {
12021210

12031211
/// Unnest the given columns with the given [`UnnestOptions`]
12041212
/// if one column is a list type, it can be recursively and simultaneously
1205-
/// unnested into the desired recursion levels
1213+
/// unnested into the desired recursion levels
12061214
/// e.g select unnest(list_col,depth=1), unnest(list_col,depth=2)
12071215
pub fn unnest_columns_recursive_with_options(
12081216
self,

datafusion/proto/src/logical_plan/mod.rs

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -490,17 +490,20 @@ impl AsLogicalPlan for LogicalPlanNode {
490490
into_logical_plan!(sort.input, ctx, extension_codec)?;
491491
let sort_expr: Vec<SortExpr> =
492492
from_proto::parse_sorts(&sort.expr, ctx, extension_codec)?;
493-
LogicalPlanBuilder::from(input).sort(sort_expr)?.build()
493+
let fetch: Option<usize> = sort.fetch.try_into().ok();
494+
LogicalPlanBuilder::from(input)
495+
.sort_with_limit(sort_expr, fetch)?
496+
.build()
494497
}
495498
LogicalPlanType::Repartition(repartition) => {
496499
use datafusion::logical_expr::Partitioning;
497500
let input: LogicalPlan =
498501
into_logical_plan!(repartition.input, ctx, extension_codec)?;
499502
use protobuf::repartition_node::PartitionMethod;
500503
let pb_partition_method = repartition.partition_method.as_ref().ok_or_else(|| {
501-
DataFusionError::Internal(String::from(
502-
"Protobuf deserialization error, RepartitionNode was missing required field 'partition_method'",
503-
))
504+
internal_datafusion_err!(
505+
"Protobuf deserialization error, RepartitionNode was missing required field 'partition_method'"
506+
)
504507
})?;
505508

506509
let partitioning_scheme = match pb_partition_method {
@@ -526,7 +529,7 @@ impl AsLogicalPlan for LogicalPlanNode {
526529
LogicalPlanType::CreateExternalTable(create_extern_table) => {
527530
let pb_schema = (create_extern_table.schema.clone()).ok_or_else(|| {
528531
DataFusionError::Internal(String::from(
529-
"Protobuf deserialization error, CreateExternalTableNode was missing required field schema.",
532+
"Protobuf deserialization error, CreateExternalTableNode was missing required field schema."
530533
))
531534
})?;
532535

datafusion/proto/tests/cases/roundtrip_logical_plan.rs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,32 @@ async fn roundtrip_logical_plan_aggregation() -> Result<()> {
341341
Ok(())
342342
}
343343

344+
#[tokio::test]
345+
async fn roundtrip_logical_plan_sort() -> Result<()> {
346+
let ctx = SessionContext::new();
347+
348+
let schema = Schema::new(vec![
349+
Field::new("a", DataType::Int64, true),
350+
Field::new("b", DataType::Decimal128(15, 2), true),
351+
]);
352+
353+
ctx.register_csv(
354+
"t1",
355+
"tests/testdata/test.csv",
356+
CsvReadOptions::default().schema(&schema),
357+
)
358+
.await?;
359+
360+
let query = "SELECT a, b FROM t1 ORDER BY b LIMIT 5";
361+
let plan = ctx.sql(query).await?.into_optimized_plan()?;
362+
363+
let bytes = logical_plan_to_bytes(&plan)?;
364+
let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?;
365+
assert_eq!(format!("{plan}"), format!("{logical_round_trip}"));
366+
367+
Ok(())
368+
}
369+
344370
#[tokio::test]
345371
async fn roundtrip_logical_plan_copy_to_sql_options() -> Result<()> {
346372
let ctx = SessionContext::new();

0 commit comments

Comments
 (0)