Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.ratis.netty;
package org.apache.ratis.util;

import org.apache.ratis.security.TlsConf;
import org.apache.ratis.security.TlsConf.CertificatesConf;
Expand All @@ -36,13 +36,14 @@
import org.apache.ratis.thirdparty.io.netty.channel.socket.nio.NioSocketChannel;
import org.apache.ratis.thirdparty.io.netty.handler.ssl.SslContext;
import org.apache.ratis.thirdparty.io.netty.handler.ssl.SslContextBuilder;
import org.apache.ratis.util.ConcurrentUtils;
import org.apache.ratis.util.TimeDuration;
import org.apache.ratis.thirdparty.io.netty.util.concurrent.Future;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.net.ssl.KeyManager;
import javax.net.ssl.TrustManager;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Function;
Expand Down Expand Up @@ -81,6 +82,39 @@ static EventLoopGroup newEventLoopGroup(String name, int size, boolean useEpoll)
return new NioEventLoopGroup(size, ConcurrentUtils.newThreadFactory(name + "-"));
}

static void shutdownGracefully(EventLoopGroup... groups) {
shutdownGracefully(CLOSE_TIMEOUT, groups);
}

static void shutdownGracefully(TimeDuration awaitTime, EventLoopGroup... groups) {
if (groups == null || groups.length == 0) {
return;
}

final List<EventLoopGroup> nonNullGroups = new ArrayList<>(groups.length);
final List<Future<?>> futures = new ArrayList<>(groups.length);
for (EventLoopGroup group : groups) {
if (group != null) {
nonNullGroups.add(group);
futures.add(group.shutdownGracefully());
}
}

for (int i = 0; i < futures.size(); i++) {
final EventLoopGroup group = nonNullGroups.get(i);
try {
if (!futures.get(i).await(awaitTime.getDuration(), awaitTime.getUnit())) {
LOG.warn("Failed to shut down EventLoopGroup {} in {}", group, awaitTime);
}
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
LOG.warn("Interrupted while shutting down EventLoopGroup {}", group, e);
} catch (Exception e) {
LOG.warn("Failed to shut down EventLoopGroup {} in {}", group, awaitTime, e);
}
}
}

static void setTrustManager(SslContextBuilder b, TrustManagerConf trustManagerConfig) {
if (trustManagerConfig == null) {
return;
Expand Down Expand Up @@ -196,4 +230,4 @@ static void closeChannel(Channel channel, String name) {
LOG.warn("closeChannel {} is not yet completed in {}", name, CLOSE_TIMEOUT);
}
}
}
}
39 changes: 39 additions & 0 deletions ratis-grpc/src/main/java/org/apache/ratis/grpc/GrpcConfigKeys.java
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,15 @@ static Consumer<String> getDefaultLog() {

String PREFIX = "raft.grpc";

String USE_EPOLL_KEY = PREFIX + ".use-epoll";
boolean USE_EPOLL_DEFAULT = true;
static boolean useEpoll(RaftProperties properties) {
return getBoolean(properties::getBoolean, USE_EPOLL_KEY, USE_EPOLL_DEFAULT, getDefaultLog());
}
static void setUseEpoll(RaftProperties properties, boolean useEpoll) {
setBoolean(properties::setBoolean, USE_EPOLL_KEY, useEpoll);
}

interface TLS {
String PREFIX = GrpcConfigKeys.PREFIX + ".tls";

Expand Down Expand Up @@ -155,6 +164,16 @@ static GrpcTlsConfig tlsConf(Parameters parameters) {
static void setTlsConf(Parameters parameters, GrpcTlsConfig conf) {
parameters.put(TLS_CONF_PARAMETER, conf, TLS_CONF_CLASS);
}

String WORKER_GROUP_SIZE_KEY = PREFIX + ".worker-group.size";
int WORKER_GROUP_SIZE_DEFAULT = 0;
static int workerGroupSize(RaftProperties properties) {
return getInt(properties::getInt, WORKER_GROUP_SIZE_KEY,
WORKER_GROUP_SIZE_DEFAULT, getDefaultLog(), requireMin(0), requireMax(65536));
}
static void setWorkerGroupSize(RaftProperties properties, int size) {
setInt(properties::setInt, WORKER_GROUP_SIZE_KEY, size);
}
}

interface Server {
Expand Down Expand Up @@ -291,6 +310,26 @@ static int stubPoolSize(RaftProperties properties) {
static void setStubPoolSize(RaftProperties properties, int size) {
setInt(properties::setInt, STUB_POOL_SIZE_KEY, size);
}

String BOSS_GROUP_SIZE_KEY = PREFIX + ".boss-group.size";
int BOSS_GROUP_SIZE_DEFAULT = 0;
static int bossGroupSize(RaftProperties properties) {
return getInt(properties::getInt, BOSS_GROUP_SIZE_KEY,
BOSS_GROUP_SIZE_DEFAULT, getDefaultLog(), requireMin(0), requireMax(65536));
}
static void setBossGroupSize(RaftProperties properties, int size) {
setInt(properties::setInt, BOSS_GROUP_SIZE_KEY, size);
}

String WORKER_GROUP_SIZE_KEY = PREFIX + ".worker-group.size";
int WORKER_GROUP_SIZE_DEFAULT = 0;
static int workerGroupSize(RaftProperties properties) {
return getInt(properties::getInt, WORKER_GROUP_SIZE_KEY,
WORKER_GROUP_SIZE_DEFAULT, getDefaultLog(), requireMin(0), requireMax(65536));
}
static void setWorkerGroupSize(RaftProperties properties, int size) {
setInt(properties::setInt, WORKER_GROUP_SIZE_KEY, size);
}
}

String MESSAGE_SIZE_MAX_KEY = PREFIX + ".message.size.max";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,6 @@ public GrpcClientRpc newRaftClientRpc(ClientId clientId, RaftProperties properti
checkPooledByteBufAllocatorUseCacheForAllThreads(LOG::debug);

final SslContexts forClient = forClientSupplier.get();
return new GrpcClientRpc(clientId, properties, forClient.adminSslContext, forClient.clientSslContext);
return GrpcClientRpc.create(clientId, properties, forClient.adminSslContext, forClient.clientSslContext);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.apache.ratis.conf.RaftProperties;
import org.apache.ratis.grpc.GrpcConfigKeys;
import org.apache.ratis.grpc.GrpcUtil;
import org.apache.ratis.util.NettyUtils;
import org.apache.ratis.grpc.metrics.intercept.client.MetricClientInterceptor;
import org.apache.ratis.proto.RaftProtos.GroupInfoReplyProto;
import org.apache.ratis.proto.RaftProtos.GroupInfoRequestProto;
Expand Down Expand Up @@ -51,9 +52,12 @@
import org.apache.ratis.thirdparty.io.grpc.netty.NegotiationType;
import org.apache.ratis.thirdparty.io.grpc.netty.NettyChannelBuilder;
import org.apache.ratis.thirdparty.io.grpc.stub.StreamObserver;
import org.apache.ratis.thirdparty.io.netty.channel.EventLoopGroup;
import org.apache.ratis.thirdparty.io.netty.handler.ssl.SslContext;
import org.apache.ratis.util.CollectionUtils;
import org.apache.ratis.util.JavaUtils;
import org.apache.ratis.util.MemoizedSupplier;
import org.apache.ratis.util.Preconditions;
import org.apache.ratis.util.SizeInBytes;
import org.apache.ratis.util.TimeDuration;
import org.apache.ratis.util.TimeoutExecutor;
Expand Down Expand Up @@ -94,14 +98,16 @@ public class GrpcClientProtocolClient implements Closeable {

private final AtomicReference<AsyncStreamObservers> unorderedStreamObservers = new AtomicReference<>();
private final MetricClientInterceptor metricClientInterceptor;
private final MemoizedSupplier<EventLoopGroup> clientWorkers;

GrpcClientProtocolClient(ClientId id, RaftPeer target, RaftProperties properties,
SslContext adminSslContext, SslContext clientSslContext) {
SslContext adminSslContext, SslContext clientSslContext, MemoizedSupplier<EventLoopGroup> clientWorkers) {
this.name = JavaUtils.memoize(() -> id + "->" + target.getId());
this.target = target;
final SizeInBytes flowControlWindow = GrpcConfigKeys.flowControlWindow(properties, LOG::debug);
this.maxMessageSize = GrpcConfigKeys.messageSizeMax(properties, LOG::debug);
metricClientInterceptor = new MetricClientInterceptor(getName());
this.clientWorkers = clientWorkers;

final String clientAddress = Optional.ofNullable(target.getClientAddress())
.filter(x -> !x.isEmpty()).orElse(target.getAddress());
Expand Down Expand Up @@ -135,6 +141,12 @@ private ManagedChannel buildChannel(String address, SslContext sslContext,
channelBuilder.negotiationType(NegotiationType.PLAINTEXT);
}

if (clientWorkers != null) {
final EventLoopGroup eventLoopGroup = clientWorkers.get();
channelBuilder.channelType(NettyUtils.getSocketChannelClass(eventLoopGroup))
.eventLoopGroup(eventLoopGroup);
}

return channelBuilder.flowControlWindow(flowControlWindow.getSizeInt())
.maxInboundMessageSize(maxMessageSize.getSizeInt())
.intercept(metricClientInterceptor)
Expand All @@ -156,6 +168,12 @@ public void close() {
metricClientInterceptor.close();
}

EventLoopGroup getClientWorkersForTesting() {
Preconditions.assertTrue(clientWorkers != null);
Preconditions.assertTrue(clientWorkers.isInitialized());
return clientWorkers.get();
}

RaftClientReplyProto groupAdd(GroupManagementRequestProto request) throws IOException {
return blockingCall(() -> adminBlockingStub
.withDeadlineAfter(requestTimeoutDuration.getDuration(), requestTimeoutDuration.getUnit())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.apache.ratis.thirdparty.io.grpc.Status;
import org.apache.ratis.thirdparty.io.grpc.StatusRuntimeException;
import org.apache.ratis.thirdparty.io.grpc.stub.StreamObserver;
import org.apache.ratis.thirdparty.io.netty.channel.EventLoopGroup;
import org.apache.ratis.proto.RaftProtos.GroupInfoRequestProto;
import org.apache.ratis.proto.RaftProtos.GroupListRequestProto;
import org.apache.ratis.proto.RaftProtos.GroupManagementRequestProto;
Expand All @@ -41,6 +42,8 @@
import org.apache.ratis.thirdparty.io.netty.handler.ssl.SslContext;
import org.apache.ratis.util.IOUtils;
import org.apache.ratis.util.JavaUtils;
import org.apache.ratis.util.MemoizedSupplier;
import org.apache.ratis.util.NettyUtils;
import org.apache.ratis.util.PeerProxyMap;
import org.apache.ratis.util.TimeDuration;
import org.slf4j.Logger;
Expand All @@ -53,18 +56,30 @@
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;

public class GrpcClientRpc extends RaftClientRpcWithProxy<GrpcClientProtocolClient> {
public final class GrpcClientRpc extends RaftClientRpcWithProxy<GrpcClientProtocolClient> {
public static final Logger LOG = LoggerFactory.getLogger(GrpcClientRpc.class);

public static GrpcClientRpc create(ClientId clientId, RaftProperties properties,
SslContext adminSslContext, SslContext clientSslContext) {
final int workerGroupSize = GrpcConfigKeys.Client.workerGroupSize(properties);
final MemoizedSupplier<EventLoopGroup> eventLoopGroup = workerGroupSize > 0 ? MemoizedSupplier.valueOf(
() -> NettyUtils.newEventLoopGroup(
clientId + "-client-workers", workerGroupSize, GrpcConfigKeys.useEpoll(properties))) : null;
return new GrpcClientRpc(clientId, properties, adminSslContext, clientSslContext, eventLoopGroup);
}

private final ClientId clientId;
private final int maxMessageSize;
private final TimeDuration requestTimeoutDuration;
private final TimeDuration watchRequestTimeoutDuration;
private final MemoizedSupplier<EventLoopGroup> clientWorkers;

public GrpcClientRpc(ClientId clientId, RaftProperties properties,
SslContext adminSslContext, SslContext clientSslContext) {
private GrpcClientRpc(ClientId clientId, RaftProperties properties,
SslContext adminSslContext, SslContext clientSslContext, MemoizedSupplier<EventLoopGroup> clientWorkers) {
super(new PeerProxyMap<>(clientId.toString(),
p -> new GrpcClientProtocolClient(clientId, p, properties, adminSslContext, clientSslContext)));
p -> new GrpcClientProtocolClient(clientId, p, properties, adminSslContext, clientSslContext, clientWorkers)));
this.clientWorkers = clientWorkers;

this.clientId = clientId;
this.maxMessageSize = GrpcConfigKeys.messageSizeMax(properties, LOG::debug).getSizeInt();
this.requestTimeoutDuration = RaftClientConfigKeys.Rpc.requestTimeout(properties);
Expand Down Expand Up @@ -213,6 +228,17 @@ private RaftClientRequestProto toRaftClientRequestProto(RaftClientRequest reques
return proto;
}

@Override
public void close() {
try {
super.close();
} finally {
if (clientWorkers != null && clientWorkers.isInitialized()) {
NettyUtils.shutdownGracefully(clientWorkers.get());
}
}
}

@Override
public boolean shouldReconnect(Throwable e) {
final Throwable cause = e.getCause();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.ratis.grpc.server;

import org.apache.ratis.grpc.GrpcUtil;
import org.apache.ratis.util.NettyUtils;
import org.apache.ratis.grpc.util.StreamObserverWithTimeout;
import org.apache.ratis.protocol.RaftPeerId;
import org.apache.ratis.server.util.ServerStringUtils;
Expand All @@ -31,6 +32,7 @@
import org.apache.ratis.proto.grpc.RaftServerProtocolServiceGrpc.RaftServerProtocolServiceBlockingStub;
import org.apache.ratis.proto.grpc.RaftServerProtocolServiceGrpc.RaftServerProtocolServiceStub;
import org.apache.ratis.protocol.RaftPeer;
import org.apache.ratis.thirdparty.io.netty.channel.EventLoopGroup;
import org.apache.ratis.thirdparty.io.netty.handler.ssl.SslContext;
import org.apache.ratis.util.TimeDuration;
import org.slf4j.Logger;
Expand All @@ -57,12 +59,15 @@ class GrpcServerProtocolClient implements Closeable {
private static final Logger LOG = LoggerFactory.getLogger(GrpcServerProtocolClient.class);
//visible for using in log / error messages AND to use in instrumented tests
private final RaftPeerId raftPeerId;
private final EventLoopGroup eventLoopGroup;

GrpcServerProtocolClient(RaftPeer target, int connections, int flowControlWindow,
TimeDuration requestTimeout, SslContext sslContext, boolean separateHBChannel) {
TimeDuration requestTimeout, SslContext sslContext, boolean separateHBChannel,
EventLoopGroup eventLoopGroup) {
raftPeerId = target.getId();
LOG.info("Build channel for {}", target);
useSeparateHBChannel = separateHBChannel;
this.eventLoopGroup = eventLoopGroup;
channel = buildChannel(target, flowControlWindow, sslContext);
blockingStub = RaftServerProtocolServiceGrpc.newBlockingStub(channel);
asyncStub = RaftServerProtocolServiceGrpc.newStub(channel);
Expand All @@ -75,7 +80,8 @@ class GrpcServerProtocolClient implements Closeable {
}

GrpcStubPool<RaftServerProtocolServiceStub> newGrpcStubPool(String address, SslContext sslContext, int connections) {
return new GrpcStubPool<>(connections, address, sslContext, RaftServerProtocolServiceGrpc::newStub, 16);
return new GrpcStubPool<>(connections, address, sslContext, RaftServerProtocolServiceGrpc::newStub, 16,
eventLoopGroup);
}

private ManagedChannel buildChannel(RaftPeer target, int flowControlWindow, SslContext sslContext) {
Expand All @@ -90,6 +96,10 @@ private ManagedChannel buildChannel(RaftPeer target, int flowControlWindow, SslC
channelBuilder.negotiationType(NegotiationType.PLAINTEXT);
}
channelBuilder.disableRetry();
if (eventLoopGroup != null) {
channelBuilder.channelType(NettyUtils.getSocketChannelClass(eventLoopGroup))
.eventLoopGroup(eventLoopGroup);
}
return channelBuilder.flowControlWindow(flowControlWindow).build();
}

Expand Down
Loading
Loading