From 6c95dc6e4776f05dcf63dc096de2c228abbc2243 Mon Sep 17 00:00:00 2001 From: Jared Wiltshire Date: Mon, 3 Feb 2025 16:45:47 -0700 Subject: [PATCH] Fix HTTP/2 CONNECT WebSocket upgrades (RFC 8441) Closes gh-34362 Signed-off-by: Jared Wiltshire --- .../support/HandshakeWebSocketService.java | 24 +++++++------- .../web/socket/WebSocketHttpHeaders.java | 2 +- .../support/AbstractHandshakeHandler.java | 32 +++++++++++-------- 3 files changed, 32 insertions(+), 26 deletions(-) diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/socket/server/support/HandshakeWebSocketService.java b/spring-webflux/src/main/java/org/springframework/web/reactive/socket/server/support/HandshakeWebSocketService.java index c54f38d9bc5b..d7577a02517e 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/socket/server/support/HandshakeWebSocketService.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/socket/server/support/HandshakeWebSocketService.java @@ -205,23 +205,25 @@ public Mono handleRequest(ServerWebExchange exchange, WebSocketHandler han HttpMethod method = request.getMethod(); HttpHeaders headers = request.getHeaders(); - if (HttpMethod.GET != method && CONNECT_METHOD != method) { + if (HttpMethod.GET != method && !CONNECT_METHOD.equals(method)) { return Mono.error(new MethodNotAllowedException( request.getMethod(), Set.of(HttpMethod.GET, CONNECT_METHOD))); } - if (!"WebSocket".equalsIgnoreCase(headers.getUpgrade())) { - return handleBadRequest(exchange, "Invalid 'Upgrade' header: " + headers); - } + if (HttpMethod.GET == method) { + if (!"WebSocket".equalsIgnoreCase(headers.getUpgrade())) { + return handleBadRequest(exchange, "Invalid 'Upgrade' header: " + headers); + } - List connectionValue = headers.getConnection(); - if (!connectionValue.contains("Upgrade") && !connectionValue.contains("upgrade")) { - return handleBadRequest(exchange, "Invalid 'Connection' header: " + headers); - } + List connectionValue = headers.getConnection(); + if (!connectionValue.contains("Upgrade") && !connectionValue.contains("upgrade")) { + return handleBadRequest(exchange, "Invalid 'Connection' header: " + headers); + } - String key = headers.getFirst(SEC_WEBSOCKET_KEY); - if (key == null) { - return handleBadRequest(exchange, "Missing \"Sec-WebSocket-Key\" header"); + String key = headers.getFirst(SEC_WEBSOCKET_KEY); + if (key == null) { + return handleBadRequest(exchange, "Missing \"Sec-WebSocket-Key\" header"); + } } String protocol = selectProtocol(headers, handler); diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/WebSocketHttpHeaders.java b/spring-websocket/src/main/java/org/springframework/web/socket/WebSocketHttpHeaders.java index fa4c9037b83c..1a4fac7f881e 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/WebSocketHttpHeaders.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/WebSocketHttpHeaders.java @@ -151,7 +151,7 @@ public void setSecWebSocketProtocol(List secWebSocketProtocols) { } /** - * Returns the value of the {@code Sec-WebSocket-Key} header. + * Returns the value of the {@code Sec-WebSocket-Protocol} header. * @return the value of the header */ public List getSecWebSocketProtocol() { diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractHandshakeHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractHandshakeHandler.java index acde43c3cc59..fce20644c16f 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractHandshakeHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractHandshakeHandler.java @@ -215,7 +215,7 @@ public final boolean doHandshake(ServerHttpRequest request, ServerHttpResponse r } try { HttpMethod httpMethod = request.getMethod(); - if (HttpMethod.GET != httpMethod && CONNECT_METHOD != httpMethod) { + if (HttpMethod.GET != httpMethod && !CONNECT_METHOD.equals(httpMethod)) { response.setStatusCode(HttpStatus.METHOD_NOT_ALLOWED); response.getHeaders().setAllow(Set.of(HttpMethod.GET, CONNECT_METHOD)); if (logger.isErrorEnabled()) { @@ -223,13 +223,15 @@ public final boolean doHandshake(ServerHttpRequest request, ServerHttpResponse r } return false; } - if (!"WebSocket".equalsIgnoreCase(headers.getUpgrade())) { - handleInvalidUpgradeHeader(request, response); - return false; - } - if (!headers.getConnection().contains("Upgrade") && !headers.getConnection().contains("upgrade")) { - handleInvalidConnectHeader(request, response); - return false; + if (HttpMethod.GET == httpMethod) { + if (!"WebSocket".equalsIgnoreCase(headers.getUpgrade())) { + handleInvalidUpgradeHeader(request, response); + return false; + } + if (!headers.getConnection().contains("Upgrade") && !headers.getConnection().contains("upgrade")) { + handleInvalidConnectHeader(request, response); + return false; + } } if (!isWebSocketVersionSupported(headers)) { handleWebSocketVersionNotSupported(request, response); @@ -239,13 +241,15 @@ public final boolean doHandshake(ServerHttpRequest request, ServerHttpResponse r response.setStatusCode(HttpStatus.FORBIDDEN); return false; } - String wsKey = headers.getSecWebSocketKey(); - if (wsKey == null) { - if (logger.isErrorEnabled()) { - logger.error("Missing \"Sec-WebSocket-Key\" header"); + if (HttpMethod.GET == httpMethod) { + String wsKey = headers.getSecWebSocketKey(); + if (wsKey == null) { + if (logger.isErrorEnabled()) { + logger.error("Missing \"Sec-WebSocket-Key\" header"); + } + response.setStatusCode(HttpStatus.BAD_REQUEST); + return false; } - response.setStatusCode(HttpStatus.BAD_REQUEST); - return false; } } catch (IOException ex) {