Skip to content

Commit 30ea6b8

Browse files
committed
feat(search): Multishard cutoffs
Signed-off-by: Vladislav Oleshko <[email protected]>
1 parent 1d02e12 commit 30ea6b8

13 files changed

+571
-173
lines changed

src/core/search/ast_expr.h

+5
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,11 @@ struct AstNode : public NodeVariants {
108108
const NodeVariants& Variant() const& {
109109
return *this;
110110
}
111+
112+
// Aggregations reduce and re-order result sets.
113+
bool IsAggregation() const {
114+
return std::holds_alternative<AstKnnNode>(Variant());
115+
}
111116
};
112117

113118
using AstExpr = AstNode;

src/core/search/search.cc

+4
Original file line numberDiff line numberDiff line change
@@ -599,4 +599,8 @@ void SearchAlgorithm::EnableProfiling() {
599599
profiling_enabled_ = true;
600600
}
601601

602+
bool SearchAlgorithm::IsProfilingEnabled() const {
603+
return profiling_enabled_;
604+
}
605+
602606
} // namespace dfly::search

src/core/search/search.h

+1
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ class SearchAlgorithm {
133133
std::optional<AggregationInfo> HasAggregation() const;
134134

135135
void EnableProfiling();
136+
bool IsProfilingEnabled() const;
136137

137138
private:
138139
bool profiling_enabled_ = false;

src/facade/reply_capture.cc

+11
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "facade/reply_capture.h"
55

66
#include "base/logging.h"
7+
#include "facade/conn_context.h"
78
#include "reply_capture.h"
89

910
#define SKIP_LESS(needed) \
@@ -150,6 +151,16 @@ void CapturingReplyBuilder::CollapseFilledCollections() {
150151
}
151152
}
152153

154+
CapturingReplyBuilder::ScopeCapture::ScopeCapture(CapturingReplyBuilder* crb,
155+
ConnectionContext* cntx)
156+
: cntx_{cntx} {
157+
old_ = cntx->Inject(crb);
158+
}
159+
160+
CapturingReplyBuilder::ScopeCapture::~ScopeCapture() {
161+
cntx_->Inject(old_);
162+
}
163+
153164
CapturingReplyBuilder::CollectionPayload::CollectionPayload(unsigned len, CollectionType type)
154165
: len{len}, type{type}, arr{} {
155166
arr.reserve(type == MAP ? len * 2 : len);

src/facade/reply_capture.h

+11
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
namespace facade {
1616

17+
class ConnectionContext;
1718
struct CaptureVisitor;
1819

1920
// CapturingReplyBuilder allows capturing replies and retrieveing them with Take().
@@ -66,6 +67,16 @@ class CapturingReplyBuilder : public RedisReplyBuilder {
6667
bool with_scores;
6768
};
6869

70+
public:
71+
struct ScopeCapture {
72+
ScopeCapture(CapturingReplyBuilder* crb, ConnectionContext* cntx);
73+
~ScopeCapture();
74+
75+
private:
76+
SinkReplyBuilder* old_;
77+
ConnectionContext* cntx_;
78+
};
79+
6980
public:
7081
CapturingReplyBuilder(ReplyMode mode = ReplyMode::FULL)
7182
: RedisReplyBuilder{nullptr}, reply_mode_{mode}, stack_{}, current_{} {

src/server/common.h

+5-3
Original file line numberDiff line numberDiff line change
@@ -214,11 +214,13 @@ template <typename RandGen> std::string GetRandomHex(RandGen& gen, size_t len) {
214214
// truthy value;
215215
template <typename T> struct AggregateValue {
216216
bool operator=(T val) {
217+
if (!bool(val))
218+
return false;
219+
217220
std::lock_guard l{mu_};
218-
if (!bool(current_) && bool(val)) {
221+
if (!bool(current_))
219222
current_ = val;
220-
}
221-
return bool(val);
223+
return true;
222224
}
223225

224226
T operator*() {

src/server/main_service.cc

+3-9
Original file line numberDiff line numberDiff line change
@@ -1363,13 +1363,6 @@ void Service::Unwatch(CmdArgList args, ConnectionContext* cntx) {
13631363
return (*cntx)->SendOk();
13641364
}
13651365

1366-
template <typename F> void WithReplies(CapturingReplyBuilder* crb, ConnectionContext* cntx, F&& f) {
1367-
SinkReplyBuilder* old_rrb = nullptr;
1368-
old_rrb = cntx->Inject(crb);
1369-
f();
1370-
cntx->Inject(old_rrb);
1371-
}
1372-
13731366
optional<CapturingReplyBuilder::Payload> Service::FlushEvalAsyncCmds(ConnectionContext* cntx,
13741367
bool force) {
13751368
auto& info = cntx->conn_state.script_info;
@@ -1385,9 +1378,10 @@ optional<CapturingReplyBuilder::Payload> Service::FlushEvalAsyncCmds(ConnectionC
13851378
cntx->transaction->MultiSwitchCmd(eval_cid);
13861379

13871380
CapturingReplyBuilder crb{ReplyMode::ONLY_ERR};
1388-
WithReplies(&crb, cntx, [&] {
1381+
{
1382+
CapturingReplyBuilder::ScopeCapture capture{&crb, cntx};
13891383
MultiCommandSquasher::Execute(absl::MakeSpan(info->async_cmds), cntx, this, true, true);
1390-
});
1384+
}
13911385

13921386
info->async_cmds_heap_mem = 0;
13931387
info->async_cmds.clear();

src/server/search/doc_index.cc

+84-22
Original file line numberDiff line numberDiff line change
@@ -56,13 +56,33 @@ const absl::flat_hash_map<string_view, search::SchemaField::FieldType> kSchemaTy
5656
{"NUMERIC"sv, search::SchemaField::NUMERIC},
5757
{"VECTOR"sv, search::SchemaField::VECTOR}};
5858

59+
size_t GetProbabilisticBound(size_t shards, size_t hits, size_t requested, bool is_aggregation) {
60+
auto intlog2 = [](size_t x) {
61+
size_t l = 0;
62+
while (x >>= 1)
63+
++l;
64+
return l;
65+
};
66+
size_t avg_shard_min = hits * intlog2(hits) / (12 + shards / 10);
67+
avg_shard_min -= min(avg_shard_min, min(hits, size_t(5)));
68+
69+
// VLOG(0) << "PROB BOUND " << hits << " " << shards << " " << requested << " => " <<
70+
// avg_shard_min
71+
// << " diffb " << requested / shards + 1 << " & " << requested;
72+
73+
if (!is_aggregation && avg_shard_min * shards >= requested)
74+
return requested / shards + 1;
75+
76+
return min(hits, requested);
77+
}
78+
5979
} // namespace
6080

61-
bool SerializedSearchDoc::operator<(const SerializedSearchDoc& other) const {
81+
bool DocResult::operator<(const DocResult& other) const {
6282
return this->score < other.score;
6383
}
6484

65-
bool SerializedSearchDoc::operator>=(const SerializedSearchDoc& other) const {
85+
bool DocResult::operator>=(const DocResult& other) const {
6686
return this->score >= other.score;
6787
}
6888

@@ -162,10 +182,11 @@ bool DocIndex::Matches(string_view key, unsigned obj_code) const {
162182
}
163183

164184
ShardDocIndex::ShardDocIndex(shared_ptr<DocIndex> index)
165-
: base_{std::move(index)}, indices_{{}, nullptr}, key_index_{} {
185+
: base_{std::move(index)}, write_epoch_{0}, indices_{{}, nullptr}, key_index_{} {
166186
}
167187

168188
void ShardDocIndex::Rebuild(const OpArgs& op_args, PMR_NS::memory_resource* mr) {
189+
write_epoch_++;
169190
key_index_ = DocKeyIndex{};
170191
indices_ = search::FieldIndices{base_->schema, mr};
171192

@@ -174,11 +195,13 @@ void ShardDocIndex::Rebuild(const OpArgs& op_args, PMR_NS::memory_resource* mr)
174195
}
175196

176197
void ShardDocIndex::AddDoc(string_view key, const DbContext& db_cntx, const PrimeValue& pv) {
198+
write_epoch_++;
177199
auto accessor = GetAccessor(db_cntx, pv);
178200
indices_.Add(key_index_.Add(key), accessor.get());
179201
}
180202

181203
void ShardDocIndex::RemoveDoc(string_view key, const DbContext& db_cntx, const PrimeValue& pv) {
204+
write_epoch_++;
182205
auto accessor = GetAccessor(db_cntx, pv);
183206
DocId id = key_index_.Remove(key);
184207
indices_.Remove(id, accessor.get());
@@ -188,38 +211,77 @@ bool ShardDocIndex::Matches(string_view key, unsigned obj_code) const {
188211
return base_->Matches(key, obj_code);
189212
}
190213

191-
SearchResult ShardDocIndex::Search(const OpArgs& op_args, const SearchParams& params,
192-
search::SearchAlgorithm* search_algo) const {
214+
io::Result<SearchResult, facade::ErrorReply> ShardDocIndex::Search(
215+
const OpArgs& op_args, const SearchParams& params, search::SearchAlgorithm* search_algo) const {
216+
auto search_results = search_algo->Search(&indices_);
217+
if (!search_results.error.empty())
218+
return nonstd::make_unexpected(facade::ErrorReply(std::move(search_results.error)));
219+
220+
size_t requested_count = params.limit_offset + params.limit_total;
221+
size_t serialize_count = min(requested_count, search_results.ids.size());
222+
223+
size_t cuttoff_bound = serialize_count;
224+
if (params.enable_cutoff && !params.IdsOnly())
225+
cuttoff_bound =
226+
GetProbabilisticBound(params.num_shards, search_results.ids.size(), requested_count,
227+
search_algo->HasAggregation().has_value());
228+
229+
VLOG(0) << "Requested " << requested_count << " got " << search_results.ids.size() << " cutoff "
230+
<< cuttoff_bound;
231+
232+
vector<DocResult> out(serialize_count);
233+
auto shard_id = EngineShard::tlocal()->shard_id();
234+
for (size_t i = 0; i < out.size(); i++) {
235+
out[i].value = DocResult::DocReference{shard_id, search_results.ids[i], i < cuttoff_bound};
236+
out[i].score =
237+
search_results.scores.empty() ? search::ResultScore{} : std::move(search_results.scores[i]);
238+
}
239+
240+
Serialize(op_args, params, absl::MakeSpan(out));
241+
242+
return SearchResult{write_epoch_, search_results.ids.size(), std::move(out),
243+
std::move(search_results.profile)};
244+
}
245+
246+
bool ShardDocIndex::Refill(const OpArgs& op_args, const SearchParams& params,
247+
search::SearchAlgorithm* search_algo, SearchResult* result) const {
248+
if (result->write_epoch == write_epoch_) {
249+
Serialize(op_args, params, absl::MakeSpan(result->docs));
250+
return true;
251+
}
252+
253+
DCHECK(!params.enable_cutoff);
254+
auto new_result = Search(op_args, params, search_algo);
255+
CHECK(new_result.has_value());
256+
*result = std::move(new_result.value());
257+
return false;
258+
}
259+
260+
void ShardDocIndex::Serialize(const OpArgs& op_args, const SearchParams& params,
261+
absl::Span<DocResult> docs) const {
193262
auto& db_slice = op_args.shard->db_slice();
194-
auto search_results = search_algo->Search(&indices_, params.limit_offset + params.limit_total);
195263

196-
if (!search_results.error.empty())
197-
return SearchResult{facade::ErrorReply{std::move(search_results.error)}};
264+
for (auto& doc : docs) {
265+
if (!holds_alternative<DocResult::DocReference>(doc.value))
266+
continue;
198267

199-
vector<SerializedSearchDoc> out;
200-
out.reserve(search_results.ids.size());
268+
auto ref = get<DocResult::DocReference>(doc.value);
269+
if (!ref.requested)
270+
return;
201271

202-
size_t expired_count = 0;
203-
for (size_t i = 0; i < search_results.ids.size(); i++) {
204-
auto key = key_index_.Get(search_results.ids[i]);
205-
auto it = db_slice.Find(op_args.db_cntx, key, base_->GetObjCode());
272+
string key{key_index_.Get(ref.doc_id)};
206273

274+
auto it = db_slice.Find(op_args.db_cntx, key, base_->GetObjCode());
207275
if (!it || !IsValid(*it)) { // Item must have expired
208-
expired_count++;
276+
doc.value = DocResult::SerializedValue{std::move(key), {}};
209277
continue;
210278
}
211279

212280
auto accessor = GetAccessor(op_args.db_cntx, (*it)->second);
213281
auto doc_data = params.return_fields ? accessor->Serialize(base_->schema, *params.return_fields)
214282
: accessor->Serialize(base_->schema);
215-
216-
auto score =
217-
search_results.scores.empty() ? std::monostate{} : std::move(search_results.scores[i]);
218-
out.push_back(SerializedSearchDoc{string{key}, std::move(doc_data), std::move(score)});
283+
doc.value = DocResult::SerializedValue{std::move(key), std::move(doc_data)};
219284
}
220-
221-
return SearchResult{search_results.total - expired_count, std::move(out),
222-
std::move(search_results.profile)};
223285
}
224286

225287
DocIndexInfo ShardDocIndex::GetInfo() const {

src/server/search/doc_index.h

+38-18
Original file line numberDiff line numberDiff line change
@@ -25,31 +25,39 @@ using SearchDocData = absl::flat_hash_map<std::string /*field*/, std::string /*v
2525
std::optional<search::SchemaField::FieldType> ParseSearchFieldType(std::string_view name);
2626
std::string_view SearchFieldTypeToString(search::SchemaField::FieldType);
2727

28-
struct SerializedSearchDoc {
29-
std::string key;
30-
SearchDocData values;
28+
struct DocResult {
29+
struct SerializedValue {
30+
std::string key;
31+
SearchDocData values;
32+
};
33+
34+
struct DocReference {
35+
ShardId shard_id;
36+
search::DocId doc_id;
37+
bool requested;
38+
};
39+
40+
std::variant<SerializedValue, DocReference> value;
3141
search::ResultScore score;
3242

33-
bool operator<(const SerializedSearchDoc& other) const;
34-
bool operator>=(const SerializedSearchDoc& other) const;
43+
bool operator<(const DocResult& other) const;
44+
bool operator>=(const DocResult& other) const;
3545
};
3646

3747
struct SearchResult {
38-
SearchResult() = default;
48+
size_t write_epoch = 0; // Write epoch of the index during on the result was created
3949

40-
SearchResult(size_t total_hits, std::vector<SerializedSearchDoc> docs,
41-
std::optional<search::AlgorithmProfile> profile)
42-
: total_hits{total_hits}, docs{std::move(docs)}, profile{std::move(profile)} {
43-
}
50+
size_t total_hits = 0; // total number of hits in shard
51+
std::vector<DocResult> docs; // serialized documents of first hits
4452

45-
SearchResult(facade::ErrorReply error) : error{std::move(error)} {
46-
}
53+
// After combining results from multiple shards and accumulating more documents than initially
54+
// requested, only a subset of all documents will be sent back to the client,
55+
// so it doesn't make sense to serialize strictly all documents in every shard ahead.
56+
// Instead, only documents up to a probablistic bound are serialized, the
57+
// leftover ids and scores are stored in the cutoff tail for use in the "unlikely" scenario.
58+
// size_t num_cutoff = 0;
4759

48-
size_t total_hits;
49-
std::vector<SerializedSearchDoc> docs;
5060
std::optional<search::AlgorithmProfile> profile;
51-
52-
std::optional<facade::ErrorReply> error;
5361
};
5462

5563
struct SearchParams {
@@ -61,6 +69,10 @@ struct SearchParams {
6169
size_t limit_offset = 0;
6270
size_t limit_total = 10;
6371

72+
// Total number of shards, used in probabilistic queries
73+
size_t num_shards = 0;
74+
bool enable_cutoff = false;
75+
6476
// Set but empty means no fields should be returned
6577
std::optional<FieldReturnList> return_fields;
6678
std::optional<search::SortOption> sort_option;
@@ -123,8 +135,12 @@ class ShardDocIndex {
123135
ShardDocIndex(std::shared_ptr<DocIndex> index);
124136

125137
// Perform search on all indexed documents and return results.
126-
SearchResult Search(const OpArgs& op_args, const SearchParams& params,
127-
search::SearchAlgorithm* search_algo) const;
138+
io::Result<SearchResult, facade::ErrorReply> Search(const OpArgs& op_args,
139+
const SearchParams& params,
140+
search::SearchAlgorithm* search_algo) const;
141+
142+
bool Refill(const OpArgs& op_args, const SearchParams& params,
143+
search::SearchAlgorithm* search_algo, SearchResult* result) const;
128144

129145
// Return whether base index matches
130146
bool Matches(std::string_view key, unsigned obj_code) const;
@@ -138,8 +154,12 @@ class ShardDocIndex {
138154
// Clears internal data. Traverses all matching documents and assigns ids.
139155
void Rebuild(const OpArgs& op_args, PMR_NS::memory_resource* mr);
140156

157+
void Serialize(const OpArgs& op_args, const SearchParams& params,
158+
absl::Span<DocResult> docs) const;
159+
141160
private:
142161
std::shared_ptr<const DocIndex> base_;
162+
size_t write_epoch_;
143163
search::FieldIndices indices_;
144164
DocKeyIndex key_index_;
145165
};

0 commit comments

Comments
 (0)