Skip to content

Commit 8576b84

Browse files
committed
feat: working singleterminal setup
1 parent 3d8d64c commit 8576b84

File tree

9 files changed

+490
-33
lines changed

9 files changed

+490
-33
lines changed

examples/s2s/asr.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
# Copyright (C) 2024 Robotec.AI
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import argparse
16+
import signal
17+
import time
18+
19+
import rclpy
20+
from rai.agents import VoiceRecognitionAgent
21+
from rai.communication.sound_device.api import SoundDeviceConfig
22+
23+
from rai_asr.models import LocalWhisper, OpenWakeWord, SileroVAD
24+
25+
VAD_THRESHOLD = 0.8 # Note that this might be different depending on your device
26+
OWW_THRESHOLD = 0.1 # Note that this might be different depending on your device
27+
28+
VAD_SAMPLING_RATE = 16000 # Or 8000
29+
DEFAULT_BLOCKSIZE = 1280
30+
31+
32+
def parse_arguments():
33+
parser = argparse.ArgumentParser(
34+
description="Voice Activity Detection and Wake Word Detection Configuration",
35+
allow_abbrev=True,
36+
)
37+
38+
# Predefined arguments
39+
parser.add_argument(
40+
"--vad-threshold",
41+
type=float,
42+
default=VAD_THRESHOLD,
43+
help="Voice Activity Detection threshold (default: 0.5)",
44+
)
45+
parser.add_argument(
46+
"--oww-threshold",
47+
type=float,
48+
default=OWW_THRESHOLD,
49+
help="OpenWakeWord threshold (default: 0.1)",
50+
)
51+
parser.add_argument(
52+
"--vad-sampling-rate",
53+
type=int,
54+
choices=[8000, 16000],
55+
default=VAD_SAMPLING_RATE,
56+
help="VAD sampling rate (default: 16000)",
57+
)
58+
parser.add_argument(
59+
"--block-size",
60+
type=int,
61+
default=DEFAULT_BLOCKSIZE,
62+
help="Audio block size (default: 1280)",
63+
)
64+
parser.add_argument(
65+
"--device-name",
66+
type=str,
67+
default="default",
68+
help="Microphone device name (default: 'default')",
69+
)
70+
71+
# Use parse_known_args to ignore unknown arguments
72+
args, unknown = parser.parse_known_args()
73+
74+
if unknown:
75+
print(f"Ignoring unknown arguments: {unknown}")
76+
77+
return args
78+
79+
80+
if __name__ == "__main__":
81+
args = parse_arguments()
82+
83+
microphone_configuration = SoundDeviceConfig(
84+
stream=True,
85+
channels=1,
86+
device_name=args.device_name,
87+
block_size=args.block_size,
88+
consumer_sampling_rate=args.vad_sampling_rate,
89+
dtype="int16",
90+
device_number=None,
91+
is_input=True,
92+
is_output=False,
93+
)
94+
vad = SileroVAD(args.vad_sampling_rate, args.vad_threshold)
95+
oww = OpenWakeWord("hey jarvis", args.oww_threshold)
96+
whisper = LocalWhisper("tiny", args.vad_sampling_rate)
97+
# whisper = OpenAIWhisper("whisper-1", args.vad_sampling_rate, "en")
98+
99+
rclpy.init()
100+
ros2_name = "rai_asr_agent"
101+
102+
agent = VoiceRecognitionAgent(microphone_configuration, ros2_name, whisper, vad)
103+
agent.add_detection_model(oww, pipeline="record")
104+
105+
agent.run()
106+
107+
def cleanup(signum, frame):
108+
print("\nCustom handler: Caught SIGINT (Ctrl+C).")
109+
print("Performing cleanup")
110+
# Optionally exit the program
111+
agent.stop()
112+
rclpy.shutdown()
113+
exit(0)
114+
115+
signal.signal(signal.SIGINT, cleanup)
116+
117+
print("Runnin")
118+
while True:
119+
time.sleep(1)

examples/s2s/conversational.py

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
# Copyright (C) 2024 Robotec.AI
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import argparse
16+
import logging
17+
import signal
18+
import time
19+
from queue import Queue
20+
from threading import Event, Thread
21+
from typing import Dict, List
22+
23+
import rclpy
24+
from langchain_core.callbacks import BaseCallbackHandler
25+
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
26+
from rai.agents.base import BaseAgent
27+
from rai.communication import BaseConnector
28+
from rai.communication.ros2.api import IROS2Message
29+
from rai.communication.ros2.connectors import ROS2HRIConnector, TopicConfig
30+
from rai.utils.model_initialization import get_llm_model
31+
32+
from rai_interfaces.msg import HRIMessage as InterfacesHRIMessage
33+
34+
# NOTE: the Agent code included here is temporary until a dedicated speech agent is created
35+
# it can still serve as a reference for writing your own RAI agents
36+
37+
38+
class LLMTextHandler(BaseCallbackHandler):
39+
def __init__(self, connector: ROS2HRIConnector):
40+
self.connector = connector
41+
self.token_buffer = ""
42+
43+
def on_llm_new_token(self, token: str, **kwargs):
44+
self.token_buffer += token
45+
if len(self.token_buffer) > 100 or token in [".", "?", "!", ",", ";", ":"]:
46+
self.connector.send_all_targets(AIMessage(content=self.token_buffer))
47+
self.token_buffer = ""
48+
49+
def on_llm_end(
50+
self,
51+
response,
52+
*,
53+
run_id,
54+
parent_run_id=None,
55+
**kwargs,
56+
):
57+
if self.token_buffer:
58+
self.connector.send_all_targets(AIMessage(content=self.token_buffer))
59+
self.token_buffer = ""
60+
61+
62+
class S2SConversationalAgent(BaseAgent):
63+
def __init__(self, connectors: Dict[str, BaseConnector]): # type: ignore
64+
super().__init__(connectors=connectors)
65+
self.message_history: List[HumanMessage | AIMessage | SystemMessage] = [
66+
SystemMessage(
67+
content="Pretend you are a robot. Answer as if you were a robot."
68+
)
69+
]
70+
self.speech_queue: Queue[InterfacesHRIMessage] = Queue()
71+
72+
self.llm = get_llm_model(model_type="complex_model", streaming=True)
73+
self._setup_ros_connector()
74+
self.main_thread = None
75+
self.stop_thread = Event()
76+
77+
def run(self):
78+
logging.info("Running S2SConversationalAgent")
79+
self.main_thread = Thread(target=self._main_loop)
80+
self.main_thread.start()
81+
82+
def _main_loop(self):
83+
while not self.stop_thread.is_set():
84+
time.sleep(0.01)
85+
speech = ""
86+
while not self.speech_queue.empty():
87+
speech += "".join(self.speech_queue.get().text)
88+
logging.info(f"Received human speech {speech}!")
89+
if speech != "":
90+
self.message_history.append(HumanMessage(content=speech))
91+
assert isinstance(self.connectors["ros2"], ROS2HRIConnector)
92+
# ai_answer = AIMessage(content="Yes, I am Jar Jar Binks")
93+
# self.connectors["ros2"].send_all_targets(ai_answer)
94+
ai_answer = self.llm.invoke(
95+
speech,
96+
config={"callbacks": [LLMTextHandler(self.connectors["ros2"])]},
97+
)
98+
self.message_history.append(ai_answer) # type: ignore
99+
100+
def _on_from_human(self, msg: IROS2Message):
101+
assert isinstance(msg, InterfacesHRIMessage)
102+
logging.info("Received message from human: %s", msg.text)
103+
self.speech_queue.put(msg)
104+
105+
def _setup_ros_connector(self):
106+
self.connectors["ros2"] = ROS2HRIConnector(
107+
sources=[
108+
(
109+
"/from_human",
110+
TopicConfig(
111+
"rai_interfaces/msg/HRIMessage",
112+
is_subscriber=True,
113+
source_author="human",
114+
subscriber_callback=self._on_from_human,
115+
),
116+
)
117+
],
118+
targets=[
119+
(
120+
"/to_human",
121+
TopicConfig(
122+
"rai_interfaces/msg/HRIMessage",
123+
source_author="ai",
124+
is_subscriber=False,
125+
),
126+
)
127+
],
128+
)
129+
130+
def stop(self):
131+
assert isinstance(self.connectors["ros2"], ROS2HRIConnector)
132+
self.connectors["ros2"].shutdown()
133+
self.stop_thread.set()
134+
if self.main_thread is not None:
135+
self.main_thread.join()
136+
137+
138+
def parse_arguments():
139+
parser = argparse.ArgumentParser(
140+
description="Text To Speech Configuration",
141+
allow_abbrev=True,
142+
)
143+
144+
# Use parse_known_args to ignore unknown arguments
145+
args, unknown = parser.parse_known_args()
146+
147+
if unknown:
148+
print(f"Ignoring unknown arguments: {unknown}")
149+
150+
return args
151+
152+
153+
if __name__ == "__main__":
154+
args = parse_arguments()
155+
rclpy.init()
156+
agent = S2SConversationalAgent(connectors={})
157+
agent.run()
158+
159+
def cleanup(signum, frame):
160+
print("\nCustom handler: Caught SIGINT (Ctrl+C).")
161+
print("Performing cleanup")
162+
# Optionally exit the program
163+
agent.stop()
164+
rclpy.shutdown()
165+
exit(0)
166+
167+
signal.signal(signal.SIGINT, cleanup)
168+
169+
print("Runnin")
170+
while True:
171+
time.sleep(1)

examples/s2s/run.sh

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
#!/usr/bin/env bash
2+
# Directory where the scripts are located
3+
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
4+
5+
# Array to store PIDs of background processes
6+
declare -a PIDS
7+
8+
# Function to run a script with the given arguments
9+
run_script() {
10+
local script="$1"
11+
shift
12+
python3 "$script" "$@" &
13+
# Store the PID of the last background process
14+
PIDS+=($!)
15+
}
16+
17+
# Function to handle Ctrl+C (SIGINT)
18+
handle_sigint() {
19+
echo -e "\nReceived SIGINT, forwarding to all running Python processes..."
20+
21+
# Send SIGINT to all child processes
22+
for pid in "${PIDS[@]}"; do
23+
if kill -0 "$pid" 2>/dev/null; then
24+
echo "Sending SIGINT to process $pid"
25+
kill -SIGINT "$pid"
26+
fi
27+
done
28+
29+
echo "Waiting for all processes to exit..."
30+
wait
31+
32+
echo "All processes have exited. Cleaning up and exiting."
33+
exit 0
34+
}
35+
36+
# Main logic
37+
main() {
38+
# Set up trap for SIGINT (Ctrl+C)
39+
trap handle_sigint SIGINT
40+
41+
# Find all Python scripts in the scripts directory
42+
mapfile -t scripts < <(find "$SCRIPT_DIR" -name "*.py")
43+
44+
# If no scripts found, exit
45+
if [ ${#scripts[@]} -eq 0 ]; then
46+
echo "No Python scripts found in $SCRIPT_DIR"
47+
exit 1
48+
fi
49+
50+
echo "Found ${#scripts[@]} Python scripts in $SCRIPT_DIR"
51+
52+
# Run all scripts in parallel with all arguments properly quoted
53+
for script in "${scripts[@]}"; do
54+
run_script "$script" "$@"
55+
done
56+
57+
echo "All scripts are running in the background. Press Ctrl+C to stop them."
58+
59+
# Wait for all background processes to finish
60+
wait
61+
62+
echo "All scripts completed successfully."
63+
}
64+
65+
# Call main with all arguments properly quoted
66+
main "$@"

0 commit comments

Comments
 (0)