Skip to content

Commit

Permalink
RANGER-4955: Add support to retrieve group information from JWT
Browse files Browse the repository at this point in the history
  • Loading branch information
kumaab committed Oct 9, 2024
1 parent f06d0e7 commit 5614d24
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 64 deletions.
6 changes: 6 additions & 0 deletions ranger-authn/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,12 @@
<version>${nimbus-jose-jwt.version}</version>
</dependency>

<dependency>
<groupId>org.springframework.security</groupId>
<artifactId>spring-security-core</artifactId>
<version>${springframework.security.version}</version>
</dependency>

<!-- Test -->
<dependency>
<groupId>org.junit.jupiter</groupId>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.ranger.authz.authority;

import java.util.Set;
import org.springframework.security.core.GrantedAuthority;

public final class JwtAuthority implements GrantedAuthority {
private static final long serialVersionUID = 12323L;
private final String role;
private final Set<String> groups;

public JwtAuthority(String role, Set<String> groups) {
this.role = role;
this.groups = groups;
}

public String getAuthority() {
return this.role;
}

public Set<String> getGroups() { return this.groups; }

public String toString() {
return this.role;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

import java.net.URL;
import java.text.ParseException;
import java.util.Set;
import java.util.HashSet;
import java.util.Arrays;
import java.util.Date;
import java.util.List;
Expand All @@ -43,20 +45,25 @@
import com.nimbusds.jose.proc.JWSVerificationKeySelector;
import com.nimbusds.jose.proc.SecurityContext;
import com.nimbusds.jwt.SignedJWT;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.proc.ConfigurableJWTProcessor;

public abstract class RangerJwtAuthHandler implements RangerAuthHandler {
private static final Logger LOG = LoggerFactory.getLogger(RangerJwtAuthHandler.class);

private JWSVerifier verifier = null;
protected SignedJWT signedJWT = null;
private String jwksProviderUrl = null;
public static final String TYPE = "ranger-jwt"; // Constant that identifies the authentication mechanism.
public static final String KEY_PROVIDER_URL = "jwks.provider-url"; // JWKS provider URL
public static final String KEY_JWT_PUBLIC_KEY = "jwt.public-key"; // JWT token provider public key
public static final String KEY_JWT_COOKIE_NAME = "jwt.cookie-name"; // JWT cookie name
public static final String KEY_JWT_AUDIENCES = "jwt.audiences";
public static final String JWT_AUTHZ_PREFIX = "Bearer ";
public static final String CUSTOM_JWT_CLAIM_GROUP_KEY_PARAM = "custom.jwt.claim.group.key";
public static final String CUSTOM_JWT_CLAIM_GROUP_KEY_VALUE_DEFAULT = "knox.groups";

public String CUSTOM_JWT_CLAIM_GROUP_KEY_VALUE = null;
protected List<String> audiences = null;
protected JWKSource<SecurityContext> keySource = null;

Expand All @@ -76,6 +83,7 @@ public void initialize(final Properties config) throws Exception {

// optional configurations
String pemPublicKey = config.getProperty(KEY_JWT_PUBLIC_KEY);
CUSTOM_JWT_CLAIM_GROUP_KEY_VALUE = config.getProperty(CUSTOM_JWT_CLAIM_GROUP_KEY_PARAM, CUSTOM_JWT_CLAIM_GROUP_KEY_VALUE_DEFAULT);

// setup JWT provider public key if configured
if (StringUtils.isNotBlank(pemPublicKey)) {
Expand Down Expand Up @@ -112,30 +120,36 @@ protected AuthenticationToken authenticate(final String jwtAuthHeader, final Str

if (StringUtils.isNotBlank(serializedJWT)) {
try {
final SignedJWT jwtToken = SignedJWT.parse(serializedJWT);
boolean valid = validateToken(jwtToken);
signedJWT = SignedJWT.parse(serializedJWT);
JWTClaimsSet claimsSet = getJWTClaimsSet();

if(LOG.isDebugEnabled()){
LOG.debug("RangerJwtAuthHandler.authenticate(): JWTClaimsSet - {}", claimsSet);
}

boolean valid = validateToken();
if (valid) {
String userName;

if (StringUtils.isNotBlank(doAsUser)) {
userName = doAsUser.trim();
} else {
userName = jwtToken.getJWTClaimsSet().getSubject();
userName = claimsSet.getSubject();
}

if (LOG.isDebugEnabled()) {
LOG.debug("RangerJwtAuthHandler.authenticate(): Issuing AuthenticationToken for user: [{}]", userName);
LOG.debug("RangerJwtAuthHandler.authenticate(): Authentication successful for user [{}] and doAs user is [{}]", jwtToken.getJWTClaimsSet().getSubject(), doAsUser);
LOG.debug("RangerJwtAuthHandler.authenticate(): Authentication successful for user [{}] and doAs user is [{}]", claimsSet.getSubject(), doAsUser);
}
token = new AuthenticationToken(userName, userName, TYPE);
} else {
LOG.warn("RangerJwtAuthHandler.authenticate(): Validation failed for JWT token: [{}] ", jwtToken.serialize());
LOG.warn("RangerJwtAuthHandler.authenticate(): Validation failed for JWT: [{}] ", signedJWT.serialize());
}
} catch (ParseException pe) {
LOG.warn("RangerJwtAuthHandler.authenticate(): Unable to parse the JWT token", pe);
} catch (ParseException | RuntimeException exp) {
LOG.warn("RangerJwtAuthHandler.authenticate(): Unable to parse the JWT", exp);
}
} else {
LOG.warn("RangerJwtAuthHandler.authenticate(): JWT token not found.");
LOG.warn("RangerJwtAuthHandler.authenticate(): JWT not found");
}
}

Expand All @@ -145,6 +159,31 @@ protected AuthenticationToken authenticate(final String jwtAuthHeader, final Str

return token;
}

protected JWTClaimsSet getJWTClaimsSet() throws ParseException {
return signedJWT.getJWTClaimsSet();
}

public Set<String> getGroupsFromClaimSet() {
List<String> groupsClaim = null;
try {
groupsClaim = (List<String>) getJWTClaimsSet().getClaim(CUSTOM_JWT_CLAIM_GROUP_KEY_VALUE);
} catch (ParseException e) {
LOG.error("Unable to parse JWT claim set", e);
}

if (groupsClaim == null) {
LOG.warn("No group claim found!");
return new HashSet<>();
}

Set<String> groups = new HashSet<>(groupsClaim);
if (LOG.isDebugEnabled()) {
LOG.debug("Groups present in Claim [{}]: {}", CUSTOM_JWT_CLAIM_GROUP_KEY_VALUE, groups);
}
return groups;
}


protected String getJWT(final String jwtAuthHeader, final String jwtCookie) {
String serializedJWT = null;
Expand All @@ -171,19 +210,18 @@ protected String getJWT(final String jwtAuthHeader, final String jwtCookie) {
* implementation through submethods used within but also allows for the
* override of the entire token validation algorithm.
*
* @param jwtToken the token to validate
* @return true if valid
*/
protected boolean validateToken(final SignedJWT jwtToken) {
boolean expValid = validateExpiration(jwtToken);
protected boolean validateToken() throws ParseException {
boolean expValid = validateExpiration();
boolean sigValid = false;
boolean audValid = false;

if (expValid) {
sigValid = validateSignature(jwtToken);
sigValid = validateSignature();

if (sigValid) {
audValid = validateAudiences(jwtToken);
audValid = validateAudiences();
}
}

Expand All @@ -195,41 +233,40 @@ protected boolean validateToken(final SignedJWT jwtToken) {
}

/**
* Verify the signature of the JWT token in this method. This method depends on
* Verify the signature of the JWT in this method. This method depends on
* the public key that was established during init based upon the provisioned
* public key. Override this method in subclasses in order to customize the
* signature verification behavior.
*
* @param jwtToken the token that contains the signature to be validated
* @return valid true if signature verifies successfully; false otherwise
*/
protected boolean validateSignature(final SignedJWT jwtToken) {
protected boolean validateSignature() {
boolean valid = false;

if (JWSObject.State.SIGNED == jwtToken.getState()) {
if (JWSObject.State.SIGNED == signedJWT.getState()) {
if (LOG.isDebugEnabled()) {
LOG.debug("JWT token is in a SIGNED state");
LOG.debug("JWT is in a SIGNED state");
}

if (jwtToken.getSignature() != null) {
if (signedJWT.getSignature() != null) {
try {
if (StringUtils.isNotBlank(jwksProviderUrl)) {
JWSKeySelector<SecurityContext> keySelector = new JWSVerificationKeySelector<>(jwtToken.getHeader().getAlgorithm(), keySource);
JWSKeySelector<SecurityContext> keySelector = new JWSVerificationKeySelector<>(signedJWT.getHeader().getAlgorithm(), keySource);

// Create a JWT processor for the access tokens
ConfigurableJWTProcessor<SecurityContext> jwtProcessor = getJwtProcessor(keySelector);

// Process the token
jwtProcessor.process(jwtToken, null);
jwtProcessor.process(signedJWT, null);
valid = true;
if (LOG.isDebugEnabled()) {
LOG.debug("JWT token has been successfully verified.");
LOG.debug("JWT has been successfully verified.");
}
} else if (verifier != null) {
if (jwtToken.verify(verifier)) {
if (signedJWT.verify(verifier)) {
valid = true;
if (LOG.isDebugEnabled()) {
LOG.debug("JWT token has been successfully verified.");
LOG.debug("JWT has been successfully verified.");
}
} else {
LOG.warn("JWT signature verification failed.");
Expand Down Expand Up @@ -257,61 +294,51 @@ protected boolean validateSignature(final SignedJWT jwtToken) {
* token claims list for audience. Override this method in subclasses in order
* to customize the audience validation behavior.
*
* @param jwtToken the JWT token where the allowed audiences will be found
* @return true if an expected audience is present, otherwise false
*/
protected boolean validateAudiences(final SignedJWT jwtToken) {
protected boolean validateAudiences() throws ParseException {
boolean valid = false;
try {
List<String> tokenAudienceList = jwtToken.getJWTClaimsSet().getAudience();
// if there were no expected audiences configured then just
// consider any audience acceptable
if (audiences == null) {
valid = true;
} else {
// if any of the configured audiences is found then consider it
// acceptable
for (String aud : tokenAudienceList) {
if (audiences.contains(aud)) {
if (LOG.isDebugEnabled()) {
LOG.debug("JWT token audience has been successfully validated.");
}
valid = true;
break;
JWTClaimsSet claimsSet = getJWTClaimsSet();
List<String> tokenAudienceList = claimsSet.getAudience();
// if there were no expected audiences configured then just consider any audience acceptable
if (audiences == null) {
valid = true;
} else {
// if any of the configured audiences is found then consider it acceptable
for (String aud : tokenAudienceList) {
if (audiences.contains(aud)) {
if (LOG.isDebugEnabled()) {
LOG.debug("JWT audience has been successfully validated.");
}
}
if (!valid) {
LOG.warn("JWT audience validation failed.");
valid = true;
break;
}
}
} catch (ParseException pe) {
LOG.warn("Unable to parse the JWT token.", pe);
}

if (!valid) {
LOG.warn("JWT audience validation failed.");
}
return valid;
}

/**
* Validate that the expiration time of the JWT token has not been violated. If
* it has then throw an AuthenticationException. Override this method in
* Validate that the expiration time of the JWT has not been violated. If
* it has, then throw an AuthenticationException. Override this method in
* subclasses in order to customize the expiration validation behavior.
*
* @param jwtToken the token that contains the expiration date to validate
* @return valid true if the token has not expired; false otherwise
*/
protected boolean validateExpiration(final SignedJWT jwtToken) {
protected boolean validateExpiration() throws ParseException {
boolean valid = false;
try {
Date expires = jwtToken.getJWTClaimsSet().getExpirationTime();
if (expires == null || new Date().before(expires)) {
valid = true;
if (LOG.isDebugEnabled()) {
LOG.debug("JWT token expiration date has been successfully validated.");
}
} else {
LOG.warn("JWT token provided is expired.");
Date expires = getJWTClaimsSet().getExpirationTime();
if (expires == null || new Date().before(expires)) {
valid = true;
if (LOG.isDebugEnabled()) {
LOG.debug("JWT expiration date has been successfully validated.");
}
} catch (ParseException pe) {
LOG.warn("Failed to validate JWT expiry.", pe);
} else {
LOG.warn("JWT provided has expired.");
}

return valid;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
package org.apache.ranger.security.web.filter;

import java.io.IOException;
import java.util.Set;
import java.util.Arrays;
import java.util.List;
import java.util.Properties;
Expand All @@ -33,6 +34,7 @@
import javax.servlet.http.HttpServletRequest;

import org.apache.log4j.Logger;
import org.apache.ranger.authz.authority.JwtAuthority;
import org.apache.ranger.authz.handler.RangerAuth;
import org.apache.ranger.authz.handler.jwt.RangerDefaultJwtAuthHandler;
import org.apache.ranger.authz.handler.jwt.RangerJwtAuthHandler;
Expand All @@ -42,7 +44,6 @@
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.authority.SimpleGrantedAuthority;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.userdetails.User;
import org.springframework.security.core.userdetails.UserDetails;
Expand Down Expand Up @@ -101,7 +102,8 @@ public void doFilter(ServletRequest request, ServletResponse response, FilterCha
RangerAuth rangerAuth = authenticate(httpServletRequest);

if (rangerAuth != null) {
final List<GrantedAuthority> grantedAuths = Arrays.asList(new SimpleGrantedAuthority(DEFAULT_RANGER_ROLE));
final Set<String> groups = getGroupsFromClaimSet();
final List<GrantedAuthority> grantedAuths = Arrays.asList(new JwtAuthority(DEFAULT_RANGER_ROLE, groups));
final UserDetails principal = new User(rangerAuth.getUserName(), "", grantedAuths);
final Authentication finalAuthentication = new UsernamePasswordAuthenticationToken(principal, "", grantedAuths);
final WebAuthenticationDetails webDetails = new WebAuthenticationDetails(httpServletRequest);
Expand Down

0 comments on commit 5614d24

Please sign in to comment.