Skip to content

Commit b3476c2

Browse files
authored
Merge pull request #60 from ydb-platform/compatibility_sa_1_4
Attempt to support sqlalchemy 1.4+
2 parents 3c83092 + 3962c00 commit b3476c2

File tree

11 files changed

+750
-586
lines changed

11 files changed

+750
-586
lines changed

requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
sqlalchemy >= 2.0.7, < 3.0.0
1+
sqlalchemy >= 1.4.0, < 3.0.0
22
ydb >= 3.18.8
3-
ydb-dbapi >= 0.1.1
3+
ydb-dbapi >= 0.1.2

test/test_core.py

Lines changed: 30 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,13 @@
1414
from ydb_sqlalchemy import sqlalchemy as ydb_sa
1515
from ydb_sqlalchemy.sqlalchemy import types
1616

17+
if sa.__version__ >= "2.":
18+
from sqlalchemy import NullPool
19+
from sqlalchemy import QueuePool
20+
else:
21+
from sqlalchemy.pool import NullPool
22+
from sqlalchemy.pool import QueuePool
23+
1724

1825
def clear_sql(stm):
1926
return stm.replace("\n", " ").replace(" ", " ").strip()
@@ -94,7 +101,7 @@ def test_sa_crud(self, connection):
94101
(5, "c"),
95102
]
96103

97-
def test_cached_query(self, connection_no_trans: sa.Connection, connection: sa.Connection):
104+
def test_cached_query(self, connection_no_trans, connection):
98105
table = self.tables.test
99106

100107
with connection_no_trans.begin() as transaction:
@@ -249,7 +256,7 @@ def test_primitive_types(self, connection):
249256
assert row == (42, "Hello World!", 3.5, True)
250257

251258
def test_integer_types(self, connection):
252-
stmt = sa.Select(
259+
stmt = sa.select(
253260
sa.func.FormatType(sa.func.TypeOf(sa.bindparam("p_uint8", 8, types.UInt8))),
254261
sa.func.FormatType(sa.func.TypeOf(sa.bindparam("p_uint16", 16, types.UInt16))),
255262
sa.func.FormatType(sa.func.TypeOf(sa.bindparam("p_uint32", 32, types.UInt32))),
@@ -263,8 +270,8 @@ def test_integer_types(self, connection):
263270
result = connection.execute(stmt).fetchone()
264271
assert result == (b"Uint8", b"Uint16", b"Uint32", b"Uint64", b"Int8", b"Int16", b"Int32", b"Int64")
265272

266-
def test_datetime_types(self, connection: sa.Connection):
267-
stmt = sa.Select(
273+
def test_datetime_types(self, connection):
274+
stmt = sa.select(
268275
sa.func.FormatType(sa.func.TypeOf(sa.bindparam("p_datetime", datetime.datetime.now(), sa.DateTime))),
269276
sa.func.FormatType(sa.func.TypeOf(sa.bindparam("p_DATETIME", datetime.datetime.now(), sa.DATETIME))),
270277
sa.func.FormatType(sa.func.TypeOf(sa.bindparam("p_TIMESTAMP", datetime.datetime.now(), sa.TIMESTAMP))),
@@ -273,7 +280,7 @@ def test_datetime_types(self, connection: sa.Connection):
273280
result = connection.execute(stmt).fetchone()
274281
assert result == (b"Timestamp", b"Datetime", b"Timestamp")
275282

276-
def test_datetime_types_timezone(self, connection: sa.Connection):
283+
def test_datetime_types_timezone(self, connection):
277284
table = self.tables.test_datetime_types
278285
tzinfo = datetime.timezone(datetime.timedelta(hours=3, minutes=42))
279286

@@ -476,7 +483,8 @@ def define_tables(cls, metadata: sa.MetaData):
476483
Column("id", Integer, primary_key=True),
477484
)
478485

479-
def test_rollback(self, connection_no_trans: sa.Connection, connection: sa.Connection):
486+
@pytest.mark.skipif(sa.__version__ < "2.", reason="Something was different in SA<2, good to fix")
487+
def test_rollback(self, connection_no_trans, connection):
480488
table = self.tables.test
481489

482490
connection_no_trans.execution_options(isolation_level=IsolationLevel.SERIALIZABLE)
@@ -491,7 +499,7 @@ def test_rollback(self, connection_no_trans: sa.Connection, connection: sa.Conne
491499
result = cursor.fetchall()
492500
assert result == []
493501

494-
def test_commit(self, connection_no_trans: sa.Connection, connection: sa.Connection):
502+
def test_commit(self, connection_no_trans, connection):
495503
table = self.tables.test
496504

497505
connection_no_trans.execution_options(isolation_level=IsolationLevel.SERIALIZABLE)
@@ -506,9 +514,7 @@ def test_commit(self, connection_no_trans: sa.Connection, connection: sa.Connect
506514
assert set(result) == {(3,), (4,)}
507515

508516
@pytest.mark.parametrize("isolation_level", (IsolationLevel.SERIALIZABLE, IsolationLevel.SNAPSHOT_READONLY))
509-
def test_interactive_transaction(
510-
self, connection_no_trans: sa.Connection, connection: sa.Connection, isolation_level
511-
):
517+
def test_interactive_transaction(self, connection_no_trans, connection, isolation_level):
512518
table = self.tables.test
513519
dbapi_connection: dbapi.Connection = connection_no_trans.connection.dbapi_connection
514520

@@ -535,9 +541,7 @@ def test_interactive_transaction(
535541
IsolationLevel.AUTOCOMMIT,
536542
),
537543
)
538-
def test_not_interactive_transaction(
539-
self, connection_no_trans: sa.Connection, connection: sa.Connection, isolation_level
540-
):
544+
def test_not_interactive_transaction(self, connection_no_trans, connection, isolation_level):
541545
table = self.tables.test
542546
dbapi_connection: dbapi.Connection = connection_no_trans.connection.dbapi_connection
543547

@@ -573,7 +577,7 @@ class IsolationSettings(NamedTuple):
573577
IsolationLevel.SNAPSHOT_READONLY: IsolationSettings(ydb.QuerySnapshotReadOnly().name, True),
574578
}
575579

576-
def test_connection_set(self, connection_no_trans: sa.Connection):
580+
def test_connection_set(self, connection_no_trans):
577581
dbapi_connection: dbapi.Connection = connection_no_trans.connection.dbapi_connection
578582

579583
for sa_isolation_level, ydb_isolation_settings in self.YDB_ISOLATION_SETTINGS_MAP.items():
@@ -614,8 +618,8 @@ def ydb_pool(self, ydb_driver):
614618
session_pool.stop()
615619

616620
def test_sa_queue_pool_with_ydb_shared_session_pool(self, ydb_driver, ydb_pool):
617-
engine1 = sa.create_engine(config.db_url, poolclass=sa.QueuePool, connect_args={"ydb_session_pool": ydb_pool})
618-
engine2 = sa.create_engine(config.db_url, poolclass=sa.QueuePool, connect_args={"ydb_session_pool": ydb_pool})
621+
engine1 = sa.create_engine(config.db_url, poolclass=QueuePool, connect_args={"ydb_session_pool": ydb_pool})
622+
engine2 = sa.create_engine(config.db_url, poolclass=QueuePool, connect_args={"ydb_session_pool": ydb_pool})
619623

620624
with engine1.connect() as conn1, engine2.connect() as conn2:
621625
dbapi_conn1: dbapi.Connection = conn1.connection.dbapi_connection
@@ -629,8 +633,8 @@ def test_sa_queue_pool_with_ydb_shared_session_pool(self, ydb_driver, ydb_pool):
629633
assert not ydb_driver._stopped
630634

631635
def test_sa_null_pool_with_ydb_shared_session_pool(self, ydb_driver, ydb_pool):
632-
engine1 = sa.create_engine(config.db_url, poolclass=sa.NullPool, connect_args={"ydb_session_pool": ydb_pool})
633-
engine2 = sa.create_engine(config.db_url, poolclass=sa.NullPool, connect_args={"ydb_session_pool": ydb_pool})
636+
engine1 = sa.create_engine(config.db_url, poolclass=NullPool, connect_args={"ydb_session_pool": ydb_pool})
637+
engine2 = sa.create_engine(config.db_url, poolclass=NullPool, connect_args={"ydb_session_pool": ydb_pool})
634638

635639
with engine1.connect() as conn1, engine2.connect() as conn2:
636640
dbapi_conn1: dbapi.Connection = conn1.connection.dbapi_connection
@@ -861,7 +865,7 @@ def test_insert_in_name_and_field(self, connection):
861865
class TestSecondaryIndex(TestBase):
862866
__backend__ = True
863867

864-
def test_column_indexes(self, connection: sa.Connection, metadata: sa.MetaData):
868+
def test_column_indexes(self, connection, metadata: sa.MetaData):
865869
table = Table(
866870
"test_column_indexes/table",
867871
metadata,
@@ -884,7 +888,7 @@ def test_column_indexes(self, connection: sa.Connection, metadata: sa.MetaData):
884888
index1 = indexes_map["ix_test_column_indexes_table_index_col2"]
885889
assert index1.index_columns == ["index_col2"]
886890

887-
def test_async_index(self, connection: sa.Connection, metadata: sa.MetaData):
891+
def test_async_index(self, connection, metadata: sa.MetaData):
888892
table = Table(
889893
"test_async_index/table",
890894
metadata,
@@ -903,7 +907,7 @@ def test_async_index(self, connection: sa.Connection, metadata: sa.MetaData):
903907
assert set(index.index_columns) == {"index_col1", "index_col2"}
904908
# TODO: Check type after https://github.com/ydb-platform/ydb-python-sdk/issues/351
905909

906-
def test_cover_index(self, connection: sa.Connection, metadata: sa.MetaData):
910+
def test_cover_index(self, connection, metadata: sa.MetaData):
907911
table = Table(
908912
"test_cover_index/table",
909913
metadata,
@@ -922,7 +926,7 @@ def test_cover_index(self, connection: sa.Connection, metadata: sa.MetaData):
922926
assert set(index.index_columns) == {"index_col1"}
923927
# TODO: Check covered columns after https://github.com/ydb-platform/ydb-python-sdk/issues/409
924928

925-
def test_indexes_reflection(self, connection: sa.Connection, metadata: sa.MetaData):
929+
def test_indexes_reflection(self, connection, metadata: sa.MetaData):
926930
table = Table(
927931
"test_indexes_reflection/table",
928932
metadata,
@@ -948,7 +952,7 @@ def test_indexes_reflection(self, connection: sa.Connection, metadata: sa.MetaDa
948952
"test_async_cover_index": {"index_col1"},
949953
}
950954

951-
def test_index_simple_usage(self, connection: sa.Connection, metadata: sa.MetaData):
955+
def test_index_simple_usage(self, connection, metadata: sa.MetaData):
952956
persons = Table(
953957
"test_index_simple_usage/persons",
954958
metadata,
@@ -979,7 +983,7 @@ def test_index_simple_usage(self, connection: sa.Connection, metadata: sa.MetaDa
979983
cursor = connection.execute(select_stmt)
980984
assert cursor.scalar_one() == "Sarah Connor"
981985

982-
def test_index_with_join_usage(self, connection: sa.Connection, metadata: sa.MetaData):
986+
def test_index_with_join_usage(self, connection, metadata: sa.MetaData):
983987
persons = Table(
984988
"test_index_with_join_usage/persons",
985989
metadata,
@@ -1033,7 +1037,7 @@ def test_index_with_join_usage(self, connection: sa.Connection, metadata: sa.Met
10331037
cursor = connection.execute(select_stmt)
10341038
assert cursor.one() == ("Sarah Connor", "wanted")
10351039

1036-
def test_index_deletion(self, connection: sa.Connection, metadata: sa.MetaData):
1040+
def test_index_deletion(self, connection, metadata: sa.MetaData):
10371041
persons = Table(
10381042
"test_index_deletion/persons",
10391043
metadata,
@@ -1062,7 +1066,7 @@ def define_tables(cls, metadata: sa.MetaData):
10621066
Table("table", metadata, sa.Column("id", sa.Integer, primary_key=True))
10631067

10641068
@classmethod
1065-
def insert_data(cls, connection: sa.Connection):
1069+
def insert_data(cls, connection):
10661070
table = cls.tables["some_dir/nested_dir/table"]
10671071
root_table = cls.tables["table"]
10681072

test/test_suite.py

Lines changed: 35 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@
6868
from sqlalchemy.testing.suite.test_types import DateTimeTest as _DateTimeTest
6969
from sqlalchemy.testing.suite.test_types import IntegerTest as _IntegerTest
7070
from sqlalchemy.testing.suite.test_types import JSONTest as _JSONTest
71-
from sqlalchemy.testing.suite.test_types import NativeUUIDTest as _NativeUUIDTest
71+
7272
from sqlalchemy.testing.suite.test_types import NumericTest as _NumericTest
7373
from sqlalchemy.testing.suite.test_types import StringTest as _StringTest
7474
from sqlalchemy.testing.suite.test_types import (
@@ -78,14 +78,16 @@
7878
TimestampMicrosecondsTest as _TimestampMicrosecondsTest,
7979
)
8080
from sqlalchemy.testing.suite.test_types import TimeTest as _TimeTest
81-
from sqlalchemy.testing.suite.test_types import TrueDivTest as _TrueDivTest
8281

8382
from ydb_sqlalchemy.sqlalchemy import types as ydb_sa_types
8483

8584
test_types_suite = sqlalchemy.testing.suite.test_types
8685
col_creator = test_types_suite.Column
8786

8887

88+
OLD_SA = sa.__version__ < "2."
89+
90+
8991
def column_getter(*args, **kwargs):
9092
col = col_creator(*args, **kwargs)
9193
if col.name == "x":
@@ -275,30 +277,35 @@ class BinaryTest(_BinaryTest):
275277
pass
276278

277279

278-
class TrueDivTest(_TrueDivTest):
279-
@pytest.mark.skip("Unsupported builtin: FLOOR")
280-
def test_floordiv_numeric(self, connection, left, right, expected):
281-
pass
280+
if not OLD_SA:
281+
from sqlalchemy.testing.suite.test_types import TrueDivTest as _TrueDivTest
282282

283-
@pytest.mark.skip("Truediv unsupported for int")
284-
def test_truediv_integer(self, connection, left, right, expected):
285-
pass
283+
class TrueDivTest(_TrueDivTest):
284+
@pytest.mark.skip("Unsupported builtin: FLOOR")
285+
def test_floordiv_numeric(self, connection, left, right, expected):
286+
pass
286287

287-
@pytest.mark.skip("Truediv unsupported for int")
288-
def test_truediv_integer_bound(self, connection):
289-
pass
288+
@pytest.mark.skip("Truediv unsupported for int")
289+
def test_truediv_integer(self, connection, left, right, expected):
290+
pass
290291

291-
@pytest.mark.skip("Numeric is not Decimal")
292-
def test_truediv_numeric(self):
293-
# SqlAlchemy maybe eat Decimal and throw Double
294-
pass
292+
@pytest.mark.skip("Truediv unsupported for int")
293+
def test_truediv_integer_bound(self, connection):
294+
pass
295295

296-
@testing.combinations(("6.25", "2.5", 2.5), argnames="left, right, expected")
297-
def test_truediv_float(self, connection, left, right, expected):
298-
eq_(
299-
connection.scalar(select(literal_column(left, type_=sa.Float()) / literal_column(right, type_=sa.Float()))),
300-
expected,
301-
)
296+
@pytest.mark.skip("Numeric is not Decimal")
297+
def test_truediv_numeric(self):
298+
# SqlAlchemy maybe eat Decimal and throw Double
299+
pass
300+
301+
@testing.combinations(("6.25", "2.5", 2.5), argnames="left, right, expected")
302+
def test_truediv_float(self, connection, left, right, expected):
303+
eq_(
304+
connection.scalar(
305+
select(literal_column(left, type_=sa.Float()) / literal_column(right, type_=sa.Float()))
306+
),
307+
expected,
308+
)
302309

303310

304311
class ExistsTest(_ExistsTest):
@@ -539,9 +546,12 @@ def test_from_as_table(self, connection):
539546
eq_(connection.execute(sa.select(table)).fetchall(), [(1,), (2,), (3,)])
540547

541548

542-
@pytest.mark.skip("uuid unsupported for columns")
543-
class NativeUUIDTest(_NativeUUIDTest):
544-
pass
549+
if not OLD_SA:
550+
from sqlalchemy.testing.suite.test_types import NativeUUIDTest as _NativeUUIDTest
551+
552+
@pytest.mark.skip("uuid unsupported for columns")
553+
class NativeUUIDTest(_NativeUUIDTest):
554+
pass
545555

546556

547557
@pytest.mark.skip("unsupported Time data type")

tox.ini

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,4 +68,5 @@ max-line-length = 120
6868
ignore=E203,W503
6969
per-file-ignores =
7070
ydb_sqlalchemy/__init__.py: F401
71+
ydb_sqlalchemy/sqlalchemy/compiler/__init__.py: F401
7172
exclude=*_pb2.py,*_grpc.py,.venv,.git,.tox,dist,doc,*egg,docs/*

0 commit comments

Comments
 (0)