Skip to content

Commit

Permalink
Simplify Stream Interface: Add versioning code as part of normal mess…
Browse files Browse the repository at this point in the history
…age reads. (#5732)

Follow up to #5721.

As part of the performance work, we introduced some dead code in the Versioned Airbyte Stream factory around messages skipping the actual protocol upgrade code.

At that time, this was deemed ok because:

- We were not using the Versioned Stream Factory's migration path and weren't going to use it for some time. (Still not using it now)
- Since that PR pushed a big validation logic change, it was slightly too confusing/complex to do everything in one PR.
- Since this hack would only live for a week or so, the risk was deemed low.
- In this PR, we combine the new validation logic change with Versioned Airbyte factory migration path and remove any dead code.
  • Loading branch information
davinchia committed Apr 11, 2023
1 parent 3f7088a commit ea0389e
Show file tree
Hide file tree
Showing 8 changed files with 133 additions and 75 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

package io.airbyte.commons.protocol.serde;

import com.fasterxml.jackson.databind.JsonNode;
import io.airbyte.commons.version.Version;
import java.util.Optional;

/**
* Airbyte Protocol deserialization interface.
Expand All @@ -14,7 +14,7 @@
*/
public interface AirbyteMessageDeserializer<T> {

T deserialize(final JsonNode json);
Optional<T> deserialize(final String json);

Version getTargetVersion();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

package io.airbyte.commons.protocol.serde;

import com.fasterxml.jackson.databind.JsonNode;
import io.airbyte.commons.json.Jsons;
import io.airbyte.commons.version.Version;
import java.util.Optional;
import lombok.Getter;

/**
Expand All @@ -26,8 +26,8 @@ public AirbyteMessageGenericDeserializer(final Version targetVersion, final Clas
}

@Override
public T deserialize(JsonNode json) {
return Jsons.object(json, typeClass);
public Optional<T> deserialize(String json) {
return Jsons.tryDeserialize(json, typeClass);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@

import static org.junit.jupiter.api.Assertions.assertEquals;

import io.airbyte.commons.json.Jsons;
import io.airbyte.protocol.models.AirbyteMessage;
import io.airbyte.protocol.models.AirbyteMessage.Type;
import io.airbyte.protocol.models.ConnectorSpecification;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.Optional;
import org.junit.jupiter.api.Test;

class AirbyteMessageV0SerDeTest {
Expand All @@ -29,9 +29,9 @@ void v0SerDeRoundTripTest() throws URISyntaxException {
.withDocumentationUrl(new URI("file:///tmp/doc")));

final String serializedMessage = ser.serialize(message);
final AirbyteMessage deserializedMessage = deser.deserialize(Jsons.deserialize(serializedMessage));
final Optional<AirbyteMessage> deserializedMessage = deser.deserialize(serializedMessage);

assertEquals(message, deserializedMessage);
assertEquals(message, deserializedMessage.get());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@

import static org.junit.jupiter.api.Assertions.assertEquals;

import io.airbyte.commons.json.Jsons;
import io.airbyte.protocol.models.AirbyteMessage;
import io.airbyte.protocol.models.AirbyteMessage.Type;
import io.airbyte.protocol.models.ConnectorSpecification;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.Optional;
import org.junit.jupiter.api.Test;

class AirbyteMessageV1SerDeTest {
Expand All @@ -29,9 +29,9 @@ void v1SerDeRoundTripTest() throws URISyntaxException {
.withDocumentationUrl(new URI("file:///tmp/doc")));

final String serializedMessage = ser.serialize(message);
final AirbyteMessage deserializedMessage = deser.deserialize(Jsons.deserialize(serializedMessage));
final Optional<AirbyteMessage> deserializedMessage = deser.deserialize(serializedMessage);

assertEquals(message, deserializedMessage);
assertEquals(message, deserializedMessage.get());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ public class VersionedAirbyteStreamFactory<T> implements AirbyteStreamFactory {

private static final Logger LOGGER = LoggerFactory.getLogger(VersionedAirbyteStreamFactory.class);
private static final double MAX_SIZE_RATIO = 0.8;
private static final long DEFAULT_MEMORY_LIMIT = Runtime.getRuntime().maxMemory();
private static final MdcScope.Builder DEFAULT_MDC_SCOPE = MdcScope.DEFAULT_BUILDER;

private static final Logger DEFAULT_LOGGER = LOGGER;
private static final Version fallbackVersion = new Version("0.2.0");

// Buffer size to use when detecting the protocol version.
Expand All @@ -80,8 +84,8 @@ public class VersionedAirbyteStreamFactory<T> implements AirbyteStreamFactory {
private final AirbyteMessageSerDeProvider serDeProvider;
private final AirbyteProtocolVersionedMigratorFactory migratorFactory;
private final Optional<ConfiguredAirbyteCatalog> configuredAirbyteCatalog;
private AirbyteMessageDeserializer<T> deserializer;
private AirbyteMessageVersionedMigrator<T> migrator;
private AirbyteMessageDeserializer<AirbyteMessage> deserializer;
private AirbyteMessageVersionedMigrator<AirbyteMessage> migrator;
private Version protocolVersion;

private boolean shouldDetectVersion = false;
Expand All @@ -98,7 +102,9 @@ public static VersionedAirbyteStreamFactory noMigrationVersionedAirbyteStreamFac
}

/**
* Similar to the above method, but allows for more testing related variables to be passed in.
* Same as above with additional config for testing.
*
* @return a VersionedAirbyteStreamFactory that does not perform any migration.
*/
@VisibleForTesting
public static VersionedAirbyteStreamFactory noMigrationVersionedAirbyteStreamFactory(Logger logger,
Expand All @@ -125,19 +131,19 @@ public VersionedAirbyteStreamFactory(final AirbyteMessageSerDeProvider serDeProv
final AirbyteProtocolVersionedMigratorFactory migratorFactory,
final Version protocolVersion,
final Optional<ConfiguredAirbyteCatalog> configuredAirbyteCatalog,
final MdcScope.Builder containerLogMdcBuilder,
final Optional<Class<? extends RuntimeException>> exceptionClass) {
this(serDeProvider, migratorFactory, protocolVersion, configuredAirbyteCatalog, LOGGER, MdcScope.DEFAULT_BUILDER, exceptionClass,
this(serDeProvider, migratorFactory, protocolVersion, configuredAirbyteCatalog, LOGGER, containerLogMdcBuilder, exceptionClass,
Runtime.getRuntime().maxMemory());
}

public VersionedAirbyteStreamFactory(final AirbyteMessageSerDeProvider serDeProvider,
final AirbyteProtocolVersionedMigratorFactory migratorFactory,
final Version protocolVersion,
final Optional<ConfiguredAirbyteCatalog> configuredAirbyteCatalog,
final MdcScope.Builder containerLogMdcBuilder,
final Optional<Class<? extends RuntimeException>> exceptionClass) {
this(serDeProvider, migratorFactory, protocolVersion, configuredAirbyteCatalog, LOGGER, containerLogMdcBuilder, exceptionClass,
Runtime.getRuntime().maxMemory());
this(serDeProvider, migratorFactory, protocolVersion, configuredAirbyteCatalog, DEFAULT_LOGGER, DEFAULT_MDC_SCOPE, exceptionClass,
DEFAULT_MEMORY_LIMIT);
}

public VersionedAirbyteStreamFactory(final AirbyteMessageSerDeProvider serDeProvider,
Expand Down Expand Up @@ -170,6 +176,17 @@ public VersionedAirbyteStreamFactory(final AirbyteMessageSerDeProvider serDeProv
@Trace(operationName = WORKER_OPERATION_NAME)
@Override
public Stream<AirbyteMessage> create(final BufferedReader bufferedReader) {
detectAndInitialiseMigrators(bufferedReader);
final boolean needMigration = !protocolVersion.getMajorVersion().equals(migratorFactory.getMostRecentVersion().getMajorVersion());
logger.info(
"Reading messages from protocol version {}{}",
protocolVersion.serialize(),
needMigration ? ", messages will be upgraded to protocol version " + migratorFactory.getMostRecentVersion().serialize() : "");

return addLineReadLogic(bufferedReader);
}

private void detectAndInitialiseMigrators(BufferedReader bufferedReader) {
if (shouldDetectVersion) {
final Optional<Version> versionMaybe;
try {
Expand All @@ -186,21 +203,10 @@ public Stream<AirbyteMessage> create(final BufferedReader bufferedReader) {
initializeForProtocolVersion(fallbackVersion);
}
}

final boolean needMigration = !protocolVersion.getMajorVersion().equals(migratorFactory.getMostRecentVersion().getMajorVersion());
logger.info(
"Reading messages from protocol version {}{}",
protocolVersion.serialize(),
needMigration ? ", messages will be upgraded to protocol version " + migratorFactory.getMostRecentVersion().serialize() : "");
return addLineReadLogic(bufferedReader);
}

/**
* Temporarily ported over from the {@link DefaultAirbyteStreamFactory} whole as we work on
* refactoring the code.
*/
@Trace(operationName = WORKER_OPERATION_NAME)
public Stream<AirbyteMessage> addLineReadLogic(final BufferedReader bufferedReader) {
private Stream<AirbyteMessage> addLineReadLogic(final BufferedReader bufferedReader) {
final var metricClient = MetricClientFactory.getMetricClient();
return bufferedReader
.lines()
Expand Down Expand Up @@ -283,7 +289,7 @@ public VersionedAirbyteStreamFactory<T> withDetectVersion(final boolean detectVe
}

protected final void initializeForProtocolVersion(final Version protocolVersion) {
this.deserializer = (AirbyteMessageDeserializer<T>) serDeProvider.getDeserializer(protocolVersion).orElseThrow();
this.deserializer = (AirbyteMessageDeserializer<AirbyteMessage>) serDeProvider.getDeserializer(protocolVersion).orElseThrow();
this.migrator = migratorFactory.getAirbyteMessageMigrator(protocolVersion);
this.protocolVersion = protocolVersion;
}
Expand Down Expand Up @@ -327,33 +333,40 @@ private String humanReadableByteCountSI(long bytes) {
return String.format("%.1f %cB", bytes / 1000.0, ci.current());
}

/**
* For every incoming message,
* <p>
* 1. deserialize the incoming JSON string to {@link AirbyteMessage}.
* <p>
* 2. validate the message.
* <p>
* 3. upgrade the message to the platform version, if needed.
*/
protected Stream<AirbyteMessage> toAirbyteMessage(final String line) {
Optional<AirbyteMessage> m = Jsons.tryDeserialize(line, AirbyteMessage.class);
// put back the deserializer.
Optional<AirbyteMessage> m = deserializer.deserialize(line);

if (m.isPresent()) {
m = BasicAirbyteMessageValidator.validate(m.get());

if (m.isEmpty()) {
logger.error("Validation failed: {}", Jsons.serialize(line));
return m.stream();
}

return upgradeMessage(m.get());
}
if (m.isEmpty()) {
logger.error("Deserialization failed: {}", Jsons.serialize(line));
}

logger.error("Deserialization failed: {}", Jsons.serialize(line));
return m.stream();
}

/**
* This was initially implemented as an override on the
* {@link DefaultAirbyteStreamFactory#toAirbyteMessage(String)}. However, rollout of the type system
* was deprioritized. In the mean time, the performance work uncovered some low-hanging fruit in the
* deserialization process that changes the signature of the toAirbyteMessage method.
* <p>
* This is left here for posterity when the types work is picked up. It is not currently used. This
* means message unwrapping is still done by DefaultAirbyteStreamFactory.
*/
protected Stream<AirbyteMessage> toAirbyteMessage(final JsonNode json) {
protected Stream<AirbyteMessage> upgradeMessage(final AirbyteMessage msg) {
try {
final AirbyteMessage message = migrator.upgrade(deserializer.deserialize(json), configuredAirbyteCatalog);
final AirbyteMessage message = migrator.upgrade(msg, configuredAirbyteCatalog);
return Stream.of(message);
} catch (final RuntimeException e) {
logger.warn("Failed to upgrade a message from version {}: {}", protocolVersion, Jsons.serialize(json), e);
logger.warn("Failed to upgrade a message from version {}: {}", protocolVersion, Jsons.serialize(msg), e);
return Stream.empty();
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,8 @@
import io.airbyte.commons.converters.ConnectorConfigUpdater;
import io.airbyte.commons.features.EnvVariableFeatureFlags;
import io.airbyte.commons.protocol.AirbyteMessageMigrator;
import io.airbyte.commons.protocol.AirbyteMessageSerDeProvider;
import io.airbyte.commons.protocol.AirbyteProtocolVersionedMigratorFactory;
import io.airbyte.commons.protocol.ConfiguredAirbyteCatalogMigrator;
import io.airbyte.commons.protocol.serde.AirbyteMessageV0Deserializer;
import io.airbyte.commons.protocol.serde.AirbyteMessageV0Serializer;
import io.airbyte.commons.version.Version;
import io.airbyte.config.JobSyncConfig.NamespaceDefinitionType;
import io.airbyte.config.ReplicationOutput;
Expand Down Expand Up @@ -39,7 +36,6 @@
import java.time.Duration;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.atomic.AtomicReference;
import lombok.extern.slf4j.Slf4j;
Expand Down Expand Up @@ -91,20 +87,14 @@ public static void executeOneSync() throws InterruptedException {
// final IntegrationLauncher integrationLauncher = new LimitedIntegrationLauncher(new
// LimitedThinRecordSourceProcess());
final IntegrationLauncher integrationLauncher = new LimitedIntegrationLauncher(new LimitedFatRecordSourceProcess());
final var serDeProvider = new AirbyteMessageSerDeProvider(
List.of(new AirbyteMessageV0Deserializer()),
List.of(new AirbyteMessageV0Serializer()));
serDeProvider.initialize();

final var msgMigrator = new AirbyteMessageMigrator(List.of());
msgMigrator.initialize();
final ConfiguredAirbyteCatalogMigrator catalogMigrator = new ConfiguredAirbyteCatalogMigrator(List.of());
catalogMigrator.initialize();
final var migratorFactory = new AirbyteProtocolVersionedMigratorFactory(msgMigrator, catalogMigrator);

final var versionFac =
new VersionedAirbyteStreamFactory(serDeProvider, migratorFactory, new Version("0.2.0"), Optional.empty(),
Optional.of(RuntimeException.class));
final var versionFac = VersionedAirbyteStreamFactory.noMigrationVersionedAirbyteStreamFactory();
final HeartbeatMonitor heartbeatMonitor = new HeartbeatMonitor(DEFAULT_HEARTBEAT_FRESHNESS_THRESHOLD);
final var versionedAbSource =
new DefaultAirbyteSource(integrationLauncher, versionFac, heartbeatMonitor, migratorFactory.getProtocolSerializer(new Version("0.2.0")),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,73 @@
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;

import io.airbyte.commons.json.Jsons;
import io.airbyte.protocol.models.AirbyteMessage;
import io.airbyte.protocol.models.Config;
import io.airbyte.workers.test_utils.AirbyteMessageUtils;
import java.util.Optional;
import org.junit.Test;

@SuppressWarnings("MissingJavadocType")
public class BasicAirbyteMessageValidatorTest {

@Test
void testValid() {
final AirbyteMessage record1 = AirbyteMessageUtils.createRecordMessage("stream_1", "field_1", "green");
void testObviousInvalid() {
final Optional<AirbyteMessage> bad = Jsons.tryDeserialize("{}", AirbyteMessage.class);

final var m = BasicAirbyteMessageValidator.validate(record1);
final var m = BasicAirbyteMessageValidator.validate(bad.get());
assertTrue(m.isEmpty());
}

@Test
void testValidRecord() {
final AirbyteMessage rec = AirbyteMessageUtils.createRecordMessage("stream_1", "field_1", "green");

final var m = BasicAirbyteMessageValidator.validate(rec);
assertTrue(m.isPresent());
assertEquals(rec, m.get());
}

@Test
void testSubtleInvalidRecord() {
final Optional<AirbyteMessage> bad = Jsons.tryDeserialize("{\"type\": \"RECORD\", \"record\": {}}", AirbyteMessage.class);

final var m = BasicAirbyteMessageValidator.validate(bad.get());
assertTrue(m.isEmpty());
}

@Test
void testValidState() {
final AirbyteMessage rec = AirbyteMessageUtils.createStateMessage(1);

final var m = BasicAirbyteMessageValidator.validate(rec);
assertTrue(m.isPresent());
assertEquals(record1, m.get());
assertEquals(rec, m.get());
}

@Test
void testSubtleInvalidState() {
final Optional<AirbyteMessage> bad = Jsons.tryDeserialize("{\"type\": \"STATE\", \"control\": {}}", AirbyteMessage.class);

final var m = BasicAirbyteMessageValidator.validate(bad.get());
assertTrue(m.isEmpty());
}

@Test
void testValidControl() {
final AirbyteMessage rec = AirbyteMessageUtils.createConfigControlMessage(new Config(), 1000.0);

final var m = BasicAirbyteMessageValidator.validate(rec);
assertTrue(m.isPresent());
assertEquals(rec, m.get());
}

@Test
void testSubtleInvalidControl() {
final Optional<AirbyteMessage> bad = Jsons.tryDeserialize("{\"type\": \"CONTROL\", \"state\": {}}", AirbyteMessage.class);

final var m = BasicAirbyteMessageValidator.validate(bad.get());
assertTrue(m.isEmpty());
}

}
Loading

0 comments on commit ea0389e

Please sign in to comment.