Skip to content

Commit 2a2c613

Browse files
authored
feat: add conversation id to HRI message (#480)
1 parent a1f03ee commit 2a2c613

File tree

10 files changed

+156
-35
lines changed

10 files changed

+156
-35
lines changed

examples/s2s/conversational.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,18 @@
3535

3636

3737
class LLMTextHandler(BaseCallbackHandler):
38-
def __init__(self, connector: ROS2HRIConnector):
38+
def __init__(self, connector: ROS2HRIConnector, speech_id: str = ""):
3939
self.connector = connector
4040
self.token_buffer = ""
41+
self.speech_id = speech_id
4142

4243
def on_llm_new_token(self, token: str, **kwargs):
4344
self.token_buffer += token
4445
if len(self.token_buffer) > 100 or token in [".", "?", "!", ",", ";", ":"]:
4546
logging.info(f"Sending token buffer: {self.token_buffer}")
46-
self.connector.send_all_targets(AIMessage(content=self.token_buffer))
47+
self.connector.send_all_targets(
48+
AIMessage(content=self.token_buffer), self.speech_id
49+
)
4750
self.token_buffer = ""
4851

4952
def on_llm_end(
@@ -74,6 +77,7 @@ def __init__(self, connectors: Dict[str, BaseConnector]): # type: ignore
7477
self._setup_ros_connector()
7578
self.main_thread = None
7679
self.stop_thread = Event()
80+
self.current_speech_id = ""
7781

7882
def run(self):
7983
logging.info("Running S2SConversationalAgent")
@@ -85,14 +89,24 @@ def _main_loop(self):
8589
time.sleep(0.01)
8690
speech = ""
8791
while not self.speech_queue.empty():
88-
speech += "".join(self.speech_queue.get().text)
92+
speech_message = self.speech_queue.get()
93+
speech += "".join(speech_message.text)
8994
logging.info(f"Received human speech {speech}!")
95+
self.current_speech_id = speech_message.conversation_id
9096
if speech != "":
91-
self.message_history.append(HumanMessage(content=speech))
97+
self.message_history.append(
98+
HumanMessage(content=speech, conversation_id=self.current_speech_id)
99+
)
92100
assert isinstance(self.connectors["ros2"], ROS2HRIConnector)
93101
ai_answer = self.llm.invoke(
94102
self.message_history,
95-
config={"callbacks": [LLMTextHandler(self.connectors["ros2"])]},
103+
config={
104+
"callbacks": [
105+
LLMTextHandler(
106+
self.connectors["ros2"], self.current_speech_id
107+
)
108+
]
109+
},
96110
)
97111
self.message_history.append(ai_answer) # type: ignore
98112

src/rai_asr/rai_asr/agents/asr_agent.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,5 +274,9 @@ def _send_ros2_message(self, data: str, topic: str):
274274
except Exception as e:
275275
self.logger.error(f"Error sending message to {topic}: {e}")
276276
else:
277-
msg = ROS2HRIMessage(HRIPayload(text=data), "human")
277+
msg = ROS2HRIMessage(
278+
HRIPayload(text=data),
279+
"human",
280+
ROS2HRIMessage.generate_conversation_id(),
281+
)
278282
self.connectors["ros2_hri"].send_message(msg, topic)

src/rai_core/rai/agents/langchain/callback.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import logging
1616
import threading
1717
from typing import List, Optional
18+
from uuid import UUID
1819

1920
from langchain_core.callbacks import BaseCallbackHandler
2021
from langchain_core.messages import AIMessage
@@ -39,45 +40,61 @@ def __init__(
3940
self.max_buffer_size = max_buffer_size
4041
self._buffer_lock = threading.Lock()
4142
self.logger = logger or logging.getLogger(__name__)
43+
self.current_conversation_id = None
44+
self.current_chunk_id = 0
4245

4346
def _should_split(self, token: str) -> bool:
4447
return token in self.splitting_chars
4548

46-
def _send_all_targets(self, tokens: str):
49+
def _send_all_targets(self, tokens: str, done: bool = False):
4750
self.logger.info(
4851
f"Sending {len(tokens)} tokens to {len(self.connectors)} connectors"
4952
)
5053
for connector_name, connector in self.connectors.items():
5154
try:
52-
connector.send_all_targets(AIMessage(content=tokens))
55+
connector.send_all_targets(
56+
AIMessage(content=tokens),
57+
self.current_conversation_id,
58+
self.current_chunk_id,
59+
done,
60+
)
5361
self.logger.debug(f"Sent {len(tokens)} tokens to {connector_name}")
5462
except Exception as e:
5563
self.logger.error(
5664
f"Failed to send {len(tokens)} tokens to {connector_name}: {e}"
5765
)
5866

59-
def on_llm_new_token(self, token: str, **kwargs):
67+
def on_llm_new_token(self, token: str, *, run_id: UUID, **kwargs):
6068
if token == "":
6169
return
70+
if self.current_conversation_id != str(run_id):
71+
self.current_conversation_id = str(run_id)
72+
self.current_chunk_id = 0
6273
if self.aggregate_chunks:
6374
with self._buffer_lock:
6475
self.chunks_buffer += token
6576
if len(self.chunks_buffer) < self.max_buffer_size:
6677
if self._should_split(token):
6778
self._send_all_targets(self.chunks_buffer)
6879
self.chunks_buffer = ""
80+
self.current_chunk_id += 1
6981
else:
7082
self._send_all_targets(self.chunks_buffer)
7183
self.chunks_buffer = ""
84+
self.current_chunk_id += 1
7285
else:
7386
self._send_all_targets(token)
87+
self.current_chunk_id += 1
7488

7589
def on_llm_end(
7690
self,
7791
response: LLMResult,
92+
*,
93+
run_id: UUID,
7894
**kwargs,
7995
):
96+
self.current_conversation_id = str(run_id)
8097
if self.aggregate_chunks and self.chunks_buffer:
8198
with self._buffer_lock:
82-
self._send_all_targets(self.chunks_buffer)
99+
self._send_all_targets(self.chunks_buffer, done=True)
83100
self.chunks_buffer = ""

src/rai_core/rai/communication/hri_connector.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import base64
16+
import uuid
1617
from dataclasses import dataclass, field
1718
from io import BytesIO
1819
from typing import Any, Dict, Generic, Literal, Optional, Sequence, TypeVar, get_args
@@ -55,19 +56,25 @@ def __init__(
5556
payload: HRIPayload,
5657
metadata: Optional[Dict[str, Any]] = None,
5758
message_author: Literal["ai", "human"] = "ai",
59+
communication_id: Optional[str] = None,
60+
seq_no: int = 0,
61+
seq_end: bool = False,
5862
**kwargs,
5963
):
6064
super().__init__(payload, metadata)
6165
self.message_author = message_author
6266
self.text = payload.text
6367
self.images = payload.images
6468
self.audios = payload.audios
69+
self.communication_id = communication_id
70+
self.seq_no = seq_no
71+
self.seq_end = seq_end
6572

6673
def __bool__(self) -> bool:
6774
return bool(self.text or self.images or self.audios)
6875

6976
def __repr__(self):
70-
return f"HRIMessage(type={self.message_author}, text={self.text}, images={self.images}, audios={self.audios})"
77+
return f"HRIMessage(type={self.message_author}, text={self.text}, images={self.images}, audios={self.audios}, communication_id={self.communication_id}, seq_no={self.seq_no}, seq_end={self.seq_end})"
7178

7279
def _image_to_base64(self, image: ImageType) -> str:
7380
buffered = BytesIO()
@@ -115,6 +122,7 @@ def to_langchain(self) -> LangchainBaseMessage:
115122
def from_langchain(
116123
cls,
117124
message: LangchainBaseMessage | RAIMultimodalMessage,
125+
communication_id: Optional[str] = None,
118126
) -> "HRIMessage":
119127
if isinstance(message, RAIMultimodalMessage):
120128
text = message.text
@@ -137,8 +145,14 @@ def from_langchain(
137145
),
138146
),
139147
message_author=message.type, # type: ignore
148+
communication_id=communication_id,
140149
)
141150

151+
@classmethod
152+
def generate_communication_id(cls) -> str:
153+
"""Generate a unique communication ID."""
154+
return str(uuid.uuid1())
155+
142156

143157
T = TypeVar("T", bound=HRIMessage)
144158

@@ -167,12 +181,21 @@ def __init__(
167181
def _build_message(
168182
self,
169183
message: LangchainBaseMessage | RAIMultimodalMessage,
184+
communication_id: Optional[str] = None,
185+
seq_no: int = 0,
186+
seq_end: bool = False,
170187
) -> T:
171-
return self.T_class.from_langchain(message)
188+
return self.T_class.from_langchain(message, communication_id, seq_no, seq_end)
172189

173-
def send_all_targets(self, message: LangchainBaseMessage | RAIMultimodalMessage):
190+
def send_all_targets(
191+
self,
192+
message: LangchainBaseMessage | RAIMultimodalMessage,
193+
communication_id: Optional[str] = None,
194+
seq_no: int = 0,
195+
seq_end: bool = False,
196+
):
174197
for target in self.configured_targets:
175-
to_send = self._build_message(message)
198+
to_send = self._build_message(message, communication_id, seq_no, seq_end)
176199
self.send_message(to_send, target)
177200

178201
def receive_all_sources(self, timeout_sec: float = 1.0) -> dict[str, T]:

src/rai_core/rai/communication/ros2/messages.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,15 @@ def __init__(self, payload: Any, metadata: Optional[Dict[str, Any]] = None):
4242

4343

4444
class ROS2HRIMessage(HRIMessage):
45-
def __init__(self, payload: HRIPayload, message_author: Literal["ai", "human"]):
46-
super().__init__(payload, {}, message_author)
45+
def __init__(
46+
self,
47+
payload: HRIPayload,
48+
message_author: Literal["ai", "human"],
49+
communication_id: Optional[str] = None,
50+
seq_no: int = 0,
51+
seq_end: bool = False,
52+
):
53+
super().__init__(payload, {}, message_author, communication_id, seq_no, seq_end)
4754

4855
@classmethod
4956
def from_ros2(
@@ -66,9 +73,13 @@ def from_ros2(
6673
)
6774
for audio_msg in cast(List[ROS2HRIMessage__Audio], msg.audios)
6875
]
76+
communication_id = msg.communication_id if msg.communication_id != "" else None
6977
return ROS2HRIMessage(
7078
payload=HRIPayload(text=msg.text, images=pil_images, audios=audio_segments),
7179
message_author=message_author,
80+
communication_id=communication_id,
81+
seq_no=msg.seq_no,
82+
seq_end=msg.seq_end,
7283
)
7384

7485
def to_ros2_dict(self) -> OrderedDict[str, Any]:
@@ -94,6 +105,9 @@ def to_ros2_dict(self) -> OrderedDict[str, Any]:
94105
text=self.payload.text,
95106
images=img_msgs,
96107
audios=audio_msgs,
108+
communication_id=self.communication_id or "",
109+
seq_no=self.seq_no,
110+
seq_end=self.seq_end,
97111
)
98112
),
99113
)

src/rai_interfaces/msg/HRIMessage.msg

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,6 @@ std_msgs/Header header
1818
string text
1919
sensor_msgs/Image[] images
2020
rai_interfaces/AudioMessage[] audios
21+
string communication_id
22+
int64 seq_no
23+
bool seq_end

src/rai_tts/rai_tts/agents/tts_agent.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ class TextToSpeechAgent(BaseAgent):
7474
Text-to-speech model used for generating audio.
7575
logger : Optional[logging.Logger], optional
7676
Logger instance for logging messages, by default None.
77+
max_speech_history : int, optional
78+
Maximum amount of speech ids to remember, by default 64
7779
"""
7880

7981
def __init__(
@@ -82,6 +84,7 @@ def __init__(
8284
ros2_name: str,
8385
tts: TTSModel,
8486
logger: Optional[logging.Logger] = None,
87+
max_speech_history=64,
8588
):
8689
if logger is None:
8790
self.logger = logging.getLogger(__name__)
@@ -101,8 +104,10 @@ def __init__(
101104
super().__init__(connectors={"ros2": ros2_connector, "speaker": speaker})
102105

103106
self.current_transcription_id = str(uuid4())[0:8]
107+
self.current_speech_id = None
104108
self.text_queues: dict[str, Queue] = {self.current_transcription_id: Queue()}
105109
self.audio_queues: dict[str, Queue] = {self.current_transcription_id: Queue()}
110+
self.remembered_speech_ids: list[str] = []
106111

107112
self.tog_play_event = Event()
108113
self.stop_event = Event()
@@ -224,7 +229,17 @@ def _on_to_human_message(self, message: IROS2Message):
224229
self.logger.warning(
225230
f"Starting playback, current id: {self.current_transcription_id}"
226231
)
227-
self.text_queues[self.current_transcription_id].put(msg.text)
232+
if (
233+
self.current_speech_id is None
234+
and msg.conversation_id is not None
235+
and msg.conversation_id not in self.remembered_speech_ids
236+
):
237+
self.current_speech_id = msg.conversation_id
238+
self.remembered_speech_ids.append(self.current_speech_id)
239+
if len(self.remembered_speech_ids) > 64:
240+
self.remembered_speech_ids.pop(0)
241+
if self.current_speech_id == msg.conversation_id:
242+
self.text_queues[self.current_transcription_id].put(msg.text)
228243
self.playback_data.playing = True
229244

230245
def _on_command_message(self, message: IROS2Message):
@@ -237,6 +252,7 @@ def _on_command_message(self, message: IROS2Message):
237252
elif message.data == "pause":
238253
self.playback_data.playing = False
239254
elif message.data == "stop":
255+
self.current_speech_id = None
240256
self.playback_data.playing = False
241257
previous_id = self.current_transcription_id
242258
self.logger.warning(f"Stopping playback, previous id: {previous_id}")

tests/communication/ros2/helpers.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -180,12 +180,11 @@ def goal_response_callback(self, future):
180180

181181
def get_result_callback(self, future):
182182
result = future.result().result
183-
self.get_logger().info(f"Result: {result.sequence}")
184-
rclpy.shutdown()
183+
self.get_logger().info(f"Result: {result}")
185184

186185
def feedback_callback(self, feedback_msg):
187186
feedback = feedback_msg.feedback
188-
self.get_logger().info(f"Received feedback: {feedback.partial_sequence}")
187+
self.get_logger().info(f"Received feedback: {feedback}")
189188

190189

191190
class TestServiceClient(Node):

tests/communication/ros2/test_connectors.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,9 @@ def test_ros2hri_default_message_publish(
184184
audios = [AudioSegment.silent(duration=1000)]
185185
text = "Hello, HRI!"
186186
payload = HRIPayload(images=images, audios=audios, text=text)
187-
message = ROS2HRIMessage(payload=payload, message_author="ai")
187+
message = ROS2HRIMessage(
188+
payload=payload, message_author="ai", communication_id=""
189+
)
188190
connector.send_message(message, target=topic_name)
189191
time.sleep(1) # wait for the message to be received
190192

@@ -231,13 +233,11 @@ def test_ros2ari_connector_create_service(
231233
service_client = TestServiceClient()
232234
executors, threads = multi_threaded_spinner([service_client])
233235
service_client.send_request()
234-
time.sleep(0.01)
236+
time.sleep(0.02)
235237
assert mock_callback.called
236-
except Exception as e:
237-
raise e
238-
239-
connector.shutdown()
240-
shutdown_executors_and_threads(executors, threads)
238+
finally:
239+
connector.shutdown()
240+
shutdown_executors_and_threads(executors, threads)
241241

242242

243243
def test_ros2ari_connector_action_call(ros_setup: None, request: pytest.FixtureRequest):
@@ -256,7 +256,7 @@ def test_ros2ari_connector_action_call(ros_setup: None, request: pytest.FixtureR
256256
action_client = TestActionClient()
257257
executors, threads = multi_threaded_spinner([action_client])
258258
action_client.send_goal()
259-
time.sleep(0.01)
259+
time.sleep(0.02)
260+
assert mock_callback.called
260261
finally:
261262
shutdown_executors_and_threads(executors, threads)
262-
assert mock_callback.called

0 commit comments

Comments
 (0)