Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Making BigQueryClientFactory Kryo serializable #1284

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

## Next

* PR #1281 : Configure alternative BigNumeric precision and scale defaults
* Issue #1175: Add details to schema mismatch message
* PR #1281 : Configure alternative BigNumeric precision and scale defaults
* PR #1284: Making BigQueryClientFactory Kryo serializable. Thanks @tom-s-powell !

## 0.40.0 - 2024-08-05

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,15 @@ public class BigQueryClientFactory implements Serializable {
private static final Map<BigQueryClientFactory, BigQueryWriteClient> writeClientMap =
new HashMap<>();

private final Credentials credentials;
// using the user agent as HeaderProvider is not serializable
private final HeaderProvider headerProvider;
private final BigQueryConfig bqConfig;

// GoogleCredentials are not compatible with Kryo serialization, so we serialize and deserialize
// when needed
private final byte[] serializedCredentials;
private transient volatile Credentials credentials;

private int cachedHashCode = 0;

@Inject
Expand All @@ -69,6 +74,7 @@ public BigQueryClientFactory(
BigQueryConfig bqConfig) {
// using Guava's optional as it is serializable
this.credentials = bigQueryCredentialsSupplier.getCredentials();
this.serializedCredentials = BigQueryUtil.getCredentialsByteArray(credentials);
this.headerProvider = headerProvider;
this.bqConfig = bqConfig;
}
Expand Down Expand Up @@ -114,15 +120,16 @@ public int hashCode() {
// Credentials). Subclasses of the abstract class ExternalAccountCredentials do not have the
// hashCode method defined on them and hence we get the byte array of the
// ExternalAccountCredentials first and then compare their hashCodes.
if (credentials instanceof ExternalAccountCredentials) {
if (getCredentials() instanceof ExternalAccountCredentials) {
cachedHashCode =
Objects.hashCode(
Arrays.hashCode(BigQueryUtil.getCredentialsByteArray(credentials)),
Arrays.hashCode(serializedCredentials),
headerProvider,
bqConfig.getClientCreationHashCode());
} else {
cachedHashCode =
Objects.hashCode(credentials, headerProvider, bqConfig.getClientCreationHashCode());
Objects.hashCode(
getCredentials(), headerProvider, bqConfig.getClientCreationHashCode());
}
}
return cachedHashCode;
Expand All @@ -148,12 +155,23 @@ public boolean equals(Object o) {
// ExternalAccountCredentials do not have an equals method defined on them and hence we
// serialize and compare byte arrays if either of the credentials are instances of
// ExternalAccountCredentials
return BigQueryUtil.areCredentialsEqual(credentials, that.credentials);
return BigQueryUtil.areCredentialsEqual(getCredentials(), that.getCredentials());
}

return false;
}

private Credentials getCredentials() {
if (credentials == null) {
synchronized (BigQueryClientFactory.class) {
if (credentials == null) {
credentials = BigQueryUtil.getCredentialsFromByteArray(serializedCredentials);
}
}
}
return credentials;
}

private BigQueryReadClient createBigQueryReadClient(
Optional<String> endpoint, int channelPoolSize, Optional<Integer> flowControlWindow) {
try {
Expand All @@ -178,7 +196,7 @@ private BigQueryReadClient createBigQueryReadClient(
BigQueryReadSettings.Builder clientSettings =
BigQueryReadSettings.newBuilder()
.setTransportChannelProvider(transportBuilder.build())
.setCredentialsProvider(FixedCredentialsProvider.create(credentials));
.setCredentialsProvider(FixedCredentialsProvider.create(getCredentials()));

bqConfig
.getCreateReadSessionTimeoutInSeconds()
Expand Down Expand Up @@ -211,7 +229,7 @@ private BigQueryWriteClient createBigQueryWriteClient(Optional<String> endpoint)
BigQueryWriteSettings.Builder clientSettings =
BigQueryWriteSettings.newBuilder()
.setTransportChannelProvider(transportBuilder.build())
.setCredentialsProvider(FixedCredentialsProvider.create(credentials));
.setCredentialsProvider(FixedCredentialsProvider.create(getCredentials()));
return BigQueryWriteClient.create(clientSettings.build());
} catch (IOException e) {
throw new BigQueryConnectorException("Error creating BigQueryWriteClient", e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,16 @@ static byte[] getCredentialsByteArray(Credentials credentials) {
return byteArrayOutputStream.toByteArray();
}

static Credentials getCredentialsFromByteArray(byte[] byteArray) {
try {
ObjectInputStream objectInputStream =
new ObjectInputStream(new ByteArrayInputStream(byteArray));
return (Credentials) objectInputStream.readObject();
} catch (IOException | ClassNotFoundException e) {
throw new RuntimeException(e);
}
}

// returns the first present optional, empty if all parameters are empty
public static <T> Optional<T> firstPresent(Optional<T>... optionals) {
for (Optional<T> o : optionals) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import static org.junit.Assert.assertThrows;
import static org.junit.Assert.assertTrue;

import com.google.auth.Credentials;
import com.google.auth.oauth2.UserCredentials;
import com.google.cloud.bigquery.BigLakeConfiguration;
import com.google.cloud.bigquery.BigQueryError;
import com.google.cloud.bigquery.BigQueryException;
Expand All @@ -43,7 +45,9 @@
import com.google.cloud.bigquery.storage.v1.ReadSession.TableReadOptions;
import com.google.cloud.bigquery.storage.v1.ReadStream;
import com.google.common.collect.ImmutableList;
import java.time.Instant;
import java.util.Arrays;
import java.util.Date;
import java.util.List;
import java.util.Map;
import java.util.Optional;
Expand Down Expand Up @@ -928,4 +932,19 @@ public void testAdjustField_numeric_to_bigNumeric() {
Field adjustedField = BigQueryUtil.adjustField(field, existingField, false);
assertThat(adjustedField.getType()).isEqualTo(LegacySQLTypeName.BIGNUMERIC);
}

@Test
public void testCredentialSerialization() {
Credentials expected =
UserCredentials.create(
AccessToken.newBuilder()
.setTokenValue("notarealtoken")
.setExpirationTime(Date.from(Instant.now()))
.build());

Credentials credentials =
BigQueryUtil.getCredentialsFromByteArray(BigQueryUtil.getCredentialsByteArray(expected));

assertThat(credentials).isEqualTo(expected);
}
}
Loading