Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

impl: Add @Sql annotation support for R2DBC in Spring tests #34350

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions spring-test/spring-test.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ dependencies {
optional(project(":spring-beans"))
optional(project(":spring-context"))
optional(project(":spring-jdbc"))
optional(project(":spring-r2dbc"))
optional(project(":spring-orm"))
optional(project(":spring-tx"))
optional(project(":spring-web"))
Expand Down Expand Up @@ -80,6 +81,7 @@ dependencies {
testImplementation("org.hibernate.orm:hibernate-core")
testImplementation("org.hibernate.validator:hibernate-validator")
testImplementation("org.hsqldb:hsqldb")
testImplementation("io.r2dbc:r2dbc-h2")
testImplementation("org.junit.platform:junit-platform-testkit")
testRuntimeOnly("com.sun.xml.bind:jaxb-core")
testRuntimeOnly("com.sun.xml.bind:jaxb-impl")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/*
* Copyright 2002-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.test.context.jdbc;

import java.util.List;

import io.r2dbc.spi.ConnectionFactory;
import reactor.core.publisher.Mono;

import org.springframework.core.io.Resource;
import org.springframework.r2dbc.connection.init.ResourceDatabasePopulator;

/**
* R2dbcPopulatorUtils is a separate class to avoid name conflicts with existing
* jdbc-related classes.
*
* <p><b>NOTE:</b> In the current architecture, MergedSqlConfig is implemented
* as a package-private method, so it has been placed in
* org.springframework.test.context.jdbc.
*
* @author jonghoon park
* @since 7.0
* @see SqlScriptsTestExecutionListener
* @see MergedSqlConfig
*/
public abstract class R2dbcPopulatorUtils {

static void execute(MergedSqlConfig mergedSqlConfig, ConnectionFactory connectionFactory, List<Resource> scriptResources) {
ResourceDatabasePopulator populator = createResourceDatabasePopulator(mergedSqlConfig);
populator.setScripts(scriptResources.toArray(new Resource[0]));

Mono.from(connectionFactory.create())
.flatMap(populator::populate)
.block();
}

private static ResourceDatabasePopulator createResourceDatabasePopulator(MergedSqlConfig mergedSqlConfig) {
ResourceDatabasePopulator populator = new ResourceDatabasePopulator();
populator.setSqlScriptEncoding(mergedSqlConfig.getEncoding());
populator.setSeparator(mergedSqlConfig.getSeparator());
populator.setCommentPrefixes(mergedSqlConfig.getCommentPrefixes());
populator.setBlockCommentStartDelimiter(mergedSqlConfig.getBlockCommentStartDelimiter());
populator.setBlockCommentEndDelimiter(mergedSqlConfig.getBlockCommentEndDelimiter());
populator.setContinueOnError(mergedSqlConfig.getErrorMode() == SqlConfig.ErrorMode.CONTINUE_ON_ERROR);
populator.setIgnoreFailedDrops(mergedSqlConfig.getErrorMode() == SqlConfig.ErrorMode.IGNORE_FAILED_DROPS);
return populator;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

import javax.sql.DataSource;

import io.r2dbc.spi.ConnectionFactory;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.jspecify.annotations.Nullable;
Expand All @@ -45,6 +46,7 @@
import org.springframework.test.context.jdbc.SqlMergeMode.MergeMode;
import org.springframework.test.context.support.AbstractTestExecutionListener;
import org.springframework.test.context.transaction.TestContextTransactionUtils;
import org.springframework.test.context.transaction.reactive.TestContextReactiveTransactionUtils;
import org.springframework.test.context.util.TestContextResourceUtils;
import org.springframework.transaction.PlatformTransactionManager;
import org.springframework.transaction.TransactionDefinition;
Expand Down Expand Up @@ -332,8 +334,13 @@ else if (logger.isDebugEnabled()) {
Assert.state(!newTxRequired, () -> String.format("Failed to execute SQL scripts for test context %s: " +
"cannot execute SQL scripts using Transaction Mode " +
"[%s] without a PlatformTransactionManager.", testContext, TransactionMode.ISOLATED));
Assert.state(dataSource != null, () -> String.format("Failed to execute SQL scripts for test context %s: " +
"supply at least a DataSource or PlatformTransactionManager.", testContext));
if (dataSource == null) {
ConnectionFactory connectionFactory = TestContextReactiveTransactionUtils.retrieveConnectionFactory(testContext);
Assert.state(connectionFactory != null, () -> String.format("Failed to execute SQL scripts for test context %s: " +
"supply at least a DataSource or PlatformTransactionManager or ConnectionFactory.", testContext));
R2dbcPopulatorUtils.execute(mergedSqlConfig, connectionFactory, scriptResources);
return;
}
// Execute scripts directly against the DataSource
populator.execute(dataSource);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
/*
* Copyright 2002-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.test.context.transaction.reactive;

import java.util.Map;

import io.r2dbc.spi.Connection;
import io.r2dbc.spi.ConnectionFactory;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.jspecify.annotations.Nullable;

import org.springframework.beans.BeansException;
import org.springframework.beans.factory.BeanFactory;
import org.springframework.beans.factory.BeanFactoryUtils;
import org.springframework.beans.factory.ListableBeanFactory;
import org.springframework.test.context.TestContext;
import org.springframework.transaction.PlatformTransactionManager;
import org.springframework.util.Assert;

/**
* Utility methods for working with transactions and data access related beans
* within the <em>Spring TestContext Framework</em>.
*
* <p>Mainly for internal use within the framework.
*
* @author jonghoon park
* @since 7.0
*/
public abstract class TestContextReactiveTransactionUtils {

/**
* Default bean name for a {@link ConnectionFactory}:
* {@code "connectionFactory"}.
*/
public static final String DEFAULT_CONNECTION_FACTORY_NAME = "connectionFactory";


private static final Log logger = LogFactory.getLog(TestContextReactiveTransactionUtils.class);

/**
* Retrieve the {@link ConnectionFactory} to use for the supplied {@linkplain TestContext
* test context}.
* <p>The following algorithm is used to retrieve the {@code ConnectionFactory} from
* the {@link org.springframework.context.ApplicationContext ApplicationContext}
* of the supplied test context:
* <ol>
* <li>Attempt to look up the single {@code ConnectionFactory} by type.
* <li>Attempt to look up the <em>primary</em> {@code ConnectionFactory} by type.
* <li>Attempt to look up the {@code ConnectionFactory} by type and the
* {@linkplain #DEFAULT_CONNECTION_FACTORY_NAME default data source name}.
* </ol>
* @param testContext the test context for which the {@code ConnectionFactory}
* should be retrieved; never {@code null}
* @return the {@code DataSource} to use, or {@code null} if not found
*/
@Nullable
public static ConnectionFactory retrieveConnectionFactory(TestContext testContext) {
Assert.notNull(testContext, "TestContext must not be null");
BeanFactory bf = testContext.getApplicationContext().getAutowireCapableBeanFactory();

try {
if (bf instanceof ListableBeanFactory lbf) {
// Look up single bean by type
Map<String, ConnectionFactory> ConnectionFactories =
BeanFactoryUtils.beansOfTypeIncludingAncestors(lbf, ConnectionFactory.class);
if (ConnectionFactories.size() == 1) {
return ConnectionFactories.values().iterator().next();
}

try {
// look up single bean by type, with support for 'primary' beans
return bf.getBean(ConnectionFactory.class);
}
catch (BeansException ex) {
logBeansException(testContext, ex, PlatformTransactionManager.class);
}
}

// look up by type and default name
return bf.getBean(DEFAULT_CONNECTION_FACTORY_NAME, ConnectionFactory.class);
}
catch (BeansException ex) {
logBeansException(testContext, ex, Connection.class);
return null;
}
}

private static void logBeansException(TestContext testContext, BeansException ex, Class<?> beanType) {
if (logger.isTraceEnabled()) {
logger.trace("Caught exception while retrieving %s for test context %s"
.formatted(beanType.getSimpleName(), testContext), ex);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
/**
* JDBC support classes for the <em>Spring TestContext Framework</em>,
* including support for declarative SQL script execution via {@code @Sql}.
*/
@NullMarked
package org.springframework.test.context.transaction.reactive;

import org.jspecify.annotations.NullMarked;
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/*
* Copyright 2002-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.test.r2dbc;

import java.util.Objects;

import io.r2dbc.spi.ConnectionFactory;
import org.jspecify.annotations.Nullable;
import reactor.core.publisher.Mono;

import org.springframework.r2dbc.core.DatabaseClient;
import org.springframework.util.StringUtils;

/**
* {@code R2dbcTestUtils} is a collection of R2DBC related utility functions
* intended to simplify standard database testing scenarios.
*
* @author jonghoon park
* @since 7.0
* @see org.springframework.r2dbc.core.DatabaseClient
*/
public abstract class R2dbcTestUtils {

/**
* Count the rows in the given table.
* @param connectionFactory the {@link ConnectionFactory} with which to perform R2DBC
* operations
* @param tableName name of the table to count rows in
* @return the number of rows in the table
*/
public static Mono<Integer> countRowsInTable(ConnectionFactory connectionFactory, String tableName) {
return countRowsInTable(DatabaseClient.create(connectionFactory), tableName);
}

/**
* Count the rows in the given table.
* @param databaseClient the {@link DatabaseClient} with which to perform R2DBC
* operations
* @param tableName name of the table to count rows in
* @return the number of rows in the table
*/
public static Mono<Integer> countRowsInTable(DatabaseClient databaseClient, String tableName) {
return countRowsInTableWhere(databaseClient, tableName, null);
}

/**
* Count the rows in the given table, using the provided {@code WHERE} clause.
* <p>If the provided {@code WHERE} clause contains text, it will be prefixed
* with {@code " WHERE "} and then appended to the generated {@code SELECT}
* statement. For example, if the provided table name is {@code "person"} and
* the provided where clause is {@code "name = 'Bob' and age > 25"}, the
* resulting SQL statement to execute will be
* {@code "SELECT COUNT(0) FROM person WHERE name = 'Bob' and age > 25"}.
* @param databaseClient the {@link DatabaseClient} with which to perform JDBC
* operations
* @param tableName the name of the table to count rows in
* @param whereClause the {@code WHERE} clause to append to the query
* @return the number of rows in the table that match the provided
* {@code WHERE} clause
*/
public static Mono<Integer> countRowsInTableWhere(
DatabaseClient databaseClient, String tableName, @Nullable String whereClause) {

String sql = "SELECT COUNT(0) FROM " + tableName;
if (StringUtils.hasText(whereClause)) {
sql += " WHERE " + whereClause;
}
return databaseClient.sql(sql)
.map(row -> Objects.requireNonNull(row.get(0, Long.class)).intValue())
.one();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
/**
* Support classes for tests based on R2DBC.
*/
@NullMarked
package org.springframework.test.r2dbc;

import org.jspecify.annotations.NullMarked;
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* Copyright 2002-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.test.context.aot.samples.r2dbc;

import io.r2dbc.spi.ConnectionFactory;
import org.junit.jupiter.api.Test;
import reactor.test.StepVerifier;

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.test.annotation.DirtiesContext;
import org.springframework.test.context.TestPropertySource;
import org.springframework.test.context.jdbc.Sql;
import org.springframework.test.context.jdbc.SqlMergeMode;
import org.springframework.test.context.junit.jupiter.SpringJUnitConfig;
import org.springframework.test.context.reactive.EmptyReactiveDatabaseConfig;

import static org.assertj.core.api.Assertions.assertThat;
import static org.springframework.test.context.jdbc.SqlMergeMode.MergeMode.MERGE;
import static org.springframework.test.r2dbc.R2dbcTestUtils.countRowsInTable;

/**
* @author jonghoon park
* @since 7.0
*/
@SpringJUnitConfig(EmptyReactiveDatabaseConfig.class)
@SqlMergeMode(MERGE)
@Sql("/org/springframework/test/context/r2dbc/schema.sql")
@DirtiesContext
@TestPropertySource(properties = "test.engine = jupiter")
public class R2dbcSqlScriptsSpringJupiterTests {

@Test
@Sql // default script --> org/springframework/test/context/aot/samples/r2dbc/R2dbcSqlScriptsSpringJupiterTests.test.sql
void test(@Autowired ConnectionFactory connectionFactory) {
StepVerifier.create(countRowsInTable(connectionFactory, "users"))
.assertNext(count -> assertThat(count).isEqualTo(1))
.verifyComplete();
}

}
Loading