Skip to content

Commit aa0481f

Browse files
Enhanced install script to enforce usage of a warehouse or cluster when skip-validation is set to False (#213)
1 parent d39ce9a commit aa0481f

File tree

10 files changed

+203
-82
lines changed

10 files changed

+203
-82
lines changed

labs.yml

-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ install:
66
require_running_cluster: false
77
require_databricks_connect: true
88
script: src/databricks/labs/remorph/install.py
9-
warehouse_types: ["SERVERLESS", "PRO"]
109
uninstall:
1110
script: src/databricks/labs/remorph/uninstall.py
1211
entrypoint: src/databricks/labs/remorph/cli.py

src/databricks/labs/remorph/cli.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from databricks.labs.blueprint.cli import App
55
from databricks.labs.blueprint.entrypoint import get_logger
6+
from databricks.labs.blueprint.installation import Installation
67
from databricks.sdk import WorkspaceClient
78

89
from databricks.labs.remorph.config import MorphConfig
@@ -29,6 +30,10 @@ def transpile(
2930
):
3031
"""transpiles source dialect to databricks dialect"""
3132
logger.info(f"user: {w.current_user.me()}")
33+
installation = Installation.current(w, 'remorph')
34+
default_config = installation.load(MorphConfig)
35+
36+
# TODO refactor cli based on the default config
3237

3338
if source.lower() not in {"snowflake", "tsql"}:
3439
raise_validation_exception(
@@ -37,20 +42,21 @@ def transpile(
3742
if not os.path.exists(input_sql) or input_sql in {None, ""}:
3843
raise_validation_exception(f"Error: Invalid value for '--input_sql': Path '{input_sql}' does not exist.")
3944
if output_folder == "":
40-
output_folder = None
45+
output_folder = default_config.output_folder if default_config.output_folder else None
4146
if skip_validation.lower() not in {"true", "false"}:
4247
raise_validation_exception(
4348
f"Error: Invalid value for '--skip_validation': '{skip_validation}' is not one of 'true', 'false'. "
4449
)
4550

51+
sdk_config = default_config.sdk_config if default_config.sdk_config else None
4652
config = MorphConfig(
4753
source=source.lower(),
4854
input_sql=input_sql,
4955
output_folder=output_folder,
5056
skip_validation=skip_validation.lower() == "true", # convert to bool
5157
catalog_name=catalog_name,
5258
schema_name=schema_name,
53-
sdk_config=w.config,
59+
sdk_config=sdk_config,
5460
)
5561

5662
status = morph(w, config)

src/databricks/labs/remorph/config.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
import logging
22
from dataclasses import dataclass
33

4-
from databricks.sdk.core import Config
5-
64
logger = logging.getLogger(__name__)
75

86

@@ -12,7 +10,7 @@ class MorphConfig:
1210
__version__ = 1
1311

1412
source: str
15-
sdk_config: Config | None
13+
sdk_config: dict[str, str] | None
1614
input_sql: str | None = None
1715
output_folder: str | None = None
1816
skip_validation: bool = False

src/databricks/labs/remorph/helpers/db_sql.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,16 @@
1616

1717

1818
def get_sql_backend(ws: WorkspaceClient, config: MorphConfig) -> SqlBackend:
19-
sdk_config = ws.config
20-
warehouse_id = isinstance(sdk_config.warehouse_id, str) and sdk_config.warehouse_id
19+
sdk_config = config.sdk_config
20+
warehouse_id = sdk_config.get("warehouse_id", None) if sdk_config else None
21+
cluster_id = sdk_config.get("cluster_id", None) if sdk_config else None
2122
catalog_name = config.catalog_name
2223
schema_name = config.schema_name
2324
if warehouse_id:
2425
sql_backend = StatementExecutionBackend(ws, warehouse_id, catalog=catalog_name, schema=schema_name)
2526
else:
27+
# assigning cluster id explicitly to the config as user can provide them during installation
28+
ws.config.cluster_id = cluster_id if cluster_id else ws.config.cluster_id
2629
sql_backend = RuntimeBackend() if "DATABRICKS_RUNTIME_VERSION" in os.environ else DatabricksConnectBackend(ws)
2730
try:
2831
sql_backend.execute(f"use catalog {catalog_name}")

src/databricks/labs/remorph/install.py

+60-12
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import logging
22
import os
3+
import time
34
import webbrowser
45
from datetime import timedelta
56
from pathlib import Path
@@ -12,13 +13,19 @@
1213
from databricks.sdk import WorkspaceClient
1314
from databricks.sdk.errors import NotFound
1415
from databricks.sdk.retries import retried
16+
from databricks.sdk.service.sql import (
17+
CreateWarehouseRequestWarehouseType,
18+
EndpointInfoWarehouseType,
19+
SpotInstancePolicy,
20+
)
1521

1622
from databricks.labs.remorph.__about__ import __version__
1723
from databricks.labs.remorph.config import MorphConfig
1824

1925
logger = logging.getLogger(__name__)
2026

2127
PRODUCT_INFO = ProductInfo(__file__)
28+
WAREHOUSE_PREFIX = "Remorph Transpiler Validation"
2229

2330

2431
class WorkspaceInstaller:
@@ -50,38 +57,79 @@ def configure(self) -> MorphConfig:
5057
logger.debug(f"Cannot find previous installation: {err}")
5158
logger.info("Please answer a couple of questions to configure Remorph")
5259

60+
# default params
61+
catalog_name = "transpiler_test"
62+
schema_name = "convertor_test"
63+
ws_config = None
64+
5365
source_prompt = self._prompts.choice("Select the source", ["snowflake", "tsql"])
5466
source = source_prompt.lower()
5567

5668
skip_validation = self._prompts.confirm("Do you want to Skip Validation")
5769

58-
catalog_name = self._prompts.question("Enter catalog_name")
59-
60-
try:
61-
self._catalog_setup.get(catalog_name)
62-
except NotFound:
63-
self.setup_catalog(catalog_name)
70+
if not skip_validation:
71+
ws_config = self._configure_runtime()
72+
catalog_name = self._prompts.question("Enter catalog_name")
73+
try:
74+
self._catalog_setup.get(catalog_name)
75+
except NotFound:
76+
self.setup_catalog(catalog_name)
6477

65-
schema_name = self._prompts.question("Enter schema_name")
78+
schema_name = self._prompts.question("Enter schema_name")
6679

67-
try:
68-
self._catalog_setup.get_schema(f"{catalog_name}.{schema_name}")
69-
except NotFound:
70-
self.setup_schema(catalog_name, schema_name)
80+
try:
81+
self._catalog_setup.get_schema(f"{catalog_name}.{schema_name}")
82+
except NotFound:
83+
self.setup_schema(catalog_name, schema_name)
7184

7285
config = MorphConfig(
7386
source=source,
7487
skip_validation=skip_validation,
7588
catalog_name=catalog_name,
7689
schema_name=schema_name,
77-
sdk_config=None,
90+
sdk_config=ws_config,
7891
)
7992

8093
ws_file_url = self._installation.save(config)
8194
if self._prompts.confirm("Open config file in the browser and continue installing?"):
8295
webbrowser.open(ws_file_url)
8396
return config
8497

98+
def _configure_runtime(self) -> dict[str, str]:
99+
if self._prompts.confirm("Do you want to use SQL Warehouse for validation?"):
100+
warehouse_id = self._configure_warehouse()
101+
return {"warehouse_id": warehouse_id}
102+
103+
if self._ws.config.cluster_id:
104+
logger.info(f"Using cluster {self._ws.config.cluster_id} for validation")
105+
return {"cluster": self._ws.config.cluster_id}
106+
107+
cluster_id = self._prompts.question("Enter a valid cluster_id to proceed")
108+
return {"cluster": cluster_id}
109+
110+
def _configure_warehouse(self):
111+
def warehouse_type(_):
112+
return _.warehouse_type.value if not _.enable_serverless_compute else "SERVERLESS"
113+
114+
pro_warehouses = {"[Create new PRO SQL warehouse]": "create_new"} | {
115+
f"{_.name} ({_.id}, {warehouse_type(_)}, {_.state.value})": _.id
116+
for _ in self._ws.warehouses.list()
117+
if _.warehouse_type == EndpointInfoWarehouseType.PRO
118+
}
119+
warehouse_id = self._prompts.choice_from_dict(
120+
"Select PRO or SERVERLESS SQL warehouse to run validation on", pro_warehouses
121+
)
122+
if warehouse_id == "create_new":
123+
new_warehouse = self._ws.warehouses.create(
124+
name=f"{WAREHOUSE_PREFIX} {time.time_ns()}",
125+
spot_instance_policy=SpotInstancePolicy.COST_OPTIMIZED,
126+
warehouse_type=CreateWarehouseRequestWarehouseType.PRO,
127+
cluster_size="Small",
128+
max_num_clusters=1,
129+
)
130+
warehouse_id = new_warehouse.id
131+
return warehouse_id
132+
85133
@retried(on=[NotFound], timeout=timedelta(minutes=5))
86134
def setup_catalog(self, catalog_name: str):
87135
allow_catalog_creation = self._prompts.confirm(

src/databricks/labs/remorph/transpiler/execute.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,10 @@ def morph(workspace_client: WorkspaceClient, config: MorphConfig):
128128
skip_validation = config.skip_validation
129129
status = []
130130
result = MorphStatus([], 0, 0, 0, [])
131-
validator = Validator(db_sql.get_sql_backend(workspace_client, config))
131+
validator = None
132+
if not config.skip_validation:
133+
validator = Validator(db_sql.get_sql_backend(workspace_client, config))
134+
132135
if input_sql.is_file():
133136
if is_sql_file(input_sql):
134137
msg = f"Processing for sqls under this file: {input_sql}"
@@ -157,16 +160,14 @@ def morph(workspace_client: WorkspaceClient, config: MorphConfig):
157160
validate_error_count = result.validate_error_count
158161

159162
error_list_count = parse_error_count + validate_error_count
160-
161163
if not skip_validation:
162164
logger.info(f"No of Sql Failed while Validating: {validate_error_count}")
163165

166+
error_log_file = "None"
164167
if error_list_count > 0:
165168
error_log_file = Path.cwd() / f"err_{os.getpid()}.lst"
166169
with error_log_file.open("a") as e:
167170
e.writelines(f"{err}\n" for err in result.error_log_list)
168-
else:
169-
error_log_file = "None"
170171

171172
status.append(
172173
{

tests/unit/conftest.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,10 @@ def mock_workspace_client():
3333
yield client
3434

3535

36-
@pytest.fixture(scope="session")
37-
def morph_config(mock_databricks_config):
36+
@pytest.fixture()
37+
def morph_config():
3838
yield MorphConfig(
39-
sdk_config=mock_databricks_config,
39+
sdk_config={"cluster_id": "test_cluster"},
4040
source="snowflake",
4141
input_sql="input_sql",
4242
output_folder="output_folder",

tests/unit/helpers/test_db_sql.py

+23-19
Original file line numberDiff line numberDiff line change
@@ -6,64 +6,68 @@
66
from databricks.labs.remorph.helpers.db_sql import get_sql_backend
77

88

9-
@pytest.mark.usefixtures("mock_workspace_client", "morph_config")
9+
@pytest.fixture()
10+
def morph_config_sqlbackend(morph_config):
11+
return morph_config
12+
13+
1014
@patch('databricks.labs.remorph.helpers.db_sql.StatementExecutionBackend')
1115
def test_get_sql_backend_with_warehouse_id(
1216
stmt_execution_backend,
1317
mock_workspace_client,
14-
morph_config,
18+
morph_config_sqlbackend,
1519
):
16-
mock_workspace_client.config.warehouse_id = "test_warehouse_id"
17-
sql_backend = get_sql_backend(mock_workspace_client, morph_config)
20+
morph_config_sqlbackend.sdk_config = {"warehouse_id": "test_warehouse_id"}
21+
sql_backend = get_sql_backend(mock_workspace_client, morph_config_sqlbackend)
1822
stmt_execution_backend.assert_called_once_with(
1923
mock_workspace_client,
2024
"test_warehouse_id",
21-
catalog=morph_config.catalog_name,
22-
schema=morph_config.schema_name,
25+
catalog=morph_config_sqlbackend.catalog_name,
26+
schema=morph_config_sqlbackend.schema_name,
2327
)
2428
assert isinstance(sql_backend, stmt_execution_backend.return_value.__class__)
2529

2630

27-
@pytest.mark.usefixtures("mock_workspace_client", "morph_config")
2831
@patch('databricks.labs.remorph.helpers.db_sql.DatabricksConnectBackend')
2932
def test_get_sql_backend_without_warehouse_id(
3033
databricks_connect_backend,
3134
mock_workspace_client,
32-
morph_config,
35+
morph_config_sqlbackend,
3336
):
3437
mock_dbc_backend_instance = databricks_connect_backend.return_value
35-
sql_backend = get_sql_backend(mock_workspace_client, morph_config)
38+
# morph config mock object has cluster id
39+
sql_backend = get_sql_backend(mock_workspace_client, morph_config_sqlbackend)
3640
databricks_connect_backend.assert_called_once_with(mock_workspace_client)
37-
mock_dbc_backend_instance.execute.assert_any_call(f"use catalog {morph_config.catalog_name}")
38-
mock_dbc_backend_instance.execute.assert_any_call(f"use {morph_config.schema_name}")
41+
mock_dbc_backend_instance.execute.assert_any_call(f"use catalog {morph_config_sqlbackend.catalog_name}")
42+
mock_dbc_backend_instance.execute.assert_any_call(f"use {morph_config_sqlbackend.schema_name}")
3943
assert isinstance(sql_backend, databricks_connect_backend.return_value.__class__)
4044

4145

42-
@pytest.mark.usefixtures("mock_workspace_client", "morph_config", "monkeypatch")
46+
@pytest.mark.usefixtures("monkeypatch")
4347
@patch('databricks.labs.remorph.helpers.db_sql.RuntimeBackend')
4448
def test_get_sql_backend_without_warehouse_id_in_notebook(
4549
runtime_backend,
4650
mock_workspace_client,
47-
morph_config,
51+
morph_config_sqlbackend,
4852
monkeypatch,
4953
):
5054
monkeypatch.setenv("DATABRICKS_RUNTIME_VERSION", "14.3")
5155
mock_runtime_backend_instance = runtime_backend.return_value
52-
sql_backend = get_sql_backend(mock_workspace_client, morph_config)
56+
morph_config_sqlbackend.sdk_config = None
57+
sql_backend = get_sql_backend(mock_workspace_client, morph_config_sqlbackend)
5358
runtime_backend.assert_called_once()
54-
mock_runtime_backend_instance.execute.assert_any_call(f"use catalog {morph_config.catalog_name}")
55-
mock_runtime_backend_instance.execute.assert_any_call(f"use {morph_config.schema_name}")
59+
mock_runtime_backend_instance.execute.assert_any_call(f"use catalog {morph_config_sqlbackend.catalog_name}")
60+
mock_runtime_backend_instance.execute.assert_any_call(f"use {morph_config_sqlbackend.schema_name}")
5661
assert isinstance(sql_backend, runtime_backend.return_value.__class__)
5762

5863

59-
@pytest.mark.usefixtures("mock_workspace_client", "morph_config")
6064
@patch('databricks.labs.remorph.helpers.db_sql.DatabricksConnectBackend')
6165
def test_get_sql_backend_with_error(
6266
databricks_connect_backend,
6367
mock_workspace_client,
64-
morph_config,
68+
morph_config_sqlbackend,
6569
):
6670
mock_dbc_backend_instance = databricks_connect_backend.return_value
6771
mock_dbc_backend_instance.execute.side_effect = DatabricksError("Test error")
6872
with pytest.raises(DatabricksError):
69-
get_sql_backend(mock_workspace_client, morph_config)
73+
get_sql_backend(mock_workspace_client, morph_config_sqlbackend)

0 commit comments

Comments
 (0)