Skip to content

Commit

Permalink
Added test to validate recovery when StorageRead client fails (#50)
Browse files Browse the repository at this point in the history
The most common error, when reading large streams, occurs when iterating
on the ServerStream. This new test should be able to demonstrate the
pipeline is able to recover the execution and obtain the expected
results by restarting the read from the beginning of the split or the
last checkpointed offset.

---------

Co-authored-by: Jayant Jain <[email protected]>
Co-authored-by: Jayant Jain <[email protected]>
  • Loading branch information
3 people authored Nov 28, 2023
1 parent f42f7ac commit 22e5a48
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ public RecordsWithSplitIds<GenericRecord> fetch() throws IOException {
ReadRowsResponse response = readStreamIterator.next();
if (!response.hasAvroRows()) {
LOG.info(
"[subtask #{}][hostname %s] The response contained"
"[subtask #{}][hostname {}] The response contained"
+ " no avro records for stream {}.",
readerContext.getIndexOfSubtask(),
readerContext.getLocalHostName(),
Expand Down Expand Up @@ -213,7 +213,7 @@ public RecordsWithSplitIds<GenericRecord> fetch() throws IOException {
}
Long itTimeMs = System.currentTimeMillis() - itStartTime;
LOG.debug(
"[subtask #{}][hostname %s] Completed reading iteration in {}ms,"
"[subtask #{}][hostname {}] Completed reading iteration in {}ms,"
+ " so far read {} from stream {}.",
readerContext.getIndexOfSubtask(),
readerContext.getLocalHostName(),
Expand All @@ -240,7 +240,7 @@ public RecordsWithSplitIds<GenericRecord> fetch() throws IOException {
Long splitTimeMs = System.currentTimeMillis() - splitStartFetch;
this.readSplitTimeMetric.ifPresent(m -> m.update(splitTimeMs));
LOG.info(
"[subtask #{}][hostname %s] Completed reading split, {} records in {}ms on stream {}.",
"[subtask #{}][hostname {}] Completed reading split, {} records in {}ms on stream {}.",
readerContext.getIndexOfSubtask(),
readerContext.getLocalHostName(),
readSoFar,
Expand All @@ -253,7 +253,7 @@ public RecordsWithSplitIds<GenericRecord> fetch() throws IOException {
} else {
Long fetchTimeMs = System.currentTimeMillis() - fetchStartTime;
LOG.debug(
"[subtask #{}][hostname %s] Completed a partial fetch in {}ms,"
"[subtask #{}][hostname {}] Completed a partial fetch in {}ms,"
+ " so far read {} from stream {}.",
readerContext.getIndexOfSubtask(),
readerContext.getLocalHostName(),
Expand All @@ -263,13 +263,15 @@ public RecordsWithSplitIds<GenericRecord> fetch() throws IOException {
}
return respBuilder.build();
} catch (Exception ex) {
// release the iterator just in case
readStreamIterator = null;
throw new IOException(
String.format(
"[subtask #%d][hostname %s] Problems while reading stream %s from BigQuery"
+ " with connection info %s. Current split offset %d,"
+ " reader offset %d. Flink options %s.",
readerContext.getIndexOfSubtask(),
readerContext.getLocalHostName(),
Optional.ofNullable(readerContext.getLocalHostName()).orElse("NA"),
Optional.ofNullable(assignedSplit.getStreamName()).orElse("NA"),
readOptions.toString(),
assignedSplit.getOffset(),
Expand All @@ -296,15 +298,15 @@ public void handleSplitsChanges(SplitsChange<BigQuerySourceSplit> splitsChanges)
@Override
public void wakeUp() {
LOG.debug(
"[subtask #{}][hostname %s] Wake up called.",
"[subtask #{}][hostname %{}] Wake up called.",
readerContext.getIndexOfSubtask(), readerContext.getLocalHostName());
// do nothing, for now
}

@Override
public void close() throws Exception {
LOG.debug(
"[subtask #{}][hostname %s] Close called, assigned splits {}.",
"[subtask #{}][hostname {}] Close called, assigned splits {}.",
readerContext.getIndexOfSubtask(),
readerContext.getLocalHostName(),
assignedSplits.toString());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package com.google.cloud.flink.bigquery.fakes;

import org.apache.flink.util.Preconditions;
import org.apache.flink.util.function.SerializableFunction;

import com.google.api.services.bigquery.model.Job;
Expand Down Expand Up @@ -62,6 +63,7 @@
import java.util.List;
import java.util.Optional;
import java.util.Random;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

Expand Down Expand Up @@ -131,18 +133,60 @@ public Job dryRunQuery(String projectId, String query) {
};
}

static class FaultyIterator<T> implements Iterator<T> {

private final Iterator<T> realIterator;
private final Double errorPercentage;
private final Random random = new Random(42);

public FaultyIterator(Iterator<T> realIterator, Double errorPercentage) {
this.realIterator = realIterator;
Preconditions.checkState(
0 <= errorPercentage && errorPercentage <= 100,
"The error percentage should be between 0 and 100");
this.errorPercentage = errorPercentage;
}

@Override
public boolean hasNext() {
return realIterator.hasNext();
}

@Override
public T next() {
if (random.nextDouble() * 100 < errorPercentage) {
throw new RuntimeException(
"Faulty iterator has failed, it will happen with a chance of: "
+ errorPercentage);
}
return realIterator.next();
}

@Override
public void remove() {
realIterator.remove();
}

@Override
public void forEachRemaining(Consumer<? super T> action) {
realIterator.forEachRemaining(action);
}
}

/** Implementation of the server stream for testing purposes. */
public static class FakeBigQueryServerStream
implements BigQueryServices.BigQueryServerStream<ReadRowsResponse> {

private final List<ReadRowsResponse> toReturn;
private final Double errorPercentage;

public FakeBigQueryServerStream(
SerializableFunction<RecordGenerationParams, List<GenericRecord>> dataGenerator,
String schema,
String dataPrefix,
Long size,
Long offset) {
Long offset,
Double errorPercentage) {
this.toReturn =
createResponse(
schema,
Expand All @@ -153,11 +197,12 @@ public FakeBigQueryServerStream(
.collect(Collectors.toList()),
0,
size);
this.errorPercentage = errorPercentage;
}

@Override
public Iterator<ReadRowsResponse> iterator() {
return toReturn.iterator();
return new FaultyIterator<>(toReturn.iterator(), errorPercentage);
}

@Override
Expand All @@ -170,13 +215,22 @@ public static class FakeBigQueryStorageReadClient implements StorageReadClient {
private final ReadSession session;
private final SerializableFunction<RecordGenerationParams, List<GenericRecord>>
dataGenerator;
private final Double errorPercentage;

public FakeBigQueryStorageReadClient(
ReadSession session,
SerializableFunction<RecordGenerationParams, List<GenericRecord>>
dataGenerator) {
this(session, dataGenerator, 0D);
}

public FakeBigQueryStorageReadClient(
ReadSession session,
SerializableFunction<RecordGenerationParams, List<GenericRecord>> dataGenerator,
Double errorPercentage) {
this.session = session;
this.dataGenerator = dataGenerator;
this.errorPercentage = errorPercentage;
}

@Override
Expand All @@ -196,7 +250,8 @@ public BigQueryServerStream<ReadRowsResponse> readRows(ReadRowsRequest request)
session.getAvroSchema().getSchema(),
request.getReadStream(),
session.getEstimatedRowCount(),
request.getOffset());
request.getOffset(),
errorPercentage);
}

@Override
Expand Down Expand Up @@ -381,6 +436,17 @@ public static BigQueryReadOptions createReadOptions(
String avroSchemaString,
SerializableFunction<RecordGenerationParams, List<GenericRecord>> dataGenerator)
throws IOException {
return createReadOptions(
expectedRowCount, expectedReadStreamCount, avroSchemaString, dataGenerator, 0D);
}

public static BigQueryReadOptions createReadOptions(
Integer expectedRowCount,
Integer expectedReadStreamCount,
String avroSchemaString,
SerializableFunction<RecordGenerationParams, List<GenericRecord>> dataGenerator,
Double errorPercentage)
throws IOException {
return BigQueryReadOptions.builder()
.setBigQueryConnectOptions(
BigQueryConnectOptions.builder()
Expand All @@ -397,7 +463,8 @@ public static BigQueryReadOptions createReadOptions(
expectedRowCount,
expectedReadStreamCount,
avroSchemaString),
dataGenerator));
dataGenerator,
errorPercentage));
})
.build())
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ public class BigQuerySourceITCase {

private static final int PARALLELISM = 2;
private static final Integer TOTAL_ROW_COUNT_PER_STREAM = 10000;
private static final Integer STREAM_COUNT = 2;

@RegisterExtension
static final MiniClusterExtension MINI_CLUSTER_RESOURCE =
Expand All @@ -71,7 +72,7 @@ public void beforeTest() throws Exception {
readOptions =
StorageClientFaker.createReadOptions(
TOTAL_ROW_COUNT_PER_STREAM,
2,
STREAM_COUNT,
StorageClientFaker.SIMPLE_AVRO_SCHEMA_STRING);
}

Expand Down Expand Up @@ -108,7 +109,7 @@ public void testReadCount() throws Exception {
.executeAndCollect());

// we only create 2 streams as response
assertThat(results).hasSize(TOTAL_ROW_COUNT_PER_STREAM * PARALLELISM);
assertThat(results).hasSize(TOTAL_ROW_COUNT_PER_STREAM * STREAM_COUNT);
}

@Test
Expand All @@ -121,12 +122,12 @@ public void testLimit() throws Exception {
List<RowData> results =
env.fromSource(bqSource, WatermarkStrategy.noWatermarks(), "BigQuery-Source")
.executeAndCollect(TOTAL_ROW_COUNT_PER_STREAM);

// need to check on parallelism since the limit is triggered per task + reader contexts = 2
assertThat(results).hasSize(limitSize * PARALLELISM);
}

@Test
public void testRecovery() throws Exception {
public void testDownstreamRecovery() throws Exception {
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
env.enableCheckpointing(300L);

Expand All @@ -142,7 +143,46 @@ public void testRecovery() throws Exception {
.map(new FailingMapper(failed))
.executeAndCollect());

assertThat(results).hasSize(TOTAL_ROW_COUNT_PER_STREAM * PARALLELISM);
assertThat(results).hasSize(TOTAL_ROW_COUNT_PER_STREAM * STREAM_COUNT);
}

@Test
public void testReaderRecovery() throws Exception {
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
env.enableCheckpointing(300L);

ResolvedSchema schema =
ResolvedSchema.of(
Column.physical("name", DataTypes.STRING()),
Column.physical("number", DataTypes.BIGINT()));
RowType rowType = (RowType) schema.toPhysicalRowDataType().getLogicalType();

TypeInformation<RowData> typeInfo = InternalTypeInfo.of(rowType);

BigQuerySource<RowData> bqSource =
BigQuerySource.<RowData>builder()
.setReadOptions(
StorageClientFaker.createReadOptions(
// just put more rows JIC
TOTAL_ROW_COUNT_PER_STREAM,
STREAM_COUNT,
StorageClientFaker.SIMPLE_AVRO_SCHEMA_STRING,
params -> StorageClientFaker.createRecordList(params),
// we want this to fail 10% of the time (1 in 10 times)
10D))
.setDeserializationSchema(
new AvroToRowDataDeserializationSchema(rowType, typeInfo))
.build();

List<RowData> results =
CollectionUtil.iteratorToList(
env.fromSource(
bqSource,
WatermarkStrategy.noWatermarks(),
"BigQuery-Source")
.executeAndCollect());

assertThat(results).hasSize(TOTAL_ROW_COUNT_PER_STREAM * STREAM_COUNT);
}

private static class FailingMapper
Expand Down

0 comments on commit 22e5a48

Please sign in to comment.