Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 992651d

Browse files
authored
Merge pull request #467 from dlawin/issue_460
expand --cloud output by polling for results
2 parents b9cae0f + f7d5d8c commit 992651d

File tree

5 files changed

+283
-172
lines changed

5 files changed

+283
-172
lines changed

data_diff/dbt.py

Lines changed: 147 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from dataclasses import dataclass
1010
from packaging.version import parse as parse_version
1111
from typing import List, Optional, Dict, Tuple, Set
12-
from .utils import getLogger
12+
from .utils import dbt_diff_string_template, getLogger
1313
from .version import __version__
1414
from pathlib import Path
1515

@@ -69,16 +69,16 @@ class DiffVars:
6969
dev_path: List[str]
7070
prod_path: List[str]
7171
primary_keys: List[str]
72-
datasource_id: str
7372
connection: Dict[str, str]
7473
threads: Optional[int]
7574

7675

7776
def dbt_diff(
7877
profiles_dir_override: Optional[str] = None, project_dir_override: Optional[str] = None, is_cloud: bool = False
7978
) -> None:
79+
diff_threads = []
8080
set_entrypoint_name("CLI-dbt")
81-
dbt_parser = DbtParser(profiles_dir_override, project_dir_override, is_cloud)
81+
dbt_parser = DbtParser(profiles_dir_override, project_dir_override)
8282
models = dbt_parser.get_models()
8383
datadiff_variables = dbt_parser.get_datadiff_variables()
8484
config_prod_database = datadiff_variables.get("prod_database")
@@ -89,7 +89,17 @@ def dbt_diff(
8989
custom_schemas = True if custom_schemas is None else custom_schemas
9090
set_dbt_user_id(dbt_parser.dbt_user_id)
9191

92-
if not is_cloud:
92+
if is_cloud:
93+
if datasource_id is None:
94+
raise ValueError(
95+
"Datasource ID not found, include it as a dbt variable in the dbt_project.yml. \nvars:\n data_diff:\n datasource_id: 1234"
96+
)
97+
datafold_host, url, api_key = _setup_cloud_diff()
98+
99+
# exit so the user can set the key
100+
if not api_key:
101+
return
102+
else:
93103
dbt_parser.set_connection()
94104

95105
if config_prod_database is None:
@@ -98,14 +108,14 @@ def dbt_diff(
98108
)
99109

100110
for model in models:
101-
diff_vars = _get_diff_vars(
102-
dbt_parser, config_prod_database, config_prod_schema, model, datasource_id, custom_schemas
103-
)
104-
105-
if is_cloud and len(diff_vars.primary_keys) > 0:
106-
_cloud_diff(diff_vars)
107-
elif not is_cloud and len(diff_vars.primary_keys) > 0:
108-
_local_diff(diff_vars)
111+
diff_vars = _get_diff_vars(dbt_parser, config_prod_database, config_prod_schema, model, custom_schemas)
112+
113+
if diff_vars.primary_keys:
114+
if is_cloud:
115+
diff_thread = run_as_daemon(_cloud_diff, diff_vars, datasource_id, datafold_host, url, api_key)
116+
diff_threads.append(diff_thread)
117+
else:
118+
_local_diff(diff_vars)
109119
else:
110120
rich.print(
111121
"[red]"
@@ -116,6 +126,11 @@ def dbt_diff(
116126
+ "Skipped due to unknown primary key. Add uniqueness tests, meta, or tags.\n"
117127
)
118128

129+
# wait for all threads
130+
if diff_threads:
131+
for thread in diff_threads:
132+
thread.join()
133+
119134
rich.print("Diffs Complete!")
120135

121136

@@ -124,7 +139,6 @@ def _get_diff_vars(
124139
config_prod_database: Optional[str],
125140
config_prod_schema: Optional[str],
126141
model,
127-
datasource_id: int,
128142
custom_schemas: bool,
129143
) -> DiffVars:
130144
dev_database = model.database
@@ -149,9 +163,7 @@ def _get_diff_vars(
149163
dev_qualified_list = [dev_database, dev_schema, model.alias]
150164
prod_qualified_list = [prod_database, prod_schema, model.alias]
151165

152-
return DiffVars(
153-
dev_qualified_list, prod_qualified_list, primary_keys, datasource_id, dbt_parser.connection, dbt_parser.threads
154-
)
166+
return DiffVars(dev_qualified_list, prod_qualified_list, primary_keys, dbt_parser.connection, dbt_parser.threads)
155167

156168

157169
def _local_diff(diff_vars: DiffVars) -> None:
@@ -221,33 +233,10 @@ def _local_diff(diff_vars: DiffVars) -> None:
221233
)
222234

223235

224-
def _cloud_diff(diff_vars: DiffVars) -> None:
225-
datafold_host = os.environ.get("DATAFOLD_HOST")
226-
if datafold_host is None:
227-
datafold_host = "https://app.datafold.com"
228-
datafold_host = datafold_host.rstrip("/")
229-
rich.print(f"Cloud datafold host: {datafold_host}")
230-
231-
api_key = os.environ.get("DATAFOLD_API_KEY")
232-
if not api_key:
233-
rich.print("[red]API key not found, add it as an environment variable called DATAFOLD_API_KEY.")
234-
yes_or_no = Confirm.ask("Would you like to generate a new API key?")
235-
if yes_or_no:
236-
webbrowser.open(f"{datafold_host}/login?next={datafold_host}/users/me")
237-
return
238-
else:
239-
raise ValueError("Cannot diff because the API key is not provided")
240-
241-
if diff_vars.datasource_id is None:
242-
raise ValueError(
243-
"Datasource ID not found, include it as a dbt variable in the dbt_project.yml. \nvars:\n data_diff:\n datasource_id: 1234"
244-
)
245-
246-
url = f"{datafold_host}/api/v1/datadiffs"
247-
236+
def _cloud_diff(diff_vars: DiffVars, datasource_id: int, datafold_host: str, url: str, api_key: str) -> None:
248237
payload = {
249-
"data_source1_id": diff_vars.datasource_id,
250-
"data_source2_id": diff_vars.datasource_id,
238+
"data_source1_id": datasource_id,
239+
"data_source2_id": datasource_id,
251240
"table1": diff_vars.prod_path,
252241
"table2": diff_vars.dev_path,
253242
"pk_columns": diff_vars.primary_keys,
@@ -258,27 +247,60 @@ def _cloud_diff(diff_vars: DiffVars) -> None:
258247
"Content-Type": "application/json",
259248
}
260249
if is_tracking_enabled():
261-
event_json = create_start_event_json({"is_cloud": True, "datasource_id": diff_vars.datasource_id})
250+
event_json = create_start_event_json({"is_cloud": True, "datasource_id": datasource_id})
262251
run_as_daemon(send_event_json, event_json)
263252

264253
start = time.monotonic()
265254
error = None
266255
diff_id = None
256+
diff_url = None
267257
try:
268-
response = requests.request("POST", url, headers=headers, json=payload, timeout=30)
269-
response.raise_for_status()
270-
data = response.json()
271-
diff_id = data["id"]
258+
diff_id = _cloud_submit_diff(url, payload, headers)
259+
summary_url = f"{url}/{diff_id}/summary_results"
260+
diff_results = _cloud_poll_and_get_summary_results(summary_url, headers)
261+
272262
diff_url = f"{datafold_host}/datadiffs/{diff_id}/overview"
273-
rich.print(
274-
"[red]"
275-
+ ".".join(diff_vars.prod_path)
276-
+ " <> "
277-
+ ".".join(diff_vars.dev_path)
278-
+ "[/] \n Diff in progress: \n "
279-
+ diff_url
280-
+ "\n"
281-
)
263+
264+
rows_added_count = diff_results["pks"]["exclusives"][1]
265+
rows_removed_count = diff_results["pks"]["exclusives"][0]
266+
267+
rows_updated = diff_results["values"]["rows_with_differences"]
268+
total_rows = diff_results["values"]["total_rows"]
269+
rows_unchanged = int(total_rows) - int(rows_updated)
270+
diff_percent_list = {
271+
x["column_name"]: str(x["match"]) + "%"
272+
for x in diff_results["values"]["columns_diff_stats"]
273+
if x["match"] != 100.0
274+
}
275+
276+
if any([rows_added_count, rows_removed_count, rows_updated]):
277+
diff_output = dbt_diff_string_template(
278+
rows_added_count,
279+
rows_removed_count,
280+
rows_updated,
281+
str(rows_unchanged),
282+
diff_percent_list,
283+
"Value Match Percent:",
284+
)
285+
rich.print(
286+
"[red]"
287+
+ ".".join(diff_vars.prod_path)
288+
+ " <> "
289+
+ ".".join(diff_vars.dev_path)
290+
+ f"[/]\n{diff_url}\n"
291+
+ diff_output
292+
+ "\n"
293+
)
294+
else:
295+
rich.print(
296+
"[red]"
297+
+ ".".join(diff_vars.prod_path)
298+
+ " <> "
299+
+ ".".join(diff_vars.dev_path)
300+
+ f"[/]\n{diff_url}\n"
301+
+ "[green]No row differences[/] \n"
302+
)
303+
282304
except BaseException as ex: # Catch KeyboardInterrupt too
283305
error = ex
284306
finally:
@@ -302,15 +324,81 @@ def _cloud_diff(diff_vars: DiffVars) -> None:
302324
send_event_json(event_json)
303325

304326
if error:
305-
raise error
327+
rich.print(
328+
"[red]"
329+
+ ".".join(diff_vars.prod_path)
330+
+ " <> "
331+
+ ".".join(diff_vars.dev_path) + "[/]\n"
332+
)
333+
if diff_id:
334+
diff_url = f"{datafold_host}/datadiffs/{diff_id}/overview"
335+
rich.print(f"{diff_url} \n")
336+
logger.error(error)
337+
338+
339+
def _setup_cloud_diff() -> Tuple[Optional[str]]:
340+
datafold_host = os.environ.get("DATAFOLD_HOST")
341+
if datafold_host is None:
342+
datafold_host = "https://app.datafold.com"
343+
datafold_host = datafold_host.rstrip("/")
344+
rich.print(f"Cloud datafold host: {datafold_host}\n")
345+
url = f"{datafold_host}/api/v1/datadiffs"
346+
347+
api_key = os.environ.get("DATAFOLD_API_KEY")
348+
if not api_key:
349+
rich.print("[red]API key not found, add it as an environment variable called DATAFOLD_API_KEY.")
350+
yes_or_no = Confirm.ask("Would you like to generate a new API key?")
351+
if yes_or_no:
352+
webbrowser.open(f"{datafold_host}/login?next={datafold_host}/users/me")
353+
return None, None, None
354+
else:
355+
raise ValueError("Cannot diff because the API key is not provided")
356+
357+
return datafold_host, url, api_key
358+
359+
360+
def _cloud_submit_diff(url, payload, headers) -> str:
361+
response = requests.request("POST", url, headers=headers, json=payload, timeout=30)
362+
response.raise_for_status()
363+
response_json = response.json()
364+
diff_id = str(response_json["id"])
365+
366+
if diff_id is None:
367+
raise Exception(f"Api response did not contain a diff_id: {str(response_json)}")
368+
return diff_id
369+
370+
371+
def _cloud_poll_and_get_summary_results(url, headers):
372+
summary_results = None
373+
start_time = time.time()
374+
sleep_interval = 5 # starts at 5 sec
375+
max_sleep_interval = 60
376+
max_wait_time = 300
377+
378+
while not summary_results:
379+
response = requests.request("GET", url, headers=headers, timeout=30)
380+
response.raise_for_status()
381+
response_json = response.json()
382+
383+
if response_json["status"] == "success":
384+
summary_results = response_json
385+
elif response_json["status"] == "failed":
386+
raise Exception(f"Diff failed: {str(response_json)}")
387+
388+
if time.time() - start_time > max_wait_time:
389+
raise Exception("Timed out waiting for diff results")
390+
391+
time.sleep(sleep_interval)
392+
sleep_interval = min(sleep_interval * 2, max_sleep_interval)
393+
394+
return summary_results
306395

307396

308397
class DbtParser:
309-
def __init__(self, profiles_dir_override: str, project_dir_override: str, is_cloud: bool) -> None:
398+
def __init__(self, profiles_dir_override: str, project_dir_override: str) -> None:
310399
self.parse_run_results, self.parse_manifest, self.ProfileRenderer, self.yaml = import_dbt()
311400
self.profiles_dir = Path(profiles_dir_override or default_profiles_dir())
312401
self.project_dir = Path(project_dir_override or default_project_dir())
313-
self.is_cloud = is_cloud
314402
self.connection = None
315403
self.project_dict = self.get_project_dict()
316404
self.manifest_obj = self.get_manifest_obj()

data_diff/diff_tables.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from data_diff.info_tree import InfoTree, SegmentInfo
1616

17-
from .utils import run_as_daemon, safezip, getLogger, truncate_error, Vector
17+
from .utils import dbt_diff_string_template, run_as_daemon, safezip, getLogger, truncate_error, Vector
1818
from .thread_utils import ThreadedYielder
1919
from .table_segment import TableSegment, create_mesh_from_points
2020
from .tracking import create_end_event_json, create_start_event_json, send_event_json, is_tracking_enabled
@@ -139,18 +139,14 @@ def get_stats_string(self, is_dbt: bool = False):
139139
diff_stats = self._get_stats(is_dbt)
140140

141141
if is_dbt:
142-
string_output = "\n| Rows Added\t| Rows Removed\n"
143-
string_output += "------------------------------------------------------------\n"
144-
145-
string_output += f"| {diff_stats.diff_by_sign['-']}\t\t| {diff_stats.diff_by_sign['+']}\n"
146-
string_output += "------------------------------------------------------------\n\n"
147-
string_output += f"Updated Rows: {diff_stats.diff_by_sign['!']}\n"
148-
string_output += f"Unchanged Rows: {diff_stats.unchanged}\n\n"
149-
150-
string_output += f"Values Updated:"
151-
152-
for k, v in diff_stats.extra_column_diffs.items():
153-
string_output += f"\n{k}: {v}"
142+
string_output = dbt_diff_string_template(
143+
diff_stats.diff_by_sign["-"],
144+
diff_stats.diff_by_sign["+"],
145+
diff_stats.diff_by_sign["!"],
146+
diff_stats.unchanged,
147+
diff_stats.extra_column_diffs,
148+
"Values Updated:",
149+
)
154150

155151
else:
156152
string_output = ""

data_diff/utils.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,3 +127,22 @@ def __sub__(self, other: "Vector"):
127127

128128
def __repr__(self) -> str:
129129
return "(%s)" % ", ".join(str(k) for k in self)
130+
131+
132+
def dbt_diff_string_template(
133+
rows_added: str, rows_removed: str, rows_updated: str, rows_unchanged: str, extra_info_dict: Dict, extra_info_str
134+
) -> str:
135+
string_output = "\n| Rows Added\t| Rows Removed\n"
136+
string_output += "------------------------------------------------------------\n"
137+
138+
string_output += f"| {rows_added}\t\t| {rows_removed}\n"
139+
string_output += "------------------------------------------------------------\n\n"
140+
string_output += f"Updated Rows: {rows_updated}\n"
141+
string_output += f"Unchanged Rows: {rows_unchanged}\n\n"
142+
143+
string_output += extra_info_str
144+
145+
for k, v in extra_info_dict.items():
146+
string_output += f"\n{k}: {v}"
147+
148+
return string_output

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,3 +80,6 @@ build-backend = "poetry.core.masonry.api"
8080

8181
[tool.poetry.scripts]
8282
data-diff = 'data_diff.__main__:main'
83+
84+
[tool.black]
85+
line-length = 120

0 commit comments

Comments
 (0)