|
| 1 | +from pathlib import Path |
| 2 | +import logging |
| 3 | +import yaml |
| 4 | +import duckdb |
| 5 | + |
| 6 | +from databricks.labs.remorph.assessments.profiler_config import PipelineConfig, Step |
| 7 | +from databricks.labs.remorph.connections.database_manager import DatabaseManager |
| 8 | + |
| 9 | +logger = logging.getLogger(__name__) |
| 10 | +logger.setLevel("INFO") |
| 11 | + |
| 12 | +DB_NAME = "profiler_extract.db" |
| 13 | + |
| 14 | + |
| 15 | +class PipelineClass: |
| 16 | + def __init__(self, config: PipelineConfig, executor: DatabaseManager): |
| 17 | + self.config = config |
| 18 | + self.executor = executor |
| 19 | + self.db_path_prefix = Path(config.extract_folder) |
| 20 | + |
| 21 | + def execute(self): |
| 22 | + logging.info(f"Pipeline initialized with config: {self.config.name}, version: {self.config.version}") |
| 23 | + for step in self.config.steps: |
| 24 | + if step.flag == "active": |
| 25 | + logging.debug(f"Executing step: {step.name}") |
| 26 | + self._execute_step(step) |
| 27 | + logging.info("Pipeline execution completed") |
| 28 | + |
| 29 | + def _execute_step(self, step: Step): |
| 30 | + logging.debug(f"Reading query from file: {step.extract_query}") |
| 31 | + with open(step.extract_query, 'r', encoding='utf-8') as file: |
| 32 | + query = file.read() |
| 33 | + |
| 34 | + # Execute the query using the database manager |
| 35 | + logging.info(f"Executing query: {query}") |
| 36 | + result = self.executor.execute_query(query) |
| 37 | + |
| 38 | + # Save the result to duckdb |
| 39 | + self._save_to_db(result, step.name, str(step.mode)) |
| 40 | + |
| 41 | + def _save_to_db(self, result, step_name: str, mode: str, batch_size: int = 1000): |
| 42 | + self._create_dir(self.db_path_prefix) |
| 43 | + conn = duckdb.connect(str(self.db_path_prefix) + '/' + DB_NAME) |
| 44 | + columns = result.keys() |
| 45 | + # TODO: Add support for figuring out data types from SQLALCHEMY result object result.cursor.description is not reliable |
| 46 | + schema = ' STRING, '.join(columns) + ' STRING' |
| 47 | + |
| 48 | + # Handle write modes |
| 49 | + if mode == 'overwrite': |
| 50 | + conn.execute(f"CREATE OR REPLACE TABLE {step_name} ({schema})") |
| 51 | + elif mode == 'append' and step_name not in conn.get_table_names(""): |
| 52 | + conn.execute(f"CREATE TABLE {step_name} ({schema})") |
| 53 | + |
| 54 | + # Batch insert using prepared statements |
| 55 | + placeholders = ', '.join(['?' for _ in columns]) |
| 56 | + insert_query = f"INSERT INTO {step_name} VALUES ({placeholders})" |
| 57 | + |
| 58 | + # Fetch and insert rows in batches |
| 59 | + while True: |
| 60 | + rows = result.fetchmany(batch_size) |
| 61 | + if not rows: |
| 62 | + break |
| 63 | + conn.executemany(insert_query, rows) |
| 64 | + |
| 65 | + conn.close() |
| 66 | + |
| 67 | + @staticmethod |
| 68 | + def _create_dir(dir_path: Path): |
| 69 | + if not Path(dir_path).exists(): |
| 70 | + dir_path.mkdir(parents=True, exist_ok=True) |
| 71 | + |
| 72 | + @staticmethod |
| 73 | + def load_config_from_yaml(file_path: str) -> PipelineConfig: |
| 74 | + with open(file_path, 'r', encoding='utf-8') as file: |
| 75 | + data = yaml.safe_load(file) |
| 76 | + steps = [Step(**step) for step in data['steps']] |
| 77 | + return PipelineConfig( |
| 78 | + name=data['name'], version=data['version'], extract_folder=data['extract_folder'], steps=steps |
| 79 | + ) |
0 commit comments