diff --git a/covid19_drdfm/streamlit/pages/2_Comparative_Run_Analysis.py b/covid19_drdfm/streamlit/pages/2_Comparative_Run_Analysis.py index 96b9729..54e6a19 100644 --- a/covid19_drdfm/streamlit/pages/2_Comparative_Run_Analysis.py +++ b/covid19_drdfm/streamlit/pages/2_Comparative_Run_Analysis.py @@ -21,7 +21,9 @@ def center_title(text): center_title("Comparative Run Analysis") # Parameter to runs -path_to_results = Path(st.text_input("Path directory of runs", value="./covid19_drdfm/data/example-data")) +run_paths_string = "./covid19_drdfm/data/example-data" +run_paths = Path(run_paths_string) +path_to_results = Path(st.text_input("Path directory of runs", value=run_paths_string)) df = parse_multiple_runs(path_to_results) @@ -53,19 +55,42 @@ def create_plot(df): st.plotly_chart(fig, use_container_width=True) return metric - -def get_summary(df): +# Function to get number of failed states +def count_failed_states(run_name): + failed_states_count = 0 + failed_file_path = run_paths / run_name / "failed.txt" + if failed_file_path.exists(): + with open(failed_file_path, 'r') as failed_file: + for line in failed_file: + if "Matrix is not positive definite" in line: + failed_states_count += 1 + return failed_states_count + +# Function to calculate deviation from the minimum number of failed states for a specific run +def calculate_deviation(run_name): + min_failed_states = float('inf') + for run_path in run_paths.iterdir(): + if run_path.is_dir(): + failed_states_count = count_failed_states(run_path.name) + min_failed_states = min(min_failed_states, failed_states_count) + + failed_states_count = count_failed_states(run_name) + deviation = - (failed_states_count - min_failed_states) + return deviation + +def get_summary(df, run_name): # Median metrics - col1, col2, col3 = st.columns(3) - col1.metric("Median Log Likelihood", df["Log Likelihood"].median()) - col2.metric("Median AIC", df["AIC"].median()) - col3.metric("Median EM Iterations", df["EM Iterations"].median()) + col1, col2, col3, col4 = st.columns(4) + col1.metric("Number of Failed States", count_failed_states(run_name), calculate_deviation(run_name)) + col2.metric("Median Log Likelihood", df["Log Likelihood"].median()) + col3.metric("Median AIC", df["AIC"].median()) + col4.metric("Median EM Iterations", df["EM Iterations"].median()) def show_summary(df): run = st.selectbox("Select a run", df["Run"].unique()) filtered_df = df[(df["Run"] == run)] - return get_summary(filtered_df) + return get_summary(filtered_df, run) def run_normal(df, metric):