Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions python/tests/detail/test_collection_dql.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,69 @@ def test_fetch_empty_ids(self, full_collection: Collection):
f"Expected 0 results for empty ID list, but got {len(result)}"
)

@pytest.mark.parametrize("doc_num", [3])
def test_fetch_with_output_fields(self, full_collection: Collection, doc_num):
"""Test that fetch respects output_fields parameter."""
multiple_docs = [
generate_doc(i, full_collection.schema) for i in range(doc_num)
]
result = full_collection.insert(multiple_docs)
for item in result:
assert item.ok(), f"Insert failed: {item.code()}"

doc_id = multiple_docs[0].id

# Case 1: output_fields=None -> all scalar fields returned
fetched_all = full_collection.fetch(ids=[doc_id], output_fields=None)
assert doc_id in fetched_all
doc_all = fetched_all[doc_id]
assert doc_all is not None
assert doc_all.has_field("int32_field"), (
"int32_field should be present when output_fields=None"
)
assert doc_all.has_field("string_field"), (
"string_field should be present when output_fields=None"
)

# Case 2: output_fields=["int32_field"] -> only int32_field returned
fetched_partial = full_collection.fetch(
ids=[doc_id], output_fields=["int32_field"]
)
assert doc_id in fetched_partial
doc_partial = fetched_partial[doc_id]
assert doc_partial is not None
assert doc_partial.has_field("int32_field"), "int32_field should be present"
assert not doc_partial.has_field("string_field"), (
'string_field should not be present when output_fields=["int32_field"]'
)
assert not doc_partial.has_field("float_field"), (
'float_field should not be present when output_fields=["int32_field"]'
)

# Case 3: output_fields=[] (empty) -> no scalar fields returned
fetched_empty = full_collection.fetch(ids=[doc_id], output_fields=[])
assert doc_id in fetched_empty
doc_empty = fetched_empty[doc_id]
assert doc_empty is not None
assert doc_empty.id == doc_id, "pk should still be set"
assert not doc_empty.has_field("int32_field"), (
"int32_field should not be present when output_fields=[]"
)
assert not doc_empty.has_field("string_field"), (
"string_field should not be present when output_fields=[]"
)

# Case 4: multiple output_fields
fetched_multi = full_collection.fetch(
ids=[doc_id], output_fields=["int32_field", "float_field"]
)
assert doc_id in fetched_multi
doc_multi = fetched_multi[doc_id]
assert doc_multi is not None
assert doc_multi.has_field("int32_field")
assert doc_multi.has_field("float_field")
assert not doc_multi.has_field("string_field")


class TestCollectionQuery:
@pytest.mark.parametrize("doc_num", [5])
Expand Down
4 changes: 3 additions & 1 deletion python/zvec/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,9 @@ class _Collection:
def Destroy(self) -> None: ...
def DropColumn(self, arg0: str) -> None: ...
def DropIndex(self, arg0: str) -> None: ...
def Fetch(self, arg0: collections.abc.Sequence[str]) -> dict[str, _Doc]: ...
def Fetch(
self, pks: collections.abc.Sequence[str], output_fields: list[str] | None = None
) -> dict[str, _Doc]: ...
def Flush(self) -> None: ...
def GroupByQuery(self, arg0: ...) -> list[...]: ...
def Insert(self, arg0: collections.abc.Sequence[_Doc]) -> list[typing.Status]: ...
Expand Down
11 changes: 9 additions & 2 deletions python/zvec/model/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,17 +336,24 @@ def delete_by_filter(self, filter: str) -> None:
self._obj.DeleteByFilter(filter)

# ========== Collection DQL-fetch Methods ==========
def fetch(self, ids: Union[str, list[str]]) -> dict[str, Doc]:
def fetch(
self,
ids: Union[str, list[str]],
*,
output_fields: Optional[list[str]] = None,
) -> dict[str, Doc]:
"""Retrieve documents by ID.

Args:
ids (Union[str, list[str]]): Document IDs to fetch.
output_fields (Optional[list[str]], optional): Scalar fields to
include. If None, all fields are returned. Defaults to None.

Returns:
dict[str, Doc]: Mapping from ID to document. Missing IDs are omitted.
"""
ids = [ids] if isinstance(ids, str) else ids
docs = self._obj.Fetch(ids)
docs = self._obj.Fetch(ids, output_fields)
return {
doc_id: py_doc
for doc_id, core_doc in docs.items()
Expand Down
20 changes: 19 additions & 1 deletion src/binding/c/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6031,6 +6031,8 @@ zvec_error_code_t zvec_collection_query(const zvec_collection_t *collection,

zvec_error_code_t zvec_collection_fetch(zvec_collection_t *collection,
const char *const *pks, size_t pk_count,
const char *const *output_fields,
size_t output_field_count,
zvec_doc_t ***results, size_t *doc_count) {
if (!collection || !pks || !results || !doc_count) {
set_last_error(
Expand Down Expand Up @@ -6063,8 +6065,24 @@ zvec_error_code_t zvec_collection_fetch(zvec_collection_t *collection,
}
}

// Build optional output_fields
std::optional<std::vector<std::string>> cpp_output_fields;
if (output_fields != nullptr && output_field_count > 0) {
std::vector<std::string> fields;
fields.reserve(output_field_count);
for (size_t i = 0; i < output_field_count; ++i) {
if (output_fields[i]) {
fields.emplace_back(output_fields[i]);
} else {
set_last_error("Null output_field at index " + std::to_string(i));
return ZVEC_ERROR_INVALID_ARGUMENT;
}
}
cpp_output_fields = std::move(fields);
}

// Call C++ fetch method
auto result = (*coll_ptr)->Fetch(pk_vector);
auto result = (*coll_ptr)->Fetch(pk_vector, cpp_output_fields);
if (!result.has_value()) {
set_last_error("Failed to fetch documents: " +
result.error().message());
Expand Down
15 changes: 9 additions & 6 deletions src/binding/python/model/python_collection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -206,12 +206,15 @@ void ZVecPyCollection::bind_dql_methods(
// return GroupResults
return unwrap_expected(result);
})
.def("Fetch",
[](const Collection &self, const std::vector<std::string> &pks) {
const auto result = self.Fetch(pks);
// return DocPtrMap
return unwrap_expected(result);
});
.def(
"Fetch",
[](const Collection &self, const std::vector<std::string> &pks,
const std::optional<std::vector<std::string>> &output_fields) {
const auto result = self.Fetch(pks, output_fields);
// return DocPtrMap
return unwrap_expected(result);
},
py::arg("pks"), py::arg("output_fields") = py::none());
}

} // namespace zvec
9 changes: 6 additions & 3 deletions src/db/collection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,9 @@ class CollectionImpl : public Collection {
Result<GroupResults> GroupByQuery(
const GroupByVectorQuery &query) const override;

Result<DocPtrMap> Fetch(const std::vector<std::string> &pks) const override;
Result<DocPtrMap> Fetch(const std::vector<std::string> &pks,
const std::optional<std::vector<std::string>>
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是否需要支持 include_vector = false 这种写法,类似于 query

&output_fields = std::nullopt) const override;

private:
void prepare_schema();
Expand Down Expand Up @@ -1605,7 +1607,8 @@ Result<GroupResults> CollectionImpl::GroupByQuery(
}

Result<DocPtrMap> CollectionImpl::Fetch(
const std::vector<std::string> &pks) const {
const std::vector<std::string> &pks,
const std::optional<std::vector<std::string>> &output_fields) const {
std::shared_lock lock(schema_handle_mtx_);

CHECK_DESTROY_RETURN_STATUS_EXPECTED(destroyed_, false);
Expand All @@ -1631,7 +1634,7 @@ Result<DocPtrMap> CollectionImpl::Fetch(
results.insert({pk, nullptr});
continue;
}
results.insert({pk, segment->Fetch(doc_id)});
results.insert({pk, segment->Fetch(doc_id, output_fields)});
}

return results;
Expand Down
26 changes: 22 additions & 4 deletions src/db/index/segment/segment.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <mutex>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <ailego/parallel/multi_thread_list.h>
#include <ailego/pattern/defer.h>
#include <arrow/dataset/dataset.h>
Expand Down Expand Up @@ -133,7 +134,9 @@ class SegmentImpl : public Segment,

Status Delete(uint64_t g_doc_id) override;

Doc::Ptr Fetch(uint64_t g_doc_id) override;
Doc::Ptr Fetch(uint64_t g_doc_id,
const std::optional<std::vector<std::string>> &output_fields =
std::nullopt) override;

CombinedVectorColumnIndexer::Ptr get_combined_vector_indexer(
const std::string &field_name) const override;
Expand Down Expand Up @@ -1042,7 +1045,9 @@ Status SegmentImpl::ConvertVectorDataBufferToDocField(
}


Doc::Ptr SegmentImpl::Fetch(uint64_t g_doc_id) {
Doc::Ptr SegmentImpl::Fetch(
uint64_t g_doc_id,
const std::optional<std::vector<std::string>> &output_fields) {
std::lock_guard lock(seg_mtx_);

if (g_doc_id > segment_meta_->max_doc_id()) {
Expand All @@ -1067,8 +1072,21 @@ Doc::Ptr SegmentImpl::Fetch(uint64_t g_doc_id) {
std::vector<std::string> forward_columns;
forward_columns.push_back(GLOBAL_DOC_ID);
forward_columns.push_back(USER_ID);
for (const auto &field : collection_schema_->forward_fields()) {
forward_columns.push_back(field->name());
if (!output_fields.has_value()) {
// No output_fields specified: return all forward fields
for (const auto &field : collection_schema_->forward_fields()) {
forward_columns.push_back(field->name());
}
} else {
// output_fields specified: only return requested fields that exist
const auto &requested = *output_fields;
std::unordered_set<std::string> requested_set(requested.begin(),
requested.end());
for (const auto &field : collection_schema_->forward_fields()) {
if (requested_set.count(field->name())) {
forward_columns.push_back(field->name());
}
}
}

// Build result schema
Expand Down
4 changes: 3 additions & 1 deletion src/db/index/segment/segment.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,9 @@ class Segment {

virtual Status Delete(uint64_t g_doc_id) = 0;

virtual Doc::Ptr Fetch(uint64_t g_doc_id) = 0;
virtual Doc::Ptr Fetch(uint64_t g_doc_id,
const std::optional<std::vector<std::string>>
&output_fields = std::nullopt) = 0;

// for sqlengine
virtual TablePtr fetch(const std::vector<std::string> &columns,
Expand Down
7 changes: 6 additions & 1 deletion src/include/zvec/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -2650,14 +2650,19 @@ ZVEC_EXPORT zvec_error_code_t ZVEC_CALL zvec_collection_query(
* @param collection Collection handle
* @param primary_keys Primary key array
* @param count Number of primary keys
* @param output_fields Array of field names to return; NULL means return all
* fields
* @param output_field_count Number of output_fields entries; 0 if
* output_fields is NULL
* @param[out] documents Returned document array (needs to be freed by calling
* zvec_docs_free)
* @param[out] found_count Number of found documents
* @return zvec_error_code_t Error code
*/
ZVEC_EXPORT zvec_error_code_t ZVEC_CALL zvec_collection_fetch(
zvec_collection_t *collection, const char *const *primary_keys,
size_t count, zvec_doc_t ***documents, size_t *found_count);
size_t count, const char *const *output_fields, size_t output_field_count,
zvec_doc_t ***documents, size_t *found_count);

// =============================================================================
// Document Related Structures
Expand Down
6 changes: 4 additions & 2 deletions src/include/zvec/db/collection.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#pragma once

#include <memory>
#include <optional>
#include <string>
#include <vector>
#include <zvec/db/doc.h>
Expand Down Expand Up @@ -101,8 +102,9 @@ class Collection {
virtual Result<GroupResults> GroupByQuery(
const GroupByVectorQuery &query) const = 0;

virtual Result<DocPtrMap> Fetch(
const std::vector<std::string> &pks) const = 0;
virtual Result<DocPtrMap> Fetch(const std::vector<std::string> &pks,
const std::optional<std::vector<std::string>>
&output_fields = std::nullopt) const = 0;
};

} // namespace zvec
36 changes: 33 additions & 3 deletions tests/c/c_api_test.c
Original file line number Diff line number Diff line change
Expand Up @@ -3975,7 +3975,8 @@ void test_collection_nullable_roundtrip(void) {
const char *pks[] = {"pk_nullable"};
zvec_doc_t **fetched = NULL;
size_t fetched_count = 0;
err = zvec_collection_fetch(collection, pks, 1, &fetched, &fetched_count);
err = zvec_collection_fetch(collection, pks, 1, NULL, 0, &fetched,
&fetched_count);
TEST_ASSERT(err == ZVEC_OK);
TEST_ASSERT(fetched_count == 1);
if (fetched && fetched_count == 1) {
Expand Down Expand Up @@ -4689,15 +4690,44 @@ void test_collection_query_functions(void) {
zvec_collection_flush(collection);
zvec_collection_optimize(collection);

// Test zvec_collection_fetch
// Test zvec_collection_fetch (fetch all fields, NULL output_fields)
const char *pks[] = {"doc1", "doc2"};
zvec_doc_t **results = NULL;
size_t found_count = 0;
err = zvec_collection_fetch(collection, pks, 2, &results, &found_count);
err = zvec_collection_fetch(collection, pks, 2, NULL, 0, &results,
&found_count);
TEST_ASSERT(err == ZVEC_OK);
TEST_ASSERT(found_count == 2);
if (results && found_count == 2) {
// Both docs should have the "name" field
TEST_ASSERT(zvec_doc_has_field(results[0], "name") == true ||
zvec_doc_has_field(results[1], "name") == true);
}
zvec_docs_free(results, found_count);

// Test zvec_collection_fetch with output_fields=["name"]
zvec_doc_t **results_partial = NULL;
size_t found_count_partial = 0;
const char *output_fields[] = {"name"};
err = zvec_collection_fetch(collection, pks, 2, output_fields, 1,
&results_partial, &found_count_partial);
TEST_ASSERT(err == ZVEC_OK);
TEST_ASSERT(found_count_partial == 2);
if (results_partial && found_count_partial == 2) {
for (size_t i = 0; i < found_count_partial; ++i) {
TEST_ASSERT(zvec_doc_has_field(results_partial[i], "name") == true);
}
}
zvec_docs_free(results_partial, found_count_partial);

// Test zvec_collection_fetch with empty output_fields (no scalar fields)
zvec_doc_t **results_empty_fields = NULL;
size_t found_count_empty = 0;
err = zvec_collection_fetch(collection, pks, 2, NULL, 0,
&results_empty_fields, &found_count_empty);
TEST_ASSERT(err == ZVEC_OK);
zvec_docs_free(results_empty_fields, found_count_empty);

// Test zvec_collection_get_options
zvec_collection_options_t *options = NULL;
err = zvec_collection_get_options(collection, &options);
Expand Down
Loading
Loading