Skip to content

Correct offloading of large batch entries to s3. Closes #154 #155

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,18 @@
import static com.amazon.sqs.javamessaging.AmazonSQSExtendedClientUtil.getReservedAttributeNameIfPresent;
import static com.amazon.sqs.javamessaging.AmazonSQSExtendedClientUtil.isLarge;
import static com.amazon.sqs.javamessaging.AmazonSQSExtendedClientUtil.isS3ReceiptHandle;
import static com.amazon.sqs.javamessaging.AmazonSQSExtendedClientUtil.sizeOf;
import static com.amazon.sqs.javamessaging.AmazonSQSExtendedClientUtil.updateMessageAttributePayloadSize;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
Expand Down Expand Up @@ -73,6 +77,7 @@
import software.amazon.awssdk.services.sqs.model.SendMessageResponse;
import software.amazon.awssdk.services.sqs.model.SqsException;
import software.amazon.awssdk.services.sqs.model.TooManyEntriesInBatchRequestException;
import software.amazon.awssdk.utils.Pair;
import software.amazon.awssdk.utils.StringUtils;
import software.amazon.payloadoffloading.PayloadStore;
import software.amazon.payloadoffloading.S3BackedPayloadStore;
Expand Down Expand Up @@ -616,23 +621,38 @@ public SendMessageBatchResponse sendMessageBatch(SendMessageBatchRequest sendMes
return super.sendMessageBatch(sendMessageBatchRequest);
}

List<SendMessageBatchRequestEntry> batchEntries = new ArrayList<>(sendMessageBatchRequest.entries().size());
List<SendMessageBatchRequestEntry> originalEntries = sendMessageBatchRequest.entries();
ArrayList<SendMessageBatchRequestEntry> alteredEntries = new ArrayList<>(originalEntries.size());
alteredEntries.addAll(originalEntries);

// Batch entry sizes order by size
List<Pair<Integer, Long>> entrySizes = IntStream.range(0, originalEntries.size())
.boxed()
.map(i -> Pair.of(i, sizeOf(originalEntries.get(i))))
.sorted((p1, p2) -> Long.compare(p2.right(), p1.right()))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the purpose of sorting here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The overall idea here is that we want to offload the least amount of batch entries possible. Less AWS calls means less time spent on roundtrips.
So we sort by size, and after that we start moving entries from largest to smallest, stoping at the moment totalSize fits the limit.
We only offload according to the sort order, and preserve original order of entries in resulting batch request, as it is important for fifo use-case.

.collect(Collectors.toList());

long totalSize = entrySizes.stream().map(Pair::right).mapToLong(Long::longValue).sum();

// Move messages to s3 starting from the largest until total size is under the threshold if needed
boolean hasS3Entries = false;
for (SendMessageBatchRequestEntry entry : sendMessageBatchRequest.entries()) {
//Check message attributes for ExtendedClient related constraints
checkMessageAttributes(clientConfiguration.getPayloadSizeThreshold(), entry.messageAttributes());

if (clientConfiguration.isAlwaysThroughS3()
|| isLarge(clientConfiguration.getPayloadSizeThreshold(), entry)) {
entry = storeMessageInS3(entry);
hasS3Entries = true;
for (Pair<Integer, Long> pair : entrySizes) {
// Verify that total size of batch request is within limits
if (totalSize <= clientConfiguration.getPayloadSizeThreshold() && !clientConfiguration.isAlwaysThroughS3()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we are checking the total size, does it make sense to move the if condition outside of the for loop? If it is true, we can throw an exception, else, we will keep add entry to s3

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The overall idea described in comment above.
The code follow the idea, we move entries one by one (starting from largest). We keep track of total size after each offload.
We want to stop offloading entries at the moment total is under the threshold to prevent un-necessary offloading, resulting in lower overall execution time and costs.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ziyanli-amazon the idea is to offload as few entries as possible, as each offload is a roundtrip against S3. So offloading goes from biggest entries to smallest until batch fits into size limit.

break;
}
batchEntries.add(entry);
Integer entryIndex = pair.left();
Long originalEntrySize = pair.right();
SendMessageBatchRequestEntry originalEntry = originalEntries.get(entryIndex);
checkMessageAttributes(clientConfiguration.getPayloadSizeThreshold(), originalEntry.messageAttributes());
SendMessageBatchRequestEntry alteredEntry = storeMessageInS3(originalEntry);
totalSize = totalSize - originalEntrySize + sizeOf(alteredEntry);
alteredEntries.set(entryIndex, alteredEntry);
hasS3Entries = true;
}

if (hasS3Entries) {
sendMessageBatchRequest = sendMessageBatchRequest.toBuilder().entries(batchEntries).build();
sendMessageBatchRequest = sendMessageBatchRequest.toBuilder().entries(alteredEntries).build();
}

return super.sendMessageBatch(sendMessageBatchRequest);
Expand Down Expand Up @@ -896,6 +916,6 @@ private static <T extends AwsRequest.Builder> T appendUserAgent(final T builder)
public void close() {
super.close();
this.clientConfiguration.getS3Client().close();
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,14 @@ public static boolean isLarge(int payloadSizeThreshold, SendMessageRequest sendM
return (totalMsgSize > payloadSizeThreshold);
}

public static long sizeOf(SendMessageBatchRequestEntry batchRequestEntry) {
int msgAttributesSize = getMsgAttributesSize(batchRequestEntry.messageAttributes());
long msgBodySize = Util.getStringSizeInBytes(batchRequestEntry.messageBody());
return msgAttributesSize + msgBodySize;
}

public static boolean isLarge(int payloadSizeThreshold, SendMessageBatchRequestEntry batchEntry) {
int msgAttributesSize = getMsgAttributesSize(batchEntry.messageAttributes());
long msgBodySize = Util.getStringSizeInBytes(batchEntry.messageBody());
long totalMsgSize = msgAttributesSize + msgBodySize;
long totalMsgSize = sizeOf(batchEntry);
return (totalMsgSize > payloadSizeThreshold);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ public void testReceiveMessage_when_MessageIsSmall() {

@Test
public void testWhenMessageBatchIsSentThenOnlyMessagesLargerThanThresholdAreStoredInS3() {
// This creates 10 messages, out of which only two are below the threshold (100K and 200K),
// This creates 10 messages, out of which only two are below the threshold (100K and 150K),
// and the other 8 are above the threshold

int[] messageLengthForCounter = new int[] {
Expand All @@ -437,7 +437,7 @@ public void testWhenMessageBatchIsSentThenOnlyMessagesLargerThanThresholdAreStor
700_000,
800_000,
900_000,
200_000,
150_000,
1000_000
};

Expand All @@ -459,6 +459,44 @@ public void testWhenMessageBatchIsSentThenOnlyMessagesLargerThanThresholdAreStor
verify(mockS3, times(8)).putObject(isA(PutObjectRequest.class), isA(RequestBody.class));
}


@Test
public void testWhenMessageBatchWithTotalSizeOverTheLimitIsSentThenLargestEntriesAreStoredInS3() {
// This creates 10 messages, out of which only two are below the threshold (100K and 150K),
// and the other 8 are above the threshold

int[] messageLengthForCounter = new int[] {
10_000,
10_000,
10_000,
150_000,
160_000,
170_000,
180_000,
10_000,
10_000,
10_000
};

List<SendMessageBatchRequestEntry> batchEntries = new ArrayList<>();
for (int i = 0; i < 10; i++) {
int messageLength = messageLengthForCounter[i];
String messageBody = generateStringWithLength(messageLength);
SendMessageBatchRequestEntry entry = SendMessageBatchRequestEntry.builder()
.id("entry_" + i)
.messageBody(messageBody)
.build();
batchEntries.add(entry);
}

SendMessageBatchRequest batchRequest = SendMessageBatchRequest.builder().queueUrl(SQS_QUEUE_URL).entries(batchEntries).build();
extendedSqsWithDefaultConfig.sendMessageBatch(batchRequest);

// There should be 3 puts for the 3 largest messages as sum of sizes of others should be within limit
verify(mockS3, times(3)).putObject(isA(PutObjectRequest.class), isA(RequestBody.class));
}


@Test
public void testWhenMessageBatchIsLargeS3PointerIsCorrectlySentToSQSAndNotOriginalMessage() {
String messageBody = generateStringWithLength(LESS_THAN_SQS_SIZE_LIMIT);
Expand Down