Skip to content

Commit bdc7747

Browse files
committed
Added Candle example [skip ci]
1 parent c3de6ca commit bdc7747

File tree

3 files changed

+111
-0
lines changed

3 files changed

+111
-0
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ Or check out some examples:
1818

1919
- [Embeddings](https://github.com/pgvector/pgvector-rust/blob/master/examples/openai/src/main.rs) with OpenAI
2020
- [Binary embeddings](https://github.com/pgvector/pgvector-rust/blob/master/examples/cohere/src/main.rs) with Cohere
21+
- [Sentence embeddings](https://github.com/pgvector/pgvector-rust/blob/master/examples/candle/src/main.rs) with Candle
2122
- [Recommendations](https://github.com/pgvector/pgvector-rust/blob/master/examples/disco/src/main.rs) with Disco
2223
- [Bulk loading](https://github.com/pgvector/pgvector-rust/blob/master/examples/loading/src/main.rs) with `COPY`
2324

examples/candle/Cargo.toml

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
[package]
2+
name = "example"
3+
version = "0.1.0"
4+
edition = "2021"
5+
publish = false
6+
7+
[dependencies]
8+
candle-core = "0.6"
9+
candle-nn = "0.6"
10+
candle-transformers = "0.6"
11+
hf-hub = "0.3"
12+
pgvector = { path = "../..", features = ["postgres"] }
13+
postgres = "0.19"
14+
serde_json = "1"
15+
tokenizers = "0.19"

examples/candle/src/main.rs

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
// https://github.com/huggingface/candle/tree/main/candle-examples/examples/bert
2+
// https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2
3+
4+
use candle_core::{Device, Tensor};
5+
use candle_nn::VarBuilder;
6+
use candle_transformers::models::bert::{BertModel, Config, DTYPE};
7+
use hf_hub::api::sync::Api;
8+
use pgvector::Vector;
9+
use postgres::{Client, NoTls};
10+
use std::error::Error;
11+
use std::fs::read_to_string;
12+
use tokenizers::{PaddingParams, PaddingStrategy, Tokenizer};
13+
14+
fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
15+
let mut client = Client::configure()
16+
.host("localhost")
17+
.dbname("pgvector_example")
18+
.user(std::env::var("USER")?.as_str())
19+
.connect(NoTls)?;
20+
21+
client.execute("CREATE EXTENSION IF NOT EXISTS vector", &[])?;
22+
client.execute("DROP TABLE IF EXISTS documents", &[])?;
23+
client.execute(
24+
"CREATE TABLE documents (id serial PRIMARY KEY, content text, embedding vector(384))",
25+
&[],
26+
)?;
27+
28+
let model = EmbeddingModel::new("sentence-transformers/all-MiniLM-L6-v2")?;
29+
30+
let input = [
31+
"The dog is barking",
32+
"The cat is purring",
33+
"The bear is growling",
34+
];
35+
let embeddings = input
36+
.iter()
37+
.map(|text| model.embed(text))
38+
.collect::<Result<Vec<_>, _>>()?;
39+
40+
for (content, embedding) in input.iter().zip(embeddings) {
41+
client.execute(
42+
"INSERT INTO documents (content, embedding) VALUES ($1, $2)",
43+
&[&content, &Vector::from(embedding)],
44+
)?;
45+
}
46+
47+
let document_id = 2;
48+
for row in client.query("SELECT content FROM documents WHERE id != $1 ORDER BY embedding <=> (SELECT embedding FROM documents WHERE id = $1) LIMIT 5", &[&document_id])? {
49+
let content: &str = row.get(0);
50+
println!("{}", content);
51+
}
52+
53+
Ok(())
54+
}
55+
56+
struct EmbeddingModel {
57+
tokenizer: Tokenizer,
58+
model: BertModel,
59+
}
60+
61+
impl EmbeddingModel {
62+
pub fn new(model_id: &str) -> Result<Self, Box<dyn Error + Send + Sync>> {
63+
let api = Api::new()?;
64+
let repo = api.model(model_id.to_string());
65+
let tokenizer_path = repo.get("tokenizer.json")?;
66+
let config_path = repo.get("config.json")?;
67+
let weights_path = repo.get("model.safetensors")?;
68+
69+
let mut tokenizer = Tokenizer::from_file(tokenizer_path)?;
70+
let padding = PaddingParams {
71+
strategy: PaddingStrategy::BatchLongest,
72+
..Default::default()
73+
};
74+
tokenizer.with_padding(Some(padding));
75+
76+
let device = Device::Cpu;
77+
let config: Config = serde_json::from_str(&read_to_string(config_path)?)?;
78+
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[weights_path], DTYPE, &device)? };
79+
let model = BertModel::load(vb, &config)?;
80+
81+
Ok(Self { tokenizer, model })
82+
}
83+
84+
// embed one at a time since BertModel does not support attention mask
85+
// https://github.com/huggingface/candle/issues/1798
86+
fn embed(&self, text: &str) -> Result<Vec<f32>, Box<dyn Error + Send + Sync>> {
87+
let tokens = self.tokenizer.encode(text, true)?;
88+
let token_ids = Tensor::new(vec![tokens.get_ids().to_vec()], &self.model.device)?;
89+
let token_type_ids = token_ids.zeros_like()?;
90+
let embeddings = self.model.forward(&token_ids, &token_type_ids)?;
91+
let embeddings = (embeddings.sum(1)? / (embeddings.dim(1)? as f64))?;
92+
let embeddings = embeddings.broadcast_div(&embeddings.sqr()?.sum_keepdim(1)?.sqrt()?)?;
93+
Ok(embeddings.squeeze(0)?.to_vec1::<f32>()?)
94+
}
95+
}

0 commit comments

Comments
 (0)