Skip to content

Better support bytes, IPs, and JSON #152

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 55 additions & 4 deletions src/document.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@ use pyo3::{
basic::CompareOp,
prelude::*,
types::{
PyAny, PyBool, PyDateAccess, PyDateTime, PyDict, PyList, PyTimeAccess,
PyTuple,
PyAny, PyBool, PyDateAccess, PyDateTime, PyDict, PyInt, PyList,
PyTimeAccess, PyTuple,
},
Python,
};

use chrono::{offset::TimeZone, NaiveDateTime, Utc};
Expand All @@ -23,7 +24,8 @@ use serde_json::Value as JsonValue;
use std::{
collections::{BTreeMap, HashMap},
fmt,
net::Ipv6Addr,
net::{IpAddr, Ipv6Addr},
str::FromStr,
};

pub(crate) fn extract_value(any: &PyAny) -> PyResult<Value> {
Expand All @@ -50,6 +52,11 @@ pub(crate) fn extract_value(any: &PyAny) -> PyResult<Value> {
if let Ok(b) = any.extract::<Vec<u8>>() {
return Ok(Value::Bytes(b));
}
if let Ok(dict) = any.downcast::<PyDict>() {
if let Ok(json) = pythonize::depythonize(dict) {
return Ok(Value::JsonObject(json));
}
}
Err(to_pyerr(format!("Value unsupported {any:?}")))
}

Expand Down Expand Up @@ -105,7 +112,37 @@ pub(crate) fn extract_value_for_type(
.map_err(to_pyerr_for_type("Facet", field_name, any))?
.inner,
),
_ => return Err(to_pyerr(format!("Value unsupported {:?}", any))),
tv::schema::Type::Bytes => Value::Bytes(
any.extract::<Vec<u8>>()
.map_err(to_pyerr_for_type("Bytes", field_name, any))?,
),
tv::schema::Type::Json => {
if let Ok(json_str) = any.extract::<&str>() {
return serde_json::from_str(json_str)
.map(Value::JsonObject)
.map_err(to_pyerr_for_type("Json", field_name, any));
}

Value::JsonObject(
any.downcast::<PyDict>()
.map(|dict| pythonize::depythonize(&dict))
.map_err(to_pyerr_for_type("Json", field_name, any))?
.map_err(to_pyerr_for_type("Json", field_name, any))?,
)
}
tv::schema::Type::IpAddr => {
let val = any
.extract::<&str>()
.map_err(to_pyerr_for_type("IpAddr", field_name, any))?;

IpAddr::from_str(val)
.map(|addr| match addr {
IpAddr::V4(addr) => addr.to_ipv6_mapped(),
IpAddr::V6(addr) => addr,
})
.map(Value::IpAddr)
.map_err(to_pyerr_for_type("IpAddr", field_name, any))?
}
};

Ok(value)
Expand All @@ -126,6 +163,20 @@ fn extract_value_single_or_list_for_type(
) -> PyResult<Vec<Value>> {
// Check if a numeric fast field supports multivalues.
if let Ok(values) = any.downcast::<PyList>() {
// Process an array of integers as a single entry if it is a bytes field.
if field_type.value_type() == tv::schema::Type::Bytes
&& values
.get_item(0)
.map(|v| v.is_instance_of::<PyInt>())
.unwrap_or(false)
{
return Ok(vec![extract_value_for_type(
values,
field_type.value_type(),
field_name,
)?]);
}

values
.iter()
.map(|any| {
Expand Down
90 changes: 82 additions & 8 deletions src/schemabuilder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

use pyo3::{exceptions, prelude::*};

use tantivy::schema;

use crate::schema::Schema;
use std::sync::{Arc, RwLock};
use tantivy::schema::{DateOptions, INDEXED};
use tantivy::schema::{
self, BytesOptions, DateOptions, IpAddrOptions, INDEXED,
};

/// Tantivy has a very strict schema.
/// You need to specify in advance whether a field is indexed or not,
Expand Down Expand Up @@ -357,22 +357,96 @@ impl SchemaBuilder {

/// Add a fast bytes field to the schema.
///
/// Bytes field are not searchable and are only used
/// as fast field, to associate any kind of payload
/// to a document.
/// Args:
/// name (str): The name of the field.
/// stored (bool, optional): If true sets the field as stored, the
/// content of the field can be later restored from a Searcher.
/// Defaults to False.
/// indexed (bool, optional): If true sets the field to be indexed.
/// fast (str, optional): Set the bytes options as a fast field. A fast
/// field is a column-oriented fashion storage for tantivy. It is
/// designed for the fast random access of some document fields
/// given a document id.
#[pyo3(signature = (
name,
stored = false,
indexed = false,
fast = false
))]
fn add_bytes_field(
&mut self,
name: &str,
stored: bool,
indexed: bool,
fast: bool,
) -> PyResult<Self> {
let builder = &mut self.builder;
let mut opts = BytesOptions::default();
if stored {
opts = opts.set_stored();
}
if indexed {
opts = opts.set_indexed();
}
if fast {
opts = opts.set_fast();
}

if let Some(builder) = builder.write().unwrap().as_mut() {
builder.add_bytes_field(name, opts);
} else {
return Err(exceptions::PyValueError::new_err(
"Schema builder object isn't valid anymore.",
));
}
Ok(self.clone())
}

/// Add an IP address field to the schema.
///
/// Args:
/// name (str): The name of the field.
fn add_bytes_field(&mut self, name: &str) -> PyResult<Self> {
/// stored (bool, optional): If true sets the field as stored, the
/// content of the field can be later restored from a Searcher.
/// Defaults to False.
/// indexed (bool, optional): If true sets the field to be indexed.
/// fast (str, optional): Set the IP address options as a fast field. A
/// fast field is a column-oriented fashion storage for tantivy. It
/// is designed for the fast random access of some document fields
/// given a document id.
#[pyo3(signature = (
name,
stored = false,
indexed = false,
fast = false
))]
fn add_ip_addr_field(
&mut self,
name: &str,
stored: bool,
indexed: bool,
fast: bool,
) -> PyResult<Self> {
let builder = &mut self.builder;
let mut opts = IpAddrOptions::default();
if stored {
opts = opts.set_stored();
}
if indexed {
opts = opts.set_indexed();
}
if fast {
opts = opts.set_fast();
}

if let Some(builder) = builder.write().unwrap().as_mut() {
builder.add_bytes_field(name, INDEXED);
builder.add_ip_addr_field(name, opts);
} else {
return Err(exceptions::PyValueError::new_err(
"Schema builder object isn't valid anymore.",
));
}

Ok(self.clone())
}

Expand Down
83 changes: 74 additions & 9 deletions tests/tantivy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import copy
import datetime
import json
import tantivy
import pickle
import pytest
Expand Down Expand Up @@ -365,7 +366,9 @@ def test_order_by_search(self):
searched_doc = index.searcher().doc(doc_address)
assert searched_doc["title"] == ["Test title"]

result = searcher.search(query, 10, order_by_field="order", order=tantivy.Order.Asc)
result = searcher.search(
query, 10, order_by_field="order", order=tantivy.Order.Asc
)

assert len(result.hits) == 3

Expand Down Expand Up @@ -443,7 +446,7 @@ def test_with_merges(self):

assert searcher.num_segments < 8

def test_doc_from_dict_schema_validation(self):
def test_doc_from_dict_numeric_validation(self):
schema = (
SchemaBuilder()
.add_unsigned_field("unsigned")
Expand Down Expand Up @@ -504,6 +507,70 @@ def test_doc_from_dict_schema_validation(self):
schema,
)

def test_doc_from_dict_bytes_validation(self):
schema = SchemaBuilder().add_bytes_field("bytes").build()

good = Document.from_dict({"bytes": b"hello"}, schema)
good = Document.from_dict({"bytes": [[1, 2, 3], [4, 5, 6]]}, schema)
good = Document.from_dict({"bytes": [1, 2, 3]}, schema)

with pytest.raises(ValueError):
bad = Document.from_dict({"bytes": [1, 2, 256]}, schema)

with pytest.raises(ValueError):
bad = Document.from_dict({"bytes": "hello"}, schema)

with pytest.raises(ValueError):
bad = Document.from_dict({"bytes": [1024, "there"]}, schema)

def test_doc_from_dict_ip_addr_validation(self):
schema = SchemaBuilder().add_ip_addr_field("ip").build()

good = Document.from_dict({"ip": "127.0.0.1"}, schema)
good = Document.from_dict({"ip": "::1"}, schema)

with pytest.raises(ValueError):
bad = Document.from_dict({"ip": 12309812348}, schema)

with pytest.raises(ValueError):
bad = Document.from_dict({"ip": "256.100.0.1"}, schema)

with pytest.raises(ValueError):
bad = Document.from_dict(
{"ip": "1234:5678:9ABC:DEF0:1234:5678:9ABC:DEF0:1234"}, schema
)

with pytest.raises(ValueError):
bad = Document.from_dict(
{"ip": "1234:5678:9ABC:DEF0:1234:5678:9ABC:GHIJ"}, schema
)

def test_doc_from_dict_json_validation(self):
# Test implicit JSON
good = Document.from_dict({"dict": {"hello": "world"}})

schema = SchemaBuilder().add_json_field("json").build()

good = Document.from_dict({"json": {}}, schema)
good = Document.from_dict({"json": {"hello": "world"}}, schema)
good = Document.from_dict(
{"nested": {"hello": ["world", "!"]}, "numbers": [1, 2, 3]}, schema
)

list_of_jsons = [
{"hello": "world"},
{"nested": {"hello": ["world", "!"]}, "numbers": [1, 2, 3]},
]
good = Document.from_dict({"json": list_of_jsons}, schema)

good = Document.from_dict({"json": json.dumps(list_of_jsons[1])}, schema)

with pytest.raises(ValueError):
bad = Document.from_dict({"json": 123}, schema)

with pytest.raises(ValueError):
bad = Document.from_dict({"json": "hello"}, schema)

def test_search_result_eq(self, ram_index, spanish_index):
eng_index = ram_index
eng_query = eng_index.parse_query("sea whale", ["title", "body"])
Expand Down Expand Up @@ -650,10 +717,6 @@ def test_document_with_facet(self):
doc = tantivy.Document(facet=facet)
assert doc["facet"][0].to_path() == ["asia/oceania", "fiji"]

def test_document_error(self):
with pytest.raises(ValueError):
tantivy.Document(name={})

def test_document_eq(self):
doc1 = tantivy.Document(name="Bill", reference=[1, 2])
doc2 = tantivy.Document.from_dict({"name": "Bill", "reference": [1, 2]})
Expand Down Expand Up @@ -848,9 +911,11 @@ def test_document_snippet(self, dir_index):
result = searcher.search(query)
assert len(result.hits) == 1

snippet_generator = SnippetGenerator.create(searcher, query, doc_schema, "title")
snippet_generator = SnippetGenerator.create(
searcher, query, doc_schema, "title"
)

for (score, doc_address) in result.hits:
for score, doc_address in result.hits:
doc = searcher.doc(doc_address)
snippet = snippet_generator.snippet_from_doc(doc)
highlights = snippet.highlighted()
Expand All @@ -859,4 +924,4 @@ def test_document_snippet(self, dir_index):
assert first.start == 20
assert first.end == 23
html_snippet = snippet.to_html()
assert html_snippet == 'The Old Man and the <b>Sea</b>'
assert html_snippet == "The Old Man and the <b>Sea</b>"