Skip to content

Commit

Permalink
Pass ConnectorSession to PlanChecker.validate
Browse files Browse the repository at this point in the history
This allows the planchecker to pass the ConnectorSession to any
connector calls that it makes.
  • Loading branch information
rschlussel committed Feb 18, 2025
1 parent 4435387 commit 7bdad6d
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ public void validateFinalPlan(PlanNode planNode, Session session, Metadata metad
checkers.get(Stage.FINAL).forEach(checker -> checker.validate(planNode, session, metadata, warningCollector));
for (PlanCheckerProvider provider : planCheckerProviderManager.getPlanCheckerProviders()) {
for (com.facebook.presto.spi.plan.PlanChecker checker : provider.getFinalPlanCheckers()) {
checker.validate(planNode, warningCollector);
checker.validate(planNode, warningCollector, session.toConnectorSession());
}
}
}
Expand All @@ -97,7 +97,7 @@ public void validateIntermediatePlan(PlanNode planNode, Session session, Metadat
checkers.get(Stage.INTERMEDIATE).forEach(checker -> checker.validate(planNode, session, metadata, warningCollector));
for (PlanCheckerProvider provider : planCheckerProviderManager.getPlanCheckerProviders()) {
for (com.facebook.presto.spi.plan.PlanChecker checker : provider.getIntermediatePlanCheckers()) {
checker.validate(planNode, warningCollector);
checker.validate(planNode, warningCollector, session.toConnectorSession());
}
}
}
Expand All @@ -107,15 +107,18 @@ public void validatePlanFragment(PlanFragment planFragment, Session session, Met
checkers.get(Stage.FRAGMENT).forEach(checker -> checker.validateFragment(planFragment, session, metadata, warningCollector));
for (PlanCheckerProvider provider : planCheckerProviderManager.getPlanCheckerProviders()) {
for (com.facebook.presto.spi.plan.PlanChecker checker : provider.getFragmentPlanCheckers()) {
checker.validateFragment(new SimplePlanFragment(
planFragment.getId(),
planFragment.getRoot(),
planFragment.getVariables(),
planFragment.getPartitioning(),
planFragment.getTableScanSchedulingOrder(),
planFragment.getPartitioningScheme(),
planFragment.getStageExecutionDescriptor(),
planFragment.isOutputTableWriterFragment()), warningCollector);
checker.validateFragment(
new SimplePlanFragment(
planFragment.getId(),
planFragment.getRoot(),
planFragment.getVariables(),
planFragment.getPartitioning(),
planFragment.getTableScanSchedulingOrder(),
planFragment.getPartitioningScheme(),
planFragment.getStageExecutionDescriptor(),
planFragment.isOutputTableWriterFragment()),
warningCollector,
session.toConnectorSession());
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import com.facebook.airlift.json.JsonCodec;
import com.facebook.airlift.log.Logger;
import com.facebook.presto.spi.ConnectorId;
import com.facebook.presto.spi.ConnectorSession;
import com.facebook.presto.spi.NodeManager;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.TableHandle;
Expand Down Expand Up @@ -61,13 +62,13 @@ public NativePlanChecker(NodeManager nodeManager, JsonCodec<SimplePlanFragment>
}

@Override
public void validate(PlanNode planNode, WarningCollector warningCollector)
public void validate(PlanNode planNode, WarningCollector warningCollector, ConnectorSession session)
{
// NO-OP, only validating fragments
}

@Override
public void validateFragment(SimplePlanFragment planFragment, WarningCollector warningCollector)
public void validateFragment(SimplePlanFragment planFragment, WarningCollector warningCollector, ConnectorSession session)
{
if (!planFragment.getPartitioning().isCoordinatorOnly() && !isInternalSystemConnector(planFragment.getRoot())) {
runValidation(planFragment);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,15 +93,15 @@ public void testNativePlanMockValidate()
PlanConversionResponse responseOk = new PlanConversionResponse(ImmutableList.of());
String responseOkString = PLAN_CONVERSION_RESPONSE_JSON_CODEC.toJson(responseOk);
server.enqueue(new MockResponse().setBody(responseOkString));
checker.validateFragment(fragment, null);
checker.validateFragment(fragment, null, null);

String errorMessage = "native conversion error";
ErrorCode errorCode = StandardErrorCode.NOT_SUPPORTED.toErrorCode();
PlanConversionResponse responseError = new PlanConversionResponse(ImmutableList.of(new PlanConversionFailureInfo("MockError", errorMessage, null, ImmutableList.of(), ImmutableList.of(), errorCode)));
String responseErrorString = PLAN_CONVERSION_RESPONSE_JSON_CODEC.toJson(responseError);
server.enqueue(new MockResponse().setResponseCode(500).setBody(responseErrorString));
PrestoException error = expectThrows(PrestoException.class,
() -> checker.validateFragment(fragment, null));
() -> checker.validateFragment(fragment, null, null));
assertEquals(error.getErrorCode(), errorCode);
assertTrue(error.getMessage().contains(errorMessage));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,15 @@

package com.facebook.presto.spi.plan;

import com.facebook.presto.spi.ConnectorSession;
import com.facebook.presto.spi.WarningCollector;

public interface PlanChecker
{
void validate(PlanNode planNode, WarningCollector warningCollector);
void validate(PlanNode planNode, WarningCollector warningCollector, ConnectorSession session);

default void validateFragment(SimplePlanFragment planFragment, WarningCollector warningCollector)
default void validateFragment(SimplePlanFragment planFragment, WarningCollector warningCollector, ConnectorSession session)
{
validate(planFragment.getRoot(), warningCollector);
validate(planFragment.getRoot(), warningCollector, session);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
*/
package com.facebook.presto.execution;

import com.facebook.presto.spi.ConnectorSession;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.StandardErrorCode;
import com.facebook.presto.spi.WarningCollector;
Expand All @@ -36,7 +37,7 @@ public TriggerFailurePlanChecker(AtomicBoolean triggerValidationFailure)
}

@Override
public void validate(PlanNode planNode, WarningCollector warningCollector)
public void validate(PlanNode planNode, WarningCollector warningCollector, ConnectorSession session)
{
if (triggerValidationFailure.get()) {
throw new PrestoException(FAILURE_ERROR_CODE, FAILURE_MESSAGE);
Expand Down

0 comments on commit 7bdad6d

Please sign in to comment.