14
14
from ydb_sqlalchemy import sqlalchemy as ydb_sa
15
15
from ydb_sqlalchemy .sqlalchemy import types
16
16
17
+ if sa .__version__ >= "2." :
18
+ from sqlalchemy import NullPool
19
+ from sqlalchemy import QueuePool
20
+ else :
21
+ from sqlalchemy .pool import NullPool
22
+ from sqlalchemy .pool import QueuePool
23
+
17
24
18
25
def clear_sql (stm ):
19
26
return stm .replace ("\n " , " " ).replace (" " , " " ).strip ()
@@ -94,7 +101,7 @@ def test_sa_crud(self, connection):
94
101
(5 , "c" ),
95
102
]
96
103
97
- def test_cached_query (self , connection_no_trans : sa . Connection , connection : sa . Connection ):
104
+ def test_cached_query (self , connection_no_trans , connection ):
98
105
table = self .tables .test
99
106
100
107
with connection_no_trans .begin () as transaction :
@@ -249,7 +256,7 @@ def test_primitive_types(self, connection):
249
256
assert row == (42 , "Hello World!" , 3.5 , True )
250
257
251
258
def test_integer_types (self , connection ):
252
- stmt = sa .Select (
259
+ stmt = sa .select (
253
260
sa .func .FormatType (sa .func .TypeOf (sa .bindparam ("p_uint8" , 8 , types .UInt8 ))),
254
261
sa .func .FormatType (sa .func .TypeOf (sa .bindparam ("p_uint16" , 16 , types .UInt16 ))),
255
262
sa .func .FormatType (sa .func .TypeOf (sa .bindparam ("p_uint32" , 32 , types .UInt32 ))),
@@ -263,8 +270,8 @@ def test_integer_types(self, connection):
263
270
result = connection .execute (stmt ).fetchone ()
264
271
assert result == (b"Uint8" , b"Uint16" , b"Uint32" , b"Uint64" , b"Int8" , b"Int16" , b"Int32" , b"Int64" )
265
272
266
- def test_datetime_types (self , connection : sa . Connection ):
267
- stmt = sa .Select (
273
+ def test_datetime_types (self , connection ):
274
+ stmt = sa .select (
268
275
sa .func .FormatType (sa .func .TypeOf (sa .bindparam ("p_datetime" , datetime .datetime .now (), sa .DateTime ))),
269
276
sa .func .FormatType (sa .func .TypeOf (sa .bindparam ("p_DATETIME" , datetime .datetime .now (), sa .DATETIME ))),
270
277
sa .func .FormatType (sa .func .TypeOf (sa .bindparam ("p_TIMESTAMP" , datetime .datetime .now (), sa .TIMESTAMP ))),
@@ -273,7 +280,7 @@ def test_datetime_types(self, connection: sa.Connection):
273
280
result = connection .execute (stmt ).fetchone ()
274
281
assert result == (b"Timestamp" , b"Datetime" , b"Timestamp" )
275
282
276
- def test_datetime_types_timezone (self , connection : sa . Connection ):
283
+ def test_datetime_types_timezone (self , connection ):
277
284
table = self .tables .test_datetime_types
278
285
tzinfo = datetime .timezone (datetime .timedelta (hours = 3 , minutes = 42 ))
279
286
@@ -476,7 +483,8 @@ def define_tables(cls, metadata: sa.MetaData):
476
483
Column ("id" , Integer , primary_key = True ),
477
484
)
478
485
479
- def test_rollback (self , connection_no_trans : sa .Connection , connection : sa .Connection ):
486
+ @pytest .mark .skipif (sa .__version__ < "2." , reason = "Something was different in SA<2, good to fix" )
487
+ def test_rollback (self , connection_no_trans , connection ):
480
488
table = self .tables .test
481
489
482
490
connection_no_trans .execution_options (isolation_level = IsolationLevel .SERIALIZABLE )
@@ -491,7 +499,7 @@ def test_rollback(self, connection_no_trans: sa.Connection, connection: sa.Conne
491
499
result = cursor .fetchall ()
492
500
assert result == []
493
501
494
- def test_commit (self , connection_no_trans : sa . Connection , connection : sa . Connection ):
502
+ def test_commit (self , connection_no_trans , connection ):
495
503
table = self .tables .test
496
504
497
505
connection_no_trans .execution_options (isolation_level = IsolationLevel .SERIALIZABLE )
@@ -506,9 +514,7 @@ def test_commit(self, connection_no_trans: sa.Connection, connection: sa.Connect
506
514
assert set (result ) == {(3 ,), (4 ,)}
507
515
508
516
@pytest .mark .parametrize ("isolation_level" , (IsolationLevel .SERIALIZABLE , IsolationLevel .SNAPSHOT_READONLY ))
509
- def test_interactive_transaction (
510
- self , connection_no_trans : sa .Connection , connection : sa .Connection , isolation_level
511
- ):
517
+ def test_interactive_transaction (self , connection_no_trans , connection , isolation_level ):
512
518
table = self .tables .test
513
519
dbapi_connection : dbapi .Connection = connection_no_trans .connection .dbapi_connection
514
520
@@ -535,9 +541,7 @@ def test_interactive_transaction(
535
541
IsolationLevel .AUTOCOMMIT ,
536
542
),
537
543
)
538
- def test_not_interactive_transaction (
539
- self , connection_no_trans : sa .Connection , connection : sa .Connection , isolation_level
540
- ):
544
+ def test_not_interactive_transaction (self , connection_no_trans , connection , isolation_level ):
541
545
table = self .tables .test
542
546
dbapi_connection : dbapi .Connection = connection_no_trans .connection .dbapi_connection
543
547
@@ -573,7 +577,7 @@ class IsolationSettings(NamedTuple):
573
577
IsolationLevel .SNAPSHOT_READONLY : IsolationSettings (ydb .QuerySnapshotReadOnly ().name , True ),
574
578
}
575
579
576
- def test_connection_set (self , connection_no_trans : sa . Connection ):
580
+ def test_connection_set (self , connection_no_trans ):
577
581
dbapi_connection : dbapi .Connection = connection_no_trans .connection .dbapi_connection
578
582
579
583
for sa_isolation_level , ydb_isolation_settings in self .YDB_ISOLATION_SETTINGS_MAP .items ():
@@ -614,8 +618,8 @@ def ydb_pool(self, ydb_driver):
614
618
session_pool .stop ()
615
619
616
620
def test_sa_queue_pool_with_ydb_shared_session_pool (self , ydb_driver , ydb_pool ):
617
- engine1 = sa .create_engine (config .db_url , poolclass = sa . QueuePool , connect_args = {"ydb_session_pool" : ydb_pool })
618
- engine2 = sa .create_engine (config .db_url , poolclass = sa . QueuePool , connect_args = {"ydb_session_pool" : ydb_pool })
621
+ engine1 = sa .create_engine (config .db_url , poolclass = QueuePool , connect_args = {"ydb_session_pool" : ydb_pool })
622
+ engine2 = sa .create_engine (config .db_url , poolclass = QueuePool , connect_args = {"ydb_session_pool" : ydb_pool })
619
623
620
624
with engine1 .connect () as conn1 , engine2 .connect () as conn2 :
621
625
dbapi_conn1 : dbapi .Connection = conn1 .connection .dbapi_connection
@@ -629,8 +633,8 @@ def test_sa_queue_pool_with_ydb_shared_session_pool(self, ydb_driver, ydb_pool):
629
633
assert not ydb_driver ._stopped
630
634
631
635
def test_sa_null_pool_with_ydb_shared_session_pool (self , ydb_driver , ydb_pool ):
632
- engine1 = sa .create_engine (config .db_url , poolclass = sa . NullPool , connect_args = {"ydb_session_pool" : ydb_pool })
633
- engine2 = sa .create_engine (config .db_url , poolclass = sa . NullPool , connect_args = {"ydb_session_pool" : ydb_pool })
636
+ engine1 = sa .create_engine (config .db_url , poolclass = NullPool , connect_args = {"ydb_session_pool" : ydb_pool })
637
+ engine2 = sa .create_engine (config .db_url , poolclass = NullPool , connect_args = {"ydb_session_pool" : ydb_pool })
634
638
635
639
with engine1 .connect () as conn1 , engine2 .connect () as conn2 :
636
640
dbapi_conn1 : dbapi .Connection = conn1 .connection .dbapi_connection
@@ -861,7 +865,7 @@ def test_insert_in_name_and_field(self, connection):
861
865
class TestSecondaryIndex (TestBase ):
862
866
__backend__ = True
863
867
864
- def test_column_indexes (self , connection : sa . Connection , metadata : sa .MetaData ):
868
+ def test_column_indexes (self , connection , metadata : sa .MetaData ):
865
869
table = Table (
866
870
"test_column_indexes/table" ,
867
871
metadata ,
@@ -884,7 +888,7 @@ def test_column_indexes(self, connection: sa.Connection, metadata: sa.MetaData):
884
888
index1 = indexes_map ["ix_test_column_indexes_table_index_col2" ]
885
889
assert index1 .index_columns == ["index_col2" ]
886
890
887
- def test_async_index (self , connection : sa . Connection , metadata : sa .MetaData ):
891
+ def test_async_index (self , connection , metadata : sa .MetaData ):
888
892
table = Table (
889
893
"test_async_index/table" ,
890
894
metadata ,
@@ -903,7 +907,7 @@ def test_async_index(self, connection: sa.Connection, metadata: sa.MetaData):
903
907
assert set (index .index_columns ) == {"index_col1" , "index_col2" }
904
908
# TODO: Check type after https://github.com/ydb-platform/ydb-python-sdk/issues/351
905
909
906
- def test_cover_index (self , connection : sa . Connection , metadata : sa .MetaData ):
910
+ def test_cover_index (self , connection , metadata : sa .MetaData ):
907
911
table = Table (
908
912
"test_cover_index/table" ,
909
913
metadata ,
@@ -922,7 +926,7 @@ def test_cover_index(self, connection: sa.Connection, metadata: sa.MetaData):
922
926
assert set (index .index_columns ) == {"index_col1" }
923
927
# TODO: Check covered columns after https://github.com/ydb-platform/ydb-python-sdk/issues/409
924
928
925
- def test_indexes_reflection (self , connection : sa . Connection , metadata : sa .MetaData ):
929
+ def test_indexes_reflection (self , connection , metadata : sa .MetaData ):
926
930
table = Table (
927
931
"test_indexes_reflection/table" ,
928
932
metadata ,
@@ -948,7 +952,7 @@ def test_indexes_reflection(self, connection: sa.Connection, metadata: sa.MetaDa
948
952
"test_async_cover_index" : {"index_col1" },
949
953
}
950
954
951
- def test_index_simple_usage (self , connection : sa . Connection , metadata : sa .MetaData ):
955
+ def test_index_simple_usage (self , connection , metadata : sa .MetaData ):
952
956
persons = Table (
953
957
"test_index_simple_usage/persons" ,
954
958
metadata ,
@@ -979,7 +983,7 @@ def test_index_simple_usage(self, connection: sa.Connection, metadata: sa.MetaDa
979
983
cursor = connection .execute (select_stmt )
980
984
assert cursor .scalar_one () == "Sarah Connor"
981
985
982
- def test_index_with_join_usage (self , connection : sa . Connection , metadata : sa .MetaData ):
986
+ def test_index_with_join_usage (self , connection , metadata : sa .MetaData ):
983
987
persons = Table (
984
988
"test_index_with_join_usage/persons" ,
985
989
metadata ,
@@ -1033,7 +1037,7 @@ def test_index_with_join_usage(self, connection: sa.Connection, metadata: sa.Met
1033
1037
cursor = connection .execute (select_stmt )
1034
1038
assert cursor .one () == ("Sarah Connor" , "wanted" )
1035
1039
1036
- def test_index_deletion (self , connection : sa . Connection , metadata : sa .MetaData ):
1040
+ def test_index_deletion (self , connection , metadata : sa .MetaData ):
1037
1041
persons = Table (
1038
1042
"test_index_deletion/persons" ,
1039
1043
metadata ,
@@ -1062,7 +1066,7 @@ def define_tables(cls, metadata: sa.MetaData):
1062
1066
Table ("table" , metadata , sa .Column ("id" , sa .Integer , primary_key = True ))
1063
1067
1064
1068
@classmethod
1065
- def insert_data (cls , connection : sa . Connection ):
1069
+ def insert_data (cls , connection ):
1066
1070
table = cls .tables ["some_dir/nested_dir/table" ]
1067
1071
root_table = cls .tables ["table" ]
1068
1072
0 commit comments