Skip to content

Commit 06fdbca

Browse files
committed
feat: functional multithreading transcription
1 parent 4826a5e commit 06fdbca

File tree

4 files changed

+53
-20
lines changed

4 files changed

+53
-20
lines changed

src/rai/rai/agents/voice_agent.py

+23-9
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,10 @@
1313
# limitations under the License.
1414

1515

16+
import logging
1617
import time
1718
from threading import Event, Lock, Thread
18-
from typing import Any, List, TypedDict
19+
from typing import Any, List, Optional, TypedDict
1920
from uuid import uuid4
2021

2122
import numpy as np
@@ -40,7 +41,12 @@ def __init__(
4041
transcription_model: BaseTranscriptionModel,
4142
vad: BaseVoiceDetectionModel,
4243
grace_period: float = 1.0,
44+
logger: Optional[logging.Logger] = None,
4345
):
46+
if logger is None:
47+
self.logger = logging.getLogger(__name__)
48+
else:
49+
self.logger = logger
4450
microphone = StreamingAudioInputDevice()
4551
microphone.configure_device(
4652
target=str(microphone_device_id), config=microphone_config
@@ -87,16 +93,20 @@ def run(self):
8793
)
8894

8995
def stop(self):
96+
self.logger.info("Stopping voice agent")
9097
self.running = False
9198
self.connectors["microphone"].terminate_action(self.listener_handle)
92-
to_finish = list(self.transcription_threads.keys())
93-
while len(to_finish) > 0:
99+
to_finish = len(list(self.transcription_threads.keys()))
100+
while to_finish > 0:
94101
for thread_id in self.transcription_threads:
95102
if self.transcription_threads[thread_id]["event"].is_set():
96103
self.transcription_threads[thread_id]["thread"].join()
97-
to_finish.remove(thread_id)
104+
to_finish -= 1
98105
else:
99-
print(f"Waiting for transcription of {thread_id} to finish")
106+
self.logger.info(
107+
f"Waiting for transcription of {thread_id} to finish..."
108+
)
109+
self.logger.info("Voice agent stopped")
100110

101111
def on_new_sample(self, indata: np.ndarray, status_flags: dict[str, Any]):
102112
sample_time = time.time()
@@ -112,7 +122,7 @@ def on_new_sample(self, indata: np.ndarray, status_flags: dict[str, Any]):
112122
should_record = self.should_record(indata, output_parameters)
113123

114124
if should_record:
115-
print("Start recording")
125+
self.logger.info("starting recording...")
116126
self.recording_started = True
117127
thread_id = str(uuid4())[0:8]
118128
transcription_thread = Thread(
@@ -129,13 +139,14 @@ def on_new_sample(self, indata: np.ndarray, status_flags: dict[str, Any]):
129139
}
130140

131141
if voice_detected:
142+
self.logger.debug("Voice detected... resetting grace period")
132143
self.grace_period_start = sample_time
133144

134145
if (
135146
self.recording_started
136147
and sample_time - self.grace_period_start > self.grace_period
137148
):
138-
print("Stop recording")
149+
self.logger.info("Grace period ended... stopping recording")
139150
self.recording_started = False
140151
self.grace_period_start = 0
141152
with self.sample_buffer_lock:
@@ -148,12 +159,12 @@ def should_record(
148159
) -> bool:
149160
for model in self.should_record_pipeline:
150161
detected, output = model.detected(audio_data, input_parameters)
151-
print(f"Detected: {detected}: {output}")
152162
if detected:
153163
return True
154164
return False
155165

156166
def transcription_thread(self, identifier: str):
167+
self.logger.info(f"transcription thread {identifier} started")
157168
with self.transcription_lock:
158169
while self.active_thread == identifier:
159170
with self.sample_buffer_lock:
@@ -171,7 +182,10 @@ def transcription_thread(self, identifier: str):
171182
audio_data = np.concatenate(audio_data)
172183
self.transcription_model.transcribe(audio_data)
173184
del self.buffer_reminders[identifier]
185+
# self.transcription_model.save_wav(f"{identifier}.wav")
174186
transcription = self.transcription_model.consume_transcription()
175187
self.transcription_threads[identifier]["transcription"] = transcription
176188
self.transcription_threads[identifier]["event"].set()
177-
# TODO: sending the transcription
189+
# TODO: sending the transcription once https://github.com/RobotecAI/rai/pull/360 is merged
190+
self.logger.info(f"transcription thread {identifier} finished")
191+
print(transcription)

src/rai/rai/communication/sound_device_connector.py

+3-6
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ def __init__(self, msg: str):
3030
class AudioInputDeviceConfig(TypedDict):
3131
block_size: int
3232
consumer_sampling_rate: int
33-
target_sampling_rate: int
3433
dtype: str
3534
device_number: Optional[int]
3635

@@ -44,7 +43,6 @@ class ConfiguredAudioInputDevice:
4443
sample_rate (int): Device sample rate
4544
consumer_sampling_rate (int): The sampling rate of the consumer
4645
window_size_samples (int): The size of the window in samples
47-
target_sampling_rate (int): The target sampling rate
4846
dtype (str): The data type of the audio samples
4947
"""
5048

@@ -58,7 +56,6 @@ def __init__(self, config: AudioInputDeviceConfig):
5856
self.window_size_samples = int(
5957
config["block_size"] * self.sample_rate / config["consumer_sampling_rate"]
6058
)
61-
self.target_sampling_rate = int(config["target_sampling_rate"])
6259
self.dtype = config["dtype"]
6360

6461

@@ -108,9 +105,9 @@ def start_action(
108105

109106
def callback(indata: np.ndarray, frames: int, _, status: CallbackFlags):
110107
indata = indata.flatten()
111-
sample_time_length = len(indata) / target_device.target_sampling_rate
112-
if target_device.sample_rate != target_device.target_sampling_rate:
113-
indata = resample(indata, int(sample_time_length * target_device.target_sampling_rate)) # type: ignore
108+
sample_time_length = len(indata) / target_device.sample_rate
109+
if target_device.sample_rate != target_device.consumer_sampling_rate:
110+
indata = resample(indata, int(sample_time_length * target_device.consumer_sampling_rate)) # type: ignore
114111
flag_dict = {
115112
"input_overflow": status.input_overflow,
116113
"input_underflow": status.input_underflow,

src/rai_asr/rai_asr/models/local_whisper.py

+27-4
Original file line numberDiff line numberDiff line change
@@ -26,20 +26,43 @@ class LocalWhisper(BaseTranscriptionModel):
2626
def __init__(self, model_name: str, sample_rate: int, language: str = "en"):
2727
super().__init__(model_name, sample_rate, language)
2828
if torch.cuda.is_available():
29-
print("Using CUDA")
3029
self.whisper = whisper.load_model(self.model_name, device="cuda")
3130
else:
3231
self.whisper = whisper.load_model(self.model_name)
3332

34-
self.samples = None
33+
# TODO: remove sample storage before PR is merged, this is just to enable saving wav files for debugging
34+
# self.samples = None
35+
36+
def consume_transcription(self) -> str:
37+
ret = super().consume_transcription()
38+
# self.samples = None
39+
return ret
40+
41+
# def save_wav(self, output_filename: str):
42+
# assert self.samples is not None, "No samples to save"
43+
# combined_samples = self.samples
44+
# if combined_samples.dtype.kind == "f":
45+
# combined_samples = np.clip(combined_samples, -1.0, 1.0)
46+
# combined_samples = (combined_samples * 32767).astype(np.int16)
47+
# elif combined_samples.dtype != np.int16:
48+
# combined_samples = combined_samples.astype(np.int16)
49+
50+
# with wave.open(output_filename, "wb") as wav_file:
51+
# n_channels = 1
52+
# sampwidth = 2
53+
# wav_file.setnchannels(n_channels)
54+
# wav_file.setsampwidth(sampwidth)
55+
# wav_file.setframerate(self.sample_rate)
56+
# wav_file.writeframes(combined_samples.tobytes())
3557

3658
def transcribe(self, data: NDArray[np.int16]):
59+
# self.samples = (
60+
# np.concatenate((self.samples, data)) if self.samples is not None else data
61+
# )
3762
normalized_data = data.astype(np.float32) / 32768.0
38-
print("Starting transcription")
3963
result = whisper.transcribe(
4064
self.whisper, normalized_data
4165
) # TODO: handling of additional transcribe arguments (perhaps in model init)
42-
print("Finished transcription")
4366
transcription = result["text"]
4467
transcription = cast(str, transcription)
4568
self.latest_transcription += transcription

tests/communication/test_sound_device_connector.py

-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ def device_config():
3131
return {
3232
"block_size": 1024,
3333
"consumer_sampling_rate": 44100,
34-
"target_sampling_rate": 16000,
3534
"dtype": "float32",
3635
}
3736

0 commit comments

Comments
 (0)