Skip to content

Commit 48a528d

Browse files
committed
refactor: applied review suggestion to visualise script
1 parent 160ec4e commit 48a528d

File tree

1 file changed

+25
-3
lines changed

1 file changed

+25
-3
lines changed

src/rai_bench/rai_bench/results_processing/visualise.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ def display_models_summ_data(df: pd.DataFrame):
145145
x_label="Model Name",
146146
y_label="Success Rate (%)",
147147
color_column="model_name",
148+
y_range=(0.0, 100.0),
148149
)
149150
st.plotly_chart(fig1, use_container_width=True) # type: ignore
150151

@@ -185,6 +186,7 @@ def display_models_extra_calls_data(df: pd.DataFrame, extra_calls: int):
185186
x_label="Model Name",
186187
y_label="Success Rate (%)",
187188
color_column="model_name",
189+
y_range=(0.0, 1.0),
188190
)
189191
st.plotly_chart(fig1, use_container_width=True) # type: ignore
190192

@@ -215,6 +217,7 @@ def display_task_type_performance(model_results: ModelResults):
215217
title="Success Rate by Task Type",
216218
x_label="Task Type",
217219
y_label="Avg Score",
220+
y_range=(0.0, 1.0),
218221
)
219222
st.plotly_chart(fig_type_score, use_container_width=True) # type: ignore
220223

@@ -244,6 +247,7 @@ def display_task_complexity_performance(model_results: ModelResults):
244247
title="Success Rate by Task Complexity",
245248
x_label="Task Complexity",
246249
y_label="Avg Score",
250+
y_range=(0.0, 1.0),
247251
)
248252
st.plotly_chart(fig_complexity_score, use_container_width=True) # type: ignore
249253

@@ -284,8 +288,21 @@ def display_detailed_task_type_analysis(
284288
model_results: ModelResults, selected_type: str
285289
):
286290
"""Display detailed analysis for a specific task type."""
287-
# Get task data for the selected type
288-
filtered_by_complexity = create_task_metrics_dataframe(model_results, "complexity")
291+
# first, get only the tasks of the selected type
292+
tasks_for_type_df = create_task_details_dataframe(model_results, selected_type)
293+
if tasks_for_type_df.empty:
294+
st.warning(f"No tasks of type {selected_type} found.")
295+
return
296+
# Now aggregate by complexity for that type
297+
filtered_by_complexity = (
298+
tasks_for_type_df.groupby("complexity") # type: ignore
299+
.agg(
300+
avg_score=("score", "mean"),
301+
avg_time=("total_time", "mean"),
302+
avg_extra_tool_calls=("extra_tool_calls_used", "mean"),
303+
)
304+
.reset_index()
305+
)
289306
filtered_by_complexity = filtered_by_complexity[
290307
filtered_by_complexity["complexity"].notna()
291308
]
@@ -299,6 +316,7 @@ def display_detailed_task_type_analysis(
299316
title=f"Success Rate by Task Complexity for '{selected_type}' Tasks",
300317
x_label="Task Complexity",
301318
y_label="Avg Score",
319+
y_range=(0.0, 1.0),
302320
)
303321
st.plotly_chart(fig_complexity_score, use_container_width=True) # type: ignore
304322

@@ -631,7 +649,11 @@ def main():
631649
st.set_page_config(layout="wide", page_title="LLM Task Results Visualizer")
632650
st.title("RAI BENCHMARK RESULTS")
633651

634-
run_folders = get_available_runs(EXPERIMENT_DIR)
652+
try:
653+
run_folders = get_available_runs(EXPERIMENT_DIR)
654+
except FileNotFoundError:
655+
st.error(f"Experiments directory '{EXPERIMENT_DIR}' not found.")
656+
return
635657

636658
if not run_folders:
637659
st.warning("No benchmark runs found in the experiments directory.")

0 commit comments

Comments
 (0)