Skip to content

Commit 37c89b8

Browse files
committed
Check
Working e2e tests prototype
1 parent b032c3f commit 37c89b8

File tree

4 files changed

+42
-15
lines changed

4 files changed

+42
-15
lines changed

src/databricks/sql/client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -789,7 +789,7 @@ def execute(
789789
790790
:returns self
791791
"""
792-
792+
793793
param_approach = self._determine_parameter_approach(parameters)
794794
if param_approach == ParameterApproach.NONE:
795795
prepared_params = NO_NATIVE_PARAMS
@@ -808,7 +808,7 @@ def execute(
808808
prepared_operation, prepared_params = self._prepare_native_parameters(
809809
transformed_operation, normalized_parameters, param_structure
810810
)
811-
811+
812812
self._check_not_closed()
813813
self._close_and_clear_active_result_set()
814814
execute_response = self.thrift_backend.execute_command(

src/databricks/sql/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,7 @@ def user_friendly_error_message(self, no_retry_reason, attempt, elapsed):
429429
# Taken from PyHive
430430
class ParamEscaper:
431431
_DATE_FORMAT = "%Y-%m-%d"
432-
_TIME_FORMAT = "%H:%M:%S.%f"
432+
_TIME_FORMAT = "%H:%M:%S.%f %z"
433433
_DATETIME_FORMAT = "{} {}".format(_DATE_FORMAT, _TIME_FORMAT)
434434

435435
def escape_args(self, parameters):
@@ -459,13 +459,15 @@ def escape_string(self, item):
459459

460460
def escape_sequence(self, item):
461461
l = map(self.escape_item, item)
462+
l = list(map(str, l))
462463
return "ARRAY(" + ",".join(l) + ")"
463464

464465
def escape_mapping(self, item):
465466
l = map(
466467
self.escape_item,
467468
(element for key, value in item.items() for element in (key, value)),
468469
)
470+
l = list(map(str, l))
469471
return "MAP(" + ",".join(l) + ")"
470472

471473
def escape_datetime(self, item, format, cutoff=0):

tests/e2e/test_complex_types.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import pytest
22
from numpy import ndarray
3+
from typing import Sequence
34

45
from tests.e2e.test_driver import PySQLPytestTestCase
56

@@ -17,7 +18,10 @@ def table_fixture(self, connection_details):
1718
CREATE TABLE IF NOT EXISTS pysql_test_complex_types_table (
1819
array_col ARRAY<STRING>,
1920
map_col MAP<STRING, INTEGER>,
20-
struct_col STRUCT<field1: STRING, field2: INTEGER>
21+
struct_col STRUCT<field1: STRING, field2: INTEGER>,
22+
array_array_col ARRAY<ARRAY<STRING>>,
23+
array_map_col ARRAY<MAP<STRING, INTEGER>>,
24+
map_array_col MAP<STRING, ARRAY<STRING>>
2125
)
2226
"""
2327
)
@@ -28,7 +32,10 @@ def table_fixture(self, connection_details):
2832
VALUES (
2933
ARRAY('a', 'b', 'c'),
3034
MAP('a', 1, 'b', 2, 'c', 3),
31-
NAMED_STRUCT('field1', 'a', 'field2', 1)
35+
NAMED_STRUCT('field1', 'a', 'field2', 1),
36+
ARRAY(ARRAY('a','b','c')),
37+
ARRAY(MAP('a', 1, 'b', 2, 'c', 3)),
38+
MAP('a', ARRAY('a', 'b', 'c'), 'b', ARRAY('d', 'e'))
3239
)
3340
"""
3441
)
@@ -38,7 +45,7 @@ def table_fixture(self, connection_details):
3845

3946
@pytest.mark.parametrize(
4047
"field,expected_type",
41-
[("array_col", ndarray), ("map_col", list), ("struct_col", dict)],
48+
[("array_col", ndarray), ("map_col", list), ("struct_col", dict), ("array_array_col", ndarray), ("array_map_col", ndarray), ("map_array_col", list)],
4249
)
4350
def test_read_complex_types_as_arrow(self, field, expected_type, table_fixture):
4451
"""Confirms the return types of a complex type field when reading as arrow"""
@@ -47,10 +54,10 @@ def test_read_complex_types_as_arrow(self, field, expected_type, table_fixture):
4754
result = cursor.execute(
4855
"SELECT * FROM pysql_test_complex_types_table LIMIT 1"
4956
).fetchone()
50-
57+
5158
assert isinstance(result[field], expected_type)
5259

53-
@pytest.mark.parametrize("field", [("array_col"), ("map_col"), ("struct_col")])
60+
@pytest.mark.parametrize("field", [("array_col"), ("map_col"), ("struct_col"), ("array_array_col"), ("array_map_col"), ("map_array_col")])
5461
def test_read_complex_types_as_string(self, field, table_fixture):
5562
"""Confirms the return type of a complex type that is returned as a string"""
5663
with self.cursor(

tests/e2e/test_parameterized_queries.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
TimestampParameter,
2727
TinyIntParameter,
2828
VoidParameter,
29+
ArrayParameter,
30+
MapParameter,
2931
)
3032
from tests.e2e.test_driver import PySQLPytestTestCase
3133

@@ -50,6 +52,8 @@ class Primitive(Enum):
5052
DOUBLE = 3.14
5153
FLOAT = 3.15
5254
SMALLINT = 51
55+
ARRAYS = ["a", "b", "c"]
56+
MAPS = {"a": 1, "b": 2, "c": 3}
5357

5458

5559
class PrimitiveExtra(Enum):
@@ -103,6 +107,8 @@ class TestParameterizedQueries(PySQLPytestTestCase):
103107
Primitive.BOOL: "boolean_col",
104108
Primitive.DATE: "date_col",
105109
Primitive.TIMESTAMP: "timestamp_col",
110+
Primitive.ARRAYS: "array_col",
111+
Primitive.MAPS: "map_col",
106112
Primitive.NONE: "null_col",
107113
}
108114

@@ -134,7 +140,9 @@ def inline_table(self, connection_details):
134140
string_col STRING,
135141
boolean_col BOOLEAN,
136142
date_col DATE,
137-
timestamp_col TIMESTAMP
143+
timestamp_col TIMESTAMP,
144+
array_col ARRAY<STRING>,
145+
map_col MAP<STRING, INT>
138146
) USING DELTA
139147
"""
140148

@@ -167,9 +175,9 @@ def _inline_roundtrip(self, params: dict, paramstyle: ParamStyle):
167175
This is a no-op but is included to make the test-code easier to read.
168176
"""
169177
target_column = self._get_inline_table_column(params.get("p"))
170-
INSERT_QUERY = f"INSERT INTO pysql_e2e_inline_param_test_table (`{target_column}`) VALUES (%(p)s)"
171-
SELECT_QUERY = f"SELECT {target_column} `col` FROM pysql_e2e_inline_param_test_table LIMIT 1"
172-
DELETE_QUERY = "DELETE FROM pysql_e2e_inline_param_test_table"
178+
INSERT_QUERY = f"INSERT INTO ___________________first.jprakash.pysql_e2e_inline_param_test_table (`{target_column}`) VALUES (%(p)s)"
179+
SELECT_QUERY = f"SELECT {target_column} `col` FROM ___________________first.jprakash.pysql_e2e_inline_param_test_table LIMIT 1"
180+
DELETE_QUERY = "DELETE FROM ___________________first.jprakash.pysql_e2e_inline_param_test_table"
173181

174182
with self.connection(extra_params={"use_inline_params": True}) as conn:
175183
with conn.cursor() as cursor:
@@ -229,10 +237,18 @@ def _eq(self, actual, expected: Primitive):
229237
If primitive is Primitive.DOUBLE than an extra quantize step is performed before
230238
making the assertion.
231239
"""
232-
if expected in (Primitive.DOUBLE, Primitive.FLOAT):
233-
return self._quantize(actual) == self._quantize(expected.value)
240+
actual_parsed = actual
241+
expected_parsed = expected.value
234242

235-
return actual == expected.value
243+
if expected in (Primitive.DOUBLE, Primitive.FLOAT):
244+
actual_parsed = self._quantize(actual)
245+
expected_parsed = self._quantize(expected.value)
246+
elif expected == Primitive.ARRAYS:
247+
actual_parsed = actual.tolist()
248+
elif expected == Primitive.MAPS:
249+
expected_parsed = list(expected.value.items())
250+
251+
return actual_parsed == expected_parsed
236252

237253
@pytest.mark.parametrize("primitive", Primitive)
238254
@pytest.mark.parametrize(
@@ -278,6 +294,8 @@ def test_primitive_single(
278294
(Primitive.SMALLINT, SmallIntParameter),
279295
(PrimitiveExtra.TIMESTAMP_NTZ, TimestampNTZParameter),
280296
(PrimitiveExtra.TINYINT, TinyIntParameter),
297+
(Primitive.ARRAYS, ArrayParameter),
298+
(Primitive.MAPS, MapParameter),
281299
],
282300
)
283301
def test_dbsqlparameter_single(

0 commit comments

Comments
 (0)