Skip to content

feature: add logging in orchestrator #128

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions include/svs/index/flat/flat.h
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,7 @@ class FlatIndex {
QueryResultView<size_t> result,
const data::ConstSimpleDataView<QueryType>& queries,
const search_parameters_type& search_parameters,
svs::logging::logger_ptr logger = svs::logging::get(),
const lib::DefaultPredicate& cancel = lib::Returns(lib::Const<false>()),
Pred predicate = lib::Returns(lib::Const<true>())
) {
Expand Down Expand Up @@ -346,6 +347,7 @@ class FlatIndex {
threads::UnitRange(start, stop),
scratch,
search_parameters,
logger,
cancel,
predicate
);
Expand Down Expand Up @@ -376,6 +378,7 @@ class FlatIndex {
const threads::UnitRange<size_t>& data_indices,
sorter_type& scratch,
const search_parameters_type& search_parameters,
svs::logging::logger_ptr logger = svs::logging::get(),
const lib::DefaultPredicate& cancel = lib::Returns(lib::Const<false>()),
Pred predicate = lib::Returns(lib::Const<true>())
) {
Expand All @@ -397,6 +400,7 @@ class FlatIndex {
threads::UnitRange(query_indices),
scratch,
distances,
logger,
cancel,
predicate
);
Expand All @@ -419,6 +423,7 @@ class FlatIndex {
const threads::UnitRange<size_t>& query_indices,
sorter_type& scratch,
distance::BroadcastDistance<DistFull>& distance_functors,
logging::logger_ptr SVS_UNUSED(logger) = svs::logging::get(),
const lib::DefaultPredicate& cancel = lib::Returns(lib::Const<false>()),
Pred predicate = lib::Returns(lib::Const<true>())
) {
Expand Down
13 changes: 9 additions & 4 deletions include/svs/index/index.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

// svs
#include "svs/concepts/data.h"
#include "svs/core/logging.h"
#include "svs/core/query_result.h"

// stl
Expand Down Expand Up @@ -47,11 +48,12 @@ void search_batch_into_with(
svs::QueryResultView<I> result,
const Queries& queries,
const search_parameters_t<Index>& search_parameters,
svs::logging::logger_ptr logger = svs::logging::get(),
const lib::DefaultPredicate& cancel = lib::Returns(lib::Const<false>())
) {
// Assert pre-conditions.
assert(result.n_queries() == queries.size());
index.search(result, queries, search_parameters, cancel);
index.search(result, queries, search_parameters, logger, cancel);
}

// Apply default search parameters
Expand All @@ -60,10 +62,11 @@ void search_batch_into(
Index& index,
svs::QueryResultView<I> result,
const Queries& queries,
svs::logging::logger_ptr logger = svs::logging::get(),
const lib::DefaultPredicate& cancel = lib::Returns(lib::Const<false>())
) {
svs::index::search_batch_into_with(
index, result, queries, index.get_search_parameters(), cancel
index, result, queries, index.get_search_parameters(), logger, cancel
);
}

Expand All @@ -74,11 +77,12 @@ svs::QueryResult<size_t> search_batch_with(
const Queries& queries,
size_t num_neighbors,
const search_parameters_t<Index>& search_parameters,
svs::logging::logger_ptr logger = svs::logging::get(),
const lib::DefaultPredicate& cancel = lib::Returns(lib::Const<false>())
) {
auto result = svs::QueryResult<size_t>{queries.size(), num_neighbors};
svs::index::search_batch_into_with(
index, result.view(), queries, search_parameters, cancel
index, result.view(), queries, search_parameters, logger, cancel
);
return result;
}
Expand All @@ -89,10 +93,11 @@ svs::QueryResult<size_t> search_batch(
Index& index,
const Queries& queries,
size_t num_neighbors,
svs::logging::logger_ptr logger = svs::logging::get(),
const lib::DefaultPredicate& cancel = lib::Returns(lib::Const<false>())
) {
return svs::index::search_batch_with(
index, queries, num_neighbors, index.get_search_parameters(), cancel
index, queries, num_neighbors, index.get_search_parameters(), logger, cancel
);
}
} // namespace svs::index
3 changes: 2 additions & 1 deletion include/svs/index/inverted/memory_based.h
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,7 @@ template <typename Index, typename Cluster> class InvertedIndex {
QueryResultView<Idx> results,
const Queries& queries,
const search_parameters_type& search_parameters,
svs::logging::logger_ptr logger = svs::logging::get(),
const lib::DefaultPredicate& cancel = lib::Returns(lib::Const<false>())
) {
threads::parallel_for(
Expand All @@ -431,7 +432,7 @@ template <typename Index, typename Cluster> class InvertedIndex {

auto&& query = queries.get_datum(i);
// Primary Index Search
index_.search(query, scratch, cancel);
index_.search(query, scratch, logger, cancel);

auto& d = scratch.scratch;
auto compare = distance::comparator(d);
Expand Down
13 changes: 9 additions & 4 deletions include/svs/index/vamana/consolidate.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ class GraphConsolidator {
Pool& threadpool_;
const Distance& distance_;
ConsolidationParameters params_;
svs::logging::logger_ptr logger_;

public:
// Constructor
Expand All @@ -169,13 +170,15 @@ class GraphConsolidator {
const Data& data,
Pool& threadpool,
const Distance& distance,
const ConsolidationParameters& params
const ConsolidationParameters& params,
svs::logging::logger_ptr logger = svs::logging::get()
)
: graph_{graph}
, data_{data}
, threadpool_{threadpool}
, distance_{distance}
, params_{params} {
, params_{params}
, logger_{logger} {
assert(graph.n_nodes() == data.size());
}

Expand Down Expand Up @@ -362,10 +365,12 @@ void consolidate(
size_t max_candidate_pool_size,
float alpha,
const Distance& distance,
Deleted&& is_deleted
Deleted&& is_deleted,
svs::logging::logger_ptr logger = svs::logging::get()
) {
ConsolidationParameters params{200'000, prune_to, max_candidate_pool_size, alpha};
auto consolidator = GraphConsolidator{graph, data, threadpool, distance, params};
auto consolidator =
GraphConsolidator{graph, data, threadpool, distance, params, logger};
consolidator(is_deleted);
}

Expand Down
47 changes: 35 additions & 12 deletions include/svs/index/vamana/dynamic_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,11 @@ class MutableVamanaIndex {
const lib::DefaultPredicate& cancel = lib::Returns(lib::Const<false>())
) const {
return [&, prefetch_parameters](
const auto& query, auto& accessor, auto& distance, auto& buffer
const auto& query,
auto& accessor,
auto& distance,
auto& buffer,
auto& logger
) {
// Perform the greedy search using the provided resources.
greedy_search(
Expand All @@ -460,6 +464,7 @@ class MutableVamanaIndex {
vamana::EntryPointInitializer<Idx>{lib::as_const_span(entry_point_)},
internal_search_builder(),
prefetch_parameters,
logger_ ? logger_ : logger,
cancel
);
// Take a pass over the search buffer to remove any deleted elements that
Expand All @@ -473,14 +478,16 @@ class MutableVamanaIndex {
void search(
const Query& query,
scratchspace_type& scratch,
const lib::DefaultPredicate& cancel = lib::Returns(lib::Const<false>())
const lib::DefaultPredicate& cancel = lib::Returns(lib::Const<false>()),
svs::logging::logger_ptr logger = svs::logging::get()
) const {
extensions::single_search(
data_,
scratch.buffer,
scratch.scratch,
query,
greedy_search_closure(scratch.prefetch_parameters, cancel)
greedy_search_closure(scratch.prefetch_parameters, cancel),
logger_ ? logger_ : logger
);
}

Expand All @@ -489,6 +496,7 @@ class MutableVamanaIndex {
QueryResultView<I> results,
const Queries& queries,
const search_parameters_type& sp,
svs::logging::logger_ptr logger = svs::logging::get(),
const lib::DefaultPredicate& cancel = lib::Returns(lib::Const<false>())
) {
threads::parallel_for(
Expand Down Expand Up @@ -516,6 +524,7 @@ class MutableVamanaIndex {
results,
threads::UnitRange{is},
greedy_search_closure(prefetch_parameters, cancel),
logger,
cancel
);
}
Expand Down Expand Up @@ -616,7 +625,10 @@ class MutableVamanaIndex {
///
template <data::ImmutableMemoryDataset Points, class ExternalIds>
std::vector<size_t> add_points(
const Points& points, const ExternalIds& external_ids, bool reuse_empty = false
const Points& points,
const ExternalIds& external_ids,
bool reuse_empty = false,
svs::logging::logger_ptr logger = svs::logging::get()
) {
const size_t num_points = points.size();
const size_t num_ids = external_ids.size();
Expand Down Expand Up @@ -690,7 +702,9 @@ class MutableVamanaIndex {
GreedySearchPrefetchParameters{sp.prefetch_lookahead_, sp.prefetch_step_};
VamanaBuilder builder{
graph_, data_, distance_, parameters, threadpool_, prefetch_parameters};
builder.construct(alpha_, entry_point(), slots, logging::Level::Trace, logger_);
builder.construct(
alpha_, entry_point(), slots, logging::Level::Trace, logger_ ? logger_ : logger
);
// Mark all added entries as valid.
for (const auto& i : slots) {
status_[i] = SlotMetadata::Valid;
Expand Down Expand Up @@ -724,16 +738,20 @@ class MutableVamanaIndex {
/// Delete consolidation performs the actual removal of deleted entries from the
/// graph.
///
template <typename T> size_t delete_entries(const T& ids) {
template <typename T>
size_t
delete_entries(const T& ids, svs::logging::logger_ptr logger = svs::logging::get()) {
translator_.check_external_exist(ids.begin(), ids.end());
for (auto i : ids) {
delete_entry(translator_.get_internal(i));
delete_entry(translator_.get_internal(i), logger_ ? logger_ : logger);
}
translator_.delete_external(ids);
return ids.size();
}

void delete_entry(size_t i) {
void delete_entry(
size_t i, svs::logging::logger_ptr SVS_UNUSED(logger) = svs::logging::get()
) {
SlotMetadata& meta = getindex(status_, i);
assert(meta == SlotMetadata::Valid);
meta = SlotMetadata::Deleted;
Expand Down Expand Up @@ -768,7 +786,10 @@ class MutableVamanaIndex {
/// @param batch_size Granularity at which points are shuffled. Setting this higher can
/// improve performance but requires more working memory.
///
void compact(Idx batch_size = 1'000) {
void compact(
Idx batch_size = 1'000,
svs::logging::logger_ptr SVS_UNUSED(logger) = svs::logging::get()
) {
// Step 1: Compute a prefix-sum matching each valid internal index to its new
// internal index.
//
Expand Down Expand Up @@ -955,7 +976,8 @@ class MutableVamanaIndex {
max_candidates_,
alpha_,
distance_,
check_is_deleted
check_is_deleted,
logger_ ? logger_ : logger_
);

// After consolidation - set all `Deleted` slots to `Empty`.
Expand Down Expand Up @@ -1031,7 +1053,8 @@ class MutableVamanaIndex {
const GroundTruth& groundtruth,
size_t num_neighbors,
double target_recall,
const CalibrationParameters& calibration_parameters = {}
const CalibrationParameters& calibration_parameters = {},
svs::logging::logger_ptr logger = svs::logging::get()
) {
// Preallocate the destination for search.
// Further, reference the search lambda in the recall lambda.
Expand All @@ -1054,7 +1077,7 @@ class MutableVamanaIndex {
target_recall,
compute_recall,
do_search,
logger_
logger
);

set_search_parameters(p);
Expand Down
12 changes: 9 additions & 3 deletions include/svs/index/vamana/extensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "svs/concepts/distance.h"
#include "svs/core/data.h"
#include "svs/core/distance.h"
#include "svs/core/logging.h"
#include "svs/core/medioid.h"
#include "svs/core/query_result.h"
#include "svs/index/vamana/greedy_search.h"
Expand Down Expand Up @@ -417,9 +418,10 @@ struct VamanaSingleSearchType {
Scratch& scratch,
const Query& query,
const Search& search,
svs::logging::logger_ptr logger = svs::logging::get(),
const lib::DefaultPredicate& cancel = lib::Returns(lib::Const<false>())
) const {
svs::svs_invoke(*this, data, search_buffer, scratch, query, search, cancel);
svs::svs_invoke(*this, data, search_buffer, scratch, query, search, logger, cancel);
}
};

Expand All @@ -442,6 +444,7 @@ SVS_FORCE_INLINE void svs_invoke(
Distance& distance,
const Query& query,
const Search& search,
svs::logging::logger_ptr logger = svs::logging::get(),
const lib::DefaultPredicate& cancel = lib::Returns(lib::Const<false>())
) {
// Check if request to cancel the search
Expand All @@ -450,7 +453,7 @@ SVS_FORCE_INLINE void svs_invoke(
}
// Perform graph search.
auto accessor = data::GetDatumAccessor();
search(query, accessor, distance, search_buffer);
search(query, accessor, distance, search_buffer, logger);
}

///
Expand Down Expand Up @@ -497,6 +500,7 @@ struct VamanaPerThreadBatchSearchType {
QueryResultView<I>& result,
threads::UnitRange<size_t> thread_indices,
const Search& search,
svs::logging::logger_ptr logger = svs::logging::get(),
const lib::DefaultPredicate& cancel = lib::Returns(lib::Const<false>())
) const {
svs::svs_invoke(
Expand All @@ -508,6 +512,7 @@ struct VamanaPerThreadBatchSearchType {
result,
thread_indices,
search,
logger,
cancel
);
}
Expand All @@ -533,6 +538,7 @@ void svs_invoke(
QueryResultView<I>& result,
threads::UnitRange<size_t> thread_indices,
const Search& search,
svs::logging::logger_ptr logger = svs::logging::get(),
const lib::DefaultPredicate& cancel = lib::Returns(lib::Const<false>())
) {
// Fallback implementation
Expand All @@ -544,7 +550,7 @@ void svs_invoke(
}
// Perform search - results will be queued in the search buffer.
single_search(
dataset, search_buffer, distance, queries.get_datum(i), search, cancel
dataset, search_buffer, distance, queries.get_datum(i), search, logger, cancel
);

// Copy back results.
Expand Down
Loading
Loading