Skip to content

Commit 1695191

Browse files
feat: add tts to rai core (#419)
Co-authored-by: Maciej Majek <[email protected]>
1 parent 7890b3f commit 1695191

File tree

29 files changed

+1810
-213
lines changed

29 files changed

+1810
-213
lines changed
Lines changed: 47 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,65 @@
11
# Human Robot Interface via Voice
22

3-
> [!IMPORTANT]
4-
> RAI_ASR supports both local Whisper models and OpenAI Whisper (cloud). When using the cloud version, the OPENAI_API_KEY environment variable must be set with a valid API key.
3+
RAI provides two ROS enabled agents for Speech to Speech communication.
54

6-
## Running example
5+
## Automatic Speech Recognition Agent
6+
7+
See `examples/s2s/asr.py` for an example usage.
8+
9+
The agent requires configuration of `sounddevice` and `ros2` connectors as well as a required voice activity detection (eg. `SileroVAD`) and transcription model e.g. (`LocalWhisper`), as well as optionally additional models to decide if the transcription should start (e.g. `OpenWakeWord`).
710

8-
When your robot's whoami package is ready, run the following:
11+
The Agent publishes information on two topics:
912

10-
> [!TIP]
11-
> Make sure rai_whoami is running.
13+
`/from_human`: `rai_interfaces/msg/HRIMessages` - containing transcriptions of the recorded speech
1214

13-
** Parameters **
14-
recording_device: The device you want to record with. Check available with:
15+
`/voice_commands`: `std_msgs/msg/String` - containing control commands, to inform the consumer if speech is currently detected (`{"data": "pause"}`), was detected, and now it stopped (`{"data": "play"}`), and if speech was transcribed (`{"data": "stop"}`).
1516

16-
```bash
17-
python -c 'import sounddevice as sd; print(sd.query_devices())'
17+
The Agent utilises sounddevice module to access user's microphone, by default the `"default"` sound device is used.
18+
To get information about available sounddevices use:
19+
20+
```
21+
python -c "import sounddevice; print(sounddevice.query_devices())"
1822
```
1923

20-
keep_speaker_busy: some speakers may go into low power mode, which may result in truncated speech beginnings. Set to true to play low frequency, low volume noise to prevent sleep mode.
24+
The device can be identifed by name and passed to the configuration.
25+
26+
## TextToSpeechAgent
27+
28+
See `examples/s2s/tts.py` for an example usage.
29+
30+
The agent requires configuration of `sounddevice` and `ros2` connectors as well as a required TextToSpeech model (e.g. `OpenTTS`).
31+
The Agent listens for information on two topics:
32+
33+
`/to_human`: `rai_interfaces/msg/HRIMessages` - containing responses to be played to human. These responses are then transcribed and put into the playback queue.
34+
35+
`/voice_commands`: `std_msgs/msg/String` - containing control commands, to pause current playback (`{"data": "pause"}`), start/continue playback (`{"data": "play"}`), or stop the playback and drop the current playback queue (`{"data": "play"}`).
36+
37+
The Agent utilises sounddevice module to access user's speaker, by default the `"default"` sound device is used.
38+
To get a list of names of available sound devices use:
39+
40+
```
41+
python -c 'import sounddevice as sd; print([x["name"] for x in list(sd.query_devices())])'
42+
```
43+
44+
The device can be identifed by name and passed to the configuration.
2145

2246
### OpenTTS
2347

24-
```bash
25-
ros2 launch rai_bringup hri.launch.py tts_vendor:=opentts robot_description_package:=<robot_description_package> recording_device:=0 keep_speaker_busy:=(true|false) asr_vendor:=(whisper|openai)
48+
To run OpenTTS (and the example) a docker server containing the model must be running.
2649

50+
To start it run:
51+
52+
```
53+
docker run -it -p 5500:5500 synesthesiam/opentts:en --no-espeak
2754
```
2855

29-
> [!NOTE]
30-
> Run OpenTTS with `docker run -it -p 5500:5500 synesthesiam/opentts:en --no-espeak`
56+
## Running example
3157

32-
### ElevenLabs
58+
To run the provided example of S2S configuration with a minimal LLM-based agent run in 4 separate terminals:
3359

34-
```bash
35-
ros2 launch rai_bringup hri.launch.py robot_description_package:=<robot_description_package> recording_device:=0 keep_speaker_busy:=(true|false) asr_vendor:=(whisper|openai)
60+
```
61+
$ docker run -it -p 5500:5500 synesthesiam/opentts:en --no-espeak
62+
$ python ./examples/s2s/asr.py
63+
$ python ./examples/s2s/tts.py
64+
$ python ./examples/s2s/conversational.py
3665
```

examples/s2s/asr.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
# Copyright (C) 2024 Robotec.AI
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import argparse
16+
import signal
17+
import time
18+
19+
import rclpy
20+
from rai.agents import VoiceRecognitionAgent
21+
from rai.communication.sound_device.api import SoundDeviceConfig
22+
23+
from rai_asr.models import LocalWhisper, OpenWakeWord, SileroVAD
24+
25+
VAD_THRESHOLD = 0.8 # Note that this might be different depending on your device
26+
OWW_THRESHOLD = 0.1 # Note that this might be different depending on your device
27+
28+
VAD_SAMPLING_RATE = 16000 # Or 8000
29+
DEFAULT_BLOCKSIZE = 1280
30+
31+
32+
def parse_arguments():
33+
parser = argparse.ArgumentParser(
34+
description="Voice Activity Detection and Wake Word Detection Configuration",
35+
allow_abbrev=True,
36+
)
37+
38+
# Predefined arguments
39+
parser.add_argument(
40+
"--vad-threshold",
41+
type=float,
42+
default=VAD_THRESHOLD,
43+
help="Voice Activity Detection threshold (default: 0.5)",
44+
)
45+
parser.add_argument(
46+
"--oww-threshold",
47+
type=float,
48+
default=OWW_THRESHOLD,
49+
help="OpenWakeWord threshold (default: 0.1)",
50+
)
51+
parser.add_argument(
52+
"--vad-sampling-rate",
53+
type=int,
54+
choices=[8000, 16000],
55+
default=VAD_SAMPLING_RATE,
56+
help="VAD sampling rate (default: 16000)",
57+
)
58+
parser.add_argument(
59+
"--block-size",
60+
type=int,
61+
default=DEFAULT_BLOCKSIZE,
62+
help="Audio block size (default: 1280)",
63+
)
64+
parser.add_argument(
65+
"--device-name",
66+
type=str,
67+
default="default",
68+
help="Microphone device name (default: 'default')",
69+
)
70+
71+
# Use parse_known_args to ignore unknown arguments
72+
args, unknown = parser.parse_known_args()
73+
74+
if unknown:
75+
print(f"Ignoring unknown arguments: {unknown}")
76+
77+
return args
78+
79+
80+
if __name__ == "__main__":
81+
args = parse_arguments()
82+
83+
microphone_configuration = SoundDeviceConfig(
84+
stream=True,
85+
channels=1,
86+
device_name=args.device_name,
87+
block_size=args.block_size,
88+
consumer_sampling_rate=args.vad_sampling_rate,
89+
dtype="int16",
90+
device_number=None,
91+
is_input=True,
92+
is_output=False,
93+
)
94+
vad = SileroVAD(args.vad_sampling_rate, args.vad_threshold)
95+
oww = OpenWakeWord("hey jarvis", args.oww_threshold)
96+
whisper = LocalWhisper("tiny", args.vad_sampling_rate)
97+
# you can easily switch the the provider by changing the whisper object
98+
# whisper = OpenAIWhisper("whisper-1", args.vad_sampling_rate, "en")
99+
100+
rclpy.init()
101+
ros2_name = "rai_asr_agent"
102+
103+
agent = VoiceRecognitionAgent(microphone_configuration, ros2_name, whisper, vad)
104+
# optionally add additional models to decide when to record data for transcription
105+
# agent.add_detection_model(oww, pipeline="record")
106+
107+
agent.run()
108+
109+
def cleanup(signum, frame):
110+
agent.stop()
111+
rclpy.shutdown()
112+
exit(0)
113+
114+
signal.signal(signal.SIGINT, cleanup)
115+
116+
while True:
117+
time.sleep(1)

examples/s2s/conversational.py

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
# Copyright (C) 2024 Robotec.AI
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import argparse
16+
import logging
17+
import signal
18+
import time
19+
from queue import Queue
20+
from threading import Event, Thread
21+
from typing import Dict, List
22+
23+
import rclpy
24+
from langchain_core.callbacks import BaseCallbackHandler
25+
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
26+
from rai.agents.base import BaseAgent
27+
from rai.communication import BaseConnector
28+
from rai.communication.ros2.api import IROS2Message
29+
from rai.communication.ros2.connectors import ROS2HRIConnector, TopicConfig
30+
from rai.utils.model_initialization import get_llm_model
31+
32+
from rai_interfaces.msg import HRIMessage as InterfacesHRIMessage
33+
34+
# NOTE: the Agent code included here is temporary until a dedicated speech agent is created
35+
# it can still serve as a reference for writing your own RAI agents
36+
37+
38+
class LLMTextHandler(BaseCallbackHandler):
39+
def __init__(self, connector: ROS2HRIConnector):
40+
self.connector = connector
41+
self.token_buffer = ""
42+
43+
def on_llm_new_token(self, token: str, **kwargs):
44+
self.token_buffer += token
45+
if len(self.token_buffer) > 100 or token in [".", "?", "!", ",", ";", ":"]:
46+
logging.info(f"Sending token buffer: {self.token_buffer}")
47+
self.connector.send_all_targets(AIMessage(content=self.token_buffer))
48+
self.token_buffer = ""
49+
50+
def on_llm_end(
51+
self,
52+
response,
53+
*,
54+
run_id,
55+
parent_run_id=None,
56+
**kwargs,
57+
):
58+
if self.token_buffer:
59+
logging.info(f"Sending token buffer: {self.token_buffer}")
60+
self.connector.send_all_targets(AIMessage(content=self.token_buffer))
61+
self.token_buffer = ""
62+
63+
64+
class S2SConversationalAgent(BaseAgent):
65+
def __init__(self, connectors: Dict[str, BaseConnector]): # type: ignore
66+
super().__init__(connectors=connectors)
67+
self.message_history: List[HumanMessage | AIMessage | SystemMessage] = [
68+
SystemMessage(
69+
content="Pretend you are a robot. Answer as if you were a robot."
70+
)
71+
]
72+
self.speech_queue: Queue[InterfacesHRIMessage] = Queue()
73+
74+
self.llm = get_llm_model(model_type="complex_model", streaming=True)
75+
self._setup_ros_connector()
76+
self.main_thread = None
77+
self.stop_thread = Event()
78+
79+
def run(self):
80+
logging.info("Running S2SConversationalAgent")
81+
self.main_thread = Thread(target=self._main_loop)
82+
self.main_thread.start()
83+
84+
def _main_loop(self):
85+
while not self.stop_thread.is_set():
86+
time.sleep(0.01)
87+
speech = ""
88+
while not self.speech_queue.empty():
89+
speech += "".join(self.speech_queue.get().text)
90+
logging.info(f"Received human speech {speech}!")
91+
if speech != "":
92+
self.message_history.append(HumanMessage(content=speech))
93+
assert isinstance(self.connectors["ros2"], ROS2HRIConnector)
94+
ai_answer = self.llm.invoke(
95+
self.message_history,
96+
config={"callbacks": [LLMTextHandler(self.connectors["ros2"])]},
97+
)
98+
self.message_history.append(ai_answer) # type: ignore
99+
100+
def _on_from_human(self, msg: IROS2Message):
101+
assert isinstance(msg, InterfacesHRIMessage)
102+
logging.info("Received message from human: %s", msg.text)
103+
self.speech_queue.put(msg)
104+
105+
def _setup_ros_connector(self):
106+
self.connectors["ros2"] = ROS2HRIConnector(
107+
sources=[
108+
(
109+
"/from_human",
110+
TopicConfig(
111+
"rai_interfaces/msg/HRIMessage",
112+
is_subscriber=True,
113+
source_author="human",
114+
subscriber_callback=self._on_from_human,
115+
),
116+
)
117+
],
118+
targets=[
119+
(
120+
"/to_human",
121+
TopicConfig(
122+
"rai_interfaces/msg/HRIMessage",
123+
source_author="ai",
124+
is_subscriber=False,
125+
),
126+
)
127+
],
128+
)
129+
130+
def stop(self):
131+
assert isinstance(self.connectors["ros2"], ROS2HRIConnector)
132+
self.connectors["ros2"].shutdown()
133+
self.stop_thread.set()
134+
if self.main_thread is not None:
135+
self.main_thread.join()
136+
137+
138+
def parse_arguments():
139+
parser = argparse.ArgumentParser(
140+
description="Text To Speech Configuration",
141+
allow_abbrev=True,
142+
)
143+
144+
# Use parse_known_args to ignore unknown arguments
145+
args, unknown = parser.parse_known_args()
146+
147+
if unknown:
148+
print(f"Ignoring unknown arguments: {unknown}")
149+
150+
return args
151+
152+
153+
if __name__ == "__main__":
154+
args = parse_arguments()
155+
rclpy.init()
156+
agent = S2SConversationalAgent(connectors={})
157+
agent.run()
158+
159+
def cleanup(signum, frame):
160+
print("\nCustom handler: Caught SIGINT (Ctrl+C).")
161+
print("Performing cleanup")
162+
# Optionally exit the program
163+
agent.stop()
164+
rclpy.shutdown()
165+
exit(0)
166+
167+
signal.signal(signal.SIGINT, cleanup)
168+
169+
while True:
170+
time.sleep(1)

0 commit comments

Comments
 (0)