Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 4afeaf0

Browse files
authored
Merge pull request #174 from datafold/issue80
Start running as soon as first min/max query returns (Issue #80)
2 parents b2e951b + dc7b833 commit 4afeaf0

File tree

3 files changed

+55
-24
lines changed

3 files changed

+55
-24
lines changed

data_diff/databases/base.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,8 +205,7 @@ def _refine_coltypes(self, table_path: DbPath, col_dict: Dict[str, ColType]):
205205
fields = [self.normalize_uuid(c, String_UUID()) for c in text_columns]
206206
samples_by_row = self.query(Select(fields, TableName(table_path), limit=16), list)
207207
if not samples_by_row:
208-
logger.warning(f"Table {table_path} is empty.")
209-
return
208+
raise ValueError(f"Table {table_path} is empty.")
210209

211210
samples_by_col = list(zip(*samples_by_row))
212211

data_diff/diff_tables.py

Lines changed: 51 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from collections import defaultdict
99
from typing import List, Tuple, Iterator, Optional
1010
import logging
11-
from concurrent.futures import ThreadPoolExecutor
11+
from concurrent.futures import ThreadPoolExecutor, as_completed
1212

1313
from runtype import dataclass
1414

@@ -315,17 +315,16 @@ def diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult:
315315
('-', columns) for items in table2 but not in table1
316316
Where `columns` is a tuple of values for the involved columns, i.e. (id, ...extra)
317317
"""
318+
# Validate options
318319
if self.bisection_factor >= self.bisection_threshold:
319320
raise ValueError("Incorrect param values (bisection factor must be lower than threshold)")
320321
if self.bisection_factor < 2:
321322
raise ValueError("Must have at least two segments per iteration (i.e. bisection_factor >= 2)")
322323

324+
# Query and validate schema
323325
table1, table2 = self._threaded_call("with_schema", [table1, table2])
324326
self._validate_and_adjust_columns(table1, table2)
325327

326-
key_ranges = self._threaded_call("query_key_range", [table1, table2])
327-
mins, maxs = zip(*key_ranges)
328-
329328
key_type = table1._schema[table1.key_column]
330329
key_type2 = table2._schema[table2.key_column]
331330
if not isinstance(key_type, IKey):
@@ -334,23 +333,42 @@ def diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult:
334333
raise NotImplementedError(f"Cannot use column of type {key_type2} as a key")
335334
assert key_type.python_type is key_type2.python_type
336335

337-
# We add 1 because our ranges are exclusive of the end (like in Python)
338-
try:
339-
min_key = min(map(key_type.python_type, mins))
340-
max_key = max(map(key_type.python_type, maxs)) + 1
341-
except (TypeError, ValueError) as e:
342-
raise type(e)(f"Cannot apply {key_type} to {mins}, {maxs}.") from e
336+
# Query min/max values
337+
key_ranges = self._threaded_call_as_completed("query_key_range", [table1, table2])
343338

344-
table1 = table1.new(min_key=min_key, max_key=max_key)
345-
table2 = table2.new(min_key=min_key, max_key=max_key)
339+
# Start with the first completed value, so we don't waste time waiting
340+
min_key1, max_key1 = self._parse_key_range_result(key_type, next(key_ranges))
341+
342+
table1, table2 = [t.new(min_key=min_key1, max_key=max_key1) for t in (table1, table2)]
346343

347344
logger.info(
348345
f"Diffing tables | segments: {self.bisection_factor}, bisection threshold: {self.bisection_threshold}. "
349346
f"key-range: {table1.min_key}..{table2.max_key}, "
350347
f"size: {table2.max_key-table1.min_key}"
351348
)
352349

353-
return self._bisect_and_diff_tables(table1, table2)
350+
# Bisect (split) the table into segments, and diff them recursively.
351+
yield from self._bisect_and_diff_tables(table1, table2)
352+
353+
# Now we check for the second min-max, to diff the portions we "missed".
354+
min_key2, max_key2 = self._parse_key_range_result(key_type, next(key_ranges))
355+
356+
if min_key2 < min_key1:
357+
pre_tables = [t.new(min_key=min_key2, max_key=min_key1) for t in (table1, table2)]
358+
yield from self._bisect_and_diff_tables(*pre_tables)
359+
360+
if max_key2 > max_key1:
361+
post_tables = [t.new(min_key=max_key1, max_key=max_key2) for t in (table1, table2)]
362+
yield from self._bisect_and_diff_tables(*post_tables)
363+
364+
def _parse_key_range_result(self, key_type, key_range):
365+
mn, mx = key_range
366+
cls = key_type.python_type
367+
# We add 1 because our ranges are exclusive of the end (like in Python)
368+
try:
369+
return cls(mn), cls(mx) + 1
370+
except (TypeError, ValueError) as e:
371+
raise type(e)(f"Cannot apply {key_type} to {mn}, {mx}.") from e
354372

355373
def _validate_and_adjust_columns(self, table1, table2):
356374
for c in table1._relevant_columns:
@@ -474,12 +492,26 @@ def _diff_tables(self, table1, table2, level=0, segment_index=None, segment_coun
474492
if checksum1 != checksum2:
475493
yield from self._bisect_and_diff_tables(table1, table2, level=level, max_rows=max(count1, count2))
476494

477-
def _thread_map(self, func, iter):
495+
def _thread_map(self, func, iterable):
496+
if not self.threaded:
497+
return map(func, iterable)
498+
499+
with ThreadPoolExecutor(max_workers=self.max_threadpool_size) as task_pool:
500+
return task_pool.map(func, iterable)
501+
502+
def _threaded_call(self, func, iterable):
503+
"Calls a method for each object in iterable."
504+
return list(self._thread_map(methodcaller(func), iterable))
505+
506+
def _thread_as_completed(self, func, iterable):
478507
if not self.threaded:
479-
return map(func, iter)
508+
return map(func, iterable)
480509

481-
task_pool = ThreadPoolExecutor(max_workers=self.max_threadpool_size)
482-
return task_pool.map(func, iter)
510+
with ThreadPoolExecutor(max_workers=self.max_threadpool_size) as task_pool:
511+
futures = [task_pool.submit(func, item) for item in iterable]
512+
for future in as_completed(futures):
513+
yield future.result()
483514

484-
def _threaded_call(self, func, iter):
485-
return list(self._thread_map(methodcaller(func), iter))
515+
def _threaded_call_as_completed(self, func, iterable):
516+
"Calls a method for each object in iterable. Returned in order of completion."
517+
return self._thread_as_completed(methodcaller(func), iterable)

tests/test_diff_tables.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,7 @@ def test_string_keys(self):
405405
f"INSERT INTO {self.table_src} VALUES ('unexpected', '<-- this bad value should not break us')", None
406406
)
407407

408-
self.assertRaises(ValueError, differ.diff_tables, self.a, self.b)
408+
self.assertRaises(ValueError, list, differ.diff_tables(self.a, self.b))
409409

410410

411411
@test_per_database
@@ -592,7 +592,7 @@ def setUp(self):
592592

593593
def test_right_table_empty(self):
594594
differ = TableDiffer()
595-
self.assertRaises(ValueError, differ.diff_tables, self.a, self.b)
595+
self.assertRaises(ValueError, list, differ.diff_tables(self.a, self.b))
596596

597597
def test_left_table_empty(self):
598598
queries = [
@@ -605,4 +605,4 @@ def test_left_table_empty(self):
605605
_commit(self.connection)
606606

607607
differ = TableDiffer()
608-
self.assertRaises(ValueError, differ.diff_tables, self.a, self.b)
608+
self.assertRaises(ValueError, list, differ.diff_tables(self.a, self.b))

0 commit comments

Comments
 (0)