Skip to content

Commit 6959377

Browse files
committed
Only attempt RFC8441 upgrade if we know the strategy supports it
Signed-off-by: Jared Wiltshire <[email protected]>
1 parent 6c95dc6 commit 6959377

File tree

3 files changed

+57
-20
lines changed

3 files changed

+57
-20
lines changed

spring-websocket/src/main/java/org/springframework/web/socket/server/RequestUpgradeStrategy.java

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,12 @@
1717
package org.springframework.web.socket.server;
1818

1919
import java.security.Principal;
20+
import java.util.EnumSet;
2021
import java.util.List;
2122
import java.util.Map;
23+
import java.util.Set;
2224

25+
import org.springframework.http.HttpMethod;
2326
import org.springframework.http.server.ServerHttpRequest;
2427
import org.springframework.http.server.ServerHttpResponse;
2528
import org.springframework.lang.Nullable;
@@ -35,6 +38,25 @@
3538
*/
3639
public interface RequestUpgradeStrategy {
3740

41+
enum OpeningHandshake {
42+
RFC6455(HttpMethod.GET),
43+
RFC8441(HttpMethod.valueOf("CONNECT"));
44+
45+
private final HttpMethod method;
46+
47+
OpeningHandshake(HttpMethod method) {
48+
this.method = method;
49+
}
50+
51+
public HttpMethod getMethod() {
52+
return method;
53+
}
54+
}
55+
56+
default Set<OpeningHandshake> getSupportedOpeningHandshake() {
57+
return EnumSet.of(OpeningHandshake.RFC6455);
58+
}
59+
3860
/**
3961
* Return the supported WebSocket protocol versions.
4062
*/

spring-websocket/src/main/java/org/springframework/web/socket/server/jetty/JettyRequestUpgradeStrategy.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,10 @@
1919
import java.lang.reflect.UndeclaredThrowableException;
2020
import java.security.Principal;
2121
import java.util.Collections;
22+
import java.util.EnumSet;
2223
import java.util.List;
2324
import java.util.Map;
25+
import java.util.Set;
2426
import java.util.function.Consumer;
2527

2628
import jakarta.servlet.ServletContext;
@@ -59,6 +61,11 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy, Serv
5961
private Consumer<Configurable> webSocketConfigurer;
6062

6163

64+
@Override
65+
public Set<OpeningHandshake> getSupportedOpeningHandshake() {
66+
return EnumSet.of(OpeningHandshake.RFC6455, OpeningHandshake.RFC8441);
67+
}
68+
6269
@Override
6370
public String[] getSupportedVersions() {
6471
return SUPPORTED_VERSIONS;

spring-websocket/src/main/java/org/springframework/web/socket/server/support/AbstractHandshakeHandler.java

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import java.util.Locale;
2727
import java.util.Map;
2828
import java.util.Set;
29+
import java.util.stream.Collectors;
2930

3031
import org.apache.commons.logging.Log;
3132
import org.apache.commons.logging.LogFactory;
@@ -48,6 +49,7 @@
4849
import org.springframework.web.socket.server.HandshakeFailureException;
4950
import org.springframework.web.socket.server.HandshakeHandler;
5051
import org.springframework.web.socket.server.RequestUpgradeStrategy;
52+
import org.springframework.web.socket.server.RequestUpgradeStrategy.OpeningHandshake;
5153
import org.springframework.web.socket.server.jetty.JettyRequestUpgradeStrategy;
5254
import org.springframework.web.socket.server.standard.GlassFishRequestUpgradeStrategy;
5355
import org.springframework.web.socket.server.standard.StandardWebSocketUpgradeStrategy;
@@ -78,9 +80,6 @@
7880
*/
7981
public abstract class AbstractHandshakeHandler implements HandshakeHandler, Lifecycle {
8082

81-
// For WebSocket upgrades in HTTP/2 (see RFC 8441)
82-
private static final HttpMethod CONNECT_METHOD = HttpMethod.valueOf("CONNECT");
83-
8483
private static final boolean tomcatWsPresent;
8584

8685
private static final boolean jettyWsPresent;
@@ -215,15 +214,16 @@ public final boolean doHandshake(ServerHttpRequest request, ServerHttpResponse r
215214
}
216215
try {
217216
HttpMethod httpMethod = request.getMethod();
218-
if (HttpMethod.GET != httpMethod && !CONNECT_METHOD.equals(httpMethod)) {
219-
response.setStatusCode(HttpStatus.METHOD_NOT_ALLOWED);
220-
response.getHeaders().setAllow(Set.of(HttpMethod.GET, CONNECT_METHOD));
221-
if (logger.isErrorEnabled()) {
222-
logger.error("Handshake failed due to unexpected HTTP method: " + httpMethod);
217+
Set<OpeningHandshake> supportedHandshakes = requestUpgradeStrategy.getSupportedOpeningHandshake();
218+
OpeningHandshake handshake = null;
219+
for (OpeningHandshake h : supportedHandshakes) {
220+
if (h.getMethod().equals(httpMethod)) {
221+
handshake = h;
222+
break;
223223
}
224-
return false;
225224
}
226-
if (HttpMethod.GET == httpMethod) {
225+
226+
if (handshake == OpeningHandshake.RFC6455) {
227227
if (!"WebSocket".equalsIgnoreCase(headers.getUpgrade())) {
228228
handleInvalidUpgradeHeader(request, response);
229229
return false;
@@ -232,16 +232,6 @@ public final boolean doHandshake(ServerHttpRequest request, ServerHttpResponse r
232232
handleInvalidConnectHeader(request, response);
233233
return false;
234234
}
235-
}
236-
if (!isWebSocketVersionSupported(headers)) {
237-
handleWebSocketVersionNotSupported(request, response);
238-
return false;
239-
}
240-
if (!isValidOrigin(request)) {
241-
response.setStatusCode(HttpStatus.FORBIDDEN);
242-
return false;
243-
}
244-
if (HttpMethod.GET == httpMethod) {
245235
String wsKey = headers.getSecWebSocketKey();
246236
if (wsKey == null) {
247237
if (logger.isErrorEnabled()) {
@@ -250,6 +240,24 @@ public final boolean doHandshake(ServerHttpRequest request, ServerHttpResponse r
250240
response.setStatusCode(HttpStatus.BAD_REQUEST);
251241
return false;
252242
}
243+
} else if (handshake == null) {
244+
response.setStatusCode(HttpStatus.METHOD_NOT_ALLOWED);
245+
Set<HttpMethod> methods = supportedHandshakes.stream()
246+
.map(OpeningHandshake::getMethod)
247+
.collect(Collectors.toUnmodifiableSet());
248+
response.getHeaders().setAllow(methods);
249+
if (logger.isErrorEnabled()) {
250+
logger.error("Handshake failed due to unexpected HTTP method: " + httpMethod);
251+
}
252+
return false;
253+
}
254+
if (!isWebSocketVersionSupported(headers)) {
255+
handleWebSocketVersionNotSupported(request, response);
256+
return false;
257+
}
258+
if (!isValidOrigin(request)) {
259+
response.setStatusCode(HttpStatus.FORBIDDEN);
260+
return false;
253261
}
254262
}
255263
catch (IOException ex) {

0 commit comments

Comments
 (0)