Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/binding/c/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4307,7 +4307,7 @@ size_t zvec_doc_memory_usage(const zvec_doc_t *doc) {
return doc_ptr->memory_usage();)
}

zvec_error_code_t zvec_doc_validate(const zvec_doc_t *doc,
zvec_error_code_t zvec_doc_validate_and_sanitize(zvec_doc_t *doc,
const zvec_collection_schema_t *schema,
bool is_update, char **error_msg) {
if (!doc || !schema) {
Expand All @@ -4327,15 +4327,15 @@ zvec_error_code_t zvec_doc_validate(const zvec_doc_t *doc,
return status_to_error_code(status);
}

auto doc_ptr = reinterpret_cast<const zvec::Doc *>(doc);
status = doc_ptr->validate(schema_ptr, is_update); if (!status.ok()) {
auto doc_ptr = reinterpret_cast<zvec::Doc *>(doc);
status = doc_ptr->validate_and_sanitize(schema_ptr, is_update); if (!status.ok()) {
if (error_msg) {
*error_msg = copy_string(status.message());
}
return status_to_error_code(status);
}

if (error_msg) { *error_msg = nullptr; }
if (error_msg) { *error_msg = nullptr; }
return ZVEC_OK;)
}

Expand Down
12 changes: 6 additions & 6 deletions src/db/collection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
#include "db/common/file_helper.h"
#include "db/common/profiler.h"
#include "db/common/typedef.h"
#include "db/index/column/vector_column/vector_column_indexer.h"
#include "db/index/common/delete_store.h"
#include "db/index/common/id_map.h"
#include "db/index/common/index_filter.h"
Expand Down Expand Up @@ -1443,8 +1442,8 @@ Result<WriteResults> CollectionImpl::write_impl(std::vector<Doc> &docs,
CHECK_DESTROY_RETURN_STATUS_EXPECTED(destroyed_, false);

for (auto &&doc : docs) {
auto validate = doc.validate(schema_, mode == WriteMode::UPDATE);
CHECK_RETURN_STATUS_EXPECTED(validate);
auto s = doc.validate_and_sanitize(schema_, mode == WriteMode::UPDATE);
CHECK_RETURN_STATUS_EXPECTED(s);
}

// TODO: The granularity of the write_lock is too coarse.
Expand All @@ -1458,7 +1457,6 @@ Result<WriteResults> CollectionImpl::write_impl(std::vector<Doc> &docs,
kMaxWriteBatchSize));
}

// validate docs
for (auto &&doc : docs) {
if (need_switch_to_new_segment()) {
auto s = switch_to_new_segment_for_writing();
Expand Down Expand Up @@ -1583,15 +1581,17 @@ Result<DocPtrList> CollectionImpl::Query(const VectorQuery &query) const {

CHECK_DESTROY_RETURN_STATUS_EXPECTED(destroyed_, false);

auto s = query.validate(schema_->get_vector_field(query.field_name_));
VectorQuery sanitized = query;
auto s = sanitized.validate_and_sanitize(
schema_->get_vector_field(sanitized.field_name_));
CHECK_RETURN_STATUS_EXPECTED(s);

auto segments = get_all_segments();
if (segments.empty()) {
return DocPtrList();
}

return sql_engine_->execute(schema_, query, segments);
return sql_engine_->execute(schema_, sanitized, segments);
}

Result<GroupResults> CollectionImpl::GroupByQuery(
Expand Down
144 changes: 119 additions & 25 deletions src/db/index/common/doc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,13 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <algorithm>
#include <cmath>
#include <cstdint>
#include <cstdlib>
#include <cstring>
#include <numeric>
#include <regex>
#include <stdexcept>
#include <zvec/ailego/internal/platform.h>
Expand Down Expand Up @@ -114,6 +117,9 @@ std::string get_value_type_name(const Doc::Value &value, bool is_vector) {
value);
}


namespace {

template <typename T>
T byte_swap(T value) {
if constexpr (std::is_same_v<T, float16_t>) {
Expand Down Expand Up @@ -159,6 +165,68 @@ T read_value_from_buffer(const uint8_t *&data) {
return value;
}

template <typename T>
std::string vec_to_string(const std::vector<T> &v) {
std::ostringstream oss;
oss << "[";
for (size_t i = 0; i < v.size(); ++i) {
if (i > 0) oss << ", ";
oss << +v[i]; // + from print as char
}
oss << "]";
return oss.str();
}

template <class... Ts>
struct overloaded : Ts... {
using Ts::operator()...;
};

template <class... Ts>
overloaded(Ts...) -> overloaded<Ts...>;


bool sort_and_find_duplicates(uint32_t *indices, char *values, size_t n,
size_t value_byte_size) {
if (n <= 1) {
return false;
}
bool already_sorted = true;
for (size_t i = 1; i < n; ++i) {
if (indices[i] == indices[i - 1]) {
return true;
}
if (indices[i] < indices[i - 1]) {
already_sorted = false;
break;
}
}
if (already_sorted) {
return false;
}
std::vector<size_t> perm(n);
std::iota(perm.begin(), perm.end(), size_t{0});
std::sort(perm.begin(), perm.end(),
[&](size_t a, size_t b) { return indices[a] < indices[b]; });
std::vector<uint32_t> sorted_indices(n);
std::vector<char> sorted_values(n * value_byte_size);
for (size_t i = 0; i < n; ++i) {
sorted_indices[i] = indices[perm[i]];
std::memcpy(sorted_values.data() + i * value_byte_size,
values + perm[i] * value_byte_size, value_byte_size);
}
std::memcpy(indices, sorted_indices.data(), n * sizeof(uint32_t));
std::memcpy(values, sorted_values.data(), n * value_byte_size);
for (size_t i = 1; i < n; ++i) {
if (indices[i] == indices[i - 1]) {
return true;
}
}
return false;
}

} // namespace


void Doc::write_to_buffer(std::vector<uint8_t> &buffer, const void *src,
size_t size) {
Expand Down Expand Up @@ -693,8 +761,8 @@ Doc::Ptr Doc::deserialize(const uint8_t *data, size_t /*size*/) {
return doc;
}

Status Doc::validate(const CollectionSchema::Ptr &schema,
bool is_update) const {
Status Doc::validate_and_sanitize(const CollectionSchema::Ptr &schema,
bool is_update) {
if (!schema) {
return Status::InternalError("schema is null during doc validation");
}
Expand Down Expand Up @@ -739,7 +807,7 @@ Status Doc::validate(const CollectionSchema::Ptr &schema,
}
}

const Value &field_value = field_pair->second;
Value &field_value = field_pair->second;
DataType expected_type = field_schema->data_type();
bool type_match = true;
uint32_t value_dimension = 0;
Expand Down Expand Up @@ -860,7 +928,7 @@ Status Doc::validate(const CollectionSchema::Ptr &schema,
std::pair<std::vector<uint32_t>, std::vector<float16_t>>>(
field_value);
if (type_match) {
auto [sparse_indices, sparse_values] = std::get<
auto &[sparse_indices, sparse_values] = std::get<
std::pair<std::vector<uint32_t>, std::vector<float16_t>>>(
field_value);
if (sparse_values.size() != sparse_indices.size()) {
Expand All @@ -874,6 +942,14 @@ Status Doc::validate(const CollectionSchema::Ptr &schema,
"] exceeds the maximum number of sparse indices (",
kSparseMaxDimSize, ")");
}
if (sort_and_find_duplicates(
sparse_indices.data(),
reinterpret_cast<char *>(sparse_values.data()),
sparse_indices.size(), sizeof(float16_t))) {
return Status::InvalidArgument(
"Invalid doc[", pk_, "]: sparse vector field[", field_name,
"] contains duplicate indices");
}
}
break;
}
Expand All @@ -895,6 +971,14 @@ Status Doc::validate(const CollectionSchema::Ptr &schema,
"] exceeds the maximum number of sparse indices (",
kSparseMaxDimSize, ")");
}
if (sort_and_find_duplicates(
sparse_indices.data(),
reinterpret_cast<char *>(sparse_values.data()),
sparse_indices.size(), sizeof(float))) {
return Status::InvalidArgument(
"Invalid doc[", pk_, "]: sparse vector field[", field_name,
"] contains duplicate indices");
}
}
break;
}
Expand Down Expand Up @@ -1036,24 +1120,6 @@ size_t Doc::memory_usage() const {
return usage;
}

template <typename T>
std::string vec_to_string(const std::vector<T> &v) {
std::ostringstream oss;
oss << "[";
for (size_t i = 0; i < v.size(); ++i) {
if (i > 0) oss << ", ";
oss << +v[i]; // + from print as char
}
oss << "]";
return oss.str();
}

template <class... Ts>
struct overloaded : Ts... {
using Ts::operator()...;
};
template <class... Ts>
overloaded(Ts...) -> overloaded<Ts...>;

std::string Doc::to_detail_string() const {
std::stringstream oss;
Expand Down Expand Up @@ -1202,7 +1268,7 @@ bool Doc::operator==(const Doc &other) const {
return true;
}

Status VectorQuery::validate(const FieldSchema *schema) const {
Status VectorQuery::validate_and_sanitize(const FieldSchema *schema) {
if ((uint32_t)topk_ > kMaxQueryTopk) {
return Status::InvalidArgument("Invalid query: topk[", topk_,
"] exceeds the maximum allowed value of ",
Expand Down Expand Up @@ -1274,12 +1340,40 @@ Status VectorQuery::validate(const FieldSchema *schema) const {
"] is not a dense vector field");
}
} else if (schema->is_sparse_vector()) {
// Validate sparse indices size
if (query_sparse_indices_.size() > kSparseMaxDimSize * sizeof(uint32_t)) {
size_t value_byte_size = 0;
switch (schema->data_type()) {
case DataType::SPARSE_VECTOR_FP32:
value_byte_size = sizeof(float);
break;
case DataType::SPARSE_VECTOR_FP16:
value_byte_size = sizeof(float16_t);
break;
default:
return Status::InvalidArgument(
"Invalid query: sparse vector type of field[", field_name_,
"] is not supported");
}
if (query_sparse_indices_.size() % sizeof(uint32_t) != 0 ||
query_sparse_values_.size() % value_byte_size != 0 ||
query_sparse_indices_.size() / sizeof(uint32_t) !=
query_sparse_values_.size() / value_byte_size) {
return Status::InvalidArgument(
"Invalid query: sparse vector query for field[", field_name_,
"] has mismatched indices and values sizes");
}
size_t n_indices = query_sparse_indices_.size() / sizeof(uint32_t);
if (n_indices > kSparseMaxDimSize) {
return Status::InvalidArgument(
"Invalid query: too many sparse indices, the maximum allowed is ",
kSparseMaxDimSize);
}
if (sort_and_find_duplicates(
reinterpret_cast<uint32_t *>(query_sparse_indices_.data()),
query_sparse_values_.data(), n_indices, value_byte_size)) {
return Status::InvalidArgument(
"Invalid query: sparse vector query for field[", field_name_,
"] contains duplicate indices");
}
} else {
return Status::InvalidArgument("Invalid query: field[", field_name_,
"] is not a vector field");
Expand Down
6 changes: 3 additions & 3 deletions src/include/zvec/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -3044,9 +3044,9 @@ ZVEC_EXPORT size_t ZVEC_CALL zvec_doc_memory_usage(const zvec_doc_t *doc);
* @param[out] error_msg Error message (needs manual release)
* @return zvec_error_code_t Error code
*/
ZVEC_EXPORT zvec_error_code_t ZVEC_CALL
zvec_doc_validate(const zvec_doc_t *doc, const zvec_collection_schema_t *schema,
bool is_update, char **error_msg);
ZVEC_EXPORT zvec_error_code_t ZVEC_CALL zvec_doc_validate_and_sanitize(
zvec_doc_t *doc, const zvec_collection_schema_t *schema, bool is_update,
char **error_msg);

/**
* @brief Get detailed string representation of document
Expand Down
6 changes: 3 additions & 3 deletions src/include/zvec/db/doc.h
Original file line number Diff line number Diff line change
Expand Up @@ -260,8 +260,8 @@ class Doc {
fields_.erase(field_name);
}

Status validate(const CollectionSchema::Ptr &schema,
bool is_update = false) const;
Status validate_and_sanitize(const CollectionSchema::Ptr &schema,
bool is_update = false);

size_t memory_usage() const;

Expand Down Expand Up @@ -378,7 +378,7 @@ struct VectorQuery {
std::optional<std::vector<std::string>> output_fields_;
QueryParams::Ptr query_params_;

Status validate(const FieldSchema *schema) const;
Status validate_and_sanitize(const FieldSchema *schema);
};

struct GroupByVectorQuery {
Expand Down
3 changes: 2 additions & 1 deletion tests/c/c_api_test.c
Original file line number Diff line number Diff line change
Expand Up @@ -4784,7 +4784,8 @@ void test_doc_advanced_functions(void) {
&(int32_t){42}, sizeof(int32_t));

char *error_msg = NULL;
zvec_error_code_t err = zvec_doc_validate(val_doc, schema, false, &error_msg);
zvec_error_code_t err =
zvec_doc_validate_and_sanitize(val_doc, schema, false, &error_msg);
TEST_ASSERT(err == ZVEC_OK);
if (error_msg) {
zvec_free(error_msg);
Expand Down
Loading
Loading