Skip to content

Commit 9ac7165

Browse files
authored
Handle outgoing requests from rust crypto SDK (#3019)
The rust matrix-sdk-crypto has an `outgoingRequests()` method which we need to poll, and make the requested requests.
1 parent 6168ced commit 9ac7165

File tree

4 files changed

+284
-12
lines changed

4 files changed

+284
-12
lines changed

spec/unit/rust-crypto.spec.ts

+164-5
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,21 @@ limitations under the License.
1616

1717
import "fake-indexeddb/auto";
1818
import { IDBFactory } from "fake-indexeddb";
19+
import * as RustSdkCryptoJs from "@matrix-org/matrix-sdk-crypto-js";
20+
import {
21+
KeysBackupRequest,
22+
KeysClaimRequest,
23+
KeysQueryRequest,
24+
KeysUploadRequest,
25+
SignatureUploadRequest,
26+
} from "@matrix-org/matrix-sdk-crypto-js";
27+
import { Mocked } from "jest-mock";
28+
import MockHttpBackend from "matrix-mock-request";
1929

2030
import { RustCrypto } from "../../src/rust-crypto/rust-crypto";
2131
import { initRustCrypto } from "../../src/rust-crypto";
32+
import { HttpApiEvent, HttpApiEventHandlerMap, IHttpOpts, MatrixHttpApi } from "../../src";
33+
import { TypedEventEmitter } from "../../src/models/typed-event-emitter";
2234

2335
afterEach(() => {
2436
// reset fake-indexeddb after each test, to make sure we don't leak connections
@@ -31,16 +43,163 @@ describe("RustCrypto", () => {
3143
const TEST_USER = "@alice:example.com";
3244
const TEST_DEVICE_ID = "TEST_DEVICE";
3345

34-
let rustCrypto: RustCrypto;
46+
describe(".exportRoomKeys", () => {
47+
let rustCrypto: RustCrypto;
3548

36-
beforeEach(async () => {
37-
rustCrypto = (await initRustCrypto(TEST_USER, TEST_DEVICE_ID)) as RustCrypto;
38-
});
49+
beforeEach(async () => {
50+
const mockHttpApi = {} as MatrixHttpApi<IHttpOpts>;
51+
rustCrypto = (await initRustCrypto(mockHttpApi, TEST_USER, TEST_DEVICE_ID)) as RustCrypto;
52+
});
3953

40-
describe(".exportRoomKeys", () => {
4154
it("should return a list", async () => {
4255
const keys = await rustCrypto.exportRoomKeys();
4356
expect(Array.isArray(keys)).toBeTruthy();
4457
});
4558
});
59+
60+
describe("outgoing requests", () => {
61+
/** the RustCrypto implementation under test */
62+
let rustCrypto: RustCrypto;
63+
64+
/** A mock http backend which rustCrypto is connected to */
65+
let httpBackend: MockHttpBackend;
66+
67+
/** a mocked-up OlmMachine which rustCrypto is connected to */
68+
let olmMachine: Mocked<RustSdkCryptoJs.OlmMachine>;
69+
70+
/** A list of results to be returned from olmMachine.outgoingRequest. Each call will shift a result off
71+
* the front of the queue, until it is empty. */
72+
let outgoingRequestQueue: Array<Array<any>>;
73+
74+
/** wait for a call to olmMachine.markRequestAsSent */
75+
function awaitCallToMarkAsSent(): Promise<void> {
76+
return new Promise((resolve, _reject) => {
77+
olmMachine.markRequestAsSent.mockImplementationOnce(async () => {
78+
resolve(undefined);
79+
});
80+
});
81+
}
82+
83+
beforeEach(async () => {
84+
httpBackend = new MockHttpBackend();
85+
86+
await RustSdkCryptoJs.initAsync();
87+
88+
const dummyEventEmitter = new TypedEventEmitter<HttpApiEvent, HttpApiEventHandlerMap>();
89+
const httpApi = new MatrixHttpApi(dummyEventEmitter, {
90+
baseUrl: "https://example.com",
91+
prefix: "/_matrix",
92+
fetchFn: httpBackend.fetchFn as typeof global.fetch,
93+
});
94+
95+
// for these tests we use a mock OlmMachine, with an implementation of outgoingRequests that
96+
// returns objects from outgoingRequestQueue
97+
outgoingRequestQueue = [];
98+
olmMachine = {
99+
outgoingRequests: jest.fn().mockImplementation(() => {
100+
return Promise.resolve(outgoingRequestQueue.shift() ?? []);
101+
}),
102+
markRequestAsSent: jest.fn(),
103+
close: jest.fn(),
104+
} as unknown as Mocked<RustSdkCryptoJs.OlmMachine>;
105+
106+
rustCrypto = new RustCrypto(olmMachine, httpApi, TEST_USER, TEST_DEVICE_ID);
107+
});
108+
109+
it("should poll for outgoing messages", () => {
110+
rustCrypto.onSyncCompleted({});
111+
expect(olmMachine.outgoingRequests).toHaveBeenCalled();
112+
});
113+
114+
/* simple requests that map directly to the request body */
115+
const tests: Array<[any, "POST" | "PUT", string]> = [
116+
[KeysUploadRequest, "POST", "https://example.com/_matrix/client/v3/keys/upload"],
117+
[KeysQueryRequest, "POST", "https://example.com/_matrix/client/v3/keys/query"],
118+
[KeysClaimRequest, "POST", "https://example.com/_matrix/client/v3/keys/claim"],
119+
[SignatureUploadRequest, "POST", "https://example.com/_matrix/client/v3/keys/signatures/upload"],
120+
[KeysBackupRequest, "PUT", "https://example.com/_matrix/client/v3/room_keys/keys"],
121+
];
122+
123+
for (const [RequestClass, expectedMethod, expectedPath] of tests) {
124+
it(`should handle ${RequestClass.name}s`, async () => {
125+
const testBody = '{ "foo": "bar" }';
126+
const outgoingRequest = new RequestClass("1234", testBody);
127+
outgoingRequestQueue.push([outgoingRequest]);
128+
129+
const testResponse = '{ "result": 1 }';
130+
httpBackend
131+
.when(expectedMethod, "/_matrix")
132+
.check((req) => {
133+
expect(req.path).toEqual(expectedPath);
134+
expect(req.rawData).toEqual(testBody);
135+
expect(req.headers["Accept"]).toEqual("application/json");
136+
expect(req.headers["Content-Type"]).toEqual("application/json");
137+
})
138+
.respond(200, testResponse, true);
139+
140+
rustCrypto.onSyncCompleted({});
141+
142+
expect(olmMachine.outgoingRequests).toHaveBeenCalledTimes(1);
143+
144+
const markSentCallPromise = awaitCallToMarkAsSent();
145+
await httpBackend.flushAllExpected();
146+
147+
await markSentCallPromise;
148+
expect(olmMachine.markRequestAsSent).toHaveBeenCalledWith("1234", outgoingRequest.type, testResponse);
149+
httpBackend.verifyNoOutstandingRequests();
150+
});
151+
}
152+
153+
it("does not explode with unknown requests", async () => {
154+
const outgoingRequest = { id: "5678", type: 987 };
155+
outgoingRequestQueue.push([outgoingRequest]);
156+
157+
rustCrypto.onSyncCompleted({});
158+
159+
await awaitCallToMarkAsSent();
160+
expect(olmMachine.markRequestAsSent).toHaveBeenCalledWith("5678", 987, "");
161+
});
162+
163+
it("stops looping when stop() is called", async () => {
164+
const testResponse = '{ "result": 1 }';
165+
166+
for (let i = 0; i < 5; i++) {
167+
outgoingRequestQueue.push([new KeysQueryRequest("1234", "{}")]);
168+
httpBackend.when("POST", "/_matrix").respond(200, testResponse, true);
169+
}
170+
171+
rustCrypto.onSyncCompleted({});
172+
173+
expect(rustCrypto["outgoingRequestLoopRunning"]).toBeTruthy();
174+
175+
// go a couple of times round the loop
176+
await httpBackend.flush("/_matrix", 1);
177+
await awaitCallToMarkAsSent();
178+
179+
await httpBackend.flush("/_matrix", 1);
180+
await awaitCallToMarkAsSent();
181+
182+
// a second sync while this is going on shouldn't make any difference
183+
rustCrypto.onSyncCompleted({});
184+
185+
await httpBackend.flush("/_matrix", 1);
186+
await awaitCallToMarkAsSent();
187+
188+
// now stop...
189+
rustCrypto.stop();
190+
191+
// which should (eventually) cause the loop to stop with no further calls to outgoingRequests
192+
olmMachine.outgoingRequests.mockReset();
193+
194+
await new Promise((resolve) => {
195+
setTimeout(resolve, 100);
196+
});
197+
expect(rustCrypto["outgoingRequestLoopRunning"]).toBeFalsy();
198+
httpBackend.verifyNoOutstandingRequests();
199+
expect(olmMachine.outgoingRequests).not.toHaveBeenCalled();
200+
201+
// we sent three, so there should be 2 left
202+
expect(outgoingRequestQueue.length).toEqual(2);
203+
});
204+
});
46205
});

src/client.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -2148,7 +2148,7 @@ export class MatrixClient extends TypedEventEmitter<EmittedEvents, ClientEventHa
21482148
// importing rust-crypto will download the webassembly, so we delay it until we know it will be
21492149
// needed.
21502150
const RustCrypto = await import("./rust-crypto");
2151-
this.cryptoBackend = await RustCrypto.initRustCrypto(userId, deviceId);
2151+
this.cryptoBackend = await RustCrypto.initRustCrypto(this.http, userId, deviceId);
21522152
}
21532153

21542154
/**

src/rust-crypto/index.ts

+7-2
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,13 @@ import { RustCrypto } from "./rust-crypto";
2020
import { logger } from "../logger";
2121
import { CryptoBackend } from "../common-crypto/CryptoBackend";
2222
import { RUST_SDK_STORE_PREFIX } from "./constants";
23+
import { IHttpOpts, MatrixHttpApi } from "../http-api";
2324

24-
export async function initRustCrypto(userId: string, deviceId: string): Promise<CryptoBackend> {
25+
export async function initRustCrypto(
26+
http: MatrixHttpApi<IHttpOpts>,
27+
userId: string,
28+
deviceId: string,
29+
): Promise<CryptoBackend> {
2530
// initialise the rust matrix-sdk-crypto-js, if it hasn't already been done
2631
await RustSdkCryptoJs.initAsync();
2732

@@ -34,7 +39,7 @@ export async function initRustCrypto(userId: string, deviceId: string): Promise<
3439

3540
// TODO: use the pickle key for the passphrase
3641
const olmMachine = await RustSdkCryptoJs.OlmMachine.initialize(u, d, RUST_SDK_STORE_PREFIX, "test pass");
37-
const rustCrypto = new RustCrypto(olmMachine, userId, deviceId);
42+
const rustCrypto = new RustCrypto(olmMachine, http, userId, deviceId);
3843

3944
logger.info("Completed rust crypto-sdk setup");
4045
return rustCrypto;

src/rust-crypto/rust-crypto.ts

+112-4
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,28 @@ limitations under the License.
1515
*/
1616

1717
import * as RustSdkCryptoJs from "@matrix-org/matrix-sdk-crypto-js";
18+
import {
19+
KeysBackupRequest,
20+
KeysClaimRequest,
21+
KeysQueryRequest,
22+
KeysUploadRequest,
23+
SignatureUploadRequest,
24+
} from "@matrix-org/matrix-sdk-crypto-js";
1825

1926
import type { IEventDecryptionResult, IMegolmSessionData } from "../@types/crypto";
2027
import { MatrixEvent } from "../models/event";
2128
import { CryptoBackend, OnSyncCompletedData } from "../common-crypto/CryptoBackend";
29+
import { logger } from "../logger";
30+
import { IHttpOpts, IRequestOpts, MatrixHttpApi, Method } from "../http-api";
31+
import { QueryDict } from "../utils";
2232

23-
// import { logger } from "../logger";
33+
/**
34+
* Common interface for all the request types returned by `OlmMachine.outgoingRequests`.
35+
*/
36+
interface OutgoingRequest {
37+
readonly id: string | undefined;
38+
readonly type: number;
39+
}
2440

2541
/**
2642
* An implementation of {@link CryptoBackend} using the Rust matrix-sdk-crypto.
@@ -29,10 +45,18 @@ export class RustCrypto implements CryptoBackend {
2945
public globalBlacklistUnverifiedDevices = false;
3046
public globalErrorOnUnknownDevices = false;
3147

32-
/** whether stop() has been called */
48+
/** whether {@link stop} has been called */
3349
private stopped = false;
3450

35-
public constructor(private readonly olmMachine: RustSdkCryptoJs.OlmMachine, _userId: string, _deviceId: string) {}
51+
/** whether {@link outgoingRequestLoop} is currently running */
52+
private outgoingRequestLoopRunning = false;
53+
54+
public constructor(
55+
private readonly olmMachine: RustSdkCryptoJs.OlmMachine,
56+
private readonly http: MatrixHttpApi<IHttpOpts>,
57+
_userId: string,
58+
_deviceId: string,
59+
) {}
3660

3761
public stop(): void {
3862
// stop() may be called multiple times, but attempting to close() the OlmMachine twice
@@ -63,11 +87,95 @@ export class RustCrypto implements CryptoBackend {
6387
return [];
6488
}
6589

90+
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////
91+
//
92+
// SyncCryptoCallbacks implementation
93+
//
94+
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////
95+
6696
/** called by the sync loop after processing each sync.
6797
*
6898
* TODO: figure out something equivalent for sliding sync.
6999
*
70100
* @param syncState - information on the completed sync.
71101
*/
72-
public onSyncCompleted(syncState: OnSyncCompletedData): void {}
102+
public onSyncCompleted(syncState: OnSyncCompletedData): void {
103+
// Processing the /sync may have produced new outgoing requests which need sending, so kick off the outgoing
104+
// request loop, if it's not already running.
105+
this.outgoingRequestLoop();
106+
}
107+
108+
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////
109+
//
110+
// Outgoing requests
111+
//
112+
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////
113+
114+
private async outgoingRequestLoop(): Promise<void> {
115+
if (this.outgoingRequestLoopRunning) {
116+
return;
117+
}
118+
this.outgoingRequestLoopRunning = true;
119+
try {
120+
while (!this.stopped) {
121+
const outgoingRequests: Object[] = await this.olmMachine.outgoingRequests();
122+
if (outgoingRequests.length == 0 || this.stopped) {
123+
// no more messages to send (or we have been told to stop): exit the loop
124+
return;
125+
}
126+
for (const msg of outgoingRequests) {
127+
await this.doOutgoingRequest(msg as OutgoingRequest);
128+
}
129+
}
130+
} catch (e) {
131+
logger.error("Error processing outgoing-message requests from rust crypto-sdk", e);
132+
} finally {
133+
this.outgoingRequestLoopRunning = false;
134+
}
135+
}
136+
137+
private async doOutgoingRequest(msg: OutgoingRequest): Promise<void> {
138+
let resp: string;
139+
140+
/* refer https://docs.rs/matrix-sdk-crypto/0.6.0/matrix_sdk_crypto/requests/enum.OutgoingRequests.html
141+
* for the complete list of request types
142+
*/
143+
if (msg instanceof KeysUploadRequest) {
144+
resp = await this.rawJsonRequest(Method.Post, "/_matrix/client/v3/keys/upload", {}, msg.body);
145+
} else if (msg instanceof KeysQueryRequest) {
146+
resp = await this.rawJsonRequest(Method.Post, "/_matrix/client/v3/keys/query", {}, msg.body);
147+
} else if (msg instanceof KeysClaimRequest) {
148+
resp = await this.rawJsonRequest(Method.Post, "/_matrix/client/v3/keys/claim", {}, msg.body);
149+
} else if (msg instanceof SignatureUploadRequest) {
150+
resp = await this.rawJsonRequest(Method.Post, "/_matrix/client/v3/keys/signatures/upload", {}, msg.body);
151+
} else if (msg instanceof KeysBackupRequest) {
152+
resp = await this.rawJsonRequest(Method.Put, "/_matrix/client/v3/room_keys/keys", {}, msg.body);
153+
} else {
154+
// TODO: ToDeviceRequest, RoomMessageRequest
155+
logger.warn("Unsupported outgoing message", Object.getPrototypeOf(msg));
156+
resp = "";
157+
}
158+
159+
if (msg.id) {
160+
await this.olmMachine.markRequestAsSent(msg.id, msg.type, resp);
161+
}
162+
}
163+
164+
private async rawJsonRequest(
165+
method: Method,
166+
path: string,
167+
queryParams: QueryDict,
168+
body: string,
169+
opts: IRequestOpts = {},
170+
): Promise<string> {
171+
// unbeknownst to HttpApi, we are sending JSON
172+
opts.headers ??= {};
173+
opts.headers["Content-Type"] = "application/json";
174+
175+
// we use the full prefix
176+
opts.prefix ??= "";
177+
178+
const resp = await this.http.authedRequest(method, path, queryParams, body, opts);
179+
return await resp.text();
180+
}
73181
}

0 commit comments

Comments
 (0)