1
1
import datetime
2
2
import decimal
3
- import os
4
3
from typing import Tuple , Union , List
5
4
from unittest import skipIf
6
5
19
18
from sqlalchemy .engine .reflection import Inspector
20
19
from sqlalchemy .orm import DeclarativeBase , Mapped , Session , mapped_column
21
20
from sqlalchemy .schema import DropColumnComment , SetColumnComment
22
- from sqlalchemy .types import BOOLEAN , DECIMAL , Date , DateTime , Integer , String
21
+ from sqlalchemy .types import BOOLEAN , DECIMAL , Date , Integer , String
23
22
24
23
try :
25
24
from sqlalchemy .orm import declarative_base
@@ -49,12 +48,12 @@ def version_agnostic_select(object_to_select, *args, **kwargs):
49
48
return select (object_to_select , * args , ** kwargs )
50
49
51
50
52
- def version_agnostic_connect_arguments (catalog = None , schema = None ) -> Tuple [str , dict ]:
53
- HOST = os . environ . get ( "host" )
54
- HTTP_PATH = os . environ . get ( "http_path" )
55
- ACCESS_TOKEN = os . environ . get ( "access_token" )
56
- CATALOG = catalog or os . environ . get ( "catalog" )
57
- SCHEMA = schema or os . environ . get ( "schema" )
51
+ def version_agnostic_connect_arguments (connection_details ) -> Tuple [str , dict ]:
52
+ HOST = connection_details [ "host" ]
53
+ HTTP_PATH = connection_details [ "http_path" ]
54
+ ACCESS_TOKEN = connection_details [ "access_token" ]
55
+ CATALOG = connection_details [ "catalog" ]
56
+ SCHEMA = connection_details [ "schema" ]
58
57
59
58
ua_connect_args = {"_user_agent_entry" : USER_AGENT_TOKEN }
60
59
@@ -77,8 +76,8 @@ def version_agnostic_connect_arguments(catalog=None, schema=None) -> Tuple[str,
77
76
78
77
79
78
@pytest .fixture
80
- def db_engine () -> Engine :
81
- conn_string , connect_args = version_agnostic_connect_arguments ()
79
+ def db_engine (connection_details ) -> Engine :
80
+ conn_string , connect_args = version_agnostic_connect_arguments (connection_details )
82
81
return create_engine (conn_string , connect_args = connect_args )
83
82
84
83
@@ -92,10 +91,11 @@ def run_query(db_engine: Engine, query: Union[str, Text]):
92
91
93
92
94
93
@pytest .fixture
95
- def samples_engine () -> Engine :
96
- conn_string , connect_args = version_agnostic_connect_arguments (
97
- catalog = "samples" , schema = "nyctaxi"
98
- )
94
+ def samples_engine (connection_details ) -> Engine :
95
+ details = connection_details .copy ()
96
+ details ["catalog" ] = "samples"
97
+ details ["schema" ] = "nyctaxi"
98
+ conn_string , connect_args = version_agnostic_connect_arguments (details )
99
99
return create_engine (conn_string , connect_args = connect_args )
100
100
101
101
@@ -141,7 +141,7 @@ def test_connect_args(db_engine):
141
141
def test_pandas_upload (db_engine , metadata_obj ):
142
142
import pandas as pd
143
143
144
- SCHEMA = os . environ . get ( "schema" )
144
+ SCHEMA = "default"
145
145
try :
146
146
df = pd .read_excel (
147
147
"src/databricks/sqlalchemy/test_local/e2e/demo_data/MOCK_DATA.xlsx"
@@ -409,7 +409,9 @@ def test_get_table_names_smoke_test(samples_engine: Engine):
409
409
_names is not None , "get_table_names did not succeed"
410
410
411
411
412
- def test_has_table_across_schemas (db_engine : Engine , samples_engine : Engine ):
412
+ def test_has_table_across_schemas (
413
+ db_engine : Engine , samples_engine : Engine , catalog : str , schema : str
414
+ ):
413
415
"""For this test to pass these conditions must be met:
414
416
- Table samples.nyctaxi.trips must exist
415
417
- Table samples.tpch.customer must exist
@@ -426,9 +428,6 @@ def test_has_table_across_schemas(db_engine: Engine, samples_engine: Engine):
426
428
)
427
429
428
430
# 3) Check for a table within a different catalog
429
- other_catalog = os .environ .get ("catalog" )
430
- other_schema = os .environ .get ("schema" )
431
-
432
431
# Create a table in a different catalog
433
432
with db_engine .connect () as conn :
434
433
conn .execute (text ("CREATE TABLE test_has_table (numbers_are_cool INT);" ))
@@ -442,8 +441,8 @@ def test_has_table_across_schemas(db_engine: Engine, samples_engine: Engine):
442
441
assert samples_engine .dialect .has_table (
443
442
connection = conn ,
444
443
table_name = "test_has_table" ,
445
- schema = other_schema ,
446
- catalog = other_catalog ,
444
+ schema = schema ,
445
+ catalog = catalog ,
447
446
)
448
447
finally :
449
448
conn .execute (text ("DROP TABLE test_has_table;" ))
@@ -503,12 +502,12 @@ def test_get_columns(db_engine, sample_table: str):
503
502
504
503
class TestCommentReflection :
505
504
@pytest .fixture (scope = "class" )
506
- def engine (self ):
507
- HOST = os . environ . get ( "host" )
508
- HTTP_PATH = os . environ . get ( "http_path" )
509
- ACCESS_TOKEN = os . environ . get ( "access_token" )
510
- CATALOG = os . environ . get ( "catalog" )
511
- SCHEMA = os . environ . get ( "schema" )
505
+ def engine (self , connection_details : dict ):
506
+ HOST = connection_details [ "host" ]
507
+ HTTP_PATH = connection_details [ "http_path" ]
508
+ ACCESS_TOKEN = connection_details [ "access_token" ]
509
+ CATALOG = connection_details [ "catalog" ]
510
+ SCHEMA = connection_details [ "schema" ]
512
511
513
512
connection_string = f"databricks://token:{ ACCESS_TOKEN } @{ HOST } ?http_path={ HTTP_PATH } &catalog={ CATALOG } &schema={ SCHEMA } "
514
513
connect_args = {"_user_agent_entry" : USER_AGENT_TOKEN }
0 commit comments