Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ out

# Stores VSCode versions used for testing VSCode extensions
.vscode-test
.vscode/
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure this belongs here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to make the .vscode folder untracked by Git.


# yarn v2
.yarn/cache
Expand Down
31 changes: 28 additions & 3 deletions src/client/sse.test.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { createServer, type IncomingMessage, type Server } from "http";
import { createServer, IncomingMessage, Server, ServerResponse } from "http";
import { AddressInfo } from "net";
import { JSONRPCMessage } from "../types.js";
import { SSEClientTransport } from "./sse.js";
Expand All @@ -10,8 +10,21 @@ describe("SSEClientTransport", () => {
let transport: SSEClientTransport;
let baseUrl: URL;
let lastServerRequest: IncomingMessage;
const serverRequests: Record<string, IncomingMessage[]> = {};
let sendServerMessage: ((message: string) => void) | null = null;

const recordServerRequest = (req: IncomingMessage, res: ServerResponse) => {
lastServerRequest = req;

const key = `${req.method} ${req.url}`;
serverRequests[key] = serverRequests[key] || [];
serverRequests[key].push(req);

res.on('finish', () => {
console.log(`[server] ${req.method} ${req.url} -> ${res.statusCode} ${res.statusMessage}`);
});
};

beforeEach((done) => {
// Reset state
lastServerRequest = null as unknown as IncomingMessage;
Expand Down Expand Up @@ -487,7 +500,7 @@ describe("SSEClientTransport", () => {

let connectionAttempts = 0;
server = createServer((req, res) => {
lastServerRequest = req;
recordServerRequest(req, res);

if (req.url === "/token" && req.method === "POST") {
// Handle token refresh request
Expand All @@ -496,7 +509,7 @@ describe("SSEClientTransport", () => {
req.on("end", () => {
const params = new URLSearchParams(body);
if (params.get("grant_type") === "refresh_token" &&
params.get("refresh_token") === "refresh-token" &&
params.get("refresh_token")?.includes("refresh-token") &&
params.get("client_id") === "test-client-id" &&
params.get("client_secret") === "test-client-secret") {
res.writeHead(200, { "Content-Type": "application/json" });
Expand Down Expand Up @@ -531,6 +544,7 @@ describe("SSEClientTransport", () => {
});
res.write("event: endpoint\n");
res.write(`data: ${baseUrl.href}\n\n`);
res.end();
connectionAttempts++;
return;
}
Expand All @@ -548,6 +562,14 @@ describe("SSEClientTransport", () => {

transport = new SSEClientTransport(baseUrl, {
authProvider: mockAuthProvider,
eventSourceInit: {
fetch: (url, init) => {
return fetch(url, { ...init, headers: {
...init?.headers,
'X-Custom-Header': 'custom-value'
} });
}
},
});

await transport.start();
Expand All @@ -559,6 +581,9 @@ describe("SSEClientTransport", () => {
});
expect(connectionAttempts).toBe(1);
expect(lastServerRequest.headers.authorization).toBe("Bearer new-token");
expect(serverRequests["GET /"]).toHaveLength(2);
expect(serverRequests["GET /"]
.every(req => req.headers["x-custom-header"] === "custom-value")).toBe(true);
});

it("refreshes expired token during POST request", async () => {
Expand Down
58 changes: 40 additions & 18 deletions src/client/sse.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,6 @@ export type SSEClientTransportOptions = {

/**
* Customizes the initial SSE request to the server (the request that begins the stream).
*
* NOTE: Setting this property will prevent an `Authorization` header from
* being automatically attached to the SSE request, if an `authProvider` is
* also given. This can be worked around by setting the `Authorization` header
* manually.
*/
eventSourceInit?: EventSourceInit;

Expand Down Expand Up @@ -96,7 +91,7 @@ export class SSEClientTransport implements Transport {
return await this._startOrAuth();
}

private async _commonHeaders(): Promise<HeadersInit> {
private async _commonHeaders(): Promise<Record<string, string>> {
const headers: HeadersInit = {};
if (this._authProvider) {
const tokens = await this._authProvider.tokens();
Expand All @@ -110,18 +105,7 @@ export class SSEClientTransport implements Transport {

private _startOrAuth(): Promise<void> {
return new Promise((resolve, reject) => {
this._eventSource = new EventSource(
this._url.href,
this._eventSourceInit ?? {
fetch: (url, init) => this._commonHeaders().then((headers) => fetch(url, {
...init,
headers: {
...headers,
Accept: "text/event-stream"
}
})),
},
);
this._eventSource = new EventSource(this._url.href, this._getEventSourceInit());
this._abortController = new AbortController();

this._eventSource.onerror = (event) => {
Expand Down Expand Up @@ -175,6 +159,44 @@ export class SSEClientTransport implements Transport {
});
}

private _getEventSourceInit(): EventSourceInit {
let eventSourceInit: EventSourceInit;

if (this._eventSourceInit) {
const originalFetch = this._eventSourceInit.fetch;

if (originalFetch && this._authProvider) {
// merge the new headers with the existing headers
eventSourceInit = {
...this._eventSourceInit,
fetch: async (url, init) => {
const newHeaders: Record<string, string> = await this._commonHeaders();
return originalFetch(url, {
...init,
headers: {
...newHeaders,
...init?.headers
}
});
}
};
} else {
eventSourceInit = this._eventSourceInit;
}
} else {
eventSourceInit = {
fetch: (url, init) => this._commonHeaders().then((headers) => fetch(url, {
...init,
headers: {
...headers,
Accept: "text/event-stream"
}
})),
};
}
return eventSourceInit;
}

async start() {
if (this._eventSource) {
throw new Error(
Expand Down