Skip to content

Commit 130792c

Browse files
committed
Started support for sparsevec type
1 parent f379fd2 commit 130792c

File tree

3 files changed

+42
-11
lines changed

3 files changed

+42
-11
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
## 0.2.0 (unreleased)
2+
3+
- Added support for `sparsevec` type
4+
15
## 0.1.1 (2023-03-09)
26

37
- Added `new` method

examples/sparse/example.lua

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ local cjson = require("cjson")
99
local http = require("socket.http")
1010
local ltn12 = require("ltn12")
1111
local pgmoon = require("pgmoon")
12+
local pgvector = require("./src/pgvector")
1213

1314
local pg = pgmoon.new({
1415
database = "pgvector_example",
@@ -20,14 +21,6 @@ assert(pg:query("CREATE EXTENSION IF NOT EXISTS vector"))
2021
assert(pg:query("DROP TABLE IF EXISTS documents"))
2122
assert(pg:query("CREATE TABLE documents (id bigserial PRIMARY KEY, content text, embedding sparsevec(30522))"))
2223

23-
function sparsevec(elements, dim)
24-
local e = {}
25-
for k, v in pairs(elements) do
26-
table.insert(e, k .. ":" .. v)
27-
end
28-
return "{" .. table.concat(e, ",") .. "}/" .. dim
29-
end
30-
3124
function embed(inputs)
3225
local url = "http://localhost:3000/embed_sparse"
3326
local data = {
@@ -67,12 +60,12 @@ local documents = {
6760
local embeddings = embed(documents)
6861
for i, content in ipairs(documents) do
6962
local embedding = embeddings[i]
70-
assert(pg:query("INSERT INTO documents (content, embedding) VALUES ($1, $2::text::sparsevec)", content, sparsevec(embedding, 30522)))
63+
assert(pg:query("INSERT INTO documents (content, embedding) VALUES ($1, $2)", content, pgvector.sparsevec(embedding, 30522)))
7164
end
7265

7366
local query = "forest"
7467
local embedding = embed({query})[1]
75-
local res = assert(pg:query("SELECT content FROM documents ORDER BY embedding <#> $1::text::sparsevec LIMIT 5", sparsevec(embedding, 30522)))
68+
local res = assert(pg:query("SELECT content FROM documents ORDER BY embedding <#> $1 LIMIT 5", pgvector.sparsevec(embedding, 30522)))
7669
for i, row in ipairs(res) do
7770
print(row["content"])
7871
end

src/pgvector.lua

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,12 @@ local vector_mt = {
66
end
77
}
88

9+
local sparsevec_mt = {
10+
pgmoon_serialize = function(v)
11+
return 0, sparsevec_serialize(v)
12+
end
13+
}
14+
915
function pgvector.new(v)
1016
local vec = {}
1117
for _, x in ipairs(v) do
@@ -30,10 +36,38 @@ function pgvector.deserialize(v)
3036
return setmetatable(vec, vector_mt)
3137
end
3238

39+
function pgvector.sparsevec(elements, dim)
40+
for k, v in pairs(elements) do
41+
assert(type(k) == "number")
42+
assert(type(v) == "number")
43+
end
44+
assert(type(dim) == "number")
45+
46+
local vec = {}
47+
vec["elements"] = elements
48+
vec["dim"] = dim
49+
return setmetatable(vec, sparsevec_mt)
50+
end
51+
52+
function sparsevec_serialize(vec)
53+
local elements = {}
54+
for i, v in pairs(vec["elements"]) do
55+
table.insert(elements, tonumber(i) .. ":" .. tonumber(v))
56+
end
57+
return "{" .. table.concat(elements, ",") .. "}/" .. tonumber(vec["dim"])
58+
end
59+
60+
function sparsevec_deserialize(v)
61+
-- TODO
62+
end
63+
3364
function pgvector.setup_vector(pg)
34-
local row = pg:query("SELECT to_regtype('vector')::oid AS vector_oid")[1]
65+
local row = pg:query("SELECT to_regtype('vector')::oid AS vector_oid, to_regtype('sparsevec')::oid AS sparsevec_oid")[1]
3566
assert(row["vector_oid"], "vector type not found in the database")
3667
pg:set_type_deserializer(row["vector_oid"], "vector", function(self, v) return pgvector.deserialize(v) end)
68+
-- if row["sparsevec_oid"] do
69+
-- pg:set_type_deserializer(row["sparsevec_oid"], "sparsevec", function(self, v) return sparsevec_deserialize(v) end)
70+
-- end
3771
end
3872

3973
return pgvector

0 commit comments

Comments
 (0)