15
15
import argparse
16
16
17
17
import rclpy
18
- from rai .node import RaiStateBasedLlmNode , describe_ros_image
19
- from rai .tools .ros .native import (
20
- GetCameraImage ,
21
- GetMsgFromTopic ,
22
- Ros2GenericServiceCaller ,
23
- Ros2GetRobotInterfaces ,
24
- Ros2ShowMsgInterfaceTool ,
25
- )
18
+ from langchain_core .messages import HumanMessage
19
+ from langchain_core .runnables import Runnable
20
+ from rai .agents .conversational_agent import State , create_conversational_agent
21
+ from rai .communication .ros2 .connectors import ROS2ARIConnector
22
+ from rai .tools .ros2 .topics import ROS2TopicsToolkit
26
23
from rai .tools .time import WaitForSecondsTool
27
- from rclpy . action import ActionClient
24
+ from rai . utils . model_initialization import get_llm_model
28
25
from rclpy .callback_groups import ReentrantCallbackGroup
29
26
from rclpy .executors import MultiThreadedExecutor
30
27
from rclpy .node import Node
34
31
35
32
36
33
class MockBehaviorTreeNode (Node ):
37
- def __init__ (self , tractor_number : int ):
34
+ def __init__ (self , tractor_number : int , agent : Runnable [ State , State ] ):
38
35
super ().__init__ (f"mock_behavior_tree_node_{ tractor_number } " )
39
36
self .tractor_number = tractor_number
37
+ self .agent = agent
40
38
41
39
# Create a callback group for concurrent execution
42
40
self .callback_group = ReentrantCallbackGroup ()
@@ -48,11 +46,6 @@ def __init__(self, tractor_number: int):
48
46
callback_group = self .callback_group ,
49
47
)
50
48
51
- # Create action client
52
- self .perform_task_client = ActionClient (
53
- self , Task , "/perform_task" , callback_group = self .callback_group
54
- )
55
-
56
49
# Create timer for periodic checks
57
50
self .create_timer (
58
51
5.0 , self .check_tractor_state , callback_group = self .callback_group
@@ -79,18 +72,7 @@ async def check_tractor_state(self):
79
72
goal_msg .description = ""
80
73
goal_msg .task = "Anomaly detected. Please decide what to do."
81
74
82
- self .perform_task_client .wait_for_server ()
83
-
84
- future = self .perform_task_client .send_goal_async (goal_msg )
85
- await future
86
-
87
- goal_handle = future .result ()
88
- if goal_handle .accepted :
89
- self .get_logger ().info ("Goal accepted by perform_task action server" )
90
- result = await goal_handle .get_result_async ()
91
- self .get_logger ().info (f"Result: { result .result } " )
92
- else :
93
- self .get_logger ().warn ("Goal rejected by perform_task action server" )
75
+ self .agent .invoke (State (messages = [HumanMessage (content = str (goal_msg ))]))
94
76
95
77
96
78
def main ():
@@ -105,30 +87,9 @@ def main():
105
87
args = parser .parse_args ()
106
88
107
89
tractor_number = args .tractor_number
108
- tractor_prefix = f"/tractor{ tractor_number } "
109
90
110
91
rclpy .init ()
111
92
112
- observe_topics = [
113
- f"{ tractor_prefix } /camera_image_color" ,
114
- ]
115
-
116
- observe_postprocessors = {
117
- f"{ tractor_prefix } /camera_image_color" : describe_ros_image
118
- }
119
-
120
- topics_allowlist = [
121
- "/rosout" ,
122
- f"{ tractor_prefix } /camera_image_color" ,
123
- # Services
124
- f"{ tractor_prefix } /continue" ,
125
- f"{ tractor_prefix } /current_state" ,
126
- f"{ tractor_prefix } /flash" ,
127
- f"{ tractor_prefix } /replan" ,
128
- ]
129
-
130
- actions_allowlist = []
131
-
132
93
SYSTEM_PROMPT = f"""
133
94
You are autonomous tractor { tractor_number } operating in an agricultural field. You are activated whenever the tractor stops due to an unexpected situation. Your task is to call a service based on your assessment of the situation.
134
95
@@ -142,35 +103,28 @@ def main():
142
103
143
104
Important: You must call only one service. The tractor can only handle one service call.
144
105
"""
145
-
146
- rai_node = RaiStateBasedLlmNode (
147
- observe_topics = observe_topics ,
148
- observe_postprocessors = observe_postprocessors ,
149
- allowlist = topics_allowlist + actions_allowlist ,
106
+ connector = ROS2ARIConnector ()
107
+ agent = create_conversational_agent (
108
+ llm = get_llm_model ("complex_model" ),
150
109
system_prompt = SYSTEM_PROMPT ,
151
110
tools = [
152
- Ros2ShowMsgInterfaceTool ,
153
- WaitForSecondsTool ,
154
- GetMsgFromTopic ,
155
- GetCameraImage ,
156
- Ros2GetRobotInterfaces ,
157
- Ros2GenericServiceCaller ,
111
+ * ROS2TopicsToolkit (connector = connector ).get_tools (),
112
+ WaitForSecondsTool (),
158
113
],
159
114
)
160
115
161
- mock_node = MockBehaviorTreeNode (tractor_number )
116
+ mock_node = MockBehaviorTreeNode (tractor_number , agent )
162
117
163
118
# Use a MultiThreadedExecutor to allow for concurrent execution
164
119
executor = MultiThreadedExecutor ()
165
- executor .add_node (rai_node )
166
120
executor .add_node (mock_node )
167
121
168
122
try :
169
123
executor .spin ()
170
124
except KeyboardInterrupt :
171
125
pass
172
126
finally :
173
- rai_node . destroy_node ()
127
+ connector . shutdown ()
174
128
mock_node .destroy_node ()
175
129
rclpy .shutdown ()
176
130
0 commit comments