Skip to content

Commit ae3f14a

Browse files
committed
refactor: unified gathering and saving results for both benchmarks
new base benchmark class update visualise to manipulation bench
1 parent 48a528d commit ae3f14a

File tree

21 files changed

+696
-450
lines changed

21 files changed

+696
-450
lines changed
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
# Copyright (C) 2025 Robotec.AI
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import csv
16+
import logging
17+
from abc import ABC, abstractmethod
18+
from pathlib import Path
19+
20+
from langgraph.graph.state import CompiledStateGraph
21+
from pydantic import BaseModel, Field
22+
23+
24+
class BenchmarkSummary(BaseModel):
25+
model_name: str = Field(..., description="Name of the LLM.")
26+
success_rate: float = Field(
27+
..., description="Percentage of successfully completed tasks."
28+
)
29+
avg_time: float = Field(..., description="Average time taken across all tasks.")
30+
total_extra_tool_calls_used: int = Field(
31+
..., description="Total number of extra tool calls used in this Task"
32+
)
33+
total_tasks: int = Field(..., description="Total number of executed tasks.")
34+
35+
36+
class BaseBenchmark(ABC):
37+
"""Base class for all benchmarks."""
38+
39+
def __init__(
40+
self,
41+
model_name: str,
42+
results_dir: Path,
43+
logger: logging.Logger | None = None,
44+
) -> None:
45+
"""Initialize the base benchmark.
46+
47+
Parameters
48+
----------
49+
model_name : str
50+
Name of the LLM model.
51+
logger : Optional[loggers_type]
52+
Logger instance.
53+
results_filename : str
54+
Path to the results file.
55+
summary_filename : Optional[str]
56+
Path to the summary file.
57+
"""
58+
self.model_name = model_name
59+
self.results_filename = results_dir / "results.csv"
60+
self.summary_filename = results_dir / "results_summary.csv"
61+
62+
if logger:
63+
self.logger = logger
64+
else:
65+
self.logger = logging.getLogger(__name__)
66+
67+
@staticmethod
68+
def csv_initialize(filename: Path, base_model_cls: type[BaseModel]) -> None:
69+
"""Initialize a CSV file based on a Pydantic model class.
70+
71+
Parameters
72+
----------
73+
filename : Path
74+
Filename of the CSV file.
75+
base_model_cls : type[BaseModel]
76+
Pydantic model class to be used for creating the columns in the CSV file.
77+
"""
78+
with open(filename, mode="w", newline="", encoding="utf-8") as file:
79+
writer = csv.DictWriter(
80+
file, fieldnames=base_model_cls.__annotations__.keys()
81+
)
82+
writer.writeheader()
83+
84+
@staticmethod
85+
def csv_writerow(filename: Path, base_model_instance: BaseModel) -> None:
86+
"""Write a single row to a CSV file based on a Pydantic model instance contents,
87+
ensuring that multiline strings are converted to one-line strings by replacing newlines.
88+
89+
Parameters
90+
----------
91+
filename : Path
92+
Filename of the CSV file.
93+
base_model_instance : BaseModel
94+
Pydantic model instance which contains the data to be written to the CSV file.
95+
"""
96+
row = base_model_instance.model_dump()
97+
98+
for key, value in row.items():
99+
if isinstance(value, str):
100+
# Replace newline characters with a single space so they don't break csv
101+
row[key] = " ".join(value.split())
102+
103+
with open(filename, mode="a", newline="", encoding="utf-8") as file:
104+
writer = csv.DictWriter(
105+
file, fieldnames=base_model_instance.__annotations__.keys()
106+
)
107+
writer.writerow(row)
108+
109+
@abstractmethod
110+
def run_next(self, agent: CompiledStateGraph) -> None:
111+
"""Run the next task/scenario of the benchmark.
112+
113+
Parameters
114+
----------
115+
agent : CompiledStateGraph
116+
LangChain tool calling agent.
117+
"""
118+
pass
119+
120+
@abstractmethod
121+
def compute_and_save_summary(self) -> None:
122+
"""Compute summary statistics and save them to the summary file."""
123+
pass
124+
125+
# TODO (jm) this can be probably same for all benchmark in the future

src/rai_bench/rai_bench/examples/manipulation_o3de/main.py

Lines changed: 19 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,9 @@
3333
from rai_open_set_vision.tools import GetGrabbingPointTool
3434

3535
from rai_bench.examples.manipulation_o3de.scenarios import (
36-
easy_scenarios,
37-
hard_scenarios,
38-
medium_scenarios,
3936
trivial_scenarios,
40-
very_hard_scenarios,
4137
)
42-
from rai_bench.manipulation_o3de.benchmark import Benchmark
38+
from rai_bench.manipulation_o3de.benchmark import ManipulationO3DEBenchmark
4339
from rai_sim.o3de.o3de_bridge import (
4440
O3DEngineArmManipulationBridge,
4541
)
@@ -72,17 +68,8 @@ def run_benchmark(model_name: str, vendor: str, out_dir: str):
7268
node.declare_parameter("conversion_ratio", 1.0)
7369

7470
# define model
75-
7671
llm = get_llm_model_direct(model_name=model_name, vendor=vendor)
7772

78-
system_prompt = """
79-
You are a robotic arm with interfaces to detect and manipulate objects.
80-
Here are the coordinates information:
81-
x - front to back (positive is forward)
82-
y - left to right (positive is right)
83-
z - up to down (positive is up)
84-
Before starting the task, make sure to grab the camera image to understand the environment.
85-
"""
8673
# define tools
8774
tools: List[BaseTool] = [
8875
GetObjectPositionsTool(
@@ -165,33 +152,33 @@ def run_benchmark(model_name: str, vendor: str, out_dir: str):
165152
t_scenarios = trivial_scenarios(
166153
configs_dir=configs_dir, connector_path=connector_path, logger=bench_logger
167154
)
168-
e_scenarios = easy_scenarios(
169-
configs_dir=configs_dir, connector_path=connector_path, logger=bench_logger
170-
)
171-
m_scenarios = medium_scenarios(
172-
configs_dir=configs_dir, connector_path=connector_path, logger=bench_logger
173-
)
174-
h_scenarios = hard_scenarios(
175-
configs_dir=configs_dir, connector_path=connector_path, logger=bench_logger
176-
)
177-
vh_scenarios = very_hard_scenarios(
178-
configs_dir=configs_dir, connector_path=connector_path, logger=bench_logger
179-
)
155+
# e_scenarios = easy_scenarios(
156+
# configs_dir=configs_dir, connector_path=connector_path, logger=bench_logger
157+
# )
158+
# m_scenarios = medium_scenarios(
159+
# configs_dir=configs_dir, connector_path=connector_path, logger=bench_logger
160+
# )
161+
# h_scenarios = hard_scenarios(
162+
# configs_dir=configs_dir, connector_path=connector_path, logger=bench_logger
163+
# )
164+
# vh_scenarios = very_hard_scenarios(
165+
# configs_dir=configs_dir, connector_path=connector_path, logger=bench_logger
166+
# )
180167

181-
all_scenarios = t_scenarios + e_scenarios + m_scenarios + h_scenarios + vh_scenarios
168+
all_scenarios = t_scenarios
182169
o3de = O3DEngineArmManipulationBridge(connector, logger=agent_logger)
183170
try:
184171
# define benchamrk
185-
results_filename = f"{out_dir}/results.csv"
186-
benchmark = Benchmark(
172+
benchmark = ManipulationO3DEBenchmark(
173+
model_name=model_name,
187174
simulation_bridge=o3de,
188175
scenarios=all_scenarios,
189176
logger=bench_logger,
190-
results_filename=results_filename,
177+
results_dir=Path(out_dir),
191178
)
192-
for _ in range(len(all_scenarios)):
179+
for scenario in all_scenarios:
193180
agent = create_conversational_agent(
194-
llm, tools, system_prompt, logger=agent_logger
181+
llm, tools, scenario.task.system_prompt, logger=agent_logger
195182
)
196183
benchmark.run_next(agent=agent)
197184
o3de.reset_arm()

src/rai_bench/rai_bench/examples/manipulation_o3de/scenarios.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from rclpy.impl.rcutils_logger import RcutilsLogger
2020

21-
from rai_bench.manipulation_o3de.benchmark import Benchmark, Scenario
21+
from rai_bench.manipulation_o3de.benchmark import ManipulationO3DEBenchmark, Scenario
2222
from rai_bench.manipulation_o3de.interfaces import Task
2323
from rai_bench.manipulation_o3de.tasks import (
2424
BuildCubeTowerTask,
@@ -86,7 +86,7 @@ def trivial_scenarios(
8686
place_object_tasks.append(
8787
PlaceObjectAtCoordTask(obj, coord, disp, logger=logger)
8888
)
89-
easy_place_objects_scenarios = Benchmark.create_scenarios(
89+
easy_place_objects_scenarios = ManipulationO3DEBenchmark.create_scenarios(
9090
tasks=place_object_tasks,
9191
simulation_configs=simulations_configs,
9292
simulation_configs_paths=simulation_configs_paths,
@@ -99,7 +99,7 @@ def trivial_scenarios(
9999
for objects in object_groups
100100
]
101101

102-
easy_move_to_left_scenarios = Benchmark.create_scenarios(
102+
easy_move_to_left_scenarios = ManipulationO3DEBenchmark.create_scenarios(
103103
tasks=move_to_left_tasks,
104104
simulation_configs=simulations_configs,
105105
simulation_configs_paths=simulation_configs_paths,
@@ -168,7 +168,7 @@ def easy_scenarios(
168168
place_object_tasks.append(
169169
PlaceObjectAtCoordTask(obj, coord, disp, logger=logger)
170170
)
171-
easy_place_objects_scenarios = Benchmark.create_scenarios(
171+
easy_place_objects_scenarios = ManipulationO3DEBenchmark.create_scenarios(
172172
tasks=place_object_tasks,
173173
simulation_configs=simulations_configs,
174174
simulation_configs_paths=simulation_configs_paths,
@@ -188,15 +188,15 @@ def easy_scenarios(
188188
for objects in object_groups
189189
]
190190

191-
easy_move_to_left_scenarios = Benchmark.create_scenarios(
191+
easy_move_to_left_scenarios = ManipulationO3DEBenchmark.create_scenarios(
192192
tasks=move_to_left_tasks,
193193
simulation_configs=simulations_configs,
194194
simulation_configs_paths=simulation_configs_paths,
195195
)
196196

197197
# place cubes
198198
task = PlaceCubesTask(threshold_distance=0.2, logger=logger)
199-
easy_place_cubes_scenarios = Benchmark.create_scenarios(
199+
easy_place_cubes_scenarios = ManipulationO3DEBenchmark.create_scenarios(
200200
tasks=[task],
201201
simulation_configs=simulations_configs,
202202
simulation_configs_paths=simulation_configs_paths,
@@ -284,7 +284,7 @@ def medium_scenarios(
284284
for objects in object_groups
285285
]
286286

287-
move_to_left_scenarios = Benchmark.create_scenarios(
287+
move_to_left_scenarios = ManipulationO3DEBenchmark.create_scenarios(
288288
tasks=move_to_left_tasks,
289289
simulation_configs=medium_simulations_configs,
290290
simulation_configs_paths=medium_simulation_configs_paths,
@@ -293,7 +293,7 @@ def medium_scenarios(
293293

294294
# place cubes
295295
task = PlaceCubesTask(threshold_distance=0.1, logger=logger)
296-
easy_place_cubes_scenarios = Benchmark.create_scenarios(
296+
easy_place_cubes_scenarios = ManipulationO3DEBenchmark.create_scenarios(
297297
tasks=[task],
298298
simulation_configs=medium_simulations_configs,
299299
simulation_configs_paths=medium_simulation_configs_paths,
@@ -310,7 +310,7 @@ def medium_scenarios(
310310
for objects in object_groups
311311
]
312312

313-
build_tower_scenarios = Benchmark.create_scenarios(
313+
build_tower_scenarios = ManipulationO3DEBenchmark.create_scenarios(
314314
tasks=build_tower_tasks,
315315
simulation_configs=easy_simulations_configs,
316316
simulation_configs_paths=easy_simulation_configs_paths,
@@ -330,7 +330,7 @@ def medium_scenarios(
330330
GroupObjectsTask(obj_types=objects, logger=logger) for objects in object_groups
331331
]
332332

333-
group_object_scenarios = Benchmark.create_scenarios(
333+
group_object_scenarios = ManipulationO3DEBenchmark.create_scenarios(
334334
tasks=group_object_tasks,
335335
simulation_configs=easy_simulations_configs,
336336
simulation_configs_paths=easy_simulation_configs_paths,
@@ -418,15 +418,15 @@ def hard_scenarios(
418418
for objects in object_groups
419419
]
420420

421-
move_to_left_scenarios = Benchmark.create_scenarios(
421+
move_to_left_scenarios = ManipulationO3DEBenchmark.create_scenarios(
422422
tasks=move_to_left_tasks,
423423
simulation_configs=hard_simulations_configs,
424424
simulation_configs_paths=hard_simulation_configs_paths,
425425
)
426426

427427
# place cubes
428428
task = PlaceCubesTask(threshold_distance=0.1, logger=logger)
429-
easy_place_cubes_scenarios = Benchmark.create_scenarios(
429+
easy_place_cubes_scenarios = ManipulationO3DEBenchmark.create_scenarios(
430430
tasks=[task],
431431
simulation_configs=hard_simulations_configs,
432432
simulation_configs_paths=hard_simulation_configs_paths,
@@ -442,7 +442,7 @@ def hard_scenarios(
442442
for objects in object_groups
443443
]
444444

445-
build_tower_scenarios = Benchmark.create_scenarios(
445+
build_tower_scenarios = ManipulationO3DEBenchmark.create_scenarios(
446446
tasks=build_tower_tasks,
447447
simulation_configs=medium_simulations_configs,
448448
simulation_configs_paths=medium_simulation_configs_paths,
@@ -464,7 +464,7 @@ def hard_scenarios(
464464
GroupObjectsTask(obj_types=objects, logger=logger) for objects in object_groups
465465
]
466466

467-
group_object_scenarios = Benchmark.create_scenarios(
467+
group_object_scenarios = ManipulationO3DEBenchmark.create_scenarios(
468468
tasks=group_object_tasks,
469469
simulation_configs=medium_simulations_configs,
470470
simulation_configs_paths=medium_simulation_configs_paths,
@@ -534,7 +534,7 @@ def very_hard_scenarios(
534534
for objects in object_groups
535535
]
536536

537-
build_tower_scenarios = Benchmark.create_scenarios(
537+
build_tower_scenarios = ManipulationO3DEBenchmark.create_scenarios(
538538
tasks=build_tower_tasks,
539539
simulation_configs=hard_simulations_configs,
540540
simulation_configs_paths=hard_simulation_configs_paths,
@@ -555,7 +555,7 @@ def very_hard_scenarios(
555555
GroupObjectsTask(obj_types=objects, logger=logger) for objects in object_groups
556556
]
557557

558-
group_object_scenarios = Benchmark.create_scenarios(
558+
group_object_scenarios = ManipulationO3DEBenchmark.create_scenarios(
559559
tasks=group_object_tasks,
560560
simulation_configs=hard_simulations_configs,
561561
simulation_configs_paths=hard_simulation_configs_paths,

src/rai_bench/rai_bench/examples/test_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
if __name__ == "__main__":
2020
models_name = ["llama3.2", "qwen2.5:7b"]
2121
vendors = ["ollama", "ollama"]
22-
benchmarks = ["manipulation_o3de"]
22+
benchmarks = ["tool_calling_agent", "manipulation_o3de"]
2323
extra_tool_calls = [5]
2424
repeats = 1
2525

src/rai_bench/rai_bench/examples/tool_calling_agent/main.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@ def run_benchmark(model_name: str, vendor: str, out_dir: str, extra_tool_calls:
5454
experiment_dir = Path(out_dir)
5555
experiment_dir.mkdir(parents=True, exist_ok=True)
5656
log_filename = experiment_dir / "benchmark.log"
57-
results_filename = experiment_dir / "results.csv"
5857

5958
file_handler = logging.FileHandler(log_filename)
6059
file_handler.setLevel(logging.DEBUG)
@@ -76,7 +75,10 @@ def run_benchmark(model_name: str, vendor: str, out_dir: str, extra_tool_calls:
7675
task.set_logger(bench_logger)
7776

7877
benchmark = ToolCallingAgentBenchmark(
79-
tasks=all_tasks, logger=bench_logger, results_filename=results_filename
78+
tasks=all_tasks,
79+
logger=bench_logger,
80+
model_name=model_name,
81+
results_dir=experiment_dir,
8082
)
8183

8284
llm = get_llm_model_direct(model_name=model_name, vendor=vendor)
@@ -87,7 +89,7 @@ def run_benchmark(model_name: str, vendor: str, out_dir: str, extra_tool_calls:
8789
system_prompt=task.get_system_prompt(),
8890
logger=agent_logger,
8991
)
90-
benchmark.run_next(agent=agent, model_name=model_name)
92+
benchmark.run_next(agent=agent)
9193

9294

9395
if __name__ == "__main__":

0 commit comments

Comments
 (0)