-
Notifications
You must be signed in to change notification settings - Fork 41
feat: add tts to rai core #419
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
32 commits
Select commit
Hold shift + click to select a range
4832493
feat: add base impl of tts agent and start moving tts models into new…
rachwalk 194c7eb
feat: change connector api to support AudioSegment
rachwalk 7313a5b
feat: working TTS, with pausing
rachwalk 3199a23
feat: working S2S
rachwalk b7eae2a
feat: add agent runner
rachwalk 0ecc1e7
chore: add runner to __init__
rachwalk 7f3fb67
fix: working demo after rebase
rachwalk 416f826
feat: add runners to create configurable, multi-agent deployments
rachwalk 49bd338
diocs: add docstrings for affected classes
rachwalk d42a0c0
chore: rename runner main method to run
rachwalk e363e6d
fix: tts agent support HRI msg
rachwalk 92c8bc7
fix: s2s migrate to HRIMessage
rachwalk da97c8e
fix: end to end working runner with HRI
rachwalk a02ea22
test: update tests to support AudioSegment api
rachwalk 08df784
feat: working multiterminal version
rachwalk 5f4d597
feat: working singleterminal setup
rachwalk 8a81211
feat: remove runner
rachwalk ad35944
chore: remove trash file
rachwalk 65f9c05
fix: race condition on cancelling speech task
rachwalk 4241d8a
fix: race condition on single transcribe queue
rachwalk 7aff7cf
fix: send voice commands only on changes
maciejmajek e16150b
Revert "fix: send voice commands only on changes"
maciejmajek a624d1b
fix: minimise ros2 traffic
rachwalk 339c05b
docs: add S2S docs
rachwalk e0dee3f
fix: minimise ros2 traffic -- add missing if
rachwalk bfab263
fix: conversational example use history
rachwalk 3861410
docs: fix typos
rachwalk 70268d9
chore: add comments on example
rachwalk 746997d
chore: remove useless comment
rachwalk 0ad433c
feat: add ElevenLabsTTS
maciejmajek efefb9a
docs: change the commant for device query
rachwalk 0a37810
chore: pre-commit
rachwalk File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,36 +1,65 @@ | ||
# Human Robot Interface via Voice | ||
|
||
> [!IMPORTANT] | ||
> RAI_ASR supports both local Whisper models and OpenAI Whisper (cloud). When using the cloud version, the OPENAI_API_KEY environment variable must be set with a valid API key. | ||
RAI provides two ROS enabled agents for Speech to Speech communication. | ||
|
||
## Running example | ||
## Automatic Speech Recognition Agent | ||
|
||
See `examples/s2s/asr.py` for an example usage. | ||
|
||
The agent requires configuration of `sounddevice` and `ros2` connectors as well as a required voice activity detection (eg. `SileroVAD`) and transcription model e.g. (`LocalWhisper`), as well as optionally additional models to decide if the transcription should start (e.g. `OpenWakeWord`). | ||
|
||
When your robot's whoami package is ready, run the following: | ||
The Agent publishes information on two topics: | ||
|
||
> [!TIP] | ||
> Make sure rai_whoami is running. | ||
`/from_human`: `rai_interfaces/msg/HRIMessages` - containing transcriptions of the recorded speech | ||
|
||
** Parameters ** | ||
recording_device: The device you want to record with. Check available with: | ||
`/voice_commands`: `std_msgs/msg/String` - containing control commands, to inform the consumer if speech is currently detected (`{"data": "pause"}`), was detected, and now it stopped (`{"data": "play"}`), and if speech was transcribed (`{"data": "stop"}`). | ||
|
||
```bash | ||
python -c 'import sounddevice as sd; print(sd.query_devices())' | ||
The Agent utilises sounddevice module to access user's microphone, by default the `"default"` sound device is used. | ||
To get information about available sounddevices use: | ||
|
||
``` | ||
python -c "import sounddevice; print(sounddevice.query_devices())" | ||
``` | ||
|
||
keep_speaker_busy: some speakers may go into low power mode, which may result in truncated speech beginnings. Set to true to play low frequency, low volume noise to prevent sleep mode. | ||
The device can be identifed by name and passed to the configuration. | ||
|
||
## TextToSpeechAgent | ||
|
||
See `examples/s2s/tts.py` for an example usage. | ||
|
||
The agent requires configuration of `sounddevice` and `ros2` connectors as well as a required TextToSpeech model (e.g. `OpenTTS`). | ||
The Agent listens for information on two topics: | ||
|
||
`/to_human`: `rai_interfaces/msg/HRIMessages` - containing responses to be played to human. These responses are then transcribed and put into the playback queue. | ||
|
||
`/voice_commands`: `std_msgs/msg/String` - containing control commands, to pause current playback (`{"data": "pause"}`), start/continue playback (`{"data": "play"}`), or stop the playback and drop the current playback queue (`{"data": "play"}`). | ||
|
||
The Agent utilises sounddevice module to access user's speaker, by default the `"default"` sound device is used. | ||
To get a list of names of available sound devices use: | ||
|
||
``` | ||
python -c 'import sounddevice as sd; print([x["name"] for x in list(sd.query_devices())])' | ||
``` | ||
|
||
The device can be identifed by name and passed to the configuration. | ||
|
||
### OpenTTS | ||
|
||
```bash | ||
ros2 launch rai_bringup hri.launch.py tts_vendor:=opentts robot_description_package:=<robot_description_package> recording_device:=0 keep_speaker_busy:=(true|false) asr_vendor:=(whisper|openai) | ||
To run OpenTTS (and the example) a docker server containing the model must be running. | ||
|
||
To start it run: | ||
|
||
``` | ||
docker run -it -p 5500:5500 synesthesiam/opentts:en --no-espeak | ||
``` | ||
|
||
> [!NOTE] | ||
> Run OpenTTS with `docker run -it -p 5500:5500 synesthesiam/opentts:en --no-espeak` | ||
## Running example | ||
|
||
### ElevenLabs | ||
To run the provided example of S2S configuration with a minimal LLM-based agent run in 4 separate terminals: | ||
|
||
```bash | ||
ros2 launch rai_bringup hri.launch.py robot_description_package:=<robot_description_package> recording_device:=0 keep_speaker_busy:=(true|false) asr_vendor:=(whisper|openai) | ||
``` | ||
$ docker run -it -p 5500:5500 synesthesiam/opentts:en --no-espeak | ||
$ python ./examples/s2s/asr.py | ||
$ python ./examples/s2s/tts.py | ||
$ python ./examples/s2s/conversational.py | ||
``` |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
# Copyright (C) 2024 Robotec.AI | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import argparse | ||
import signal | ||
import time | ||
|
||
import rclpy | ||
from rai.agents import VoiceRecognitionAgent | ||
from rai.communication.sound_device.api import SoundDeviceConfig | ||
|
||
from rai_asr.models import LocalWhisper, OpenWakeWord, SileroVAD | ||
|
||
VAD_THRESHOLD = 0.8 # Note that this might be different depending on your device | ||
OWW_THRESHOLD = 0.1 # Note that this might be different depending on your device | ||
|
||
VAD_SAMPLING_RATE = 16000 # Or 8000 | ||
DEFAULT_BLOCKSIZE = 1280 | ||
|
||
|
||
def parse_arguments(): | ||
parser = argparse.ArgumentParser( | ||
description="Voice Activity Detection and Wake Word Detection Configuration", | ||
allow_abbrev=True, | ||
) | ||
|
||
# Predefined arguments | ||
parser.add_argument( | ||
"--vad-threshold", | ||
type=float, | ||
default=VAD_THRESHOLD, | ||
help="Voice Activity Detection threshold (default: 0.5)", | ||
) | ||
parser.add_argument( | ||
"--oww-threshold", | ||
type=float, | ||
default=OWW_THRESHOLD, | ||
help="OpenWakeWord threshold (default: 0.1)", | ||
) | ||
parser.add_argument( | ||
"--vad-sampling-rate", | ||
type=int, | ||
choices=[8000, 16000], | ||
default=VAD_SAMPLING_RATE, | ||
help="VAD sampling rate (default: 16000)", | ||
) | ||
parser.add_argument( | ||
"--block-size", | ||
type=int, | ||
default=DEFAULT_BLOCKSIZE, | ||
help="Audio block size (default: 1280)", | ||
) | ||
parser.add_argument( | ||
"--device-name", | ||
type=str, | ||
default="default", | ||
help="Microphone device name (default: 'default')", | ||
) | ||
|
||
# Use parse_known_args to ignore unknown arguments | ||
args, unknown = parser.parse_known_args() | ||
|
||
if unknown: | ||
print(f"Ignoring unknown arguments: {unknown}") | ||
|
||
return args | ||
|
||
|
||
if __name__ == "__main__": | ||
args = parse_arguments() | ||
|
||
microphone_configuration = SoundDeviceConfig( | ||
stream=True, | ||
channels=1, | ||
device_name=args.device_name, | ||
block_size=args.block_size, | ||
consumer_sampling_rate=args.vad_sampling_rate, | ||
dtype="int16", | ||
device_number=None, | ||
is_input=True, | ||
is_output=False, | ||
) | ||
vad = SileroVAD(args.vad_sampling_rate, args.vad_threshold) | ||
oww = OpenWakeWord("hey jarvis", args.oww_threshold) | ||
whisper = LocalWhisper("tiny", args.vad_sampling_rate) | ||
# you can easily switch the the provider by changing the whisper object | ||
# whisper = OpenAIWhisper("whisper-1", args.vad_sampling_rate, "en") | ||
maciejmajek marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
rclpy.init() | ||
ros2_name = "rai_asr_agent" | ||
|
||
agent = VoiceRecognitionAgent(microphone_configuration, ros2_name, whisper, vad) | ||
# optionally add additional models to decide when to record data for transcription | ||
# agent.add_detection_model(oww, pipeline="record") | ||
maciejmajek marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
agent.run() | ||
|
||
def cleanup(signum, frame): | ||
agent.stop() | ||
rclpy.shutdown() | ||
exit(0) | ||
|
||
signal.signal(signal.SIGINT, cleanup) | ||
|
||
while True: | ||
time.sleep(1) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,170 @@ | ||
# Copyright (C) 2024 Robotec.AI | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import argparse | ||
import logging | ||
import signal | ||
import time | ||
from queue import Queue | ||
from threading import Event, Thread | ||
from typing import Dict, List | ||
|
||
import rclpy | ||
from langchain_core.callbacks import BaseCallbackHandler | ||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage | ||
from rai.agents.base import BaseAgent | ||
from rai.communication import BaseConnector | ||
from rai.communication.ros2.api import IROS2Message | ||
from rai.communication.ros2.connectors import ROS2HRIConnector, TopicConfig | ||
from rai.utils.model_initialization import get_llm_model | ||
|
||
from rai_interfaces.msg import HRIMessage as InterfacesHRIMessage | ||
|
||
# NOTE: the Agent code included here is temporary until a dedicated speech agent is created | ||
# it can still serve as a reference for writing your own RAI agents | ||
|
||
|
||
class LLMTextHandler(BaseCallbackHandler): | ||
def __init__(self, connector: ROS2HRIConnector): | ||
self.connector = connector | ||
self.token_buffer = "" | ||
|
||
def on_llm_new_token(self, token: str, **kwargs): | ||
self.token_buffer += token | ||
if len(self.token_buffer) > 100 or token in [".", "?", "!", ",", ";", ":"]: | ||
logging.info(f"Sending token buffer: {self.token_buffer}") | ||
self.connector.send_all_targets(AIMessage(content=self.token_buffer)) | ||
self.token_buffer = "" | ||
|
||
def on_llm_end( | ||
self, | ||
response, | ||
*, | ||
run_id, | ||
parent_run_id=None, | ||
**kwargs, | ||
): | ||
if self.token_buffer: | ||
logging.info(f"Sending token buffer: {self.token_buffer}") | ||
self.connector.send_all_targets(AIMessage(content=self.token_buffer)) | ||
self.token_buffer = "" | ||
|
||
|
||
class S2SConversationalAgent(BaseAgent): | ||
def __init__(self, connectors: Dict[str, BaseConnector]): # type: ignore | ||
super().__init__(connectors=connectors) | ||
self.message_history: List[HumanMessage | AIMessage | SystemMessage] = [ | ||
SystemMessage( | ||
content="Pretend you are a robot. Answer as if you were a robot." | ||
) | ||
] | ||
self.speech_queue: Queue[InterfacesHRIMessage] = Queue() | ||
|
||
self.llm = get_llm_model(model_type="complex_model", streaming=True) | ||
self._setup_ros_connector() | ||
self.main_thread = None | ||
self.stop_thread = Event() | ||
|
||
def run(self): | ||
logging.info("Running S2SConversationalAgent") | ||
self.main_thread = Thread(target=self._main_loop) | ||
self.main_thread.start() | ||
|
||
def _main_loop(self): | ||
while not self.stop_thread.is_set(): | ||
time.sleep(0.01) | ||
speech = "" | ||
while not self.speech_queue.empty(): | ||
speech += "".join(self.speech_queue.get().text) | ||
logging.info(f"Received human speech {speech}!") | ||
if speech != "": | ||
self.message_history.append(HumanMessage(content=speech)) | ||
assert isinstance(self.connectors["ros2"], ROS2HRIConnector) | ||
ai_answer = self.llm.invoke( | ||
self.message_history, | ||
config={"callbacks": [LLMTextHandler(self.connectors["ros2"])]}, | ||
) | ||
self.message_history.append(ai_answer) # type: ignore | ||
|
||
def _on_from_human(self, msg: IROS2Message): | ||
assert isinstance(msg, InterfacesHRIMessage) | ||
logging.info("Received message from human: %s", msg.text) | ||
self.speech_queue.put(msg) | ||
|
||
def _setup_ros_connector(self): | ||
self.connectors["ros2"] = ROS2HRIConnector( | ||
sources=[ | ||
( | ||
"/from_human", | ||
TopicConfig( | ||
"rai_interfaces/msg/HRIMessage", | ||
is_subscriber=True, | ||
source_author="human", | ||
subscriber_callback=self._on_from_human, | ||
), | ||
) | ||
], | ||
targets=[ | ||
( | ||
"/to_human", | ||
TopicConfig( | ||
"rai_interfaces/msg/HRIMessage", | ||
source_author="ai", | ||
is_subscriber=False, | ||
), | ||
) | ||
], | ||
) | ||
|
||
def stop(self): | ||
assert isinstance(self.connectors["ros2"], ROS2HRIConnector) | ||
self.connectors["ros2"].shutdown() | ||
self.stop_thread.set() | ||
if self.main_thread is not None: | ||
self.main_thread.join() | ||
|
||
|
||
def parse_arguments(): | ||
parser = argparse.ArgumentParser( | ||
description="Text To Speech Configuration", | ||
allow_abbrev=True, | ||
) | ||
|
||
# Use parse_known_args to ignore unknown arguments | ||
args, unknown = parser.parse_known_args() | ||
|
||
if unknown: | ||
print(f"Ignoring unknown arguments: {unknown}") | ||
|
||
return args | ||
|
||
|
||
if __name__ == "__main__": | ||
args = parse_arguments() | ||
rclpy.init() | ||
agent = S2SConversationalAgent(connectors={}) | ||
agent.run() | ||
|
||
def cleanup(signum, frame): | ||
print("\nCustom handler: Caught SIGINT (Ctrl+C).") | ||
print("Performing cleanup") | ||
# Optionally exit the program | ||
agent.stop() | ||
rclpy.shutdown() | ||
exit(0) | ||
|
||
signal.signal(signal.SIGINT, cleanup) | ||
|
||
while True: | ||
time.sleep(1) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Which part of the name?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you mean? All of it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think only the value to the
r'\sALSA'
should be passedThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please refer to the
sd.query_devices()
documentation. It is generally expected that a reasonable user will not run code without understanding its effects on their machine. When referring to "name" in the documentation, I refer to whatever is returned under the "name" field of objects returned by the function called. It would be highly counterintuitive to refer to anything else. Especially to refer using a regex, which are famously counterintuitive - seeing that yours includes the trailing comma.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just provide an example.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please refer to line 17 of this file, which provides an example.