Skip to content

Commit b62fee1

Browse files
authored
fix: collecting messages from agent (#544)
1 parent bb98682 commit b62fee1

File tree

3 files changed

+35
-32
lines changed

3 files changed

+35
-32
lines changed

src/rai_bench/rai_bench/manipulation_o3de/benchmark.py

+21-18
Original file line numberDiff line numberDiff line change
@@ -184,33 +184,36 @@ def run_next(self, agent: CompiledStateGraph) -> None:
184184
tool_calls_num = 0
185185

186186
ts = time.perf_counter()
187+
prev_count: int = 0
187188
for state in agent.stream(
188189
{"messages": [HumanMessage(content=scenario.task.get_prompt())]},
189190
{
190191
"recursion_limit": 100
191192
}, # NOTE (jmatejcz) what should be recursion limit?
192193
):
193-
graph_node_name = list(state.keys())[0]
194-
msg = state[graph_node_name]["messages"][-1]
195-
196-
if isinstance(msg, HumanMultimodalMessage):
197-
last_msg = msg.text
198-
elif isinstance(msg, BaseMessage):
199-
if isinstance(msg.content, list):
200-
if len(msg.content) == 1:
201-
if type(msg.content[0]) is dict:
202-
last_msg = msg.content[0].get("text", "")
203-
else:
204-
last_msg = msg.content
205-
self._logger.debug(f"{graph_node_name}: {last_msg}") # type: ignore
194+
node = next(iter(state))
195+
new_messages = state[node]["messages"][prev_count:]
196+
prev_count = len(state[node]["messages"])
197+
198+
for msg in new_messages:
199+
if isinstance(msg, HumanMultimodalMessage):
200+
last_msg = msg.text
201+
elif isinstance(msg, BaseMessage):
202+
if isinstance(msg.content, list):
203+
if len(msg.content) == 1:
204+
if type(msg.content[0]) is dict:
205+
last_msg = msg.content[0].get("text", "")
206+
else:
207+
last_msg = msg.content
208+
self._logger.debug(f"{node}: {last_msg}") # type: ignore
206209

207-
else:
208-
raise ValueError(f"Unexpected type of message: {type(msg)}")
210+
else:
211+
raise ValueError(f"Unexpected type of message: {type(msg)}")
209212

210-
if isinstance(msg, AIMessage):
211-
tool_calls_num += len(msg.tool_calls)
213+
if isinstance(msg, AIMessage):
214+
tool_calls_num += len(msg.tool_calls)
212215

213-
self._logger.info(f"AI Message: {msg}") # type: ignore
216+
self._logger.info(f"AI Message: {msg}") # type: ignore
214217

215218
te = time.perf_counter()
216219
try:

src/rai_bench/rai_bench/tool_calling_agent/benchmark.py

+13-13
Original file line numberDiff line numberDiff line change
@@ -181,9 +181,10 @@ def run_next(self, agent: CompiledStateGraph, model_name: str) -> None:
181181

182182
ts = time.perf_counter()
183183
messages: List[BaseMessage] = []
184+
prev_count: int = 0
184185
try:
185186
if isinstance(task, SpatialReasoningAgentTask):
186-
for event in agent.stream(
187+
for state in agent.stream(
187188
{
188189
"messages": [
189190
HumanMultimodalMessage(
@@ -193,20 +194,24 @@ def run_next(self, agent: CompiledStateGraph, model_name: str) -> None:
193194
},
194195
config=config,
195196
):
196-
flattened = {k: v for d in event.values() for k, v in d.items()}
197-
# cos = event.values()
198-
messages.extend(flattened["messages"])
197+
node = next(iter(state))
198+
all_messages = state[node]["messages"]
199+
for new_msg in all_messages[prev_count:]:
200+
messages.append(new_msg)
201+
prev_count = len(messages)
199202
else:
200-
for event in agent.stream(
203+
for state in agent.stream(
201204
{"messages": [HumanMultimodalMessage(content=task.get_prompt())]},
202205
config=config,
203206
):
204-
flattened = {k: v for d in event.values() for k, v in d.items()}
205-
messages.extend(flattened["messages"])
207+
node = next(iter(state))
208+
all_messages = state[node]["messages"]
209+
for new_msg in all_messages[prev_count:]:
210+
messages.append(new_msg)
211+
prev_count = len(messages)
206212

207213
except GraphRecursionError as e:
208214
self.logger.error(msg=f"Reached recursion limit {e}")
209-
# task.fail_rest_of_validators()
210215

211216
self.logger.debug(messages)
212217
toll_calls = task.get_tool_calls_from_messages(messages=messages)
@@ -252,11 +257,6 @@ def run_next(self, agent: CompiledStateGraph, model_name: str) -> None:
252257
if completed_tasks == self.num_tasks:
253258
self._compute_and_save_summary()
254259

255-
# except StopIteration:
256-
# if self.task_results:
257-
# self._compute_and_save_summary()
258-
# print("No more scenarios left to run.")
259-
260260
def _compute_and_save_summary(self):
261261
self.logger.info("Computing and saving average results...")
262262
for model_name, results in self.model_results.items():

src/rai_extensions/rai_open_set_vision/scripts/run_vision_agents.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515

1616
import rclpy
17-
from rai.utils import wait_for_shutdown
17+
from rai.agents import wait_for_shutdown
1818
from rai_open_set_vision.agents import GroundedSamAgent, GroundingDinoAgent
1919

2020

0 commit comments

Comments
 (0)