16
16
package org .springframework .data .mongodb .repository .query ;
17
17
18
18
import java .util .ArrayList ;
19
+ import java .util .Collection ;
19
20
import java .util .Iterator ;
20
21
import java .util .List ;
21
22
import java .util .function .Supplier ;
26
27
import org .springframework .data .domain .Page ;
27
28
import org .springframework .data .domain .Pageable ;
28
29
import org .springframework .data .domain .Range ;
29
- import org .springframework .data .domain .Score ;
30
30
import org .springframework .data .domain .SearchResult ;
31
31
import org .springframework .data .domain .SearchResults ;
32
+ import org .springframework .data .domain .Similarity ;
32
33
import org .springframework .data .domain .Slice ;
33
34
import org .springframework .data .domain .SliceImpl ;
34
- import org .springframework .data .domain .Sort ;
35
- import org .springframework .data .domain .Vector ;
36
35
import org .springframework .data .geo .Distance ;
37
36
import org .springframework .data .geo .GeoPage ;
38
37
import org .springframework .data .geo .GeoResult ;
46
45
import org .springframework .data .mongodb .core .ExecutableRemoveOperation .TerminatingRemove ;
47
46
import org .springframework .data .mongodb .core .ExecutableUpdateOperation .ExecutableUpdate ;
48
47
import org .springframework .data .mongodb .core .MongoOperations ;
49
- import org .springframework .data .mongodb .core .aggregation .Aggregation ;
50
48
import org .springframework .data .mongodb .core .aggregation .AggregationOperation ;
51
49
import org .springframework .data .mongodb .core .aggregation .AggregationResults ;
52
50
import org .springframework .data .mongodb .core .aggregation .TypedAggregation ;
53
- import org .springframework .data .mongodb .core .aggregation .VectorSearchOperation ;
54
51
import org .springframework .data .mongodb .core .query .NearQuery ;
55
52
import org .springframework .data .mongodb .core .query .Query ;
56
53
import org .springframework .data .mongodb .core .query .UpdateDefinition ;
@@ -225,7 +222,7 @@ private static boolean isListOfGeoResult(TypeInformation<?> returnType) {
225
222
}
226
223
227
224
/**
228
- * {@link MongoQueryExecution} to execute vector search
225
+ * {@link MongoQueryExecution} to execute vector search.
229
226
*
230
227
* @author Mark Paluch
231
228
* @since 5.0
@@ -235,118 +232,64 @@ class VectorSearchExecution implements MongoQueryExecution {
235
232
private final MongoOperations operations ;
236
233
private final MongoQueryMethod method ;
237
234
private final String collectionName ;
238
- private final @ Nullable Integer numCandidates ;
239
- private final VectorSearchOperation .SearchType searchType ;
240
- private final MongoParameterAccessor accessor ;
241
- private final Class <Object > outputType ;
242
- private final String path ;
235
+ private final VectorSearchDelegate .QueryMetadata queryMetadata ;
236
+ private final List <AggregationOperation > pipeline ;
243
237
244
238
public VectorSearchExecution (MongoOperations operations , MongoQueryMethod method , String collectionName ,
245
- String path , @ Nullable Integer numCandidates , VectorSearchOperation .SearchType searchType ,
246
- MongoParameterAccessor accessor , Class <Object > outputType ) {
239
+ VectorSearchDelegate .QueryMetadata queryMetadata , MongoParameterAccessor accessor ) {
247
240
248
241
this .operations = operations ;
249
242
this .collectionName = collectionName ;
250
- this .path = path ;
251
- this .numCandidates = numCandidates ;
243
+ this .queryMetadata = queryMetadata ;
252
244
this .method = method ;
253
- this .searchType = searchType ;
254
- this .accessor = accessor ;
255
- this .outputType = outputType ;
245
+ this .pipeline = queryMetadata .getAggregationPipeline (method , accessor );
256
246
}
257
247
258
248
@ Override
259
249
public Object execute (Query query ) {
260
250
261
- SearchResults <?> results = doExecuteQuery ( query );
262
- return isListOfSearchResult ( method . getReturnType ()) ? results . getContent () : results ;
263
- }
251
+ AggregationResults <?> aggregated = operations . aggregate (
252
+ TypedAggregation . newAggregation ( queryMetadata . outputType (), pipeline ), collectionName ,
253
+ queryMetadata . outputType ());
264
254
265
- @ SuppressWarnings ("unchecked" )
266
- SearchResults <Object > doExecuteQuery (Query query ) {
255
+ List <?> mappedResults = aggregated .getMappedResults ();
267
256
268
- Vector vector = accessor .getVector ();
269
- Score score = accessor .getScore ();
270
- Range <Score > distance = accessor .getScoreRange ();
271
- int limit ;
257
+ if (isSearchResult (method .getReturnType ())) {
272
258
273
- if (query .isLimited ()) {
274
- limit = query .getLimit ();
275
- } else {
276
- limit = Math .max (1 , numCandidates != null ? numCandidates / 20 : 1 );
277
- }
259
+ List <org .bson .Document > rawResults = aggregated .getRawResults ().getList ("results" , org .bson .Document .class );
260
+ List <SearchResult <Object >> result = new ArrayList <>(mappedResults .size ());
278
261
279
- List <AggregationOperation > stages = new ArrayList <>();
280
- VectorSearchOperation $vectorSearch = Aggregation .vectorSearch (method .getAnnotatedHint ()).path (path )
281
- .vector (vector ).limit (limit );
262
+ for (int i = 0 ; i < mappedResults .size (); i ++) {
263
+ Document document = rawResults .get (i );
264
+ SearchResult <Object > searchResult = new SearchResult <>(mappedResults .get (i ),
265
+ Similarity .raw (document .getDouble ("__score__" ), queryMetadata .scoringFunction ()));
282
266
283
- if (numCandidates != null ) {
284
- $vectorSearch = $vectorSearch .numCandidates (numCandidates );
285
- }
267
+ result .add (searchResult );
268
+ }
286
269
287
- $vectorSearch = $vectorSearch .filter (query .getQueryObject ());
288
- $vectorSearch = $vectorSearch .searchType (searchType );
289
- $vectorSearch = $vectorSearch .withSearchScore ("__score__" );
290
-
291
- if (score != null ) {
292
- $vectorSearch = $vectorSearch .withFilterBySore (c -> {
293
- c .gt (score .getValue ());
294
- });
295
- } else if (distance .getLowerBound ().isBounded () || distance .getUpperBound ().isBounded ()) {
296
- $vectorSearch = $vectorSearch .withFilterBySore (c -> {
297
- Range .Bound <Score > lower = distance .getLowerBound ();
298
- if (lower .isBounded ()) {
299
- double value = lower .getValue ().get ().getValue ();
300
- if (lower .isInclusive ()) {
301
- c .gte (value );
302
- } else {
303
- c .gt (value );
304
- }
305
- }
306
-
307
- Range .Bound <Score > upper = distance .getUpperBound ();
308
- if (upper .isBounded ()) {
309
-
310
- double value = upper .getValue ().get ().getValue ();
311
- if (upper .isInclusive ()) {
312
- c .lte (value );
313
- } else {
314
- c .lt (value );
315
- }
316
- }
317
- });
270
+ return isListOfSearchResult (method .getReturnType ()) ? result : new SearchResults <>(result );
318
271
}
319
272
320
- stages .add ($vectorSearch );
321
-
322
- if (query .isSorted ()) {
323
- // TODO stages.add(Aggregation.sort(query.with()));
324
- } else {
325
- stages .add (Aggregation .sort (Sort .Direction .DESC , "__score__" ));
326
- }
327
-
328
- AggregationResults <Object > aggregated = operations
329
- .aggregate (TypedAggregation .<Object > newAggregation (outputType , stages ), collectionName , outputType );
330
-
331
- List <Object > mappedResults = aggregated .getMappedResults ();
332
- List <org .bson .Document > rawResults = aggregated .getRawResults ().getList ("results" , org .bson .Document .class );
333
-
334
- List <SearchResult <Object >> result = new ArrayList <>(mappedResults .size ());
273
+ return mappedResults ;
274
+ }
335
275
336
- for (int i = 0 ; i < mappedResults .size (); i ++) {
337
- Document document = rawResults .get (i );
338
- SearchResult <Object > searchResult = new SearchResult <>(mappedResults .get (i ),
339
- Score .of (document .getDouble ("__score__" )));
276
+ private static boolean isListOfSearchResult (TypeInformation <?> returnType ) {
340
277
341
- result .add (searchResult );
278
+ if (!Collection .class .isAssignableFrom (returnType .getType ())) {
279
+ return false ;
342
280
}
343
281
344
- return new SearchResults <>(result );
282
+ TypeInformation <?> componentType = returnType .getComponentType ();
283
+ return componentType != null && SearchResult .class .equals (componentType .getType ());
345
284
}
346
285
347
- private static boolean isListOfSearchResult (TypeInformation <?> returnType ) {
286
+ private static boolean isSearchResult (TypeInformation <?> returnType ) {
348
287
349
- if (!returnType .getType ().equals (List .class )) {
288
+ if (SearchResults .class .isAssignableFrom (returnType .getType ())) {
289
+ return true ;
290
+ }
291
+
292
+ if (!Iterable .class .isAssignableFrom (returnType .getType ())) {
350
293
return false ;
351
294
}
352
295
0 commit comments