Skip to content

Commit d1d51e5

Browse files
committed
Better support bytes, IPs, and JSON
1 parent 4ac17da commit d1d51e5

File tree

3 files changed

+211
-21
lines changed

3 files changed

+211
-21
lines changed

src/document.rs

Lines changed: 55 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@ use pyo3::{
66
basic::CompareOp,
77
prelude::*,
88
types::{
9-
PyAny, PyBool, PyDateAccess, PyDateTime, PyDict, PyList, PyTimeAccess,
10-
PyTuple,
9+
PyAny, PyBool, PyDateAccess, PyDateTime, PyDict, PyInt, PyList,
10+
PyTimeAccess, PyTuple,
1111
},
12+
Python,
1213
};
1314

1415
use chrono::{offset::TimeZone, NaiveDateTime, Utc};
@@ -23,7 +24,8 @@ use serde_json::Value as JsonValue;
2324
use std::{
2425
collections::{BTreeMap, HashMap},
2526
fmt,
26-
net::Ipv6Addr,
27+
net::{IpAddr, Ipv6Addr},
28+
str::FromStr,
2729
};
2830

2931
pub(crate) fn extract_value(any: &PyAny) -> PyResult<Value> {
@@ -50,6 +52,11 @@ pub(crate) fn extract_value(any: &PyAny) -> PyResult<Value> {
5052
if let Ok(b) = any.extract::<Vec<u8>>() {
5153
return Ok(Value::Bytes(b));
5254
}
55+
if let Ok(dict) = any.downcast::<PyDict>() {
56+
if let Ok(json) = pythonize::depythonize(dict) {
57+
return Ok(Value::JsonObject(json));
58+
}
59+
}
5360
Err(to_pyerr(format!("Value unsupported {any:?}")))
5461
}
5562

@@ -105,7 +112,37 @@ pub(crate) fn extract_value_for_type(
105112
.map_err(to_pyerr_for_type("Facet", field_name, any))?
106113
.inner,
107114
),
108-
_ => return Err(to_pyerr(format!("Value unsupported {:?}", any))),
115+
tv::schema::Type::Bytes => Value::Bytes(
116+
any.extract::<Vec<u8>>()
117+
.map_err(to_pyerr_for_type("Bytes", field_name, any))?,
118+
),
119+
tv::schema::Type::Json => {
120+
if let Ok(json_str) = any.extract::<&str>() {
121+
return serde_json::from_str(json_str)
122+
.map(Value::JsonObject)
123+
.map_err(to_pyerr_for_type("Json", field_name, any));
124+
}
125+
126+
Value::JsonObject(
127+
any.downcast::<PyDict>()
128+
.map(|dict| pythonize::depythonize(&dict))
129+
.map_err(to_pyerr_for_type("Json", field_name, any))?
130+
.map_err(to_pyerr_for_type("Json", field_name, any))?,
131+
)
132+
}
133+
tv::schema::Type::IpAddr => {
134+
let val = any
135+
.extract::<&str>()
136+
.map_err(to_pyerr_for_type("IpAddr", field_name, any))?;
137+
138+
IpAddr::from_str(val)
139+
.map(|addr| match addr {
140+
IpAddr::V4(addr) => addr.to_ipv6_mapped(),
141+
IpAddr::V6(addr) => addr,
142+
})
143+
.map(Value::IpAddr)
144+
.map_err(to_pyerr_for_type("IpAddr", field_name, any))?
145+
}
109146
};
110147

111148
Ok(value)
@@ -126,6 +163,20 @@ fn extract_value_single_or_list_for_type(
126163
) -> PyResult<Vec<Value>> {
127164
// Check if a numeric fast field supports multivalues.
128165
if let Ok(values) = any.downcast::<PyList>() {
166+
// Process an array of integers as a single entry if it is a bytes field.
167+
if field_type.value_type() == tv::schema::Type::Bytes
168+
&& values
169+
.get_item(0)
170+
.map(|v| v.is_instance_of::<PyInt>())
171+
.unwrap_or(false)
172+
{
173+
return Ok(vec![extract_value_for_type(
174+
values,
175+
field_type.value_type(),
176+
field_name,
177+
)?]);
178+
}
179+
129180
values
130181
.iter()
131182
.map(|any| {

src/schemabuilder.rs

Lines changed: 82 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22

33
use pyo3::{exceptions, prelude::*};
44

5-
use tantivy::schema;
6-
75
use crate::schema::Schema;
86
use std::sync::{Arc, RwLock};
9-
use tantivy::schema::{DateOptions, INDEXED};
7+
use tantivy::schema::{
8+
self, BytesOptions, DateOptions, IpAddrOptions, INDEXED,
9+
};
1010

1111
/// Tantivy has a very strict schema.
1212
/// You need to specify in advance whether a field is indexed or not,
@@ -357,22 +357,96 @@ impl SchemaBuilder {
357357

358358
/// Add a fast bytes field to the schema.
359359
///
360-
/// Bytes field are not searchable and are only used
361-
/// as fast field, to associate any kind of payload
362-
/// to a document.
360+
/// Args:
361+
/// name (str): The name of the field.
362+
/// stored (bool, optional): If true sets the field as stored, the
363+
/// content of the field can be later restored from a Searcher.
364+
/// Defaults to False.
365+
/// indexed (bool, optional): If true sets the field to be indexed.
366+
/// fast (str, optional): Set the bytes options as a fast field. A fast
367+
/// field is a column-oriented fashion storage for tantivy. It is
368+
/// designed for the fast random access of some document fields
369+
/// given a document id.
370+
#[pyo3(signature = (
371+
name,
372+
stored = false,
373+
indexed = false,
374+
fast = false
375+
))]
376+
fn add_bytes_field(
377+
&mut self,
378+
name: &str,
379+
stored: bool,
380+
indexed: bool,
381+
fast: bool,
382+
) -> PyResult<Self> {
383+
let builder = &mut self.builder;
384+
let mut opts = BytesOptions::default();
385+
if stored {
386+
opts = opts.set_stored();
387+
}
388+
if indexed {
389+
opts = opts.set_indexed();
390+
}
391+
if fast {
392+
opts = opts.set_fast();
393+
}
394+
395+
if let Some(builder) = builder.write().unwrap().as_mut() {
396+
builder.add_bytes_field(name, opts);
397+
} else {
398+
return Err(exceptions::PyValueError::new_err(
399+
"Schema builder object isn't valid anymore.",
400+
));
401+
}
402+
Ok(self.clone())
403+
}
404+
405+
/// Add an IP address field to the schema.
363406
///
364407
/// Args:
365408
/// name (str): The name of the field.
366-
fn add_bytes_field(&mut self, name: &str) -> PyResult<Self> {
409+
/// stored (bool, optional): If true sets the field as stored, the
410+
/// content of the field can be later restored from a Searcher.
411+
/// Defaults to False.
412+
/// indexed (bool, optional): If true sets the field to be indexed.
413+
/// fast (str, optional): Set the IP address options as a fast field. A
414+
/// fast field is a column-oriented fashion storage for tantivy. It
415+
/// is designed for the fast random access of some document fields
416+
/// given a document id.
417+
#[pyo3(signature = (
418+
name,
419+
stored = false,
420+
indexed = false,
421+
fast = false
422+
))]
423+
fn add_ip_addr_field(
424+
&mut self,
425+
name: &str,
426+
stored: bool,
427+
indexed: bool,
428+
fast: bool,
429+
) -> PyResult<Self> {
367430
let builder = &mut self.builder;
431+
let mut opts = IpAddrOptions::default();
432+
if stored {
433+
opts = opts.set_stored();
434+
}
435+
if indexed {
436+
opts = opts.set_indexed();
437+
}
438+
if fast {
439+
opts = opts.set_fast();
440+
}
368441

369442
if let Some(builder) = builder.write().unwrap().as_mut() {
370-
builder.add_bytes_field(name, INDEXED);
443+
builder.add_ip_addr_field(name, opts);
371444
} else {
372445
return Err(exceptions::PyValueError::new_err(
373446
"Schema builder object isn't valid anymore.",
374447
));
375448
}
449+
376450
Ok(self.clone())
377451
}
378452

tests/tantivy_test.py

Lines changed: 74 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import copy
44
import datetime
5+
import json
56
import tantivy
67
import pickle
78
import pytest
@@ -365,7 +366,9 @@ def test_order_by_search(self):
365366
searched_doc = index.searcher().doc(doc_address)
366367
assert searched_doc["title"] == ["Test title"]
367368

368-
result = searcher.search(query, 10, order_by_field="order", order=tantivy.Order.Asc)
369+
result = searcher.search(
370+
query, 10, order_by_field="order", order=tantivy.Order.Asc
371+
)
369372

370373
assert len(result.hits) == 3
371374

@@ -443,7 +446,7 @@ def test_with_merges(self):
443446

444447
assert searcher.num_segments < 8
445448

446-
def test_doc_from_dict_schema_validation(self):
449+
def test_doc_from_dict_numeric_validation(self):
447450
schema = (
448451
SchemaBuilder()
449452
.add_unsigned_field("unsigned")
@@ -504,6 +507,70 @@ def test_doc_from_dict_schema_validation(self):
504507
schema,
505508
)
506509

510+
def test_doc_from_dict_bytes_validation(self):
511+
schema = SchemaBuilder().add_bytes_field("bytes").build()
512+
513+
good = Document.from_dict({"bytes": b"hello"}, schema)
514+
good = Document.from_dict({"bytes": [[1, 2, 3], [4, 5, 6]]}, schema)
515+
good = Document.from_dict({"bytes": [1, 2, 3]}, schema)
516+
517+
with pytest.raises(ValueError):
518+
bad = Document.from_dict({"bytes": [1, 2, 256]}, schema)
519+
520+
with pytest.raises(ValueError):
521+
bad = Document.from_dict({"bytes": "hello"}, schema)
522+
523+
with pytest.raises(ValueError):
524+
bad = Document.from_dict({"bytes": [1024, "there"]}, schema)
525+
526+
def test_doc_from_dict_ip_addr_validation(self):
527+
schema = SchemaBuilder().add_ip_addr_field("ip").build()
528+
529+
good = Document.from_dict({"ip": "127.0.0.1"}, schema)
530+
good = Document.from_dict({"ip": "::1"}, schema)
531+
532+
with pytest.raises(ValueError):
533+
bad = Document.from_dict({"ip": 12309812348}, schema)
534+
535+
with pytest.raises(ValueError):
536+
bad = Document.from_dict({"ip": "256.100.0.1"}, schema)
537+
538+
with pytest.raises(ValueError):
539+
bad = Document.from_dict(
540+
{"ip": "1234:5678:9ABC:DEF0:1234:5678:9ABC:DEF0:1234"}, schema
541+
)
542+
543+
with pytest.raises(ValueError):
544+
bad = Document.from_dict(
545+
{"ip": "1234:5678:9ABC:DEF0:1234:5678:9ABC:GHIJ"}, schema
546+
)
547+
548+
def test_doc_from_dict_json_validation(self):
549+
# Test implicit JSON
550+
good = Document.from_dict({"dict": {"hello": "world"}})
551+
552+
schema = SchemaBuilder().add_json_field("json").build()
553+
554+
good = Document.from_dict({"json": {}}, schema)
555+
good = Document.from_dict({"json": {"hello": "world"}}, schema)
556+
good = Document.from_dict(
557+
{"nested": {"hello": ["world", "!"]}, "numbers": [1, 2, 3]}, schema
558+
)
559+
560+
list_of_jsons = [
561+
{"hello": "world"},
562+
{"nested": {"hello": ["world", "!"]}, "numbers": [1, 2, 3]},
563+
]
564+
good = Document.from_dict({"json": list_of_jsons}, schema)
565+
566+
good = Document.from_dict({"json": json.dumps(list_of_jsons[1])}, schema)
567+
568+
with pytest.raises(ValueError):
569+
bad = Document.from_dict({"json": 123}, schema)
570+
571+
with pytest.raises(ValueError):
572+
bad = Document.from_dict({"json": "hello"}, schema)
573+
507574
def test_search_result_eq(self, ram_index, spanish_index):
508575
eng_index = ram_index
509576
eng_query = eng_index.parse_query("sea whale", ["title", "body"])
@@ -650,10 +717,6 @@ def test_document_with_facet(self):
650717
doc = tantivy.Document(facet=facet)
651718
assert doc["facet"][0].to_path() == ["asia/oceania", "fiji"]
652719

653-
def test_document_error(self):
654-
with pytest.raises(ValueError):
655-
tantivy.Document(name={})
656-
657720
def test_document_eq(self):
658721
doc1 = tantivy.Document(name="Bill", reference=[1, 2])
659722
doc2 = tantivy.Document.from_dict({"name": "Bill", "reference": [1, 2]})
@@ -848,9 +911,11 @@ def test_document_snippet(self, dir_index):
848911
result = searcher.search(query)
849912
assert len(result.hits) == 1
850913

851-
snippet_generator = SnippetGenerator.create(searcher, query, doc_schema, "title")
914+
snippet_generator = SnippetGenerator.create(
915+
searcher, query, doc_schema, "title"
916+
)
852917

853-
for (score, doc_address) in result.hits:
918+
for score, doc_address in result.hits:
854919
doc = searcher.doc(doc_address)
855920
snippet = snippet_generator.snippet_from_doc(doc)
856921
highlights = snippet.highlighted()
@@ -859,4 +924,4 @@ def test_document_snippet(self, dir_index):
859924
assert first.start == 20
860925
assert first.end == 23
861926
html_snippet = snippet.to_html()
862-
assert html_snippet == 'The Old Man and the <b>Sea</b>'
927+
assert html_snippet == "The Old Man and the <b>Sea</b>"

0 commit comments

Comments
 (0)