17
17
package org .apache .rocketmq .proxy .grpc ;
18
18
19
19
import io .grpc .netty .shaded .io .grpc .netty .GrpcHttp2ConnectionHandler ;
20
+ import io .grpc .netty .shaded .io .grpc .netty .GrpcSslContexts ;
20
21
import io .grpc .netty .shaded .io .grpc .netty .InternalProtocolNegotiationEvent ;
21
22
import io .grpc .netty .shaded .io .grpc .netty .InternalProtocolNegotiator ;
22
23
import io .grpc .netty .shaded .io .grpc .netty .InternalProtocolNegotiators ;
23
24
import io .grpc .netty .shaded .io .netty .buffer .ByteBuf ;
24
25
import io .grpc .netty .shaded .io .netty .channel .ChannelHandler ;
25
26
import io .grpc .netty .shaded .io .netty .channel .ChannelHandlerContext ;
26
27
import io .grpc .netty .shaded .io .netty .handler .codec .ByteToMessageDecoder ;
28
+ import io .grpc .netty .shaded .io .netty .handler .ssl .ClientAuth ;
27
29
import io .grpc .netty .shaded .io .netty .handler .ssl .SslContext ;
28
30
import io .grpc .netty .shaded .io .netty .handler .ssl .SslHandler ;
31
+ import io .grpc .netty .shaded .io .netty .handler .ssl .util .InsecureTrustManagerFactory ;
32
+ import io .grpc .netty .shaded .io .netty .handler .ssl .util .SelfSignedCertificate ;
29
33
import io .grpc .netty .shaded .io .netty .util .AsciiString ;
34
+ import java .io .InputStream ;
35
+ import java .nio .file .Files ;
36
+ import java .nio .file .Paths ;
30
37
import java .util .List ;
31
38
import org .apache .rocketmq .common .constant .LoggerName ;
32
39
import org .apache .rocketmq .logging .org .slf4j .Logger ;
33
40
import org .apache .rocketmq .logging .org .slf4j .LoggerFactory ;
41
+ import org .apache .rocketmq .proxy .config .ConfigurationManager ;
42
+ import org .apache .rocketmq .proxy .config .ProxyConfig ;
43
+ import org .apache .rocketmq .remoting .common .TlsMode ;
44
+ import org .apache .rocketmq .remoting .netty .TlsSystemConfig ;
34
45
35
46
public class OptionalSSLProtocolNegotiator implements InternalProtocolNegotiator .ProtocolNegotiator {
36
- private static final Logger log = LoggerFactory .getLogger (LoggerName .PROXY_LOGGER_NAME );
37
- private final SslContext sslContext ;
47
+ protected static final Logger log = LoggerFactory .getLogger (LoggerName .PROXY_LOGGER_NAME );
48
+
38
49
/**
39
50
* the length of the ssl record header (in bytes)
40
51
*/
41
52
private static final int SSL_RECORD_HEADER_LENGTH = 5 ;
42
53
43
- public OptionalSSLProtocolNegotiator (SslContext sslContext ) {
44
- this .sslContext = sslContext ;
54
+ private static SslContext sslContext ;
55
+
56
+ public OptionalSSLProtocolNegotiator () {
57
+ sslContext = loadSslContext ();
45
58
}
46
59
47
60
@ Override
@@ -50,43 +63,81 @@ public AsciiString scheme() {
50
63
}
51
64
52
65
@ Override
53
- public ChannelHandler newHandler (GrpcHttp2ConnectionHandler grpcHttp2ConnectionHandler ) {
54
- ChannelHandler plaintext =
55
- InternalProtocolNegotiators .serverPlaintext ().newHandler (grpcHttp2ConnectionHandler );
56
- ChannelHandler ssl =
57
- InternalProtocolNegotiators .serverTls (sslContext ).newHandler (grpcHttp2ConnectionHandler );
58
- return new PortUnificationServerHandler (ssl , plaintext );
66
+ public ChannelHandler newHandler (GrpcHttp2ConnectionHandler grpcHandler ) {
67
+ return new PortUnificationServerHandler (grpcHandler );
59
68
}
60
69
61
70
@ Override
62
71
public void close () {}
63
72
73
+ private static SslContext loadSslContext () {
74
+ try {
75
+ ProxyConfig proxyConfig = ConfigurationManager .getProxyConfig ();
76
+ if (proxyConfig .isTlsTestModeEnable ()) {
77
+ SelfSignedCertificate selfSignedCertificate = new SelfSignedCertificate ();
78
+ return GrpcSslContexts .forServer (selfSignedCertificate .certificate (),
79
+ selfSignedCertificate .privateKey ())
80
+ .trustManager (InsecureTrustManagerFactory .INSTANCE )
81
+ .clientAuth (ClientAuth .NONE )
82
+ .build ();
83
+ } else {
84
+ String tlsKeyPath = ConfigurationManager .getProxyConfig ().getTlsKeyPath ();
85
+ String tlsCertPath = ConfigurationManager .getProxyConfig ().getTlsCertPath ();
86
+ try (InputStream serverKeyInputStream = Files .newInputStream (
87
+ Paths .get (tlsKeyPath ));
88
+ InputStream serverCertificateStream = Files .newInputStream (
89
+ Paths .get (tlsCertPath ))) {
90
+ SslContext res = GrpcSslContexts .forServer (serverCertificateStream ,
91
+ serverKeyInputStream )
92
+ .trustManager (InsecureTrustManagerFactory .INSTANCE )
93
+ .clientAuth (ClientAuth .NONE )
94
+ .build ();
95
+ log .info ("grpc load TLS configured OK" );
96
+ return res ;
97
+ }
98
+ }
99
+ } catch (Exception e ) {
100
+ log .error ("grpc tls set failed. msg: {}, e:" , e .getMessage (), e );
101
+ throw new RuntimeException ("grpc tls set failed: " + e .getMessage ());
102
+ }
103
+ }
104
+
64
105
public static class PortUnificationServerHandler extends ByteToMessageDecoder {
106
+
65
107
private final ChannelHandler ssl ;
66
108
private final ChannelHandler plaintext ;
67
109
68
- public PortUnificationServerHandler (ChannelHandler ssl , ChannelHandler plaintext ) {
69
- this .ssl = ssl ;
70
- this .plaintext = plaintext ;
110
+ public PortUnificationServerHandler (GrpcHttp2ConnectionHandler grpcHandler ) {
111
+ this .ssl = InternalProtocolNegotiators .serverTls (sslContext )
112
+ .newHandler (grpcHandler );
113
+ this .plaintext = InternalProtocolNegotiators .serverPlaintext ()
114
+ .newHandler (grpcHandler );
71
115
}
72
116
73
117
@ Override
74
118
protected void decode (ChannelHandlerContext ctx , ByteBuf in , List <Object > out )
75
- throws Exception {
119
+ throws Exception {
76
120
try {
77
- // in SslHandler.isEncrypted, it need at least 5 bytes to judge is encrypted or not
78
- if (in .readableBytes () < SSL_RECORD_HEADER_LENGTH ) {
79
- return ;
80
- }
81
- if (SslHandler .isEncrypted (in )) {
121
+ TlsMode tlsMode = TlsSystemConfig .tlsMode ;
122
+ if (TlsMode .ENFORCING .equals (tlsMode )) {
82
123
ctx .pipeline ().addAfter (ctx .name (), null , this .ssl );
83
- } else {
124
+ } else if ( TlsMode . DISABLED . equals ( tlsMode )) {
84
125
ctx .pipeline ().addAfter (ctx .name (), null , this .plaintext );
126
+ } else {
127
+ // in SslHandler.isEncrypted, it need at least 5 bytes to judge is encrypted or not
128
+ if (in .readableBytes () < SSL_RECORD_HEADER_LENGTH ) {
129
+ return ;
130
+ }
131
+ if (SslHandler .isEncrypted (in )) {
132
+ ctx .pipeline ().addAfter (ctx .name (), null , this .ssl );
133
+ } else {
134
+ ctx .pipeline ().addAfter (ctx .name (), null , this .plaintext );
135
+ }
85
136
}
86
137
ctx .fireUserEventTriggered (InternalProtocolNegotiationEvent .getDefault ());
87
138
ctx .pipeline ().remove (this );
88
139
} catch (Exception e ) {
89
- log .error ("process protocol negotiator failed." , e );
140
+ log .error ("process ssl protocol negotiator failed." , e );
90
141
throw e ;
91
142
}
92
143
}
0 commit comments