Skip to content

Commit c6a0c8b

Browse files
schaudermp911de
authored andcommitted
Fix Composite ids for R2DBC.
R2DBC now has minimal support for embedded entities. We can read and write them. And we can use them as ids. Closes #2012 Original pull request: #2114
1 parent 3835b11 commit c6a0c8b

File tree

7 files changed

+417
-29
lines changed

7 files changed

+417
-29
lines changed

spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/convert/MappingR2dbcConverter.java

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
*
5858
* @author Mark Paluch
5959
* @author Oliver Drotbohm
60+
* @author Jens Schauder
6061
*/
6162
public class MappingR2dbcConverter extends MappingRelationalConverter implements R2dbcConverter {
6263

@@ -189,8 +190,17 @@ private void writeInternal(Object source, OutboundRow sink, Class<?> userClass)
189190
writeProperties(sink, entity, propertyAccessor);
190191
}
191192

193+
/**
194+
* write the values of the properties of an {@link RelationalPersistentEntity} to an {@link OutboundRow}.
195+
*
196+
* @param sink must not be {@literal null}.
197+
* @param entity must not be {@literal null}.
198+
* @param accessor used for accessing the property values of {@literal entity}. May be {@literal null}. A
199+
* {@literal null} value is used when this is an embedded {@literal null} entity, resulting in all its
200+
* property values to be {@literal null} as well.
201+
*/
192202
private void writeProperties(OutboundRow sink, RelationalPersistentEntity<?> entity,
193-
PersistentPropertyAccessor<?> accessor) {
203+
@Nullable PersistentPropertyAccessor<?> accessor) {
194204

195205
for (RelationalPersistentProperty property : entity) {
196206

@@ -200,11 +210,27 @@ private void writeProperties(OutboundRow sink, RelationalPersistentEntity<?> ent
200210

201211
Object value;
202212

203-
if (property.isIdProperty()) {
204-
IdentifierAccessor identifierAccessor = entity.getIdentifierAccessor(accessor.getBean());
205-
value = identifierAccessor.getIdentifier();
213+
if (accessor == null) {
214+
value = null;
206215
} else {
207-
value = accessor.getProperty(property);
216+
if (property.isIdProperty()) {
217+
IdentifierAccessor identifierAccessor = entity.getIdentifierAccessor(accessor.getBean());
218+
value = identifierAccessor.getIdentifier();
219+
} else {
220+
value = accessor.getProperty(property);
221+
}
222+
}
223+
224+
if (property.isEmbedded()) {
225+
226+
RelationalPersistentEntity<?> embeddedEntity = getMappingContext().getRequiredPersistentEntity(property);
227+
PersistentPropertyAccessor<Object> embeddedAccessor = null;
228+
if (value != null) {
229+
embeddedAccessor = embeddedEntity.getPropertyAccessor(value);
230+
}
231+
writeProperties(sink, embeddedEntity, embeddedAccessor);
232+
233+
continue;
208234
}
209235

210236
if (value == null) {

spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/core/DefaultReactiveDataAccessStrategy.java

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
import org.springframework.data.relational.core.dialect.ArrayColumns;
4444
import org.springframework.data.relational.core.dialect.Dialect;
4545
import org.springframework.data.relational.core.dialect.RenderContextFactory;
46+
import org.springframework.data.relational.core.mapping.RelationalMappingContext;
4647
import org.springframework.data.relational.core.mapping.RelationalPersistentEntity;
4748
import org.springframework.data.relational.core.mapping.RelationalPersistentProperty;
4849
import org.springframework.data.relational.core.sql.SqlIdentifier;
@@ -66,7 +67,7 @@ public class DefaultReactiveDataAccessStrategy implements ReactiveDataAccessStra
6667
private final R2dbcDialect dialect;
6768
private final R2dbcConverter converter;
6869
private final UpdateMapper updateMapper;
69-
private final MappingContext<RelationalPersistentEntity<?>, ? extends RelationalPersistentProperty> mappingContext;
70+
private final RelationalMappingContext mappingContext;
7071
private final StatementMapper statementMapper;
7172
private final NamedParameterExpander expander = new NamedParameterExpander();
7273

@@ -119,16 +120,14 @@ public static R2dbcConverter createConverter(R2dbcDialect dialect, Collection<?>
119120
* @param dialect the {@link R2dbcDialect} to use.
120121
* @param converter must not be {@literal null}.
121122
*/
122-
@SuppressWarnings("unchecked")
123123
public DefaultReactiveDataAccessStrategy(R2dbcDialect dialect, R2dbcConverter converter) {
124124

125125
Assert.notNull(dialect, "Dialect must not be null");
126126
Assert.notNull(converter, "RelationalConverter must not be null");
127127

128128
this.converter = converter;
129129
this.updateMapper = new UpdateMapper(dialect, converter);
130-
this.mappingContext = (MappingContext<RelationalPersistentEntity<?>, ? extends RelationalPersistentProperty>) this.converter
131-
.getMappingContext();
130+
this.mappingContext = (RelationalMappingContext) this.converter.getMappingContext();
132131
this.dialect = dialect;
133132

134133
RenderContextFactory factory = new RenderContextFactory(dialect);
@@ -141,13 +140,22 @@ public List<SqlIdentifier> getAllColumns(Class<?> entityType) {
141140

142141
RelationalPersistentEntity<?> persistentEntity = getPersistentEntity(entityType);
143142

143+
return getAllColumns(persistentEntity);
144+
}
145+
146+
private List<SqlIdentifier> getAllColumns(@Nullable RelationalPersistentEntity<?> persistentEntity) {
147+
144148
if (persistentEntity == null) {
145149
return Collections.singletonList(SqlIdentifier.unquoted("*"));
146150
}
147151

148152
List<SqlIdentifier> columnNames = new ArrayList<>();
149153
for (RelationalPersistentProperty property : persistentEntity) {
150-
columnNames.add(property.getColumnName());
154+
if (property.isEmbedded()) {
155+
columnNames.addAll(getAllColumns(mappingContext.getRequiredPersistentEntity(property)));
156+
} else {
157+
columnNames.add(property.getColumnName());
158+
}
151159
}
152160

153161
return columnNames;
@@ -159,12 +167,8 @@ public List<SqlIdentifier> getIdentifierColumns(Class<?> entityType) {
159167
RelationalPersistentEntity<?> persistentEntity = getRequiredPersistentEntity(entityType);
160168

161169
List<SqlIdentifier> columnNames = new ArrayList<>();
162-
for (RelationalPersistentProperty property : persistentEntity) {
163-
164-
if (property.isIdProperty()) {
165-
columnNames.add(property.getColumnName());
166-
}
167-
}
170+
mappingContext.getAggregatePath(persistentEntity).getTableInfo().idColumnInfos()
171+
.forEach((__, ci) -> columnNames.add(ci.name()));
168172

169173
return columnNames;
170174
}

spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/core/R2dbcEntityTemplate.java

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import reactor.core.publisher.Mono;
2424

2525
import java.util.Collections;
26+
import java.util.HashMap;
2627
import java.util.LinkedHashSet;
2728
import java.util.List;
2829
import java.util.Map;
@@ -33,7 +34,6 @@
3334
import java.util.stream.Collectors;
3435

3536
import org.reactivestreams.Publisher;
36-
3737
import org.springframework.beans.BeansException;
3838
import org.springframework.beans.factory.BeanFactory;
3939
import org.springframework.beans.factory.BeanFactoryAware;
@@ -60,6 +60,7 @@
6060
import org.springframework.data.r2dbc.mapping.event.BeforeSaveCallback;
6161
import org.springframework.data.relational.core.conversion.AbstractRelationalConverter;
6262
import org.springframework.data.relational.core.mapping.PersistentPropertyTranslator;
63+
import org.springframework.data.relational.core.mapping.RelationalMappingContext;
6364
import org.springframework.data.relational.core.mapping.RelationalPersistentEntity;
6465
import org.springframework.data.relational.core.mapping.RelationalPersistentProperty;
6566
import org.springframework.data.relational.core.query.Criteria;
@@ -96,6 +97,7 @@
9697
* @author Robert Heim
9798
* @author Sebastian Wieland
9899
* @author Mikhail Polivakha
100+
* @author Jens Schauder
99101
* @since 1.1
100102
*/
101103
public class R2dbcEntityTemplate implements R2dbcEntityOperations, BeanFactoryAware, ApplicationContextAware {
@@ -350,8 +352,8 @@ <T, P extends Publisher<T>> P doSelect(Query query, Class<?> entityClass, SqlIde
350352
return (P) ((Flux<?>) result).concatMap(it -> maybeCallAfterConvert(it, tableName));
351353
}
352354

353-
private <T> RowsFetchSpec<T> doSelect(Query query, Class<?> entityType, SqlIdentifier tableName,
354-
Class<T> returnType, Function<? super Statement, ? extends Statement> filterFunction) {
355+
private <T> RowsFetchSpec<T> doSelect(Query query, Class<?> entityType, SqlIdentifier tableName, Class<T> returnType,
356+
Function<? super Statement, ? extends Statement> filterFunction) {
355357

356358
StatementMapper statementMapper = dataAccessStrategy.getStatementMapper().forType(entityType);
357359

@@ -378,11 +380,8 @@ private <T> RowsFetchSpec<T> doSelect(Query query, Class<?> entityType, SqlIdent
378380

379381
PreparedOperation<?> operation = statementMapper.getMappedObject(selectSpec);
380382

381-
return getRowsFetchSpec(
382-
databaseClient.sql(operation).filter(statementFilterFunction.andThen(filterFunction)),
383-
entityType,
384-
returnType
385-
);
383+
return getRowsFetchSpec(databaseClient.sql(operation).filter(statementFilterFunction.andThen(filterFunction)),
384+
entityType, returnType);
386385
}
387386

388387
@Override
@@ -622,16 +621,26 @@ private <T> Mono<T> doUpdate(T entity, SqlIdentifier tableName) {
622621
return maybeCallBeforeSave(entityToUse, outboundRow, tableName) //
623622
.flatMap(onBeforeSave -> {
624623

625-
SqlIdentifier idColumn = persistentEntity.getRequiredIdProperty().getColumnName();
626-
Parameter id = outboundRow.remove(idColumn);
624+
Map<SqlIdentifier, Object> idValues = new HashMap<>();
625+
((RelationalMappingContext) mappingContext).getAggregatePath(persistentEntity).getTableInfo()
626+
.idColumnInfos().forEach((ap, ci) -> idValues.put(ci.name(), outboundRow.remove(ci.name())));
627627

628628
persistentEntity.forEach(p -> {
629629
if (p.isInsertOnly()) {
630630
outboundRow.remove(p.getColumnName());
631631
}
632632
});
633633

634-
Criteria criteria = Criteria.where(dataAccessStrategy.toSql(idColumn)).is(id);
634+
Assert.state(!idValues.isEmpty(), entityToUse + " has no id. Update is not possible");
635+
636+
Criteria criteria = null;
637+
for (Map.Entry<SqlIdentifier, Object> idAndValue : idValues.entrySet()) {
638+
if (criteria == null) {
639+
criteria = Criteria.where(dataAccessStrategy.toSql(idAndValue.getKey())).is(idAndValue.getValue());
640+
} else {
641+
criteria = criteria.and(dataAccessStrategy.toSql(idAndValue.getKey())).is(idAndValue.getValue());
642+
}
643+
}
635644

636645
if (matchingVersionCriteria != null) {
637646
criteria = criteria.and(matchingVersionCriteria);

spring-data-r2dbc/src/test/java/org/springframework/data/r2dbc/convert/MappingR2dbcConverterUnitTests.java

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
import org.springframework.data.r2dbc.dialect.PostgresDialect;
4444
import org.springframework.data.r2dbc.mapping.OutboundRow;
4545
import org.springframework.data.r2dbc.mapping.R2dbcMappingContext;
46+
import org.springframework.data.relational.core.mapping.Embedded;
4647
import org.springframework.data.relational.core.mapping.RelationalMappingContext;
4748
import org.springframework.data.relational.core.sql.SqlIdentifier;
4849
import org.springframework.r2dbc.core.Parameter;
@@ -261,6 +262,53 @@ void writeShouldObtainIdFromIdentifierAccessor() {
261262
assertThat(row).containsEntry(SqlIdentifier.unquoted("id"), Parameter.from(42L));
262263
}
263264

265+
@Test // GH-2096
266+
void shouldWriteSingleLevelEmbeddedEntity() {
267+
268+
Level1 entity = new Level1("root", new Level2("child", 23));
269+
270+
OutboundRow row = new OutboundRow();
271+
converter.write(entity, row);
272+
273+
assertThat(row).containsExactlyInAnyOrderEntriesOf(Map.of(
274+
SqlIdentifier.unquoted("name"), Parameter.from("root"),
275+
SqlIdentifier.unquoted("level2_name"), Parameter.from("child"),
276+
SqlIdentifier.unquoted("level2_number"), Parameter.from(23)
277+
));
278+
}
279+
280+
@Test // GH-2096
281+
void shouldWriteMultiLevelEmbeddedEntity() {
282+
283+
WithEmbedded entity = new WithEmbedded(4711L, new Level1("level1", new Level2("child", 23)));
284+
285+
OutboundRow row = new OutboundRow();
286+
converter.write(entity, row);
287+
288+
assertThat(row).containsExactlyInAnyOrderEntriesOf(Map.of(
289+
SqlIdentifier.unquoted("id"), Parameter.from(4711L),
290+
SqlIdentifier.unquoted("level1_name"), Parameter.from("level1"),
291+
SqlIdentifier.unquoted("level1_level2_name"), Parameter.from("child"),
292+
SqlIdentifier.unquoted("level1_level2_number"), Parameter.from(23)
293+
));
294+
}
295+
296+
@Test // GH-2096
297+
void shouldWriteNullEmbeddedEntity() {
298+
299+
WithEmbedded entity = new WithEmbedded(4711L, null);
300+
301+
OutboundRow row = new OutboundRow();
302+
converter.write(entity, row);
303+
304+
assertThat(row).containsExactlyInAnyOrderEntriesOf(Map.of(
305+
SqlIdentifier.unquoted("id"), Parameter.from(4711L),
306+
SqlIdentifier.unquoted("level1_name"), Parameter.empty(String.class),
307+
SqlIdentifier.unquoted("level1_level2_name"), Parameter.empty(String.class),
308+
SqlIdentifier.unquoted("level1_level2_number"), Parameter.empty(Integer.class)
309+
));
310+
}
311+
264312
static class Person {
265313
@Id String id;
266314
String firstname, lastname;
@@ -326,6 +374,13 @@ public PersonWithConversions(String id, Map<String, String> nested, NonMappableE
326374
record WithPrimitiveId(@Id long id) {
327375
}
328376

377+
record WithEmbedded(@Id long id, @Embedded.Empty(prefix = "level1_") Level1 one){}
378+
379+
record Level1(String name, @Embedded.Empty(prefix = "level2_") Level2 two) {
380+
381+
}
382+
record Level2(String name, Integer number){}
383+
329384
static class CustomConversionPerson {
330385

331386
String foo;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
package org.springframework.data.r2dbc.core;
2+
3+
import java.util.Arrays;
4+
import java.util.List;
5+
6+
import org.assertj.core.api.SoftAssertions;
7+
import org.junit.jupiter.api.Test;
8+
import org.springframework.data.annotation.Id;
9+
import org.springframework.data.r2dbc.dialect.H2Dialect;
10+
import org.springframework.data.relational.core.mapping.Embedded;
11+
import org.springframework.data.relational.core.sql.SqlIdentifier;
12+
13+
/**
14+
* Unit tests for {@link DefaultReactiveDataAccessStrategy}.
15+
*
16+
* @author Jens Schauder
17+
*/
18+
class DefaultReactiveDataAccessStrategyUnitTests {
19+
20+
DefaultReactiveDataAccessStrategy dataAccessStrategy = new DefaultReactiveDataAccessStrategy(H2Dialect.INSTANCE);
21+
22+
@Test
23+
void getAllColumns() {
24+
25+
SoftAssertions.assertSoftly(softly -> {
26+
check(softly, SimpleEntity.class, "ID", "NAME");
27+
check(softly, WithEmbedded.class, "ID", "L1_NAME", "L1_L2_NAME", "L1_L2_NUMBER");
28+
check(softly, WithEmbeddedId.class, "ID_NAME", "ID_NUMBER", "NAME");
29+
});
30+
}
31+
32+
private void check(SoftAssertions softly, Class<?> entityType, String... columnNames) {
33+
34+
List<SqlIdentifier> sqlIdentifiers = Arrays.stream(columnNames).map(SqlIdentifier::quoted).toList();
35+
softly.assertThat(dataAccessStrategy.getAllColumns(entityType)).describedAs(entityType.getName())
36+
.containsExactlyInAnyOrder(sqlIdentifiers.toArray(new SqlIdentifier[0]));
37+
}
38+
39+
record SimpleEntity(int id, String name) {
40+
}
41+
42+
record WithEmbedded(int id, @Embedded.Empty(prefix = "L1_") Level1 level1) {
43+
}
44+
45+
record Level1(String name, @Embedded.Empty(prefix = "L2_") Level2 l2) {
46+
}
47+
48+
record Level2(String name, Integer number) {
49+
}
50+
51+
record WithEmbeddedId(@Id @Embedded.Empty(prefix = "ID_") Level2 id, String name) {
52+
}
53+
54+
}

0 commit comments

Comments
 (0)