Skip to content

Commit 3069571

Browse files
committed
Improved example [skip ci]
1 parent e6b60f3 commit 3069571

File tree

1 file changed

+25
-25
lines changed

1 file changed

+25
-25
lines changed

examples/hybrid_search/src/main.rs

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -48,35 +48,11 @@ fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
4848
)?;
4949
}
5050

51-
let sql = "
52-
WITH semantic_search AS (
53-
SELECT id, RANK () OVER (ORDER BY embedding <=> $2) AS rank
54-
FROM documents
55-
ORDER BY embedding <=> $2
56-
LIMIT 20
57-
),
58-
keyword_search AS (
59-
SELECT id, RANK () OVER (ORDER BY ts_rank_cd(to_tsvector('english', content), query) DESC)
60-
FROM documents, plainto_tsquery('english', $1) query
61-
WHERE to_tsvector('english', content) @@ query
62-
ORDER BY ts_rank_cd(to_tsvector('english', content), query) DESC
63-
LIMIT 20
64-
)
65-
SELECT
66-
COALESCE(semantic_search.id, keyword_search.id) AS id,
67-
COALESCE(1.0 / ($3::double precision + semantic_search.rank), 0.0) +
68-
COALESCE(1.0 / ($3::double precision + keyword_search.rank), 0.0) AS score
69-
FROM semantic_search
70-
FULL OUTER JOIN keyword_search ON semantic_search.id = keyword_search.id
71-
ORDER BY score DESC
72-
LIMIT 5
73-
";
74-
7551
let query = "growling bear";
7652
let query_embedding = model.embed(query)?;
7753
let k = 60.0;
7854

79-
for row in client.query(sql, &[&query, &Vector::from(query_embedding), &k])? {
55+
for row in client.query(HYBRID_SQL, &[&query, &Vector::from(query_embedding), &k])? {
8056
let id: i32 = row.get(0);
8157
let score: f64 = row.get(1);
8258
println!("document: {}, RRF score: {}", id, score);
@@ -85,6 +61,30 @@ fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
8561
Ok(())
8662
}
8763

64+
const HYBRID_SQL: &str = "
65+
WITH semantic_search AS (
66+
SELECT id, RANK () OVER (ORDER BY embedding <=> $2) AS rank
67+
FROM documents
68+
ORDER BY embedding <=> $2
69+
LIMIT 20
70+
),
71+
keyword_search AS (
72+
SELECT id, RANK () OVER (ORDER BY ts_rank_cd(to_tsvector('english', content), query) DESC)
73+
FROM documents, plainto_tsquery('english', $1) query
74+
WHERE to_tsvector('english', content) @@ query
75+
ORDER BY ts_rank_cd(to_tsvector('english', content), query) DESC
76+
LIMIT 20
77+
)
78+
SELECT
79+
COALESCE(semantic_search.id, keyword_search.id) AS id,
80+
COALESCE(1.0 / ($3::double precision + semantic_search.rank), 0.0) +
81+
COALESCE(1.0 / ($3::double precision + keyword_search.rank), 0.0) AS score
82+
FROM semantic_search
83+
FULL OUTER JOIN keyword_search ON semantic_search.id = keyword_search.id
84+
ORDER BY score DESC
85+
LIMIT 5
86+
";
87+
8888
struct EmbeddingModel {
8989
tokenizer: Tokenizer,
9090
model: BertModel,

0 commit comments

Comments
 (0)