Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle empty Subscription DataReports correctly #248

Merged
merged 11 commits into from
Mar 11, 2023
10 changes: 5 additions & 5 deletions src/error/TryCatchHandler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
*/
import { ClassExtends } from "../util/Type";

type ErrorHandler<T> = (error: Error) => T;
type ErrorHandler<E extends Error, T> = (error: E) => T;

/**
* Try to execute the code block and catch the error if it is of the given type.
Expand All @@ -16,13 +16,13 @@ type ErrorHandler<T> = (error: Error) => T;
* @param errorType Errortype to catch and handle
* @param fallbackValueOrFunction Fallback value or function to compute the fallback value
*/
export function tryCatch<T>(codeBlock: () => T, errorType: ClassExtends<Error>, fallbackValueOrFunction: ErrorHandler<T> | T): T {
export function tryCatch<E extends Error, T>(codeBlock: () => T, errorType: ClassExtends<E>, fallbackValueOrFunction: ErrorHandler<E, T> | T): T {
try {
return codeBlock();
} catch (error) {
if (error instanceof errorType) {
if (typeof fallbackValueOrFunction === "function") {
return (fallbackValueOrFunction as ErrorHandler<T>)(error);
return (fallbackValueOrFunction as ErrorHandler<E, T>)(error);
} else {
return fallbackValueOrFunction;
}
Expand All @@ -40,13 +40,13 @@ export function tryCatch<T>(codeBlock: () => T, errorType: ClassExtends<Error>,
* @param errorType Errortype to catch and handle
* @param fallbackValueOrFunction Fallback value or function to compute the fallback value
*/
export async function tryCatchAsync<T>(codeBlock: () => Promise<T>, errorType: ClassExtends<Error>, fallbackValueOrFunction: ErrorHandler<T> | T): Promise<T> {
export async function tryCatchAsync<E extends Error, T>(codeBlock: () => Promise<T>, errorType: ClassExtends<E>, fallbackValueOrFunction: ErrorHandler<E, T> | T): Promise<T> {
try {
return await codeBlock();
} catch (error) {
if (error instanceof errorType) {
if (typeof fallbackValueOrFunction === "function") {
return (fallbackValueOrFunction as ErrorHandler<T>)(error);
return (fallbackValueOrFunction as ErrorHandler<E, T>)(error);
} else {
return fallbackValueOrFunction;
}
Expand Down
27 changes: 21 additions & 6 deletions src/matter/common/MessageExchange.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,19 @@ import { Logger } from "../../log/Logger";
import { NodeId } from "./NodeId";
import { ByteArray } from "@project-chip/matter.js";
import { SecureChannelProtocol } from "../session/secure/SecureChannelProtocol";
import { MatterError } from "../../error/MatterError";

const logger = Logger.get("MessageExchange");

export class UnexpectedMessageError extends MatterError {
public constructor(
message: string,
public readonly receivedMessage: Message,
) {
super(`(${MessageCodec.messageToString(receivedMessage)}) ${message}`);
}
}

/** The base number for the exponential backoff equation. */
const MRP_BACKOFF_BASE = 1.6;

Expand Down Expand Up @@ -146,8 +156,8 @@ export class MessageExchange<ContextT> {
}
} else {
// The other side has received our previous message
this.sentMessageAckSuccess?.();
this.retransmissionTimer?.stop();
this.sentMessageAckSuccess?.(message);
this.sentMessageAckSuccess = undefined;
this.sentMessageAckFailure = undefined;
this.sentMessageToAck = undefined;
Expand All @@ -166,7 +176,7 @@ export class MessageExchange<ContextT> {
await this.messagesQueue.write(message);
}

async send(messageType: number, payload: ByteArray) {
async send(messageType: number, payload: ByteArray, expectAckOnly: boolean = false) {
if (this.sentMessageToAck !== undefined) throw new Error("The previous message has not been acked yet, cannot send a new message");

this.session.notifyActivity(false);
Expand All @@ -192,11 +202,11 @@ export class MessageExchange<ContextT> {
if (messageType !== MessageType.StandaloneAck) {
this.receivedMessageToAck = undefined;
}
let ackPromise: Promise<void> | undefined;
let ackPromise: Promise<Message> | undefined;
if (message.payloadHeader.requiresAck) {
this.sentMessageToAck = message;
this.retransmissionTimer = Time.getTimer(this.getResubmissionBackOffTime(0), () => this.retransmitMessage(message, 0));
const { promise, resolver, rejecter } = await getPromiseResolver<void>();
const { promise, resolver, rejecter } = await getPromiseResolver<Message>();
ackPromise = promise;
this.sentMessageAckSuccess = resolver;
this.sentMessageAckFailure = rejecter;
Expand All @@ -206,10 +216,15 @@ export class MessageExchange<ContextT> {

if (ackPromise !== undefined) {
this.retransmissionTimer?.start();
await ackPromise;
this.retransmissionTimer?.stop();
// Await Response to be received (or Message retransmit limit reached which rejects the promise)
const responseMessage = await ackPromise;
this.sentMessageAckSuccess = undefined;
this.sentMessageAckFailure = undefined;
// If we only expect an Ack without data but got data, throw an error
const { payloadHeader: { protocolId, messageType } } = responseMessage;
if (expectAckOnly && !SecureChannelProtocol.isStandaloneAck(protocolId, messageType)) {
throw new UnexpectedMessageError("Expected ack only", responseMessage);
}
}
}

Expand Down
20 changes: 16 additions & 4 deletions src/matter/interaction/InteractionMessenger.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
*/

import { Logger } from "../../log/Logger";
import { MessageExchange } from "../common/MessageExchange";
import { MessageExchange, UnexpectedMessageError } from "../common/MessageExchange";
import { MatterController } from "../MatterController";
import { MatterDevice } from "../MatterDevice";
import {
Expand All @@ -25,6 +25,7 @@ import {
import { ByteArray, TlvSchema, TypeFromSchema } from "@project-chip/matter.js";
import { Message } from "../../codec/MessageCodec";
import { MatterError } from "../../error/MatterError";
import { tryCatchAsync } from "../../error/TryCatchHandler";

export const enum MessageType {
StatusResponse = 0x01,
Expand Down Expand Up @@ -82,7 +83,7 @@ class InteractionMessenger<ContextT> {
async nextMessage(expectedMessageType?: number) {
const message = await this.exchangeBase.nextMessage();
const messageType = message.payloadHeader.messageType;
this.throwIfError(messageType, message.payload);
this.throwIfErrorStatusMessage(message);
if (expectedMessageType !== undefined && messageType !== expectedMessageType) throw new Error(`Received unexpected message type: ${messageType}, expected: ${expectedMessageType}`);
return message;
}
Expand All @@ -91,7 +92,9 @@ class InteractionMessenger<ContextT> {
this.exchangeBase.close();
}

protected throwIfError(messageType: number, payload: ByteArray) {
protected throwIfErrorStatusMessage(message: Message) {
const { payloadHeader: { messageType}, payload } = message;

if (messageType !== MessageType.StatusResponse) return;
const { status } = TlvStatusResponse.decode(payload);
if (status !== StatusCode.Success) throw new StatusResponseError(`Received error status: ${ status }`, status);
Expand Down Expand Up @@ -198,7 +201,16 @@ export class InteractionServerMessenger extends InteractionMessenger<MatterDevic
}
}

await this.exchange.send(MessageType.ReportData, TlvDataReport.encode(dataReport));
if (dataReport.suppressResponse) {
// We do not expect a response other than a Standalone Ack, so if we receive anything else, we throw an error
await tryCatchAsync(async () => await this.exchange.send(MessageType.ReportData, TlvDataReport.encode(dataReport), true), UnexpectedMessageError, error => {
const { receivedMessage } = error;
this.throwIfErrorStatusMessage(receivedMessage);
});
} else {
await this.exchange.send(MessageType.ReportData, TlvDataReport.encode(dataReport));
await this.waitForSuccess();
}
}

async send(messageType: number, payload: ByteArray) {
Expand Down
6 changes: 0 additions & 6 deletions src/matter/interaction/SubscriptionHandler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,6 @@ export class SubscriptionHandler {
},
})),
});

await messenger.waitForSuccess();
}

async attributeChangeListener(path: AttributePath, schema: TlvSchema<any>, version: number, value: any) {
Expand Down Expand Up @@ -157,10 +155,6 @@ export class SubscriptionHandler {
})),
});

// Only expect answer for non-empty data reports
if (values.length) {
await messenger.waitForSuccess();
}
messenger.close();
}
}
6 changes: 4 additions & 2 deletions src/matter/session/secure/SecureChannelMessenger.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import { GeneralStatusCode, ProtocolStatusCode, MessageType, SECURE_CHANNEL_PROT
import { ByteArray, TlvSchema } from "@project-chip/matter.js";
import { MatterError } from "../../../error/MatterError";
import { TlvSecureChannelStatusMessage } from "./SecureChannelStatusMessageSchema";
import {Message} from "../../../codec/MessageCodec";

/** Error base Class for all errors related to the status response messages. */
export class ChannelStatusResponseError extends MatterError {
Expand All @@ -29,7 +30,7 @@ export class SecureChannelMessenger<ContextT> {
async nextMessage(expectedMessageType?: number) {
const message = await this.exchange.nextMessage();
const messageType = message.payloadHeader.messageType;
this.throwIfError(messageType, message.payload);
this.throwIfErrorStatusReport(message);
if (expectedMessageType !== undefined && messageType !== expectedMessageType) throw new Error(`Received unexpected message type: ${messageType}, expected: ${expectedMessageType}`);
return message;
}
Expand Down Expand Up @@ -73,7 +74,8 @@ export class SecureChannelMessenger<ContextT> {
}));
}

protected throwIfError(messageType: number, payload: ByteArray) {
protected throwIfErrorStatusReport(message: Message) {
const { payloadHeader: { messageType }, payload } = message;
if (messageType !== MessageType.StatusReport) return;

const { generalStatus, protocolId, protocolStatus } = TlvSecureChannelStatusMessage.decode(payload);
Expand Down