Skip to content

Commit 52bfc7c

Browse files
committed
Account for new metric
1 parent fa6e701 commit 52bfc7c

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

tests/test_workflows/test_basic_workflow.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@ def test_basic_workflow(tmp_path, inference_network, summary_network):
2323

2424
assert "loss" in list(history.history.keys())
2525
assert len(history.history["loss"]) == 4
26-
assert list(plots.keys()) == ["losses", "recovery", "calibration_ecdf", "z_score_contraction"]
26+
assert list(plots.keys()) == ["losses", "recovery", "calibration_ecdf", "coverage", "z_score_contraction"]
2727
assert list(metrics.columns) == ["p1", "p2"]
28-
assert metrics.values.shape == (3, 2)
28+
assert metrics.values.shape == (4, 2)
2929

3030
# Ensure saving and loading from workflow works fine
3131
loaded_approximator = keras.saving.load_model(os.path.join(str(tmp_path), "model.keras"))
@@ -54,9 +54,9 @@ def test_basic_workflow_fusion(
5454

5555
assert "loss" in list(history.history.keys())
5656
assert len(history.history["loss"]) == 4
57-
assert list(plots.keys()) == ["losses", "recovery", "calibration_ecdf", "z_score_contraction"]
57+
assert list(plots.keys()) == ["losses", "recovery", "calibration_ecdf", "coverage", "z_score_contraction"]
5858
assert list(metrics.columns) == ["p1", "p2"]
59-
assert metrics.values.shape == (3, 2)
59+
assert metrics.values.shape == (4, 2)
6060

6161
# Ensure saving and loading from workflow works fine
6262
loaded_approximator = keras.saving.load_model(os.path.join(str(tmp_path), "model.keras"))

0 commit comments

Comments
 (0)