Skip to content

Commit 1070b37

Browse files
committed
Vector: Fix type checking and compatibility with SQLAlchemy 1.x
1 parent 51e5874 commit 1070b37

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

src/sqlalchemy_cratedb/type/vector.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
- The type implementation might want to be accompanied by corresponding support
2626
for the `KNN_MATCH` function, similar to what the dialect already offers for
2727
fulltext search through its `Match` predicate.
28+
- After dropping support for SQLAlchemy 1.3, use
29+
`class FloatVector(sa.TypeDecorator[t.Sequence[float]]):`
2830
2931
## Origin
3032
This module is based on the corresponding pgvector implementation
@@ -44,7 +46,7 @@
4446
__all__ = ["FloatVector"]
4547

4648

47-
def from_db(value: t.Iterable) -> t.Optional[npt.ArrayLike]:
49+
def from_db(value: t.Iterable) -> t.Optional["npt.ArrayLike"]:
4850
import numpy as np
4951

5052
# from `pgvector.utils`
@@ -77,8 +79,7 @@ def to_db(value: t.Any, dim: t.Optional[int] = None) -> t.Optional[t.List]:
7779
return value
7880

7981

80-
class FloatVector(sa.TypeDecorator[t.Sequence[float]]):
81-
82+
class FloatVector(sa.TypeDecorator):
8283
"""
8384
An improved implementation of the `FloatVector` data type for CrateDB,
8485
compared to the previous implementation on behalf of the LangChain adapter.
@@ -146,14 +147,14 @@ def __init__(self, dimensions: int = None):
146147
def as_generic(self):
147148
return sa.ARRAY
148149

149-
def bind_processor(self, dialect: sa.Dialect) -> t.Callable:
150+
def bind_processor(self, dialect: sa.engine.Dialect) -> t.Callable:
150151
def process(value: t.Iterable) -> t.Optional[t.List]:
151152
return to_db(value, self.dimensions)
152153

153154
return process
154155

155-
def result_processor(self, dialect: sa.Dialect, coltype: t.Any) -> t.Callable:
156-
def process(value: t.Any) -> t.Optional[npt.ArrayLike]:
156+
def result_processor(self, dialect: sa.engine.Dialect, coltype: t.Any) -> t.Callable:
157+
def process(value: t.Any) -> t.Optional["npt.ArrayLike"]:
157158
return from_db(value)
158159

159160
return process

0 commit comments

Comments
 (0)