Skip to content

Refactor how token refreshing works to be more resilient #4819

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Apr 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion spec/unit/http-api/fetch.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,9 @@ describe("FetchHttpApi", () => {
accessToken,
refreshToken,
});
const result = await api.authedRequest(Method.Post, "/account/password");
const result = await api.authedRequest(Method.Post, "/account/password", undefined, undefined, {
headers: {},
});
expect(result).toEqual(okayResponse);
expect(tokenRefreshFunction).toHaveBeenCalledWith(refreshToken);

Expand All @@ -372,6 +374,7 @@ describe("FetchHttpApi", () => {
const tokenRefreshFunction = jest.fn().mockResolvedValue({
accessToken: newAccessToken,
refreshToken: newRefreshToken,
expiry: new Date(Date.now() + 1000),
});

// fetch doesn't like our new or old tokens
Expand Down
50 changes: 30 additions & 20 deletions spec/unit/oidc/tokenRefresher.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,12 @@ describe("OidcTokenRefresher", () => {
method: "POST",
});

expect(result).toEqual({
accessToken: "new-access-token",
refreshToken: "new-refresh-token",
});
expect(result).toEqual(
expect.objectContaining({
accessToken: "new-access-token",
refreshToken: "new-refresh-token",
}),
);
});

it("should persist the new tokens", async () => {
Expand All @@ -144,10 +146,12 @@ describe("OidcTokenRefresher", () => {

await refresher.doRefreshAccessToken("refresh-token");

expect(refresher.persistTokens).toHaveBeenCalledWith({
accessToken: "new-access-token",
refreshToken: "new-refresh-token",
});
expect(refresher.persistTokens).toHaveBeenCalledWith(
expect.objectContaining({
accessToken: "new-access-token",
refreshToken: "new-refresh-token",
}),
);
});

it("should only have one inflight refresh request at once", async () => {
Expand Down Expand Up @@ -189,21 +193,25 @@ describe("OidcTokenRefresher", () => {

// only one call to token endpoint
expect(fetchMock).toHaveFetchedTimes(1, config.token_endpoint);
expect(result1).toEqual({
accessToken: "first-new-access-token",
refreshToken: "first-new-refresh-token",
});
expect(result1).toEqual(
expect.objectContaining({
accessToken: "first-new-access-token",
refreshToken: "first-new-refresh-token",
}),
);
// same response
expect(result1).toEqual(result2);

// call again after first request resolves
const third = await refresher.doRefreshAccessToken("first-new-refresh-token");

// called token endpoint, got new tokens
expect(third).toEqual({
accessToken: "second-new-access-token",
refreshToken: "second-new-refresh-token",
});
expect(third).toEqual(
expect.objectContaining({
accessToken: "second-new-access-token",
refreshToken: "second-new-refresh-token",
}),
);
});

it("should log and rethrow when token refresh fails", async () => {
Expand Down Expand Up @@ -261,10 +269,12 @@ describe("OidcTokenRefresher", () => {
const result = await refresher.doRefreshAccessToken("first-new-refresh-token");

// called token endpoint, got new tokens
expect(result).toEqual({
accessToken: "second-new-access-token",
refreshToken: "second-new-refresh-token",
});
expect(result).toEqual(
expect.objectContaining({
accessToken: "second-new-access-token",
refreshToken: "second-new-refresh-token",
}),
);
});

it("should throw TokenRefreshLogoutError when expired", async () => {
Expand Down
111 changes: 35 additions & 76 deletions src/http-api/fetch.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ limitations under the License.
* This is an internal module. See {@link MatrixHttpApi} for the public class.
*/

import { checkObjectHasKeys, encodeParams } from "../utils.ts";
import { checkObjectHasKeys, deepCopy, encodeParams } from "../utils.ts";
import { type TypedEventEmitter } from "../models/typed-event-emitter.ts";
import { Method } from "./method.ts";
import { ConnectionError, MatrixError, TokenRefreshError, TokenRefreshLogoutError } from "./errors.ts";
import { ConnectionError, MatrixError, TokenRefreshError } from "./errors.ts";
import {
HttpApiEvent,
type HttpApiEventHandlerMap,
Expand All @@ -31,7 +31,7 @@ import {
} from "./interface.ts";
import { anySignal, parseErrorResponse, timeoutSignal } from "./utils.ts";
import { type QueryDict } from "../utils.ts";
import { singleAsyncExecution } from "../utils/decorators.ts";
import { TokenRefresher, TokenRefreshOutcome } from "./refresh.ts";

interface TypedResponse<T> extends Response {
json(): Promise<T>;
Expand All @@ -43,14 +43,9 @@ export type ResponseType<T, O extends IHttpOpts> = O extends { json: false }
? T
: TypedResponse<T>;

const enum TokenRefreshOutcome {
Success = "success",
Failure = "failure",
Logout = "logout",
}

export class FetchHttpApi<O extends IHttpOpts> {
private abortController = new AbortController();
private readonly tokenRefresher: TokenRefresher;

public constructor(
private eventEmitter: TypedEventEmitter<HttpApiEvent, HttpApiEventHandlerMap>,
Expand All @@ -59,6 +54,8 @@ export class FetchHttpApi<O extends IHttpOpts> {
checkObjectHasKeys(opts, ["baseUrl", "prefix"]);
opts.onlyData = !!opts.onlyData;
opts.useAuthorizationHeader = opts.useAuthorizationHeader ?? true;

this.tokenRefresher = new TokenRefresher(opts);
}

public abort(): void {
Expand Down Expand Up @@ -113,12 +110,6 @@ export class FetchHttpApi<O extends IHttpOpts> {
return this.requestOtherUrl(method, fullUri, body, opts);
}

/**
* Promise used to block authenticated requests during a token refresh to avoid repeated expected errors.
* @private
*/
private tokenRefreshPromise?: Promise<unknown>;

/**
* Perform an authorised request to the homeserver.
* @param method - The HTTP method e.g. "GET".
Expand Down Expand Up @@ -146,36 +137,45 @@ export class FetchHttpApi<O extends IHttpOpts> {
* @returns Rejects with an error if a problem occurred.
* This includes network problems and Matrix-specific error JSON.
*/
public async authedRequest<T>(
public authedRequest<T>(
method: Method,
path: string,
queryParams?: QueryDict,
queryParams: QueryDict = {},
body?: Body,
paramOpts: IRequestOpts & { doNotAttemptTokenRefresh?: boolean } = {},
paramOpts: IRequestOpts = {},
): Promise<ResponseType<T, O>> {
if (!queryParams) queryParams = {};
return this.doAuthedRequest<T>(1, method, path, queryParams, body, paramOpts);
}

// Wrapper around public method authedRequest to allow for tracking retry attempt counts
private async doAuthedRequest<T>(
attempt: number,
method: Method,
path: string,
queryParams: QueryDict,
body?: Body,
paramOpts: IRequestOpts = {},
): Promise<ResponseType<T, O>> {
// avoid mutating paramOpts so they can be used on retry
const opts = { ...paramOpts };

// Await any ongoing token refresh before we build the headers/params
await this.tokenRefreshPromise;
const opts = deepCopy(paramOpts);
// we have to manually copy the abortSignal over as it is not a plain object
opts.abortSignal = paramOpts.abortSignal;

// Take a copy of the access token so we have a record of the token we used for this request if it fails
const accessToken = this.opts.accessToken;
if (accessToken) {
// Take a snapshot of the current token state before we start the request so we can reference it if we error
const requestSnapshot = await this.tokenRefresher.prepareForRequest();
if (requestSnapshot.accessToken) {
if (this.opts.useAuthorizationHeader) {
if (!opts.headers) {
opts.headers = {};
}
if (!opts.headers.Authorization) {
opts.headers.Authorization = `Bearer ${accessToken}`;
opts.headers.Authorization = `Bearer ${requestSnapshot.accessToken}`;
}
if (queryParams.access_token) {
delete queryParams.access_token;
}
} else if (!queryParams.access_token) {
queryParams.access_token = accessToken;
queryParams.access_token = requestSnapshot.accessToken;
}
}

Expand All @@ -187,33 +187,19 @@ export class FetchHttpApi<O extends IHttpOpts> {
throw error;
}

if (error.errcode === "M_UNKNOWN_TOKEN" && !opts.doNotAttemptTokenRefresh) {
// If the access token has changed since we started the request, but before we refreshed it,
// then it was refreshed due to another request failing, so retry before refreshing again.
let outcome: TokenRefreshOutcome | null = null;
if (accessToken === this.opts.accessToken) {
const tokenRefreshPromise = this.tryRefreshToken();
this.tokenRefreshPromise = tokenRefreshPromise;
outcome = await tokenRefreshPromise;
}

if (outcome === TokenRefreshOutcome.Success || outcome === null) {
if (error.errcode === "M_UNKNOWN_TOKEN") {
const outcome = await this.tokenRefresher.handleUnknownToken(requestSnapshot, attempt);
if (outcome === TokenRefreshOutcome.Success) {
// if we got a new token retry the request
return this.authedRequest(method, path, queryParams, body, {
...paramOpts,
// Only attempt token refresh once for each failed request
doNotAttemptTokenRefresh: outcome !== null,
});
return this.doAuthedRequest(attempt + 1, method, path, queryParams, body, paramOpts);
}
if (outcome === TokenRefreshOutcome.Failure) {
throw new TokenRefreshError(error);
}
// Fall through to SessionLoggedOut handler below
}

// otherwise continue with error handling
if (error.errcode == "M_UNKNOWN_TOKEN" && !opts?.inhibitLogoutEmit) {
this.eventEmitter.emit(HttpApiEvent.SessionLoggedOut, error);
if (!opts?.inhibitLogoutEmit) {
this.eventEmitter.emit(HttpApiEvent.SessionLoggedOut, error);
}
} else if (error.errcode == "M_CONSENT_NOT_GIVEN") {
this.eventEmitter.emit(HttpApiEvent.NoConsent, error.message, error.data.consent_uri);
}
Expand All @@ -222,33 +208,6 @@ export class FetchHttpApi<O extends IHttpOpts> {
}
}

/**
* Attempt to refresh access tokens.
* On success, sets new access and refresh tokens in opts.
* @returns Promise that resolves to a boolean - true when token was refreshed successfully
*/
@singleAsyncExecution
private async tryRefreshToken(): Promise<TokenRefreshOutcome> {
if (!this.opts.refreshToken || !this.opts.tokenRefreshFunction) {
return TokenRefreshOutcome.Logout;
}

try {
const { accessToken, refreshToken } = await this.opts.tokenRefreshFunction(this.opts.refreshToken);
this.opts.accessToken = accessToken;
this.opts.refreshToken = refreshToken;
// successfully got new tokens
return TokenRefreshOutcome.Success;
} catch (error) {
this.opts.logger?.warn("Failed to refresh token", error);
// If we get a TokenError or MatrixError, we should log out, otherwise assume transient
if (error instanceof TokenRefreshLogoutError || error instanceof MatrixError) {
return TokenRefreshOutcome.Logout;
}
return TokenRefreshOutcome.Failure;
}
}

/**
* Perform a request to the homeserver without any credentials.
* @param method - The HTTP method e.g. "GET".
Expand Down
11 changes: 11 additions & 0 deletions src/http-api/interface.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,20 @@ export type Body = Record<string, any> | BodyInit;
* Unencrypted access and (optional) refresh token
*/
export type AccessTokens = {
/**
* The new access token to use for authenticated requests
*/
accessToken: string;
/**
* The new refresh token to use for refreshing tokens, optional
*/
refreshToken?: string;
/**
* Approximate date when the access token will expire, optional
*/
expiry?: Date;
};

/**
* @experimental
* Function that performs token refresh using the given refreshToken.
Expand Down
Loading