Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[YugabyteDB] Improve locking mechanism during migrations #76

Merged
merged 4 commits into from
Feb 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,21 @@

import java.sql.Connection;
import java.sql.SQLException;
import java.util.List;


@CustomLog
public class YugabyteDBDatabase extends PostgreSQLDatabase {

public static final String LOCK_TABLE_NAME = "YB_FLYWAY_LOCK_TABLE";
private static final String LOCK_TABLE_SCHEMA_SQL = "SELECT table_name, column_name FROM information_schema.columns WHERE table_name = '" + LOCK_TABLE_NAME + "'";
private static final String DROP_LOCK_TABLE_IF_EXISTS_DDL = "DROP TABLE IF EXISTS " + LOCK_TABLE_NAME;
/**
* This table is used to enforce locking through SELECT ... FOR UPDATE on a
* token row inserted in this table. The token row is inserted with the name
* of the Flyway's migration history table as a token for simplicity.
*/
private static final String CREATE_LOCK_TABLE_DDL = "CREATE TABLE IF NOT EXISTS " + LOCK_TABLE_NAME + " (table_name varchar PRIMARY KEY, locked bool)";
private static final String CREATE_LOCK_TABLE_DDL = "CREATE TABLE IF NOT EXISTS " + LOCK_TABLE_NAME + " (table_name varchar PRIMARY KEY, lock_id bigint, ts timestamp)";

public YugabyteDBDatabase(Configuration configuration, JdbcConnectionFactory jdbcConnectionFactory, StatementInterceptor statementInterceptor) {
super(configuration, jdbcConnectionFactory, statementInterceptor);
Expand Down Expand Up @@ -84,7 +87,21 @@ public boolean useSingleConnection() {

private void createLockTable() {
try {
jdbcTemplate.execute(CREATE_LOCK_TABLE_DDL);
List<String> columns = jdbcTemplate.query(LOCK_TABLE_SCHEMA_SQL, rs -> rs.getString("column_name"));
if (columns.isEmpty()) {
LOG.debug("Lock table not found, creating it...");
jdbcTemplate.execute(CREATE_LOCK_TABLE_DDL);
} else {
for (String column : columns) {
if ("lock_id".equals(column)) {
LOG.debug("Lock table with expected schema already exists");
return;
}
}
LOG.info("Lock table exists but has old schema. Dropping and recreating it with new schema...");
jdbcTemplate.execute(DROP_LOCK_TABLE_IF_EXISTS_DDL);
jdbcTemplate.execute(CREATE_LOCK_TABLE_DDL);
}
} catch (SQLException e) {
throw new FlywaySqlException("Unable to initialize the lock table", e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
import org.flywaydb.core.api.FlywayException;
import org.flywaydb.core.internal.exception.FlywaySqlException;
import org.flywaydb.core.internal.jdbc.JdbcTemplate;
import org.flywaydb.core.internal.strategy.RetryStrategy;
import org.flywaydb.core.internal.util.FlywayDbWebsiteLinks;
import org.flywaydb.core.internal.util.SqlCallable;

import java.sql.*;
import java.time.Instant;
import java.util.Map;
import java.util.Random;
import java.util.concurrent.Callable;
import java.util.concurrent.ConcurrentHashMap;

Expand All @@ -18,6 +20,10 @@ public class YugabyteDBExecutionTemplate {
private final JdbcTemplate jdbcTemplate;
private final String tableName;
private static final Map<String, Boolean> tableEntries = new ConcurrentHashMap<>();
private static final Random random = new Random();
public static final int DEFAULT_LOCK_ID_TTL = 1000 * 60 * 5;
public static final int MAX_LOCK_ID_TTL = 1000 * 60 * 60;
public static final String LOCK_ID_TTL_SYS_PROP_NAME = "flyway.yugabytedb.lock-id-ttl-ms";

YugabyteDBExecutionTemplate(JdbcTemplate jdbcTemplate, String tableName) {
this.jdbcTemplate = jdbcTemplate;
Expand All @@ -26,8 +32,9 @@ public class YugabyteDBExecutionTemplate {

public <T> T execute(Callable<T> callable) {
Exception error = null;
long lockId = 0;
try {
lock();
lockId = lock();
return callable.call();
} catch (RuntimeException e) {
error = e;
Expand All @@ -36,31 +43,36 @@ public <T> T execute(Callable<T> callable) {
error = e;
throw new FlywayException(e);
} finally {
unlock(error);
if (lockId != 0) {
unlock(lockId, error);
}
}
}

private void lock() throws SQLException {
RetryStrategy strategy = new RetryStrategy();
strategy.doWithRetries(this::tryLock, "Interrupted while attempting to acquire lock through SELECT ... FOR UPDATE",
private long lock() throws SQLException {
YBRetryStrategy strategy = new YBRetryStrategy();
return strategy.doWithRetries(this::tryLock, "Interrupted while attempting to acquire lock through SELECT ... FOR UPDATE",
"Number of retries exceeded while attempting to acquire lock through SELECT ... FOR UPDATE. " +
"Configure the number of retries with the 'lockRetryCount' configuration option: " + FlywayDbWebsiteLinks.LOCK_RETRY_COUNT);

}

private boolean tryLock() {
private long tryLock() {
Exception exception = null;
boolean txStarted = false, success = false;
boolean txStarted = false;
long lockIdToBeReturned = 0;
Statement statement = null;
try {
statement = jdbcTemplate.getConnection().createStatement();

if (!tableEntries.containsKey(tableName)) {
try {
String now = new Timestamp(Instant.now().getEpochSecond()).toString();
statement.executeUpdate("INSERT INTO "
+ YugabyteDBDatabase.LOCK_TABLE_NAME
+ " VALUES ('" + tableName + "', 'false')");
+ " VALUES ('" + tableName + "', 0, '" + now + "')");
tableEntries.put(tableName, true);
LOG.info("insert query ts: " + now);
LOG.info(Thread.currentThread().getName() + "> Inserted a token row for " + tableName + " in " + YugabyteDBDatabase.LOCK_TABLE_NAME);
} catch (SQLException e) {
if ("23505".equals(e.getSQLState())) {
Expand All @@ -72,38 +84,53 @@ private boolean tryLock() {
}
}

boolean locked;
String selectForUpdate = "SELECT locked FROM "
long lockIdRead = 0;
String selectForUpdate = "SELECT lock_id, ts FROM "
+ YugabyteDBDatabase.LOCK_TABLE_NAME
+ " WHERE table_name = '"
+ tableName
+ "' FOR UPDATE";
String updateLocked = "UPDATE " + YugabyteDBDatabase.LOCK_TABLE_NAME
+ " SET locked = true WHERE table_name = '"
+ tableName + "'";

statement.execute("BEGIN");
txStarted = true;
ResultSet rs = statement.executeQuery(selectForUpdate);
if (rs.next()) {
locked = rs.getBoolean("locked");
lockIdRead = rs.getLong("lock_id");
Timestamp tsRead = rs.getTimestamp("ts");
String current = new Timestamp(Instant.now().getEpochSecond()).toString();
long lockIdTtl = DEFAULT_LOCK_ID_TTL;
String sysProp = System.getProperty(LOCK_ID_TTL_SYS_PROP_NAME);
if (sysProp != null) {
try {
lockIdTtl = Long.parseLong(sysProp);
lockIdTtl = lockIdTtl < 0 || lockIdTtl > MAX_LOCK_ID_TTL ? DEFAULT_LOCK_ID_TTL : lockIdTtl;
} catch (NumberFormatException e) {
LOG.warn("Invalid value for " + LOCK_ID_TTL_SYS_PROP_NAME + ": " + sysProp + ". Using default value: " + DEFAULT_LOCK_ID_TTL + " ms");
}
}

if (locked) {
statement.execute("COMMIT");
txStarted = false;
LOG.debug(Thread.currentThread().getName() + "> Another Flyway operation is in progress. Allowing it to complete");
if (lockIdRead == 0 || Instant.now().getEpochSecond() - tsRead.getTime() > lockIdTtl) {
lockIdToBeReturned = random.nextLong();
if (lockIdRead == 0) {
LOG.debug(Thread.currentThread().getName() + "> Setting lock_id = " + lockIdToBeReturned);
} else {
LOG.warn(Thread.currentThread().getName() + "> Lock with lock_id " + lockIdRead + " is held for more than " + lockIdTtl + " millis. Resetting it with lock_id " + lockIdToBeReturned);
}
String updateLockId = "UPDATE " + YugabyteDBDatabase.LOCK_TABLE_NAME
+ " SET lock_id = " + lockIdToBeReturned + ", ts = '" + current + "' WHERE table_name = '"
+ tableName + "'";
LOG.debug(Thread.currentThread().getName() + "> executing query " + updateLockId);
statement.executeUpdate(updateLockId);
} else {
LOG.debug(Thread.currentThread().getName() + "> Setting locked = true");
statement.executeUpdate(updateLocked);
success = true;
LOG.debug(Thread.currentThread().getName() + "> Another Flyway operation is in progress. Allowing it to complete");
}
} else {
// For some reason the record was not found, retry
tableEntries.remove(tableName);
}

} catch (SQLException e) {
LOG.warn(Thread.currentThread().getName() + "> Unable to perform lock action, SQLState: " + e.getSQLState());
LOG.debug(Thread.currentThread().getName() + "> Unable to perform lock action, SQLState: " + e.getSQLState());
if (!"40001".equalsIgnoreCase(e.getSQLState())) {
exception = new FlywaySqlException("Unable to perform lock action", e);
throw (FlywaySqlException) exception;
Expand All @@ -112,56 +139,103 @@ private boolean tryLock() {
if (txStarted) {
try {
statement.execute("COMMIT");
LOG.debug(Thread.currentThread().getName() + "> Completed the tx to set locked = true");
// lock_id may not be set if there is exception in select for update
LOG.debug(Thread.currentThread().getName() + "> Completed the tx to attempt to set lock_id");
} catch (SQLException e) {
if (exception == null) {
throw new FlywaySqlException("Failed to commit the tx to set locked = true", e);
throw new FlywaySqlException("Failed to commit the tx to set lock_id ", e);
}
LOG.warn(Thread.currentThread().getName() + "> Failed to commit the tx to set locked = true: " + e);
LOG.warn(Thread.currentThread().getName() + "> Failed to commit the tx to set lock_id: " + e);
}
}
}
return success;
return lockIdToBeReturned;
}

private void unlock(Exception rethrow) {
private void unlock(long prevLockId, Exception rethrow) {
Statement statement = null;
try {
statement = jdbcTemplate.getConnection().createStatement();
statement.execute("BEGIN");
ResultSet rs = statement.executeQuery("SELECT locked FROM " + YugabyteDBDatabase.LOCK_TABLE_NAME + " WHERE table_name = '" + tableName + "' FOR UPDATE");
ResultSet rs = statement.executeQuery("SELECT lock_id FROM " + YugabyteDBDatabase.LOCK_TABLE_NAME + " WHERE table_name = '" + tableName + "' FOR UPDATE");

if (rs.next()) {
boolean locked = rs.getBoolean("locked");
if (locked) {
statement.executeUpdate("UPDATE " + YugabyteDBDatabase.LOCK_TABLE_NAME + " SET locked = false WHERE table_name = '" + tableName + "'");
long lockId = rs.getLong("lock_id");
if (lockId == prevLockId) {
statement.executeUpdate("UPDATE " + YugabyteDBDatabase.LOCK_TABLE_NAME + " SET lock_id = 0 WHERE table_name = '" + tableName + "'");
} else {
// Unexpected. This may happen only when callable took too long to complete
// and another thread forcefully reset it.
String msgLock = "Expected and actual lock_id mismatch. Expected: " + prevLockId + ", Actual: " + lockId;
String msg = "Unlock failed but the Flyway operation may have succeeded. Check your Flyway operation before re-trying";
LOG.warn(Thread.currentThread().getName() + "> " + msg);
LOG.warn(Thread.currentThread().getName() + "> " + msg + "\n" + msgLock);
if (rethrow == null) {
throw new FlywayException(msg);
}
}
}
} catch (SQLException e) {
if (rethrow == null) {
rethrow = new FlywayException("Unable to perform unlock action", e);
rethrow = new FlywaySqlException("Unable to perform unlock action for lock_id " + prevLockId, e);
throw (FlywaySqlException) rethrow;
}
LOG.warn("Unable to perform unlock action " + e);
LOG.warn("Unable to perform unlock action for lock_id " + prevLockId + ": " + e);
} finally {
try {
statement.execute("COMMIT");
LOG.debug(Thread.currentThread().getName() + "> Completed the tx to set locked = false");
LOG.debug(Thread.currentThread().getName() + "> Completed the tx to reset lock_id " + prevLockId);
} catch (SQLException e) {
if (rethrow == null) {
throw new FlywaySqlException("Failed to commit unlock action", e);
throw new FlywaySqlException("Failed to commit unlock action for lock_id " + prevLockId, e);
}
LOG.warn("Failed to commit unlock action: " + e);
LOG.warn("Failed to commit unlock action for lock_id " + prevLockId + ": " + e);
}
}
}

public static class YBRetryStrategy {
private static int numberOfRetries = 50;
private static boolean unlimitedRetries;
private int numberOfRetriesRemaining;

public YBRetryStrategy() {
this.numberOfRetriesRemaining = numberOfRetries;
}

public static void setNumberOfRetries(int retries) {
numberOfRetries = retries;
unlimitedRetries = retries < 0;
}

private boolean hasMoreRetries() {
return unlimitedRetries || this.numberOfRetriesRemaining > 0;
}

private void nextRetry() {
if (!unlimitedRetries) {
--this.numberOfRetriesRemaining;
}
}

private int nextWaitInMilliseconds() {
return 1000;
}

public long doWithRetries(SqlCallable<Long> callable, String interruptionMessage, String retriesExceededMessage) throws SQLException {
long id = 0;
while(id == 0) {
id = callable.call();
try {
Thread.sleep(this.nextWaitInMilliseconds());
} catch (InterruptedException e) {
throw new FlywayException(interruptionMessage, e);
}
if (!this.hasMoreRetries()) {
throw new FlywayException(retriesExceededMessage);
}
this.nextRetry();
}
return id;
}
}
}