diff --git a/sqllineage/cli.py b/sqllineage/cli.py index 91593496..857de5cf 100644 --- a/sqllineage/cli.py +++ b/sqllineage/cli.py @@ -99,6 +99,11 @@ def main(args=None) -> None: type=str, ) args = parser.parse_args(args) + metadata_provider = ( + SQLAlchemyMetaDataProvider(args.sqlalchemy_url) + if args.sqlalchemy_url + else DummyMetaDataProvider() + ) if args.e and args.f: warnings.warn("Both -e and -f options are specified. -e option will be ignored") if args.f or args.e: @@ -106,11 +111,7 @@ def main(args=None) -> None: runner = LineageRunner( sql, dialect=args.dialect, - metadata_provider=( - SQLAlchemyMetaDataProvider(args.sqlalchemy_url) - if args.sqlalchemy_url - else DummyMetaDataProvider() - ), + metadata_provider=metadata_provider, verbose=args.verbose, draw_options={ "host": args.host, @@ -126,7 +127,13 @@ def main(args=None) -> None: else: runner.print_table_lineage() elif args.graph_visualization: - return draw_lineage_graph(**{"host": args.host, "port": args.port}) + return draw_lineage_graph( + **{ + "host": args.host, + "port": args.port, + "metadata_provider": metadata_provider, + } + ) elif args.dialects: dialects = [] for _, supported_dialects in LineageRunner.supported_dialects().items(): diff --git a/sqllineage/drawing.py b/sqllineage/drawing.py index 6b40f385..961667ec 100644 --- a/sqllineage/drawing.py +++ b/sqllineage/drawing.py @@ -19,6 +19,7 @@ from sqllineage import DEFAULT_DIALECT, DEFAULT_HOST, DEFAULT_PORT, STATIC_FOLDER from sqllineage.config import SQLLineageConfig +from sqllineage.core.metadata.dummy import DummyMetaDataProvider from sqllineage.exceptions import SQLLineageException from sqllineage.utils.constant import LineageLevel from sqllineage.utils.helpers import extract_sql_from_args @@ -30,6 +31,7 @@ class SQLLineageApp: def __init__(self) -> None: self.routes: Dict[str, Callable[[Dict[str, Any]], Dict[str, Any]]] = {} self.root_path = Path(SQLLineageConfig.DIRECTORY) + self.metadata_provider = DummyMetaDataProvider() def route(self, path: str): def wrapper(handler): @@ -165,7 +167,9 @@ def lineage(payload): req_args = Namespace(**payload) sql = extract_sql_from_args(req_args) dialect = getattr(req_args, "dialect", DEFAULT_DIALECT) - lr = LineageRunner(sql, dialect=dialect, verbose=True) + lr = LineageRunner( + sql, dialect=dialect, verbose=True, metadata_provider=app.metadata_provider + ) data = { "verbose": str(lr), "dag": lr.to_cytoscape(), @@ -206,8 +210,10 @@ def draw_lineage_graph(**kwargs) -> None: port = kwargs.pop("port", DEFAULT_PORT) querystring = urlencode({k: v for k, v in kwargs.items() if v}) path = f"/?{querystring}" if querystring else "/" - if "f" in kwargs: - app.root_path = Path(kwargs["f"]).parent + if f := kwargs.get("f"): + app.root_path = Path(f).parent + if metadata_provider := kwargs.get("metadata_provider"): + app.metadata_provider = metadata_provider with make_server(host, port, app) as httpd: print(f" * SQLLineage Running on http://{host}:{port}{path}") httpd.serve_forever() diff --git a/sqllineage/runner.py b/sqllineage/runner.py index 9f9e472b..9d666b53 100644 --- a/sqllineage/runner.py +++ b/sqllineage/runner.py @@ -1,7 +1,7 @@ import logging import warnings from collections import OrderedDict -from typing import Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple from sqllineage import DEFAULT_DIALECT, SQLPARSE_DIALECT from sqllineage.config import SQLLineageConfig @@ -41,7 +41,7 @@ def __init__( metadata_provider: MetaDataProvider = DummyMetaDataProvider(), verbose: bool = False, silent_mode: bool = False, - draw_options: Optional[Dict[str, str]] = None, + draw_options: Optional[Dict[str, Any]] = None, ): """ The entry point of SQLLineage after command line options are parsed. @@ -120,6 +120,7 @@ def draw(self) -> None: draw_options.pop("f", None) draw_options["e"] = self._sql draw_options["dialect"] = self._dialect + draw_options["metadata_provider"] = self._metadata_provider return draw_lineage_graph(**draw_options) @lazy_method diff --git a/tests/core/test_cli.py b/tests/core/test_cli.py index eb1a2be2..7a836362 100644 --- a/tests/core/test_cli.py +++ b/tests/core/test_cli.py @@ -21,6 +21,7 @@ def test_cli_dummy(_): main(["-f", sql_file, "-g"]) main(["-f", sql_file, "--silent_mode"]) main(["-f", sql_file, "--sqlalchemy_url=sqlite:///:memory:"]) + main(["--sqlalchemy_url=sqlite:///:memory:", "-g"]) break main(["-g"]) main(["-ds"])