Skip to content

feat(control-plane): add support for handling multiple events in a single invocation #4603

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,8 @@ Join our discord community via [this invite link](https://discord.gg/bxgXW8jJGh)
| <a name="input_key_name"></a> [key\_name](#input\_key\_name) | Key pair name | `string` | `null` | no |
| <a name="input_kms_key_arn"></a> [kms\_key\_arn](#input\_kms\_key\_arn) | Optional CMK Key ARN to be used for Parameter Store. This key must be in the current account. | `string` | `null` | no |
| <a name="input_lambda_architecture"></a> [lambda\_architecture](#input\_lambda\_architecture) | AWS Lambda architecture. Lambda functions using Graviton processors ('arm64') tend to have better price/performance than 'x86\_64' functions. | `string` | `"arm64"` | no |
| <a name="input_lambda_event_source_mapping_batch_size"></a> [lambda\_event\_source\_mapping\_batch\_size](#input\_lambda\_event\_source\_mapping\_batch\_size) | Maximum number of records to pass to the lambda function in a single batch for the event source mapping. When not set, the AWS default of 10 events will be used. | `number` | `10` | no |
| <a name="input_lambda_event_source_mapping_maximum_batching_window_in_seconds"></a> [lambda\_event\_source\_mapping\_maximum\_batching\_window\_in\_seconds](#input\_lambda\_event\_source\_mapping\_maximum\_batching\_window\_in\_seconds) | Maximum amount of time to gather records before invoking the lambda function, in seconds. AWS requires this to be greater than 0 if batch\_size is greater than 10. Defaults to 0. | `number` | `0` | no |
| <a name="input_lambda_principals"></a> [lambda\_principals](#input\_lambda\_principals) | (Optional) add extra principals to the role created for execution of the lambda, e.g. for local testing. | <pre>list(object({<br/> type = string<br/> identifiers = list(string)<br/> }))</pre> | `[]` | no |
| <a name="input_lambda_runtime"></a> [lambda\_runtime](#input\_lambda\_runtime) | AWS Lambda runtime. | `string` | `"nodejs22.x"` | no |
| <a name="input_lambda_s3_bucket"></a> [lambda\_s3\_bucket](#input\_lambda\_s3\_bucket) | S3 bucket from which to specify lambda functions. This is an alternative to providing local files directly. | `string` | `null` | no |
Expand Down
171 changes: 147 additions & 24 deletions lambdas/functions/control-plane/src/lambda.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -70,19 +70,33 @@ vi.mock('@aws-github-runner/aws-powertools-util');
vi.mock('@aws-github-runner/aws-ssm-util');

describe('Test scale up lambda wrapper.', () => {
it('Do not handle multiple record sets.', async () => {
await testInvalidRecords([sqsRecord, sqsRecord]);
it('Do not handle empty record sets.', async () => {
const sqsEventMultipleRecords: SQSEvent = {
Records: [],
};

await expect(scaleUpHandler(sqsEventMultipleRecords, context)).resolves.not.toThrow();
});

it('Do not handle empty record sets.', async () => {
await testInvalidRecords([]);
it('Ignores non-sqs event sources.', async () => {
const record = {
...sqsRecord,
eventSource: 'aws:non-sqs',
};

const sqsEventMultipleRecordsNonSQS: SQSEvent = {
Records: [record],
};

await expect(scaleUpHandler(sqsEventMultipleRecordsNonSQS, context)).resolves.not.toThrow();
expect(scaleUp).toHaveBeenCalledWith([]);
});

it('Scale without error should resolve.', async () => {
const mock = vi.fn(scaleUp);
mock.mockImplementation(() => {
return new Promise((resolve) => {
resolve();
resolve([]);
});
});
await expect(scaleUpHandler(sqsEvent, context)).resolves.not.toThrow();
Expand All @@ -104,28 +118,137 @@ describe('Test scale up lambda wrapper.', () => {
vi.mocked(scaleUp).mockImplementation(mock);
await expect(scaleUpHandler(sqsEvent, context)).rejects.toThrow(error);
});
});

async function testInvalidRecords(sqsRecords: SQSRecord[]) {
const mock = vi.fn(scaleUp);
const logWarnSpy = vi.spyOn(logger, 'warn');
mock.mockImplementation(() => {
return new Promise((resolve) => {
resolve();
describe('Batch processing', () => {
beforeEach(() => {
vi.clearAllMocks();
});

const createMultipleRecords = (count: number, eventSource = 'aws:sqs'): SQSRecord[] => {
return Array.from({ length: count }, (_, i) => ({
...sqsRecord,
eventSource,
messageId: `message-${i}`,
body: JSON.stringify({
...body,
id: i + 1,
}),
}));
};

it('Should handle multiple SQS records in a single invocation', async () => {
const records = createMultipleRecords(3);
const multiRecordEvent: SQSEvent = { Records: records };

const mock = vi.fn(scaleUp);
mock.mockImplementation(() => Promise.resolve([]));
vi.mocked(scaleUp).mockImplementation(mock);

await expect(scaleUpHandler(multiRecordEvent, context)).resolves.not.toThrow();
expect(scaleUp).toHaveBeenCalledWith(
expect.arrayContaining([
expect.objectContaining({ messageId: 'message-0' }),
expect.objectContaining({ messageId: 'message-1' }),
expect.objectContaining({ messageId: 'message-2' }),
]),
);
});

it('Should return batch item failures for rejected messages', async () => {
const records = createMultipleRecords(3);
const multiRecordEvent: SQSEvent = { Records: records };

const mock = vi.fn(scaleUp);
mock.mockImplementation(() => Promise.resolve(['message-1', 'message-2']));
vi.mocked(scaleUp).mockImplementation(mock);

const result = await scaleUpHandler(multiRecordEvent, context);
expect(result).toEqual({
batchItemFailures: [{ itemIdentifier: 'message-1' }, { itemIdentifier: 'message-2' }],
});
});

it('Should filter out non-SQS event sources', async () => {
const sqsRecords = createMultipleRecords(2, 'aws:sqs');
const nonSqsRecords = createMultipleRecords(1, 'aws:sns');
const mixedEvent: SQSEvent = {
Records: [...sqsRecords, ...nonSqsRecords],
};

const mock = vi.fn(scaleUp);
mock.mockImplementation(() => Promise.resolve([]));
vi.mocked(scaleUp).mockImplementation(mock);

await scaleUpHandler(mixedEvent, context);
expect(scaleUp).toHaveBeenCalledWith(
expect.arrayContaining([
expect.objectContaining({ messageId: 'message-0' }),
expect.objectContaining({ messageId: 'message-1' }),
]),
);
expect(scaleUp).not.toHaveBeenCalledWith(
expect.arrayContaining([expect.objectContaining({ messageId: 'message-2' })]),
);
});

it('Should sort messages by retry count', async () => {
const records = [
{
...sqsRecord,
messageId: 'high-retry',
body: JSON.stringify({ ...body, retryCounter: 5 }),
},
{
...sqsRecord,
messageId: 'low-retry',
body: JSON.stringify({ ...body, retryCounter: 1 }),
},
{
...sqsRecord,
messageId: 'no-retry',
body: JSON.stringify({ ...body }),
},
];
const multiRecordEvent: SQSEvent = { Records: records };

const mock = vi.fn(scaleUp);
mock.mockImplementation((messages) => {
// Verify messages are sorted by retry count (ascending)
expect(messages[0].messageId).toBe('no-retry');
expect(messages[1].messageId).toBe('low-retry');
expect(messages[2].messageId).toBe('high-retry');
return Promise.resolve([]);
});
vi.mocked(scaleUp).mockImplementation(mock);

await scaleUpHandler(multiRecordEvent, context);
});

it('Should return all failed messages when scaleUp throws non-ScaleError', async () => {
const records = createMultipleRecords(2);
const multiRecordEvent: SQSEvent = { Records: records };

const mock = vi.fn(scaleUp);
mock.mockImplementation(() => Promise.reject(new Error('Generic error')));
vi.mocked(scaleUp).mockImplementation(mock);

const result = await scaleUpHandler(multiRecordEvent, context);
expect(result).toEqual({ batchItemFailures: [] });
});

it('Should throw when scaleUp throws ScaleError', async () => {
const records = createMultipleRecords(2);
const multiRecordEvent: SQSEvent = { Records: records };

const error = new ScaleError('Critical scaling error');
const mock = vi.fn(scaleUp);
mock.mockImplementation(() => Promise.reject(error));
vi.mocked(scaleUp).mockImplementation(mock);

await expect(scaleUpHandler(multiRecordEvent, context)).rejects.toThrow(error);
});
});
const sqsEventMultipleRecords: SQSEvent = {
Records: sqsRecords,
};

await expect(scaleUpHandler(sqsEventMultipleRecords, context)).resolves.not.toThrow();

expect(logWarnSpy).toHaveBeenCalledWith(
expect.stringContaining(
'Event ignored, only one record at the time can be handled, ensure the lambda batch size is set to 1.',
),
);
}
});

describe('Test scale down lambda wrapper.', () => {
it('Scaling down no error.', async () => {
Expand Down
62 changes: 50 additions & 12 deletions lambdas/functions/control-plane/src/lambda.ts
Original file line number Diff line number Diff line change
@@ -1,34 +1,72 @@
import middy from '@middy/core';
import { logger, setContext } from '@aws-github-runner/aws-powertools-util';
import { captureLambdaHandler, tracer } from '@aws-github-runner/aws-powertools-util';
import { Context, SQSEvent } from 'aws-lambda';
import { Context, type SQSBatchItemFailure, type SQSBatchResponse, SQSEvent } from 'aws-lambda';

import { PoolEvent, adjust } from './pool/pool';
import ScaleError from './scale-runners/ScaleError';
import { scaleDown } from './scale-runners/scale-down';
import { scaleUp } from './scale-runners/scale-up';
import { type ActionRequestMessage, type ActionRequestMessageSQS, scaleUp } from './scale-runners/scale-up';
import { SSMCleanupOptions, cleanSSMTokens } from './scale-runners/ssm-housekeeper';
import { checkAndRetryJob } from './scale-runners/job-retry';

export async function scaleUpHandler(event: SQSEvent, context: Context): Promise<void> {
export async function scaleUpHandler(event: SQSEvent, context: Context): Promise<SQSBatchResponse> {
setContext(context, 'lambda.ts');
logger.logEventIfEnabled(event);

if (event.Records.length !== 1) {
logger.warn('Event ignored, only one record at the time can be handled, ensure the lambda batch size is set to 1.');
return Promise.resolve();
// Group the messages by their event source. We're only interested in
// `aws:sqs`-originated messages.
const groupedEvents = new Map<string, ActionRequestMessageSQS[]>();
for (const { body, eventSource, messageId } of event.Records) {
const group = groupedEvents.get(eventSource) || [];
const payload = JSON.parse(body) as ActionRequestMessage;

if (group.length === 0) {
groupedEvents.set(eventSource, group);
}

groupedEvents.get(eventSource)?.push({
...payload,
messageId,
});
}

for (const [eventSource, messages] of groupedEvents.entries()) {
if (eventSource === 'aws:sqs') {
continue;
}

logger.warn('Ignoring non-sqs event source', { eventSource, messages });
}

const sqsMessages = groupedEvents.get('aws:sqs') ?? [];

// Sort messages by their retry count, so that we retry the same messages if
// there's a persistent failure. This should cause messages to be dropped
// quicker than if we retried in an arbitrary order.
sqsMessages.sort((l, r) => {
return (l.retryCounter ?? 0) - (r.retryCounter ?? 0);
});

const batchItemFailures: SQSBatchItemFailure[] = [];

try {
await scaleUp(event.Records[0].eventSource, JSON.parse(event.Records[0].body));
return Promise.resolve();
const rejectedMessageIds = await scaleUp(sqsMessages);

for (const messageId of rejectedMessageIds) {
batchItemFailures.push({
itemIdentifier: messageId,
});
}

return { batchItemFailures };
} catch (e) {
if (e instanceof ScaleError) {
return Promise.reject(e);
} else {
logger.warn(`Ignoring error: ${e}`);
return Promise.resolve();
throw e;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not started a full review yet. But this part is a bit tricky.

In th current logic the runner.ts is making a difference in type type of errors. The one that can be recovered by a retry (scale errors) and the rest. It also assumes the creation of a single runner. Which means in case of an error that can auto recover the message goes imply back to the queue. Example is limit (max runners), but also not spot instance available.

With the changes introduce here in case of an error the whole batch message goes back to the queue (error in runner.ts). And a retry fro all messages are initated.

To address this the runner.ts needs to be made smarter and be able to return the the number of instances failed to created and retry could make sense. That number (or messages ids) can be added to the ScaleError object.

Keeping the excpetion logic in this way also will caus probelms when the createRunenr only create a few runner since the max is exceeded. If this combines with a ScaleError. The runners not created bec ause a max will not be reprorted back. Due to the Exception.

}

logger.warn(`Will retry error: ${e}`);
return { batchItemFailures };
}
}

Expand Down
42 changes: 32 additions & 10 deletions lambdas/functions/control-plane/src/local.ts
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
import { logger } from '@aws-github-runner/aws-powertools-util';

import { ActionRequestMessage, scaleUp } from './scale-runners/scale-up';
import { scaleUpHandler } from './lambda';
import { Context, SQSEvent } from 'aws-lambda';

const sqsEvent = {
const sqsEvent: SQSEvent = {
Records: [
{
messageId: 'e8d74d08-644e-42ca-bf82-a67daa6c4dad',
receiptHandle:
// eslint-disable-next-line max-len
'AQEBCpLYzDEKq4aKSJyFQCkJduSKZef8SJVOperbYyNhXqqnpFG5k74WygVAJ4O0+9nybRyeOFThvITOaS21/jeHiI5fgaM9YKuI0oGYeWCIzPQsluW5CMDmtvqv1aA8sXQ5n2x0L9MJkzgdIHTC3YWBFLQ2AxSveOyIHwW+cHLIFCAcZlOaaf0YtaLfGHGkAC4IfycmaijV8NSlzYgDuxrC9sIsWJ0bSvk5iT4ru/R4+0cjm7qZtGlc04k9xk5Fu6A+wRxMaIyiFRY+Ya19ykcevQldidmEjEWvN6CRToLgclk=',
body: {
body: JSON.stringify({
repositoryName: 'self-hosted',
repositoryOwner: 'test-runners',
eventType: 'workflow_job',
id: 987654,
installationId: 123456789,
},
}),
attributes: {
ApproximateReceiveCount: '1',
SentTimestamp: '1626450047230',
Expand All @@ -34,12 +34,34 @@ const sqsEvent = {
],
};

const context: Context = {
awsRequestId: '1',
callbackWaitsForEmptyEventLoop: false,
functionName: '',
functionVersion: '',
getRemainingTimeInMillis: () => 0,
invokedFunctionArn: '',
logGroupName: '',
logStreamName: '',
memoryLimitInMB: '',
done: () => {
return;
},
fail: () => {
return;
},
succeed: () => {
return;
},
};

export function run(): void {
scaleUp(sqsEvent.Records[0].eventSource, sqsEvent.Records[0].body as ActionRequestMessage)
.then()
.catch((e) => {
logger.error(e);
});
try {
scaleUpHandler(sqsEvent, context);
} catch (e: unknown) {
const message = e instanceof Error ? e.message : `${e}`;
logger.error(message, e instanceof Error ? { error: e } : {});
}
}

run();
Loading