Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 5632150

Browse files
authored
Merge pull request #498 from dlawin/issue_447
handle all custom schemas scenarios
2 parents 25692cb + b018eb2 commit 5632150

File tree

2 files changed

+83
-27
lines changed

2 files changed

+83
-27
lines changed

data_diff/dbt.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -85,14 +85,17 @@ def dbt_diff(
8585
datadiff_variables = dbt_parser.get_datadiff_variables()
8686
config_prod_database = datadiff_variables.get("prod_database")
8787
config_prod_schema = datadiff_variables.get("prod_schema")
88+
config_prod_custom_schema = datadiff_variables.get("prod_custom_schema")
8889
datasource_id = datadiff_variables.get("datasource_id")
89-
custom_schemas = datadiff_variables.get("custom_schemas")
90-
# custom schemas is default dbt behavior, so default to True if the var doesn't exist
91-
custom_schemas = True if custom_schemas is None else custom_schemas
9290
set_dbt_user_id(dbt_parser.dbt_user_id)
9391
set_dbt_version(dbt_parser.dbt_version)
9492
set_dbt_project_id(dbt_parser.dbt_project_id)
9593

94+
if datadiff_variables.get("custom_schemas") is not None:
95+
logger.warning(
96+
"vars: data_diff: custom_schemas: is no longer used and can be removed.\nTo utilize custom schemas, see the documentation here: https://docs.datafold.com/development_testing/open_source"
97+
)
98+
9699
if is_cloud:
97100
api = _initialize_api()
98101
# exit so the user can set the key
@@ -125,7 +128,9 @@ def dbt_diff(
125128
)
126129

127130
for model in models:
128-
diff_vars = _get_diff_vars(dbt_parser, config_prod_database, config_prod_schema, model, custom_schemas)
131+
diff_vars = _get_diff_vars(
132+
dbt_parser, config_prod_database, config_prod_schema, config_prod_custom_schema, model
133+
)
129134

130135
if diff_vars.primary_keys:
131136
if is_cloud:
@@ -149,22 +154,33 @@ def _get_diff_vars(
149154
dbt_parser: "DbtParser",
150155
config_prod_database: Optional[str],
151156
config_prod_schema: Optional[str],
157+
config_prod_custom_schema: Optional[str],
152158
model,
153-
custom_schemas: bool,
154159
) -> DiffVars:
155160
dev_database = model.database
156161
dev_schema = model.schema_
157162

158163
primary_keys = dbt_parser.get_pk_from_model(model, dbt_parser.unique_columns, "primary-key")
159164

160165
prod_database = config_prod_database if config_prod_database else dev_database
161-
prod_schema = config_prod_schema if config_prod_schema else dev_schema
162166

163-
# if project has custom schemas (default)
164-
# need to construct the prod schema as <prod_target_schema>_<custom_schema>
165-
# https://docs.getdbt.com/docs/build/custom-schemas
166-
if custom_schemas and model.config.schema_:
167-
prod_schema = prod_schema + "_" + model.config.schema_
167+
# prod schema name differs from dev schema name
168+
if config_prod_schema:
169+
custom_schema = model.config.schema_
170+
171+
# the model has a custom schema config(schema='some_schema')
172+
if custom_schema:
173+
if not config_prod_custom_schema:
174+
raise ValueError(
175+
f"Found a custom schema on model {model.name}, but no value for\nvars:\n data_diff:\n prod_custom_schema:\nPlease set a value!\n"
176+
+ "For more details see: https://docs.datafold.com/development_testing/open_source"
177+
)
178+
prod_schema = config_prod_custom_schema.replace("<custom_schema>", custom_schema)
179+
# no custom schema, use the default
180+
else:
181+
prod_schema = config_prod_schema
182+
else:
183+
prod_schema = dev_schema
168184

169185
if dbt_parser.requires_upper:
170186
dev_qualified_list = [x.upper() for x in [dev_database, dev_schema, model.alias]]

tests/test_dbt.py

Lines changed: 56 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -740,9 +740,6 @@ def test_diff_not_is_cloud_no_pks(
740740
"prod_schema": "prod_schema",
741741
"datasource_id": 1,
742742
}
743-
host = "a_host"
744-
url = "a_url"
745-
api_key = "a_api_key"
746743

747744
mock_dbt_parser_inst.get_models.return_value = [mock_model]
748745
mock_dbt_parser_inst.get_datadiff_variables.return_value = expected_dbt_vars_dict
@@ -756,42 +753,67 @@ def test_diff_not_is_cloud_no_pks(
756753
mock_local_diff.assert_not_called()
757754
self.assertEqual(mock_print.call_count, 1)
758755

759-
def test_get_diff_vars_custom_schemas_prod_db_and_schema(self):
756+
def test_get_diff_vars_replace_custom_schema(self):
760757
mock_model = Mock()
761758
prod_database = "a_prod_db"
762759
prod_schema = "a_prod_schema"
763760
primary_keys = ["a_primary_key"]
764761
mock_model.database = "a_dev_db"
765-
mock_model.schema_ = "a_custom_dev_schema"
762+
mock_model.schema_ = "a_custom_schema"
766763
mock_model.config.schema_ = mock_model.schema_
767764
mock_model.alias = "a_model_name"
768765
mock_dbt_parser = Mock()
769766
mock_dbt_parser.get_pk_from_model.return_value = primary_keys
770767
mock_dbt_parser.requires_upper = False
771768

772-
diff_vars = _get_diff_vars(mock_dbt_parser, "a_prod_db", "a_prod_schema", mock_model, custom_schemas=True)
769+
diff_vars = _get_diff_vars(mock_dbt_parser, prod_database, prod_schema, "prod_<custom_schema>", mock_model)
773770

774771
assert diff_vars.dev_path == [mock_model.database, mock_model.schema_, mock_model.alias]
775-
assert diff_vars.prod_path == [prod_database, prod_schema + "_" + mock_model.schema_, mock_model.alias]
772+
assert diff_vars.prod_path == [prod_database, "prod_" + mock_model.schema_, mock_model.alias]
776773
assert diff_vars.primary_keys == primary_keys
777774
assert diff_vars.connection == mock_dbt_parser.connection
778775
assert diff_vars.threads == mock_dbt_parser.threads
776+
assert prod_schema not in diff_vars.prod_path
777+
779778
mock_dbt_parser.get_pk_from_model.assert_called_once()
780779

781-
def test_get_diff_vars_false_custom_schemas_prod_db_and_schema(self):
780+
def test_get_diff_vars_static_custom_schema(self):
782781
mock_model = Mock()
783782
prod_database = "a_prod_db"
784783
prod_schema = "a_prod_schema"
785784
primary_keys = ["a_primary_key"]
786785
mock_model.database = "a_dev_db"
787-
mock_model.schema_ = "a_custom_dev_schema"
786+
mock_model.schema_ = "a_custom_schema"
788787
mock_model.config.schema_ = mock_model.schema_
789788
mock_model.alias = "a_model_name"
790789
mock_dbt_parser = Mock()
791-
mock_dbt_parser.get_pk_from_model.return_value = ["a_primary_key"]
790+
mock_dbt_parser.get_pk_from_model.return_value = primary_keys
792791
mock_dbt_parser.requires_upper = False
793792

794-
diff_vars = _get_diff_vars(mock_dbt_parser, "a_prod_db", "a_prod_schema", mock_model, custom_schemas=False)
793+
diff_vars = _get_diff_vars(mock_dbt_parser, prod_database, prod_schema, "prod", mock_model)
794+
795+
assert diff_vars.dev_path == [mock_model.database, mock_model.schema_, mock_model.alias]
796+
assert diff_vars.prod_path == [prod_database, "prod", mock_model.alias]
797+
assert diff_vars.primary_keys == primary_keys
798+
assert diff_vars.connection == mock_dbt_parser.connection
799+
assert diff_vars.threads == mock_dbt_parser.threads
800+
assert prod_schema not in diff_vars.prod_path
801+
mock_dbt_parser.get_pk_from_model.assert_called_once()
802+
803+
def test_get_diff_vars_no_custom_schema_on_model(self):
804+
mock_model = Mock()
805+
prod_database = "a_prod_db"
806+
prod_schema = "a_prod_schema"
807+
primary_keys = ["a_primary_key"]
808+
mock_model.database = "a_dev_db"
809+
mock_model.schema_ = "a_custom_schema"
810+
mock_model.config.schema_ = None
811+
mock_model.alias = "a_model_name"
812+
mock_dbt_parser = Mock()
813+
mock_dbt_parser.get_pk_from_model.return_value = primary_keys
814+
mock_dbt_parser.requires_upper = False
815+
816+
diff_vars = _get_diff_vars(mock_dbt_parser, prod_database, prod_schema, "prod", mock_model)
795817

796818
assert diff_vars.dev_path == [mock_model.database, mock_model.schema_, mock_model.alias]
797819
assert diff_vars.prod_path == [prod_database, prod_schema, mock_model.alias]
@@ -800,23 +822,41 @@ def test_get_diff_vars_false_custom_schemas_prod_db_and_schema(self):
800822
assert diff_vars.threads == mock_dbt_parser.threads
801823
mock_dbt_parser.get_pk_from_model.assert_called_once()
802824

803-
def test_get_diff_vars_false_custom_schemas_prod_db(self):
825+
def test_get_diff_vars_match_dev_schema(self):
804826
mock_model = Mock()
805827
prod_database = "a_prod_db"
806828
primary_keys = ["a_primary_key"]
807829
mock_model.database = "a_dev_db"
808-
mock_model.schema_ = "a_custom_dev_schema"
809-
mock_model.config.schema_ = mock_model.schema_
830+
mock_model.schema_ = "a_schema"
831+
mock_model.config.schema_ = None
810832
mock_model.alias = "a_model_name"
811833
mock_dbt_parser = Mock()
812-
mock_dbt_parser.get_pk_from_model.return_value = ["a_primary_key"]
834+
mock_dbt_parser.get_pk_from_model.return_value = primary_keys
813835
mock_dbt_parser.requires_upper = False
814836

815-
diff_vars = _get_diff_vars(mock_dbt_parser, "a_prod_db", None, mock_model, custom_schemas=False)
837+
diff_vars = _get_diff_vars(mock_dbt_parser, prod_database, None, None, mock_model)
816838

817839
assert diff_vars.dev_path == [mock_model.database, mock_model.schema_, mock_model.alias]
818840
assert diff_vars.prod_path == [prod_database, mock_model.schema_, mock_model.alias]
819841
assert diff_vars.primary_keys == primary_keys
820842
assert diff_vars.connection == mock_dbt_parser.connection
821843
assert diff_vars.threads == mock_dbt_parser.threads
822844
mock_dbt_parser.get_pk_from_model.assert_called_once()
845+
846+
def test_get_diff_custom_schema_no_config_exception(self):
847+
mock_model = Mock()
848+
prod_database = "a_prod_db"
849+
prod_schema = "a_prod_schema"
850+
primary_keys = ["a_primary_key"]
851+
mock_model.database = "a_dev_db"
852+
mock_model.schema_ = "a_schema"
853+
mock_model.config.schema_ = "a_custom_schema"
854+
mock_model.alias = "a_model_name"
855+
mock_dbt_parser = Mock()
856+
mock_dbt_parser.get_pk_from_model.return_value = primary_keys
857+
mock_dbt_parser.requires_upper = False
858+
859+
with self.assertRaises(ValueError):
860+
_get_diff_vars(mock_dbt_parser, prod_database, prod_schema, None, mock_model)
861+
862+
mock_dbt_parser.get_pk_from_model.assert_called_once()

0 commit comments

Comments
 (0)