Skip to content
This repository has been archived by the owner on Jan 7, 2025. It is now read-only.

[Logical Optimizer] align schema #61

Merged
merged 3 commits into from
Feb 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions datafusion-optd-cli/tests/cli_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,5 +57,8 @@ fn cli_test_tpch() {
cmd.current_dir(".."); // all paths in `test.sql` assume we're in the base dir of the repo
cmd.args(["--enable-logical", "--file", "tpch/test.sql"]);
let status = cmd.status().unwrap();
assert!(status.success(), "should not have crashed when running tpch");
}
assert!(
status.success(),
"should not have crashed when running tpch"
);
}
4 changes: 1 addition & 3 deletions optd-core/src/cascades/optimizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -214,9 +214,7 @@ impl<T: RelNodeTyp> CascadesOptimizer<T> {
group_id: GroupId,
mut on_produce: impl FnMut(RelNodeRef<T>, GroupId) -> RelNodeRef<T>,
) -> Result<RelNodeRef<T>> {
self
.memo
.get_best_group_binding(group_id, &mut on_produce)
self.memo.get_best_group_binding(group_id, &mut on_produce)
}

fn fire_optimize_tasks(&mut self, group_id: GroupId) -> Result<()> {
Expand Down
7 changes: 1 addition & 6 deletions optd-core/src/cascades/tasks/optimize_inputs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -243,12 +243,7 @@ impl<T: RelNodeTyp> Task<T> for OptimizeInputsTask {
} else {
self.update_winner(
&cost.sum(
&cost.compute_cost(
&expr.typ,
&expr.data,
&input_cost,
Some(context),
),
&cost.compute_cost(&expr.typ, &expr.data, &input_cost, Some(context)),
&input_cost,
),
optimizer,
Expand Down
5 changes: 1 addition & 4 deletions optd-core/src/heuristics/optimizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,7 @@ fn match_node<T: RelNodeTyp>(
assert!(res.is_none(), "dup pick");
}
RuleMatcher::PickMany { pick_to } => {
let res = pick.insert(
*pick_to,
RelNode::new_list(node.children[idx..].to_vec()),
);
let res = pick.insert(*pick_to, RelNode::new_list(node.children[idx..].to_vec()));
assert!(res.is_none(), "dup pick");
should_end = true;
}
Expand Down
28 changes: 17 additions & 11 deletions optd-datafusion-bridge/src/from_optd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ use crate::{physical_collector::CollectorExec, OptdPlanContext};
// TODO: current DataType and ConstantType are not 1 to 1 mapping
// optd schema stores constantType from data type in catalog.get
// for decimal128, the precision is lost
fn from_optd_schema(optd_schema: &OptdSchema) -> Schema {
fn from_optd_schema(optd_schema: OptdSchema) -> Schema {
let match_type = |typ: &ConstantType| match typ {
ConstantType::Any => unimplemented!(),
ConstantType::Bool => DataType::Boolean,
Expand All @@ -52,12 +52,14 @@ fn from_optd_schema(optd_schema: &OptdSchema) -> Schema {
ConstantType::Decimal => DataType::Float64,
ConstantType::Utf8String => DataType::Utf8,
};
let fields: Vec<_> = optd_schema
.0
.iter()
.enumerate()
.map(|(i, typ)| Field::new(format!("c{}", i), match_type(typ), false))
.collect();
let mut fields = Vec::with_capacity(optd_schema.len());
for field in optd_schema.fields {
fields.push(Field::new(
field.name,
match_type(&field.typ),
field.nullable,
));
}
Schema::new(fields)
}

Expand Down Expand Up @@ -351,7 +353,8 @@ impl OptdPlanContext<'_> {
Schema::new_with_metadata(fields, HashMap::new())
};

let physical_expr = Self::conv_from_optd_expr(node.cond(), &Arc::new(filter_schema.clone()))?;
let physical_expr =
Self::conv_from_optd_expr(node.cond(), &Arc::new(filter_schema.clone()))?;

if let JoinType::Cross = node.join_type() {
return Ok(Arc::new(CrossJoinExec::new(left_exec, right_exec))
Expand Down Expand Up @@ -436,7 +439,7 @@ impl OptdPlanContext<'_> {

#[async_recursion]
async fn conv_from_optd_plan_node(&mut self, node: PlanNode) -> Result<Arc<dyn ExecutionPlan>> {
let mut schema = OptdSchema(vec![]);
let mut schema = OptdSchema { fields: vec![] };
if node.typ() == OptRelNodeTyp::PhysicalEmptyRelation {
schema = node.schema(self.optimizer.unwrap().optd_optimizer());
}
Expand Down Expand Up @@ -484,7 +487,7 @@ impl OptdPlanContext<'_> {
}
OptRelNodeTyp::PhysicalEmptyRelation => {
let physical_node = PhysicalEmptyRelation::from_rel_node(rel_node).unwrap();
let datafusion_schema: Schema = from_optd_schema(&schema);
let datafusion_schema: Schema = from_optd_schema(schema);
Ok(Arc::new(datafusion::physical_plan::empty::EmptyExec::new(
physical_node.produce_one_row(),
Arc::new(datafusion_schema),
Expand All @@ -495,7 +498,10 @@ impl OptdPlanContext<'_> {
result.with_context(|| format!("when processing {}", rel_node_dbg))
}

pub async fn conv_from_optd(&mut self, root_rel: OptRelNodeRef) -> Result<Arc<dyn ExecutionPlan>> {
pub async fn conv_from_optd(
&mut self,
root_rel: OptRelNodeRef,
) -> Result<Arc<dyn ExecutionPlan>> {
self.conv_from_optd_plan_node(PlanNode::from_rel_node(root_rel).unwrap())
.await
}
Expand Down
17 changes: 12 additions & 5 deletions optd-datafusion-bridge/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,10 @@ impl Catalog for DatafusionCatalog {
let catalog = self.catalog.catalog("datafusion").unwrap();
let schema = catalog.schema("public").unwrap();
let table = futures_lite::future::block_on(schema.table(name.as_ref())).unwrap();
let fields = table.schema();
let mut optd_schema = vec![];
for field in fields.fields() {
let schema = table.schema();
let fields = schema.fields();
let mut optd_fields = Vec::with_capacity(fields.len());
for field in fields {
let dt = match field.data_type() {
DataType::Date32 => ConstantType::Date,
DataType::Int32 => ConstantType::Int32,
Expand All @@ -73,9 +74,15 @@ impl Catalog for DatafusionCatalog {
DataType::Decimal128(_, _) => ConstantType::Decimal,
dt => unimplemented!("{:?}", dt),
};
optd_schema.push(dt);
optd_fields.push(optd_datafusion_repr::properties::schema::Field {
name: field.name().to_string(),
typ: dt,
nullable: field.is_nullable(),
});
}
optd_datafusion_repr::properties::schema::Schema {
fields: optd_fields,
}
optd_datafusion_repr::properties::schema::Schema(optd_schema)
}
}

Expand Down
32 changes: 26 additions & 6 deletions optd-datafusion-repr/src/properties/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,19 @@ use optd_core::property::PropertyBuilder;
use crate::plan_nodes::{ConstantType, OptRelNodeTyp};

#[derive(Clone, Debug)]
pub struct Schema(pub Vec<ConstantType>);
pub struct Field {
pub name: String,
pub typ: ConstantType,
pub nullable: bool,
}
#[derive(Clone, Debug)]
pub struct Schema {
pub fields: Vec<Field>,
}

// TODO: add names, nullable to schema
impl Schema {
pub fn len(&self) -> usize {
self.0.len()
self.fields.len()
}

pub fn is_empty(&self) -> bool {
Expand Down Expand Up @@ -48,11 +55,24 @@ impl PropertyBuilder<OptRelNodeTyp> for SchemaPropertyBuilder {
OptRelNodeTyp::Filter => children[0].clone(),
OptRelNodeTyp::Join(_) => {
let mut schema = children[0].clone();
schema.0.extend(children[1].clone().0);
let schema2 = children[1].clone();
schema.fields.extend(schema2.fields);
schema
}
OptRelNodeTyp::List => Schema(vec![ConstantType::Any; children.len()]),
_ => Schema(vec![]),
OptRelNodeTyp::List => {
// TODO: calculate real is_nullable for aggregations
Schema {
fields: vec![
Field {
name: "unnamed".to_string(),
typ: ConstantType::Any,
nullable: true
};
children.len()
],
}
}
_ => Schema { fields: vec![] },
}
}

Expand Down
14 changes: 6 additions & 8 deletions optd-datafusion-repr/src/rules/joins.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ fn apply_join_commute(
cond,
JoinType::Inner,
);
let mut proj_expr = Vec::with_capacity(left_schema.0.len() + right_schema.0.len());
let mut proj_expr = Vec::with_capacity(left_schema.len() + right_schema.len());
for i in 0..left_schema.len() {
proj_expr.push(ColumnRefExpr::new(right_schema.len() + i).into_expr());
}
Expand Down Expand Up @@ -218,21 +218,19 @@ fn apply_hash_join(
let Some(mut right_expr) = ColumnRefExpr::from_rel_node(right_expr.into_rel_node()) else {
return vec![];
};
let can_convert = if left_expr.index() < left_schema.0.len()
&& right_expr.index() >= left_schema.0.len()
let can_convert = if left_expr.index() < left_schema.len()
&& right_expr.index() >= left_schema.len()
{
true
} else if right_expr.index() < left_schema.0.len()
&& left_expr.index() >= left_schema.0.len()
{
} else if right_expr.index() < left_schema.len() && left_expr.index() >= left_schema.len() {
(left_expr, right_expr) = (right_expr, left_expr);
true
} else {
false
};

if can_convert {
let right_expr = ColumnRefExpr::new(right_expr.index() - left_schema.0.len());
let right_expr = ColumnRefExpr::new(right_expr.index() - left_schema.len());
let node = PhysicalHashJoin::new(
PlanNode::from_group(left.into()),
PlanNode::from_group(right.into()),
Expand Down Expand Up @@ -342,7 +340,7 @@ fn apply_projection_pull_up_join(
.into_rel_node(),
);
}

Expr::from_rel_node(
RelNode {
typ: expr.typ.clone(),
Expand Down
Loading