Skip to content

Commit 9d08988

Browse files
committed
feat: add configurable voice agent basic version
1 parent 95e7124 commit 9d08988

File tree

7 files changed

+187
-2
lines changed

7 files changed

+187
-2
lines changed

src/rai/rai/agents/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@
1515
from rai.agents.conversational_agent import create_conversational_agent
1616
from rai.agents.state_based import create_state_based_agent
1717
from rai.agents.tool_runner import ToolRunner
18+
from rai.agents.voice_agent import VoiceRecognitionAgent
1819

1920
__all__ = [
2021
"ToolRunner",
2122
"create_conversational_agent",
2223
"create_state_based_agent",
24+
"VoiceRecognitionAgent",
2325
]

src/rai/rai/agents/base.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Copyright (C) 2024 Robotec.AI
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
from abc import ABC, abstractmethod
17+
from typing import Optional
18+
19+
from rai.communication import BaseConnector
20+
21+
22+
class BaseAgent(ABC):
23+
def __init__(
24+
self, connectors: Optional[dict[str, BaseConnector]] = None, *args, **kwargs
25+
):
26+
if connectors is None:
27+
connectors = {}
28+
self.connectors: dict[str, BaseConnector] = connectors
29+
30+
@abstractmethod
31+
def setup(self, *args, **kwargs):
32+
pass
33+
34+
@abstractmethod
35+
def run(self, *args, **kwargs):
36+
pass

src/rai/rai/agents/voice_agent.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# Copyright (C) 2024 Robotec.AI
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
from threading import Lock, Thread
17+
from typing import Any, List, Tuple
18+
19+
import numpy as np
20+
from numpy.typing import NDArray
21+
22+
from rai.agents.base import BaseAgent
23+
from rai.communication import AudioInputDeviceConfig, StreamingAudioInputDevice
24+
from rai_asr.models.base import BaseVoiceDetectionModel
25+
26+
27+
class VoiceRecognitionAgent(BaseAgent):
28+
def __init__(self):
29+
super().__init__(connectors={"microphone": StreamingAudioInputDevice()})
30+
self.should_record_pipeline: List[BaseVoiceDetectionModel] = []
31+
self.should_stop_pipeline: List[BaseVoiceDetectionModel] = []
32+
self.transcription_lock = Lock()
33+
self.shared_samples = []
34+
self.recording_started = False
35+
self.ran_setup = False
36+
37+
def __call__(self):
38+
self.run()
39+
40+
def setup(
41+
self, microphone_device_id: int, microphone_config: AudioInputDeviceConfig
42+
):
43+
assert isinstance(self.connectors["microphone"], StreamingAudioInputDevice)
44+
self.microphone_device_id = str(microphone_device_id)
45+
self.connectors["microphone"].configure_device(
46+
target=self.microphone_device_id, config=microphone_config
47+
)
48+
self.ran_setup = True
49+
50+
def run(self):
51+
self.listener_handle = self.connectors["microphone"].start_action(
52+
self.microphone_device_id, self.on_new_sample
53+
)
54+
self.transcription_thread = Thread(target=self._transcription_function)
55+
self.transcription_thread.start()
56+
57+
def stop(self):
58+
self.connectors["microphone"].terminate_action(self.listener_handle)
59+
self.transcription_thread.join()
60+
61+
def on_new_sample(self, indata: np.ndarray, status_flags: dict[str, Any]):
62+
should_stop, should_cancel = self.should_stop_recording(indata)
63+
print(indata)
64+
if should_cancel:
65+
self.cancel_task()
66+
if (self.recording_started and not should_stop) or (
67+
self.should_start_recording(indata)
68+
):
69+
with self.transcription_lock:
70+
self.shared_samples.extend(indata)
71+
72+
def should_start_recording(self, audio_data: NDArray[np.int16]) -> bool:
73+
output_parameters = {}
74+
for model in self.should_record_pipeline:
75+
should_listen, output_parameters = model.detected(
76+
audio_data, output_parameters
77+
)
78+
if not should_listen:
79+
return False
80+
return True
81+
82+
def should_stop_recording(self, audio_data: NDArray[np.int16]) -> Tuple[bool, bool]:
83+
output_parameters = {}
84+
for model in self.should_stop_pipeline:
85+
should_listen, output_parameters = model.detected(
86+
audio_data, output_parameters
87+
)
88+
# TODO: Add handling output parametrs for checking if should cancel
89+
if should_listen:
90+
return False, False
91+
return True, False
92+
93+
def _transcription_function(self):
94+
with self.transcription_lock:
95+
samples = np.array(self.shared_samples)
96+
print(samples)
97+
self.shared_samples = []

src/rai/rai/communication/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,16 @@
1313
# limitations under the License.
1414

1515
from .base_connector import BaseConnector, BaseMessage
16-
from .sound_device_connector import SoundDeviceError, StreamingAudioInputDevice
16+
from .sound_device_connector import (
17+
AudioInputDeviceConfig,
18+
SoundDeviceError,
19+
StreamingAudioInputDevice,
20+
)
1721

1822
__all__ = [
1923
"BaseMessage",
2024
"BaseConnector",
2125
"StreamingAudioInputDevice",
2226
"SoundDeviceError",
27+
"AudioInputDeviceConfig",
2328
]

src/rai/rai/communication/sound_device_connector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def start_action(
9999
self,
100100
target: str,
101101
on_feedback: Callable[[np.ndarray, dict[str, Any]], None],
102-
on_finish: Callable = lambda _: None,
102+
on_finish: Callable = lambda: None,
103103
) -> str:
104104

105105
target_device = self.configred_devices.get(target)
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Copyright (C) 2024 Robotec.AI
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from .base import BaseVoiceDetectionModel
16+
17+
__all__ = ["BaseVoiceDetectionModel"]

src/rai_asr/rai_asr/models/base.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# Copyright (C) 2024 Robotec.AI
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
from abc import ABC, abstractmethod
17+
from typing import Any, Tuple
18+
19+
from numpy._typing import NDArray
20+
21+
22+
class BaseVoiceDetectionModel(ABC):
23+
24+
@abstractmethod
25+
def detected(
26+
self, audio_data: NDArray, input_parameters: dict[str, Any]
27+
) -> Tuple[bool, dict[str, Any]]:
28+
pass

0 commit comments

Comments
 (0)