Skip to content

Commit

Permalink
Only attempt RFC8441 upgrade if we know the strategy supports it
Browse files Browse the repository at this point in the history
Signed-off-by: Jared Wiltshire <[email protected]>
  • Loading branch information
jazdw committed Feb 5, 2025
1 parent 6c95dc6 commit 6959377
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,12 @@
package org.springframework.web.socket.server;

import java.security.Principal;
import java.util.EnumSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

import org.springframework.http.HttpMethod;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.lang.Nullable;
Expand All @@ -35,6 +38,25 @@
*/
public interface RequestUpgradeStrategy {

enum OpeningHandshake {
RFC6455(HttpMethod.GET),
RFC8441(HttpMethod.valueOf("CONNECT"));

private final HttpMethod method;

OpeningHandshake(HttpMethod method) {
this.method = method;
}

public HttpMethod getMethod() {
return method;
}
}

default Set<OpeningHandshake> getSupportedOpeningHandshake() {
return EnumSet.of(OpeningHandshake.RFC6455);
}

/**
* Return the supported WebSocket protocol versions.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@
import java.lang.reflect.UndeclaredThrowableException;
import java.security.Principal;
import java.util.Collections;
import java.util.EnumSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Consumer;

import jakarta.servlet.ServletContext;
Expand Down Expand Up @@ -59,6 +61,11 @@ public class JettyRequestUpgradeStrategy implements RequestUpgradeStrategy, Serv
private Consumer<Configurable> webSocketConfigurer;


@Override
public Set<OpeningHandshake> getSupportedOpeningHandshake() {
return EnumSet.of(OpeningHandshake.RFC6455, OpeningHandshake.RFC8441);
}

@Override
public String[] getSupportedVersions() {
return SUPPORTED_VERSIONS;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
Expand All @@ -48,6 +49,7 @@
import org.springframework.web.socket.server.HandshakeFailureException;
import org.springframework.web.socket.server.HandshakeHandler;
import org.springframework.web.socket.server.RequestUpgradeStrategy;
import org.springframework.web.socket.server.RequestUpgradeStrategy.OpeningHandshake;
import org.springframework.web.socket.server.jetty.JettyRequestUpgradeStrategy;
import org.springframework.web.socket.server.standard.GlassFishRequestUpgradeStrategy;
import org.springframework.web.socket.server.standard.StandardWebSocketUpgradeStrategy;
Expand Down Expand Up @@ -78,9 +80,6 @@
*/
public abstract class AbstractHandshakeHandler implements HandshakeHandler, Lifecycle {

// For WebSocket upgrades in HTTP/2 (see RFC 8441)
private static final HttpMethod CONNECT_METHOD = HttpMethod.valueOf("CONNECT");

private static final boolean tomcatWsPresent;

private static final boolean jettyWsPresent;
Expand Down Expand Up @@ -215,15 +214,16 @@ public final boolean doHandshake(ServerHttpRequest request, ServerHttpResponse r
}
try {
HttpMethod httpMethod = request.getMethod();
if (HttpMethod.GET != httpMethod && !CONNECT_METHOD.equals(httpMethod)) {
response.setStatusCode(HttpStatus.METHOD_NOT_ALLOWED);
response.getHeaders().setAllow(Set.of(HttpMethod.GET, CONNECT_METHOD));
if (logger.isErrorEnabled()) {
logger.error("Handshake failed due to unexpected HTTP method: " + httpMethod);
Set<OpeningHandshake> supportedHandshakes = requestUpgradeStrategy.getSupportedOpeningHandshake();
OpeningHandshake handshake = null;
for (OpeningHandshake h : supportedHandshakes) {
if (h.getMethod().equals(httpMethod)) {
handshake = h;
break;
}
return false;
}
if (HttpMethod.GET == httpMethod) {

if (handshake == OpeningHandshake.RFC6455) {
if (!"WebSocket".equalsIgnoreCase(headers.getUpgrade())) {
handleInvalidUpgradeHeader(request, response);
return false;
Expand All @@ -232,16 +232,6 @@ public final boolean doHandshake(ServerHttpRequest request, ServerHttpResponse r
handleInvalidConnectHeader(request, response);
return false;
}
}
if (!isWebSocketVersionSupported(headers)) {
handleWebSocketVersionNotSupported(request, response);
return false;
}
if (!isValidOrigin(request)) {
response.setStatusCode(HttpStatus.FORBIDDEN);
return false;
}
if (HttpMethod.GET == httpMethod) {
String wsKey = headers.getSecWebSocketKey();
if (wsKey == null) {
if (logger.isErrorEnabled()) {
Expand All @@ -250,6 +240,24 @@ public final boolean doHandshake(ServerHttpRequest request, ServerHttpResponse r
response.setStatusCode(HttpStatus.BAD_REQUEST);
return false;
}
} else if (handshake == null) {
response.setStatusCode(HttpStatus.METHOD_NOT_ALLOWED);
Set<HttpMethod> methods = supportedHandshakes.stream()
.map(OpeningHandshake::getMethod)
.collect(Collectors.toUnmodifiableSet());
response.getHeaders().setAllow(methods);
if (logger.isErrorEnabled()) {
logger.error("Handshake failed due to unexpected HTTP method: " + httpMethod);
}
return false;
}
if (!isWebSocketVersionSupported(headers)) {
handleWebSocketVersionNotSupported(request, response);
return false;
}
if (!isValidOrigin(request)) {
response.setStatusCode(HttpStatus.FORBIDDEN);
return false;
}
}
catch (IOException ex) {
Expand Down

0 comments on commit 6959377

Please sign in to comment.