Skip to content

Commit 7dc4b43

Browse files
committed
feat: integrate with ros2 connector
1 parent 1e948a2 commit 7dc4b43

File tree

3 files changed

+31
-8
lines changed

3 files changed

+31
-8
lines changed

src/rai/rai/agents/voice_agent.py

+20-6
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,12 @@
2323
from numpy.typing import NDArray
2424

2525
from rai.agents.base import BaseAgent
26-
from rai.communication import AudioInputDeviceConfig, StreamingAudioInputDevice
26+
from rai.communication import (
27+
AudioInputDeviceConfig,
28+
ROS2ARIConnector,
29+
ROS2ARIMessage,
30+
StreamingAudioInputDevice,
31+
)
2732
from rai_asr.models import BaseTranscriptionModel, BaseVoiceDetectionModel
2833

2934

@@ -38,6 +43,7 @@ def __init__(
3843
self,
3944
microphone_device_id: int, # TODO: Change to name based instead of id based identification
4045
microphone_config: AudioInputDeviceConfig,
46+
ros2_name: str,
4147
transcription_model: BaseTranscriptionModel,
4248
vad: BaseVoiceDetectionModel,
4349
grace_period: float = 1.0,
@@ -51,7 +57,8 @@ def __init__(
5157
microphone.configure_device(
5258
target=str(microphone_device_id), config=microphone_config
5359
)
54-
super().__init__(connectors={"microphone": microphone})
60+
ros2_connector = ROS2ARIConnector(ros2_name)
61+
super().__init__(connectors={"microphone": microphone, "ros2": ros2_connector})
5562
self.microphone_device_id = str(microphone_device_id)
5663
self.should_record_pipeline: List[BaseVoiceDetectionModel] = []
5764
self.should_stop_pipeline: List[BaseVoiceDetectionModel] = []
@@ -89,7 +96,10 @@ def add_detection_model(
8996
def run(self):
9097
self.running = True
9198
self.listener_handle = self.connectors["microphone"].start_action(
92-
self.microphone_device_id, self.on_new_sample
99+
action_data=None,
100+
target=self.microphone_device_id,
101+
on_feedback=self.on_new_sample,
102+
on_done=lambda: None,
93103
)
94104

95105
def stop(self):
@@ -184,8 +194,12 @@ def transcription_thread(self, identifier: str):
184194
del self.buffer_reminders[identifier]
185195
# self.transcription_model.save_wav(f"{identifier}.wav")
186196
transcription = self.transcription_model.consume_transcription()
197+
print("Transcription: ", transcription)
198+
self.connectors["ros2"].send_message(
199+
ROS2ARIMessage(
200+
{"data": transcription}, {"msg_type": "std_msgs/msg/String"}
201+
),
202+
"/from_human",
203+
)
187204
self.transcription_threads[identifier]["transcription"] = transcription
188205
self.transcription_threads[identifier]["event"].set()
189-
# TODO: sending the transcription once https://github.com/RobotecAI/rai/pull/360 is merged
190-
self.logger.info(f"transcription thread {identifier} finished")
191-
print(transcription)

src/rai/rai/communication/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .ari_connector import ARIConnector, ARIMessage
1616
from .base_connector import BaseConnector, BaseMessage
1717
from .hri_connector import HRIConnector, HRIMessage, HRIPayload
18+
from .ros2.connectors import ROS2ARIConnector, ROS2ARIMessage
1819
from .sound_device_connector import (
1920
AudioInputDeviceConfig,
2021
SoundDeviceError,
@@ -29,6 +30,8 @@
2930
"HRIConnector",
3031
"HRIMessage",
3132
"HRIPayload",
33+
"ROS2ARIConnector",
34+
"ROS2ARIMessage",
3235
"StreamingAudioInputDevice",
3336
"SoundDeviceError",
3437
"AudioInputDeviceConfig",

src/rai/rai/communication/ros2/connectors.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import threading
1616
import uuid
17-
from typing import Any, Callable, Dict, Optional
17+
from typing import Any, Callable, Dict, Optional, TypedDict
1818

1919
from rclpy.executors import MultiThreadedExecutor
2020
from rclpy.node import Node
@@ -23,8 +23,14 @@
2323
from rai.communication.ros2.api import ROS2ActionAPI, ROS2ServiceAPI, ROS2TopicAPI
2424

2525

26+
class ROS2ARIPayload(TypedDict):
27+
data: Any
28+
29+
2630
class ROS2ARIMessage(ARIMessage):
27-
def __init__(self, payload: Any, metadata: Optional[Dict[str, Any]] = None):
31+
def __init__(
32+
self, payload: ROS2ARIPayload, metadata: Optional[Dict[str, Any]] = None
33+
):
2834
super().__init__(payload, metadata)
2935

3036

0 commit comments

Comments
 (0)