Skip to content

Commit 681b877

Browse files
committed
updated test cases
1 parent 32c9c9a commit 681b877

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

tests/operators/forecast/test_datasets.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,16 +73,18 @@
7373
parameters_short.append((model, dataset_i))
7474

7575

76-
def verify_explanations(tmpdirname, additional_cols):
76+
def verify_explanations(tmpdirname, additional_cols, target_category_columns):
7777
glb_expl = pd.read_csv(f"{tmpdirname}/results/global_explanation.csv", index_col=0)
7878
loc_expl = pd.read_csv(f"{tmpdirname}/results/local_explanation.csv")
7979
assert loc_expl.shape[0] == PERIODS
80-
for x in ["Date", "Series"]:
80+
columns = ["Date", "Series"]
81+
if not target_category_columns:
82+
columns.remove("Series")
83+
for x in columns:
8184
assert x in set(loc_expl.columns)
8285
# for x in additional_cols:
8386
# assert x in set(loc_expl.columns)
8487
# assert x in set(glb_expl.index)
85-
assert "Series 1" in set(glb_expl.columns)
8688

8789

8890
@pytest.mark.parametrize("model, data_details", parameters_short)
@@ -151,6 +153,7 @@ def test_load_datasets(model, data_details):
151153
verify_explanations(
152154
tmpdirname=tmpdirname,
153155
additional_cols=additional_cols,
156+
target_category_columns=yaml_i["spec"]['target_category_columns']
154157
)
155158
if include_test_data:
156159
test_metrics = pd.read_csv(f"{tmpdirname}/results/test_metrics.csv")

0 commit comments

Comments
 (0)