@@ -740,9 +740,6 @@ def test_diff_not_is_cloud_no_pks(
740
740
"prod_schema" : "prod_schema" ,
741
741
"datasource_id" : 1 ,
742
742
}
743
- host = "a_host"
744
- url = "a_url"
745
- api_key = "a_api_key"
746
743
747
744
mock_dbt_parser_inst .get_models .return_value = [mock_model ]
748
745
mock_dbt_parser_inst .get_datadiff_variables .return_value = expected_dbt_vars_dict
@@ -756,42 +753,67 @@ def test_diff_not_is_cloud_no_pks(
756
753
mock_local_diff .assert_not_called ()
757
754
self .assertEqual (mock_print .call_count , 1 )
758
755
759
- def test_get_diff_vars_custom_schemas_prod_db_and_schema (self ):
756
+ def test_get_diff_vars_replace_custom_schema (self ):
760
757
mock_model = Mock ()
761
758
prod_database = "a_prod_db"
762
759
prod_schema = "a_prod_schema"
763
760
primary_keys = ["a_primary_key" ]
764
761
mock_model .database = "a_dev_db"
765
- mock_model .schema_ = "a_custom_dev_schema "
762
+ mock_model .schema_ = "a_custom_schema "
766
763
mock_model .config .schema_ = mock_model .schema_
767
764
mock_model .alias = "a_model_name"
768
765
mock_dbt_parser = Mock ()
769
766
mock_dbt_parser .get_pk_from_model .return_value = primary_keys
770
767
mock_dbt_parser .requires_upper = False
771
768
772
- diff_vars = _get_diff_vars (mock_dbt_parser , "a_prod_db" , "a_prod_schema " , mock_model , custom_schemas = True )
769
+ diff_vars = _get_diff_vars (mock_dbt_parser , prod_database , prod_schema , "prod_<custom_schema> " , mock_model )
773
770
774
771
assert diff_vars .dev_path == [mock_model .database , mock_model .schema_ , mock_model .alias ]
775
- assert diff_vars .prod_path == [prod_database , prod_schema + "_ " + mock_model .schema_ , mock_model .alias ]
772
+ assert diff_vars .prod_path == [prod_database , "prod_ " + mock_model .schema_ , mock_model .alias ]
776
773
assert diff_vars .primary_keys == primary_keys
777
774
assert diff_vars .connection == mock_dbt_parser .connection
778
775
assert diff_vars .threads == mock_dbt_parser .threads
776
+ assert prod_schema not in diff_vars .prod_path
777
+
779
778
mock_dbt_parser .get_pk_from_model .assert_called_once ()
780
779
781
- def test_get_diff_vars_false_custom_schemas_prod_db_and_schema (self ):
780
+ def test_get_diff_vars_static_custom_schema (self ):
782
781
mock_model = Mock ()
783
782
prod_database = "a_prod_db"
784
783
prod_schema = "a_prod_schema"
785
784
primary_keys = ["a_primary_key" ]
786
785
mock_model .database = "a_dev_db"
787
- mock_model .schema_ = "a_custom_dev_schema "
786
+ mock_model .schema_ = "a_custom_schema "
788
787
mock_model .config .schema_ = mock_model .schema_
789
788
mock_model .alias = "a_model_name"
790
789
mock_dbt_parser = Mock ()
791
- mock_dbt_parser .get_pk_from_model .return_value = [ "a_primary_key" ]
790
+ mock_dbt_parser .get_pk_from_model .return_value = primary_keys
792
791
mock_dbt_parser .requires_upper = False
793
792
794
- diff_vars = _get_diff_vars (mock_dbt_parser , "a_prod_db" , "a_prod_schema" , mock_model , custom_schemas = False )
793
+ diff_vars = _get_diff_vars (mock_dbt_parser , prod_database , prod_schema , "prod" , mock_model )
794
+
795
+ assert diff_vars .dev_path == [mock_model .database , mock_model .schema_ , mock_model .alias ]
796
+ assert diff_vars .prod_path == [prod_database , "prod" , mock_model .alias ]
797
+ assert diff_vars .primary_keys == primary_keys
798
+ assert diff_vars .connection == mock_dbt_parser .connection
799
+ assert diff_vars .threads == mock_dbt_parser .threads
800
+ assert prod_schema not in diff_vars .prod_path
801
+ mock_dbt_parser .get_pk_from_model .assert_called_once ()
802
+
803
+ def test_get_diff_vars_no_custom_schema_on_model (self ):
804
+ mock_model = Mock ()
805
+ prod_database = "a_prod_db"
806
+ prod_schema = "a_prod_schema"
807
+ primary_keys = ["a_primary_key" ]
808
+ mock_model .database = "a_dev_db"
809
+ mock_model .schema_ = "a_custom_schema"
810
+ mock_model .config .schema_ = None
811
+ mock_model .alias = "a_model_name"
812
+ mock_dbt_parser = Mock ()
813
+ mock_dbt_parser .get_pk_from_model .return_value = primary_keys
814
+ mock_dbt_parser .requires_upper = False
815
+
816
+ diff_vars = _get_diff_vars (mock_dbt_parser , prod_database , prod_schema , "prod" , mock_model )
795
817
796
818
assert diff_vars .dev_path == [mock_model .database , mock_model .schema_ , mock_model .alias ]
797
819
assert diff_vars .prod_path == [prod_database , prod_schema , mock_model .alias ]
@@ -800,23 +822,41 @@ def test_get_diff_vars_false_custom_schemas_prod_db_and_schema(self):
800
822
assert diff_vars .threads == mock_dbt_parser .threads
801
823
mock_dbt_parser .get_pk_from_model .assert_called_once ()
802
824
803
- def test_get_diff_vars_false_custom_schemas_prod_db (self ):
825
+ def test_get_diff_vars_match_dev_schema (self ):
804
826
mock_model = Mock ()
805
827
prod_database = "a_prod_db"
806
828
primary_keys = ["a_primary_key" ]
807
829
mock_model .database = "a_dev_db"
808
- mock_model .schema_ = "a_custom_dev_schema "
809
- mock_model .config .schema_ = mock_model . schema_
830
+ mock_model .schema_ = "a_schema "
831
+ mock_model .config .schema_ = None
810
832
mock_model .alias = "a_model_name"
811
833
mock_dbt_parser = Mock ()
812
- mock_dbt_parser .get_pk_from_model .return_value = [ "a_primary_key" ]
834
+ mock_dbt_parser .get_pk_from_model .return_value = primary_keys
813
835
mock_dbt_parser .requires_upper = False
814
836
815
- diff_vars = _get_diff_vars (mock_dbt_parser , "a_prod_db" , None , mock_model , custom_schemas = False )
837
+ diff_vars = _get_diff_vars (mock_dbt_parser , prod_database , None , None , mock_model )
816
838
817
839
assert diff_vars .dev_path == [mock_model .database , mock_model .schema_ , mock_model .alias ]
818
840
assert diff_vars .prod_path == [prod_database , mock_model .schema_ , mock_model .alias ]
819
841
assert diff_vars .primary_keys == primary_keys
820
842
assert diff_vars .connection == mock_dbt_parser .connection
821
843
assert diff_vars .threads == mock_dbt_parser .threads
822
844
mock_dbt_parser .get_pk_from_model .assert_called_once ()
845
+
846
+ def test_get_diff_custom_schema_no_config_exception (self ):
847
+ mock_model = Mock ()
848
+ prod_database = "a_prod_db"
849
+ prod_schema = "a_prod_schema"
850
+ primary_keys = ["a_primary_key" ]
851
+ mock_model .database = "a_dev_db"
852
+ mock_model .schema_ = "a_schema"
853
+ mock_model .config .schema_ = "a_custom_schema"
854
+ mock_model .alias = "a_model_name"
855
+ mock_dbt_parser = Mock ()
856
+ mock_dbt_parser .get_pk_from_model .return_value = primary_keys
857
+ mock_dbt_parser .requires_upper = False
858
+
859
+ with self .assertRaises (ValueError ):
860
+ _get_diff_vars (mock_dbt_parser , prod_database , prod_schema , None , mock_model )
861
+
862
+ mock_dbt_parser .get_pk_from_model .assert_called_once ()
0 commit comments