forked from sepideh-abedini/MaskSQL
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
67 lines (47 loc) · 1.58 KB
/
main.py
File metadata and controls
67 lines (47 loc) · 1.58 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
"""Main entry point for the MaskSQL pipeline."""
import argparse
import asyncio
import logging
import shutil
from pathlib import Path
from dotenv import load_dotenv
load_dotenv()
from src.masksql import MaskSQL # noqa: E402
from src.utils.logging import configure_logging # noqa: E402
logger = logging.getLogger(__name__)
def clean_cache_directory(cache_dir: str) -> None:
"""Clean intermediate files from the data directory.
Removes files matching the pattern [0-9]*_* but excludes files starting with 1_*.
This is used to clean up intermediate pipeline output files while preserving
the initial input files.
Parameters
----------
cache_dir : str
Path to the cache directory to clean.
"""
cache_path = Path(cache_dir)
if not cache_path.exists():
logger.error(f"Data directory does not exist: {cache_dir}")
return
shutil.rmtree(cache_path)
logger.info("Cleanup complete")
async def main() -> None:
"""Run the MaskSQL main pipeline."""
parser = argparse.ArgumentParser(description="MaskSQL")
parser.add_argument(
"--clean",
action="store_true",
help="Clean cached files from cache directory",
)
parser.add_argument(
"-c", "--config", default="configs/conf.yaml", help="Path to config file"
)
args = parser.parse_args()
configure_logging()
mask_sql = MaskSQL.from_config(args.config)
if args.clean:
clean_cache_directory(mask_sql.conf.cache_dir)
else:
await mask_sql.evaluate()
if __name__ == "__main__":
asyncio.run(main())