Skip to content

Commit 642396c

Browse files
committed
Add reactive search support.
1 parent c80f735 commit 642396c

File tree

10 files changed

+888
-326
lines changed

10 files changed

+888
-326
lines changed

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,27 +15,20 @@
1515
*/
1616
package org.springframework.data.mongodb.repository.aot;
1717

18-
import static org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.aggregationBlockBuilder;
19-
import static org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.aggregationExecutionBlockBuilder;
20-
import static org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.deleteExecutionBlockBuilder;
21-
import static org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.queryBlockBuilder;
22-
import static org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.queryExecutionBlockBuilder;
23-
import static org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.updateBlockBuilder;
24-
import static org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.updateExecutionBlockBuilder;
25-
2618
import java.lang.reflect.Method;
2719
import java.util.regex.Pattern;
2820

2921
import org.apache.commons.logging.Log;
3022
import org.apache.commons.logging.LogFactory;
3123
import org.jspecify.annotations.Nullable;
24+
3225
import org.springframework.core.annotation.AnnotatedElementUtils;
3326
import org.springframework.data.mongodb.core.MongoOperations;
3427
import org.springframework.data.mongodb.core.aggregation.AggregationUpdate;
3528
import org.springframework.data.mongodb.core.mapping.MongoMappingContext;
3629
import org.springframework.data.mongodb.repository.Query;
3730
import org.springframework.data.mongodb.repository.Update;
38-
import org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.QueryCodeBlockBuilder;
31+
import org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.*;
3932
import org.springframework.data.mongodb.repository.query.MongoQueryMethod;
4033
import org.springframework.data.repository.aot.generate.AotRepositoryConstructorBuilder;
4134
import org.springframework.data.repository.aot.generate.AotRepositoryFragmentMetadata;
@@ -178,10 +171,11 @@ private QueryInteraction createStringQuery(RepositoryInformation repositoryInfor
178171

179172
private static boolean backoff(MongoQueryMethod method) {
180173

181-
boolean skip = method.isGeoNearQuery() || method.isScrollQuery() || method.isStreamQuery();
174+
boolean skip = method.isGeoNearQuery() || method.isScrollQuery() || method.isStreamQuery()
175+
|| method.isSearchQuery();
182176

183177
if (skip && logger.isDebugEnabled()) {
184-
logger.debug("Skipping AOT generation for [%s]. Method is either geo-near, streaming or scrolling query"
178+
logger.debug("Skipping AOT generation for [%s]. Method is either geo-near, streaming, search or scrolling query"
185179
.formatted(method.getName()));
186180
}
187181
return skip;
@@ -193,7 +187,6 @@ private static MethodContributor<MongoQueryMethod> aggregationMethodContributor(
193187
return MethodContributor.forQueryMethod(queryMethod).withMetadata(aggregation).contribute(context -> {
194188

195189
CodeBlock.Builder builder = CodeBlock.builder();
196-
builder.add(context.codeBlocks().logDebug("invoking [%s]".formatted(context.getMethod().getName())));
197190

198191
builder.add(aggregationBlockBuilder(context, queryMethod).stages(aggregation)
199192
.usingAggregationVariableName("aggregation").build());

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoQueryExecution.java

Lines changed: 35 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
package org.springframework.data.mongodb.repository.query;
1717

1818
import java.util.ArrayList;
19+
import java.util.Collection;
1920
import java.util.Iterator;
2021
import java.util.List;
2122
import java.util.function.Supplier;
@@ -26,13 +27,11 @@
2627
import org.springframework.data.domain.Page;
2728
import org.springframework.data.domain.Pageable;
2829
import org.springframework.data.domain.Range;
29-
import org.springframework.data.domain.Score;
3030
import org.springframework.data.domain.SearchResult;
3131
import org.springframework.data.domain.SearchResults;
32+
import org.springframework.data.domain.Similarity;
3233
import org.springframework.data.domain.Slice;
3334
import org.springframework.data.domain.SliceImpl;
34-
import org.springframework.data.domain.Sort;
35-
import org.springframework.data.domain.Vector;
3635
import org.springframework.data.geo.Distance;
3736
import org.springframework.data.geo.GeoPage;
3837
import org.springframework.data.geo.GeoResult;
@@ -46,11 +45,9 @@
4645
import org.springframework.data.mongodb.core.ExecutableRemoveOperation.TerminatingRemove;
4746
import org.springframework.data.mongodb.core.ExecutableUpdateOperation.ExecutableUpdate;
4847
import org.springframework.data.mongodb.core.MongoOperations;
49-
import org.springframework.data.mongodb.core.aggregation.Aggregation;
5048
import org.springframework.data.mongodb.core.aggregation.AggregationOperation;
5149
import org.springframework.data.mongodb.core.aggregation.AggregationResults;
5250
import org.springframework.data.mongodb.core.aggregation.TypedAggregation;
53-
import org.springframework.data.mongodb.core.aggregation.VectorSearchOperation;
5451
import org.springframework.data.mongodb.core.query.NearQuery;
5552
import org.springframework.data.mongodb.core.query.Query;
5653
import org.springframework.data.mongodb.core.query.UpdateDefinition;
@@ -225,7 +222,7 @@ private static boolean isListOfGeoResult(TypeInformation<?> returnType) {
225222
}
226223

227224
/**
228-
* {@link MongoQueryExecution} to execute vector search
225+
* {@link MongoQueryExecution} to execute vector search.
229226
*
230227
* @author Mark Paluch
231228
* @since 5.0
@@ -235,118 +232,64 @@ class VectorSearchExecution implements MongoQueryExecution {
235232
private final MongoOperations operations;
236233
private final MongoQueryMethod method;
237234
private final String collectionName;
238-
private final @Nullable Integer numCandidates;
239-
private final VectorSearchOperation.SearchType searchType;
240-
private final MongoParameterAccessor accessor;
241-
private final Class<Object> outputType;
242-
private final String path;
235+
private final VectorSearchDelegate.QueryMetadata queryMetadata;
236+
private final List<AggregationOperation> pipeline;
243237

244238
public VectorSearchExecution(MongoOperations operations, MongoQueryMethod method, String collectionName,
245-
String path, @Nullable Integer numCandidates, VectorSearchOperation.SearchType searchType,
246-
MongoParameterAccessor accessor, Class<Object> outputType) {
239+
VectorSearchDelegate.QueryMetadata queryMetadata, MongoParameterAccessor accessor) {
247240

248241
this.operations = operations;
249242
this.collectionName = collectionName;
250-
this.path = path;
251-
this.numCandidates = numCandidates;
243+
this.queryMetadata = queryMetadata;
252244
this.method = method;
253-
this.searchType = searchType;
254-
this.accessor = accessor;
255-
this.outputType = outputType;
245+
this.pipeline = queryMetadata.getAggregationPipeline(method, accessor);
256246
}
257247

258248
@Override
259249
public Object execute(Query query) {
260250

261-
SearchResults<?> results = doExecuteQuery(query);
262-
return isListOfSearchResult(method.getReturnType()) ? results.getContent() : results;
263-
}
251+
AggregationResults<?> aggregated = operations.aggregate(
252+
TypedAggregation.newAggregation(queryMetadata.outputType(), pipeline), collectionName,
253+
queryMetadata.outputType());
264254

265-
@SuppressWarnings("unchecked")
266-
SearchResults<Object> doExecuteQuery(Query query) {
255+
List<?> mappedResults = aggregated.getMappedResults();
267256

268-
Vector vector = accessor.getVector();
269-
Score score = accessor.getScore();
270-
Range<Score> distance = accessor.getScoreRange();
271-
int limit;
257+
if (isSearchResult(method.getReturnType())) {
272258

273-
if (query.isLimited()) {
274-
limit = query.getLimit();
275-
} else {
276-
limit = Math.max(1, numCandidates != null ? numCandidates / 20 : 1);
277-
}
259+
List<org.bson.Document> rawResults = aggregated.getRawResults().getList("results", org.bson.Document.class);
260+
List<SearchResult<Object>> result = new ArrayList<>(mappedResults.size());
278261

279-
List<AggregationOperation> stages = new ArrayList<>();
280-
VectorSearchOperation $vectorSearch = Aggregation.vectorSearch(method.getAnnotatedHint()).path(path)
281-
.vector(vector).limit(limit);
262+
for (int i = 0; i < mappedResults.size(); i++) {
263+
Document document = rawResults.get(i);
264+
SearchResult<Object> searchResult = new SearchResult<>(mappedResults.get(i),
265+
Similarity.raw(document.getDouble("__score__"), queryMetadata.scoringFunction()));
282266

283-
if (numCandidates != null) {
284-
$vectorSearch = $vectorSearch.numCandidates(numCandidates);
285-
}
267+
result.add(searchResult);
268+
}
286269

287-
$vectorSearch = $vectorSearch.filter(query.getQueryObject());
288-
$vectorSearch = $vectorSearch.searchType(searchType);
289-
$vectorSearch = $vectorSearch.withSearchScore("__score__");
290-
291-
if (score != null) {
292-
$vectorSearch = $vectorSearch.withFilterBySore(c -> {
293-
c.gt(score.getValue());
294-
});
295-
} else if (distance.getLowerBound().isBounded() || distance.getUpperBound().isBounded()) {
296-
$vectorSearch = $vectorSearch.withFilterBySore(c -> {
297-
Range.Bound<Score> lower = distance.getLowerBound();
298-
if (lower.isBounded()) {
299-
double value = lower.getValue().get().getValue();
300-
if (lower.isInclusive()) {
301-
c.gte(value);
302-
} else {
303-
c.gt(value);
304-
}
305-
}
306-
307-
Range.Bound<Score> upper = distance.getUpperBound();
308-
if (upper.isBounded()) {
309-
310-
double value = upper.getValue().get().getValue();
311-
if (upper.isInclusive()) {
312-
c.lte(value);
313-
} else {
314-
c.lt(value);
315-
}
316-
}
317-
});
270+
return isListOfSearchResult(method.getReturnType()) ? result : new SearchResults<>(result);
318271
}
319272

320-
stages.add($vectorSearch);
321-
322-
if (query.isSorted()) {
323-
// TODO stages.add(Aggregation.sort(query.with()));
324-
} else {
325-
stages.add(Aggregation.sort(Sort.Direction.DESC, "__score__"));
326-
}
327-
328-
AggregationResults<Object> aggregated = operations
329-
.aggregate(TypedAggregation.<Object> newAggregation(outputType, stages), collectionName, outputType);
330-
331-
List<Object> mappedResults = aggregated.getMappedResults();
332-
List<org.bson.Document> rawResults = aggregated.getRawResults().getList("results", org.bson.Document.class);
333-
334-
List<SearchResult<Object>> result = new ArrayList<>(mappedResults.size());
273+
return mappedResults;
274+
}
335275

336-
for (int i = 0; i < mappedResults.size(); i++) {
337-
Document document = rawResults.get(i);
338-
SearchResult<Object> searchResult = new SearchResult<>(mappedResults.get(i),
339-
Score.of(document.getDouble("__score__")));
276+
private static boolean isListOfSearchResult(TypeInformation<?> returnType) {
340277

341-
result.add(searchResult);
278+
if (!Collection.class.isAssignableFrom(returnType.getType())) {
279+
return false;
342280
}
343281

344-
return new SearchResults<>(result);
282+
TypeInformation<?> componentType = returnType.getComponentType();
283+
return componentType != null && SearchResult.class.equals(componentType.getType());
345284
}
346285

347-
private static boolean isListOfSearchResult(TypeInformation<?> returnType) {
286+
private static boolean isSearchResult(TypeInformation<?> returnType) {
348287

349-
if (!returnType.getType().equals(List.class)) {
288+
if (SearchResults.class.isAssignableFrom(returnType.getType())) {
289+
return true;
290+
}
291+
292+
if (!Iterable.class.isAssignableFrom(returnType.getType())) {
350293
return false;
351294
}
352295

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ReactiveMongoQueryExecution.java

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,19 +18,26 @@
1818
import reactor.core.publisher.Flux;
1919
import reactor.core.publisher.Mono;
2020

21+
import java.util.List;
22+
23+
import org.bson.Document;
2124
import org.jspecify.annotations.Nullable;
2225
import org.reactivestreams.Publisher;
2326

2427
import org.springframework.core.convert.converter.Converter;
2528
import org.springframework.data.convert.DtoInstantiatingConverter;
2629
import org.springframework.data.domain.Pageable;
2730
import org.springframework.data.domain.Range;
31+
import org.springframework.data.domain.SearchResult;
32+
import org.springframework.data.domain.Similarity;
2833
import org.springframework.data.geo.Distance;
2934
import org.springframework.data.geo.GeoResult;
3035
import org.springframework.data.geo.Point;
3136
import org.springframework.data.mapping.model.EntityInstantiators;
3237
import org.springframework.data.mongodb.core.ReactiveMongoOperations;
3338
import org.springframework.data.mongodb.core.ReactiveUpdateOperation.ReactiveUpdate;
39+
import org.springframework.data.mongodb.core.aggregation.AggregationOperation;
40+
import org.springframework.data.mongodb.core.aggregation.TypedAggregation;
3441
import org.springframework.data.mongodb.core.query.NearQuery;
3542
import org.springframework.data.mongodb.core.query.Query;
3643
import org.springframework.data.mongodb.core.query.UpdateDefinition;
@@ -118,6 +125,57 @@ private boolean isStreamOfGeoResult() {
118125
}
119126
}
120127

128+
/**
129+
* {@link ReactiveMongoQueryExecution} to execute vector search.
130+
*
131+
* @author Mark Paluch
132+
* @since 5.0
133+
*/
134+
class VectorSearchExecution implements ReactiveMongoQueryExecution {
135+
136+
private final ReactiveMongoOperations operations;
137+
private final VectorSearchDelegate.QueryMetadata queryMetadata;
138+
private final List<AggregationOperation> pipeline;
139+
private final boolean returnSearchResult;
140+
141+
public VectorSearchExecution(ReactiveMongoOperations operations, MongoQueryMethod method,
142+
VectorSearchDelegate.QueryMetadata queryMetadata, MongoParameterAccessor accessor) {
143+
144+
this.operations = operations;
145+
this.queryMetadata = queryMetadata;
146+
this.pipeline = queryMetadata.getAggregationPipeline(method, accessor);
147+
this.returnSearchResult = isSearchResult(method.getReturnType());
148+
}
149+
150+
@Override
151+
public Publisher<? extends Object> execute(Query query, Class<?> type, String collection) {
152+
153+
Flux<Document> aggregate = operations
154+
.aggregate(TypedAggregation.newAggregation(queryMetadata.outputType(), pipeline), collection, Document.class);
155+
156+
return aggregate.map(document -> {
157+
158+
Object mappedResult = operations.getConverter().read(queryMetadata.outputType(), document);
159+
160+
return returnSearchResult
161+
? new SearchResult<>(mappedResult,
162+
Similarity.raw(document.getDouble(queryMetadata.scoreField()), queryMetadata.scoringFunction()))
163+
: mappedResult;
164+
});
165+
}
166+
167+
private static boolean isSearchResult(TypeInformation<?> returnType) {
168+
169+
if (!Publisher.class.isAssignableFrom(returnType.getType())) {
170+
return false;
171+
}
172+
173+
TypeInformation<?> componentType = returnType.getComponentType();
174+
return componentType != null && SearchResult.class.equals(componentType.getType());
175+
}
176+
177+
}
178+
121179
/**
122180
* {@link ReactiveMongoQueryExecution} removing documents matching the query.
123181
*

0 commit comments

Comments
 (0)