Skip to content

Commit 4826a5e

Browse files
committed
feat: basic multithreading implementation of transcription
1 parent 5dfbcf9 commit 4826a5e

File tree

6 files changed

+134
-80
lines changed

6 files changed

+134
-80
lines changed

src/rai/rai/agents/base.py

-4
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,6 @@ def __init__(
2727
connectors = {}
2828
self.connectors: dict[str, BaseConnector] = connectors
2929

30-
@abstractmethod
31-
def setup(self, *args, **kwargs):
32-
pass
33-
3430
@abstractmethod
3531
def run(self, *args, **kwargs):
3632
pass

src/rai/rai/agents/voice_agent.py

+115-56
Original file line numberDiff line numberDiff line change
@@ -14,46 +14,61 @@
1414

1515

1616
import time
17-
from threading import Lock, Thread
18-
from typing import Any, List, cast
17+
from threading import Event, Lock, Thread
18+
from typing import Any, List, TypedDict
19+
from uuid import uuid4
1920

2021
import numpy as np
2122
from numpy.typing import NDArray
2223

2324
from rai.agents.base import BaseAgent
2425
from rai.communication import AudioInputDeviceConfig, StreamingAudioInputDevice
25-
from rai_asr.models.base import BaseTranscriptionModel, BaseVoiceDetectionModel
26+
from rai_asr.models import BaseTranscriptionModel, BaseVoiceDetectionModel
2627

2728

28-
class VoiceRecognitionAgent(BaseAgent):
29-
def __init__(self):
30-
super().__init__(connectors={"microphone": StreamingAudioInputDevice()})
31-
self.should_record_pipeline: List[BaseVoiceDetectionModel] = []
32-
self.should_stop_pipeline: List[BaseVoiceDetectionModel] = []
33-
self.transcription_lock = Lock()
34-
self.shared_samples = []
35-
self.recording_started = False
36-
self.ran_setup = False
29+
class ThreadData(TypedDict):
30+
thread: Thread
31+
event: Event
32+
transcription: str
3733

38-
def __call__(self):
39-
self.run()
4034

41-
def setup(
35+
class VoiceRecognitionAgent(BaseAgent):
36+
def __init__(
4237
self,
4338
microphone_device_id: int, # TODO: Change to name based instead of id based identification
4439
microphone_config: AudioInputDeviceConfig,
4540
transcription_model: BaseTranscriptionModel,
41+
vad: BaseVoiceDetectionModel,
42+
grace_period: float = 1.0,
4643
):
47-
self.connectors["microphone"] = cast(
48-
StreamingAudioInputDevice, self.connectors["microphone"]
44+
microphone = StreamingAudioInputDevice()
45+
microphone.configure_device(
46+
target=str(microphone_device_id), config=microphone_config
4947
)
48+
super().__init__(connectors={"microphone": microphone})
5049
self.microphone_device_id = str(microphone_device_id)
51-
self.connectors["microphone"].configure_device(
52-
target=self.microphone_device_id, config=microphone_config
53-
)
50+
self.should_record_pipeline: List[BaseVoiceDetectionModel] = []
51+
self.should_stop_pipeline: List[BaseVoiceDetectionModel] = []
52+
5453
self.transcription_model = transcription_model
55-
self.ran_setup = True
56-
self.running = False
54+
self.transcription_lock = Lock()
55+
56+
self.vad: BaseVoiceDetectionModel = vad
57+
58+
self.grace_period = grace_period
59+
self.grace_period_start = 0
60+
61+
self.recording_started = False
62+
self.ran_setup = False
63+
64+
self.sample_buffer = []
65+
self.sample_buffer_lock = Lock()
66+
self.active_thread = ""
67+
self.transcription_threads: dict[str, ThreadData] = {}
68+
self.buffer_reminders: dict[str, list[NDArray]] = {}
69+
70+
def __call__(self):
71+
self.run()
5772

5873
def add_detection_model(
5974
self, model: BaseVoiceDetectionModel, pipeline: str = "record"
@@ -70,49 +85,93 @@ def run(self):
7085
self.listener_handle = self.connectors["microphone"].start_action(
7186
self.microphone_device_id, self.on_new_sample
7287
)
73-
self.transcription_thread = Thread(target=self._transcription_function)
74-
self.transcription_thread.start()
7588

7689
def stop(self):
7790
self.running = False
7891
self.connectors["microphone"].terminate_action(self.listener_handle)
79-
self.transcription_thread.join()
92+
to_finish = list(self.transcription_threads.keys())
93+
while len(to_finish) > 0:
94+
for thread_id in self.transcription_threads:
95+
if self.transcription_threads[thread_id]["event"].is_set():
96+
self.transcription_threads[thread_id]["thread"].join()
97+
to_finish.remove(thread_id)
98+
else:
99+
print(f"Waiting for transcription of {thread_id} to finish")
80100

81101
def on_new_sample(self, indata: np.ndarray, status_flags: dict[str, Any]):
82-
should_stop = self.should_stop_recording(indata)
83-
if self.should_start_recording(indata):
102+
sample_time = time.time()
103+
with self.sample_buffer_lock:
104+
self.sample_buffer.append(indata)
105+
if not self.recording_started and len(self.sample_buffer) > 5:
106+
self.sample_buffer = self.sample_buffer[-5:]
107+
108+
voice_detected, output_parameters = self.vad.detected(indata, {})
109+
should_record = False
110+
# TODO: second condition is temporary
111+
if voice_detected and not self.recording_started:
112+
should_record = self.should_record(indata, output_parameters)
113+
114+
if should_record:
115+
print("Start recording")
84116
self.recording_started = True
85-
if self.recording_started and not should_stop:
86-
with self.transcription_lock:
87-
self.shared_samples.extend(indata)
117+
thread_id = str(uuid4())[0:8]
118+
transcription_thread = Thread(
119+
target=self.transcription_thread,
120+
args=[thread_id],
121+
)
122+
transcription_finished = Event()
123+
self.active_thread = thread_id
124+
transcription_thread.start()
125+
self.transcription_threads[thread_id] = {
126+
"thread": transcription_thread,
127+
"event": transcription_finished,
128+
"transcription": "",
129+
}
130+
131+
if voice_detected:
132+
self.grace_period_start = sample_time
88133

89-
def should_start_recording(self, audio_data: NDArray[np.int16]) -> bool:
90-
output_parameters = {}
134+
if (
135+
self.recording_started
136+
and sample_time - self.grace_period_start > self.grace_period
137+
):
138+
print("Stop recording")
139+
self.recording_started = False
140+
self.grace_period_start = 0
141+
with self.sample_buffer_lock:
142+
self.buffer_reminders[self.active_thread] = self.sample_buffer
143+
self.sample_buffer = []
144+
self.active_thread = ""
145+
146+
def should_record(
147+
self, audio_data: NDArray, input_parameters: dict[str, Any]
148+
) -> bool:
91149
for model in self.should_record_pipeline:
92-
should_listen, output_parameters = model.detected(
93-
audio_data, output_parameters
94-
)
95-
print(should_listen, output_parameters)
96-
if not should_listen:
97-
return False
98-
return True
99-
100-
def should_stop_recording(self, audio_data: NDArray[np.int16]) -> bool:
101-
output_parameters = {}
102-
for model in self.should_stop_pipeline:
103-
should_listen, output_parameters = model.detected(
104-
audio_data, output_parameters
105-
)
106-
if should_listen:
150+
detected, output = model.detected(audio_data, input_parameters)
151+
print(f"Detected: {detected}: {output}")
152+
if detected:
107153
return True
108154
return False
109155

110-
def _transcription_function(self):
111-
while self.running:
112-
time.sleep(0.1)
113-
# critical section for samples
114-
with self.transcription_lock:
115-
samples = np.array(self.shared_samples)
116-
self.shared_samples = []
117-
# end critical section for samples
118-
self.transcription_model.add_samples(samples)
156+
def transcription_thread(self, identifier: str):
157+
with self.transcription_lock:
158+
while self.active_thread == identifier:
159+
with self.sample_buffer_lock:
160+
if len(self.sample_buffer) == 0:
161+
continue
162+
audio_data = self.sample_buffer.copy()
163+
self.sample_buffer = []
164+
audio_data = np.concatenate(audio_data)
165+
self.transcription_model.transcribe(audio_data)
166+
167+
# transciption of the reminder of the buffer
168+
with self.sample_buffer_lock:
169+
if identifier in self.buffer_reminders:
170+
audio_data = self.buffer_reminders[identifier]
171+
audio_data = np.concatenate(audio_data)
172+
self.transcription_model.transcribe(audio_data)
173+
del self.buffer_reminders[identifier]
174+
transcription = self.transcription_model.consume_transcription()
175+
self.transcription_threads[identifier]["transcription"] = transcription
176+
self.transcription_threads[identifier]["event"].set()
177+
# TODO: sending the transcription

src/rai_asr/rai_asr/models/base.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,13 @@ def __init__(self, model_name: str, sample_rate: int, language: str = "en"):
3535
self.sample_rate = sample_rate
3636
self.language = language
3737

38-
@abstractmethod
39-
def add_samples(self, data: NDArray[np.int16]):
40-
pass
38+
self.latest_transcription = ""
39+
40+
def consume_transcription(self) -> str:
41+
ret = self.latest_transcription
42+
self.latest_transcription = ""
43+
return ret
4144

4245
@abstractmethod
43-
def transcribe(self) -> str:
46+
def transcribe(self, data: NDArray[np.int16]):
4447
pass

src/rai_asr/rai_asr/models/local_whisper.py

+11-13
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from typing import cast
1616

1717
import numpy as np
18+
import torch
1819
import whisper
1920
from numpy._typing import NDArray
2021

@@ -24,24 +25,21 @@
2425
class LocalWhisper(BaseTranscriptionModel):
2526
def __init__(self, model_name: str, sample_rate: int, language: str = "en"):
2627
super().__init__(model_name, sample_rate, language)
27-
self.whisper = whisper.load_model(self.model_name)
28+
if torch.cuda.is_available():
29+
print("Using CUDA")
30+
self.whisper = whisper.load_model(self.model_name, device="cuda")
31+
else:
32+
self.whisper = whisper.load_model(self.model_name)
2833

2934
self.samples = None
3035

31-
def add_samples(self, data: NDArray[np.int16]):
36+
def transcribe(self, data: NDArray[np.int16]):
3237
normalized_data = data.astype(np.float32) / 32768.0
33-
self.samples = (
34-
np.concatenate([self.samples, normalized_data])
35-
if self.samples is not None
36-
else data
37-
)
38-
39-
def transcribe(self) -> str:
40-
if self.samples is None:
41-
raise ValueError("No samples to transcribe")
38+
print("Starting transcription")
4239
result = whisper.transcribe(
43-
self.whisper, self.samples
40+
self.whisper, normalized_data
4441
) # TODO: handling of additional transcribe arguments (perhaps in model init)
42+
print("Finished transcription")
4543
transcription = result["text"]
4644
transcription = cast(str, transcription)
47-
return transcription
45+
self.latest_transcription += transcription

src/rai_asr/rai_asr/models/open_wake_word.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,11 @@ def __init__(self, wake_word_model_path: str, threshold: float = 0.5):
3737
def detected(
3838
self, audio_data: NDArray, input_parameters: dict[str, Any]
3939
) -> Tuple[bool, dict[str, Any]]:
40-
print(len(audio_data))
4140
predictions = self.model.predict(audio_data)
4241
ret = input_parameters.copy()
4342
ret.update({self.model_name: {"predictions": predictions}})
4443
for key, value in predictions.items():
4544
if value > self.threshold:
46-
self.model.reset()
45+
# self.model.reset()
4746
return True, ret
4847
return False, ret

src/rai_asr/rai_asr/models/silero_vad.py

-1
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,5 @@ def detected(
5757
).item()
5858
ret = input_parameters.copy()
5959
ret.update({self.model_name: {"vad_confidence": vad_confidence}})
60-
self.model.reset_states() # NOTE: see streaming example at the bottom https://github.com/snakers4/silero-vad/wiki/Examples-and-Dependencies#dependencies
6160

6261
return vad_confidence > self.threshold, ret

0 commit comments

Comments
 (0)