@@ -56,13 +56,33 @@ const absl::flat_hash_map<string_view, search::SchemaField::FieldType> kSchemaTy
56
56
{" NUMERIC" sv, search::SchemaField::NUMERIC},
57
57
{" VECTOR" sv, search::SchemaField::VECTOR}};
58
58
59
+ size_t GetProbabilisticBound (size_t shards, size_t hits, size_t requested, bool is_aggregation) {
60
+ auto intlog2 = [](size_t x) {
61
+ size_t l = 0 ;
62
+ while (x >>= 1 )
63
+ ++l;
64
+ return l;
65
+ };
66
+ size_t avg_shard_min = hits * intlog2 (hits) / (12 + shards / 10 );
67
+ avg_shard_min -= min (avg_shard_min, min (hits, size_t (5 )));
68
+
69
+ // VLOG(0) << "PROB BOUND " << hits << " " << shards << " " << requested << " => " <<
70
+ // avg_shard_min
71
+ // << " diffb " << requested / shards + 1 << " & " << requested;
72
+
73
+ if (!is_aggregation && avg_shard_min * shards >= requested)
74
+ return requested / shards + 1 ;
75
+
76
+ return min (hits, requested);
77
+ }
78
+
59
79
} // namespace
60
80
61
- bool SerializedSearchDoc ::operator <(const SerializedSearchDoc & other) const {
81
+ bool DocResult ::operator <(const DocResult & other) const {
62
82
return this ->score < other.score ;
63
83
}
64
84
65
- bool SerializedSearchDoc ::operator >=(const SerializedSearchDoc & other) const {
85
+ bool DocResult ::operator >=(const DocResult & other) const {
66
86
return this ->score >= other.score ;
67
87
}
68
88
@@ -162,10 +182,11 @@ bool DocIndex::Matches(string_view key, unsigned obj_code) const {
162
182
}
163
183
164
184
ShardDocIndex::ShardDocIndex (shared_ptr<DocIndex> index)
165
- : base_{std::move (index )}, indices_{{}, nullptr }, key_index_{} {
185
+ : base_{std::move (index )}, write_epoch_{ 0 }, indices_{{}, nullptr }, key_index_{} {
166
186
}
167
187
168
188
void ShardDocIndex::Rebuild (const OpArgs& op_args, PMR_NS::memory_resource* mr) {
189
+ write_epoch_++;
169
190
key_index_ = DocKeyIndex{};
170
191
indices_ = search::FieldIndices{base_->schema , mr};
171
192
@@ -174,11 +195,13 @@ void ShardDocIndex::Rebuild(const OpArgs& op_args, PMR_NS::memory_resource* mr)
174
195
}
175
196
176
197
void ShardDocIndex::AddDoc (string_view key, const DbContext& db_cntx, const PrimeValue& pv) {
198
+ write_epoch_++;
177
199
auto accessor = GetAccessor (db_cntx, pv);
178
200
indices_.Add (key_index_.Add (key), accessor.get ());
179
201
}
180
202
181
203
void ShardDocIndex::RemoveDoc (string_view key, const DbContext& db_cntx, const PrimeValue& pv) {
204
+ write_epoch_++;
182
205
auto accessor = GetAccessor (db_cntx, pv);
183
206
DocId id = key_index_.Remove (key);
184
207
indices_.Remove (id, accessor.get ());
@@ -188,38 +211,77 @@ bool ShardDocIndex::Matches(string_view key, unsigned obj_code) const {
188
211
return base_->Matches (key, obj_code);
189
212
}
190
213
191
- SearchResult ShardDocIndex::Search (const OpArgs& op_args, const SearchParams& params,
192
- search::SearchAlgorithm* search_algo) const {
214
+ io::Result<SearchResult, facade::ErrorReply> ShardDocIndex::Search (
215
+ const OpArgs& op_args, const SearchParams& params, search::SearchAlgorithm* search_algo) const {
216
+ auto search_results = search_algo->Search (&indices_);
217
+ if (!search_results.error .empty ())
218
+ return nonstd::make_unexpected (facade::ErrorReply (std::move (search_results.error )));
219
+
220
+ size_t requested_count = params.limit_offset + params.limit_total ;
221
+ size_t serialize_count = min (requested_count, search_results.ids .size ());
222
+
223
+ size_t cuttoff_bound = serialize_count;
224
+ if (params.enable_cutoff && !params.IdsOnly ())
225
+ cuttoff_bound =
226
+ GetProbabilisticBound (params.num_shards , search_results.ids .size (), requested_count,
227
+ search_algo->HasAggregation ().has_value ());
228
+
229
+ VLOG (0 ) << " Requested " << requested_count << " got " << search_results.ids .size () << " cutoff "
230
+ << cuttoff_bound;
231
+
232
+ vector<DocResult> out (serialize_count);
233
+ auto shard_id = EngineShard::tlocal ()->shard_id ();
234
+ for (size_t i = 0 ; i < out.size (); i++) {
235
+ out[i].value = DocResult::DocReference{shard_id, search_results.ids [i], i < cuttoff_bound};
236
+ out[i].score =
237
+ search_results.scores .empty () ? search::ResultScore{} : std::move (search_results.scores [i]);
238
+ }
239
+
240
+ Serialize (op_args, params, absl::MakeSpan (out));
241
+
242
+ return SearchResult{write_epoch_, search_results.ids .size (), std::move (out),
243
+ std::move (search_results.profile )};
244
+ }
245
+
246
+ bool ShardDocIndex::Refill (const OpArgs& op_args, const SearchParams& params,
247
+ search::SearchAlgorithm* search_algo, SearchResult* result) const {
248
+ if (result->write_epoch == write_epoch_) {
249
+ Serialize (op_args, params, absl::MakeSpan (result->docs ));
250
+ return true ;
251
+ }
252
+
253
+ DCHECK (!params.enable_cutoff );
254
+ auto new_result = Search (op_args, params, search_algo);
255
+ CHECK (new_result.has_value ());
256
+ *result = std::move (new_result.value ());
257
+ return false ;
258
+ }
259
+
260
+ void ShardDocIndex::Serialize (const OpArgs& op_args, const SearchParams& params,
261
+ absl::Span<DocResult> docs) const {
193
262
auto & db_slice = op_args.shard ->db_slice ();
194
- auto search_results = search_algo->Search (&indices_, params.limit_offset + params.limit_total );
195
263
196
- if (!search_results.error .empty ())
197
- return SearchResult{facade::ErrorReply{std::move (search_results.error )}};
264
+ for (auto & doc : docs) {
265
+ if (!holds_alternative<DocResult::DocReference>(doc.value ))
266
+ continue ;
198
267
199
- vector<SerializedSearchDoc> out;
200
- out.reserve (search_results.ids .size ());
268
+ auto ref = get<DocResult::DocReference>(doc.value );
269
+ if (!ref.requested )
270
+ return ;
201
271
202
- size_t expired_count = 0 ;
203
- for (size_t i = 0 ; i < search_results.ids .size (); i++) {
204
- auto key = key_index_.Get (search_results.ids [i]);
205
- auto it = db_slice.Find (op_args.db_cntx , key, base_->GetObjCode ());
272
+ string key{key_index_.Get (ref.doc_id )};
206
273
274
+ auto it = db_slice.Find (op_args.db_cntx , key, base_->GetObjCode ());
207
275
if (!it || !IsValid (*it)) { // Item must have expired
208
- expired_count++ ;
276
+ doc. value = DocResult::SerializedValue{ std::move (key), {}} ;
209
277
continue ;
210
278
}
211
279
212
280
auto accessor = GetAccessor (op_args.db_cntx , (*it)->second );
213
281
auto doc_data = params.return_fields ? accessor->Serialize (base_->schema , *params.return_fields )
214
282
: accessor->Serialize (base_->schema );
215
-
216
- auto score =
217
- search_results.scores .empty () ? std::monostate{} : std::move (search_results.scores [i]);
218
- out.push_back (SerializedSearchDoc{string{key}, std::move (doc_data), std::move (score)});
283
+ doc.value = DocResult::SerializedValue{std::move (key), std::move (doc_data)};
219
284
}
220
-
221
- return SearchResult{search_results.total - expired_count, std::move (out),
222
- std::move (search_results.profile )};
223
285
}
224
286
225
287
DocIndexInfo ShardDocIndex::GetInfo () const {
0 commit comments