@@ -23,9 +23,9 @@ def test_basic_workflow(tmp_path, inference_network, summary_network):
2323
2424 assert "loss" in list (history .history .keys ())
2525 assert len (history .history ["loss" ]) == 4
26- assert list (plots .keys ()) == ["losses" , "recovery" , "calibration_ecdf" , "z_score_contraction" ]
26+ assert list (plots .keys ()) == ["losses" , "recovery" , "calibration_ecdf" , "coverage" , " z_score_contraction" ]
2727 assert list (metrics .columns ) == ["p1" , "p2" ]
28- assert metrics .values .shape == (3 , 2 )
28+ assert metrics .values .shape == (4 , 2 )
2929
3030 # Ensure saving and loading from workflow works fine
3131 loaded_approximator = keras .saving .load_model (os .path .join (str (tmp_path ), "model.keras" ))
@@ -54,9 +54,9 @@ def test_basic_workflow_fusion(
5454
5555 assert "loss" in list (history .history .keys ())
5656 assert len (history .history ["loss" ]) == 4
57- assert list (plots .keys ()) == ["losses" , "recovery" , "calibration_ecdf" , "z_score_contraction" ]
57+ assert list (plots .keys ()) == ["losses" , "recovery" , "calibration_ecdf" , "coverage" , " z_score_contraction" ]
5858 assert list (metrics .columns ) == ["p1" , "p2" ]
59- assert metrics .values .shape == (3 , 2 )
59+ assert metrics .values .shape == (4 , 2 )
6060
6161 # Ensure saving and loading from workflow works fine
6262 loaded_approximator = keras .saving .load_model (os .path .join (str (tmp_path ), "model.keras" ))
0 commit comments