diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java index 25607c03c9c86..43c304444bbd9 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java @@ -889,7 +889,7 @@ public PlanOptimizers( // MergeJoinForSortedInputOptimizer can avoid the local exchange for a join operation // Should be placed after AddExchanges, but before AddLocalExchange // To replace the JoinNode to MergeJoin ahead of AddLocalExchange to avoid adding extra local exchange - builder.add(new MergeJoinForSortedInputOptimizer(metadata)); + builder.add(new MergeJoinForSortedInputOptimizer(metadata, featuresConfig.isNativeExecutionEnabled())); // Optimizers above this don't understand local exchanges, so be careful moving this. builder.add(new AddLocalExchanges(metadata, featuresConfig.isNativeExecutionEnabled())); @@ -958,7 +958,7 @@ public PlanOptimizers( statsCalculator, costCalculator, ImmutableList.of(), - ImmutableSet.of(new RuntimeReorderJoinSides(metadata)))); + ImmutableSet.of(new RuntimeReorderJoinSides(metadata, featuresConfig.isNativeExecutionEnabled())))); this.runtimeOptimizers = runtimeBuilder.build(); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/JoinSwappingUtils.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/JoinSwappingUtils.java index 670ce15d9a344..f548fc8ac5b2c 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/JoinSwappingUtils.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/JoinSwappingUtils.java @@ -64,7 +64,8 @@ public static Optional createRuntimeSwappedJoinNode( Metadata metadata, Lookup lookup, Session session, - PlanNodeIdAllocator idAllocator) + PlanNodeIdAllocator idAllocator, + boolean nativeExecution) { JoinNode swapped = joinNode.flipChildren(); @@ -76,7 +77,7 @@ public static Optional createRuntimeSwappedJoinNode( PlanNode resolvedSwappedLeft = lookup.resolve(newLeft); if (resolvedSwappedLeft instanceof ExchangeNode && resolvedSwappedLeft.getSources().size() == 1) { // Ensure the new probe after skipping the local exchange will satisfy the required probe side property - if (checkProbeSidePropertySatisfied(resolvedSwappedLeft.getSources().get(0), metadata, lookup, session)) { + if (checkProbeSidePropertySatisfied(resolvedSwappedLeft.getSources().get(0), metadata, lookup, session, nativeExecution)) { newLeft = resolvedSwappedLeft.getSources().get(0); // The HashGenerationOptimizer will generate hashVariables and append to the output layout of the nodes following the same order. Therefore, // we use the index of the old hashVariable in the ExchangeNode output layout to retrieve the hashVariable from the new left node, and feed @@ -100,7 +101,7 @@ public static Optional createRuntimeSwappedJoinNode( .map(EquiJoinClause::getRight) .collect(toImmutableList()); PlanNode newRight = swapped.getRight(); - if (!checkBuildSidePropertySatisfied(swapped.getRight(), buildJoinVariables, metadata, lookup, session)) { + if (!checkBuildSidePropertySatisfied(swapped.getRight(), buildJoinVariables, metadata, lookup, session, nativeExecution)) { if (getTaskConcurrency(session) > 1) { newRight = systemPartitionedExchange( idAllocator.getNextId(), @@ -132,7 +133,7 @@ public static Optional createRuntimeSwappedJoinNode( } // Check if the new probe side after removing unnecessary local exchange is valid. - public static boolean checkProbeSidePropertySatisfied(PlanNode node, Metadata metadata, Lookup lookup, Session session) + public static boolean checkProbeSidePropertySatisfied(PlanNode node, Metadata metadata, Lookup lookup, Session session, boolean nativeExecution) { StreamPreferredProperties requiredProbeProperty; if (isSpillEnabled(session) && isJoinSpillingEnabled(session)) { @@ -141,7 +142,7 @@ public static boolean checkProbeSidePropertySatisfied(PlanNode node, Metadata me else { requiredProbeProperty = defaultParallelism(session); } - StreamPropertyDerivations.StreamProperties nodeProperty = derivePropertiesRecursively(node, metadata, lookup, session); + StreamPropertyDerivations.StreamProperties nodeProperty = derivePropertiesRecursively(node, metadata, lookup, session, nativeExecution); return requiredProbeProperty.isSatisfiedBy(nodeProperty); } @@ -151,7 +152,8 @@ private static boolean checkBuildSidePropertySatisfied( List partitioningColumns, Metadata metadata, Lookup lookup, - Session session) + Session session, + boolean nativeExecution) { StreamPreferredProperties requiredBuildProperty; if (getTaskConcurrency(session) > 1) { @@ -160,7 +162,7 @@ private static boolean checkBuildSidePropertySatisfied( else { requiredBuildProperty = singleStream(); } - StreamPropertyDerivations.StreamProperties nodeProperty = derivePropertiesRecursively(node, metadata, lookup, session); + StreamPropertyDerivations.StreamProperties nodeProperty = derivePropertiesRecursively(node, metadata, lookup, session, nativeExecution); return requiredBuildProperty.isSatisfiedBy(nodeProperty); } @@ -168,13 +170,14 @@ private static StreamPropertyDerivations.StreamProperties derivePropertiesRecurs PlanNode node, Metadata metadata, Lookup lookup, - Session session) + Session session, + boolean nativeExecution) { PlanNode actual = lookup.resolve(node); List inputProperties = actual.getSources().stream() - .map(source -> derivePropertiesRecursively(source, metadata, lookup, session)) + .map(source -> derivePropertiesRecursively(source, metadata, lookup, session, nativeExecution)) .collect(toImmutableList()); - return StreamPropertyDerivations.deriveProperties(actual, inputProperties, metadata, session); + return StreamPropertyDerivations.deriveProperties(actual, inputProperties, metadata, session, nativeExecution); } public static boolean isBelowBroadcastLimit(PlanNode planNode, Rule.Context context) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RuntimeReorderJoinSides.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RuntimeReorderJoinSides.java index 923350ad74774..e956e84d15535 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RuntimeReorderJoinSides.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RuntimeReorderJoinSides.java @@ -44,10 +44,12 @@ public class RuntimeReorderJoinSides private static final Pattern PATTERN = join(); private final Metadata metadata; + private final boolean nativeExecution; - public RuntimeReorderJoinSides(Metadata metadata) + public RuntimeReorderJoinSides(Metadata metadata, boolean nativeExecution) { this.metadata = requireNonNull(metadata, "metadata is null"); + this.nativeExecution = nativeExecution; } @Override @@ -97,7 +99,7 @@ public Result apply(JoinNode joinNode, Captures captures, Context context) return Result.empty(); } - Optional rewrittenNode = createRuntimeSwappedJoinNode(joinNode, metadata, context.getLookup(), context.getSession(), context.getIdAllocator()); + Optional rewrittenNode = createRuntimeSwappedJoinNode(joinNode, metadata, context.getLookup(), context.getSession(), context.getIdAllocator(), nativeExecution); if (rewrittenNode.isPresent()) { log.debug(format("Probe size: %.2f is smaller than Build size: %.2f => invoke runtime join swapping on JoinNode ID: %s.", leftOutputSizeInBytes, rightOutputSizeInBytes, joinNode.getId())); return Result.ofPlanNode(rewrittenNode.get()); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/AddLocalExchanges.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/AddLocalExchanges.java index 975130d4307d0..7158a92de6592 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/AddLocalExchanges.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/AddLocalExchanges.java @@ -887,7 +887,7 @@ public PlanWithProperties visitIndexJoin(IndexJoinNode node, StreamPreferredProp parentPreferences.constrainTo(node.getProbeSource().getOutputVariables()).withDefaultParallelism(session)); // index source does not support local parallel and must produce a single stream - StreamProperties indexStreamProperties = derivePropertiesRecursively(node.getIndexSource(), metadata, session); + StreamProperties indexStreamProperties = derivePropertiesRecursively(node.getIndexSource(), metadata, session, nativeExecution); checkArgument(indexStreamProperties.getDistribution() == SINGLE, "index source must be single stream"); PlanWithProperties index = new PlanWithProperties(node.getIndexSource(), indexStreamProperties); @@ -983,12 +983,12 @@ private PlanWithProperties rebaseAndDeriveProperties(PlanNode node, List inputProperties) { - return new PlanWithProperties(result, StreamPropertyDerivations.deriveProperties(result, inputProperties, metadata, session)); + return new PlanWithProperties(result, StreamPropertyDerivations.deriveProperties(result, inputProperties, metadata, session, nativeExecution)); } private PlanWithProperties accept(PlanNode node, StreamPreferredProperties context) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/MergeJoinForSortedInputOptimizer.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/MergeJoinForSortedInputOptimizer.java index 4b0e6b7564a7e..fc1735f7e7e70 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/MergeJoinForSortedInputOptimizer.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/MergeJoinForSortedInputOptimizer.java @@ -40,11 +40,13 @@ public class MergeJoinForSortedInputOptimizer implements PlanOptimizer { private final Metadata metadata; + private final boolean nativeExecution; private boolean isEnabledForTesting; - public MergeJoinForSortedInputOptimizer(Metadata metadata) + public MergeJoinForSortedInputOptimizer(Metadata metadata, boolean nativeExecution) { this.metadata = requireNonNull(metadata, "metadata is null"); + this.nativeExecution = nativeExecution; } @Override @@ -139,8 +141,8 @@ public PlanNode visitJoin(JoinNode node, RewriteContext context) private boolean meetsDataRequirement(PlanNode left, PlanNode right, JoinNode node) { // Acquire data properties for both left and right side - StreamPropertyDerivations.StreamProperties leftProperties = StreamPropertyDerivations.derivePropertiesRecursively(left, metadata, session); - StreamPropertyDerivations.StreamProperties rightProperties = StreamPropertyDerivations.derivePropertiesRecursively(right, metadata, session); + StreamPropertyDerivations.StreamProperties leftProperties = StreamPropertyDerivations.derivePropertiesRecursively(left, metadata, session, nativeExecution); + StreamPropertyDerivations.StreamProperties rightProperties = StreamPropertyDerivations.derivePropertiesRecursively(right, metadata, session, nativeExecution); List leftJoinColumns = node.getCriteria().stream().map(EquiJoinClause::getLeft).collect(toImmutableList()); List rightJoinColumns = node.getCriteria().stream() diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/StreamPropertyDerivations.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/StreamPropertyDerivations.java index 8dbaa277c1e19..d99f33439d1f5 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/StreamPropertyDerivations.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/StreamPropertyDerivations.java @@ -41,6 +41,7 @@ import com.facebook.presto.spi.plan.WindowNode; import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.optimizations.StreamPropertyDerivations.StreamProperties.StreamDistribution; import com.facebook.presto.sql.planner.plan.ApplyNode; import com.facebook.presto.sql.planner.plan.AssignUniqueId; import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode; @@ -65,8 +66,6 @@ import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; -import javax.annotation.concurrent.Immutable; - import java.util.HashMap; import java.util.HashSet; import java.util.List; @@ -96,20 +95,20 @@ public final class StreamPropertyDerivations { private StreamPropertyDerivations() {} - public static StreamProperties derivePropertiesRecursively(PlanNode node, Metadata metadata, Session session) + public static StreamProperties derivePropertiesRecursively(PlanNode node, Metadata metadata, Session session, boolean nativeExecution) { List inputProperties = node.getSources().stream() - .map(source -> derivePropertiesRecursively(source, metadata, session)) + .map(source -> derivePropertiesRecursively(source, metadata, session, nativeExecution)) .collect(toImmutableList()); - return StreamPropertyDerivations.deriveProperties(node, inputProperties, metadata, session); + return StreamPropertyDerivations.deriveProperties(node, inputProperties, metadata, session, nativeExecution); } - public static StreamProperties deriveProperties(PlanNode node, StreamProperties inputProperties, Metadata metadata, Session session) + public static StreamProperties deriveProperties(PlanNode node, StreamProperties inputProperties, Metadata metadata, Session session, boolean nativeExecution) { - return deriveProperties(node, ImmutableList.of(inputProperties), metadata, session); + return deriveProperties(node, ImmutableList.of(inputProperties), metadata, session, nativeExecution); } - public static StreamProperties deriveProperties(PlanNode node, List inputProperties, Metadata metadata, Session session) + public static StreamProperties deriveProperties(PlanNode node, List inputProperties, Metadata metadata, Session session, boolean nativeExecution) { requireNonNull(node, "node is null"); requireNonNull(inputProperties, "inputProperties is null"); @@ -127,7 +126,7 @@ public static StreamProperties deriveProperties(PlanNode node, List @@ -147,11 +146,13 @@ private static class Visitor { private final Metadata metadata; private final Session session; + private final boolean nativeExecution; - private Visitor(Metadata metadata, Session session) + private Visitor(Metadata metadata, Session session, boolean nativeExecution) { this.metadata = metadata; this.session = session; + this.nativeExecution = nativeExecution; } @Override @@ -291,13 +292,16 @@ public StreamProperties visitTableScan(TableScanNode node, List> streamPartitionSymbols = layout.getStreamPartitioningColumns() .flatMap(columns -> getNonConstantVariables(columns, assignments, constants)); + // Native execution creates a fixed number of drivers for TableScan pipelines + StreamDistribution streamDistribution = nativeExecution ? FIXED : MULTIPLE; + // if we are partitioned on empty set, we must say multiple of unknown partitioning, because // the connector does not guarantee a single split in this case (since it might not understand // that the value is a constant). if (streamPartitionSymbols.isPresent() && streamPartitionSymbols.get().isEmpty()) { - return new StreamProperties(MULTIPLE, Optional.empty(), false); + return new StreamProperties(streamDistribution, Optional.empty(), false); } - return new StreamProperties(MULTIPLE, streamPartitionSymbols, false); + return new StreamProperties(streamDistribution, streamPartitionSymbols, false); } private Optional> getNonConstantVariables(Set columnHandles, Map assignments, Set globalConstants) @@ -633,7 +637,6 @@ public StreamProperties visitRemoteSource(RemoteSourceNode node, List visitAggregation(AggregationNode node, Void conte if (!seenExchanges.localRepartitionExchange) { // No local repartition exchange between final and partial aggregation. // Make sure that final aggregation operators are executed by single thread. - StreamProperties localProperties = StreamPropertyDerivations.derivePropertiesRecursively(node, metadata, session); + StreamProperties localProperties = StreamPropertyDerivations.derivePropertiesRecursively(node, metadata, session, nativeExecution); checkArgument(localProperties.isSingleStream(), "Final aggregation with default value not separated from partial aggregation by local hash exchange"); } diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateStreamingAggregations.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateStreamingAggregations.java index 39ae75000ab4a..cfe040826bdcf 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateStreamingAggregations.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateStreamingAggregations.java @@ -41,10 +41,17 @@ public class ValidateStreamingAggregations implements Checker { + private final boolean nativeExecution; + + public ValidateStreamingAggregations(boolean nativeExecution) + { + this.nativeExecution = nativeExecution; + } + @Override public void validate(PlanNode planNode, Session session, Metadata metadata, WarningCollector warningCollector) { - planNode.accept(new Visitor(session, metadata), null); + planNode.accept(new Visitor(session, metadata, nativeExecution), null); } private static final class Visitor @@ -52,11 +59,13 @@ private static final class Visitor { private final Session session; private final Metadata metadata; + private final boolean nativeExecution; - private Visitor(Session session, Metadata metadata) + private Visitor(Session session, Metadata metadata, boolean nativeExecution) { this.session = session; this.metadata = metadata; + this.nativeExecution = nativeExecution; } @Override @@ -73,7 +82,7 @@ public Void visitAggregation(AggregationNode node, Void context) return null; } - StreamProperties properties = derivePropertiesRecursively(node.getSource(), metadata, session); + StreamProperties properties = derivePropertiesRecursively(node.getSource(), metadata, session, nativeExecution); List> desiredProperties = ImmutableList.of(new GroupingProperty<>(node.getPreGroupedVariables())); Iterator>> matchIterator = LocalProperties.match(properties.getLocalProperties(), desiredProperties).iterator(); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateStreamingJoins.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateStreamingJoins.java index 52196b42ad564..9baf188be023b 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateStreamingJoins.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/sanity/ValidateStreamingJoins.java @@ -100,7 +100,7 @@ public Void visitJoin(JoinNode node, Void context) else { requiredBuildProperty = singleStream(); } - StreamProperties buildProperties = derivePropertiesRecursively(node.getRight(), metadata, session); + StreamProperties buildProperties = derivePropertiesRecursively(node.getRight(), metadata, session, nativeExecutionEnabled); checkArgument(requiredBuildProperty.isSatisfiedBy(buildProperties), "Build side needs an additional local exchange for join: %s", node.getId()); StreamPreferredProperties requiredProbeProperty; @@ -110,7 +110,7 @@ public Void visitJoin(JoinNode node, Void context) else { requiredProbeProperty = defaultParallelism(session); } - StreamProperties probeProperties = derivePropertiesRecursively(node.getLeft(), metadata, session); + StreamProperties probeProperties = derivePropertiesRecursively(node.getLeft(), metadata, session, nativeExecutionEnabled); checkArgument(requiredProbeProperty.isSatisfiedBy(probeProperties), "Probe side needs an additional local exchange for join: %s", node.getId()); } return null; diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRuntimeReorderJoinSides.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRuntimeReorderJoinSides.java index f67b95de41c15..d5e4d97f6e17a 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRuntimeReorderJoinSides.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestRuntimeReorderJoinSides.java @@ -276,6 +276,6 @@ public void testFlipsAndAdjustExchangeWhenProbeSideSmaller() private RuleAssert assertReorderJoinSides() { - return tester.assertThat(new RuntimeReorderJoinSides(tester.getMetadata())); + return tester.assertThat(new RuntimeReorderJoinSides(tester.getMetadata(), false)); } } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/sanity/TestValidateAggregationsWithDefaultValues.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/sanity/TestValidateAggregationsWithDefaultValues.java index 1f3d7ec8f109f..04c8bd65513fc 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/sanity/TestValidateAggregationsWithDefaultValues.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/sanity/TestValidateAggregationsWithDefaultValues.java @@ -188,7 +188,7 @@ private void validatePlan(PlanNode root, boolean noExchange) getQueryRunner().inTransaction(session -> { // metadata.getCatalogHandle() registers the catalog for the transaction session.getCatalog().ifPresent(catalog -> metadata.getCatalogHandle(session, catalog)); - new ValidateAggregationsWithDefaultValues(noExchange).validate(root, session, metadata, WarningCollector.NOOP); + new ValidateAggregationsWithDefaultValues(noExchange, false).validate(root, session, metadata, WarningCollector.NOOP); return null; }); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/sanity/TestValidateStreamingAggregations.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/sanity/TestValidateStreamingAggregations.java index 831b210228da0..46028c7178011 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/sanity/TestValidateStreamingAggregations.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/sanity/TestValidateStreamingAggregations.java @@ -112,7 +112,8 @@ private void validatePlan(Function planProvider) getQueryRunner().inTransaction(session -> { // metadata.getCatalogHandle() registers the catalog for the transaction session.getCatalog().ifPresent(catalog -> metadata.getCatalogHandle(session, catalog)); - new ValidateStreamingAggregations().validate(planNode, session, metadata, WarningCollector.NOOP); + new ValidateStreamingAggregations(true).validate(planNode, session, metadata, WarningCollector.NOOP); + new ValidateStreamingAggregations(false).validate(planNode, session, metadata, WarningCollector.NOOP); return null; }); } diff --git a/presto-spark-base/src/main/java/com/facebook/presto/spark/planner/optimizers/AdaptivePlanOptimizers.java b/presto-spark-base/src/main/java/com/facebook/presto/spark/planner/optimizers/AdaptivePlanOptimizers.java index b4d3e1053b042..03f708a061fe9 100644 --- a/presto-spark-base/src/main/java/com/facebook/presto/spark/planner/optimizers/AdaptivePlanOptimizers.java +++ b/presto-spark-base/src/main/java/com/facebook/presto/spark/planner/optimizers/AdaptivePlanOptimizers.java @@ -17,6 +17,7 @@ import com.facebook.presto.cost.CostCalculator; import com.facebook.presto.cost.StatsCalculator; import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.sql.analyzer.FeaturesConfig; import com.facebook.presto.sql.planner.OptimizerStatsRecorder; import com.facebook.presto.sql.planner.RuleStatsRecorder; import com.facebook.presto.sql.planner.iterative.IterativeOptimizer; @@ -47,10 +48,16 @@ public AdaptivePlanOptimizers( MBeanExporter exporter, Metadata metadata, StatsCalculator statsCalculator, - CostCalculator costCalculator) + CostCalculator costCalculator, + FeaturesConfig featuresConfig) { this.exporter = exporter; - this.adaptiveOptimizers = ImmutableList.of(new IterativeOptimizer(metadata, ruleStats, statsCalculator, costCalculator, ImmutableSet.of(new PickJoinSides(metadata)))); + this.adaptiveOptimizers = ImmutableList.of(new IterativeOptimizer( + metadata, + ruleStats, + statsCalculator, + costCalculator, + ImmutableSet.of(new PickJoinSides(metadata, featuresConfig.isNativeExecutionEnabled())))); } @PostConstruct diff --git a/presto-spark-base/src/main/java/com/facebook/presto/spark/planner/optimizers/PickJoinSides.java b/presto-spark-base/src/main/java/com/facebook/presto/spark/planner/optimizers/PickJoinSides.java index f2abc6b253640..12f600247f85b 100644 --- a/presto-spark-base/src/main/java/com/facebook/presto/spark/planner/optimizers/PickJoinSides.java +++ b/presto-spark-base/src/main/java/com/facebook/presto/spark/planner/optimizers/PickJoinSides.java @@ -71,11 +71,13 @@ public class PickJoinSides // changing the distribution type too && !(joinNode.getCriteria().isEmpty() && (joinNode.getType() == LEFT || joinNode.getType() == RIGHT))); - private Metadata metadata; + private final Metadata metadata; + private final boolean nativeExecution; - public PickJoinSides(Metadata metadata) + public PickJoinSides(Metadata metadata, boolean nativeExecution) { this.metadata = requireNonNull(metadata, "metadata is null"); + this.nativeExecution = nativeExecution; } @Override @@ -101,7 +103,7 @@ public Result apply(JoinNode joinNode, Captures captures, Context context) // if we don't have exact costs for the join, but based on source tables we think the left side // is very small or much smaller than the right, then flip the join. if (rightSize > leftSize || (isSizeBasedJoinDistributionTypeEnabled(context.getSession()) && (Double.isNaN(leftSize) || Double.isNaN(rightSize)) && isLeftSideSmall(joinNode, context))) { - rewrittenNode = createRuntimeSwappedJoinNode(joinNode, metadata, context.getLookup(), context.getSession(), context.getIdAllocator()); + rewrittenNode = createRuntimeSwappedJoinNode(joinNode, metadata, context.getLookup(), context.getSession(), context.getIdAllocator(), nativeExecution); } return rewrittenNode.map(Result::ofPlanNode).orElseGet(Result::empty); diff --git a/presto-spark-base/src/test/java/com/facebook/presto/spark/planner/optimizers/TestPickJoinSides.java b/presto-spark-base/src/test/java/com/facebook/presto/spark/planner/optimizers/TestPickJoinSides.java index 30d93ce26eb80..c1565f3dbfb42 100644 --- a/presto-spark-base/src/test/java/com/facebook/presto/spark/planner/optimizers/TestPickJoinSides.java +++ b/presto-spark-base/src/test/java/com/facebook/presto/spark/planner/optimizers/TestPickJoinSides.java @@ -381,7 +381,7 @@ public void testDoesNotFireWhenDisabled() { int aSize = 100; int bSize = 10_000; - tester.assertThat(new PickJoinSides(tester.getMetadata())) + tester.assertThat(new PickJoinSides(tester.getMetadata(), false)) .setSystemProperty(ADAPTIVE_JOIN_SIDE_SWITCHING_ENABLED, "false") .overrideStats("valuesA", PlanNodeStatsEstimate.builder() .setTotalSize(aSize) @@ -472,7 +472,7 @@ public void testDoesNotFireForRightCrossJoin() private RuleAssert assertPickJoinSides() { - return tester.assertThat(new PickJoinSides(tester.getMetadata())) + return tester.assertThat(new PickJoinSides(tester.getMetadata(), false)) .setSystemProperty(JOIN_MAX_BROADCAST_TABLE_SIZE, "100MB") .setSystemProperty(ADAPTIVE_JOIN_SIDE_SWITCHING_ENABLED, "true"); }