Skip to content

Commit a1f03ee

Browse files
authored
feat: dynamic services (#496)
1 parent f5f206f commit a1f03ee

File tree

22 files changed

+799
-346
lines changed

22 files changed

+799
-346
lines changed

examples/s2s/conversational.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,7 @@
2525
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
2626
from rai.agents.base import BaseAgent
2727
from rai.communication import BaseConnector
28-
from rai.communication.ros2.api import IROS2Message
29-
from rai.communication.ros2.connectors import ROS2HRIConnector, TopicConfig
28+
from rai.communication.ros2 import IROS2Message, ROS2HRIConnector, TopicConfig
3029
from rai.utils.model_initialization import get_llm_model
3130

3231
from rai_interfaces.msg import HRIMessage as InterfacesHRIMessage

src/rai_core/rai/communication/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,12 @@
1515
from .ari_connector import ARIConnector, ARIMessage
1616
from .base_connector import BaseConnector, BaseMessage
1717
from .hri_connector import HRIConnector, HRIMessage, HRIPayload
18-
from .ros2.api import TopicConfig
19-
from .ros2.connectors import (
18+
from .ros2 import (
2019
ROS2ARIConnector,
2120
ROS2ARIMessage,
2221
ROS2HRIConnector,
2322
ROS2HRIMessage,
23+
TopicConfig,
2424
)
2525
from .sound_device.connector import (
2626
SoundDeviceConfig,

src/rai_core/rai/communication/base_connector.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,26 @@ def receive_message(self, source: str, timeout_sec: float, **kwargs: Any) -> T:
5353
@abstractmethod
5454
def service_call(
5555
self, message: T, target: str, timeout_sec: float, **kwargs: Any
56-
) -> T:
56+
) -> BaseMessage:
57+
pass
58+
59+
@abstractmethod
60+
def create_service(
61+
self,
62+
service_name: str,
63+
on_request: Callable,
64+
on_done: Optional[Callable] = None,
65+
**kwargs: Any,
66+
) -> str:
67+
pass
68+
69+
@abstractmethod
70+
def create_action(
71+
self,
72+
action_name: str,
73+
generate_feedback_callback: Callable,
74+
**kwargs: Any,
75+
) -> str:
5776
pass
5877

5978
@abstractmethod

src/rai_core/rai/communication/hri_connector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def __init__(
162162
raise HRIException(
163163
f"Error while instantiating {str(self.__class__)}: Message type T derived from HRIMessage needs to be provided e.g. Connector[MessageType]()"
164164
)
165-
self.T_class = get_args(self.__orig_bases__[0])[0]
165+
self.T_class = get_args(self.__orig_bases__[-1])[0]
166166

167167
def _build_message(
168168
self,

src/rai_core/rai/communication/ros2/__init__.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,26 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from .api import ConfigurableROS2TopicAPI, ROS2ActionAPI, ROS2ServiceAPI, ROS2TopicAPI
15+
from .api import (
16+
ConfigurableROS2TopicAPI,
17+
IROS2Message,
18+
ROS2ActionAPI,
19+
ROS2ServiceAPI,
20+
ROS2TopicAPI,
21+
TopicConfig,
22+
)
1623
from .connectors import ROS2ARIConnector, ROS2HRIConnector
1724
from .messages import ROS2ARIMessage, ROS2HRIMessage
1825

1926
__all__ = [
2027
"ConfigurableROS2TopicAPI",
28+
"IROS2Message",
2129
"ROS2ARIConnector",
2230
"ROS2ARIMessage",
2331
"ROS2ActionAPI",
2432
"ROS2HRIConnector",
2533
"ROS2HRIMessage",
2634
"ROS2ServiceAPI",
2735
"ROS2TopicAPI",
36+
"TopicConfig",
2837
]

src/rai_core/rai/communication/ros2/api.py

Lines changed: 110 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
Type,
3535
TypedDict,
3636
cast,
37+
runtime_checkable,
3738
)
3839

3940
import rclpy
@@ -47,8 +48,15 @@
4748
import rosidl_runtime_py.set_message
4849
import rosidl_runtime_py.utilities
4950
from action_msgs.srv import CancelGoal
50-
from rclpy.action import ActionClient
51+
from rclpy.action import ActionClient, CancelResponse, GoalResponse
5152
from rclpy.action.client import ClientGoalHandle
53+
from rclpy.action.server import (
54+
ActionServer,
55+
ServerGoalHandle,
56+
default_cancel_callback,
57+
default_goal_callback,
58+
default_handle_accepted_callback,
59+
)
5260
from rclpy.callback_groups import ReentrantCallbackGroup
5361
from rclpy.publisher import Publisher
5462
from rclpy.qos import (
@@ -57,15 +65,19 @@
5765
LivelinessPolicy,
5866
QoSProfile,
5967
ReliabilityPolicy,
68+
qos_profile_action_status_default,
69+
qos_profile_services_default,
6070
)
71+
from rclpy.service import Service
6172
from rclpy.task import Future
6273
from rclpy.topic_endpoint_info import TopicEndpointInfo
6374

6475
from rai.tools.ros.utils import import_message_from_str
6576

6677

78+
@runtime_checkable
6779
class IROS2Message(Protocol):
68-
__slots__: tuple
80+
__slots__: list
6981

7082
def get_fields_and_field_types(self) -> dict: ...
7183

@@ -590,6 +602,7 @@ class ROS2ServiceAPI:
590602
def __init__(self, node: rclpy.node.Node) -> None:
591603
self.node = node
592604
self._logger = node.get_logger()
605+
self._services: Dict[str, Service] = {}
593606

594607
def call_service(
595608
self,
@@ -622,6 +635,19 @@ def call_service(
622635
def get_service_names_and_types(self) -> List[Tuple[str, List[str]]]:
623636
return self.node.get_service_names_and_types()
624637

638+
def create_service(
639+
self,
640+
service_name: str,
641+
service_type: str,
642+
callback: Callable[[Any, Any], Any],
643+
**kwargs,
644+
) -> str:
645+
srv_cls = import_message_from_str(service_type)
646+
service = self.node.create_service(srv_cls, service_name, callback, **kwargs)
647+
handle = str(uuid.uuid4())
648+
self._services[handle] = service
649+
return handle
650+
625651

626652
class ROS2ActionData(TypedDict):
627653
action_client: Optional[ActionClient]
@@ -636,6 +662,7 @@ def __init__(self, node: rclpy.node.Node) -> None:
636662
self.node = node
637663
self._logger = node.get_logger()
638664
self.actions: Dict[str, ROS2ActionData] = {}
665+
self._action_servers: Dict[str, ActionServer] = {}
639666
self._callback_executor = ThreadPoolExecutor(max_workers=10)
640667

641668
def _generate_handle(self):
@@ -672,6 +699,87 @@ def _safe_callback_wrapper(
672699
except Exception as e:
673700
self._logger.error(f"Error in feedback callback: {str(e)}")
674701

702+
def create_action_server(
703+
self,
704+
action_type: str,
705+
action_name: str,
706+
execute_callback: Callable[[ServerGoalHandle], Type[IROS2Message]],
707+
*,
708+
callback_group: Optional[rclpy.node.CallbackGroup] = None,
709+
goal_callback: Callable[[IROS2Message], GoalResponse] = default_goal_callback,
710+
handle_accepted_callback: Callable[
711+
[ServerGoalHandle], None
712+
] = default_handle_accepted_callback,
713+
cancel_callback: Callable[
714+
[IROS2Message], CancelResponse
715+
] = default_cancel_callback,
716+
goal_service_qos_profile: QoSProfile = qos_profile_services_default,
717+
result_service_qos_profile: QoSProfile = qos_profile_services_default,
718+
cancel_service_qos_profile: QoSProfile = qos_profile_services_default,
719+
feedback_pub_qos_profile: QoSProfile = QoSProfile(depth=10),
720+
status_pub_qos_profile: QoSProfile = qos_profile_action_status_default,
721+
result_timeout: int = 900,
722+
) -> str:
723+
"""
724+
Create an action server.
725+
726+
Args:
727+
action_type: The action message type with namespace
728+
action_name: The name of the action server
729+
execute_callback: The callback to execute when a goal is received
730+
callback_grou: The callback group to use for the action server
731+
goal_callback: The callback to execute when a goal is received
732+
handle_accepted_callback: The callback to execute when a goal handle is accepted
733+
cancel_callback: The callback to execute when a goal is canceled
734+
goal_service_qos_profile: The QoS profile for the goal service
735+
result_service_qos_profile: The QoS profile for the result service
736+
cancel_service_qos_profile: The QoS profile for the cancel service
737+
feedback_pub_qos_profile: The QoS profile for the feedback publisher
738+
status_pub_qos_profile: The QoS profile for the status publisher
739+
result_timeout: The timeout for waiting for a result
740+
741+
Returns:
742+
The handle for the created action server
743+
744+
Raises:
745+
ValueError: If the action server cannot be created
746+
"""
747+
handle = self._generate_handle()
748+
action_ros_type = import_message_from_str(action_type)
749+
try:
750+
action_server = ActionServer(
751+
node=self.node,
752+
action_type=action_ros_type,
753+
action_name=action_name,
754+
execute_callback=execute_callback,
755+
callback_group=callback_group,
756+
goal_callback=goal_callback,
757+
handle_accepted_callback=handle_accepted_callback,
758+
cancel_callback=cancel_callback,
759+
goal_service_qos_profile=goal_service_qos_profile,
760+
result_service_qos_profile=result_service_qos_profile,
761+
cancel_service_qos_profile=cancel_service_qos_profile,
762+
feedback_pub_qos_profile=feedback_pub_qos_profile,
763+
status_pub_qos_profile=status_pub_qos_profile,
764+
result_timeout=result_timeout,
765+
)
766+
self._logger.info(f"Created action server: {action_name}")
767+
except TypeError as e:
768+
import inspect
769+
770+
signature = inspect.signature(ActionServer.__init__)
771+
args = [
772+
param.name
773+
for param in signature.parameters.values()
774+
if param.name != "self"
775+
]
776+
777+
raise ValueError(
778+
f"Failed to create action server: {str(e)}. Valid arguments are: {args}"
779+
)
780+
self._action_servers[handle] = action_server
781+
return handle
782+
675783
def send_goal(
676784
self,
677785
action_name: str,

0 commit comments

Comments
 (0)