Skip to content

Allow custom fetch in SSEClientTransport and StreamableHTTPClientTransport #721

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jul 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions src/client/sse.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,38 @@ describe("SSEClientTransport", () => {
expect(lastServerRequest.headers.authorization).toBe(authToken);
});

it("uses custom fetch implementation from options", async () => {
const authToken = "Bearer custom-token";

const fetchWithAuth = jest.fn((url: string | URL, init?: RequestInit) => {
const headers = new Headers(init?.headers);
headers.set("Authorization", authToken);
return fetch(url.toString(), { ...init, headers });
});

transport = new SSEClientTransport(resourceBaseUrl, {
fetch: fetchWithAuth,
});

await transport.start();

expect(lastServerRequest.headers.authorization).toBe(authToken);

// Send a message to verify fetchWithAuth used for POST as well
const message: JSONRPCMessage = {
jsonrpc: "2.0",
id: "1",
method: "test",
params: {},
};

await transport.send(message);

expect(fetchWithAuth).toHaveBeenCalledTimes(2);
expect(lastServerRequest.method).toBe("POST");
expect(lastServerRequest.headers.authorization).toBe(authToken);
});

it("passes custom headers to fetch requests", async () => {
const customHeaders = {
Authorization: "Bearer test-token",
Expand Down
13 changes: 10 additions & 3 deletions src/client/sse.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { EventSource, type ErrorEvent, type EventSourceInit } from "eventsource";
import { Transport } from "../shared/transport.js";
import { Transport, FetchLike } from "../shared/transport.js";
import { JSONRPCMessage, JSONRPCMessageSchema } from "../types.js";
import { auth, AuthResult, extractResourceMetadataUrl, OAuthClientProvider, UnauthorizedError } from "./auth.js";

Expand Down Expand Up @@ -47,6 +47,11 @@ export type SSEClientTransportOptions = {
* Customizes recurring POST requests to the server.
*/
requestInit?: RequestInit;

/**
* Custom fetch implementation used for all network requests.
*/
fetch?: FetchLike;
};

/**
Expand All @@ -62,6 +67,7 @@ export class SSEClientTransport implements Transport {
private _eventSourceInit?: EventSourceInit;
private _requestInit?: RequestInit;
private _authProvider?: OAuthClientProvider;
private _fetch?: FetchLike;
private _protocolVersion?: string;

onclose?: () => void;
Expand All @@ -77,6 +83,7 @@ export class SSEClientTransport implements Transport {
this._eventSourceInit = opts?.eventSourceInit;
this._requestInit = opts?.requestInit;
this._authProvider = opts?.authProvider;
this._fetch = opts?.fetch;
}

private async _authThenStart(): Promise<void> {
Expand Down Expand Up @@ -117,7 +124,7 @@ export class SSEClientTransport implements Transport {
}

private _startOrAuth(): Promise<void> {
const fetchImpl = (this?._eventSourceInit?.fetch || fetch) as typeof fetch
const fetchImpl = (this?._eventSourceInit?.fetch ?? this._fetch ?? fetch) as typeof fetch
return new Promise((resolve, reject) => {
this._eventSource = new EventSource(
this._url.href,
Expand Down Expand Up @@ -242,7 +249,7 @@ export class SSEClientTransport implements Transport {
signal: this._abortController?.signal,
};

const response = await fetch(this._endpoint, init);
const response = await (this._fetch ?? fetch)(this._endpoint, init);
if (!response.ok) {
if (response.status === 401 && this._authProvider) {

Expand Down
33 changes: 31 additions & 2 deletions src/client/streamableHttp.test.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { StreamableHTTPClientTransport, StreamableHTTPReconnectionOptions } from "./streamableHttp.js";
import { StreamableHTTPClientTransport, StreamableHTTPReconnectionOptions, StartSSEOptions } from "./streamableHttp.js";
import { OAuthClientProvider, UnauthorizedError } from "./auth.js";
import { JSONRPCMessage } from "../types.js";

Expand Down Expand Up @@ -443,6 +443,35 @@ describe("StreamableHTTPClientTransport", () => {
expect(errorSpy).toHaveBeenCalled();
});

it("uses custom fetch implementation", async () => {
const authToken = "Bearer custom-token";

const fetchWithAuth = jest.fn((url: string | URL, init?: RequestInit) => {
const headers = new Headers(init?.headers);
headers.set("Authorization", authToken);
return (global.fetch as jest.Mock)(url, { ...init, headers });
});

(global.fetch as jest.Mock)
.mockResolvedValueOnce(
new Response(null, { status: 200, headers: { "content-type": "text/event-stream" } })
)
.mockResolvedValueOnce(new Response(null, { status: 202 }));

transport = new StreamableHTTPClientTransport(new URL("http://localhost:1234/mcp"), { fetch: fetchWithAuth });

await transport.start();
await (transport as unknown as { _startOrAuthSse: (opts: StartSSEOptions) => Promise<void> })._startOrAuthSse({});

await transport.send({ jsonrpc: "2.0", method: "test", params: {}, id: "1" } as JSONRPCMessage);

expect(fetchWithAuth).toHaveBeenCalled();
for (const call of (global.fetch as jest.Mock).mock.calls) {
const headers = call[1].headers as Headers;
expect(headers.get("Authorization")).toBe(authToken);
}
});


it("should always send specified custom headers", async () => {
const requestInit = {
Expand Down Expand Up @@ -530,7 +559,7 @@ describe("StreamableHTTPClientTransport", () => {
// Second retry - should double (2^1 * 100 = 200)
expect(getDelay(1)).toBe(200);

// Third retry - should double again (2^2 * 100 = 400)
// Third retry - should double again (2^2 * 100 = 400)
expect(getDelay(2)).toBe(400);

// Fourth retry - should double again (2^3 * 100 = 800)
Expand Down
23 changes: 15 additions & 8 deletions src/client/streamableHttp.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { Transport } from "../shared/transport.js";
import { Transport, FetchLike } from "../shared/transport.js";
import { isInitializedNotification, isJSONRPCRequest, isJSONRPCResponse, JSONRPCMessage, JSONRPCMessageSchema } from "../types.js";
import { auth, AuthResult, extractResourceMetadataUrl, OAuthClientProvider, UnauthorizedError } from "./auth.js";
import { EventSourceParserStream } from "eventsource-parser/stream";
Expand All @@ -23,7 +23,7 @@ export class StreamableHTTPError extends Error {
/**
* Options for starting or authenticating an SSE connection
*/
interface StartSSEOptions {
export interface StartSSEOptions {
/**
* The resumption token used to continue long-running requests that were interrupted.
*
Expand Down Expand Up @@ -99,6 +99,11 @@ export type StreamableHTTPClientTransportOptions = {
*/
requestInit?: RequestInit;

/**
* Custom fetch implementation used for all network requests.
*/
fetch?: FetchLike;

/**
* Options to configure the reconnection behavior.
*/
Expand All @@ -122,6 +127,7 @@ export class StreamableHTTPClientTransport implements Transport {
private _resourceMetadataUrl?: URL;
private _requestInit?: RequestInit;
private _authProvider?: OAuthClientProvider;
private _fetch?: FetchLike;
private _sessionId?: string;
private _reconnectionOptions: StreamableHTTPReconnectionOptions;
private _protocolVersion?: string;
Expand All @@ -138,6 +144,7 @@ export class StreamableHTTPClientTransport implements Transport {
this._resourceMetadataUrl = undefined;
this._requestInit = opts?.requestInit;
this._authProvider = opts?.authProvider;
this._fetch = opts?.fetch;
this._sessionId = opts?.sessionId;
this._reconnectionOptions = opts?.reconnectionOptions ?? DEFAULT_STREAMABLE_HTTP_RECONNECTION_OPTIONS;
}
Expand Down Expand Up @@ -200,7 +207,7 @@ export class StreamableHTTPClientTransport implements Transport {
headers.set("last-event-id", resumptionToken);
}

const response = await fetch(this._url, {
const response = await (this._fetch ?? fetch)(this._url, {
method: "GET",
headers,
signal: this._abortController?.signal,
Expand Down Expand Up @@ -251,15 +258,15 @@ export class StreamableHTTPClientTransport implements Transport {

private _normalizeHeaders(headers: HeadersInit | undefined): Record<string, string> {
if (!headers) return {};

if (headers instanceof Headers) {
return Object.fromEntries(headers.entries());
}

if (Array.isArray(headers)) {
return Object.fromEntries(headers);
}

return { ...headers as Record<string, string> };
}

Expand Down Expand Up @@ -414,7 +421,7 @@ export class StreamableHTTPClientTransport implements Transport {
signal: this._abortController?.signal,
};

const response = await fetch(this._url, init);
const response = await (this._fetch ?? fetch)(this._url, init);

// Handle session ID received during initialization
const sessionId = response.headers.get("mcp-session-id");
Expand Down Expand Up @@ -520,7 +527,7 @@ export class StreamableHTTPClientTransport implements Transport {
signal: this._abortController?.signal,
};

const response = await fetch(this._url, init);
const response = await (this._fetch ?? fetch)(this._url, init);

// We specifically handle 405 as a valid response according to the spec,
// meaning the server does not support explicit session termination
Expand Down
10 changes: 6 additions & 4 deletions src/shared/transport.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import { JSONRPCMessage, MessageExtraInfo, RequestId } from "../types.js";

export type FetchLike = (url: string | URL, init?: RequestInit) => Promise<Response>;

/**
* Options for sending a JSON-RPC message.
*/
export type TransportSendOptions = {
/**
/**
* If present, `relatedRequestId` is used to indicate to the transport which incoming request to associate this outgoing message with.
*/
relatedRequestId?: RequestId;
Expand Down Expand Up @@ -38,7 +40,7 @@ export interface Transport {

/**
* Sends a JSON-RPC message (request or response).
*
*
* If present, `relatedRequestId` is used to indicate to the transport which incoming request to associate this outgoing message with.
*/
send(message: JSONRPCMessage, options?: TransportSendOptions): Promise<void>;
Expand All @@ -64,9 +66,9 @@ export interface Transport {

/**
* Callback for when a message (request or response) is received over the connection.
*
*
* Includes the requestInfo and authInfo if the transport is authenticated.
*
*
* The requestInfo can be used to get the original request information (headers, etc.)
*/
onmessage?: (message: JSONRPCMessage, extra?: MessageExtraInfo) => void;
Expand Down