Skip to content

Commit 5dfbcf9

Browse files
committed
fix: apply suggestions from preliminary review
1 parent 7871662 commit 5dfbcf9

File tree

5 files changed

+76
-30
lines changed

5 files changed

+76
-30
lines changed

src/rai/rai/agents/voice_agent.py

+40-19
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,16 @@
1313
# limitations under the License.
1414

1515

16+
import time
1617
from threading import Lock, Thread
17-
from typing import Any, List, Tuple
18+
from typing import Any, List, cast
1819

1920
import numpy as np
2021
from numpy.typing import NDArray
2122

2223
from rai.agents.base import BaseAgent
2324
from rai.communication import AudioInputDeviceConfig, StreamingAudioInputDevice
24-
from rai_asr.models.base import BaseVoiceDetectionModel
25+
from rai_asr.models.base import BaseTranscriptionModel, BaseVoiceDetectionModel
2526

2627

2728
class VoiceRecognitionAgent(BaseAgent):
@@ -38,34 +39,50 @@ def __call__(self):
3839
self.run()
3940

4041
def setup(
41-
self, microphone_device_id: int, microphone_config: AudioInputDeviceConfig
42+
self,
43+
microphone_device_id: int, # TODO: Change to name based instead of id based identification
44+
microphone_config: AudioInputDeviceConfig,
45+
transcription_model: BaseTranscriptionModel,
4246
):
43-
assert isinstance(self.connectors["microphone"], StreamingAudioInputDevice)
47+
self.connectors["microphone"] = cast(
48+
StreamingAudioInputDevice, self.connectors["microphone"]
49+
)
4450
self.microphone_device_id = str(microphone_device_id)
4551
self.connectors["microphone"].configure_device(
4652
target=self.microphone_device_id, config=microphone_config
4753
)
54+
self.transcription_model = transcription_model
4855
self.ran_setup = True
56+
self.running = False
57+
58+
def add_detection_model(
59+
self, model: BaseVoiceDetectionModel, pipeline: str = "record"
60+
):
61+
if pipeline == "record":
62+
self.should_record_pipeline.append(model)
63+
elif pipeline == "stop":
64+
self.should_stop_pipeline.append(model)
65+
else:
66+
raise ValueError("Pipeline should be either 'record' or 'stop'")
4967

5068
def run(self):
69+
self.running = True
5170
self.listener_handle = self.connectors["microphone"].start_action(
5271
self.microphone_device_id, self.on_new_sample
5372
)
5473
self.transcription_thread = Thread(target=self._transcription_function)
5574
self.transcription_thread.start()
5675

5776
def stop(self):
77+
self.running = False
5878
self.connectors["microphone"].terminate_action(self.listener_handle)
5979
self.transcription_thread.join()
6080

6181
def on_new_sample(self, indata: np.ndarray, status_flags: dict[str, Any]):
62-
should_stop, should_cancel = self.should_stop_recording(indata)
63-
print(indata)
64-
if should_cancel:
65-
self.cancel_task()
66-
if (self.recording_started and not should_stop) or (
67-
self.should_start_recording(indata)
68-
):
82+
should_stop = self.should_stop_recording(indata)
83+
if self.should_start_recording(indata):
84+
self.recording_started = True
85+
if self.recording_started and not should_stop:
6986
with self.transcription_lock:
7087
self.shared_samples.extend(indata)
7188

@@ -75,23 +92,27 @@ def should_start_recording(self, audio_data: NDArray[np.int16]) -> bool:
7592
should_listen, output_parameters = model.detected(
7693
audio_data, output_parameters
7794
)
95+
print(should_listen, output_parameters)
7896
if not should_listen:
7997
return False
8098
return True
8199

82-
def should_stop_recording(self, audio_data: NDArray[np.int16]) -> Tuple[bool, bool]:
100+
def should_stop_recording(self, audio_data: NDArray[np.int16]) -> bool:
83101
output_parameters = {}
84102
for model in self.should_stop_pipeline:
85103
should_listen, output_parameters = model.detected(
86104
audio_data, output_parameters
87105
)
88-
# TODO: Add handling output parametrs for checking if should cancel
89106
if should_listen:
90-
return False, False
91-
return True, False
107+
return True
108+
return False
92109

93110
def _transcription_function(self):
94-
with self.transcription_lock:
95-
samples = np.array(self.shared_samples)
96-
print(samples)
97-
self.shared_samples = []
111+
while self.running:
112+
time.sleep(0.1)
113+
# critical section for samples
114+
with self.transcription_lock:
115+
samples = np.array(self.shared_samples)
116+
self.shared_samples = []
117+
# end critical section for samples
118+
self.transcription_model.add_samples(samples)

src/rai_asr/rai_asr/models/base.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,9 @@ def __init__(self, model_name: str, sample_rate: int, language: str = "en"):
3636
self.language = language
3737

3838
@abstractmethod
39-
def transcribe(self, data: NDArray[np.int16]) -> str:
39+
def add_samples(self, data: NDArray[np.int16]):
4040
pass
4141

42-
def __call__(self, data: NDArray[np.int16]) -> str:
43-
return self.transcribe(data)
42+
@abstractmethod
43+
def transcribe(self) -> str:
44+
pass

src/rai_asr/rai_asr/models/local_whisper.py

+19-4
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from typing import cast
16+
1517
import numpy as np
1618
import whisper
1719
from numpy._typing import NDArray
@@ -24,9 +26,22 @@ def __init__(self, model_name: str, sample_rate: int, language: str = "en"):
2426
super().__init__(model_name, sample_rate, language)
2527
self.whisper = whisper.load_model(self.model_name)
2628

27-
def transcribe(self, data: NDArray[np.int16]) -> str:
28-
result = whisper.transcribe(self.whisper, data.astype(np.float32) / 32768.0)
29+
self.samples = None
30+
31+
def add_samples(self, data: NDArray[np.int16]):
32+
normalized_data = data.astype(np.float32) / 32768.0
33+
self.samples = (
34+
np.concatenate([self.samples, normalized_data])
35+
if self.samples is not None
36+
else data
37+
)
38+
39+
def transcribe(self) -> str:
40+
if self.samples is None:
41+
raise ValueError("No samples to transcribe")
42+
result = whisper.transcribe(
43+
self.whisper, self.samples
44+
) # TODO: handling of additional transcribe arguments (perhaps in model init)
2945
transcription = result["text"]
30-
# NOTE: this is only for type enforcement, doesn't need to work on runtime
31-
assert isinstance(transcription, str)
46+
transcription = cast(str, transcription)
3247
return transcription

src/rai_asr/rai_asr/models/open_ai_whisper.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,19 @@ def __init__(self, model_name: str, sample_rate: int, language: str = "en"):
3636
self.openai_client.audio.transcriptions.create,
3737
model=self.model_name,
3838
)
39+
self.samples = []
3940

40-
def transcribe(self, data: NDArray[np.int16]) -> str:
41+
def add_samples(self, data: NDArray[np.int16]):
42+
normalized_data = data.astype(np.float32) / 32768.0
43+
self.samples = (
44+
np.concatenate([self.samples, normalized_data])
45+
if self.samples is not None
46+
else data
47+
)
48+
49+
def transcribe(self) -> str:
4150
with io.BytesIO() as temp_wav_buffer:
42-
wavfile.write(temp_wav_buffer, self.sample_rate, data)
51+
wavfile.write(temp_wav_buffer, self.sample_rate, self.samples)
4352
temp_wav_buffer.seek(0)
4453
temp_wav_buffer.name = "temp.wav"
4554
response = self.model(file=temp_wav_buffer, language=self.language)

src/rai_asr/rai_asr/models/silero_vad.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Any, Tuple
15+
from typing import Any, Literal, Tuple
1616

1717
import numpy as np
1818
import torch
@@ -22,7 +22,7 @@
2222

2323

2424
class SileroVAD(BaseVoiceDetectionModel):
25-
def __init__(self, sampling_rate=16000, threshold=0.5):
25+
def __init__(self, sampling_rate: Literal[8000, 16000] = 16000, threshold=0.5):
2626
super(SileroVAD, self).__init__()
2727
self.model_name = "silero_vad"
2828
self.model, _ = torch.hub.load(

0 commit comments

Comments
 (0)