Skip to content

Commit

Permalink
Add QueryType to AccessControlContext
Browse files Browse the repository at this point in the history
  • Loading branch information
RohanSidhu authored and NikhilCollooru committed Sep 27, 2024
1 parent 375dd5a commit 8385163
Show file tree
Hide file tree
Showing 17 changed files with 187 additions and 111 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
public class TestRangerBasedAccessControl
{
public static final ConnectorTransactionHandle TRANSACTION_HANDLE = new ConnectorTransactionHandle() {};
public static final AccessControlContext CONTEXT = new AccessControlContext(new QueryId("query_id"), Optional.empty(), Collections.emptySet(), Optional.empty(), WarningCollector.NOOP, new RuntimeStats());
public static final AccessControlContext CONTEXT = new AccessControlContext(new QueryId("query_id"), Optional.empty(), Collections.emptySet(), Optional.empty(), WarningCollector.NOOP, new RuntimeStats(), Optional.empty());

@Test
public void testTablePriviledgesRolesNotAllowed()
Expand Down
157 changes: 97 additions & 60 deletions presto-main/src/main/java/com/facebook/presto/Session.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import com.facebook.presto.common.RuntimeStats;
import com.facebook.presto.common.function.SqlFunctionProperties;
import com.facebook.presto.common.resourceGroups.QueryType;
import com.facebook.presto.common.transaction.TransactionId;
import com.facebook.presto.common.type.TimeZoneKey;
import com.facebook.presto.cost.PlanCostEstimate;
Expand Down Expand Up @@ -99,6 +100,7 @@ public final class Session
private final Optional<Tracer> tracer;
private final WarningCollector warningCollector;
private final RuntimeStats runtimeStats;
private final Optional<QueryType> queryType;

private final OptimizerInformationCollector optimizerInformationCollector = new OptimizerInformationCollector();
private final OptimizerResultCollector optimizerResultCollector = new OptimizerResultCollector();
Expand Down Expand Up @@ -131,7 +133,8 @@ public Session(
Map<SqlFunctionId, SqlInvokedFunction> sessionFunctions,
Optional<Tracer> tracer,
WarningCollector warningCollector,
RuntimeStats runtimeStats)
RuntimeStats runtimeStats,
Optional<QueryType> queryType)
{
this.queryId = requireNonNull(queryId, "queryId is null");
this.transactionId = requireNonNull(transactionId, "transactionId is null");
Expand Down Expand Up @@ -172,7 +175,8 @@ public Session(
this.tracer = requireNonNull(tracer, "tracer is null");
this.warningCollector = requireNonNull(warningCollector, "warningCollector is null");
this.runtimeStats = requireNonNull(runtimeStats, "runtimeStats is null");
this.context = new AccessControlContext(queryId, clientInfo, clientTags, source, warningCollector, runtimeStats);
this.queryType = requireNonNull(queryType, "queryType is null");
this.context = new AccessControlContext(queryId, clientInfo, clientTags, source, warningCollector, runtimeStats, queryType);
}

public QueryId getQueryId()
Expand Down Expand Up @@ -353,6 +357,11 @@ public Map<PlanNodeId, PlanCostEstimate> getPlanNodeCostMap()
return planNodeCostMap;
}

public Optional<QueryType> getQueryType()
{
return queryType;
}

public Session beginTransactionId(TransactionId transactionId, TransactionManager transactionManager, AccessControl accessControl)
{
requireNonNull(transactionId, "transactionId is null");
Expand Down Expand Up @@ -447,63 +456,8 @@ public Session beginTransactionId(TransactionId transactionId, TransactionManage
sessionFunctions,
tracer,
warningCollector,
runtimeStats);
}

public Session withDefaultProperties(
SystemSessionPropertyConfiguration systemPropertyConfiguration,
Map<String, Map<String, String>> catalogPropertyDefaults)
{
requireNonNull(systemPropertyConfiguration, "systemPropertyConfiguration is null");
requireNonNull(catalogPropertyDefaults, "catalogPropertyDefaults is null");

// to remove this check properties must be authenticated and validated as in beginTransactionId
checkState(
!this.transactionId.isPresent() && this.connectorProperties.isEmpty(),
"Session properties cannot be overridden once a transaction is active");

Map<String, String> systemProperties = new HashMap<>();
systemProperties.putAll(systemPropertyConfiguration.systemPropertyDefaults);
systemProperties.putAll(this.systemProperties);
systemProperties.putAll(systemPropertyConfiguration.systemPropertyOverrides);

Map<String, Map<String, String>> connectorProperties = catalogPropertyDefaults.entrySet().stream()
.map(entry -> Maps.immutableEntry(entry.getKey(), new HashMap<>(entry.getValue())))
.collect(Collectors.toMap(Entry::getKey, Entry::getValue));
for (Entry<String, Map<String, String>> catalogProperties : this.unprocessedCatalogProperties.entrySet()) {
String catalog = catalogProperties.getKey();
for (Entry<String, String> entry : catalogProperties.getValue().entrySet()) {
connectorProperties.computeIfAbsent(catalog, id -> new HashMap<>())
.put(entry.getKey(), entry.getValue());
}
}

return new Session(
queryId,
transactionId,
clientTransactionSupport,
identity,
source,
catalog,
schema,
traceToken,
timeZoneKey,
locale,
remoteUserAddress,
userAgent,
clientInfo,
clientTags,
resourceEstimates,
startTime,
systemProperties,
ImmutableMap.of(),
connectorProperties,
sessionPropertyManager,
preparedStatements,
sessionFunctions,
tracer,
warningCollector,
runtimeStats);
runtimeStats,
queryType);
}

public ConnectorSession toConnectorSession()
Expand Down Expand Up @@ -630,6 +584,7 @@ public static class SessionBuilder
private final SessionPropertyManager sessionPropertyManager;
private final Map<String, String> preparedStatements = new HashMap<>();
private final Map<SqlFunctionId, SqlInvokedFunction> sessionFunctions = new HashMap<>();
private Optional<QueryType> queryType = Optional.empty();
private WarningCollector warningCollector = WarningCollector.NOOP;
private RuntimeStats runtimeStats = new RuntimeStats();

Expand Down Expand Up @@ -665,6 +620,7 @@ private SessionBuilder(Session session)
this.tracer = requireNonNull(session.tracer, "tracer is null");
this.warningCollector = requireNonNull(session.warningCollector, "warningCollector is null");
this.runtimeStats = requireNonNull(session.runtimeStats, "runtimeStats is null");
this.queryType = requireNonNull(session.queryType, "queryType is null");
}

public SessionBuilder setQueryId(QueryId queryId)
Expand Down Expand Up @@ -821,11 +777,57 @@ public SessionBuilder setRuntimeStats(RuntimeStats runtimeStats)
return this;
}

public SessionBuilder setQueryType(Optional<QueryType> queryType)
{
this.queryType = requireNonNull(queryType, "queryType is null");
return this;
}

public <T> T getSystemProperty(String name, Class<T> type)
{
return sessionPropertyManager.decodeSystemPropertyValue(name, systemProperties.get(name), type);
}

public WarningCollector getWarningCollector()
{
return this.warningCollector;
}

public Map<String, String> getPreparedStatements()
{
return this.preparedStatements;
}

public Identity getIdentity()
{
return this.identity;
}

public Optional<String> getSource()
{
return Optional.ofNullable(this.source);
}

public Set<String> getClientTags()
{
return this.clientTags;
}

public Optional<String> getClientInfo()
{
return Optional.ofNullable(this.clientInfo);
}

public Map<String, String> getSystemProperties()
{
return this.systemProperties;
}

public Map<String, Map<String, String>> getUnprocessedCatalogProperties()
{
return this.catalogSessionProperties;
}

public Session build()
{
return new Session(
Expand Down Expand Up @@ -853,7 +855,42 @@ public Session build()
sessionFunctions,
tracer,
warningCollector,
runtimeStats);
runtimeStats,
queryType);
}

public void applyDefaultProperties(SystemSessionPropertyConfiguration systemPropertyConfiguration, Map<String, Map<String, String>> catalogPropertyDefaults)
{
requireNonNull(systemPropertyConfiguration, "systemPropertyConfiguration is null");
requireNonNull(catalogPropertyDefaults, "catalogPropertyDefaults is null");

// to remove this check properties must be authenticated and validated as in beginTransactionId
checkState(
this.transactionId == null && this.connectorProperties.isEmpty(),
"Session properties cannot be overridden once a transaction is active");

Map<String, String> systemProperties = new HashMap<>();
systemProperties.putAll(systemPropertyConfiguration.systemPropertyDefaults);
systemProperties.putAll(this.systemProperties);
systemProperties.putAll(systemPropertyConfiguration.systemPropertyOverrides);
this.systemProperties.putAll(systemProperties);

Map<String, Map<String, String>> connectorProperties = catalogPropertyDefaults.entrySet().stream()
.map(entry -> Maps.immutableEntry(entry.getKey(), new HashMap<>(entry.getValue())))
.collect(Collectors.toMap(Entry::getKey, Entry::getValue));
for (Entry<String, Map<String, String>> catalogProperties : this.catalogSessionProperties.entrySet()) {
String catalog = catalogProperties.getKey();
for (Entry<String, String> entry : catalogProperties.getValue().entrySet()) {
connectorProperties.computeIfAbsent(catalog, id -> new HashMap<>()).put(entry.getKey(), entry.getValue());
}
}

for (Entry<String, Map<String, String>> catalogProperties : connectorProperties.entrySet()) {
String catalog = catalogProperties.getKey();
for (Entry<String, String> entry : catalogProperties.getValue().entrySet()) {
setCatalogSessionProperty(catalog, entry.getKey(), entry.getValue());
}
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,7 @@ public Session toSession(SessionPropertyManager sessionPropertyManager, Map<Stri
Optional.empty(),
// we use NOOP to create a session from the representation as worker does not require warning collectors
WarningCollector.NOOP,
new RuntimeStats());
new RuntimeStats(),
Optional.empty());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
import java.util.Optional;
import java.util.concurrent.Executor;

import static com.facebook.presto.Session.SessionBuilder;
import static com.facebook.presto.SystemSessionProperties.getAnalyzerType;
import static com.facebook.presto.spi.StandardErrorCode.QUERY_TEXT_TOO_LARGE;
import static com.facebook.presto.util.AnalyzerUtil.createAnalyzerOptions;
Expand Down Expand Up @@ -259,6 +260,7 @@ public ListenableFuture<?> createQuery(QueryId queryId, String slug, int retryCo
private <C> void createQueryInternal(QueryId queryId, String slug, int retryCount, SessionContext sessionContext, String query, ResourceGroupManager<C> resourceGroupManager)
{
Session session = null;
SessionBuilder sessionBuilder = null;
PreparedQuery preparedQuery;
try {
if (query.length() > maxQueryLength) {
Expand All @@ -268,16 +270,18 @@ private <C> void createQueryInternal(QueryId queryId, String slug, int retryCoun
}

// decode session
session = sessionSupplier.createSession(queryId, sessionContext, warningCollectorFactory);
sessionBuilder = sessionSupplier.createSessionBuilder(queryId, sessionContext, warningCollectorFactory);
session = sessionBuilder.build();

// prepare query
AnalyzerOptions analyzerOptions = createAnalyzerOptions(session, session.getWarningCollector());
AnalyzerOptions analyzerOptions = createAnalyzerOptions(session, sessionBuilder.getWarningCollector());
QueryPreparerProvider queryPreparerProvider = queryPreparerProviderManager.getQueryPreparerProvider(getAnalyzerType(session));
preparedQuery = queryPreparerProvider.getQueryPreparer().prepareQuery(analyzerOptions, query, session.getPreparedStatements(), session.getWarningCollector());
preparedQuery = queryPreparerProvider.getQueryPreparer().prepareQuery(analyzerOptions, query, sessionBuilder.getPreparedStatements(), sessionBuilder.getWarningCollector());
query = preparedQuery.getFormattedQuery().orElse(query);

// select resource group
Optional<QueryType> queryType = preparedQuery.getQueryType();
sessionBuilder.setQueryType(queryType);
SelectionContext<C> selectionContext = resourceGroupManager.selectGroup(new SelectionCriteria(
sessionContext.getIdentity().getPrincipal().isPresent(),
sessionContext.getIdentity().getUser(),
Expand All @@ -290,7 +294,12 @@ private <C> void createQueryInternal(QueryId queryId, String slug, int retryCoun
sessionContext.getIdentity().getPrincipal().map(Principal::getName)));

// apply system default session properties (does not override user set properties)
session = sessionPropertyDefaults.newSessionWithDefaultProperties(session, queryType.map(Enum::name), Optional.of(selectionContext.getResourceGroupId()));
sessionPropertyDefaults.applyDefaultProperties(sessionBuilder, queryType.map(Enum::name), Optional.of(selectionContext.getResourceGroupId()));

session = sessionBuilder.build();
if (sessionContext.getTransactionId().isPresent()) {
session = session.beginTransactionId(sessionContext.getTransactionId().get(), transactionManager, accessControl);
}

// mark existing transaction as active
transactionManager.activateTransaction(session, preparedQuery.isTransactionControlStatement(), accessControl);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ public static void checkPermissions(AccessControl accessControl, SecurityConfig
sessionContext.getClientTags(),
Optional.ofNullable(sessionContext.getSource()),
WarningCollector.NOOP,
sessionContext.getRuntimeStats()),
sessionContext.getRuntimeStats(),
Optional.empty()),
identity.getPrincipal(),
identity.getUser());
}
Expand All @@ -71,7 +72,8 @@ public static Optional<AuthorizedIdentity> getAuthorizedIdentity(AccessControl a
sessionContext.getClientTags(),
Optional.ofNullable(sessionContext.getSource()),
WarningCollector.NOOP,
sessionContext.getRuntimeStats()),
sessionContext.getRuntimeStats(),
Optional.empty()),
identity.getUser(),
sessionContext.getCertificates());
return Optional.of(authorizedIdentity);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import com.facebook.presto.execution.warnings.WarningCollectorFactory;
import com.facebook.presto.spi.QueryId;

import static com.facebook.presto.Session.SessionBuilder;

/**
* Used on workers.
*/
Expand All @@ -28,4 +30,10 @@ public Session createSession(QueryId queryId, SessionContext context, WarningCol
{
throw new UnsupportedOperationException();
}

@Override
public SessionBuilder createSessionBuilder(QueryId queryId, SessionContext context, WarningCollectorFactory warningCollectorFactory)
{
throw new UnsupportedOperationException();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,16 @@ public QuerySessionSupplier(

@Override
public Session createSession(QueryId queryId, SessionContext context, WarningCollectorFactory warningCollectorFactory)
{
Session session = createSessionBuilder(queryId, context, warningCollectorFactory).build();
if (context.getTransactionId().isPresent()) {
session = session.beginTransactionId(context.getTransactionId().get(), transactionManager, accessControl);
}
return session;
}

@Override
public SessionBuilder createSessionBuilder(QueryId queryId, SessionContext context, WarningCollectorFactory warningCollectorFactory)
{
SessionBuilder sessionBuilder = Session.builder(sessionPropertyManager)
.setQueryId(queryId)
Expand Down Expand Up @@ -128,11 +138,7 @@ else if (context.getTimeZoneId() != null) {
WarningCollector warningCollector = warningCollectorFactory.create(sessionBuilder.getSystemProperty(WARNING_HANDLING, WarningHandlingLevel.class));
sessionBuilder.setWarningCollector(warningCollector);

Session session = sessionBuilder.build();
if (context.getTransactionId().isPresent()) {
session = session.beginTransactionId(context.getTransactionId().get(), transactionManager, accessControl);
}
return session;
return sessionBuilder;
}

private Identity authenticateIdentity(QueryId queryId, SessionContext context)
Expand Down
Loading

0 comments on commit 8385163

Please sign in to comment.