Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 010a473

Browse files
authored
Merge pull request #230 from datafold/issue229
Fixed support for diffing columns of different names
2 parents 0cd0c40 + 0a18d0b commit 010a473

File tree

5 files changed

+77
-41
lines changed

5 files changed

+77
-41
lines changed

data_diff/databases/clickhouse.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def cursor(self, cursor_factory=None):
6969
def _parse_type_repr(self, type_repr: str) -> Optional[Type[ColType]]:
7070
nullable_prefix = "Nullable("
7171
if type_repr.startswith(nullable_prefix):
72-
type_repr = type_repr[len(nullable_prefix):].rstrip(")")
72+
type_repr = type_repr[len(nullable_prefix) :].rstrip(")")
7373

7474
if type_repr.startswith("Decimal"):
7575
type_repr = "Decimal"
@@ -91,7 +91,7 @@ def to_string(self, s: str) -> str:
9191
return f"toString({s})"
9292

9393
def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
94-
prec= coltype.precision
94+
prec = coltype.precision
9595
if coltype.rounds:
9696
timestamp = f"toDateTime64(round(toUnixTimestamp64Micro(toDateTime64({value}, 6)) / 1000000, {prec}), 6)"
9797
return self.to_string(timestamp)

data_diff/diff_tables.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -172,42 +172,42 @@ def _parse_key_range_result(self, key_type, key_range):
172172
raise type(e)(f"Cannot apply {key_type} to {mn}, {mx}.") from e
173173

174174
def _validate_and_adjust_columns(self, table1, table2):
175-
for c in table1._relevant_columns:
176-
if c not in table1._schema:
175+
for c1, c2 in safezip(table1._relevant_columns, table2._relevant_columns):
176+
if c1 not in table1._schema:
177177
raise ValueError(f"Column '{c}' not found in schema for table {table1}")
178-
if c not in table2._schema:
178+
if c2 not in table2._schema:
179179
raise ValueError(f"Column '{c}' not found in schema for table {table2}")
180180

181181
# Update schemas to minimal mutual precision
182-
col1 = table1._schema[c]
183-
col2 = table2._schema[c]
182+
col1 = table1._schema[c1]
183+
col2 = table2._schema[c2]
184184
if isinstance(col1, PrecisionType):
185185
if not isinstance(col2, PrecisionType):
186-
raise TypeError(f"Incompatible types for column '{c}': {col1} <-> {col2}")
186+
raise TypeError(f"Incompatible types for column '{c1}': {col1} <-> {col2}")
187187

188188
lowest = min(col1, col2, key=attrgetter("precision"))
189189

190190
if col1.precision != col2.precision:
191-
logger.warning(f"Using reduced precision {lowest} for column '{c}'. Types={col1}, {col2}")
191+
logger.warning(f"Using reduced precision {lowest} for column '{c1}'. Types={col1}, {col2}")
192192

193-
table1._schema[c] = col1.replace(precision=lowest.precision, rounds=lowest.rounds)
194-
table2._schema[c] = col2.replace(precision=lowest.precision, rounds=lowest.rounds)
193+
table1._schema[c1] = col1.replace(precision=lowest.precision, rounds=lowest.rounds)
194+
table2._schema[c2] = col2.replace(precision=lowest.precision, rounds=lowest.rounds)
195195

196196
elif isinstance(col1, NumericType):
197197
if not isinstance(col2, NumericType):
198-
raise TypeError(f"Incompatible types for column '{c}': {col1} <-> {col2}")
198+
raise TypeError(f"Incompatible types for column '{c1}': {col1} <-> {col2}")
199199

200200
lowest = min(col1, col2, key=attrgetter("precision"))
201201

202202
if col1.precision != col2.precision:
203-
logger.warning(f"Using reduced precision {lowest} for column '{c}'. Types={col1}, {col2}")
203+
logger.warning(f"Using reduced precision {lowest} for column '{c1}'. Types={col1}, {col2}")
204204

205-
table1._schema[c] = col1.replace(precision=lowest.precision)
206-
table2._schema[c] = col2.replace(precision=lowest.precision)
205+
table1._schema[c1] = col1.replace(precision=lowest.precision)
206+
table2._schema[c2] = col2.replace(precision=lowest.precision)
207207

208208
elif isinstance(col1, StringType):
209209
if not isinstance(col2, StringType):
210-
raise TypeError(f"Incompatible types for column '{c}': {col1} <-> {col2}")
210+
raise TypeError(f"Incompatible types for column '{c1}': {col1} <-> {col2}")
211211

212212
for t in [table1, table2]:
213213
for c in t._relevant_columns:

tests/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
TEST_ORACLE_CONN_STRING: str = None
2222
TEST_DATABRICKS_CONN_STRING: str = os.environ.get("DATADIFF_DATABRICKS_URI")
2323
TEST_TRINO_CONN_STRING: str = os.environ.get("DATADIFF_TRINO_URI") or None
24-
# clickhouse uri for provided docker - "clickhouse://clickhouse:Password1@localhost:9000/clickhouse"
24+
# clickhouse uri for provided docker - "clickhouse://clickhouse:Password1@localhost:9000/clickhouse"
2525
TEST_CLICKHOUSE_CONN_STRING: str = os.environ.get("DATADIFF_CLICKHOUSE_URI") or None
2626

2727
DEFAULT_N_SAMPLES = 50

tests/test_database_types.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -271,8 +271,8 @@ def init_conns():
271271
],
272272
"uuid": [
273273
"String",
274-
]
275-
}
274+
],
275+
},
276276
}
277277

278278

@@ -482,13 +482,13 @@ def _insert_to_table(conn, table, values, type):
482482
if type.startswith("DateTime64"):
483483
value = f"'{sample.replace(tzinfo=None)}'"
484484

485-
elif type == 'DateTime':
485+
elif type == "DateTime":
486486
sample = sample.replace(tzinfo=None)
487487
# Clickhouse's DateTime does not allow to store micro/milli/nano seconds
488488
value = f"'{str(sample)[:19]}'"
489489

490-
elif type.startswith('Decimal'):
491-
precision = int(type[8:].rstrip(')').split(',')[1])
490+
elif type.startswith("Decimal"):
491+
precision = int(type[8:].rstrip(")").split(",")[1])
492492
value = round(sample, precision)
493493

494494
else:

tests/test_diff_tables.py

Lines changed: 55 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -234,25 +234,7 @@ def setUp(self):
234234
f"create table {self.table_dst}(id int, userid int, movieid int, rating {float_type}, timestamp timestamp)",
235235
None,
236236
)
237-
# self.preql(
238-
# f"""
239-
# table {self.table_src_name} {{
240-
# userid: int
241-
# movieid: int
242-
# rating: float
243-
# timestamp: timestamp
244-
# }}
245-
246-
# table {self.table_dst_name} {{
247-
# userid: int
248-
# movieid: int
249-
# rating: float
250-
# timestamp: timestamp
251-
# }}
252-
# commit()
253-
# """
254-
# )
255-
self.preql.commit()
237+
_commit(self.connection)
256238

257239
self.table = TableSegment(self.connection, self.table_src_path, "id", "timestamp", case_sensitive=False)
258240
self.table2 = TableSegment(self.connection, self.table_dst_path, "id", "timestamp", case_sensitive=False)
@@ -402,6 +384,60 @@ def test_diff_sorted_by_key(self):
402384
self.assertEqual(expected, diff)
403385

404386

387+
@test_per_database
388+
class TestDiffTables2(TestPerDatabase):
389+
def test_diff_column_names(self):
390+
float_type = _get_float_type(self.connection)
391+
392+
self.connection.query(
393+
f"create table {self.table_src}(id int, rating {float_type}, timestamp timestamp)",
394+
None,
395+
)
396+
self.connection.query(
397+
f"create table {self.table_dst}(id2 int, rating2 {float_type}, timestamp2 timestamp)",
398+
None,
399+
)
400+
_commit(self.connection)
401+
402+
time = "2022-01-01 00:00:00"
403+
time2 = "2021-01-01 00:00:00"
404+
405+
time_str = f"timestamp '{time}'"
406+
time_str2 = f"timestamp '{time2}'"
407+
_insert_rows(
408+
self.connection,
409+
self.table_src,
410+
["id", "rating", "timestamp"],
411+
[
412+
[1, 9, time_str],
413+
[2, 9, time_str2],
414+
[3, 9, time_str],
415+
[4, 9, time_str2],
416+
[5, 9, time_str],
417+
],
418+
)
419+
420+
_insert_rows(
421+
self.connection,
422+
self.table_dst,
423+
["id2", "rating2", "timestamp2"],
424+
[
425+
[1, 9, time_str],
426+
[2, 9, time_str2],
427+
[3, 9, time_str],
428+
[4, 9, time_str2],
429+
[5, 9, time_str],
430+
],
431+
)
432+
433+
table1 = TableSegment(self.connection, self.table_src_path, "id", "timestamp", case_sensitive=False)
434+
table2 = TableSegment(self.connection, self.table_dst_path, "id2", "timestamp2", case_sensitive=False)
435+
436+
differ = TableDiffer()
437+
diff = list(differ.diff_tables(table1, table2))
438+
assert diff == []
439+
440+
405441
@test_per_database
406442
class TestUUIDs(TestPerDatabase):
407443
def setUp(self):

0 commit comments

Comments
 (0)