Skip to content

Commit d8d9f7d

Browse files
committed
feat(storage:s3): multi-part upload: upload parts concurrently
1 parent b4d97d8 commit d8d9f7d

File tree

3 files changed

+123
-84
lines changed

3 files changed

+123
-84
lines changed

checkstyle/suppressions.xml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
<suppress checks="ClassDataAbstractionCoupling" files=".*Test\.java"/>
2323
<suppress checks="ClassFanOutComplexity" files=".*Test\.java"/>
2424
<suppress checks="ClassFanOutComplexity" files="RemoteStorageManager.java"/>
25+
<suppress checks="ClassDataAbstractionCoupling" files="S3MultiPartOutputStream.java"/>
2526
<suppress checks="ClassDataAbstractionCoupling" files="S3StorageConfig.java"/>
2627
<suppress checks="ClassDataAbstractionCoupling" files="RemoteStorageManager.java"/>
2728
</suppressions>

storage/s3/src/main/java/io/aiven/kafka/tieredstorage/storage/s3/S3MultiPartOutputStream.java

Lines changed: 63 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,16 @@
1818

1919
import java.io.ByteArrayInputStream;
2020
import java.io.IOException;
21-
import java.io.InputStream;
2221
import java.io.OutputStream;
2322
import java.nio.ByteBuffer;
2423
import java.util.ArrayList;
2524
import java.util.List;
25+
import java.util.Map;
2626
import java.util.Objects;
27+
import java.util.concurrent.CompletableFuture;
28+
import java.util.concurrent.ConcurrentHashMap;
29+
import java.util.concurrent.ExecutionException;
30+
import java.util.stream.Collectors;
2731

2832
import com.amazonaws.services.s3.AmazonS3;
2933
import com.amazonaws.services.s3.model.AbortMultipartUploadRequest;
@@ -33,14 +37,17 @@
3337
import com.amazonaws.services.s3.model.InitiateMultipartUploadResult;
3438
import com.amazonaws.services.s3.model.PartETag;
3539
import com.amazonaws.services.s3.model.UploadPartRequest;
36-
import com.amazonaws.services.s3.model.UploadPartResult;
3740
import org.slf4j.Logger;
3841
import org.slf4j.LoggerFactory;
3942

4043
/**
4144
* S3 multipart output stream.
4245
* Enable uploads to S3 with unknown size by feeding input bytes to multiple parts and upload them on close.
4346
*
47+
* <p>OutputStream is used to write sequentially, but
48+
* uploading parts happen asynchronously to reduce full upload latency.
49+
* Concurrency happens within the output stream implementation and does not require changes on the callers.
50+
*
4451
* <p>Requires S3 client and starts a multipart transaction when instantiated. Do not reuse.
4552
*/
4653
public class S3MultiPartOutputStream extends OutputStream {
@@ -54,7 +61,10 @@ public class S3MultiPartOutputStream extends OutputStream {
5461
final int partSize;
5562

5663
private final String uploadId;
57-
private final List<PartETag> partETags = new ArrayList<>();
64+
// Keep track of the async uploads to wait for before committing
65+
private final List<CompletableFuture<Void>> partUploads = new ArrayList<>();
66+
// concurrent tags are required as multiple threads will write to it after upload result is returned
67+
private final Map<Integer, String> partETags = new ConcurrentHashMap<>();
5868

5969
private boolean closed;
6070

@@ -88,32 +98,49 @@ public void write(final byte[] b, final int off, final int len) throws IOExcepti
8898
}
8999
final ByteBuffer source = ByteBuffer.wrap(b, off, len);
90100
while (source.hasRemaining()) {
91-
final int transferred = Math.min(partBuffer.remaining(), source.remaining());
92-
final int offset = source.arrayOffset() + source.position();
93-
// TODO: get rid of this array copying
94-
partBuffer.put(source.array(), offset, transferred);
95-
source.position(source.position() + transferred);
101+
final int toCopy = Math.min(partBuffer.remaining(), source.remaining());
102+
final int positionAfterCopying = source.position() + toCopy;
103+
source.limit(positionAfterCopying);
104+
partBuffer.put(source.slice());
105+
source.clear(); // reset limit
106+
source.position(positionAfterCopying);
96107
if (!partBuffer.hasRemaining()) {
97-
flushBuffer(0, partSize);
108+
partBuffer.position(0);
109+
partBuffer.limit(partSize);
110+
flushBuffer(partBuffer.slice(), partSize);
111+
partBuffer.clear();
98112
}
99113
}
100114
}
101115

102116
@Override
103117
public void close() throws IOException {
104118
if (partBuffer.position() > 0) {
105-
flushBuffer(partBuffer.arrayOffset(), partBuffer.position());
119+
final int actualPartSize = partBuffer.position();
120+
partBuffer.position(0);
121+
partBuffer.limit(actualPartSize);
122+
flushBuffer(partBuffer.slice(), actualPartSize);
106123
}
107124
if (Objects.nonNull(uploadId)) {
108-
if (!partETags.isEmpty()) {
125+
if (!partUploads.isEmpty()) {
109126
try {
110-
final CompleteMultipartUploadRequest request =
111-
new CompleteMultipartUploadRequest(bucketName, key, uploadId, partETags);
112-
final CompleteMultipartUploadResult result = client.completeMultipartUpload(request);
113-
log.debug("Completed multipart upload {} with result {}", uploadId, result);
114-
} catch (final Exception e) {
127+
// wait for all uploads to complete successfully before committing
128+
CompletableFuture.allOf(partUploads.toArray(new CompletableFuture[0]))
129+
.thenAccept(unused -> {
130+
final List<PartETag> tags = partETags.entrySet()
131+
.stream()
132+
.map(entry -> new PartETag(entry.getKey(), entry.getValue()))
133+
.collect(Collectors.toList());
134+
final CompleteMultipartUploadRequest request =
135+
new CompleteMultipartUploadRequest(bucketName, key, uploadId, tags);
136+
final CompleteMultipartUploadResult result = client.completeMultipartUpload(request);
137+
log.debug("Completed multipart upload {} with result {}", uploadId, result);
138+
})
139+
.get(); // TODO: maybe set a timeout?
140+
} catch (final InterruptedException | ExecutionException e) {
115141
log.error("Failed to complete multipart upload {}, aborting transaction", uploadId, e);
116142
client.abortMultipartUpload(new AbortMultipartUploadRequest(bucketName, key, uploadId));
143+
throw new IOException(e);
117144
}
118145
} else {
119146
client.abortMultipartUpload(new AbortMultipartUploadRequest(bucketName, key, uploadId));
@@ -122,31 +149,32 @@ public void close() throws IOException {
122149
closed = true;
123150
}
124151

125-
private void flushBuffer(final int offset,
126-
final int actualPartSize) throws IOException {
152+
private void flushBuffer(final ByteBuffer partBuffer, final int actualPartSize) throws IOException {
127153
try {
128-
final ByteArrayInputStream in = new ByteArrayInputStream(partBuffer.array(), offset, actualPartSize);
129-
uploadPart(in, actualPartSize);
130-
partBuffer.clear();
154+
final byte[] array = new byte[actualPartSize];
155+
partBuffer.get(array, 0, actualPartSize);
156+
157+
final UploadPartRequest uploadPartRequest =
158+
new UploadPartRequest()
159+
.withBucketName(bucketName)
160+
.withKey(key)
161+
.withUploadId(uploadId)
162+
.withPartSize(actualPartSize)
163+
.withPartNumber(partUploads.size() + 1)
164+
.withInputStream(new ByteArrayInputStream(array));
165+
166+
// Run request async
167+
final CompletableFuture<Void> upload =
168+
CompletableFuture.supplyAsync(() -> client.uploadPart(uploadPartRequest))
169+
.thenAccept(result ->
170+
partETags.put(result.getPartETag().getPartNumber(), result.getPartETag().getETag()));
171+
172+
partUploads.add(upload);
131173
} catch (final Exception e) {
132174
log.error("Failed to upload part in multipart upload {}, aborting transaction", uploadId, e);
133175
client.abortMultipartUpload(new AbortMultipartUploadRequest(bucketName, key, uploadId));
134176
closed = true;
135177
throw new IOException(e);
136178
}
137179
}
138-
139-
private void uploadPart(final InputStream in, final int actualPartSize) {
140-
final int partNumber = partETags.size() + 1;
141-
final UploadPartRequest uploadPartRequest =
142-
new UploadPartRequest()
143-
.withBucketName(bucketName)
144-
.withKey(key)
145-
.withUploadId(uploadId)
146-
.withPartSize(actualPartSize)
147-
.withPartNumber(partNumber)
148-
.withInputStream(in);
149-
final UploadPartResult uploadResult = client.uploadPart(uploadPartRequest);
150-
partETags.add(uploadResult.getPartETag());
151-
}
152180
}

storage/s3/src/test/java/io/aiven/kafka/tieredstorage/storage/s3/S3MultiPartOutputStreamTest.java

Lines changed: 59 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,12 @@
1616

1717
package io.aiven.kafka.tieredstorage.storage.s3;
1818

19-
import java.io.ByteArrayInputStream;
2019
import java.io.IOException;
21-
import java.util.ArrayList;
22-
import java.util.List;
20+
import java.util.HashMap;
21+
import java.util.Map;
2322
import java.util.Random;
23+
import java.util.concurrent.ConcurrentHashMap;
24+
import java.util.stream.Collectors;
2425

2526
import com.amazonaws.services.s3.AmazonS3;
2627
import com.amazonaws.services.s3.model.AbortMultipartUploadRequest;
@@ -83,7 +84,7 @@ void sendAbortForAnyExceptionWhileWriting() {
8384
new S3MultiPartOutputStream(BUCKET_NAME, FILE_KEY, 100, mockedS3)) {
8485
out.write(new byte[] {1, 2, 3});
8586
}
86-
}).isInstanceOf(IOException.class).hasCause(testException);
87+
}).isInstanceOf(IOException.class).hasRootCause(testException);
8788

8889
verify(mockedS3).initiateMultipartUpload(any(InitiateMultipartUploadRequest.class));
8990
verify(mockedS3).uploadPart(any(UploadPartRequest.class));
@@ -135,102 +136,116 @@ void writesOneByte() throws Exception {
135136
verify(mockedS3).uploadPart(any(UploadPartRequest.class));
136137
verify(mockedS3).completeMultipartUpload(any(CompleteMultipartUploadRequest.class));
137138

139+
final UploadPartRequest value = uploadPartRequestCaptor.getValue();
138140
assertUploadPartRequest(
139-
uploadPartRequestCaptor.getValue(),
141+
value,
142+
value.getInputStream().readAllBytes(),
140143
1,
141144
1,
142145
new byte[] {1});
143146
assertCompleteMultipartUploadRequest(
144147
completeMultipartUploadRequestCaptor.getValue(),
145-
List.of(new PartETag(1, "SOME_ETAG"))
148+
Map.of(1, "SOME_ETAG")
146149
);
147150
}
148151

149152
@Test
150153
void writesMultipleMessages() throws Exception {
151154
final int bufferSize = 10;
152-
final byte[] message = new byte[bufferSize];
153155

154156
when(mockedS3.initiateMultipartUpload(any(InitiateMultipartUploadRequest.class)))
155157
.thenReturn(newInitiateMultipartUploadResult());
158+
159+
final Map<Integer, UploadPartRequest> uploadPartRequests = new ConcurrentHashMap<>();
160+
final Map<Integer, byte[]> uploadPartContents = new ConcurrentHashMap<>();
156161
when(mockedS3.uploadPart(uploadPartRequestCaptor.capture()))
157-
.thenAnswer(a -> {
158-
final UploadPartRequest up = a.getArgument(0);
162+
.thenAnswer(answer -> {
163+
final UploadPartRequest up = answer.getArgument(0);
164+
//emulate behave of S3 client otherwise we will get wrong array in the memory
165+
uploadPartRequests.put(up.getPartNumber(), up);
166+
uploadPartContents.put(up.getPartNumber(), up.getInputStream().readAllBytes());
167+
159168
return newUploadPartResult(up.getPartNumber(), "SOME_TAG#" + up.getPartNumber());
160169
});
161170
when(mockedS3.completeMultipartUpload(completeMultipartUploadRequestCaptor.capture()))
162171
.thenReturn(new CompleteMultipartUploadResult());
163172

164-
final List<byte[]> expectedMessagesList = new ArrayList<>();
173+
final Map<Integer, byte[]> expectedMessageParts = new HashMap<>();
165174
try (final S3MultiPartOutputStream out =
166175
new S3MultiPartOutputStream(BUCKET_NAME, FILE_KEY, bufferSize, mockedS3)) {
167176
for (int i = 0; i < 3; i++) {
177+
final byte[] message = new byte[bufferSize];
168178
random.nextBytes(message);
169179
out.write(message, 0, message.length);
170-
expectedMessagesList.add(message);
180+
expectedMessageParts.put(i + 1, message);
171181
}
172182
}
173183

174184
verify(mockedS3).initiateMultipartUpload(any(InitiateMultipartUploadRequest.class));
175185
verify(mockedS3, times(3)).uploadPart(any(UploadPartRequest.class));
176186
verify(mockedS3).completeMultipartUpload(any(CompleteMultipartUploadRequest.class));
177187

178-
final List<UploadPartRequest> uploadRequests = uploadPartRequestCaptor.getAllValues();
179-
int counter = 0;
180-
for (final byte[] expectedMessage : expectedMessagesList) {
188+
for (final Integer part : expectedMessageParts.keySet()) {
181189
assertUploadPartRequest(
182-
uploadRequests.get(counter),
190+
uploadPartRequests.get(part),
191+
uploadPartContents.get(part),
183192
bufferSize,
184-
counter + 1,
185-
expectedMessage);
186-
counter++;
193+
part,
194+
expectedMessageParts.get(part)
195+
);
187196
}
188197
assertCompleteMultipartUploadRequest(
189198
completeMultipartUploadRequestCaptor.getValue(),
190-
List.of(new PartETag(1, "SOME_TAG#1"),
191-
new PartETag(2, "SOME_TAG#2"),
192-
new PartETag(3, "SOME_TAG#3"))
199+
Map.of(1, "SOME_TAG#1",
200+
2, "SOME_TAG#2",
201+
3, "SOME_TAG#3")
193202
);
194203
}
195204

196205
@Test
197206
void writesTailMessages() throws Exception {
198207
final int messageSize = 20;
199208

200-
final List<UploadPartRequest> uploadPartRequests = new ArrayList<>();
209+
final Map<Integer, UploadPartRequest> uploadPartRequests = new ConcurrentHashMap<>();
210+
final Map<Integer, byte[]> uploadPartContents = new ConcurrentHashMap<>();
201211

202212
when(mockedS3.initiateMultipartUpload(any(InitiateMultipartUploadRequest.class)))
203213
.thenReturn(newInitiateMultipartUploadResult());
204214
when(mockedS3.uploadPart(any(UploadPartRequest.class)))
205-
.thenAnswer(a -> {
206-
final UploadPartRequest up = a.getArgument(0);
215+
.thenAnswer(answer -> {
216+
final UploadPartRequest up = answer.getArgument(0);
207217
//emulate behave of S3 client otherwise we will get wrong array in the memory
208-
up.setInputStream(new ByteArrayInputStream(up.getInputStream().readAllBytes()));
209-
uploadPartRequests.add(up);
218+
uploadPartRequests.put(up.getPartNumber(), up);
219+
uploadPartContents.put(up.getPartNumber(), up.getInputStream().readAllBytes());
210220

211221
return newUploadPartResult(up.getPartNumber(), "SOME_TAG#" + up.getPartNumber());
212222
});
213223
when(mockedS3.completeMultipartUpload(completeMultipartUploadRequestCaptor.capture()))
214224
.thenReturn(new CompleteMultipartUploadResult());
215225

216-
final byte[] message = new byte[messageSize];
217226

218227
final byte[] expectedFullMessage = new byte[messageSize + 10];
219228
final byte[] expectedTailMessage = new byte[10];
220229

221-
final S3MultiPartOutputStream
222-
out = new S3MultiPartOutputStream(BUCKET_NAME, FILE_KEY, messageSize + 10, mockedS3);
223-
random.nextBytes(message);
224-
out.write(message);
225-
System.arraycopy(message, 0, expectedFullMessage, 0, message.length);
226-
random.nextBytes(message);
227-
out.write(message);
228-
System.arraycopy(message, 0, expectedFullMessage, 20, 10);
229-
System.arraycopy(message, 10, expectedTailMessage, 0, 10);
230+
final var out = new S3MultiPartOutputStream(BUCKET_NAME, FILE_KEY, messageSize + 10, mockedS3);
231+
{
232+
final byte[] message = new byte[messageSize];
233+
random.nextBytes(message);
234+
out.write(message);
235+
System.arraycopy(message, 0, expectedFullMessage, 0, message.length);
236+
}
237+
{
238+
final byte[] message = new byte[messageSize];
239+
random.nextBytes(message);
240+
out.write(message);
241+
System.arraycopy(message, 0, expectedFullMessage, 20, 10);
242+
System.arraycopy(message, 10, expectedTailMessage, 0, 10);
243+
}
230244
out.close();
231245

232-
assertUploadPartRequest(uploadPartRequests.get(0), 30, 1, expectedFullMessage);
233-
assertUploadPartRequest(uploadPartRequests.get(1), 10, 2, expectedTailMessage);
246+
assertThat(uploadPartRequests).hasSize(2);
247+
assertUploadPartRequest(uploadPartRequests.get(1), uploadPartContents.get(1), 30, 1, expectedFullMessage);
248+
assertUploadPartRequest(uploadPartRequests.get(2), uploadPartContents.get(2), 10, 2, expectedTailMessage);
234249

235250
verify(mockedS3).initiateMultipartUpload(any(InitiateMultipartUploadRequest.class));
236251
verify(mockedS3, times(2)).uploadPart(any(UploadPartRequest.class));
@@ -251,6 +266,7 @@ private static UploadPartResult newUploadPartResult(final int partNumber, final
251266
}
252267

253268
private static void assertUploadPartRequest(final UploadPartRequest uploadPartRequest,
269+
final byte[] bytes,
254270
final int expectedPartSize,
255271
final int expectedPartNumber,
256272
final byte[] expectedBytes) {
@@ -259,23 +275,17 @@ private static void assertUploadPartRequest(final UploadPartRequest uploadPartRe
259275
assertThat(uploadPartRequest.getPartNumber()).isEqualTo(expectedPartNumber);
260276
assertThat(uploadPartRequest.getBucketName()).isEqualTo(BUCKET_NAME);
261277
assertThat(uploadPartRequest.getKey()).isEqualTo(FILE_KEY);
262-
assertThat(uploadPartRequest.getInputStream()).hasBinaryContent(expectedBytes);
278+
assertThat(bytes).isEqualTo(expectedBytes);
263279
}
264280

265281
private static void assertCompleteMultipartUploadRequest(final CompleteMultipartUploadRequest request,
266-
final List<PartETag> expectedETags) {
282+
final Map<Integer, String> expectedETags) {
267283
assertThat(request.getBucketName()).isEqualTo(BUCKET_NAME);
268284
assertThat(request.getKey()).isEqualTo(FILE_KEY);
269285
assertThat(request.getUploadId()).isEqualTo(UPLOAD_ID);
270-
assertThat(request.getPartETags()).hasSameSizeAs(expectedETags);
271-
272-
for (int i = 0; i < expectedETags.size(); i++) {
273-
final PartETag expectedETag = expectedETags.get(i);
274-
final PartETag etag = request.getPartETags().get(i);
275-
276-
assertThat(etag.getPartNumber()).isEqualTo(expectedETag.getPartNumber());
277-
assertThat(etag.getETag()).isEqualTo(expectedETag.getETag());
278-
}
286+
final Map<Integer, String> tags = request.getPartETags().stream()
287+
.collect(Collectors.toMap(PartETag::getPartNumber, PartETag::getETag));
288+
assertThat(tags).containsExactlyInAnyOrderEntriesOf(expectedETags);
279289
}
280290

281291
private static void assertAbortMultipartUploadRequest(final AbortMultipartUploadRequest request) {

0 commit comments

Comments
 (0)