Skip to content

Commit a0bcd0e

Browse files
committed
feat(storage:s3): multi-part upload: upload parts concurrently
1 parent b4d97d8 commit a0bcd0e

File tree

3 files changed

+119
-99
lines changed

3 files changed

+119
-99
lines changed

checkstyle/suppressions.xml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
<suppress checks="ClassDataAbstractionCoupling" files=".*Test\.java"/>
2323
<suppress checks="ClassFanOutComplexity" files=".*Test\.java"/>
2424
<suppress checks="ClassFanOutComplexity" files="RemoteStorageManager.java"/>
25+
<suppress checks="ClassDataAbstractionCoupling" files="S3MultiPartOutputStream.java"/>
2526
<suppress checks="ClassDataAbstractionCoupling" files="S3StorageConfig.java"/>
2627
<suppress checks="ClassDataAbstractionCoupling" files="RemoteStorageManager.java"/>
2728
</suppressions>

storage/s3/src/main/java/io/aiven/kafka/tieredstorage/storage/s3/S3MultiPartOutputStream.java

Lines changed: 50 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -18,29 +18,32 @@
1818

1919
import java.io.ByteArrayInputStream;
2020
import java.io.IOException;
21-
import java.io.InputStream;
2221
import java.io.OutputStream;
2322
import java.nio.ByteBuffer;
2423
import java.util.ArrayList;
25-
import java.util.List;
2624
import java.util.Objects;
25+
import java.util.concurrent.CompletableFuture;
26+
import java.util.concurrent.ConcurrentLinkedQueue;
27+
import java.util.concurrent.atomic.AtomicInteger;
2728

2829
import com.amazonaws.services.s3.AmazonS3;
2930
import com.amazonaws.services.s3.model.AbortMultipartUploadRequest;
3031
import com.amazonaws.services.s3.model.CompleteMultipartUploadRequest;
31-
import com.amazonaws.services.s3.model.CompleteMultipartUploadResult;
3232
import com.amazonaws.services.s3.model.InitiateMultipartUploadRequest;
3333
import com.amazonaws.services.s3.model.InitiateMultipartUploadResult;
3434
import com.amazonaws.services.s3.model.PartETag;
3535
import com.amazonaws.services.s3.model.UploadPartRequest;
36-
import com.amazonaws.services.s3.model.UploadPartResult;
3736
import org.slf4j.Logger;
3837
import org.slf4j.LoggerFactory;
3938

4039
/**
4140
* S3 multipart output stream.
4241
* Enable uploads to S3 with unknown size by feeding input bytes to multiple parts and upload them on close.
4342
*
43+
* <p>OutputStream is used to write sequentially, but
44+
* uploading parts happen asynchronously to reduce full upload latency.
45+
* Concurrency happens within the output stream implementation and does not require changes on the callers.
46+
*
4447
* <p>Requires S3 client and starts a multipart transaction when instantiated. Do not reuse.
4548
*/
4649
public class S3MultiPartOutputStream extends OutputStream {
@@ -54,7 +57,10 @@ public class S3MultiPartOutputStream extends OutputStream {
5457
final int partSize;
5558

5659
private final String uploadId;
57-
private final List<PartETag> partETags = new ArrayList<>();
60+
// holds async part upload operations building a list of partETags required when committing
61+
private CompletableFuture<ConcurrentLinkedQueue<PartETag>> partUploads =
62+
CompletableFuture.completedFuture(new ConcurrentLinkedQueue<>());
63+
private final AtomicInteger partNumber = new AtomicInteger(0);
5864

5965
private boolean closed;
6066

@@ -88,32 +94,42 @@ public void write(final byte[] b, final int off, final int len) throws IOExcepti
8894
}
8995
final ByteBuffer source = ByteBuffer.wrap(b, off, len);
9096
while (source.hasRemaining()) {
91-
final int transferred = Math.min(partBuffer.remaining(), source.remaining());
92-
final int offset = source.arrayOffset() + source.position();
93-
// TODO: get rid of this array copying
94-
partBuffer.put(source.array(), offset, transferred);
95-
source.position(source.position() + transferred);
97+
final int toCopy = Math.min(partBuffer.remaining(), source.remaining());
98+
final int positionAfterCopying = source.position() + toCopy;
99+
source.limit(positionAfterCopying);
100+
partBuffer.put(source.slice());
101+
source.clear(); // reset limit
102+
source.position(positionAfterCopying);
96103
if (!partBuffer.hasRemaining()) {
97-
flushBuffer(0, partSize);
104+
partBuffer.position(0);
105+
partBuffer.limit(partSize);
106+
flushBuffer(partBuffer.slice(), partSize);
107+
partBuffer.clear();
98108
}
99109
}
100110
}
101111

102112
@Override
103113
public void close() throws IOException {
104114
if (partBuffer.position() > 0) {
105-
flushBuffer(partBuffer.arrayOffset(), partBuffer.position());
115+
final int actualPartSize = partBuffer.position();
116+
partBuffer.position(0);
117+
partBuffer.limit(actualPartSize);
118+
flushBuffer(partBuffer.slice(), actualPartSize);
106119
}
107120
if (Objects.nonNull(uploadId)) {
108-
if (!partETags.isEmpty()) {
121+
if (partNumber.get() > 0) {
109122
try {
110-
final CompleteMultipartUploadRequest request =
111-
new CompleteMultipartUploadRequest(bucketName, key, uploadId, partETags);
112-
final CompleteMultipartUploadResult result = client.completeMultipartUpload(request);
123+
// wait for all uploads to complete successfully before committing
124+
final ConcurrentLinkedQueue<PartETag> tagsQueue = partUploads.get(); // TODO: maybe set a timeout?
125+
final ArrayList<PartETag> partETags = new ArrayList<>(tagsQueue);
126+
final var request = new CompleteMultipartUploadRequest(bucketName, key, uploadId, partETags);
127+
final var result = client.completeMultipartUpload(request);
113128
log.debug("Completed multipart upload {} with result {}", uploadId, result);
114129
} catch (final Exception e) {
115130
log.error("Failed to complete multipart upload {}, aborting transaction", uploadId, e);
116131
client.abortMultipartUpload(new AbortMultipartUploadRequest(bucketName, key, uploadId));
132+
throw new IOException(e);
117133
}
118134
} else {
119135
client.abortMultipartUpload(new AbortMultipartUploadRequest(bucketName, key, uploadId));
@@ -122,31 +138,24 @@ public void close() throws IOException {
122138
closed = true;
123139
}
124140

125-
private void flushBuffer(final int offset,
126-
final int actualPartSize) throws IOException {
127-
try {
128-
final ByteArrayInputStream in = new ByteArrayInputStream(partBuffer.array(), offset, actualPartSize);
129-
uploadPart(in, actualPartSize);
130-
partBuffer.clear();
131-
} catch (final Exception e) {
132-
log.error("Failed to upload part in multipart upload {}, aborting transaction", uploadId, e);
133-
client.abortMultipartUpload(new AbortMultipartUploadRequest(bucketName, key, uploadId));
134-
closed = true;
135-
throw new IOException(e);
136-
}
137-
}
141+
private void flushBuffer(final ByteBuffer partBuffer, final int actualPartSize) {
142+
final byte[] partContent = new byte[actualPartSize];
143+
partBuffer.get(partContent, 0, actualPartSize);
144+
145+
final var uploadPartRequest = new UploadPartRequest()
146+
.withBucketName(bucketName)
147+
.withKey(key)
148+
.withUploadId(uploadId)
149+
.withPartSize(actualPartSize)
150+
.withPartNumber(partNumber.incrementAndGet())
151+
.withInputStream(new ByteArrayInputStream(partContent));
138152

139-
private void uploadPart(final InputStream in, final int actualPartSize) {
140-
final int partNumber = partETags.size() + 1;
141-
final UploadPartRequest uploadPartRequest =
142-
new UploadPartRequest()
143-
.withBucketName(bucketName)
144-
.withKey(key)
145-
.withUploadId(uploadId)
146-
.withPartSize(actualPartSize)
147-
.withPartNumber(partNumber)
148-
.withInputStream(in);
149-
final UploadPartResult uploadResult = client.uploadPart(uploadPartRequest);
150-
partETags.add(uploadResult.getPartETag());
153+
// Run request async
154+
partUploads = partUploads.thenCombine(
155+
CompletableFuture.supplyAsync(() -> client.uploadPart(uploadPartRequest)),
156+
(partETags, result) -> {
157+
partETags.add(result.getPartETag());
158+
return partETags;
159+
});
151160
}
152161
}

storage/s3/src/test/java/io/aiven/kafka/tieredstorage/storage/s3/S3MultiPartOutputStreamTest.java

Lines changed: 68 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,12 @@
1616

1717
package io.aiven.kafka.tieredstorage.storage.s3;
1818

19-
import java.io.ByteArrayInputStream;
2019
import java.io.IOException;
21-
import java.util.ArrayList;
22-
import java.util.List;
20+
import java.util.HashMap;
21+
import java.util.Map;
2322
import java.util.Random;
23+
import java.util.concurrent.ConcurrentHashMap;
24+
import java.util.stream.Collectors;
2425

2526
import com.amazonaws.services.s3.AmazonS3;
2627
import com.amazonaws.services.s3.model.AbortMultipartUploadRequest;
@@ -83,7 +84,7 @@ void sendAbortForAnyExceptionWhileWriting() {
8384
new S3MultiPartOutputStream(BUCKET_NAME, FILE_KEY, 100, mockedS3)) {
8485
out.write(new byte[] {1, 2, 3});
8586
}
86-
}).isInstanceOf(IOException.class).hasCause(testException);
87+
}).isInstanceOf(IOException.class).hasRootCause(testException);
8788

8889
verify(mockedS3).initiateMultipartUpload(any(InitiateMultipartUploadRequest.class));
8990
verify(mockedS3).uploadPart(any(UploadPartRequest.class));
@@ -101,7 +102,7 @@ void sendAbortForAnyExceptionWhenClose() throws Exception {
101102
when(mockedS3.uploadPart(any(UploadPartRequest.class)))
102103
.thenThrow(RuntimeException.class);
103104

104-
final S3MultiPartOutputStream out = new S3MultiPartOutputStream(BUCKET_NAME, FILE_KEY, 10, mockedS3);
105+
final var out = new S3MultiPartOutputStream(BUCKET_NAME, FILE_KEY, 10, mockedS3);
105106

106107
final byte[] buffer = new byte[5];
107108
random.nextBytes(buffer);
@@ -127,110 +128,124 @@ void writesOneByte() throws Exception {
127128
when(mockedS3.completeMultipartUpload(completeMultipartUploadRequestCaptor.capture()))
128129
.thenReturn(new CompleteMultipartUploadResult());
129130

130-
try (final S3MultiPartOutputStream out = new S3MultiPartOutputStream(BUCKET_NAME, FILE_KEY, 100, mockedS3)) {
131+
try (final var out = new S3MultiPartOutputStream(BUCKET_NAME, FILE_KEY, 100, mockedS3)) {
131132
out.write(1);
132133
}
133134

134135
verify(mockedS3).initiateMultipartUpload(any(InitiateMultipartUploadRequest.class));
135136
verify(mockedS3).uploadPart(any(UploadPartRequest.class));
136137
verify(mockedS3).completeMultipartUpload(any(CompleteMultipartUploadRequest.class));
137138

139+
final UploadPartRequest value = uploadPartRequestCaptor.getValue();
138140
assertUploadPartRequest(
139-
uploadPartRequestCaptor.getValue(),
141+
value,
142+
value.getInputStream().readAllBytes(),
140143
1,
141144
1,
142145
new byte[] {1});
143146
assertCompleteMultipartUploadRequest(
144147
completeMultipartUploadRequestCaptor.getValue(),
145-
List.of(new PartETag(1, "SOME_ETAG"))
148+
Map.of(1, "SOME_ETAG")
146149
);
147150
}
148151

149152
@Test
150153
void writesMultipleMessages() throws Exception {
151154
final int bufferSize = 10;
152-
final byte[] message = new byte[bufferSize];
153155

154156
when(mockedS3.initiateMultipartUpload(any(InitiateMultipartUploadRequest.class)))
155157
.thenReturn(newInitiateMultipartUploadResult());
156-
when(mockedS3.uploadPart(uploadPartRequestCaptor.capture()))
157-
.thenAnswer(a -> {
158-
final UploadPartRequest up = a.getArgument(0);
158+
159+
// capturing requests and contents from concurrent threads
160+
final Map<Integer, UploadPartRequest> uploadPartRequests = new ConcurrentHashMap<>();
161+
final Map<Integer, byte[]> uploadPartContents = new ConcurrentHashMap<>();
162+
when(mockedS3.uploadPart(any(UploadPartRequest.class)))
163+
.thenAnswer(answer -> {
164+
final UploadPartRequest up = answer.getArgument(0);
165+
//emulate behave of S3 client otherwise we will get wrong array in the memory
166+
uploadPartRequests.put(up.getPartNumber(), up);
167+
uploadPartContents.put(up.getPartNumber(), up.getInputStream().readAllBytes());
168+
159169
return newUploadPartResult(up.getPartNumber(), "SOME_TAG#" + up.getPartNumber());
160170
});
161-
when(mockedS3.completeMultipartUpload(completeMultipartUploadRequestCaptor.capture()))
171+
when(mockedS3.completeMultipartUpload(any(CompleteMultipartUploadRequest.class)))
162172
.thenReturn(new CompleteMultipartUploadResult());
163173

164-
final List<byte[]> expectedMessagesList = new ArrayList<>();
165-
try (final S3MultiPartOutputStream out =
166-
new S3MultiPartOutputStream(BUCKET_NAME, FILE_KEY, bufferSize, mockedS3)) {
174+
final Map<Integer, byte[]> expectedMessageParts = new HashMap<>();
175+
try (final var out = new S3MultiPartOutputStream(BUCKET_NAME, FILE_KEY, bufferSize, mockedS3)) {
167176
for (int i = 0; i < 3; i++) {
177+
final byte[] message = new byte[bufferSize];
168178
random.nextBytes(message);
169179
out.write(message, 0, message.length);
170-
expectedMessagesList.add(message);
180+
expectedMessageParts.put(i + 1, message);
171181
}
172182
}
173183

174184
verify(mockedS3).initiateMultipartUpload(any(InitiateMultipartUploadRequest.class));
175185
verify(mockedS3, times(3)).uploadPart(any(UploadPartRequest.class));
176186
verify(mockedS3).completeMultipartUpload(any(CompleteMultipartUploadRequest.class));
177187

178-
final List<UploadPartRequest> uploadRequests = uploadPartRequestCaptor.getAllValues();
179-
int counter = 0;
180-
for (final byte[] expectedMessage : expectedMessagesList) {
188+
for (final Integer part : expectedMessageParts.keySet()) {
181189
assertUploadPartRequest(
182-
uploadRequests.get(counter),
190+
uploadPartRequests.get(part),
191+
uploadPartContents.get(part),
183192
bufferSize,
184-
counter + 1,
185-
expectedMessage);
186-
counter++;
193+
part,
194+
expectedMessageParts.get(part)
195+
);
187196
}
188197
assertCompleteMultipartUploadRequest(
189198
completeMultipartUploadRequestCaptor.getValue(),
190-
List.of(new PartETag(1, "SOME_TAG#1"),
191-
new PartETag(2, "SOME_TAG#2"),
192-
new PartETag(3, "SOME_TAG#3"))
199+
Map.of(1, "SOME_TAG#1",
200+
2, "SOME_TAG#2",
201+
3, "SOME_TAG#3")
193202
);
194203
}
195204

196205
@Test
197206
void writesTailMessages() throws Exception {
198207
final int messageSize = 20;
199208

200-
final List<UploadPartRequest> uploadPartRequests = new ArrayList<>();
201-
202209
when(mockedS3.initiateMultipartUpload(any(InitiateMultipartUploadRequest.class)))
203210
.thenReturn(newInitiateMultipartUploadResult());
211+
212+
// capturing requests and contents from concurrent threads
213+
final Map<Integer, UploadPartRequest> uploadPartRequests = new ConcurrentHashMap<>();
214+
final Map<Integer, byte[]> uploadPartContents = new ConcurrentHashMap<>();
204215
when(mockedS3.uploadPart(any(UploadPartRequest.class)))
205-
.thenAnswer(a -> {
206-
final UploadPartRequest up = a.getArgument(0);
216+
.thenAnswer(answer -> {
217+
final UploadPartRequest up = answer.getArgument(0);
207218
//emulate behave of S3 client otherwise we will get wrong array in the memory
208-
up.setInputStream(new ByteArrayInputStream(up.getInputStream().readAllBytes()));
209-
uploadPartRequests.add(up);
219+
uploadPartRequests.put(up.getPartNumber(), up);
220+
uploadPartContents.put(up.getPartNumber(), up.getInputStream().readAllBytes());
210221

211222
return newUploadPartResult(up.getPartNumber(), "SOME_TAG#" + up.getPartNumber());
212223
});
213-
when(mockedS3.completeMultipartUpload(completeMultipartUploadRequestCaptor.capture()))
224+
when(mockedS3.completeMultipartUpload(any(CompleteMultipartUploadRequest.class)))
214225
.thenReturn(new CompleteMultipartUploadResult());
215226

216-
final byte[] message = new byte[messageSize];
217-
218227
final byte[] expectedFullMessage = new byte[messageSize + 10];
219228
final byte[] expectedTailMessage = new byte[10];
220229

221-
final S3MultiPartOutputStream
222-
out = new S3MultiPartOutputStream(BUCKET_NAME, FILE_KEY, messageSize + 10, mockedS3);
223-
random.nextBytes(message);
224-
out.write(message);
225-
System.arraycopy(message, 0, expectedFullMessage, 0, message.length);
226-
random.nextBytes(message);
227-
out.write(message);
228-
System.arraycopy(message, 0, expectedFullMessage, 20, 10);
229-
System.arraycopy(message, 10, expectedTailMessage, 0, 10);
230+
final var out = new S3MultiPartOutputStream(BUCKET_NAME, FILE_KEY, messageSize + 10, mockedS3);
231+
{
232+
final byte[] message = new byte[messageSize];
233+
random.nextBytes(message);
234+
out.write(message);
235+
System.arraycopy(message, 0, expectedFullMessage, 0, message.length);
236+
}
237+
{
238+
final byte[] message = new byte[messageSize];
239+
random.nextBytes(message);
240+
out.write(message);
241+
System.arraycopy(message, 0, expectedFullMessage, 20, 10);
242+
System.arraycopy(message, 10, expectedTailMessage, 0, 10);
243+
}
230244
out.close();
231245

232-
assertUploadPartRequest(uploadPartRequests.get(0), 30, 1, expectedFullMessage);
233-
assertUploadPartRequest(uploadPartRequests.get(1), 10, 2, expectedTailMessage);
246+
assertThat(uploadPartRequests).hasSize(2);
247+
assertUploadPartRequest(uploadPartRequests.get(1), uploadPartContents.get(1), 30, 1, expectedFullMessage);
248+
assertUploadPartRequest(uploadPartRequests.get(2), uploadPartContents.get(2), 10, 2, expectedTailMessage);
234249

235250
verify(mockedS3).initiateMultipartUpload(any(InitiateMultipartUploadRequest.class));
236251
verify(mockedS3, times(2)).uploadPart(any(UploadPartRequest.class));
@@ -251,6 +266,7 @@ private static UploadPartResult newUploadPartResult(final int partNumber, final
251266
}
252267

253268
private static void assertUploadPartRequest(final UploadPartRequest uploadPartRequest,
269+
final byte[] bytes,
254270
final int expectedPartSize,
255271
final int expectedPartNumber,
256272
final byte[] expectedBytes) {
@@ -259,23 +275,17 @@ private static void assertUploadPartRequest(final UploadPartRequest uploadPartRe
259275
assertThat(uploadPartRequest.getPartNumber()).isEqualTo(expectedPartNumber);
260276
assertThat(uploadPartRequest.getBucketName()).isEqualTo(BUCKET_NAME);
261277
assertThat(uploadPartRequest.getKey()).isEqualTo(FILE_KEY);
262-
assertThat(uploadPartRequest.getInputStream()).hasBinaryContent(expectedBytes);
278+
assertThat(bytes).isEqualTo(expectedBytes);
263279
}
264280

265281
private static void assertCompleteMultipartUploadRequest(final CompleteMultipartUploadRequest request,
266-
final List<PartETag> expectedETags) {
282+
final Map<Integer, String> expectedETags) {
267283
assertThat(request.getBucketName()).isEqualTo(BUCKET_NAME);
268284
assertThat(request.getKey()).isEqualTo(FILE_KEY);
269285
assertThat(request.getUploadId()).isEqualTo(UPLOAD_ID);
270-
assertThat(request.getPartETags()).hasSameSizeAs(expectedETags);
271-
272-
for (int i = 0; i < expectedETags.size(); i++) {
273-
final PartETag expectedETag = expectedETags.get(i);
274-
final PartETag etag = request.getPartETags().get(i);
275-
276-
assertThat(etag.getPartNumber()).isEqualTo(expectedETag.getPartNumber());
277-
assertThat(etag.getETag()).isEqualTo(expectedETag.getETag());
278-
}
286+
final Map<Integer, String> tags = request.getPartETags().stream()
287+
.collect(Collectors.toMap(PartETag::getPartNumber, PartETag::getETag));
288+
assertThat(tags).containsExactlyInAnyOrderEntriesOf(expectedETags);
279289
}
280290

281291
private static void assertAbortMultipartUploadRequest(final AbortMultipartUploadRequest request) {

0 commit comments

Comments
 (0)