|
73 | 73 | parameters_short.append((model, dataset_i))
|
74 | 74 |
|
75 | 75 |
|
76 |
| -def verify_explanations(tmpdirname, additional_cols): |
| 76 | +def verify_explanations(tmpdirname, additional_cols, target_category_columns): |
77 | 77 | glb_expl = pd.read_csv(f"{tmpdirname}/results/global_explanation.csv", index_col=0)
|
78 | 78 | loc_expl = pd.read_csv(f"{tmpdirname}/results/local_explanation.csv")
|
79 | 79 | assert loc_expl.shape[0] == PERIODS
|
80 |
| - for x in ["Date", "Series"]: |
| 80 | + columns = ["Date", "Series"] |
| 81 | + if not target_category_columns: |
| 82 | + columns.remove("Series") |
| 83 | + for x in columns: |
81 | 84 | assert x in set(loc_expl.columns)
|
82 | 85 | # for x in additional_cols:
|
83 | 86 | # assert x in set(loc_expl.columns)
|
84 | 87 | # assert x in set(glb_expl.index)
|
85 |
| - assert "Series 1" in set(glb_expl.columns) |
86 | 88 |
|
87 | 89 |
|
88 | 90 | @pytest.mark.parametrize("model, data_details", parameters_short)
|
@@ -151,6 +153,7 @@ def test_load_datasets(model, data_details):
|
151 | 153 | verify_explanations(
|
152 | 154 | tmpdirname=tmpdirname,
|
153 | 155 | additional_cols=additional_cols,
|
| 156 | + target_category_columns=yaml_i["spec"]['target_category_columns'] |
154 | 157 | )
|
155 | 158 | if include_test_data:
|
156 | 159 | test_metrics = pd.read_csv(f"{tmpdirname}/results/test_metrics.csv")
|
|
0 commit comments