9
9
from dataclasses import dataclass
10
10
from packaging .version import parse as parse_version
11
11
from typing import List , Optional , Dict , Tuple , Set
12
- from .utils import getLogger
12
+ from .utils import dbt_diff_string_template , getLogger
13
13
from .version import __version__
14
14
from pathlib import Path
15
15
@@ -69,16 +69,16 @@ class DiffVars:
69
69
dev_path : List [str ]
70
70
prod_path : List [str ]
71
71
primary_keys : List [str ]
72
- datasource_id : str
73
72
connection : Dict [str , str ]
74
73
threads : Optional [int ]
75
74
76
75
77
76
def dbt_diff (
78
77
profiles_dir_override : Optional [str ] = None , project_dir_override : Optional [str ] = None , is_cloud : bool = False
79
78
) -> None :
79
+ diff_threads = []
80
80
set_entrypoint_name ("CLI-dbt" )
81
- dbt_parser = DbtParser (profiles_dir_override , project_dir_override , is_cloud )
81
+ dbt_parser = DbtParser (profiles_dir_override , project_dir_override )
82
82
models = dbt_parser .get_models ()
83
83
datadiff_variables = dbt_parser .get_datadiff_variables ()
84
84
config_prod_database = datadiff_variables .get ("prod_database" )
@@ -89,7 +89,17 @@ def dbt_diff(
89
89
custom_schemas = True if custom_schemas is None else custom_schemas
90
90
set_dbt_user_id (dbt_parser .dbt_user_id )
91
91
92
- if not is_cloud :
92
+ if is_cloud :
93
+ if datasource_id is None :
94
+ raise ValueError (
95
+ "Datasource ID not found, include it as a dbt variable in the dbt_project.yml. \n vars:\n data_diff:\n datasource_id: 1234"
96
+ )
97
+ datafold_host , url , api_key = _setup_cloud_diff ()
98
+
99
+ # exit so the user can set the key
100
+ if not api_key :
101
+ return
102
+ else :
93
103
dbt_parser .set_connection ()
94
104
95
105
if config_prod_database is None :
@@ -98,14 +108,14 @@ def dbt_diff(
98
108
)
99
109
100
110
for model in models :
101
- diff_vars = _get_diff_vars (
102
- dbt_parser , config_prod_database , config_prod_schema , model , datasource_id , custom_schemas
103
- )
104
-
105
- if is_cloud and len ( diff_vars . primary_keys ) > 0 :
106
- _cloud_diff ( diff_vars )
107
- elif not is_cloud and len ( diff_vars . primary_keys ) > 0 :
108
- _local_diff (diff_vars )
111
+ diff_vars = _get_diff_vars (dbt_parser , config_prod_database , config_prod_schema , model , custom_schemas )
112
+
113
+ if diff_vars . primary_keys :
114
+ if is_cloud :
115
+ diff_thread = run_as_daemon ( _cloud_diff , diff_vars , datasource_id , datafold_host , url , api_key )
116
+ diff_threads . append ( diff_thread )
117
+ else :
118
+ _local_diff (diff_vars )
109
119
else :
110
120
rich .print (
111
121
"[red]"
@@ -116,6 +126,11 @@ def dbt_diff(
116
126
+ "Skipped due to unknown primary key. Add uniqueness tests, meta, or tags.\n "
117
127
)
118
128
129
+ # wait for all threads
130
+ if diff_threads :
131
+ for thread in diff_threads :
132
+ thread .join ()
133
+
119
134
rich .print ("Diffs Complete!" )
120
135
121
136
@@ -124,7 +139,6 @@ def _get_diff_vars(
124
139
config_prod_database : Optional [str ],
125
140
config_prod_schema : Optional [str ],
126
141
model ,
127
- datasource_id : int ,
128
142
custom_schemas : bool ,
129
143
) -> DiffVars :
130
144
dev_database = model .database
@@ -149,9 +163,7 @@ def _get_diff_vars(
149
163
dev_qualified_list = [dev_database , dev_schema , model .alias ]
150
164
prod_qualified_list = [prod_database , prod_schema , model .alias ]
151
165
152
- return DiffVars (
153
- dev_qualified_list , prod_qualified_list , primary_keys , datasource_id , dbt_parser .connection , dbt_parser .threads
154
- )
166
+ return DiffVars (dev_qualified_list , prod_qualified_list , primary_keys , dbt_parser .connection , dbt_parser .threads )
155
167
156
168
157
169
def _local_diff (diff_vars : DiffVars ) -> None :
@@ -221,33 +233,10 @@ def _local_diff(diff_vars: DiffVars) -> None:
221
233
)
222
234
223
235
224
- def _cloud_diff (diff_vars : DiffVars ) -> None :
225
- datafold_host = os .environ .get ("DATAFOLD_HOST" )
226
- if datafold_host is None :
227
- datafold_host = "https://app.datafold.com"
228
- datafold_host = datafold_host .rstrip ("/" )
229
- rich .print (f"Cloud datafold host: { datafold_host } " )
230
-
231
- api_key = os .environ .get ("DATAFOLD_API_KEY" )
232
- if not api_key :
233
- rich .print ("[red]API key not found, add it as an environment variable called DATAFOLD_API_KEY." )
234
- yes_or_no = Confirm .ask ("Would you like to generate a new API key?" )
235
- if yes_or_no :
236
- webbrowser .open (f"{ datafold_host } /login?next={ datafold_host } /users/me" )
237
- return
238
- else :
239
- raise ValueError ("Cannot diff because the API key is not provided" )
240
-
241
- if diff_vars .datasource_id is None :
242
- raise ValueError (
243
- "Datasource ID not found, include it as a dbt variable in the dbt_project.yml. \n vars:\n data_diff:\n datasource_id: 1234"
244
- )
245
-
246
- url = f"{ datafold_host } /api/v1/datadiffs"
247
-
236
+ def _cloud_diff (diff_vars : DiffVars , datasource_id : int , datafold_host : str , url : str , api_key : str ) -> None :
248
237
payload = {
249
- "data_source1_id" : diff_vars . datasource_id ,
250
- "data_source2_id" : diff_vars . datasource_id ,
238
+ "data_source1_id" : datasource_id ,
239
+ "data_source2_id" : datasource_id ,
251
240
"table1" : diff_vars .prod_path ,
252
241
"table2" : diff_vars .dev_path ,
253
242
"pk_columns" : diff_vars .primary_keys ,
@@ -258,27 +247,60 @@ def _cloud_diff(diff_vars: DiffVars) -> None:
258
247
"Content-Type" : "application/json" ,
259
248
}
260
249
if is_tracking_enabled ():
261
- event_json = create_start_event_json ({"is_cloud" : True , "datasource_id" : diff_vars . datasource_id })
250
+ event_json = create_start_event_json ({"is_cloud" : True , "datasource_id" : datasource_id })
262
251
run_as_daemon (send_event_json , event_json )
263
252
264
253
start = time .monotonic ()
265
254
error = None
266
255
diff_id = None
256
+ diff_url = None
267
257
try :
268
- response = requests . request ( "POST" , url , headers = headers , json = payload , timeout = 30 )
269
- response . raise_for_status ()
270
- data = response . json ( )
271
- diff_id = data [ "id" ]
258
+ diff_id = _cloud_submit_diff ( url , payload , headers )
259
+ summary_url = f" { url } / { diff_id } /summary_results"
260
+ diff_results = _cloud_poll_and_get_summary_results ( summary_url , headers )
261
+
272
262
diff_url = f"{ datafold_host } /datadiffs/{ diff_id } /overview"
273
- rich .print (
274
- "[red]"
275
- + "." .join (diff_vars .prod_path )
276
- + " <> "
277
- + "." .join (diff_vars .dev_path )
278
- + "[/] \n Diff in progress: \n "
279
- + diff_url
280
- + "\n "
281
- )
263
+
264
+ rows_added_count = diff_results ["pks" ]["exclusives" ][1 ]
265
+ rows_removed_count = diff_results ["pks" ]["exclusives" ][0 ]
266
+
267
+ rows_updated = diff_results ["values" ]["rows_with_differences" ]
268
+ total_rows = diff_results ["values" ]["total_rows" ]
269
+ rows_unchanged = int (total_rows ) - int (rows_updated )
270
+ diff_percent_list = {
271
+ x ["column_name" ]: str (x ["match" ]) + "%"
272
+ for x in diff_results ["values" ]["columns_diff_stats" ]
273
+ if x ["match" ] != 100.0
274
+ }
275
+
276
+ if any ([rows_added_count , rows_removed_count , rows_updated ]):
277
+ diff_output = dbt_diff_string_template (
278
+ rows_added_count ,
279
+ rows_removed_count ,
280
+ rows_updated ,
281
+ str (rows_unchanged ),
282
+ diff_percent_list ,
283
+ "Value Match Percent:" ,
284
+ )
285
+ rich .print (
286
+ "[red]"
287
+ + "." .join (diff_vars .prod_path )
288
+ + " <> "
289
+ + "." .join (diff_vars .dev_path )
290
+ + f"[/]\n { diff_url } \n "
291
+ + diff_output
292
+ + "\n "
293
+ )
294
+ else :
295
+ rich .print (
296
+ "[red]"
297
+ + "." .join (diff_vars .prod_path )
298
+ + " <> "
299
+ + "." .join (diff_vars .dev_path )
300
+ + f"[/]\n { diff_url } \n "
301
+ + "[green]No row differences[/] \n "
302
+ )
303
+
282
304
except BaseException as ex : # Catch KeyboardInterrupt too
283
305
error = ex
284
306
finally :
@@ -302,15 +324,81 @@ def _cloud_diff(diff_vars: DiffVars) -> None:
302
324
send_event_json (event_json )
303
325
304
326
if error :
305
- raise error
327
+ rich .print (
328
+ "[red]"
329
+ + "." .join (diff_vars .prod_path )
330
+ + " <> "
331
+ + "." .join (diff_vars .dev_path ) + "[/]\n "
332
+ )
333
+ if diff_id :
334
+ diff_url = f"{ datafold_host } /datadiffs/{ diff_id } /overview"
335
+ rich .print (f"{ diff_url } \n " )
336
+ logger .error (error )
337
+
338
+
339
+ def _setup_cloud_diff () -> Tuple [Optional [str ]]:
340
+ datafold_host = os .environ .get ("DATAFOLD_HOST" )
341
+ if datafold_host is None :
342
+ datafold_host = "https://app.datafold.com"
343
+ datafold_host = datafold_host .rstrip ("/" )
344
+ rich .print (f"Cloud datafold host: { datafold_host } \n " )
345
+ url = f"{ datafold_host } /api/v1/datadiffs"
346
+
347
+ api_key = os .environ .get ("DATAFOLD_API_KEY" )
348
+ if not api_key :
349
+ rich .print ("[red]API key not found, add it as an environment variable called DATAFOLD_API_KEY." )
350
+ yes_or_no = Confirm .ask ("Would you like to generate a new API key?" )
351
+ if yes_or_no :
352
+ webbrowser .open (f"{ datafold_host } /login?next={ datafold_host } /users/me" )
353
+ return None , None , None
354
+ else :
355
+ raise ValueError ("Cannot diff because the API key is not provided" )
356
+
357
+ return datafold_host , url , api_key
358
+
359
+
360
+ def _cloud_submit_diff (url , payload , headers ) -> str :
361
+ response = requests .request ("POST" , url , headers = headers , json = payload , timeout = 30 )
362
+ response .raise_for_status ()
363
+ response_json = response .json ()
364
+ diff_id = str (response_json ["id" ])
365
+
366
+ if diff_id is None :
367
+ raise Exception (f"Api response did not contain a diff_id: { str (response_json )} " )
368
+ return diff_id
369
+
370
+
371
+ def _cloud_poll_and_get_summary_results (url , headers ):
372
+ summary_results = None
373
+ start_time = time .time ()
374
+ sleep_interval = 5 # starts at 5 sec
375
+ max_sleep_interval = 60
376
+ max_wait_time = 300
377
+
378
+ while not summary_results :
379
+ response = requests .request ("GET" , url , headers = headers , timeout = 30 )
380
+ response .raise_for_status ()
381
+ response_json = response .json ()
382
+
383
+ if response_json ["status" ] == "success" :
384
+ summary_results = response_json
385
+ elif response_json ["status" ] == "failed" :
386
+ raise Exception (f"Diff failed: { str (response_json )} " )
387
+
388
+ if time .time () - start_time > max_wait_time :
389
+ raise Exception ("Timed out waiting for diff results" )
390
+
391
+ time .sleep (sleep_interval )
392
+ sleep_interval = min (sleep_interval * 2 , max_sleep_interval )
393
+
394
+ return summary_results
306
395
307
396
308
397
class DbtParser :
309
- def __init__ (self , profiles_dir_override : str , project_dir_override : str , is_cloud : bool ) -> None :
398
+ def __init__ (self , profiles_dir_override : str , project_dir_override : str ) -> None :
310
399
self .parse_run_results , self .parse_manifest , self .ProfileRenderer , self .yaml = import_dbt ()
311
400
self .profiles_dir = Path (profiles_dir_override or default_profiles_dir ())
312
401
self .project_dir = Path (project_dir_override or default_project_dir ())
313
- self .is_cloud = is_cloud
314
402
self .connection = None
315
403
self .project_dict = self .get_project_dict ()
316
404
self .manifest_obj = self .get_manifest_obj ()
0 commit comments