Skip to content

feat: add StateBaseAgent #529

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: development
Choose a base branch
from

Conversation

boczekbartek
Copy link
Member

@boczekbartek boczekbartek commented Apr 15, 2025

Purpose

  • LLM agents needs information from the environment that can represent their state.

Proposed Changes

  • Introduce StateBaseAgent that can expose the state to the LLM agent and collect state
    using aggragation functions.
  • Add interface for aggregation functions: BaseStateAggregator
  • Add example aggregators: for ros2 logs and VLM based for ros2 images

TODOs:

  • deregister callbacks in agent shutdown

Issues

Testing

source setup_shell.sh 
python examples/agents/state_based.py
vcs import . <demos.repos
cd src/examples/rai-rosbot-demo
./run-nav.sh # with no simulation nav2 will produce a lof of errors with logs, which will be added to state
ros2 topic pub /from_human rai_interfaces/msg/HRIMessage "{'text': 'Is nav2 stack working correctly?'}" --once

Summary by CodeRabbit

  • New Features

    • Introduced a state-based agent example for ROS2 environments, showcasing human-robot interaction and multimodal data aggregation.
    • Added a modular aggregator framework, including ROS2-specific aggregators for logs and images with language model-powered analysis and summarization.
    • Provided a configurable, thread-based agent architecture for periodic state aggregation and integration with language models.
    • Expanded public APIs for agent creation and configuration, supporting advanced stateful and multimodal workflows.
  • Bug Fixes

    • Improved image conversion to support both standard and compressed ROS2 image messages.
  • Documentation

    • Updated developer guide to reflect new agent instantiation methods and imports.
  • Chores

    • Refactored and reorganized agent and aggregator modules for better modularity and extensibility.
    • Updated and expanded package exports for easier access to new agent and aggregator components.

@boczekbartek boczekbartek force-pushed the bb/feat/state_based_agent branch from 66e2383 to 70c68a3 Compare April 16, 2025 07:45
@boczekbartek boczekbartek changed the title feat: add StateBaseAgent feat: add StateBaseAgent Apr 16, 2025
@boczekbartek boczekbartek force-pushed the bb/feat/state_based_agent branch from b7ad8a6 to fc10dcb Compare April 16, 2025 08:02
@boczekbartek boczekbartek marked this pull request as ready for review April 16, 2025 08:06
@boczekbartek
Copy link
Member Author

@coderabbitai full review

Copy link
Contributor

coderabbitai bot commented Apr 16, 2025

Walkthrough

This update introduces a modular framework for state-based agents with ROS2 integration, focusing on multimodal message aggregation and stateful language model interactions. It adds new aggregator abstractions and ROS2-specific aggregators for logs and images, and implements a generic BaseStateBasedAgent class for periodic state aggregation and agent control. ROS2 agent support is established with a dedicated connector and example script. The state-based agent creation logic is refactored: the previous monolithic implementation is removed, and agent creation is now handled via a new runnable in the LangChain integration. Public APIs are reorganized and expanded, with updated imports and documentation.

Changes

File(s) Change Summary
examples/agents/state_based.py New example script demonstrating setup and execution of a ROS2 state-based agent using new aggregators and connectors.
src/rai_core/rai/agents/ros2/state_based_agent.py, src/rai_core/rai/agents/ros2/__init__.py New ROS2 agent class (ROS2StateBasedAgent) and subpackage export.
src/rai_core/rai/agents/langchain/state_based_agent.py New BaseStateBasedAgent class and StateBasedConfig model for periodic state aggregation and agent logic.
src/rai_core/rai/agents/langchain/runnables.py Added create_state_based_runnable and retriever_wrapper for stateful agent runnable creation; supports multimodal state retrieval.
src/rai_core/rai/agents/langchain/__init__.py Expanded public API exports for LangChain agent components, including new state-based agent entities.
src/rai_core/rai/agents/langchain/agent.py, src/rai_core/rai/agents/langchain/react_agent.py Refactored imports to relative; minor docstring and event reset in stop method.
src/rai_core/rai/agents/__init__.py Removed export and import of create_state_based_agent.
src/rai_core/rai/agents/state_based.py Deleted previous monolithic state-based agent implementation and all its public entities.
src/rai_core/rai/aggregators/base.py, src/rai_core/rai/aggregators/__init__.py Introduced abstract BaseAggregator class and package export.
src/rai_core/rai/aggregators/ros2/aggregators.py, src/rai_core/rai/aggregators/ros2/__init__.py Added ROS2-specific aggregator classes for logs and images, and package export.
src/rai_core/rai/communication/ros2/api/conversion.py Updated convert_ros_img_to_base64 to support both Image and CompressedImage types; fixed color conversion constant.
docs/developer_guide.md Updated documentation to reference new state-based agent creation API.

Sequence Diagram(s)

sequenceDiagram
    participant User
    participant ROS2StateBasedAgent
    participant Aggregators
    participant ROS2Connector
    participant LLM
    participant Tools

    User->>ROS2StateBasedAgent: Start agent
    ROS2StateBasedAgent->>ROS2Connector: Setup connectors (subscribe/publish)
    loop Periodic Aggregation
        ROS2Connector->>Aggregators: Receive and buffer messages
        Aggregators->>ROS2StateBasedAgent: Aggregate state (logs, images, etc.)
    end
    ROS2StateBasedAgent->>LLM: Provide aggregated state (via retriever)
    LLM->>Tools: (Conditional) Call tools if needed
    Tools-->>LLM: Tool results
    LLM->>ROS2StateBasedAgent: Generate response
    ROS2StateBasedAgent->>ROS2Connector: Publish response (e.g., to HRI topic)
    User-->>ROS2StateBasedAgent: Shutdown signal
    ROS2StateBasedAgent->>ROS2Connector: Clean shutdown
Loading
sequenceDiagram
    participant Agent
    participant Aggregators
    participant LLM

    loop Aggregation Interval
        Agent->>Aggregators: Trigger aggregation
        Aggregators-->>Agent: Aggregated messages (logs, images, etc.)
    end
    Agent->>LLM: Run stateful runnable with aggregated state
    LLM->>Agent: Output (response, tool calls, etc.)
Loading

📜 Recent review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Lite

📥 Commits

Reviewing files that changed from the base of the PR and between bb98682 and cebcfb9.

📒 Files selected for processing (16)
  • docs/developer_guide.md (2 hunks)
  • examples/agents/state_based.py (1 hunks)
  • src/rai_core/rai/agents/__init__.py (0 hunks)
  • src/rai_core/rai/agents/langchain/__init__.py (1 hunks)
  • src/rai_core/rai/agents/langchain/agent.py (3 hunks)
  • src/rai_core/rai/agents/langchain/react_agent.py (1 hunks)
  • src/rai_core/rai/agents/langchain/runnables.py (2 hunks)
  • src/rai_core/rai/agents/langchain/state_based_agent.py (1 hunks)
  • src/rai_core/rai/agents/ros2/__init__.py (1 hunks)
  • src/rai_core/rai/agents/ros2/state_based_agent.py (1 hunks)
  • src/rai_core/rai/agents/state_based.py (0 hunks)
  • src/rai_core/rai/aggregators/__init__.py (1 hunks)
  • src/rai_core/rai/aggregators/base.py (1 hunks)
  • src/rai_core/rai/aggregators/ros2/__init__.py (1 hunks)
  • src/rai_core/rai/aggregators/ros2/aggregators.py (1 hunks)
  • src/rai_core/rai/communication/ros2/api/conversion.py (1 hunks)
💤 Files with no reviewable changes (2)
  • src/rai_core/rai/agents/init.py
  • src/rai_core/rai/agents/state_based.py
🧰 Additional context used
🧬 Code Graph Analysis (6)
src/rai_core/rai/agents/ros2/state_based_agent.py (2)
src/rai_core/rai/communication/ros2/connectors/ros2_connector.py (1)
  • ROS2Connector (19-20)
src/rai_core/rai/agents/langchain/state_based_agent.py (2)
  • BaseStateBasedAgent (47-188)
  • setup_connector (96-97)
src/rai_core/rai/aggregators/__init__.py (1)
src/rai_core/rai/aggregators/base.py (1)
  • BaseAggregator (24-54)
src/rai_core/rai/agents/ros2/__init__.py (1)
src/rai_core/rai/agents/ros2/state_based_agent.py (1)
  • ROS2StateBasedAgent (20-22)
src/rai_core/rai/aggregators/base.py (3)
src/rai_core/rai/communication/base_connector.py (1)
  • BaseMessage (40-48)
src/rai_core/rai/communication/ros2/ros_logs.py (2)
  • append (156-159)
  • clear (153-154)
src/rai_core/rai/aggregators/ros2/aggregators.py (4)
  • get (34-53)
  • get (59-66)
  • get (85-109)
  • get (139-169)
src/rai_core/rai/communication/ros2/api/conversion.py (1)
src/rai_bench/rai_bench/tool_calling_agent/messages/topics.py (1)
  • Image (39-46)
src/rai_core/rai/agents/langchain/agent.py (3)
src/rai_core/rai/agents/base.py (1)
  • BaseAgent (19-29)
src/rai_core/rai/agents/langchain/callback.py (1)
  • HRICallbackHandler (27-99)
src/rai_core/rai/agents/langchain/runnables.py (1)
  • ReActAgentState (40-49)
🔇 Additional comments (13)
src/rai_core/rai/aggregators/base.py (3)

32-40: LGTM! Good implementation of a circular buffer.

The constructor and __call__ method correctly implement a circular buffer pattern with automatic eviction of oldest elements when the maximum size is reached.


42-45: LGTM! Clear abstract method definition.

The abstract get method signature is well-defined with appropriate return type and documentation.


47-54:

❓ Verification inconclusive

Consider adding thread-safety if needed for concurrent access.

The buffer management methods look good, but if this class might be accessed from multiple threads, consider adding synchronization mechanisms.


🏁 Script executed:

#!/bin/bash
# Check if there are any indications of multi-threaded or concurrent usage of the aggregator classes
echo "Checking for potential multi-threaded usage of aggregators"
grep -r "Thread\|thread\|concurrent\|async\|await" --include="*.py" src/rai_core/rai | grep -i "aggregator"

Length of output: 234


Verify thread-safety requirements for Aggregator buffer

I ran a search across src/rai_core/rai for concurrency patterns (Thread, concurrent, async, await) and found no matches. It looks like these buffer methods aren’t currently used in a multi-threaded context. If you do plan to share an instance of this class across threads, consider:

  • Initializing a lock in __init__:
    from threading import Lock
    ...
    self._lock = Lock()
  • Wrapping buffer access in both methods:
     def clear_buffer(self) -> None:
    -    self._buffer.clear()
    +    with self._lock:
    +        self._buffer.clear()
    
     def get_buffer(self) -> List[T]:
    -    return list(self._buffer)
    +    with self._lock:
    +        return list(self._buffer)

If you’re not using this class concurrently, no changes are needed.

src/rai_core/rai/agents/ros2/__init__.py (1)

1-16: LGTM! Clean module export.

The import and export are implemented correctly, making ROS2StateBasedAgent available when importing from this package.

src/rai_core/rai/agents/langchain/react_agent.py (1)

22-23: LGTM! Good use of relative imports.

Converting from absolute to relative imports for internal package modules is a good practice that improves maintainability.

docs/developer_guide.md (1)

146-148: LGTM: Updated agent creation with new runnable-based approach

The updated agent creation code correctly uses the new create_state_based_runnable function, maintaining the same parameter structure while leveraging the new architecture.

src/rai_core/rai/agents/langchain/agent.py (3)

28-30: LGTM: Improved import organization

Changing from absolute to relative imports is a good practice that makes the code more maintainable within the package structure.


212-213: LGTM: Enhanced method documentation

Adding a docstring to the stop method improves code clarity and documentation.


221-221: LGTM: Proper event cleanup

Clearing the stop event after joining the thread ensures the event is reset properly, which is important for proper agent lifecycle management if the agent is restarted.

src/rai_core/rai/agents/ros2/state_based_agent.py (1)

15-18: LGTM: Clean imports

The imports are concise and appropriate, bringing in just the necessary components from their respective modules.

src/rai_core/rai/aggregators/ros2/__init__.py (2)

15-20: LGTM: Clear component imports

The imports are well-organized, bringing in all the aggregator classes from the implementation file.


22-27: LGTM: Well-defined public API

The __all__ list clearly defines the public API for this module, making it explicit which classes are intended to be used by consumers of this package.

examples/agents/state_based.py (1)

35-44: Mixed key types in aggregators dict may break lookup logic

StateBasedConfig.aggregators receives both
• a Tuple[str,str] key and
• a plain str key.

If the consuming code normalises topic names into one representation (e.g. always tuple), the second entry ("/rosout") might be ignored or raise KeyError.

Please confirm the model allows heterogeneous keys; otherwise, wrap the second key in a tuple too:

-    "/rosout": [
+    ("/rosout", "rcl_interfaces/msg/Log"): [

(Replace with the correct type string expected by the aggregator.)


🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai generate sequence diagram to generate a sequence diagram of the changes in this PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🔭 Outside diff range comments (1)
src/rai_core/rai/communication/ros2/api/conversion.py (1)

117-134: 🛠️ Refactor suggestion

Consider consolidating image conversion functions.

The new function provides a cleaner approach for encoding ROS2 images to base64, supporting both Image and CompressedImage types. It's a good implementation, but as noted in the TODO comment, it should be merged with the existing convert_ros_img_to_base64 function to avoid code duplication.

Consider implementing this consolidation by either:

  1. Expanding the existing function to handle both types
  2. Creating a unified function that calls appropriate conversion methods based on message type
-# TODO(boczekbartek): merge with above
-def encode_ros2_img_to_base64(
-    img_message: sensor_msgs.msg.Image | sensor_msgs.msg.CompressedImage,
-) -> str:
-    msg_type = type(img_message)
-    if msg_type == sensor_msgs.msg.Image:
-        image = CvBridge().imgmsg_to_cv2(  # type: ignore
-            img_message, desired_encoding="rgb8"
-        )
-    elif msg_type == sensor_msgs.msg.CompressedImage:
-        image = CvBridge().compressed_imgmsg_to_cv2(  # type: ignore
-            img_message, desired_encoding="rgb8"
-        )
-    else:
-        raise ValueError(f"Unsupported message type: {msg_type}")
-    image = cast(cv2.Mat, image)
-    return preprocess_image(image)
+def convert_ros_img_to_base64(
+    msg: sensor_msgs.msg.Image | sensor_msgs.msg.CompressedImage,
+    desired_encoding: str = "rgb8"
+) -> str:
+    """Convert ROS2 image message to base64 encoded string.
+
+    Args:
+        msg: ROS2 image message
+        desired_encoding: Desired encoding for the image
+
+    Returns:
+        Base64 encoded string
+
+    Raises:
+        ValueError: If the message type is not supported
+    """
+    bridge = CvBridge()
+    msg_type = type(msg)
+    
+    if msg_type == sensor_msgs.msg.Image:
+        cv_image = cast(cv2.Mat, bridge.imgmsg_to_cv2(msg, desired_encoding=desired_encoding))
+    elif msg_type == sensor_msgs.msg.CompressedImage:
+        cv_image = cast(cv2.Mat, bridge.compressed_imgmsg_to_cv2(msg, desired_encoding=desired_encoding))
+    else:
+        raise ValueError(f"Unsupported message type: {msg_type}")
+    
+    return preprocess_image(cv_image)
🧹 Nitpick comments (5)
examples/agents/state_based.py (1)

12-12: Typo in License Header

There's a small spelling error in line 12: "goveself.rning" should be "governing".

-# See the License for the specific language goveself.rning permissions and
+# See the License for the specific language governing permissions and
src/rai_core/rai/agents/react_agent.py (1)

39-49: Clarify Handling of 'runnable' vs 'tools' and 'system_prompt'

When both runnable and tools / system_prompt are provided, the latter are effectively superseded by the custom runnable. Consider documenting or warning if both are passed, to avoid confusion.

src/rai_core/rai/agents/langchain/runnables.py (2)

126-127: Add Docstring to Explain StateMessage

A short docstring describing the purpose and usage of StateMessage can aid readability and maintainability.

+class StateMessage(HumanMultimodalMessage):
+    """
+    A multimodal message containing retrieved state information
+    (images, audio, textual content) for the agent to incorporate.
+    """

162-164: Use Named Function Instead of Lambda

Static analysis suggests rewriting the lambda expression as a named function to enhance clarity.

-        state_retriever = lambda: {}
+        def default_state_retriever():
+            return {}
+        state_retriever = default_state_retriever
🧰 Tools
🪛 Ruff (0.8.2)

163-163: Do not assign a lambda expression, use a def

Rewrite state_retriever as a def

(E731)

src/rai_core/rai/state_based/ros2/aggregators.py (1)

98-99: Optimize the image encoding step.

Currently, you encode all images in the buffer even though only the last image is used in the aggregator. This can be inefficient for large buffers or high-frequency image streams.

-def __call__(self, msgs: List[Image | CompressedImage]) -> str | None:
-    if len(msgs) == 0:
-        return None
-    b64_images: List[str] = [encode_ros2_img_to_base64(msg) for msg in msgs]
-    return encode_ros2_img_to_base64(msgs[-1])
+def __call__(self, msgs: List[Image | CompressedImage]) -> str | None:
+    if not msgs:
+        return None
+    return encode_ros2_img_to_base64(msgs[-1])
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between c4ba634 and fc10dcb.

📒 Files selected for processing (10)
  • examples/agents/state_based.py (1 hunks)
  • src/rai_core/rai/agents/__init__.py (1 hunks)
  • src/rai_core/rai/agents/langchain/runnables.py (2 hunks)
  • src/rai_core/rai/agents/react_agent.py (2 hunks)
  • src/rai_core/rai/agents/state_based.py (0 hunks)
  • src/rai_core/rai/agents/state_based_agent.py (1 hunks)
  • src/rai_core/rai/communication/ros2/api/conversion.py (2 hunks)
  • src/rai_core/rai/initialization/model_initialization.py (2 hunks)
  • src/rai_core/rai/state_based/ros2/aggregators.py (1 hunks)
  • src/rai_core/rai/tools/ros2/generic/topics.py (2 hunks)
💤 Files with no reviewable changes (1)
  • src/rai_core/rai/agents/state_based.py
🧰 Additional context used
🧬 Code Graph Analysis (5)
src/rai_core/rai/tools/ros2/generic/topics.py (2)
src/rai_core/rai/communication/ros2/api/conversion.py (2)
  • encode_ros2_img_to_base64 (118-133)
  • ros2_message_to_dict (32-47)
src/rai_core/rai/messages/artifacts.py (1)
  • MultimodalArtifact (20-22)
src/rai_core/rai/agents/react_agent.py (1)
src/rai_core/rai/agents/langchain/runnables.py (2)
  • ReActAgentState (38-47)
  • create_react_runnable (78-123)
src/rai_core/rai/communication/ros2/api/conversion.py (1)
src/rai_core/rai/messages/conversion.py (1)
  • preprocess_image (24-48)
src/rai_core/rai/agents/__init__.py (1)
src/rai_core/rai/agents/state_based_agent.py (2)
  • StateBasedAgent (45-166)
  • StateBasedConfig (36-42)
src/rai_core/rai/agents/langchain/runnables.py (3)
src/rai_core/rai/agents/tool_runner.py (1)
  • ToolRunner (33-146)
src/rai_core/rai/initialization/model_initialization.py (1)
  • get_llm_model (125-158)
src/rai_core/rai/messages/multimodal.py (1)
  • HumanMultimodalMessage (60-81)
🪛 Ruff (0.8.2)
src/rai_core/rai/agents/langchain/runnables.py

163-163: Do not assign a lambda expression, use a def

Rewrite state_retriever as a def

(E731)

src/rai_core/rai/state_based/ros2/aggregators.py

51-51: Do not perform function call TopicConfig in argument defaults; instead, perform the call within the function, or read the default from a module-level singleton variable

(B008)

🔇 Additional comments (7)
src/rai_core/rai/communication/ros2/api/conversion.py (1)

29-29: Good addition of import for preprocess_image.

The import statement correctly brings in the preprocess_image function that will be used by the new encode_ros2_img_to_base64 function.

src/rai_core/rai/agents/__init__.py (2)

18-18: Updated import for the new StateBasedAgent architecture.

The import statement has been properly updated to reflect the refactoring from the factory function pattern to a class-based approach.


24-25: Properly updated all list with new exported symbols.

The all list has been correctly updated to include the new StateBasedAgent and StateBasedConfig classes, which aligns with the architectural changes in this PR.

src/rai_core/rai/initialization/model_initialization.py (2)

22-22: Good addition of BaseChatModel import.

Importing BaseChatModel from langchain.chat_models.base is a necessary addition to support the return type annotation being added.


130-130: Enhanced type safety with return type annotation.

Adding the explicit return type annotation -> BaseChatModel to the get_llm_model function improves type safety and code readability, making it clearer what type of object the function returns.

src/rai_core/rai/tools/ros2/generic/topics.py (2)

25-29: Well-structured imports with specific functions.

The imports have been properly updated to include only the necessary functions from the conversion module, improving code clarity by making dependencies explicit.


136-137:

✅ Verification successful

Simplified image processing using centralized utility function.

The implementation now leverages the new encode_ros2_img_to_base64 utility function, which reduces code duplication and centralizes image processing logic. This change makes the code more maintainable and consistent.


🏁 Script executed:

#!/bin/bash
# Verify that no other code in the repository still uses direct CvBridge conversion 
# for images that could benefit from the new utility function

echo "Searching for potential direct CvBridge usage for image conversion that could use encode_ros2_img_to_base64 instead:"
rg -A 3 -B 3 "CvBridge\(\).imgmsg_to_cv2|CvBridge\(\).compressed_imgmsg_to_cv2" --type py --glob '!**/conversion.py'

Length of output: 361


Centralized image processing utility verified

The changes introduced in src/rai_core/rai/tools/ros2/generic/topics.py (lines 136-137) now effectively leverage the encode_ros2_img_to_base64 utility function. The provided grep command confirmed there are no residual direct usages of CvBridge conversion methods in the repository, ensuring consistency and improved maintainability.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider renaming state_based catalog to aggregators.
Consider moving the BaseStateAggregator(ABC) to base.py and keeping the implementations in aggregators/ros2 catalog.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good idea, done in cfd0ca0

qos_profile = self._ari_connector._topic_api._resolve_qos_profile(
topic, config.auto_qos_matching, config.qos_profile, for_publisher=False
)
print(qos_profile)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
print(qos_profile)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

@boczekbartek boczekbartek Apr 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry for not accepting your suggestion, I noticed your comment after cleaning this up myself :D

@boczekbartek boczekbartek force-pushed the bb/feat/state_based_agent branch from 4c06d7c to f31ffff Compare April 17, 2025 11:20
@boczekbartek boczekbartek force-pushed the bb/feat/state_based_agent branch from a0665a3 to 6620b74 Compare April 17, 2025 13:04
@boczekbartek
Copy link
Member Author

@coderabbitai full review

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 7

🧹 Nitpick comments (10)
src/rai_core/rai/aggregators/base.py (1)

49-56: Consider adding return type annotations to utility methods.

While the code is well-structured, for consistency with the rest of the class, consider adding return type annotations to the clear and get_buffer methods.

-    def clear(self) -> None:
+    def clear(self) -> None:
     self._buffer.clear()

-    def get_buffer(self) -> List[T]:
+    def get_buffer(self) -> List[T]:
     return list(self._buffer)

-    def __str__(self) -> str:
+    def __str__(self) -> str:
     return f"{self.__class__.__name__}"
src/rai_core/rai/agents/ros2/state_based_agent.py (1)

19-21: Consider adding constructor parameters for ROS2Connector customization.

The implementation is simple and follows the inheritance pattern correctly. However, the ROS2Connector constructor accepts parameters like node_name and destroy_subscribers that could be useful for customization.

Consider extending the constructor to allow customization of the ROS2 connector:

-class ROS2StateBasedAgent(BaseStateBasedAgent):
-    def setup_connector(self):
-        return ROS2Connector()
+class ROS2StateBasedAgent(BaseStateBasedAgent):
+    def __init__(
+        self,
+        connectors: dict[str, HRIConnector[HRIMessage]],
+        config: StateBasedConfig,
+        llm: Optional[BaseChatModel] = None,
+        tools: Optional[List[BaseTool]] = None,
+        state: Optional[ReActAgentState] = None,
+        system_prompt: Optional[str] = None,
+        node_name: str = None,
+        destroy_subscribers: bool = False,
+    ):
+        super().__init__(
+            connectors, config, llm, tools, state, system_prompt
+        )
+        self.node_name = node_name
+        self.destroy_subscribers = destroy_subscribers
+
+    def setup_connector(self):
+        return ROS2Connector(
+            node_name=self.node_name,
+            destroy_subscribers=self.destroy_subscribers
+        )

This allows users to customize the ROS2 connector when creating the agent.

src/rai_core/rai/tools/ros2/generic/topics.py (1)

136-137: Simplified image encoding using the new utility function.

The code is much cleaner now, using the centralized encoding function instead of inline conversion logic. However, consider adding error handling for potential failures during image encoding.

-        b64_image = encode_ros2_img_to_base64(message.payload)
-        return "Image received successfully", MultimodalArtifact(images=[b64_image])
+        try:
+            b64_image = encode_ros2_img_to_base64(message.payload)
+            return "Image received successfully", MultimodalArtifact(images=[b64_image])
+        except ValueError as e:
+            # Handle unsupported message types
+            return f"Error encoding image: {str(e)}", MultimodalArtifact(images=[])
+        except Exception as e:
+            # Handle other errors
+            return f"Unexpected error encoding image: {str(e)}", MultimodalArtifact(images=[])
src/rai_core/rai/agents/react_agent.py (1)

39-49: self.llm may stay None when a pre‑built runnable is injected – confirm this will not break downstream usages

When the caller passes in a ready‑made runnable, the llm argument can legitimately be None.
Because we unconditionally store it (self.llm = llm), any later code that expects a concrete BaseChatModel may find None and raise.

If self.llm is never referenced outside this constructor, consider removing the attribute entirely or documenting that it can be None. Otherwise, guard all later usages.

examples/agents/state_based.py (1)

11-13: Typo in licence header

goveself.rninggoverning

While purely cosmetic, licence headers are scanned automatically by some compliance tools; correcting the typo avoids false positives.

-# See the License for the specific language goveself.rning permissions and
+# See the License for the specific language governing permissions and
src/rai_core/rai/agents/base_state_based_agent.py (2)

80-88: Unused ReentrantCallbackGroup—remove or integrate to avoid dead code

self._callback_group is created but never passed to the connector or used elsewhere.
Keeping unused members adds cognitive overhead and may mis‑lead future maintainers.

-        self._callback_group = ReentrantCallbackGroup()

If you plan to leverage it for ROS2 subscriptions later, add a TODO comment explaining when it will become relevant. Otherwise, delete it.


154-165: _stop_event.clear() at the end of stop() can restart threads unintentionally

clear() resets the flag that other threads (and run()) rely on.
If a user calls stop() and later forgets to recreate the agent, a subsequent run() will restart old threads on the same object, which is error‑prone.

Remove the clear() or add a comment justifying the reuse scenario.

src/rai_core/rai/aggregators/ros2/aggregators.py (3)

34-50: Consider clearing the buffer after processing

Without self.clear(), the same log messages are re‑processed in every interval.
If persistence is not desired, add self.clear() at the end of get().


130-157: Diff aggregator: misleading statistics and missing original length

The final message uses len(b64_images) twice, so it always reports n keyframes selected from n images.

-        return HumanMessage(
-            content=f"Result of the analysis of the {len(b64_images)} keyframes selected from {len(b64_images)} last images:\n{response}"
-        )
+        return HumanMessage(
+            content=(
+                f"Result of the analysis of the {len(b64_images)} keyframes "
+                f"selected from {len(msgs)} last images:\n{response}"
+            )
+        )

111-124: get_key_elements docstring & name mention 1st/middle/last but function returns in unspecified order

The helper currently returns [elements[0], elements[middle], elements[-1]], which is fine, but consider clarifying the order in the docstring or renaming to select_keyframes for intent clarity.

📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Lite

📥 Commits

Reviewing files that changed from the base of the PR and between 28eb194 and 3861b36.

📒 Files selected for processing (15)
  • examples/agents/state_based.py (1 hunks)
  • src/rai_core/rai/agents/__init__.py (1 hunks)
  • src/rai_core/rai/agents/base_state_based_agent.py (1 hunks)
  • src/rai_core/rai/agents/langchain/runnables.py (2 hunks)
  • src/rai_core/rai/agents/react_agent.py (2 hunks)
  • src/rai_core/rai/agents/ros2/__init__.py (1 hunks)
  • src/rai_core/rai/agents/ros2/state_based_agent.py (1 hunks)
  • src/rai_core/rai/agents/state_based.py (0 hunks)
  • src/rai_core/rai/aggregators/__init__.py (1 hunks)
  • src/rai_core/rai/aggregators/base.py (1 hunks)
  • src/rai_core/rai/aggregators/ros2/__init__.py (1 hunks)
  • src/rai_core/rai/aggregators/ros2/aggregators.py (1 hunks)
  • src/rai_core/rai/communication/ros2/api/conversion.py (2 hunks)
  • src/rai_core/rai/initialization/model_initialization.py (2 hunks)
  • src/rai_core/rai/tools/ros2/generic/topics.py (2 hunks)
💤 Files with no reviewable changes (1)
  • src/rai_core/rai/agents/state_based.py
🧰 Additional context used
🧬 Code Graph Analysis (8)
src/rai_core/rai/tools/ros2/generic/topics.py (2)
src/rai_core/rai/communication/ros2/api/conversion.py (2)
  • encode_ros2_img_to_base64 (118-133)
  • ros2_message_to_dict (32-47)
src/rai_core/rai/messages/artifacts.py (1)
  • MultimodalArtifact (20-22)
src/rai_core/rai/aggregators/__init__.py (1)
src/rai_core/rai/aggregators/base.py (1)
  • BaseAggregator (26-56)
src/rai_core/rai/aggregators/ros2/__init__.py (1)
src/rai_core/rai/aggregators/ros2/aggregators.py (4)
  • ROS2GetLastImageAggregator (54-63)
  • ROS2ImgVLMDescriptionAggregator (66-102)
  • ROS2ImgVLMDiffAggregator (105-157)
  • ROS2LogsAggregator (28-51)
src/rai_core/rai/agents/ros2/state_based_agent.py (2)
src/rai_core/rai/agents/base_state_based_agent.py (2)
  • BaseStateBasedAgent (47-165)
  • setup_connector (93-94)
src/rai_core/rai/communication/ros2/connectors/connector.py (1)
  • ROS2Connector (42-283)
src/rai_core/rai/agents/ros2/__init__.py (1)
src/rai_core/rai/agents/ros2/state_based_agent.py (1)
  • ROS2StateBasedAgent (19-21)
src/rai_core/rai/communication/ros2/api/conversion.py (1)
src/rai_core/rai/messages/conversion.py (1)
  • preprocess_image (24-48)
src/rai_core/rai/agents/react_agent.py (1)
src/rai_core/rai/agents/langchain/runnables.py (2)
  • ReActAgentState (39-48)
  • create_react_runnable (79-124)
src/rai_core/rai/aggregators/base.py (1)
src/rai_core/rai/messages/multimodal.py (1)
  • HumanMultimodalMessage (60-81)
🔇 Additional comments (15)
src/rai_core/rai/initialization/model_initialization.py (1)

22-22: Good addition of type annotation for improved type safety.

Adding the -> BaseChatModel return type annotation to the get_llm_model function and importing the required class improves code clarity and enables better static type checking.

Also applies to: 130-130

src/rai_core/rai/aggregators/base.py (4)

15-24: Well-structured imports and type definitions.

The imports are well-organized, and the use of generic type variable T provides a flexible foundation for different aggregator implementations.


26-38: Good design of the BaseAggregator abstract class.

The base class is well-documented and follows good OOP principles with clear separation of concerns. The use of Generic[T] allows for type-safe implementations for different message types.


39-43: Efficient buffer management implementation.

The __call__ method efficiently manages the buffer size by automatically removing the oldest message when the buffer reaches its maximum size.


44-47: Clear abstract method definition.

The abstract get method is well-defined with appropriate return type annotations that support both regular and multimodal message types.

src/rai_core/rai/agents/ros2/__init__.py (1)

1-16: Appropriate module structure for ROS2 agent exports.

The module follows standard Python packaging practices by defining __all__ to explicitly specify the public API. This provides a clean interface for users of the package.

src/rai_core/rai/aggregators/__init__.py (1)

1-17: Well-structured package initialization.

The module correctly exports the BaseAggregator class through __all__, following Python best practices for package organization.

src/rai_core/rai/agents/__init__.py (2)

15-15: Update imports to match the new architecture.

The import of BaseStateBasedAgent and StateBasedConfig from the new module reflects the architectural shift from the earlier factory function approach to a more modular, class-based design.


21-29: Exports updated appropriately.

The __all__ list correctly exports the new classes while removing the old factory function. This ensures that the public API is up-to-date with the architectural changes.

src/rai_core/rai/agents/ros2/state_based_agent.py (1)

15-17: Imports are appropriate and minimal.

The imports bring in only the necessary components: the base agent class and the ROS2 connector.

src/rai_core/rai/aggregators/ros2/__init__.py (2)

15-20: Well-structured imports for aggregator classes.

The imports clearly specify the four ROS2-specific aggregator classes from the local module.


22-27: Explicit exports in all list follow best practices.

The __all__ list explicitly exports the four aggregator classes, defining a clear public API for the module.

src/rai_core/rai/tools/ros2/generic/topics.py (1)

24-29: Updated imports to use the new conversion function.

The imports are updated to include the new encode_ros2_img_to_base64 function along with the existing ros2_message_to_dict function.

src/rai_core/rai/communication/ros2/api/conversion.py (1)

29-30: Added import for preprocess_image.

The import for the shared image preprocessing utility is appropriate.

src/rai_core/rai/agents/base_state_based_agent.py (1)

96-103: Callback registration leaks if setup_connector() fails

If any single register_callback raises, previously registered callbacks remain stored in the connector and _registered_callbacks, yet stop() will never be called. Wrap the loop in try/except, unregister already‑added callbacks on failure, or build the list first and register in one go.

self._buffer.append(msg)

@abstractmethod
def get(self) -> HumanMessage | HumanMultimodalMessage | None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BaseMessage

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return list(self._buffer)

def __str__(self) -> str:
return f"{self.__class__.__name__}"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add len of bufferf?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Returns the VLM analysis of the last image in the aggregation buffer
"""

def __init__(self, max_size: int | None = None) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

llm optional[BaseChatModel]

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"""Returns the aggregated message"""
pass

def clear(self) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clear_buffer?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

b64_images: List[str] = [encode_ros2_img_to_base64(msg) for msg in msgs]
self.clear()

system_prompt = "You are an expert in image analysis and your speciality is the"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move to static const, allow overriding via constructor

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@boczekbartek
Copy link
Member Author

@CodeRabbit full review

@boczekbartek boczekbartek force-pushed the bb/feat/state_based_agent branch from cebcfb9 to 81c575a Compare April 29, 2025 09:33
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Inline review comments failed to post. This is likely due to GitHub's limits when posting large numbers of comments. If you are seeing this consistently it is likely a permissions issue. Please check "Moderation" -> "Code review limits" under your organization settings.

Actionable comments posted: 2

♻️ Duplicate comments (4)
src/rai_core/rai/aggregators/__init__.py (1)

1-17: Clean package export, but check copyright year.

The init file correctly exports the BaseAggregator class, but the copyright year should be updated.

-# Copyright (C) 2025 Robotec.AI
+# Copyright (C) 2024 Robotec.AI
src/rai_core/rai/agents/langchain/runnables.py (2)

128-136: In-place mutation of shared HumanMessage objects remains unfixed
See earlier review; the code still mutates message.content, risking state corruption if the same object is reused elsewhere.


164-172: Creating ToolRunner([]) still raises ValueError – original issue persists
The runnable builds a ToolRunner even when tools is empty, reproducing the crash highlighted in the previous review.

src/rai_core/rai/aggregators/ros2/aggregators.py (1)

32-52: Duplicate-compression algorithm still loses the last duplicate block
(see previous review).

The logic drops the final repetition summary and replaces the new log with
the repetition counter, discarding the actual message (B in A A A B). A
corrected version was proposed earlier and is reposted below for convenience.

-        prev_parsed = None
-        counter = 0
+        prev_parsed = None
+        counter = 1
@@
-            if parsed == prev_parsed:
-                counter += 1
-                continue
-            else:
-                if counter != 0:
-                    parsed = f"Log above repeated {counter} times"
-            buffer.append(parsed)
-            counter = 0
-            prev_parsed = parsed
+            if parsed == prev_parsed:
+                counter += 1
+            else:
+                # Flush previous duplicates first
+                if counter > 1:
+                    buffer.append(f"Previous log repeated {counter-1} times")
+                buffer.append(parsed)
+                prev_parsed = parsed
+                counter = 1
+
+        if counter > 1:
+            buffer.append(f"Previous log repeated {counter-1} times")
-        result = f"Logs summary: {list(dict.fromkeys(buffer))}"
+        result = f"Logs summary: {buffer}"

This preserves every unique message and accurately reports repetitions at the
end of the buffer.

🧹 Nitpick comments (13)
src/rai_core/rai/aggregators/base.py (2)

1-13: Check the copyright year.

The copyright year is set to 2025, which appears to be in the future. Update it to the current year (2024) or the appropriate year when this code was written.

-# Copyright (C) 2025 Robotec.AI
+# Copyright (C) 2024 Robotec.AI

23-30: Enhance documentation to describe the purpose and use cases.

The current docstring explains the method behavior but doesn't describe the purpose of aggregators in the overall system.

 class BaseAggregator(ABC, Generic[T]):
     """
     Interface for aggregators.
 
     `__call__` method receives a message and appends it to the buffer.
     `get` method returns the aggregated message.
+    
+    Aggregators buffer messages of type T and convert them into BaseMessage instances
+    that can be used by LLM agents to incorporate environmental state information.
+    Subclasses should define how to convert buffered messages into a meaningful 
+    representation for the agent.
     """
docs/developer_guide.md (1)

134-135: Consider adding an explanation of "runnable" concept

The import has been correctly updated to reflect the architectural shift from create_state_based_agent to create_state_based_runnable. This aligns with the new modular framework for state-based agents introduced in the PR.

Consider adding a brief note explaining what a "runnable" is in the LangChain context, as this concept might not be familiar to all developers using the framework.

src/rai_core/rai/agents/ros2/state_based_agent.py (1)

20-22: Add class docstring to explain purpose and usage

The implementation clearly defines a ROS2-specific state-based agent by overriding the setup_connector method to return a ROS2Connector instance.

Add a docstring to the class explaining its purpose, how it integrates with ROS2, and providing usage examples. This would be especially helpful for new users:

class ROS2StateBasedAgent(BaseStateBasedAgent):
+    """
+    A state-based agent implementation for ROS2 environments.
+    
+    This class extends BaseStateBasedAgent to work specifically with ROS2,
+    setting up a ROS2Connector for communication and state aggregation.
+    
+    Example:
+        ```python
+        from rai.agents.ros2 import ROS2StateBasedAgent
+        from rai.aggregators.ros2 import ROS2LogsAggregator
+        
+        # Configure state aggregation
+        config = StateBasedConfig(
+            aggregators={"/rosout": [ROS2LogsAggregator()]},
+            time_interval=1.0
+        )
+        
+        # Create agent
+        agent = ROS2StateBasedAgent(config=config, ...)
+        agent.run()
+        ```
+    """
    def setup_connector(self):
        return ROS2Connector()
src/rai_core/rai/communication/ros2/api/conversion.py (1)

118-120: Use tobytes() and drop the extra copy

np.ndarray.tostring() is deprecated in NumPy ≥1.19 and will be removed.
Also, wrapping the result in bytes() above creates an unnecessary copy, unlike ndarray.tobytes(), which already returns bytes.

-        image_data = cv2.imencode(".png", cv_image)[1].tostring()  # type: ignore
+        image_data = cv2.imencode(".png", cv_image)[1].tobytes()  # type: ignore

Please update the earlier branches (bytes(cv2.imencode(...)[1])) similarly for consistency.

src/rai_core/rai/agents/langchain/__init__.py (2)

15-23: Heavy eager imports increase startup time & risk circular deps

Importing nine sub-modules unconditionally in __init__.py slows down import rai.agents.langchain and may trigger circular-import headaches (e.g. if .agent or .react_agent import this package back).
Consider deferring the imports:

from importlib import import_module
import sys
def __getattr__(name):
    mapping = {
        "BaseState": ".agent",
        "LangChainAgent": ".agent",
        ...
    }
    if name in mapping:
        module = import_module(f"{__name__}{mapping[name]}")
        obj = getattr(module, name)
        globals()[name] = obj
        return obj
    raise AttributeError(name)

This keeps the public surface unchanged but avoids loading unused code.


25-36: __all__ omission: newly exposed symbols must stay in sync

Every time a new submodule symbol is added you need to remember to edit __all__.
To avoid rot, you can generate it dynamically:

__all__ = [n for n in globals() if not n.startswith("_") and n[0].isupper()]

(or export only the mapping used in __getattr__, if you adopt lazy loading).

examples/agents/state_based.py (2)

12-13: Typo in licence boiler-plate

goveself.rning looks like a find-replace artefact – should be “governing”.


50-53: Graceful shutdown / exception propagation

agent.run() is blocking and exceptions inside the graph will currently terminate the process without a clear message.
Consider running it in a background task and wrapping in try/finally so that wait_for_shutdown is always reached and connectors are closed:

try:
    agent.run()
finally:
    wait_for_shutdown([agent])
src/rai_core/rai/agents/langchain/state_based_agent.py (2)

35-38: Fix minor typo in the StateBasedConfig description

"aggragator""aggregator".
Although harmless, public-facing docs/IDE hints inherit this typo.

-        description="Dict of topic : aggregator or (topic, msg_type) : aggragator"
+        description="Dict of topic : aggregator or (topic, msg_type) : aggregator"

172-187: Do not clear _stop_event after shutdown

stop() sets _stop_event to notify all threads that shutdown is in
progress. Clearing it at the end of the method re-opens a tiny time window in
which the agent instance appears “running” again, which can lead to surprises
if the same object is reused or inspected from another thread.

-        self._stop_event.clear()

Recommendation: remove the line unless you have a documented requirement to
restart the very same agent instance.

src/rai_core/rai/aggregators/ros2/aggregators.py (2)

165-169: Misreported image counts in VLM diff output

After calling get_key_elements, len(b64_images) refers to the selected
subset, not the total number of frames analysed. The summary therefore prints
the same number twice.

-            content=f"Result of the analysis of the {len(b64_images)} keyframes selected from {len(b64_images)} last images:\n{response}"
+        original_count = len(all_images)  # capture before subsampling
+        keyframes = self.get_key_elements(all_images)
+        ...
+            content=(
+                f"Result of the analysis of the {len(keyframes)} keyframes "
+                f"selected from {original_count} last images:\n{response}"
+            )

Capturing the original buffer length before subsampling (and using distinct
variable names) yields a correct, reader-friendly message.


113-118: Minor doc-string typo

“midden” → “middle”.

-    aggregation buffer: 1st, midden, last
+    aggregation buffer: 1st, middle, last
🛑 Comments failed to post (2)
src/rai_core/rai/communication/ros2/api/conversion.py (1)

95-106: 🛠️ Refactor suggestion

⚠️ Potential issue

Grayscale handling & channel-count detection can silently mis-route images

cv_image.shape[-1] assumes a 3-D ndarray.
For mono images returned by CvBridge, shape == (h, w) (2-D), so shape[-1] equals w, not 1.
The current logic therefore falls through to the BGR path, giving an incorrect colour conversion and distorted PNG.

-    if cv_image.shape[-1] == 4:
+    # BGRA
+    if len(cv_image.shape) == 3 and cv_image.shape[-1] == 4:
...
-    elif cv_image.shape[-1] == 1:
+    # MONO
+    elif len(cv_image.shape) == 2 or (len(cv_image.shape) == 3 and cv_image.shape[-1] == 1):

Adding the len(cv_image.shape) guard ensures both (h,w) and (h,w,1) grayscale inputs select the intended branch.
Please add a test with a MONO8 sensor_msgs.msg.Image to verify.

Committable suggestion skipped: line range outside the PR's diff.

src/rai_core/rai/agents/langchain/state_based_agent.py (1)

86-118: 🛠️ Refactor suggestion

Guard _aggregation_results with a lock to avoid race conditions

_aggregation_results is written in the background aggregation thread and may be
read concurrently by:

  1. The runnable graph (through self.get_state()).
  2. Any external callers that query get_state().

Without synchronisation, a read during a concurrent write can raise
RuntimeError: dictionary changed size during iteration or return a partially
updated state.

Suggested minimal fix:

@@
-        self._aggregation_results: Dict[str, HumanMessage | HumanMultimodalMessage] = (
-            dict()
-        )
+        self._aggregation_results: Dict[str, HumanMessage | HumanMultimodalMessage] = {}
+        self._state_lock = threading.Lock()
@@
     def get_state(self) -> Dict[str, HumanMessage | HumanMultimodalMessage]:
         """Returns output for all aggregators"""
-        return self._aggregation_results
+        with self._state_lock:
+            # Return a shallow copy to prevent callers from mutating internal state
+            return dict(self._aggregation_results)
@@
-                self._aggregation_results[source] = output
+                with self._state_lock:
+                    self._aggregation_results[source] = output

This keeps the change local and inexpensive while guaranteeing
thread-safety.
If higher throughput is required later, consider switching to a
collections.OrderedDict guarded by threading.RLock or an immutable
snapshotting strategy.

Also applies to: 168-171

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants