Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit c9ab57f

Browse files
authored
Merge pull request #170 from datafold/oracle_tests
Tests now cover oracle, Redshift, snowflake and bigquery; Various fixes to said drivers.
2 parents 73e254a + 2020859 commit c9ab57f

15 files changed

+302
-164
lines changed

data_diff/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,11 @@
1313

1414

1515
def connect_to_table(
16-
db_info: Union[str, dict], table_name: Union[DbPath, str], key_column: str = "id", thread_count: Optional[int] = 1, **kwargs
16+
db_info: Union[str, dict],
17+
table_name: Union[DbPath, str],
18+
key_column: str = "id",
19+
thread_count: Optional[int] = 1,
20+
**kwargs,
1721
):
1822
"""Connects to the given database, and creates a TableSegment instance
1923

data_diff/__main__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,11 @@
2626
"-": "red",
2727
}
2828

29+
2930
def _remove_passwords_in_dict(d: dict):
3031
for k, v in d.items():
31-
if k == 'password':
32-
d[k] = '*' * len(v)
32+
if k == "password":
33+
d[k] = "*" * len(v)
3334
elif isinstance(v, dict):
3435
_remove_passwords_in_dict(v)
3536

data_diff/databases/base.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
TemporalType,
2121
UnknownColType,
2222
Text,
23+
DbTime,
2324
)
2425
from data_diff.sql import DbPath, SqlOrStr, Compiler, Explain, Select, TableName
2526

@@ -151,9 +152,10 @@ def _parse_type(
151152

152153
elif issubclass(cls, Decimal):
153154
if numeric_scale is None:
154-
raise ValueError(
155-
f"{self.name}: Unexpected numeric_scale is NULL, for column {'.'.join(table_path)}.{col_name} of type {type_repr}."
156-
)
155+
numeric_scale = 0 # Needed for Oracle.
156+
# raise ValueError(
157+
# f"{self.name}: Unexpected numeric_scale is NULL, for column {'.'.join(table_path)}.{col_name} of type {type_repr}."
158+
# )
157159
return cls(precision=numeric_scale)
158160

159161
elif issubclass(cls, Float):
@@ -242,6 +244,13 @@ def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None
242244

243245
return f"LIMIT {limit}"
244246

247+
def concat(self, l: List[str]) -> str:
248+
joined_exprs = ", ".join(l)
249+
return f"concat({joined_exprs})"
250+
251+
def timestamp_value(self, t: DbTime) -> str:
252+
return "'%s'" % t.isoformat()
253+
245254
def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str:
246255
if isinstance(coltype, String_UUID):
247256
return f"TRIM({value})"

data_diff/databases/connect.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,9 @@ def match_path(self, dsn):
8080
"presto": MatchUriPath(Presto, ["catalog", "schema"], help_str="presto://<user>@<host>/<catalog>/<schema>"),
8181
"bigquery": MatchUriPath(BigQuery, ["dataset"], help_str="bigquery://<project>/<dataset>"),
8282
"databricks": MatchUriPath(
83-
Databricks, ["catalog", "schema"], help_str="databricks://:access_token@server_name/http_path",
83+
Databricks,
84+
["catalog", "schema"],
85+
help_str="databricks://:access_token@server_name/http_path",
8486
),
8587
"trino": MatchUriPath(Trino, ["catalog", "schema"], help_str="trino://<user>@<host>/<catalog>/<schema>"),
8688
}
@@ -125,9 +127,9 @@ def connect_to_uri(db_uri: str, thread_count: Optional[int] = 1) -> Database:
125127
if scheme == "databricks":
126128
assert not dsn.user
127129
kw = {}
128-
kw['access_token'] = dsn.password
129-
kw['http_path'] = dsn.path
130-
kw['server_hostname'] = dsn.host
130+
kw["access_token"] = dsn.password
131+
kw["http_path"] = dsn.path
132+
kw["server_hostname"] = dsn.host
131133
kw.update(dsn.query)
132134
else:
133135
kw = matcher.match_path(dsn)

data_diff/databases/database_types.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import decimal
22
from abc import ABC, abstractmethod
3-
from typing import Sequence, Optional, Tuple, Union, Dict, Any
3+
from typing import Sequence, Optional, Tuple, Union, Dict, List
44
from datetime import datetime
55

66
from runtype import dataclass
@@ -120,13 +120,24 @@ def to_string(self, s: str) -> str:
120120
"Provide SQL for casting a column to string"
121121
...
122122

123+
@abstractmethod
124+
def concat(self, s: List[str]) -> str:
125+
"Provide SQL for concatenating a bunch of column into a string"
126+
...
127+
128+
@abstractmethod
129+
def timestamp_value(self, t: DbTime) -> str:
130+
"Provide SQL for the given timestamp value"
131+
...
132+
123133
@abstractmethod
124134
def md5_to_int(self, s: str) -> str:
125135
"Provide SQL for computing md5 and returning an int"
126136
...
127137

128138
@abstractmethod
129139
def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None):
140+
"Provide SQL fragment for limit and offset inside a select"
130141
...
131142

132143
@abstractmethod

data_diff/databases/oracle.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
from .base import ThreadedDatabase, import_helper, ConnectError, QueryError
55
from .base import DEFAULT_DATETIME_PRECISION, DEFAULT_NUMERIC_PRECISION
66

7-
SESSION_TIME_ZONE = None # Changed by the tests
7+
SESSION_TIME_ZONE = None # Changed by the tests
8+
89

910
@import_helper("oracle")
1011
def import_oracle():
@@ -89,6 +90,7 @@ def _parse_type(
8990
regexps = {
9091
r"TIMESTAMP\((\d)\) WITH LOCAL TIME ZONE": Timestamp,
9192
r"TIMESTAMP\((\d)\) WITH TIME ZONE": TimestampTZ,
93+
r"TIMESTAMP\((\d)\)": Timestamp,
9294
}
9395
for regexp, t_cls in regexps.items():
9496
m = re.match(regexp + "$", type_repr)
@@ -99,14 +101,23 @@ def _parse_type(
99101
rounds=self.ROUNDS_ON_PREC_LOSS,
100102
)
101103

102-
return super()._parse_type(type_repr, col_name, type_repr, datetime_precision, numeric_precision, numeric_scale)
104+
return super()._parse_type(
105+
table_name, col_name, type_repr, datetime_precision, numeric_precision, numeric_scale
106+
)
103107

104108
def offset_limit(self, offset: Optional[int] = None, limit: Optional[int] = None):
105109
if offset:
106110
raise NotImplementedError("No support for OFFSET in query")
107111

108112
return f"FETCH NEXT {limit} ROWS ONLY"
109113

114+
def concat(self, l: List[str]) -> str:
115+
joined_exprs = " || ".join(l)
116+
return f"({joined_exprs})"
117+
118+
def timestamp_value(self, t: DbTime) -> str:
119+
return "timestamp '%s'" % t.isoformat(" ")
120+
110121
def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str:
111122
# Cast is necessary for correct MD5 (trimming not enough)
112123
return f"CAST(TRIM({value}) AS VARCHAR(36))"

data_diff/databases/postgresql.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
from .base import ThreadedDatabase, import_helper, ConnectError
33
from .base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, _CHECKSUM_BITSIZE, TIMESTAMP_PRECISION_POS
44

5-
SESSION_TIME_ZONE = None # Changed by the tests
5+
SESSION_TIME_ZONE = None # Changed by the tests
6+
67

78
@import_helper("postgresql")
89
def import_postgresql():
@@ -49,7 +50,7 @@ def _convert_db_precision_to_digits(self, p: int) -> int:
4950

5051
def create_connection(self):
5152
if not self._args:
52-
self._args['host'] = None # psycopg2 requires 1+ arguments
53+
self._args["host"] = None # psycopg2 requires 1+ arguments
5354

5455
pg = import_postgresql()
5556
try:

data_diff/databases/redshift.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@ def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
3535
def normalize_number(self, value: str, coltype: FractionalType) -> str:
3636
return self.to_string(f"{value}::decimal(38,{coltype.precision})")
3737

38+
def concat(self, l: List[str]) -> str:
39+
joined_exprs = " || ".join(l)
40+
return f"({joined_exprs})"
41+
3842
def select_table_schema(self, path: DbPath) -> str:
3943
schema, table = self._normalize_table_path(path)
4044

data_diff/databases/trino.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,9 @@ def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
6666
else:
6767
s = f"date_format(cast({value} as timestamp(6)), '%Y-%m-%d %H:%i:%S.%f')"
6868

69-
return f"RPAD(RPAD({s}, {TIMESTAMP_PRECISION_POS + coltype.precision}, '.'), {TIMESTAMP_PRECISION_POS + 6}, '0')"
69+
return (
70+
f"RPAD(RPAD({s}, {TIMESTAMP_PRECISION_POS + coltype.precision}, '.'), {TIMESTAMP_PRECISION_POS + 6}, '0')"
71+
)
7072

7173
def normalize_number(self, value: str, coltype: FractionalType) -> str:
7274
return self.to_string(f"cast({value} as decimal(38,{coltype.precision}))")
@@ -96,9 +98,7 @@ def _parse_type(
9698
if m:
9799
datetime_precision = int(m.group(1))
98100
return t_cls(
99-
precision=datetime_precision
100-
if datetime_precision is not None
101-
else DEFAULT_DATETIME_PRECISION,
101+
precision=datetime_precision if datetime_precision is not None else DEFAULT_DATETIME_PRECISION,
102102
rounds=self.ROUNDS_ON_PREC_LOSS,
103103
)
104104

@@ -115,9 +115,7 @@ def _parse_type(
115115
if m:
116116
return n_cls()
117117

118-
return super()._parse_type(
119-
table_path, col_name, type_repr, datetime_precision, numeric_precision
120-
)
118+
return super()._parse_type(table_path, col_name, type_repr, datetime_precision, numeric_precision)
121119

122120
def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str:
123121
return f"TRIM({value})"

data_diff/sql.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,8 @@ class Checksum(Sql):
121121

122122
def compile(self, c: Compiler):
123123
if len(self.exprs) > 1:
124-
compiled_exprs = ", ".join(f"coalesce({c.compile(expr)}, '<null>')" for expr in self.exprs)
125-
expr = f"concat({compiled_exprs})"
124+
compiled_exprs = [f"coalesce({c.compile(expr)}, '<null>')" for expr in self.exprs]
125+
expr = c.database.concat(compiled_exprs)
126126
else:
127127
# No need to coalesce - safe to assume that key cannot be null
128128
(expr,) = self.exprs
@@ -180,10 +180,9 @@ def compile(self, c: Compiler):
180180
@dataclass
181181
class Time(Sql):
182182
time: datetime
183-
column: Optional[SqlOrStr] = None
184183

185184
def compile(self, c: Compiler):
186-
return "'%s'" % self.time.isoformat()
185+
return c.database.timestamp_value(self.time)
187186

188187

189188
@dataclass

0 commit comments

Comments
 (0)