Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 119 additions & 1 deletion crates/rmcp/src/transport/streamable_http_server/tower.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::{convert::Infallible, fmt::Display, sync::Arc, time::Duration};

use bytes::Bytes;
use futures::{StreamExt, future::BoxFuture};
use http::{Method, Request, Response, header::ALLOW};
use http::{HeaderMap, Method, Request, Response, Uri, header::ALLOW};
use http_body::Body;
use http_body_util::{BodyExt, Full, combinators::BoxBody};
use tokio_stream::wrappers::ReceiverStream;
Expand Down Expand Up @@ -48,6 +48,12 @@ pub struct StreamableHttpServerConfig {
/// When this token is cancelled, all active sessions are terminated and
/// the server stops accepting new requests.
pub cancellation_token: CancellationToken,
/// Allowed hostnames for inbound `Host` / `Origin` validation.
///
/// By default, Streamable HTTP servers only accept loopback hosts to
/// prevent DNS rebinding attacks against locally running servers. Public
/// deployments should override this list with their own hostnames.
pub allowed_hosts: Vec<String>,
Comment on lines +52 to +56
Copy link

Copilot AI Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding the new pub allowed_hosts field to StreamableHttpServerConfig is a semver-breaking change for downstream users constructing the config via struct literals (they will now fail to compile). If this is intended, it should be paired with a major version bump / explicit release note; otherwise consider making the config #[non_exhaustive] and/or moving toward a constructor/builder pattern to avoid future breakage.

Copilot uses AI. Check for mistakes.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is especially important given that PR #715 was recently merged to prepare for 1.0 stable release. StreamableHttpServerConfig may have been missed in that effort. Adding #[non_exhaustive] here would be consistent with that direction and prevent this class of breakage going forward.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this security update should be introduced in 1.0 version , and I will add the #[non_exhaustive].

}

impl Default for StreamableHttpServerConfig {
Expand All @@ -58,10 +64,26 @@ impl Default for StreamableHttpServerConfig {
stateful_mode: true,
json_response: false,
cancellation_token: CancellationToken::new(),
allowed_hosts: vec!["localhost".into(), "127.0.0.1".into(), "::1".into()],
}
}
}

impl StreamableHttpServerConfig {
pub fn with_allowed_hosts(
mut self,
allowed_hosts: impl IntoIterator<Item = impl Into<String>>,
) -> Self {
self.allowed_hosts = allowed_hosts.into_iter().map(Into::into).collect();
self
}
/// Disable allowed hosts. This will allow requests with any `Host` or `Origin` header, which is NOT recommended for public deployments.
pub fn disable_allowed_hosts(mut self) -> Self {
self.allowed_hosts.clear();
self
}
}

#[expect(
clippy::result_large_err,
reason = "BoxResponse is intentionally large; matches other handlers in this file"
Expand Down Expand Up @@ -102,6 +124,99 @@ fn validate_protocol_version_header(headers: &http::HeaderMap) -> Result<(), Box
Ok(())
}

fn forbidden_response(message: impl Into<String>) -> BoxResponse {
Response::builder()
.status(http::StatusCode::FORBIDDEN)
.body(Full::new(Bytes::from(message.into())).boxed())
.expect("valid response")
}

fn normalize_host(host: &str) -> String {
host.trim_matches('[')
.trim_matches(']')
.to_ascii_lowercase()
}
Comment on lines +134 to +138
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

normalize_host does not preserve the port, and parse_host_header uses authority.host(), which also strips the port. As a result, both localhost:8000 and localhost:9999 normalize to localhost and pass validation equally. This is significant because multiple services can run on different ports of the same host. Without port validation, a DNS rebinding attack could target a different service on localhost:9999, and the check would pass since localhost is in the allow list. Would you consider adding similar port-aware validation here?

Copy link
Member Author

@jokemanfire jokemanfire Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes , It should be considerd.


fn host_is_allowed(host: &str, allowed_hosts: &[String]) -> bool {
if allowed_hosts.is_empty() {
// If the allowed hosts list is empty, allow all hosts (not recommended).
return true;
}
let normalized = normalize_host(host);
allowed_hosts
.iter()
.any(|allowed| normalize_host(allowed) == normalized)
}

fn parse_host_header(headers: &HeaderMap) -> Result<Option<String>, BoxResponse> {
let Some(host) = headers.get(http::header::HOST) else {
return Ok(None);
};

let host = host
.to_str()
.map_err(|_| forbidden_response("Forbidden: Invalid Host header encoding"))?;
let authority = http::uri::Authority::try_from(host)
.map_err(|_| forbidden_response("Forbidden: Invalid Host header"))?;
Ok(Some(normalize_host(authority.host())))
}

fn parse_origin_host(headers: &HeaderMap) -> Result<Option<String>, BoxResponse> {
let Some(origin) = headers.get(http::header::ORIGIN) else {
return Ok(None);
};

let origin = origin
.to_str()
.map_err(|_| forbidden_response("Forbidden: Invalid Origin header encoding"))?;
if origin.eq_ignore_ascii_case("null") {
return Err(forbidden_response("Forbidden: Invalid Origin header"));
}

let uri: Uri = origin
.parse()
.map_err(|_| forbidden_response("Forbidden: Invalid Origin header"))?;
let Some(authority) = uri.authority() else {
return Err(forbidden_response("Forbidden: Invalid Origin header"));
};
let Some(scheme) = uri.scheme_str() else {
return Err(forbidden_response("Forbidden: Invalid Origin header"));
};
if !matches!(scheme, "http" | "https") {
return Err(forbidden_response("Forbidden: Invalid Origin header"));
}

Ok(Some(normalize_host(authority.host())))
}
Comment on lines +189 to +190
Copy link

Copilot AI Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

parse_host_header/parse_origin_host currently discard the port by extracting only authority.host(). That makes the later Host/Origin comparison effectively “hostname-only”, allowing Host: localhost:8080 with Origin: http://localhost:9999 to pass. Consider preserving and comparing the full authority (host + optional port), or at least requiring ports to match when present (accounting for default ports).

Copilot uses AI. Check for mistakes.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great


fn validate_dns_rebinding_headers(
headers: &HeaderMap,
config: &StreamableHttpServerConfig,
) -> Result<(), BoxResponse> {
let host = parse_host_header(headers)?;
if let Some(host) = host.as_deref() {
if !host_is_allowed(host, &config.allowed_hosts) {
return Err(forbidden_response("Forbidden: Host header is not allowed"));
}
}

let origin = parse_origin_host(headers)?;
if let Some(origin) = origin.as_deref() {
if !host_is_allowed(origin, &config.allowed_hosts) {
Copy link
Member

@DaleSeo DaleSeo Mar 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it intentional to checking Origin against allowed_hosts here? allowed_hosts answers "what hostnames is this server known as?" while Origin represents the source page that initiated the request. These are different concepts.

For example, consider an MCP server at mcp.example.com being called by a frontend at app.example.com:

allowed_hosts: ["mcp.example.com"]

Host: mcp.example.com           // ✅ this IS the server
Origin: http://app.example.com  // ❌ rejected, but this is a legitimate caller

From my understanding, restricting which origins can call the server is a CORS concern. For DNS rebinding protection, validating only the Host header should be sufficient.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well ,thanks ,I will take some time to resolve this.

return Err(forbidden_response(
"Forbidden: Origin header is not allowed",
));
}
if let Some(host) = host.as_deref() {
if origin != host {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Host header identifies the destination server while Origin identifies the source page. They will naturally differ in any cross-origin scenario.

Would you consider removing the Origin validation entirely from this DNS rebinding guard and leaving origin-based restrictions to CORS middleware where they belong?

return Err(forbidden_response("Forbidden: Origin does not match Host"));
}
}
}

Ok(())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to reject requests that are missing the Host header when allowed_hosts is non-empty?

TypeScript SDK's DNS rebinding guard rejects requests with missing headers as well: https://github.com/modelcontextprotocol/typescript-sdk/blob/main/packages/server/src/server/middleware/hostHeaderValidation.ts

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keep same with ts.

}

/// # Streamable HTTP server
///
/// An HTTP service that implements the
Expand Down Expand Up @@ -251,6 +366,9 @@ where
B: Body + Send + 'static,
B::Error: Display,
{
if let Err(response) = validate_dns_rebinding_headers(request.headers(), &self.config) {
return response;
}
let method = request.method().clone();
let allowed_methods = match self.config.stateful_mode {
true => "GET, POST, DELETE",
Expand Down
83 changes: 83 additions & 0 deletions crates/rmcp/tests/test_custom_headers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -870,3 +870,86 @@ fn test_protocol_version_utilities() {
assert!(ProtocolVersion::KNOWN_VERSIONS.contains(&ProtocolVersion::V_2025_03_26));
assert!(ProtocolVersion::KNOWN_VERSIONS.contains(&ProtocolVersion::V_2025_06_18));
}

/// Integration test: Verify server validates Host and Origin headers for DNS rebinding protection
#[tokio::test]
#[cfg(all(feature = "transport-streamable-http-server", feature = "server",))]
async fn test_server_validates_host_and_origin_headers() {
use std::sync::Arc;

use bytes::Bytes;
use http::{Method, Request, header::CONTENT_TYPE};
use http_body_util::Full;
use rmcp::{
handler::server::ServerHandler,
model::{ServerCapabilities, ServerInfo},
transport::streamable_http_server::{
StreamableHttpServerConfig, StreamableHttpService, session::local::LocalSessionManager,
},
};
use serde_json::json;

#[derive(Clone)]
struct TestHandler;

impl ServerHandler for TestHandler {
fn get_info(&self) -> ServerInfo {
ServerInfo::new(ServerCapabilities::builder().build())
}
}

let service = StreamableHttpService::new(
|| Ok(TestHandler),
Arc::new(LocalSessionManager::default()),
StreamableHttpServerConfig::default(),
);

let init_body = json!({
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {
"protocolVersion": "2025-03-26",
"capabilities": {},
"clientInfo": {
"name": "test-client",
"version": "1.0.0"
}
}
});

let allowed_request = Request::builder()
.method(Method::POST)
.header("Accept", "application/json, text/event-stream")
.header(CONTENT_TYPE, "application/json")
.header("Host", "localhost:8080")
.header("Origin", "http://localhost:8080")
.body(Full::new(Bytes::from(init_body.to_string())))
.unwrap();

let response = service.handle(allowed_request).await;
assert_eq!(response.status(), http::StatusCode::OK);

let bad_host_request = Request::builder()
.method(Method::POST)
.header("Accept", "application/json, text/event-stream")
.header(CONTENT_TYPE, "application/json")
.header("Host", "attacker.example")
.body(Full::new(Bytes::from(init_body.to_string())))
.unwrap();

let response = service.handle(bad_host_request).await;
assert_eq!(response.status(), http::StatusCode::FORBIDDEN);

let bad_origin_request = Request::builder()
.method(Method::POST)
.header("Accept", "application/json, text/event-stream")
.header(CONTENT_TYPE, "application/json")
.header("Host", "localhost:8080")
Comment on lines +944 to +948
Copy link

Copilot AI Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding a test case for the same-hostname/different-port scenario (e.g., Host: localhost:8080 with Origin: http://localhost:9999) and assert it is rejected. This guards against port-mismatch bypasses of the Host/Origin validation logic.

Copilot uses AI. Check for mistakes.
.header("Origin", "http://attacker.example")
.body(Full::new(Bytes::from(init_body.to_string())))
.unwrap();

let response = service.handle(bad_origin_request).await;
assert_eq!(response.status(), http::StatusCode::FORBIDDEN);
}
Loading