11import datetime
22import decimal
3- import os
43from typing import Tuple , Union , List
54from unittest import skipIf
65
1918from sqlalchemy .engine .reflection import Inspector
2019from sqlalchemy .orm import DeclarativeBase , Mapped , Session , mapped_column
2120from sqlalchemy .schema import DropColumnComment , SetColumnComment
22- from sqlalchemy .types import BOOLEAN , DECIMAL , Date , DateTime , Integer , String
21+ from sqlalchemy .types import BOOLEAN , DECIMAL , Date , Integer , String
2322
2423try :
2524 from sqlalchemy .orm import declarative_base
@@ -49,12 +48,12 @@ def version_agnostic_select(object_to_select, *args, **kwargs):
4948 return select (object_to_select , * args , ** kwargs )
5049
5150
52- def version_agnostic_connect_arguments (catalog = None , schema = None ) -> Tuple [str , dict ]:
53- HOST = os . environ . get ( "host" )
54- HTTP_PATH = os . environ . get ( "http_path" )
55- ACCESS_TOKEN = os . environ . get ( "access_token" )
56- CATALOG = catalog or os . environ . get ( "catalog" )
57- SCHEMA = schema or os . environ . get ( "schema" )
51+ def version_agnostic_connect_arguments (connection_details ) -> Tuple [str , dict ]:
52+ HOST = connection_details [ "host" ]
53+ HTTP_PATH = connection_details [ "http_path" ]
54+ ACCESS_TOKEN = connection_details [ "access_token" ]
55+ CATALOG = connection_details [ "catalog" ]
56+ SCHEMA = connection_details [ "schema" ]
5857
5958 ua_connect_args = {"_user_agent_entry" : USER_AGENT_TOKEN }
6059
@@ -77,8 +76,8 @@ def version_agnostic_connect_arguments(catalog=None, schema=None) -> Tuple[str,
7776
7877
7978@pytest .fixture
80- def db_engine () -> Engine :
81- conn_string , connect_args = version_agnostic_connect_arguments ()
79+ def db_engine (connection_details ) -> Engine :
80+ conn_string , connect_args = version_agnostic_connect_arguments (connection_details )
8281 return create_engine (conn_string , connect_args = connect_args )
8382
8483
@@ -92,10 +91,11 @@ def run_query(db_engine: Engine, query: Union[str, Text]):
9291
9392
9493@pytest .fixture
95- def samples_engine () -> Engine :
96- conn_string , connect_args = version_agnostic_connect_arguments (
97- catalog = "samples" , schema = "nyctaxi"
98- )
94+ def samples_engine (connection_details ) -> Engine :
95+ details = connection_details .copy ()
96+ details ["catalog" ] = "samples"
97+ details ["schema" ] = "nyctaxi"
98+ conn_string , connect_args = version_agnostic_connect_arguments (details )
9999 return create_engine (conn_string , connect_args = connect_args )
100100
101101
@@ -141,7 +141,7 @@ def test_connect_args(db_engine):
141141def test_pandas_upload (db_engine , metadata_obj ):
142142 import pandas as pd
143143
144- SCHEMA = os . environ . get ( "schema" )
144+ SCHEMA = "default"
145145 try :
146146 df = pd .read_excel (
147147 "src/databricks/sqlalchemy/test_local/e2e/demo_data/MOCK_DATA.xlsx"
@@ -409,7 +409,9 @@ def test_get_table_names_smoke_test(samples_engine: Engine):
409409 _names is not None , "get_table_names did not succeed"
410410
411411
412- def test_has_table_across_schemas (db_engine : Engine , samples_engine : Engine ):
412+ def test_has_table_across_schemas (
413+ db_engine : Engine , samples_engine : Engine , catalog : str , schema : str
414+ ):
413415 """For this test to pass these conditions must be met:
414416 - Table samples.nyctaxi.trips must exist
415417 - Table samples.tpch.customer must exist
@@ -426,9 +428,6 @@ def test_has_table_across_schemas(db_engine: Engine, samples_engine: Engine):
426428 )
427429
428430 # 3) Check for a table within a different catalog
429- other_catalog = os .environ .get ("catalog" )
430- other_schema = os .environ .get ("schema" )
431-
432431 # Create a table in a different catalog
433432 with db_engine .connect () as conn :
434433 conn .execute (text ("CREATE TABLE test_has_table (numbers_are_cool INT);" ))
@@ -442,8 +441,8 @@ def test_has_table_across_schemas(db_engine: Engine, samples_engine: Engine):
442441 assert samples_engine .dialect .has_table (
443442 connection = conn ,
444443 table_name = "test_has_table" ,
445- schema = other_schema ,
446- catalog = other_catalog ,
444+ schema = schema ,
445+ catalog = catalog ,
447446 )
448447 finally :
449448 conn .execute (text ("DROP TABLE test_has_table;" ))
@@ -503,12 +502,12 @@ def test_get_columns(db_engine, sample_table: str):
503502
504503class TestCommentReflection :
505504 @pytest .fixture (scope = "class" )
506- def engine (self ):
507- HOST = os . environ . get ( "host" )
508- HTTP_PATH = os . environ . get ( "http_path" )
509- ACCESS_TOKEN = os . environ . get ( "access_token" )
510- CATALOG = os . environ . get ( "catalog" )
511- SCHEMA = os . environ . get ( "schema" )
505+ def engine (self , connection_details : dict ):
506+ HOST = connection_details [ "host" ]
507+ HTTP_PATH = connection_details [ "http_path" ]
508+ ACCESS_TOKEN = connection_details [ "access_token" ]
509+ CATALOG = connection_details [ "catalog" ]
510+ SCHEMA = connection_details [ "schema" ]
512511
513512 connection_string = f"databricks://token:{ ACCESS_TOKEN } @{ HOST } ?http_path={ HTTP_PATH } &catalog={ CATALOG } &schema={ SCHEMA } "
514513 connect_args = {"_user_agent_entry" : USER_AGENT_TOKEN }
0 commit comments