8
8
from collections import defaultdict
9
9
from typing import List , Tuple , Iterator , Optional
10
10
import logging
11
- from concurrent .futures import ThreadPoolExecutor
11
+ from concurrent .futures import ThreadPoolExecutor , as_completed
12
12
13
13
from runtype import dataclass
14
14
@@ -315,17 +315,16 @@ def diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult:
315
315
('-', columns) for items in table2 but not in table1
316
316
Where `columns` is a tuple of values for the involved columns, i.e. (id, ...extra)
317
317
"""
318
+ # Validate options
318
319
if self .bisection_factor >= self .bisection_threshold :
319
320
raise ValueError ("Incorrect param values (bisection factor must be lower than threshold)" )
320
321
if self .bisection_factor < 2 :
321
322
raise ValueError ("Must have at least two segments per iteration (i.e. bisection_factor >= 2)" )
322
323
324
+ # Query and validate schema
323
325
table1 , table2 = self ._threaded_call ("with_schema" , [table1 , table2 ])
324
326
self ._validate_and_adjust_columns (table1 , table2 )
325
327
326
- key_ranges = self ._threaded_call ("query_key_range" , [table1 , table2 ])
327
- mins , maxs = zip (* key_ranges )
328
-
329
328
key_type = table1 ._schema [table1 .key_column ]
330
329
key_type2 = table2 ._schema [table2 .key_column ]
331
330
if not isinstance (key_type , IKey ):
@@ -334,23 +333,42 @@ def diff_tables(self, table1: TableSegment, table2: TableSegment) -> DiffResult:
334
333
raise NotImplementedError (f"Cannot use column of type { key_type2 } as a key" )
335
334
assert key_type .python_type is key_type2 .python_type
336
335
337
- # We add 1 because our ranges are exclusive of the end (like in Python)
338
- try :
339
- min_key = min (map (key_type .python_type , mins ))
340
- max_key = max (map (key_type .python_type , maxs )) + 1
341
- except (TypeError , ValueError ) as e :
342
- raise type (e )(f"Cannot apply { key_type } to { mins } , { maxs } ." ) from e
336
+ # Query min/max values
337
+ key_ranges = self ._threaded_call_as_completed ("query_key_range" , [table1 , table2 ])
343
338
344
- table1 = table1 .new (min_key = min_key , max_key = max_key )
345
- table2 = table2 .new (min_key = min_key , max_key = max_key )
339
+ # Start with the first completed value, so we don't waste time waiting
340
+ min_key1 , max_key1 = self ._parse_key_range_result (key_type , next (key_ranges ))
341
+
342
+ table1 , table2 = [t .new (min_key = min_key1 , max_key = max_key1 ) for t in (table1 , table2 )]
346
343
347
344
logger .info (
348
345
f"Diffing tables | segments: { self .bisection_factor } , bisection threshold: { self .bisection_threshold } . "
349
346
f"key-range: { table1 .min_key } ..{ table2 .max_key } , "
350
347
f"size: { table2 .max_key - table1 .min_key } "
351
348
)
352
349
353
- return self ._bisect_and_diff_tables (table1 , table2 )
350
+ # Bisect (split) the table into segments, and diff them recursively.
351
+ yield from self ._bisect_and_diff_tables (table1 , table2 )
352
+
353
+ # Now we check for the second min-max, to diff the portions we "missed".
354
+ min_key2 , max_key2 = self ._parse_key_range_result (key_type , next (key_ranges ))
355
+
356
+ if min_key2 < min_key1 :
357
+ pre_tables = [t .new (min_key = min_key2 , max_key = min_key1 ) for t in (table1 , table2 )]
358
+ yield from self ._bisect_and_diff_tables (* pre_tables )
359
+
360
+ if max_key2 > max_key1 :
361
+ post_tables = [t .new (min_key = max_key1 , max_key = max_key2 ) for t in (table1 , table2 )]
362
+ yield from self ._bisect_and_diff_tables (* post_tables )
363
+
364
+ def _parse_key_range_result (self , key_type , key_range ):
365
+ mn , mx = key_range
366
+ cls = key_type .python_type
367
+ # We add 1 because our ranges are exclusive of the end (like in Python)
368
+ try :
369
+ return cls (mn ), cls (mx ) + 1
370
+ except (TypeError , ValueError ) as e :
371
+ raise type (e )(f"Cannot apply { key_type } to { mn } , { mx } ." ) from e
354
372
355
373
def _validate_and_adjust_columns (self , table1 , table2 ):
356
374
for c in table1 ._relevant_columns :
@@ -474,12 +492,26 @@ def _diff_tables(self, table1, table2, level=0, segment_index=None, segment_coun
474
492
if checksum1 != checksum2 :
475
493
yield from self ._bisect_and_diff_tables (table1 , table2 , level = level , max_rows = max (count1 , count2 ))
476
494
477
- def _thread_map (self , func , iter ):
495
+ def _thread_map (self , func , iterable ):
496
+ if not self .threaded :
497
+ return map (func , iterable )
498
+
499
+ with ThreadPoolExecutor (max_workers = self .max_threadpool_size ) as task_pool :
500
+ return task_pool .map (func , iterable )
501
+
502
+ def _threaded_call (self , func , iterable ):
503
+ "Calls a method for each object in iterable."
504
+ return list (self ._thread_map (methodcaller (func ), iterable ))
505
+
506
+ def _thread_as_completed (self , func , iterable ):
478
507
if not self .threaded :
479
- return map (func , iter )
508
+ return map (func , iterable )
480
509
481
- task_pool = ThreadPoolExecutor (max_workers = self .max_threadpool_size )
482
- return task_pool .map (func , iter )
510
+ with ThreadPoolExecutor (max_workers = self .max_threadpool_size ) as task_pool :
511
+ futures = [task_pool .submit (func , item ) for item in iterable ]
512
+ for future in as_completed (futures ):
513
+ yield future .result ()
483
514
484
- def _threaded_call (self , func , iter ):
485
- return list (self ._thread_map (methodcaller (func ), iter ))
515
+ def _threaded_call_as_completed (self , func , iterable ):
516
+ "Calls a method for each object in iterable. Returned in order of completion."
517
+ return self ._thread_as_completed (methodcaller (func ), iterable )
0 commit comments