Skip to content

refactor: move asr and tts agents #469

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Mar 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/s2s/asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
import time

import rclpy
from rai.agents import VoiceRecognitionAgent
from rai.communication.sound_device.api import SoundDeviceConfig

from rai_asr.agents import SpeechRecognitionAgent
from rai_asr.models import LocalWhisper, OpenWakeWord, SileroVAD

VAD_THRESHOLD = 0.8 # Note that this might be different depending on your device
Expand Down Expand Up @@ -100,7 +100,7 @@ def parse_arguments():
rclpy.init()
ros2_name = "rai_asr_agent"

agent = VoiceRecognitionAgent(microphone_configuration, ros2_name, whisper, vad)
agent = SpeechRecognitionAgent(microphone_configuration, ros2_name, whisper, vad)
# optionally add additional models to decide when to record data for transcription
# agent.add_detection_model(oww, pipeline="record")

Expand Down
2 changes: 1 addition & 1 deletion examples/s2s/tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
import time

import rclpy
from rai.agents import TextToSpeechAgent
from rai.communication.sound_device import SoundDeviceConfig

from rai_tts.agents import TextToSpeechAgent
from rai_tts.models import OpenTTS


Expand Down
76 changes: 62 additions & 14 deletions src/rai_asr/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,74 @@

## Description

The RAI ASR (Automatic Speech Recognition) node utilizes a combination of voice activity detection (VAD) and a speech recognition model to transcribe spoken language into text. The node is configured to handle multiple languages and model types, providing flexibility in various ASR applications. It detects speech, records it, and then uses a model to transcribe the recorded audio into text.
This is the [RAI](https://github.com/RobotecAI/rai) automatic speech recognition package.
It contains Agents definitions for the ASR feature.

## Installation
## Models

```bash
rosdep install --from-paths src --ignore-src -r
This package contains three types of models: Voice Activity Detection (VAD), Wake word and transcription.

The `detect` API for VAD and Wake word models, with the following signature:

```
def detect(
self, audio_data: NDArray, input_parameters: dict[str, Any]
) -> Tuple[bool, dict[str, Any]]:
```

Allows for chaining the models into detection piplelines. The `input_parameters` provide a utility to pass the output dictionary from previous models.

The `transcribe` API for transcription models, with the following signature:

```
def transcribe(self, data: NDArray[np.int16]) -> str:
```

## Subscribed Topics
Takes the audio data encoded as 2 byte ints and returns the string with transcription.

### SileroVAD

[SileroVAD](https://github.com/snakers4/silero-vad) is an open source VAD model. It requires no additional setup. It returns confidence regarding there being voice in the provided recording.

### OpenWakeWord

[OpenWakeWord](https://github.com/dscripka/openWakeWord) is an open source package containing multiple pre-configured models, as well as allowing for using custom wake words.
Refer to the package documentation for adding custom wake words.

The model is expected to return `True` if the wake word is detected in the audio sample contains it.

### OpenAIWhisper

[OpenAIWhisper](https://platform.openai.com/docs/guides/speech-to-text) is a cloud-based transcription model. Refer to the documentation for configuration capabilities.
The environment variable `OPEN_API_KEY` needs to be set to a valid OPENAI key in order to use this model.

### LocalWhisper

[LocalWhisper](https://github.com/openai/whisper) is the locally hosted version of OpenAI whisper. It supports GPU acceleration, and follows the same configuration capabilities, as the cloud based one.

### FasterWhisper

[FasterWhisper](https://github.com/SYSTRAN/faster-whisper) is another implementation of the whisper model. It's optimized for speed and memory footprint. It follows the same API as the other two provided implementations.

### Custom Models

Custom VAD, Wake Word, or other detection models can be implemented by inheriting from `rai_asr.base.BaseVoiceDetectionModel`. The `detect` and `reset` methods must be implemented.

Custom transcription models can be implemented by inheriting from `rai_asr.base.BaseTranscriptionModel`. The `transcribe` method must be implemented.

## Agents

### Speech Recognition Agent

This node does not subscribe to any topics. It operates independently, capturing audio directly from the microphone.
The speech recognition Agent uses ROS 2 and sounddevice `Connectors`, to communicate with other agents and access the microphone.

## Published Topics
It fulfills the following ROS 2 communication API:

- **`rai_asr/transcription`** (`std_msgs/String`): Publishes the transcribed text obtained from the audio recording.
- **`rai_asr/status`** (`std_msgs/String`): Publishes node status (recording, transcribing). During transcription, the node does not listen/record.
Publishes to topic `/to_human: [HRIMessage]`:
`message.text` is set with the transcription result using the selected transcription model.

## Parameters
Publishes to topic `/voice_commands: [std_msgs/msg/String]`:

- **`language`** (`string`, default: `"en"`): The language code for the ASR model. This parameter defines the language in which the audio will be transcribed.
- **`model`** (`string`, default: `"base"`): The type of ASR model to use. Different models may have different performance characteristics. For list of models see `python -c "import whisper;print(whisper.available_models())"`
- **`silence_grace_period`** (`double`, default: `1.0`): The grace period in seconds after silence is detected to stop recording. This helps in determining the end of a speech segment.
- **`sample_rate`** (`integer`, default: `0`): The sample rate for audio capture. If set to 0, the sample rate will be auto-detected.
- `"pause"` - when voice is detected but the `detection_pipeline` didn't return detection (for interruptive S2S)
- `"play"` - when voice is not detected, but there was previously a transcription sent
- `"stop"` - when voice is detected and the `detection_pipeline` returned a detection (or is empty)
19 changes: 19 additions & 0 deletions src/rai_asr/rai_asr/agents/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright (C) 2025 Robotec.AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from rai_asr.agents.asr_agent import SpeechRecognitionAgent

__all__ = [
"SpeechRecognitionAgent",
]
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (C) 2024 Robotec.AI
# Copyright (C) 2025 Robotec.AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -21,7 +21,6 @@

import numpy as np
from numpy.typing import NDArray

from rai.agents.base import BaseAgent
from rai.communication import (
HRIPayload,
Expand All @@ -33,6 +32,7 @@
SoundDeviceConnector,
SoundDeviceMessage,
)

from rai_asr.models import BaseTranscriptionModel, BaseVoiceDetectionModel


Expand All @@ -43,7 +43,7 @@ class ThreadData(TypedDict):
joined: bool


class VoiceRecognitionAgent(BaseAgent):
class SpeechRecognitionAgent(BaseAgent):
"""
Agent responsible for voice recognition, transcription, and processing voice activity.

Expand Down
75 changes: 0 additions & 75 deletions src/rai_asr/rai_asr/asr_clients.py

This file was deleted.

4 changes: 0 additions & 4 deletions src/rai_core/rai/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,10 @@
from rai.agents.react_agent import ReActAgent
from rai.agents.state_based import create_state_based_agent
from rai.agents.tool_runner import ToolRunner
from rai.agents.tts_agent import TextToSpeechAgent
from rai.agents.voice_agent import VoiceRecognitionAgent

__all__ = [
"ReActAgent",
"TextToSpeechAgent",
"ToolRunner",
"VoiceRecognitionAgent",
"create_conversational_agent",
"create_state_based_agent",
]
62 changes: 62 additions & 0 deletions src/rai_tts/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# RAI Text To Speech

This is the [RAI](https://github.com/RobotecAI/rai) text to speech package.
It contains Agent definitions for the TTS feature.

## Models

Out of the box the following models are supported:

### ElevenLabs

[ElevenLabs](https://elevenlabs.io/) is a proprietary cloud provider for TTS. Refer to the website for the documentation.
In order to use it the `ELEVENLABS_API_KEY` environment variable must be set, with a valid API key.

### OpenTTS

[OpenTTS](https://github.com/synesthesiam/opentts) is an open source model for TTS.
It can be easily set up using docker. Run:

```
docker run -it -p 5500:5500 synesthesiam/opentts:en --no-espeak
```

To setup a basic english OpenTTS server on port 5500 (default).
Refer to the providers documentation for available voices and options.

### Custom Models

To add your custom TTS model inherit from the `rai_tts.models.base.TTSModel` class.

You can use the following template:

```
class MyTTSModel(TTSModel):
def get_speech(self, text: str) -> AudioSegment:
...
return AudioSegment()

def get_tts_params(self) -> Tuple[int, int]:
...
return sample_rate, channels

```

Such a model will work with the `TextToSpeechAgent` defined below:

## Agents

### TextToSpeechAgent

The TextToSpeechAgent utilises ROS 2 and sounddevice `Connectors` to receive data, and play it using a speaker.
It complies to the following ROS 2 API:

Subscription topic `/to_human: [rai_interfaces/msg/HRIMessage]`:
`message.text` will be parsed, run through the TTS model and played using the speaker
Subscription topic `/voice_commands: [std_msgs/msg/String]`:
The following values are accepted:

- `"play"`: allow for playing the voice through the speaker (if voice queue is not empty)
- `"pause"`: pause the playing of the voice through the speaker
- `"stop"`: stop the current playback and clear the queue
- `"tog_play"`: toggle between play and pause
5 changes: 3 additions & 2 deletions src/rai_tts/rai_tts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .tts_clients import ElevenLabsClient, OpenTTSClient
from .agents import TextToSpeechAgent
from .models import ElevenLabsTTS, OpenTTS

__all__ = ["ElevenLabsClient", "OpenTTSClient"]
__all__ = ["ElevenLabsTTS", "OpenTTS", "TextToSpeechAgent"]
19 changes: 19 additions & 0 deletions src/rai_tts/rai_tts/agents/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright (C) 2025 Robotec.AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from rai_tts.agents.tts_agent import TextToSpeechAgent

__all__ = [
"TextToSpeechAgent",
]
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (C) 2024 Robotec.AI
# Copyright (C) 2025 Robotec.AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -21,8 +21,6 @@

from numpy._typing import NDArray
from pydub import AudioSegment
from std_msgs.msg import String

from rai.agents.base import BaseAgent
from rai.communication import (
ROS2HRIConnector,
Expand All @@ -33,6 +31,8 @@
from rai.communication.ros2.api import IROS2Message
from rai.communication.ros2.connectors import ROS2HRIMessage
from rai.communication.sound_device.connector import SoundDeviceMessage
from std_msgs.msg import String

from rai_interfaces.msg._hri_message import HRIMessage
from rai_tts.models.base import TTSModel

Expand Down
Loading