Skip to content

Commit 135cd53

Browse files
committed
Simplified code
1 parent 07cccce commit 135cd53

11 files changed

Lines changed: 26 additions & 46 deletions

File tree

pgvector/asyncpg/register.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ async def register_vector(conn: Connection, schema: str = 'public') -> None:
77
'vector',
88
schema=schema,
99
encoder=Vector._to_db_binary,
10-
decoder=Vector._from_db_binary,
10+
decoder=Vector.from_binary,
1111
format='binary'
1212
)
1313

@@ -16,15 +16,15 @@ async def register_vector(conn: Connection, schema: str = 'public') -> None:
1616
'halfvec',
1717
schema=schema,
1818
encoder=HalfVector._to_db_binary,
19-
decoder=HalfVector._from_db_binary,
19+
decoder=HalfVector.from_binary,
2020
format='binary'
2121
)
2222

2323
await conn.set_type_codec(
2424
'sparsevec',
2525
schema=schema,
2626
encoder=SparseVector._to_db_binary,
27-
decoder=SparseVector._from_db_binary,
27+
decoder=SparseVector.from_binary,
2828
format='binary'
2929
)
3030
except ValueError as e:

pgvector/django/halfvec.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def to_python(self, value: Any) -> HalfVector | None:
3131
if value is None or isinstance(value, HalfVector):
3232
return value
3333
elif isinstance(value, str):
34-
return HalfVector._from_db(value)
34+
return HalfVector.from_text(value)
3535
else:
3636
return HalfVector(value)
3737

pgvector/django/vector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def to_python(self, value: Any) -> Vector | None:
3131
if value is None or isinstance(value, Vector):
3232
return value
3333
elif isinstance(value, str):
34-
return Vector._from_db(value)
34+
return Vector.from_text(value)
3535
else:
3636
return Vector(value)
3737

pgvector/halfvec.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -96,10 +96,3 @@ def _from_db(cls, value: str | HalfVector | None) -> HalfVector | None:
9696
return value
9797

9898
return cls.from_text(value)
99-
100-
@classmethod
101-
def _from_db_binary(cls, value: bytes | HalfVector | None) -> HalfVector | None:
102-
if value is None or isinstance(value, HalfVector):
103-
return value
104-
105-
return cls.from_binary(value)

pgvector/pg8000/register.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def register_vector(conn: Connection) -> None:
1111
raise RuntimeError('vector type not found in the database')
1212

1313
conn.register_out_adapter(Vector, Vector._to_db)
14-
conn.register_in_adapter(type_info['vector'], Vector._from_db)
14+
conn.register_in_adapter(type_info['vector'], Vector.from_text)
1515

1616
try:
1717
import numpy as np
@@ -21,8 +21,8 @@ def register_vector(conn: Connection) -> None:
2121

2222
if 'halfvec' in type_info:
2323
conn.register_out_adapter(HalfVector, HalfVector._to_db)
24-
conn.register_in_adapter(type_info['halfvec'], HalfVector._from_db)
24+
conn.register_in_adapter(type_info['halfvec'], HalfVector.from_text)
2525

2626
if 'sparsevec' in type_info:
2727
conn.register_out_adapter(SparseVector, SparseVector._to_db)
28-
conn.register_in_adapter(type_info['sparsevec'], SparseVector._from_db)
28+
conn.register_in_adapter(type_info['sparsevec'], SparseVector.from_text)

pgvector/psycopg/bit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,14 @@ class BitDumper(Dumper):
1212
format = Format.TEXT
1313

1414
def dump(self, obj: Bit) -> Buffer | None:
15-
return Bit._to_db(obj).encode('utf8')
15+
return obj.to_text().encode('utf8')
1616

1717

1818
class BitBinaryDumper(BitDumper):
1919
format = Format.BINARY
2020

2121
def dump(self, obj: Bit) -> Buffer | None:
22-
return Bit._to_db_binary(obj)
22+
return obj.to_binary()
2323

2424

2525
def register_bit_info(context: BaseConnection[Any], info: TypeInfo | None) -> None:

pgvector/psycopg/halfvec.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,14 @@ class HalfVectorDumper(Dumper):
1212
format = Format.TEXT
1313

1414
def dump(self, obj: HalfVector) -> Buffer | None:
15-
value = HalfVector._to_db(obj)
16-
return value if value is None else value.encode('utf8')
15+
return obj.to_text().encode('utf8')
1716

1817

1918
class HalfVectorBinaryDumper(HalfVectorDumper):
2019
format = Format.BINARY
2120

2221
def dump(self, obj: HalfVector) -> Buffer | None:
23-
return HalfVector._to_db_binary(obj)
22+
return obj.to_binary()
2423

2524

2625
class HalfVectorLoader(Loader):
@@ -29,7 +28,7 @@ class HalfVectorLoader(Loader):
2928
def load(self, data: Buffer) -> HalfVector | None:
3029
if isinstance(data, memoryview):
3130
data = bytes(data)
32-
return HalfVector._from_db(data.decode('utf8'))
31+
return HalfVector.from_text(data.decode('utf8'))
3332

3433

3534
class HalfVectorBinaryLoader(HalfVectorLoader):
@@ -38,7 +37,7 @@ class HalfVectorBinaryLoader(HalfVectorLoader):
3837
def load(self, data: Buffer) -> HalfVector | None:
3938
if isinstance(data, (bytearray, memoryview)):
4039
data = bytes(data)
41-
return HalfVector._from_db_binary(data)
40+
return HalfVector.from_binary(data)
4241

4342

4443
def register_halfvec_info(context: BaseConnection[Any], info: TypeInfo) -> None:

pgvector/psycopg/sparsevec.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,14 @@ class SparseVectorDumper(Dumper):
1212
format = Format.TEXT
1313

1414
def dump(self, obj: SparseVector) -> Buffer | None:
15-
value = SparseVector._to_db(obj)
16-
return value if value is None else value.encode('utf8')
15+
return obj.to_text().encode('utf8')
1716

1817

1918
class SparseVectorBinaryDumper(SparseVectorDumper):
2019
format = Format.BINARY
2120

2221
def dump(self, obj: SparseVector) -> Buffer | None:
23-
return SparseVector._to_db_binary(obj)
22+
return obj.to_binary()
2423

2524

2625
class SparseVectorLoader(Loader):
@@ -29,7 +28,7 @@ class SparseVectorLoader(Loader):
2928
def load(self, data: Buffer) -> SparseVector | None:
3029
if isinstance(data, memoryview):
3130
data = bytes(data)
32-
return SparseVector._from_db(data.decode('utf8'))
31+
return SparseVector.from_text(data.decode('utf8'))
3332

3433

3534
class SparseVectorBinaryLoader(SparseVectorLoader):
@@ -38,7 +37,7 @@ class SparseVectorBinaryLoader(SparseVectorLoader):
3837
def load(self, data: Buffer) -> SparseVector | None:
3938
if isinstance(data, (bytearray, memoryview)):
4039
data = bytes(data)
41-
return SparseVector._from_db_binary(data)
40+
return SparseVector.from_binary(data)
4241

4342

4443
def register_sparsevec_info(context: BaseConnection[Any], info: TypeInfo) -> None:

pgvector/psycopg/vector.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,18 @@ class VectorDumper(Dumper):
1616
format = Format.TEXT
1717

1818
def dump(self, obj: Vector | np.ndarray) -> Buffer | None:
19-
value = Vector._to_db(obj)
20-
return value if value is None else value.encode('utf8')
19+
if not isinstance(obj, Vector):
20+
obj = Vector(obj)
21+
return obj.to_text().encode('utf8')
2122

2223

2324
class VectorBinaryDumper(VectorDumper):
2425
format = Format.BINARY
2526

2627
def dump(self, obj: Vector | np.ndarray) -> Buffer | None:
27-
return Vector._to_db_binary(obj)
28+
if not isinstance(obj, Vector):
29+
obj = Vector(obj)
30+
return obj.to_binary()
2831

2932

3033
class VectorLoader(Loader):
@@ -33,7 +36,7 @@ class VectorLoader(Loader):
3336
def load(self, data: Buffer) -> Vector | None:
3437
if isinstance(data, memoryview):
3538
data = bytes(data)
36-
return Vector._from_db(data.decode('utf8'))
39+
return Vector.from_text(data.decode('utf8'))
3740

3841

3942
class VectorBinaryLoader(VectorLoader):
@@ -42,7 +45,7 @@ class VectorBinaryLoader(VectorLoader):
4245
def load(self, data: Buffer) -> Vector | None:
4346
if isinstance(data, (bytearray, memoryview)):
4447
data = bytes(data)
45-
return Vector._from_db_binary(data)
48+
return Vector.from_binary(data)
4649

4750

4851
def register_vector_info(context: BaseConnection[Any], info: TypeInfo | None) -> None:

pgvector/sparsevec.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -174,10 +174,3 @@ def _from_db(cls, value: str | SparseVector | None) -> SparseVector | None:
174174
return value
175175

176176
return cls.from_text(value)
177-
178-
@classmethod
179-
def _from_db_binary(cls, value: bytes | SparseVector | None) -> SparseVector | None:
180-
if value is None or isinstance(value, SparseVector):
181-
return value
182-
183-
return cls.from_binary(value)

0 commit comments

Comments
 (0)