Skip to content

Commit 85e9693

Browse files
committed
refactor: amended process pool to use executor supplied to duckdb pipeline/data contract rather than always instantiating new pool
1 parent 2e2f236 commit 85e9693

File tree

5 files changed

+37
-15
lines changed

5 files changed

+37
-15
lines changed

src/dve/core_engine/backends/implementations/duckdb/contract.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
# pylint: disable=R0903
44
import logging
55
from collections.abc import Iterator
6+
from concurrent.futures import Future, ProcessPoolExecutor, as_completed
67
from functools import partial
7-
from multiprocessing import Pool, cpu_count
8+
from multiprocessing import cpu_count
89
from typing import Any, Optional
910
from uuid import uuid4
1011

@@ -70,10 +71,12 @@ def __init__(
7071
connection: DuckDBPyConnection,
7172
logger: Optional[logging.Logger] = None,
7273
debug: bool = False,
74+
executor: Optional[ProcessPoolExecutor] = None,
7375
**kwargs: Any,
7476
):
7577
self.debug = debug
7678
self._connection = connection
79+
self._executor = ProcessPoolExecutor(cpu_count() - 1) if not executor else executor
7780
"""A bool indicating whether to enable debug logging."""
7881

7982
super().__init__(logger, **kwargs)
@@ -164,11 +167,13 @@ def apply_data_contract(
164167

165168
batches = pq.ParquetFile(entity_locations[entity_name]).iter_batches(10000)
166169
msg_count = 0
167-
with Pool(cpu_count() - 1) as pool:
168-
for msgs in pool.imap_unordered(row_validator_helper, batches):
169-
if msgs:
170-
msg_writer.write_queue.put(msgs)
171-
msg_count += len(msgs)
170+
futures: list[Future] = [
171+
self._executor.submit(row_validator_helper, batch) for batch in batches
172+
]
173+
for future in as_completed(futures):
174+
if msgs := future.result():
175+
msg_writer.write_queue.put(msgs)
176+
msg_count += len(msgs)
172177

173178
self.logger.info(f"Data contract found {msg_count} issues in {entity_name}")
174179

src/dve/pipeline/duckdb_pipeline.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""DuckDB implementation for `Pipeline` object."""
22

33
import logging
4+
from concurrent.futures import ProcessPoolExecutor
45
from typing import Optional
56

67
from duckdb import DuckDBPyConnection, DuckDBPyRelation
@@ -33,12 +34,13 @@ def __init__(
3334
reference_data_loader: Optional[type[BaseRefDataLoader]] = None,
3435
job_run_id: Optional[int] = None,
3536
logger: Optional[logging.Logger] = None,
37+
executor: Optional[ProcessPoolExecutor] = None,
3638
):
3739
self._connection = connection
3840
super().__init__(
3941
processed_files_path,
4042
audit_tables,
41-
DuckDBDataContract(connection=self._connection),
43+
DuckDBDataContract(connection=self._connection, executor=executor),
4244
DuckDBStepImplementations.register_udfs(connection=self._connection),
4345
rules_path,
4446
submitted_files_path,

tests/features/environment.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from concurrent.futures import ProcessPoolExecutor
2+
from multiprocessing import cpu_count
13
import shutil
24
import tempfile
35
from pathlib import Path
@@ -27,6 +29,7 @@ def before_all(context: Context):
2729
temp_dir = Path(context.dbfs_root.__enter__())
2830
dbfs_impl = DBFSFilesystemImplementation(temp_dir)
2931
add_implementation(dbfs_impl)
32+
context.process_pool = ProcessPoolExecutor(cpu_count() - 1)
3033

3134

3235
def before_scenario(context: Context, scenario: Scenario):
@@ -78,3 +81,4 @@ def after_all(context: Context):
7881

7982
context.connection.close()
8083
shutil.rmtree(context.ddb_db_file.parent)
84+
context.process_pool.shutdown(wait=True, cancel_futures=True)

tests/features/steps/steps_pipeline.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
77
"""
88
# pylint: disable=no-name-in-module
9-
from concurrent.futures import ThreadPoolExecutor
9+
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
1010
from functools import partial, reduce
1111
from itertools import chain
1212
import operator
@@ -75,6 +75,7 @@ def setup_duckdb_pipeline(
7575
dataset_id: str,
7676
processing_path: Path,
7777
schema_file_name: Optional[str] = None,
78+
executor: Optional[ProcessPoolExecutor] = None
7879
):
7980

8081
schema_file_name = f"{dataset_id}.dischema.json" if not schema_file_name else schema_file_name
@@ -97,6 +98,7 @@ def setup_duckdb_pipeline(
9798
rules_path=rules_path,
9899
submitted_files_path=processing_path.as_posix(),
99100
reference_data_loader=DuckDBRefDataLoader,
101+
executor=executor
100102
)
101103

102104

@@ -204,7 +206,7 @@ def add_pipeline_to_ctx(
204206
context: Context, implementation: str, schema_file_name: Optional[str] = None
205207
):
206208
pipeline_map: Dict[str, Callable] = {
207-
"duckdb": partial(setup_duckdb_pipeline, connection=context.connection),
209+
"duckdb": partial(setup_duckdb_pipeline, connection=context.connection, executor=context.process_pool),
208210
"spark": partial(setup_spark_pipeline, spark=context.spark_session),
209211
}
210212
if not implementation in pipeline_map:

tests/test_core_engine/test_backends/test_implementations/test_duckdb/test_data_contract.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
from concurrent.futures import ProcessPoolExecutor
12
import json
3+
from multiprocessing import cpu_count
24
from pathlib import Path
35
from typing import Any, Dict, List, Tuple
46

@@ -30,7 +32,13 @@
3032
temp_xml_file,
3133
)
3234

33-
def test_duckdb_data_contract_csv(temp_csv_file):
35+
@pytest.fixture(scope="module")
36+
def temp_process_pool_executor():
37+
with ProcessPoolExecutor(cpu_count() - 1) as pool:
38+
yield pool
39+
40+
41+
def test_duckdb_data_contract_csv(temp_csv_file, temp_process_pool_executor):
3442
uri, _, _, mdl = temp_csv_file
3543
connection = default_connection
3644

@@ -89,7 +97,7 @@ def test_duckdb_data_contract_csv(temp_csv_file):
8997
}
9098
entity_locations: Dict[str, URI] = {"test_ds": str(uri)}
9199

92-
data_contract: DuckDBDataContract = DuckDBDataContract(connection)
100+
data_contract: DuckDBDataContract = DuckDBDataContract(connection, executor=temp_process_pool_executor)
93101
entities, feedback_errors_uri, stage_successful = data_contract.apply_data_contract(get_parent(uri.as_posix()), entities, entity_locations, dc_meta)
94102
rel: DuckDBPyRelation = entities.get("test_ds")
95103
assert dict(zip(rel.columns, rel.dtypes)) == {
@@ -100,7 +108,7 @@ def test_duckdb_data_contract_csv(temp_csv_file):
100108
assert stage_successful
101109

102110

103-
def test_duckdb_data_contract_xml(temp_xml_file):
111+
def test_duckdb_data_contract_xml(temp_xml_file, temp_process_pool_executor):
104112
uri, header_model, header_data, class_model, class_data = temp_xml_file
105113
connection = default_connection
106114
contract_meta = json.dumps(
@@ -187,7 +195,7 @@ def test_duckdb_data_contract_xml(temp_xml_file):
187195
reporting_fields={"test_header": ["school"], "test_class_info": ["year"]},
188196
)
189197

190-
data_contract: DuckDBDataContract = DuckDBDataContract(connection)
198+
data_contract: DuckDBDataContract = DuckDBDataContract(connection, executor=temp_process_pool_executor)
191199
entities, feedback_errors_uri, stage_successful = data_contract.apply_data_contract(get_parent(uri.as_posix()), entities, entity_locations, dc_meta)
192200
header_rel: DuckDBPyRelation = entities.get("test_header")
193201
header_expected_schema: Dict[str, DuckDBPyType] = {
@@ -327,10 +335,11 @@ def test_ddb_data_contract_read_nested_parquet(nested_all_string_parquet):
327335
}
328336

329337
def test_duckdb_data_contract_custom_error_details(nested_all_string_parquet_w_errors,
330-
nested_parquet_custom_dc_err_details):
338+
nested_parquet_custom_dc_err_details,
339+
temp_process_pool_executor):
331340
parquet_uri, contract_meta, _ = nested_all_string_parquet_w_errors
332341
connection = default_connection
333-
data_contract = DuckDBDataContract(connection)
342+
data_contract = DuckDBDataContract(connection, executor=temp_process_pool_executor)
334343

335344
entity = data_contract.read_parquet(path=parquet_uri)
336345
assert entity.count("*").fetchone()[0] == 2

0 commit comments

Comments
 (0)