Skip to content

Add RPC header for access token #7803

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: trunk
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@
import org.apache.hadoop.thirdparty.protobuf.Message;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.hadoop.security.AuthorizationContext;

/** An abstract IPC service. IPC calls take a single {@link Writable} as a
* parameter, and return a {@link Writable} as their value. A service runs on
Expand Down Expand Up @@ -1004,6 +1005,7 @@ public static class Call implements Schedulable,
final byte[] clientId;
private final Span span; // the trace span on the server side
private final CallerContext callerContext; // the call context
private final byte[] authHeader; // the auth header
private boolean deferredResponse = false;
private int priorityLevel;
// the priority level assigned by scheduler, 0 by default
Expand Down Expand Up @@ -1035,6 +1037,11 @@ public Call(int id, int retryCount, Void ignore1, Void ignore2,

Call(int id, int retryCount, RPC.RpcKind kind, byte[] clientId,
Span span, CallerContext callerContext) {
this(id, retryCount, kind, clientId, span, callerContext, null);
}

Call(int id, int retryCount, RPC.RpcKind kind, byte[] clientId,
Span span, CallerContext callerContext, byte[] authHeader) {
this.callId = id;
this.retryCount = retryCount;
this.timestampNanos = Time.monotonicNowNanos();
Expand All @@ -1043,6 +1050,7 @@ public Call(int id, int retryCount, Void ignore1, Void ignore2,
this.clientId = clientId;
this.span = span;
this.callerContext = callerContext;
this.authHeader = authHeader;
this.clientStateId = Long.MIN_VALUE;
this.isCallCoordinated = false;
}
Expand Down Expand Up @@ -1243,7 +1251,14 @@ private class RpcCall extends Call {
RpcCall(Connection connection, int id, int retryCount,
Writable param, RPC.RpcKind kind, byte[] clientId,
Span span, CallerContext context) {
super(id, retryCount, kind, clientId, span, context);
this(connection, id, retryCount, param, kind, clientId,
span, context, new byte[0]);
}

RpcCall(Connection connection, int id, int retryCount,
Writable param, RPC.RpcKind kind, byte[] clientId,
Span span, CallerContext context, byte[] authHeader) {
super(id, retryCount, kind, clientId, span, context, authHeader);
this.connection = connection;
this.rpcRequest = param;
}
Expand Down Expand Up @@ -2975,51 +2990,61 @@ private void processRpcRequest(RpcRequestHeaderProto header,
.build();
}

RpcCall call = new RpcCall(this, header.getCallId(),
header.getRetryCount(), rpcRequest,
ProtoUtil.convert(header.getRpcKind()),
header.getClientId().toByteArray(), span, callerContext);

// Save the priority level assignment by the scheduler
call.setPriorityLevel(callQueue.getPriorityLevel(call));
call.markCallCoordinated(false);
if(alignmentContext != null && call.rpcRequest != null &&
(call.rpcRequest instanceof ProtobufRpcEngine2.RpcProtobufRequest)) {
// if call.rpcRequest is not RpcProtobufRequest, will skip the following
// step and treat the call as uncoordinated. As currently only certain
// ClientProtocol methods request made through RPC protobuf needs to be
// coordinated.
String methodName;
String protoName;
ProtobufRpcEngine2.RpcProtobufRequest req =
(ProtobufRpcEngine2.RpcProtobufRequest) call.rpcRequest;
try {
methodName = req.getRequestHeader().getMethodName();
protoName = req.getRequestHeader().getDeclaringClassProtocolName();
if (alignmentContext.isCoordinatedCall(protoName, methodName)) {
call.markCallCoordinated(true);
long stateId;
stateId = alignmentContext.receiveRequestState(
header, getMaxIdleTime());
call.setClientStateId(stateId);
if (header.hasRouterFederatedState()) {
call.setFederatedNamespaceState(header.getRouterFederatedState());
// Set AuthorizationContext for this thread if present
byte[] authHeader = null;
try {
if (header.hasAuthorizationHeader()) {
authHeader = header.getAuthorizationHeader().toByteArray();
}

RpcCall call = new RpcCall(this, header.getCallId(),
header.getRetryCount(), rpcRequest,
ProtoUtil.convert(header.getRpcKind()),
header.getClientId().toByteArray(), span, callerContext, authHeader);

// Save the priority level assignment by the scheduler
call.setPriorityLevel(callQueue.getPriorityLevel(call));
call.markCallCoordinated(false);
if (alignmentContext != null && call.rpcRequest != null &&
(call.rpcRequest instanceof ProtobufRpcEngine2.RpcProtobufRequest)) {
// if call.rpcRequest is not RpcProtobufRequest, will skip the following
// step and treat the call as uncoordinated. As currently only certain
// ClientProtocol methods request made through RPC protobuf needs to be
// coordinated.
String methodName;
String protoName;
ProtobufRpcEngine2.RpcProtobufRequest req =
(ProtobufRpcEngine2.RpcProtobufRequest) call.rpcRequest;
try {
methodName = req.getRequestHeader().getMethodName();
protoName = req.getRequestHeader().getDeclaringClassProtocolName();
if (alignmentContext.isCoordinatedCall(protoName, methodName)) {
call.markCallCoordinated(true);
long stateId;
stateId = alignmentContext.receiveRequestState(
header, getMaxIdleTime());
call.setClientStateId(stateId);
if (header.hasRouterFederatedState()) {
call.setFederatedNamespaceState(header.getRouterFederatedState());
}
}
} catch (IOException ioe) {
throw new RpcServerException("Processing RPC request caught ", ioe);
}
} catch (IOException ioe) {
throw new RpcServerException("Processing RPC request caught ", ioe);
}
}

try {
internalQueueCall(call);
} catch (RpcServerException rse) {
throw rse;
} catch (IOException ioe) {
throw new FatalRpcServerException(
RpcErrorCodeProto.ERROR_RPC_SERVER, ioe);
try {
internalQueueCall(call);
} catch (RpcServerException rse) {
throw rse;
} catch (IOException ioe) {
throw new FatalRpcServerException(
RpcErrorCodeProto.ERROR_RPC_SERVER, ioe);
}
incRpcCount(); // Increment the rpc count
} finally {
AuthorizationContext.clear();
}
incRpcCount(); // Increment the rpc count
}

/**
Expand Down Expand Up @@ -3245,6 +3270,7 @@ public void run() {
}
// always update the current call context
CallerContext.setCurrent(call.callerContext);
AuthorizationContext.setCurrentAuthorizationHeader(call.authHeader);
UserGroupInformation remoteUser = call.getRemoteUser();
connDropped = !call.isOpen();
if (remoteUser != null) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.security;

/**
* Utility for managing a thread-local authorization header for RPC calls.
*/
public final class AuthorizationContext {
private static final ThreadLocal<byte[]> AUTH_HEADER = new ThreadLocal<>();

private AuthorizationContext() {}

public static void setCurrentAuthorizationHeader(byte[] header) {
AUTH_HEADER.set(header);
}

public static byte[] getCurrentAuthorizationHeader() {
return AUTH_HEADER.get();
}

public static void clear() {
AUTH_HEADER.remove();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.apache.hadoop.tracing.Span;
import org.apache.hadoop.tracing.Tracer;
import org.apache.hadoop.tracing.TraceUtils;
import org.apache.hadoop.security.AuthorizationContext;

import org.apache.hadoop.thirdparty.protobuf.ByteString;

Expand Down Expand Up @@ -203,6 +204,12 @@ public static RpcRequestHeaderProto makeRpcRequestHeader(RPC.RpcKind rpcKind,
result.setCallerContext(contextBuilder);
}

// Add authorization header if present
byte[] authzHeader = AuthorizationContext.getCurrentAuthorizationHeader();
if (authzHeader != null) {
result.setAuthorizationHeader(ByteString.copyFrom(authzHeader));
}

// Add alignment context if it is not null
if (alignmentContext != null) {
alignmentContext.updateRequestState(result);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ message RpcRequestHeaderProto { // the header for the RpcRequest
// The client should not interpret these bytes, but only forward bytes
// received from RpcResponseHeaderProto.routerFederatedState.
optional bytes routerFederatedState = 9;
// Authorization header for passing opaque credentials or tokens
optional bytes authorizationHeader = 10;
}


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.security;

import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;

public class TestAuthorizationContext {

@Test
public void testSetAndGetAuthorizationHeader() {
byte[] header = "my-auth-header".getBytes();
AuthorizationContext.setCurrentAuthorizationHeader(header);
Assertions.assertArrayEquals(header, AuthorizationContext.getCurrentAuthorizationHeader());
AuthorizationContext.clear();
}

@Test
public void testClearAuthorizationHeader() {
byte[] header = "clear-me".getBytes();
AuthorizationContext.setCurrentAuthorizationHeader(header);
AuthorizationContext.clear();
Assertions.assertNull(AuthorizationContext.getCurrentAuthorizationHeader());
}

@Test
public void testThreadLocalIsolation() throws Exception {
byte[] mainHeader = "main-thread".getBytes();
AuthorizationContext.setCurrentAuthorizationHeader(mainHeader);
Thread t = new Thread(() -> {
Assertions.assertNull(AuthorizationContext.getCurrentAuthorizationHeader());
byte[] threadHeader = "other-thread".getBytes();
AuthorizationContext.setCurrentAuthorizationHeader(threadHeader);
Assertions.assertArrayEquals(threadHeader, AuthorizationContext.getCurrentAuthorizationHeader());
AuthorizationContext.clear();
Assertions.assertNull(AuthorizationContext.getCurrentAuthorizationHeader());
});
t.start();
t.join();
// Main thread should still have its header
Assertions.assertArrayEquals(mainHeader, AuthorizationContext.getCurrentAuthorizationHeader());
AuthorizationContext.clear();
}

@Test
public void testNullAndEmptyHeader() {
AuthorizationContext.setCurrentAuthorizationHeader(null);
Assertions.assertNull(AuthorizationContext.getCurrentAuthorizationHeader());
byte[] empty = new byte[0];
AuthorizationContext.setCurrentAuthorizationHeader(empty);
Assertions.assertArrayEquals(empty, AuthorizationContext.getCurrentAuthorizationHeader());
AuthorizationContext.clear();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.hdfs.server.namenode;

import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.fs.FileStatus;
import org.apache.hadoop.hdfs.HdfsConfiguration;
import org.apache.hadoop.hdfs.MiniDFSCluster;
import org.apache.hadoop.security.AuthorizationContext;
import org.junit.jupiter.api.Test;

import java.net.InetAddress;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import static org.apache.hadoop.hdfs.DFSConfigKeys.DFS_NAMENODE_AUDIT_LOGGERS_KEY;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertNull;

public class TestAuthorizationHeaderPropagation {

public static class HeaderCapturingAuditLogger implements AuditLogger {
public static final List<byte[]> capturedHeaders = new ArrayList<>();
@Override
public void initialize(Configuration conf) {}
@Override
public void logAuditEvent(boolean succeeded, String userName, InetAddress addr,
String cmd, String src, String dst, FileStatus stat) {
byte[] header = AuthorizationContext.getCurrentAuthorizationHeader();
capturedHeaders.add(header == null ? null : Arrays.copyOf(header, header.length));
}
}

@Test
public void testAuthorizationHeaderPerRpc() throws Exception {
Configuration conf = new HdfsConfiguration();
conf.set(DFS_NAMENODE_AUDIT_LOGGERS_KEY, HeaderCapturingAuditLogger.class.getName());
MiniDFSCluster cluster = new MiniDFSCluster.Builder(conf).build();
try {
cluster.waitClusterUp();
HeaderCapturingAuditLogger.capturedHeaders.clear();
FileSystem fs = cluster.getFileSystem();
// First RPC with header1
byte[] header1 = "header-one".getBytes();
AuthorizationContext.setCurrentAuthorizationHeader(header1);
fs.mkdirs(new Path("/authz1"));
AuthorizationContext.clear();
// Second RPC with header2
byte[] header2 = "header-two".getBytes();
AuthorizationContext.setCurrentAuthorizationHeader(header2);
fs.mkdirs(new Path("/authz2"));
AuthorizationContext.clear();
// Third RPC with no header
fs.mkdirs(new Path("/authz3"));
// Now assert
assertArrayEquals(header1, HeaderCapturingAuditLogger.capturedHeaders.get(0));
assertArrayEquals(header2, HeaderCapturingAuditLogger.capturedHeaders.get(1));
assertNull(HeaderCapturingAuditLogger.capturedHeaders.get(2));
} finally {
cluster.shutdown();
}
}
}