23
23
from numpy .typing import NDArray
24
24
25
25
from rai .agents .base import BaseAgent
26
- from rai .communication import AudioInputDeviceConfig , StreamingAudioInputDevice
26
+ from rai .communication import (
27
+ AudioInputDeviceConfig ,
28
+ ROS2ARIConnector ,
29
+ ROS2ARIMessage ,
30
+ StreamingAudioInputDevice ,
31
+ )
27
32
from rai_asr .models import BaseTranscriptionModel , BaseVoiceDetectionModel
28
33
29
34
@@ -38,6 +43,7 @@ def __init__(
38
43
self ,
39
44
microphone_device_id : int , # TODO: Change to name based instead of id based identification
40
45
microphone_config : AudioInputDeviceConfig ,
46
+ ros2_name : str ,
41
47
transcription_model : BaseTranscriptionModel ,
42
48
vad : BaseVoiceDetectionModel ,
43
49
grace_period : float = 1.0 ,
@@ -51,7 +57,8 @@ def __init__(
51
57
microphone .configure_device (
52
58
target = str (microphone_device_id ), config = microphone_config
53
59
)
54
- super ().__init__ (connectors = {"microphone" : microphone })
60
+ ros2_connector = ROS2ARIConnector (ros2_name )
61
+ super ().__init__ (connectors = {"microphone" : microphone , "ros2" : ros2_connector })
55
62
self .microphone_device_id = str (microphone_device_id )
56
63
self .should_record_pipeline : List [BaseVoiceDetectionModel ] = []
57
64
self .should_stop_pipeline : List [BaseVoiceDetectionModel ] = []
@@ -89,7 +96,10 @@ def add_detection_model(
89
96
def run (self ):
90
97
self .running = True
91
98
self .listener_handle = self .connectors ["microphone" ].start_action (
92
- self .microphone_device_id , self .on_new_sample
99
+ action_data = None ,
100
+ target = self .microphone_device_id ,
101
+ on_feedback = self .on_new_sample ,
102
+ on_done = lambda : None ,
93
103
)
94
104
95
105
def stop (self ):
@@ -184,8 +194,12 @@ def transcription_thread(self, identifier: str):
184
194
del self .buffer_reminders [identifier ]
185
195
# self.transcription_model.save_wav(f"{identifier}.wav")
186
196
transcription = self .transcription_model .consume_transcription ()
197
+ print ("Transcription: " , transcription )
198
+ self .connectors ["ros2" ].send_message (
199
+ ROS2ARIMessage (
200
+ {"data" : transcription }, {"msg_type" : "std_msgs/msg/String" }
201
+ ),
202
+ "/from_human" ,
203
+ )
187
204
self .transcription_threads [identifier ]["transcription" ] = transcription
188
205
self .transcription_threads [identifier ]["event" ].set ()
189
- # TODO: sending the transcription once https://github.com/RobotecAI/rai/pull/360 is merged
190
- self .logger .info (f"transcription thread { identifier } finished" )
191
- print (transcription )
0 commit comments