Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Port most of the OAUTH2 support from Trino to Presto #24443

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 81 additions & 14 deletions presto-main/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,33 @@
</properties>

<dependencies>
<dependency>
<groupId>com.squareup.okhttp3</groupId>
<artifactId>okhttp</artifactId>
<version>3.9.0</version>
<exclusions>
<exclusion>
<groupId>org.jetbrains.kotlin</groupId>
<artifactId>kotlin-stdlib</artifactId>
</exclusion>
</exclusions>
</dependency>

<dependency>
<groupId>net.jodah</groupId>
<artifactId>failsafe</artifactId>
</dependency>

<dependency>
<groupId>io.jsonwebtoken</groupId>
<artifactId>jjwt-api</artifactId>
</dependency>

<dependency>
<groupId>io.jsonwebtoken</groupId>
<artifactId>jjwt-jackson</artifactId>
</dependency>

<dependency>
<groupId>com.esri.geometry</groupId>
<artifactId>esri-geometry-api</artifactId>
Expand Down Expand Up @@ -134,6 +161,12 @@
<dependency>
<groupId>com.facebook.airlift</groupId>
<artifactId>http-server</artifactId>
<exclusions>
<exclusion>
<groupId>io.jsonwebtoken</groupId>
<artifactId>jjwt</artifactId>
</exclusion>
</exclusions>
</dependency>

<dependency>
Expand Down Expand Up @@ -357,20 +390,6 @@
<artifactId>jts-core</artifactId>
</dependency>

<dependency>
<groupId>io.jsonwebtoken</groupId>
<artifactId>jjwt-api</artifactId>
</dependency>
<dependency>
<groupId>io.jsonwebtoken</groupId>
<artifactId>jjwt-impl</artifactId>
<scope>runtime</scope>
</dependency>
<dependency>
<groupId>io.jsonwebtoken</groupId>
<artifactId>jjwt-jackson</artifactId>
</dependency>

<dependency>
<groupId>org.apache.datasketches</groupId>
<artifactId>datasketches-memory</artifactId>
Expand Down Expand Up @@ -514,6 +533,51 @@
<artifactId>mockwebserver</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>com.nimbusds</groupId>
<artifactId>nimbus-jose-jwt</artifactId>
<version>9.14</version>
</dependency>

<dependency>
<groupId>com.nimbusds</groupId>
<artifactId>oauth2-oidc-sdk</artifactId>
<version>9.18</version>
<exclusions>
<exclusion>
<groupId>org.aw2</groupId>
<artifactId>asm</artifactId>
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-lang3</artifactId>
<version>3.12.0</version>
<scope>compile</scope>
</dependency>
<dependency>
<groupId>io.jsonwebtoken</groupId>
<artifactId>jjwt-impl</artifactId>
<scope>compile</scope>
</dependency>
<dependency>
<groupId>org.testcontainers</groupId>
<artifactId>testcontainers</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.testcontainers</groupId>
<artifactId>postgresql</artifactId>
<exclusions>
<exclusion>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
</exclusion>
</exclusions>
<scope>test</scope>
</dependency>
</dependencies>

<build>
Expand Down Expand Up @@ -574,6 +638,9 @@
<ignorePackages>
<ignorePackage>com.facebook.presto.testing.assertions</ignorePackage>
</ignorePackages>
<ignoreClassNamePatterns>
<ignoreClassNamePattern>com/facebook/presto/server/MockHttpServletRequest</ignoreClassNamePattern>
</ignoreClassNamePatterns>
</configuration>
</plugin>
<plugin>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@
import com.facebook.presto.security.AccessControlModule;
import com.facebook.presto.server.security.PasswordAuthenticatorManager;
import com.facebook.presto.server.security.PrestoAuthenticatorManager;
import com.facebook.presto.server.security.SecurityConfig;
import com.facebook.presto.server.security.ServerSecurityModule;
import com.facebook.presto.server.security.oauth2.OAuth2Client;
import com.facebook.presto.sql.analyzer.FeaturesConfig;
import com.facebook.presto.sql.parser.SqlParserOptions;
import com.facebook.presto.sql.planner.sanity.PlanCheckerProviderManager;
Expand Down Expand Up @@ -83,6 +85,7 @@
import static com.facebook.airlift.json.JsonBinder.jsonBinder;
import static com.facebook.presto.server.PrestoSystemRequirements.verifyJvmRequirements;
import static com.facebook.presto.server.PrestoSystemRequirements.verifySystemTimeIsReasonable;
import static com.facebook.presto.server.security.SecurityConfig.AuthenticationType.OAUTH2;
import static com.google.common.base.Strings.nullToEmpty;
import static java.util.Objects.requireNonNull;

Expand Down Expand Up @@ -147,7 +150,7 @@ public void run()

modules.addAll(getAdditionalModules());

Bootstrap app = new Bootstrap(modules.build());
Bootstrap app = new Bootstrap((Module) modules.build());

try {
Injector injector = app.initialize();
Expand Down Expand Up @@ -198,6 +201,11 @@ public void run()
injector.getInstance(ClientRequestFilterManager.class).loadClientRequestFilters();
startAssociatedProcesses(injector);

SecurityConfig securityConfig = injector.getInstance(SecurityConfig.class);
if (securityConfig.getAuthenticationTypes().contains(OAUTH2)) {
injector.getInstance(OAuth2Client.class).load();
}

injector.getInstance(Announcer.class).start();

log.info("======== SERVER STARTED ========");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
*/
package com.facebook.presto.server;

import com.facebook.presto.server.security.oauth2.OAuthWebUiCookie;

import javax.annotation.security.RolesAllowed;
import javax.ws.rs.GET;
import javax.ws.rs.HeaderParam;
Expand All @@ -21,15 +23,19 @@
import javax.ws.rs.core.Response;
import javax.ws.rs.core.UriInfo;

import java.util.Optional;

import static com.facebook.presto.server.security.RoleType.ADMIN;
import static com.facebook.presto.server.security.oauth2.OAuth2Utils.getLastURLParameter;
import static com.google.common.base.Strings.isNullOrEmpty;
import static com.google.common.net.HttpHeaders.X_FORWARDED_PROTO;
import static javax.ws.rs.core.Response.Status.MOVED_PERMANENTLY;

@Path("/")
@Path(WebUiResource.UI_ENDPOINT)
@RolesAllowed(ADMIN)
public class WebUiResource
{
public static final String UI_ENDPOINT = "/";

@GET
public Response redirectIndexHtml(
@HeaderParam(X_FORWARDED_PROTO) String proto,
Expand All @@ -38,9 +44,30 @@ public Response redirectIndexHtml(
if (isNullOrEmpty(proto)) {
proto = uriInfo.getRequestUri().getScheme();
}
Optional<String> lastURL = getLastURLParameter(uriInfo.getQueryParameters());
if (lastURL.isPresent()) {
return Response
.seeOther(uriInfo.getRequestUriBuilder().scheme(proto).uri(lastURL.get()).build())
.build();
}

return Response
.seeOther(uriInfo.getRequestUriBuilder().scheme(proto).path("/ui/").replaceQuery("").build())
.build();
}

return Response.status(MOVED_PERMANENTLY)
.location(uriInfo.getRequestUriBuilder().scheme(proto).path("/ui/").build())
@GET
@Path("/logout")
public Response logout(
@HeaderParam(X_FORWARDED_PROTO) String proto,
@Context UriInfo uriInfo)
{
if (isNullOrEmpty(proto)) {
proto = uriInfo.getRequestUri().getScheme();
}
return Response
.seeOther(uriInfo.getBaseUriBuilder().scheme(proto).path("/ui/logout.html").build())
.cookie(OAuthWebUiCookie.delete())
.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import com.facebook.airlift.http.server.AuthenticationException;
import com.facebook.airlift.http.server.Authenticator;
import com.facebook.presto.ClientRequestFilterManager;
import com.facebook.presto.server.security.oauth2.OAuth2Authenticator;
import com.facebook.presto.spi.ClientRequestFilter;
import com.facebook.presto.spi.PrestoException;
import com.google.common.base.Joiner;
Expand Down Expand Up @@ -46,7 +47,11 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

import static com.facebook.presto.server.WebUiResource.UI_ENDPOINT;
import static com.facebook.presto.server.security.oauth2.OAuth2CallbackResource.CALLBACK_ENDPOINT;
import static com.facebook.presto.server.security.oauth2.OAuth2TokenExchangeResource.TOKEN_ENDPOINT;
import static com.facebook.presto.spi.StandardErrorCode.HEADER_MODIFICATION_ATTEMPT;
import static com.google.common.io.ByteStreams.copy;
import static com.google.common.io.ByteStreams.nullOutputStream;
Expand All @@ -61,32 +66,47 @@ public class AuthenticationFilter
implements Filter
{
private static final String HTTPS_PROTOCOL = "https";
private final List<Authenticator> authenticators;
private final boolean allowForwardedHttps;
private static List<Authenticator> authenticators;
private static boolean allowForwardedHttps;
private final ClientRequestFilterManager clientRequestFilterManager;
private final List<String> headersBlockList = ImmutableList.of("X-Presto-Transaction-Id", "X-Presto-Started-Transaction-Id", "X-Presto-Clear-Transaction-Id", "X-Presto-Trace-Token");
private final WebUiAuthenticationManager webUiAuthenticationManager;
private final boolean isOauth2Enabled;

@Inject
public AuthenticationFilter(List<Authenticator> authenticators, SecurityConfig securityConfig, ClientRequestFilterManager clientRequestFilterManager)
public AuthenticationFilter(List<Authenticator> authenticators, SecurityConfig securityConfig, WebUiAuthenticationManager webUiAuthenticationManager, ClientRequestFilterManager clientRequestFilterManager)
{
allowForwardedHttps = requireNonNull(securityConfig, "securityConfig is null").getAllowForwardedHttps();

this.authenticators = ImmutableList.copyOf(requireNonNull(authenticators, "authenticators is null"));
this.allowForwardedHttps = requireNonNull(securityConfig, "securityConfig is null").getAllowForwardedHttps();
this.clientRequestFilterManager = requireNonNull(clientRequestFilterManager, "clientRequestFilterManager is null");
this.webUiAuthenticationManager = requireNonNull(webUiAuthenticationManager, "webUiAuthenticationManager is null");
this.isOauth2Enabled = this.authenticators.stream()
.anyMatch(a -> a.getClass().equals(OAuth2Authenticator.class));
}

@Override
public void init(FilterConfig filterConfig) {}
public void init(FilterConfig filterConfig)
{
}

@Override
public void destroy() {}
public void destroy()
{
}

@Override
public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain nextFilter)
throws IOException, ServletException
{
HttpServletRequest request = (HttpServletRequest) servletRequest;
HttpServletResponse response = (HttpServletResponse) servletResponse;

// Check if it's a request going to the web UI side.
if (isWebUiRequest(request) && isOauth2Enabled) {
// call web authenticator
this.webUiAuthenticationManager.handleRequest(request, response, nextFilter);
return;
}
// skip authentication if non-secure or not configured
if (!doesRequestSupportAuthentication(request)) {
nextFilter.doFilter(request, response);
Expand Down Expand Up @@ -118,6 +138,10 @@ public void doFilter(ServletRequest servletRequest, ServletResponse servletRespo
// authentication failed
skipRequestBody(request);

// Browsers have special handling for the BASIC challenge authenticate header so we need to filter them out if the WebUI Oauth Token is present.
if (isOauth2Enabled && OAuth2Authenticator.extractTokenFromCookie(request).isPresent()) {
authenticateHeaders = authenticateHeaders.stream().filter(value -> value.contains("x_token_server")).collect(Collectors.toSet());
}
for (String value : authenticateHeaders) {
response.addHeader(WWW_AUTHENTICATE, value);
}
Expand All @@ -140,6 +164,12 @@ public void doFilter(ServletRequest servletRequest, ServletResponse servletRespo
}
}

private boolean isWebUiRequest(HttpServletRequest request)
{
String pathInfo = request.getPathInfo();
return pathInfo == null || pathInfo.equals(UI_ENDPOINT) || pathInfo.startsWith("/ui");
}

public HttpServletRequest mergeExtraHeaders(HttpServletRequest request, Principal principal)
{
List<ClientRequestFilter> clientRequestFilters = clientRequestFilterManager.getClientRequestFilters();
Expand Down Expand Up @@ -195,11 +225,10 @@ private boolean doesRequestSupportAuthentication(HttpServletRequest request)
return false;
}

private static ServletRequest withPrincipal(HttpServletRequest request, Principal principal)
public static ServletRequest withPrincipal(HttpServletRequest request, Principal principal)
{
requireNonNull(principal, "principal is null");
return new HttpServletRequestWrapper(request)
{
return new HttpServletRequestWrapper(request) {
@Override
public Principal getUserPrincipal()
{
Expand All @@ -208,6 +237,12 @@ public Principal getUserPrincipal()
};
}

public static boolean isPublic(HttpServletRequest request)
{
return request.getPathInfo().startsWith(TOKEN_ENDPOINT)
|| request.getPathInfo().startsWith(CALLBACK_ENDPOINT);
}

private static void skipRequestBody(HttpServletRequest request)
throws IOException
{
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.server.security;

import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import java.io.IOException;

public class DefaultWebUiAuthenticationManager
implements WebUiAuthenticationManager
{
@Override
public void handleRequest(HttpServletRequest request, HttpServletResponse response, FilterChain nextFilter)
throws IOException, ServletException
{
nextFilter.doFilter(request, response);
}
}
Loading
Loading