|
1 | 1 | import logging
|
2 | 2 | import os
|
| 3 | +import time |
3 | 4 | import webbrowser
|
4 | 5 | from datetime import timedelta
|
5 | 6 | from pathlib import Path
|
|
12 | 13 | from databricks.sdk import WorkspaceClient
|
13 | 14 | from databricks.sdk.errors import NotFound
|
14 | 15 | from databricks.sdk.retries import retried
|
| 16 | +from databricks.sdk.service.sql import ( |
| 17 | + CreateWarehouseRequestWarehouseType, |
| 18 | + EndpointInfoWarehouseType, |
| 19 | + SpotInstancePolicy, |
| 20 | +) |
15 | 21 |
|
16 | 22 | from databricks.labs.remorph.__about__ import __version__
|
17 | 23 | from databricks.labs.remorph.config import MorphConfig
|
18 | 24 |
|
19 | 25 | logger = logging.getLogger(__name__)
|
20 | 26 |
|
21 | 27 | PRODUCT_INFO = ProductInfo(__file__)
|
| 28 | +WAREHOUSE_PREFIX = "Remorph Transpiler Validation" |
22 | 29 |
|
23 | 30 |
|
24 | 31 | class WorkspaceInstaller:
|
@@ -50,38 +57,79 @@ def configure(self) -> MorphConfig:
|
50 | 57 | logger.debug(f"Cannot find previous installation: {err}")
|
51 | 58 | logger.info("Please answer a couple of questions to configure Remorph")
|
52 | 59 |
|
| 60 | + # default params |
| 61 | + catalog_name = "transpiler_test" |
| 62 | + schema_name = "convertor_test" |
| 63 | + ws_config = None |
| 64 | + |
53 | 65 | source_prompt = self._prompts.choice("Select the source", ["snowflake", "tsql"])
|
54 | 66 | source = source_prompt.lower()
|
55 | 67 |
|
56 | 68 | skip_validation = self._prompts.confirm("Do you want to Skip Validation")
|
57 | 69 |
|
58 |
| - catalog_name = self._prompts.question("Enter catalog_name") |
59 |
| - |
60 |
| - try: |
61 |
| - self._catalog_setup.get(catalog_name) |
62 |
| - except NotFound: |
63 |
| - self.setup_catalog(catalog_name) |
| 70 | + if not skip_validation: |
| 71 | + ws_config = self._configure_runtime() |
| 72 | + catalog_name = self._prompts.question("Enter catalog_name") |
| 73 | + try: |
| 74 | + self._catalog_setup.get(catalog_name) |
| 75 | + except NotFound: |
| 76 | + self.setup_catalog(catalog_name) |
64 | 77 |
|
65 |
| - schema_name = self._prompts.question("Enter schema_name") |
| 78 | + schema_name = self._prompts.question("Enter schema_name") |
66 | 79 |
|
67 |
| - try: |
68 |
| - self._catalog_setup.get_schema(f"{catalog_name}.{schema_name}") |
69 |
| - except NotFound: |
70 |
| - self.setup_schema(catalog_name, schema_name) |
| 80 | + try: |
| 81 | + self._catalog_setup.get_schema(f"{catalog_name}.{schema_name}") |
| 82 | + except NotFound: |
| 83 | + self.setup_schema(catalog_name, schema_name) |
71 | 84 |
|
72 | 85 | config = MorphConfig(
|
73 | 86 | source=source,
|
74 | 87 | skip_validation=skip_validation,
|
75 | 88 | catalog_name=catalog_name,
|
76 | 89 | schema_name=schema_name,
|
77 |
| - sdk_config=None, |
| 90 | + sdk_config=ws_config, |
78 | 91 | )
|
79 | 92 |
|
80 | 93 | ws_file_url = self._installation.save(config)
|
81 | 94 | if self._prompts.confirm("Open config file in the browser and continue installing?"):
|
82 | 95 | webbrowser.open(ws_file_url)
|
83 | 96 | return config
|
84 | 97 |
|
| 98 | + def _configure_runtime(self) -> dict[str, str]: |
| 99 | + if self._prompts.confirm("Do you want to use SQL Warehouse for validation?"): |
| 100 | + warehouse_id = self._configure_warehouse() |
| 101 | + return {"warehouse_id": warehouse_id} |
| 102 | + |
| 103 | + if self._ws.config.cluster_id: |
| 104 | + logger.info(f"Using cluster {self._ws.config.cluster_id} for validation") |
| 105 | + return {"cluster": self._ws.config.cluster_id} |
| 106 | + |
| 107 | + cluster_id = self._prompts.question("Enter a valid cluster_id to proceed") |
| 108 | + return {"cluster": cluster_id} |
| 109 | + |
| 110 | + def _configure_warehouse(self): |
| 111 | + def warehouse_type(_): |
| 112 | + return _.warehouse_type.value if not _.enable_serverless_compute else "SERVERLESS" |
| 113 | + |
| 114 | + pro_warehouses = {"[Create new PRO SQL warehouse]": "create_new"} | { |
| 115 | + f"{_.name} ({_.id}, {warehouse_type(_)}, {_.state.value})": _.id |
| 116 | + for _ in self._ws.warehouses.list() |
| 117 | + if _.warehouse_type == EndpointInfoWarehouseType.PRO |
| 118 | + } |
| 119 | + warehouse_id = self._prompts.choice_from_dict( |
| 120 | + "Select PRO or SERVERLESS SQL warehouse to run validation on", pro_warehouses |
| 121 | + ) |
| 122 | + if warehouse_id == "create_new": |
| 123 | + new_warehouse = self._ws.warehouses.create( |
| 124 | + name=f"{WAREHOUSE_PREFIX} {time.time_ns()}", |
| 125 | + spot_instance_policy=SpotInstancePolicy.COST_OPTIMIZED, |
| 126 | + warehouse_type=CreateWarehouseRequestWarehouseType.PRO, |
| 127 | + cluster_size="Small", |
| 128 | + max_num_clusters=1, |
| 129 | + ) |
| 130 | + warehouse_id = new_warehouse.id |
| 131 | + return warehouse_id |
| 132 | + |
85 | 133 | @retried(on=[NotFound], timeout=timedelta(minutes=5))
|
86 | 134 | def setup_catalog(self, catalog_name: str):
|
87 | 135 | allow_catalog_creation = self._prompts.confirm(
|
|
0 commit comments