26
26
import java .util .Locale ;
27
27
import java .util .Map ;
28
28
import java .util .Set ;
29
+ import java .util .stream .Collectors ;
29
30
30
31
import org .apache .commons .logging .Log ;
31
32
import org .apache .commons .logging .LogFactory ;
48
49
import org .springframework .web .socket .server .HandshakeFailureException ;
49
50
import org .springframework .web .socket .server .HandshakeHandler ;
50
51
import org .springframework .web .socket .server .RequestUpgradeStrategy ;
52
+ import org .springframework .web .socket .server .RequestUpgradeStrategy .OpeningHandshake ;
51
53
import org .springframework .web .socket .server .jetty .JettyRequestUpgradeStrategy ;
52
54
import org .springframework .web .socket .server .standard .GlassFishRequestUpgradeStrategy ;
53
55
import org .springframework .web .socket .server .standard .StandardWebSocketUpgradeStrategy ;
78
80
*/
79
81
public abstract class AbstractHandshakeHandler implements HandshakeHandler , Lifecycle {
80
82
81
- // For WebSocket upgrades in HTTP/2 (see RFC 8441)
82
- private static final HttpMethod CONNECT_METHOD = HttpMethod .valueOf ("CONNECT" );
83
-
84
83
private static final boolean tomcatWsPresent ;
85
84
86
85
private static final boolean jettyWsPresent ;
@@ -215,15 +214,16 @@ public final boolean doHandshake(ServerHttpRequest request, ServerHttpResponse r
215
214
}
216
215
try {
217
216
HttpMethod httpMethod = request .getMethod ();
218
- if (HttpMethod .GET != httpMethod && !CONNECT_METHOD .equals (httpMethod )) {
219
- response .setStatusCode (HttpStatus .METHOD_NOT_ALLOWED );
220
- response .getHeaders ().setAllow (Set .of (HttpMethod .GET , CONNECT_METHOD ));
221
- if (logger .isErrorEnabled ()) {
222
- logger .error ("Handshake failed due to unexpected HTTP method: " + httpMethod );
217
+ Set <OpeningHandshake > supportedHandshakes = requestUpgradeStrategy .getSupportedOpeningHandshake ();
218
+ OpeningHandshake handshake = null ;
219
+ for (OpeningHandshake h : supportedHandshakes ) {
220
+ if (h .getMethod ().equals (httpMethod )) {
221
+ handshake = h ;
222
+ break ;
223
223
}
224
- return false ;
225
224
}
226
- if (HttpMethod .GET == httpMethod ) {
225
+
226
+ if (handshake == OpeningHandshake .RFC6455 ) {
227
227
if (!"WebSocket" .equalsIgnoreCase (headers .getUpgrade ())) {
228
228
handleInvalidUpgradeHeader (request , response );
229
229
return false ;
@@ -232,16 +232,6 @@ public final boolean doHandshake(ServerHttpRequest request, ServerHttpResponse r
232
232
handleInvalidConnectHeader (request , response );
233
233
return false ;
234
234
}
235
- }
236
- if (!isWebSocketVersionSupported (headers )) {
237
- handleWebSocketVersionNotSupported (request , response );
238
- return false ;
239
- }
240
- if (!isValidOrigin (request )) {
241
- response .setStatusCode (HttpStatus .FORBIDDEN );
242
- return false ;
243
- }
244
- if (HttpMethod .GET == httpMethod ) {
245
235
String wsKey = headers .getSecWebSocketKey ();
246
236
if (wsKey == null ) {
247
237
if (logger .isErrorEnabled ()) {
@@ -250,6 +240,24 @@ public final boolean doHandshake(ServerHttpRequest request, ServerHttpResponse r
250
240
response .setStatusCode (HttpStatus .BAD_REQUEST );
251
241
return false ;
252
242
}
243
+ } else if (handshake == null ) {
244
+ response .setStatusCode (HttpStatus .METHOD_NOT_ALLOWED );
245
+ Set <HttpMethod > methods = supportedHandshakes .stream ()
246
+ .map (OpeningHandshake ::getMethod )
247
+ .collect (Collectors .toUnmodifiableSet ());
248
+ response .getHeaders ().setAllow (methods );
249
+ if (logger .isErrorEnabled ()) {
250
+ logger .error ("Handshake failed due to unexpected HTTP method: " + httpMethod );
251
+ }
252
+ return false ;
253
+ }
254
+ if (!isWebSocketVersionSupported (headers )) {
255
+ handleWebSocketVersionNotSupported (request , response );
256
+ return false ;
257
+ }
258
+ if (!isValidOrigin (request )) {
259
+ response .setStatusCode (HttpStatus .FORBIDDEN );
260
+ return false ;
253
261
}
254
262
}
255
263
catch (IOException ex ) {
0 commit comments